Skip to content

Commit 2a0c146

Browse files
committed
fix no tokio reactor running
1 parent c001fff commit 2a0c146

1 file changed

Lines changed: 65 additions & 40 deletions

File tree

src/lib.rs

Lines changed: 65 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -47,61 +47,58 @@ fn connect(
4747
None,
4848
None,
4949
);
50-
println!("connecting");
5150
tokio::pin!(fut);
5251
let result = rt.block_on(check_signals(py, fut));
53-
// let result = rt.block_on(async {
54-
// loop {
55-
// tokio::select! {
56-
// out = &mut fut => {
57-
// break out;
58-
// }
59-
60-
// _ = tokio::time::sleep(std::time::Duration::from_millis(300)) => {
61-
// py.check_signals().unwrap();
62-
// }
63-
// }
64-
// }
65-
// });
66-
println!("done connecting");
6752
result.map_err(to_py_err)?
6853
}
6954
None => libsql_core::Database::open(database).map_err(to_py_err)?,
7055
}
7156
};
7257
let autocommit = isolation_level.is_none();
7358
let conn = db.connect().map_err(to_py_err)?;
74-
let conn = Arc::new(conn);
7559
Ok(Connection {
7660
db,
77-
conn,
61+
conn: Arc::new(ConnectionGuard {
62+
conn: Some(conn),
63+
handle: rt.handle().clone(),
64+
}),
7865
rt,
7966
isolation_level,
8067
autocommit,
8168
})
8269
}
8370

84-
async fn check_signals<F, R>(py: Python<'_>, mut fut: std::pin::Pin<&mut F>) -> R
85-
where
86-
F: std::future::Future<Output = R>,
87-
{
88-
loop {
89-
tokio::select! {
90-
out = &mut fut => {
91-
break out;
92-
}
71+
// We need to add a drop guard that runs when we finally drop our
72+
// only reference to libsql_core::Connection. This is because when
73+
// hrana is enabled it needs access to the tokio api to spawn a close
74+
// call in the background. So this adds the ability that when drop is called
75+
// on ConnectionGuard it will drop the connection with a tokio context entered.
76+
struct ConnectionGuard {
77+
conn: Option<libsql_core::Connection>,
78+
handle: tokio::runtime::Handle,
79+
}
9380

94-
_ = tokio::time::sleep(std::time::Duration::from_millis(300)) => {
95-
py.check_signals().unwrap();
96-
}
81+
impl std::ops::Deref for ConnectionGuard {
82+
type Target = libsql_core::Connection;
83+
84+
fn deref(&self) -> &Self::Target {
85+
&self.conn.as_ref().expect("Connection already dropped")
86+
}
87+
}
88+
89+
impl Drop for ConnectionGuard {
90+
fn drop(&mut self) {
91+
let _enter = self.handle.enter();
92+
if let Some(conn) = self.conn.take() {
93+
drop(conn);
9794
}
9895
}
9996
}
10097

10198
#[pyclass]
10299
pub struct Connection {
103100
db: libsql_core::Database,
104-
conn: Arc<libsql_core::Connection>,
101+
conn: Arc<ConnectionGuard>,
105102
rt: tokio::runtime::Runtime,
106103
isolation_level: Option<String>,
107104
autocommit: bool,
@@ -126,7 +123,10 @@ impl Connection {
126123
}
127124

128125
fn sync(self_: PyRef<'_, Self>, py: Python<'_>) -> PyResult<()> {
129-
let fut = self_.db.sync();
126+
let fut = {
127+
let _enter = self_.rt.enter();
128+
self_.db.sync()
129+
};
130130
tokio::pin!(fut);
131131

132132
self_
@@ -141,7 +141,7 @@ impl Connection {
141141
if !self_.conn.is_autocommit() {
142142
self_
143143
.rt
144-
.block_on(self_.conn.execute("COMMIT", ()))
144+
.block_on(async { self_.conn.execute("COMMIT", ()).await })
145145
.map_err(to_py_err)?;
146146
}
147147
Ok(())
@@ -152,7 +152,7 @@ impl Connection {
152152
if !self_.conn.is_autocommit() {
153153
self_
154154
.rt
155-
.block_on(self_.conn.execute("ROLLBACK", ()))
155+
.block_on(async { self_.conn.execute("ROLLBACK", ()).await })
156156
.map_err(to_py_err)?;
157157
}
158158
Ok(())
@@ -165,7 +165,7 @@ impl Connection {
165165
) -> PyResult<Cursor> {
166166
let cursor = Connection::cursor(&self_)?;
167167
let rt = self_.rt.handle();
168-
rt.block_on(execute(&cursor, sql, parameters))?;
168+
rt.block_on(async { execute(&cursor, sql, parameters).await })?;
169169
Ok(cursor)
170170
}
171171

@@ -179,7 +179,7 @@ impl Connection {
179179
let parameters = parameters.extract::<&PyTuple>()?;
180180
self_
181181
.rt
182-
.block_on(execute(&cursor, sql.clone(), Some(parameters)))?;
182+
.block_on(async { execute(&cursor, sql.clone(), Some(parameters)).await })?;
183183
}
184184
Ok(cursor)
185185
}
@@ -200,7 +200,7 @@ pub struct Cursor {
200200
#[pyo3(get, set)]
201201
arraysize: usize,
202202
rt: tokio::runtime::Handle,
203-
conn: Arc<libsql_core::Connection>,
203+
conn: Arc<ConnectionGuard>,
204204
stmt: RefCell<Option<libsql_core::Statement>>,
205205
rows: RefCell<Option<libsql_core::Rows>>,
206206
rowcount: RefCell<i64>,
@@ -218,7 +218,9 @@ impl Cursor {
218218
sql: String,
219219
parameters: Option<&PyTuple>,
220220
) -> PyResult<pyo3::PyRef<'a, Self>> {
221-
self_.rt.block_on(execute(&self_, sql, parameters))?;
221+
self_
222+
.rt
223+
.block_on(async { execute(&self_, sql, parameters).await })?;
222224
Ok(self_)
223225
}
224226

@@ -231,7 +233,7 @@ impl Cursor {
231233
let parameters = parameters.extract::<&PyTuple>()?;
232234
self_
233235
.rt
234-
.block_on(execute(&self_, sql.clone(), Some(parameters)))?;
236+
.block_on(async { execute(&self_, sql.clone(), Some(parameters)).await })?;
235237
}
236238
Ok(self_)
237239
}
@@ -286,7 +288,10 @@ impl Cursor {
286288
// done before iterating.
287289
if !*self_.done.borrow() {
288290
for _ in 0..size {
289-
let row = self_.rt.block_on(rows.next()).map_err(to_py_err)?;
291+
let row = self_
292+
.rt
293+
.block_on(async { rows.next().await })
294+
.map_err(to_py_err)?;
290295
match row {
291296
Some(row) => {
292297
let row = convert_row(self_.py(), row, rows.column_count())?;
@@ -311,7 +316,10 @@ impl Cursor {
311316
Some(rows) => {
312317
let mut elements: Vec<Py<PyAny>> = vec![];
313318
loop {
314-
let row = self_.rt.block_on(rows.next()).map_err(to_py_err)?;
319+
let row = self_
320+
.rt
321+
.block_on(async { rows.next().await })
322+
.map_err(to_py_err)?;
315323
match row {
316324
Some(row) => {
317325
let row = convert_row(self_.py(), row, rows.column_count())?;
@@ -433,3 +441,20 @@ fn libsql_experimental(py: Python, m: &PyModule) -> PyResult<()> {
433441
m.add_class::<Cursor>()?;
434442
Ok(())
435443
}
444+
445+
async fn check_signals<F, R>(py: Python<'_>, mut fut: std::pin::Pin<&mut F>) -> R
446+
where
447+
F: std::future::Future<Output = R>,
448+
{
449+
loop {
450+
tokio::select! {
451+
out = &mut fut => {
452+
break out;
453+
}
454+
455+
_ = tokio::time::sleep(std::time::Duration::from_millis(300)) => {
456+
py.check_signals().unwrap();
457+
}
458+
}
459+
}
460+
}

0 commit comments

Comments
 (0)