@@ -21,6 +21,7 @@ fn is_remote_path(path: &str) -> bool {
2121#[ pyfunction]
2222#[ pyo3( signature = ( database, isolation_level="DEFERRED" . to_string( ) , check_same_thread=true , uri=false , sync_url=None , auth_token="" ) ) ]
2323fn connect (
24+ py : Python < ' _ > ,
2425 database : String ,
2526 isolation_level : Option < String > ,
2627 check_same_thread : bool ,
@@ -46,28 +47,58 @@ fn connect(
4647 None ,
4748 None ,
4849 ) ;
49- let result = rt. block_on ( fut) ;
50+ tokio:: pin!( fut) ;
51+ let result = rt. block_on ( check_signals ( py, fut) ) ;
5052 result. map_err ( to_py_err) ?
5153 }
5254 None => libsql_core:: Database :: open ( database) . map_err ( to_py_err) ?,
5355 }
5456 } ;
5557 let autocommit = isolation_level. is_none ( ) ;
5658 let conn = db. connect ( ) . map_err ( to_py_err) ?;
57- let conn = Arc :: new ( conn) ;
5859 Ok ( Connection {
5960 db,
60- conn,
61+ conn : Arc :: new ( ConnectionGuard {
62+ conn : Some ( conn) ,
63+ handle : rt. handle ( ) . clone ( ) ,
64+ } ) ,
6165 rt,
6266 isolation_level,
6367 autocommit,
6468 } )
6569}
6670
71+ // We need to add a drop guard that runs when we finally drop our
72+ // only reference to libsql_core::Connection. This is because when
73+ // hrana is enabled it needs access to the tokio api to spawn a close
74+ // call in the background. So this adds the ability that when drop is called
75+ // on ConnectionGuard it will drop the connection with a tokio context entered.
76+ struct ConnectionGuard {
77+ conn : Option < libsql_core:: Connection > ,
78+ handle : tokio:: runtime:: Handle ,
79+ }
80+
81+ impl std:: ops:: Deref for ConnectionGuard {
82+ type Target = libsql_core:: Connection ;
83+
84+ fn deref ( & self ) -> & Self :: Target {
85+ & self . conn . as_ref ( ) . expect ( "Connection already dropped" )
86+ }
87+ }
88+
89+ impl Drop for ConnectionGuard {
90+ fn drop ( & mut self ) {
91+ let _enter = self . handle . enter ( ) ;
92+ if let Some ( conn) = self . conn . take ( ) {
93+ drop ( conn) ;
94+ }
95+ }
96+ }
97+
6798#[ pyclass]
6899pub struct Connection {
69100 db : libsql_core:: Database ,
70- conn : Arc < libsql_core :: Connection > ,
101+ conn : Arc < ConnectionGuard > ,
71102 rt : tokio:: runtime:: Runtime ,
72103 isolation_level : Option < String > ,
73104 autocommit : bool ,
@@ -91,8 +122,17 @@ impl Connection {
91122 } )
92123 }
93124
94- fn sync ( self_ : PyRef < ' _ , Self > ) -> PyResult < ( ) > {
95- self_. rt . block_on ( self_. db . sync ( ) ) . map_err ( to_py_err) ?;
125+ fn sync ( self_ : PyRef < ' _ , Self > , py : Python < ' _ > ) -> PyResult < ( ) > {
126+ let fut = {
127+ let _enter = self_. rt . enter ( ) ;
128+ self_. db . sync ( )
129+ } ;
130+ tokio:: pin!( fut) ;
131+
132+ self_
133+ . rt
134+ . block_on ( check_signals ( py, fut) )
135+ . map_err ( to_py_err) ?;
96136 Ok ( ( ) )
97137 }
98138
@@ -101,7 +141,7 @@ impl Connection {
101141 if !self_. conn . is_autocommit ( ) {
102142 self_
103143 . rt
104- . block_on ( self_. conn . execute ( "COMMIT" , ( ) ) )
144+ . block_on ( async { self_. conn . execute ( "COMMIT" , ( ) ) . await } )
105145 . map_err ( to_py_err) ?;
106146 }
107147 Ok ( ( ) )
@@ -112,7 +152,7 @@ impl Connection {
112152 if !self_. conn . is_autocommit ( ) {
113153 self_
114154 . rt
115- . block_on ( self_. conn . execute ( "ROLLBACK" , ( ) ) )
155+ . block_on ( async { self_. conn . execute ( "ROLLBACK" , ( ) ) . await } )
116156 . map_err ( to_py_err) ?;
117157 }
118158 Ok ( ( ) )
@@ -125,7 +165,7 @@ impl Connection {
125165 ) -> PyResult < Cursor > {
126166 let cursor = Connection :: cursor ( & self_) ?;
127167 let rt = self_. rt . handle ( ) ;
128- rt. block_on ( execute ( & cursor, sql, parameters) ) ?;
168+ rt. block_on ( async { execute ( & cursor, sql, parameters) . await } ) ?;
129169 Ok ( cursor)
130170 }
131171
@@ -139,7 +179,7 @@ impl Connection {
139179 let parameters = parameters. extract :: < & PyTuple > ( ) ?;
140180 self_
141181 . rt
142- . block_on ( execute ( & cursor, sql. clone ( ) , Some ( parameters) ) ) ?;
182+ . block_on ( async { execute ( & cursor, sql. clone ( ) , Some ( parameters) ) . await } ) ?;
143183 }
144184 Ok ( cursor)
145185 }
@@ -160,7 +200,7 @@ pub struct Cursor {
160200 #[ pyo3( get, set) ]
161201 arraysize : usize ,
162202 rt : tokio:: runtime:: Handle ,
163- conn : Arc < libsql_core :: Connection > ,
203+ conn : Arc < ConnectionGuard > ,
164204 stmt : RefCell < Option < libsql_core:: Statement > > ,
165205 rows : RefCell < Option < libsql_core:: Rows > > ,
166206 rowcount : RefCell < i64 > ,
@@ -178,7 +218,9 @@ impl Cursor {
178218 sql : String ,
179219 parameters : Option < & PyTuple > ,
180220 ) -> PyResult < pyo3:: PyRef < ' a , Self > > {
181- self_. rt . block_on ( execute ( & self_, sql, parameters) ) ?;
221+ self_
222+ . rt
223+ . block_on ( async { execute ( & self_, sql, parameters) . await } ) ?;
182224 Ok ( self_)
183225 }
184226
@@ -191,7 +233,7 @@ impl Cursor {
191233 let parameters = parameters. extract :: < & PyTuple > ( ) ?;
192234 self_
193235 . rt
194- . block_on ( execute ( & self_, sql. clone ( ) , Some ( parameters) ) ) ?;
236+ . block_on ( async { execute ( & self_, sql. clone ( ) , Some ( parameters) ) . await } ) ?;
195237 }
196238 Ok ( self_)
197239 }
@@ -246,19 +288,22 @@ impl Cursor {
246288 // done before iterating.
247289 if !* self_. done . borrow ( ) {
248290 for _ in 0 ..size {
249- let row = self_. rt . block_on ( rows. next ( ) ) . map_err ( to_py_err) ?;
291+ let row = self_
292+ . rt
293+ . block_on ( async { rows. next ( ) . await } )
294+ . map_err ( to_py_err) ?;
250295 match row {
251296 Some ( row) => {
252297 let row = convert_row ( self_. py ( ) , row, rows. column_count ( ) ) ?;
253298 elements. push ( row. into ( ) ) ;
254299 }
255300 None => {
256301 self_. done . replace ( true ) ;
257- break
302+ break ;
258303 }
259304 }
260305 }
261- }
306+ }
262307 Ok ( Some ( PyList :: new ( self_. py ( ) , elements) ) )
263308 }
264309 None => Ok ( None ) ,
@@ -271,7 +316,10 @@ impl Cursor {
271316 Some ( rows) => {
272317 let mut elements: Vec < Py < PyAny > > = vec ! [ ] ;
273318 loop {
274- let row = self_. rt . block_on ( rows. next ( ) ) . map_err ( to_py_err) ?;
319+ let row = self_
320+ . rt
321+ . block_on ( async { rows. next ( ) . await } )
322+ . map_err ( to_py_err) ?;
275323 match row {
276324 Some ( row) => {
277325 let row = convert_row ( self_. py ( ) , row, rows. column_count ( ) ) ?;
@@ -393,3 +441,20 @@ fn libsql_experimental(py: Python, m: &PyModule) -> PyResult<()> {
393441 m. add_class :: < Cursor > ( ) ?;
394442 Ok ( ( ) )
395443}
444+
445+ async fn check_signals < F , R > ( py : Python < ' _ > , mut fut : std:: pin:: Pin < & mut F > ) -> R
446+ where
447+ F : std:: future:: Future < Output = R > ,
448+ {
449+ loop {
450+ tokio:: select! {
451+ out = & mut fut => {
452+ break out;
453+ }
454+
455+ _ = tokio:: time:: sleep( std:: time:: Duration :: from_millis( 300 ) ) => {
456+ py. check_signals( ) . unwrap( ) ;
457+ }
458+ }
459+ }
460+ }
0 commit comments