@@ -47,61 +47,58 @@ fn connect(
4747 None ,
4848 None ,
4949 ) ;
50- println ! ( "connecting" ) ;
5150 tokio:: pin!( fut) ;
5251 let result = rt. block_on ( check_signals ( py, fut) ) ;
53- // let result = rt.block_on(async {
54- // loop {
55- // tokio::select! {
56- // out = &mut fut => {
57- // break out;
58- // }
59-
60- // _ = tokio::time::sleep(std::time::Duration::from_millis(300)) => {
61- // py.check_signals().unwrap();
62- // }
63- // }
64- // }
65- // });
66- println ! ( "done connecting" ) ;
6752 result. map_err ( to_py_err) ?
6853 }
6954 None => libsql_core:: Database :: open ( database) . map_err ( to_py_err) ?,
7055 }
7156 } ;
7257 let autocommit = isolation_level. is_none ( ) ;
7358 let conn = db. connect ( ) . map_err ( to_py_err) ?;
74- let conn = Arc :: new ( conn) ;
7559 Ok ( Connection {
7660 db,
77- conn,
61+ conn : Arc :: new ( ConnectionGuard {
62+ conn : Some ( conn) ,
63+ handle : rt. handle ( ) . clone ( ) ,
64+ } ) ,
7865 rt,
7966 isolation_level,
8067 autocommit,
8168 } )
8269}
8370
84- async fn check_signals < F , R > ( py : Python < ' _ > , mut fut : std :: pin :: Pin < & mut F > ) -> R
85- where
86- F : std :: future :: Future < Output = R > ,
87- {
88- loop {
89- tokio :: select! {
90- out = & mut fut => {
91- break out ;
92- }
71+ // We need to add a drop guard that runs when we finally drop our
72+ // only reference to libsql_core::Connection. This is because when
73+ // hrana is enabled it needs access to the tokio api to spawn a close
74+ // call in the background. So this adds the ability that when drop is called
75+ // on ConnectionGuard it will drop the connection with a tokio context entered.
76+ struct ConnectionGuard {
77+ conn : Option < libsql_core :: Connection > ,
78+ handle : tokio :: runtime :: Handle ,
79+ }
9380
94- _ = tokio:: time:: sleep( std:: time:: Duration :: from_millis( 300 ) ) => {
95- py. check_signals( ) . unwrap( ) ;
96- }
81+ impl std:: ops:: Deref for ConnectionGuard {
82+ type Target = libsql_core:: Connection ;
83+
84+ fn deref ( & self ) -> & Self :: Target {
85+ & self . conn . as_ref ( ) . expect ( "Connection already dropped" )
86+ }
87+ }
88+
89+ impl Drop for ConnectionGuard {
90+ fn drop ( & mut self ) {
91+ let _enter = self . handle . enter ( ) ;
92+ if let Some ( conn) = self . conn . take ( ) {
93+ drop ( conn) ;
9794 }
9895 }
9996}
10097
10198#[ pyclass]
10299pub struct Connection {
103100 db : libsql_core:: Database ,
104- conn : Arc < libsql_core :: Connection > ,
101+ conn : Arc < ConnectionGuard > ,
105102 rt : tokio:: runtime:: Runtime ,
106103 isolation_level : Option < String > ,
107104 autocommit : bool ,
@@ -126,7 +123,10 @@ impl Connection {
126123 }
127124
128125 fn sync ( self_ : PyRef < ' _ , Self > , py : Python < ' _ > ) -> PyResult < ( ) > {
129- let fut = self_. db . sync ( ) ;
126+ let fut = {
127+ let _enter = self_. rt . enter ( ) ;
128+ self_. db . sync ( )
129+ } ;
130130 tokio:: pin!( fut) ;
131131
132132 self_
@@ -141,7 +141,7 @@ impl Connection {
141141 if !self_. conn . is_autocommit ( ) {
142142 self_
143143 . rt
144- . block_on ( self_. conn . execute ( "COMMIT" , ( ) ) )
144+ . block_on ( async { self_. conn . execute ( "COMMIT" , ( ) ) . await } )
145145 . map_err ( to_py_err) ?;
146146 }
147147 Ok ( ( ) )
@@ -152,7 +152,7 @@ impl Connection {
152152 if !self_. conn . is_autocommit ( ) {
153153 self_
154154 . rt
155- . block_on ( self_. conn . execute ( "ROLLBACK" , ( ) ) )
155+ . block_on ( async { self_. conn . execute ( "ROLLBACK" , ( ) ) . await } )
156156 . map_err ( to_py_err) ?;
157157 }
158158 Ok ( ( ) )
@@ -165,7 +165,7 @@ impl Connection {
165165 ) -> PyResult < Cursor > {
166166 let cursor = Connection :: cursor ( & self_) ?;
167167 let rt = self_. rt . handle ( ) ;
168- rt. block_on ( execute ( & cursor, sql, parameters) ) ?;
168+ rt. block_on ( async { execute ( & cursor, sql, parameters) . await } ) ?;
169169 Ok ( cursor)
170170 }
171171
@@ -179,7 +179,7 @@ impl Connection {
179179 let parameters = parameters. extract :: < & PyTuple > ( ) ?;
180180 self_
181181 . rt
182- . block_on ( execute ( & cursor, sql. clone ( ) , Some ( parameters) ) ) ?;
182+ . block_on ( async { execute ( & cursor, sql. clone ( ) , Some ( parameters) ) . await } ) ?;
183183 }
184184 Ok ( cursor)
185185 }
@@ -200,7 +200,7 @@ pub struct Cursor {
200200 #[ pyo3( get, set) ]
201201 arraysize : usize ,
202202 rt : tokio:: runtime:: Handle ,
203- conn : Arc < libsql_core :: Connection > ,
203+ conn : Arc < ConnectionGuard > ,
204204 stmt : RefCell < Option < libsql_core:: Statement > > ,
205205 rows : RefCell < Option < libsql_core:: Rows > > ,
206206 rowcount : RefCell < i64 > ,
@@ -218,7 +218,9 @@ impl Cursor {
218218 sql : String ,
219219 parameters : Option < & PyTuple > ,
220220 ) -> PyResult < pyo3:: PyRef < ' a , Self > > {
221- self_. rt . block_on ( execute ( & self_, sql, parameters) ) ?;
221+ self_
222+ . rt
223+ . block_on ( async { execute ( & self_, sql, parameters) . await } ) ?;
222224 Ok ( self_)
223225 }
224226
@@ -231,7 +233,7 @@ impl Cursor {
231233 let parameters = parameters. extract :: < & PyTuple > ( ) ?;
232234 self_
233235 . rt
234- . block_on ( execute ( & self_, sql. clone ( ) , Some ( parameters) ) ) ?;
236+ . block_on ( async { execute ( & self_, sql. clone ( ) , Some ( parameters) ) . await } ) ?;
235237 }
236238 Ok ( self_)
237239 }
@@ -286,7 +288,10 @@ impl Cursor {
286288 // done before iterating.
287289 if !* self_. done . borrow ( ) {
288290 for _ in 0 ..size {
289- let row = self_. rt . block_on ( rows. next ( ) ) . map_err ( to_py_err) ?;
291+ let row = self_
292+ . rt
293+ . block_on ( async { rows. next ( ) . await } )
294+ . map_err ( to_py_err) ?;
290295 match row {
291296 Some ( row) => {
292297 let row = convert_row ( self_. py ( ) , row, rows. column_count ( ) ) ?;
@@ -311,7 +316,10 @@ impl Cursor {
311316 Some ( rows) => {
312317 let mut elements: Vec < Py < PyAny > > = vec ! [ ] ;
313318 loop {
314- let row = self_. rt . block_on ( rows. next ( ) ) . map_err ( to_py_err) ?;
319+ let row = self_
320+ . rt
321+ . block_on ( async { rows. next ( ) . await } )
322+ . map_err ( to_py_err) ?;
315323 match row {
316324 Some ( row) => {
317325 let row = convert_row ( self_. py ( ) , row, rows. column_count ( ) ) ?;
@@ -433,3 +441,20 @@ fn libsql_experimental(py: Python, m: &PyModule) -> PyResult<()> {
433441 m. add_class :: < Cursor > ( ) ?;
434442 Ok ( ( ) )
435443}
444+
445+ async fn check_signals < F , R > ( py : Python < ' _ > , mut fut : std:: pin:: Pin < & mut F > ) -> R
446+ where
447+ F : std:: future:: Future < Output = R > ,
448+ {
449+ loop {
450+ tokio:: select! {
451+ out = & mut fut => {
452+ break out;
453+ }
454+
455+ _ = tokio:: time:: sleep( std:: time:: Duration :: from_millis( 300 ) ) => {
456+ py. check_signals( ) . unwrap( ) ;
457+ }
458+ }
459+ }
460+ }
0 commit comments