1- use crate :: { query :: parameters :: QueryParameter , rows :: CanyonRows } ;
2- # [ cfg ( feature = "mssql" ) ]
3- use async_std :: net :: TcpStream ;
1+ use crate :: connection :: { MsManager , SqlServerConnectionPool } ;
2+ use crate :: query :: parameters :: QueryParameter ;
3+ use bb8 :: PooledConnection ;
44use std:: error:: Error ;
5+ #[ cfg( feature = "mssql" ) ]
56use tiberius:: Query ;
67
78/// A connection with a `SqlServer` database
8- #[ cfg( feature = "mssql" ) ]
9- pub struct SqlServerConnection {
10- pub client : & ' static mut tiberius:: Client < TcpStream > ,
9+ #[ cfg( feature = "mssql" ) ] // TODO: remove the local cfg and put them at module level
10+ pub struct SqlServerConnection ( SqlServerConnectionPool ) ;
11+
12+ impl SqlServerConnection {
13+ pub fn new ( pool : SqlServerConnectionPool ) -> Result < Self , Box < dyn Error + Send + Sync > > {
14+ Ok ( Self ( pool) )
15+ }
16+ pub async fn get_pooled (
17+ & self ,
18+ ) -> Result < PooledConnection < ' _ , MsManager > , Box < dyn Error + Send + Sync > > {
19+ Ok ( self . 0 . get ( ) . await ?)
20+ }
1121}
1222
1323#[ cfg( feature = "mssql" ) ]
1424pub ( crate ) mod sqlserver_query_launcher {
1525 use super :: * ;
16- use crate :: mapper:: RowMapper ;
17- use crate :: rows:: FromSqlOwnedValue ;
1826 use tiberius:: QueryStream ;
1927
20- #[ inline( always) ]
21- pub ( crate ) async fn query < S , R > (
22- stmt : S ,
23- params : & [ & ' _ dyn QueryParameter ] ,
24- conn : & SqlServerConnection ,
25- ) -> Result < Vec < R > , Box < dyn Error + Send + Sync > >
26- where
27- S : AsRef < str > + Send ,
28- R : RowMapper ,
29- Vec < R > : FromIterator < <R as RowMapper >:: Output > ,
30- {
31- Ok ( execute_query ( stmt. as_ref ( ) , params, conn)
32- . await ?
33- . into_results ( )
34- . await ?
35- . into_iter ( )
36- . flatten ( )
37- . flat_map ( |row| R :: deserialize_sqlserver ( & row) )
38- . collect :: < Vec < R > > ( ) )
39- }
40-
41- #[ inline( always) ]
42- pub ( crate ) async fn query_rows (
43- stmt : & str ,
44- params : & [ & ' _ dyn QueryParameter ] ,
45- conn : & SqlServerConnection ,
46- ) -> Result < CanyonRows , Box < dyn Error + Send + Sync > > {
47- let result = execute_query ( stmt, params, conn)
48- . await ?
49- . into_results ( )
50- . await ?
51- . into_iter ( )
52- . flatten ( )
53- . collect ( ) ;
54-
55- Ok ( CanyonRows :: Tiberius ( result) )
56- }
57-
58- pub ( crate ) async fn query_one < R > (
59- stmt : & str ,
60- params : & [ & ' _ dyn QueryParameter ] ,
61- conn : & SqlServerConnection ,
62- ) -> Result < Option < R :: Output > , Box < dyn Error + Send + Sync > >
63- where
64- R : RowMapper ,
65- {
66- let result = execute_query ( stmt, params, conn) . await ?. into_row ( ) . await ?;
67-
68- match result {
69- Some ( r) => Ok ( Some ( R :: deserialize_sqlserver ( & r) ?) ) ,
70- None => Ok ( None ) ,
71- }
72- }
73-
74- pub ( crate ) async fn query_one_for < T : FromSqlOwnedValue < T > > (
75- stmt : & str ,
76- params : & [ & ' _ dyn QueryParameter ] ,
77- conn : & SqlServerConnection ,
78- ) -> Result < T , Box < dyn Error + Send + Sync > > {
79- let row = execute_query ( stmt, params, conn)
80- . await ?
81- . into_row ( )
82- . await ?
83- . ok_or_else ( || format ! ( "Failure executing 'query_one_for' while retrieving the first row with stmt: {:?}" , stmt) ) ?;
84-
85- Ok ( row
86- . into_iter ( )
87- . map ( T :: from_sql_owned)
88- . collect :: < Vec < _ > > ( )
89- . remove ( 0 ) ?
90- . ok_or_else ( || format ! ( "Failure executing 'query_one_for' while retrieving the first column value on the first row with stmt: {:?}" , stmt) ) ?
91- )
92- }
93-
94- pub ( crate ) async fn execute (
28+ pub ( crate ) async fn execute_query < ' a > (
9529 stmt : & str ,
96- params : & [ & ' _ dyn QueryParameter ] ,
97- conn : & SqlServerConnection ,
98- ) -> Result < u64 , Box < dyn Error + Send + Sync > > {
99- let mssql_query = generate_mssql_stmt ( stmt, params) . await ;
100-
101- #[ allow( mutable_transmutes) ] // TODO: pls solve this elegantly someday :(
102- let sqlservconn =
103- unsafe { std:: mem:: transmute :: < & SqlServerConnection , & mut SqlServerConnection > ( conn) } ;
104-
105- mssql_query
106- . execute ( sqlservconn. client )
107- . await
108- . map ( |r| r. total ( ) )
109- . map_err ( From :: from)
30+ params : & [ & dyn QueryParameter ] ,
31+ conn : & ' a mut bb8:: PooledConnection < ' _ , bb8_tiberius:: ConnectionManager > ,
32+ ) -> Result < QueryStream < ' a > , Box < dyn Error + Send + Sync > > {
33+ let mssql_query = generate_mssql_query_client ( stmt, params) . await ;
34+ mssql_query. query ( conn) . await . map_err ( From :: from)
11035 }
11136
112- async fn execute_query < ' a > (
37+ pub ( crate ) async fn generate_mssql_query_client < ' a > (
11338 stmt : & str ,
11439 params : & [ & ' a dyn QueryParameter ] ,
115- conn : & SqlServerConnection ,
116- ) -> Result < QueryStream < ' a > , Box < dyn Error + Send + Sync > > {
117- let mssql_query = generate_mssql_stmt ( stmt, params) . await ;
118-
119- #[ allow( mutable_transmutes) ] // TODO: pls solve this elegantly someday :(
120- let sqlservconn =
121- unsafe { std:: mem:: transmute :: < & SqlServerConnection , & mut SqlServerConnection > ( conn) } ;
122- Ok ( mssql_query. query ( sqlservconn. client ) . await ?)
123- }
124-
125- async fn generate_mssql_stmt < ' a > ( stmt : & str , params : & [ & ' a dyn QueryParameter ] ) -> Query < ' a > {
40+ ) -> Query < ' a > {
12641 let mut stmt = String :: from ( stmt) ;
12742 if stmt. contains ( "RETURNING" ) {
43+ // TODO: when the InsertQuerybuilder with a api on the builder for the returning clause
12844 let c = stmt. clone ( ) ;
12945 let temp = c. split_once ( "RETURNING" ) . unwrap ( ) ;
13046 let temp2 = temp. 0 . split_once ( "VALUES" ) . unwrap ( ) ;
@@ -137,13 +53,19 @@ pub(crate) mod sqlserver_query_launcher {
13753 ) ;
13854 }
13955
140- // TODO: We must address the query generation
141- // NOTE: ready to apply the change now that the querybuilder knows what's the underlying db type
142- let mut mssql_query = Query :: new ( stmt. to_owned ( ) . replace ( '$' , "@P" ) ) ;
56+ let stmt = stmt. replace ( '$' , "@P" ) ; // TODO: this should be solved by the querybuilder
57+ generate_query_and_bind_params ( stmt, params)
58+ }
59+
60+ // Query and parameters are generated in this procedure together to avoid lifetime errors
61+ fn generate_query_and_bind_params < ' a > (
62+ stmt : String ,
63+ params : & [ & ' a ( dyn QueryParameter + ' a ) ] ,
64+ ) -> Query < ' a > {
65+ let mut mssql_query = Query :: new ( stmt) ;
14366 params. iter ( ) . for_each ( |param| {
14467 mssql_query. bind ( * param) ;
14568 } ) ;
146-
14769 mssql_query
14870 }
14971}
0 commit comments