Skip to content

Commit 49ff2ed

Browse files
committed
feat: connection pooling for all the supported databases
1 parent 6809c75 commit 49ff2ed

9 files changed

Lines changed: 250 additions & 262 deletions

File tree

canyon_core/src/canyon.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
use crate::connection::conn_errors::DatasourceNotFound;
22
use crate::connection::database_type::DatabaseType;
33
use crate::connection::datasources::{CanyonSqlConfig, DatasourceConfig, Datasources};
4-
use crate::connection::{
5-
CANYON_INSTANCE, db_connector, get_canyon_tokio_runtime,
6-
};
4+
use crate::connection::{CANYON_INSTANCE, db_connector, get_canyon_tokio_runtime};
75
use db_connector::DatabaseConnection;
86
use std::collections::HashMap;
97
use std::sync::Arc;
108
use std::{error::Error, fs};
119
use tokio::sync::Mutex;
12-
use crate::connection::pool::CanyonConnection;
1310

1411
pub type SharedConnection = Arc<Mutex<DatabaseConnection>>;
1512

@@ -54,14 +51,14 @@ pub type SharedConnection = Arc<Mutex<DatabaseConnection>>;
5451
/// - `find_datasource_by_name_or_default`: Finds a datasource by name or returns the default.
5552
/// - `get_connection`: Retrieves a read-only connection from the cache.
5653
/// - `get_mut_connection`: Retrieves a mutable connection from the cache.
57-
pub struct Canyon<'a> {
54+
pub struct Canyon {
5855
config: Datasources,
59-
connections: HashMap<&'static str, CanyonConnection<'a>>,
56+
connections: HashMap<&'static str, DatabaseConnection>,
6057
default_connection: Option<DatabaseConnection>,
6158
default_db_type: Option<DatabaseType>,
6259
}
6360

64-
impl<'a> Canyon<'a> {
61+
impl Canyon {
6562
/// Returns the global singleton instance of `Canyon`.
6663
///
6764
/// This function allows access to the singleton instance of the Canyon engine
@@ -115,8 +112,8 @@ impl<'a> Canyon<'a> {
115112
let config_content = fs::read_to_string(&path)?;
116113
let config: Datasources = toml::from_str::<CanyonSqlConfig>(&config_content)?.canyon_sql;
117114

118-
let mut connections: HashMap<&str, CanyonConnection<'a>> = HashMap::new();
119-
let mut default_connection: Option<CanyonConnection<'a>> = None;
115+
let mut connections: HashMap<&str, DatabaseConnection> = HashMap::new();
116+
let mut default_connection: Option<DatabaseConnection> = None;
120117
let mut default_db_type: Option<DatabaseType> = None;
121118

122119
for ds in config.datasources.iter() {
@@ -248,7 +245,7 @@ mod __impl {
248245
})
249246
}
250247

251-
pub(crate) async fn process_new_conn_by_datasource<'a>(
248+
pub(crate) async fn process_new_conn_by_datasource(
252249
ds: &DatasourceConfig,
253250
connections: &mut HashMap<&str, DatabaseConnection>,
254251
default: &mut Option<DatabaseConnection>,
Lines changed: 35 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,130 +1,46 @@
1-
use crate::{query::parameters::QueryParameter, rows::CanyonRows};
2-
#[cfg(feature = "mssql")]
3-
use async_std::net::TcpStream;
1+
use crate::connection::{MsManager, SqlServerConnectionPool};
2+
use crate::query::parameters::QueryParameter;
3+
use bb8::PooledConnection;
44
use std::error::Error;
5+
#[cfg(feature = "mssql")]
56
use tiberius::Query;
67

78
/// A connection with a `SqlServer` database
8-
#[cfg(feature = "mssql")]
9-
pub struct SqlServerConnection {
10-
pub client: &'static mut tiberius::Client<TcpStream>,
9+
#[cfg(feature = "mssql")] // TODO: remove the local cfg and put them at module level
10+
pub struct SqlServerConnection(SqlServerConnectionPool);
11+
12+
impl SqlServerConnection {
13+
pub fn new(pool: SqlServerConnectionPool) -> Result<Self, Box<dyn Error + Send + Sync>> {
14+
Ok(Self(pool))
15+
}
16+
pub async fn get_pooled(
17+
&self,
18+
) -> Result<PooledConnection<'_, MsManager>, Box<dyn Error + Send + Sync>> {
19+
Ok(self.0.get().await?)
20+
}
1121
}
1222

1323
#[cfg(feature = "mssql")]
1424
pub(crate) mod sqlserver_query_launcher {
1525
use super::*;
16-
use crate::mapper::RowMapper;
17-
use crate::rows::FromSqlOwnedValue;
1826
use tiberius::QueryStream;
1927

20-
#[inline(always)]
21-
pub(crate) async fn query<S, R>(
22-
stmt: S,
23-
params: &[&'_ dyn QueryParameter],
24-
conn: &SqlServerConnection,
25-
) -> Result<Vec<R>, Box<dyn Error + Send + Sync>>
26-
where
27-
S: AsRef<str> + Send,
28-
R: RowMapper,
29-
Vec<R>: FromIterator<<R as RowMapper>::Output>,
30-
{
31-
Ok(execute_query(stmt.as_ref(), params, conn)
32-
.await?
33-
.into_results()
34-
.await?
35-
.into_iter()
36-
.flatten()
37-
.flat_map(|row| R::deserialize_sqlserver(&row))
38-
.collect::<Vec<R>>())
39-
}
40-
41-
#[inline(always)]
42-
pub(crate) async fn query_rows(
43-
stmt: &str,
44-
params: &[&'_ dyn QueryParameter],
45-
conn: &SqlServerConnection,
46-
) -> Result<CanyonRows, Box<dyn Error + Send + Sync>> {
47-
let result = execute_query(stmt, params, conn)
48-
.await?
49-
.into_results()
50-
.await?
51-
.into_iter()
52-
.flatten()
53-
.collect();
54-
55-
Ok(CanyonRows::Tiberius(result))
56-
}
57-
58-
pub(crate) async fn query_one<R>(
59-
stmt: &str,
60-
params: &[&'_ dyn QueryParameter],
61-
conn: &SqlServerConnection,
62-
) -> Result<Option<R::Output>, Box<dyn Error + Send + Sync>>
63-
where
64-
R: RowMapper,
65-
{
66-
let result = execute_query(stmt, params, conn).await?.into_row().await?;
67-
68-
match result {
69-
Some(r) => Ok(Some(R::deserialize_sqlserver(&r)?)),
70-
None => Ok(None),
71-
}
72-
}
73-
74-
pub(crate) async fn query_one_for<T: FromSqlOwnedValue<T>>(
75-
stmt: &str,
76-
params: &[&'_ dyn QueryParameter],
77-
conn: &SqlServerConnection,
78-
) -> Result<T, Box<dyn Error + Send + Sync>> {
79-
let row = execute_query(stmt, params, conn)
80-
.await?
81-
.into_row()
82-
.await?
83-
.ok_or_else(|| format!("Failure executing 'query_one_for' while retrieving the first row with stmt: {:?}", stmt))?;
84-
85-
Ok(row
86-
.into_iter()
87-
.map(T::from_sql_owned)
88-
.collect::<Vec<_>>()
89-
.remove(0)?
90-
.ok_or_else(|| format!("Failure executing 'query_one_for' while retrieving the first column value on the first row with stmt: {:?}", stmt))?
91-
)
92-
}
93-
94-
pub(crate) async fn execute(
28+
pub(crate) async fn execute_query<'a>(
9529
stmt: &str,
96-
params: &[&'_ dyn QueryParameter],
97-
conn: &SqlServerConnection,
98-
) -> Result<u64, Box<dyn Error + Send + Sync>> {
99-
let mssql_query = generate_mssql_stmt(stmt, params).await;
100-
101-
#[allow(mutable_transmutes)] // TODO: pls solve this elegantly someday :(
102-
let sqlservconn =
103-
unsafe { std::mem::transmute::<&SqlServerConnection, &mut SqlServerConnection>(conn) };
104-
105-
mssql_query
106-
.execute(sqlservconn.client)
107-
.await
108-
.map(|r| r.total())
109-
.map_err(From::from)
30+
params: &[&dyn QueryParameter],
31+
conn: &'a mut bb8::PooledConnection<'_, bb8_tiberius::ConnectionManager>,
32+
) -> Result<QueryStream<'a>, Box<dyn Error + Send + Sync>> {
33+
let mssql_query = generate_mssql_query_client(stmt, params).await;
34+
mssql_query.query(conn).await.map_err(From::from)
11035
}
11136

112-
async fn execute_query<'a>(
37+
pub(crate) async fn generate_mssql_query_client<'a>(
11338
stmt: &str,
11439
params: &[&'a dyn QueryParameter],
115-
conn: &SqlServerConnection,
116-
) -> Result<QueryStream<'a>, Box<dyn Error + Send + Sync>> {
117-
let mssql_query = generate_mssql_stmt(stmt, params).await;
118-
119-
#[allow(mutable_transmutes)] // TODO: pls solve this elegantly someday :(
120-
let sqlservconn =
121-
unsafe { std::mem::transmute::<&SqlServerConnection, &mut SqlServerConnection>(conn) };
122-
Ok(mssql_query.query(sqlservconn.client).await?)
123-
}
124-
125-
async fn generate_mssql_stmt<'a>(stmt: &str, params: &[&'a dyn QueryParameter]) -> Query<'a> {
40+
) -> Query<'a> {
12641
let mut stmt = String::from(stmt);
12742
if stmt.contains("RETURNING") {
43+
// TODO: when the InsertQuerybuilder with a api on the builder for the returning clause
12844
let c = stmt.clone();
12945
let temp = c.split_once("RETURNING").unwrap();
13046
let temp2 = temp.0.split_once("VALUES").unwrap();
@@ -137,13 +53,19 @@ pub(crate) mod sqlserver_query_launcher {
13753
);
13854
}
13955

140-
// TODO: We must address the query generation
141-
// NOTE: ready to apply the change now that the querybuilder knows what's the underlying db type
142-
let mut mssql_query = Query::new(stmt.to_owned().replace('$', "@P"));
56+
let stmt = stmt.replace('$', "@P"); // TODO: this should be solved by the querybuilder
57+
generate_query_and_bind_params(stmt, params)
58+
}
59+
60+
// Query and parameters are generated in this procedure together to avoid lifetime errors
61+
fn generate_query_and_bind_params<'a>(
62+
stmt: String,
63+
params: &[&'a (dyn QueryParameter + 'a)],
64+
) -> Query<'a> {
65+
let mut mssql_query = Query::new(stmt);
14366
params.iter().for_each(|param| {
14467
mssql_query.bind(*param);
14568
});
146-
14769
mssql_query
14870
}
14971
}

canyon_core/src/connection/clients/mysql.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@ use std::error::Error;
1010

1111
/// A connection with a `Mysql` database
1212
#[cfg(feature = "mysql")]
13-
pub struct MysqlConnection {
14-
pub client: Pool,
13+
pub struct MySQLConnector(mysql_async::Pool);
14+
15+
impl MySQLConnector {
16+
pub fn new(pool: Pool) -> Self {
17+
Self(pool)
18+
}
1519
}
1620

1721
#[cfg(feature = "mysql")]
@@ -33,7 +37,7 @@ pub(crate) mod mysql_query_launcher {
3337
pub async fn query<S, R>(
3438
stmt: S,
3539
params: &[&'_ dyn QueryParameter],
36-
conn: &MysqlConnection,
40+
conn: &MySQLConnector,
3741
) -> Result<Vec<R>, Box<dyn Error + Send + Sync>>
3842
where
3943
S: AsRef<str> + Send,
@@ -51,7 +55,7 @@ pub(crate) mod mysql_query_launcher {
5155
pub(crate) async fn query_rows(
5256
stmt: &str,
5357
params: &[&'_ dyn QueryParameter],
54-
conn: &MysqlConnection,
58+
conn: &MySQLConnector,
5559
) -> Result<CanyonRows, Box<dyn Error + Send + Sync>> {
5660
Ok(CanyonRows::MySQL(execute_query(stmt, params, conn).await?))
5761
}
@@ -60,7 +64,7 @@ pub(crate) mod mysql_query_launcher {
6064
pub(crate) async fn query_one<R>(
6165
stmt: &str,
6266
params: &[&'_ dyn QueryParameter],
63-
conn: &MysqlConnection,
67+
conn: &MySQLConnector,
6468
) -> Result<Option<R::Output>, Box<dyn Error + Send + Sync>>
6569
where
6670
R: RowMapper,
@@ -77,7 +81,7 @@ pub(crate) mod mysql_query_launcher {
7781
pub(crate) async fn query_one_for<T: FromSqlOwnedValue<T>>(
7882
stmt: &str,
7983
params: &[&'_ dyn QueryParameter],
80-
conn: &MysqlConnection,
84+
conn: &MySQLConnector,
8185
) -> Result<T, Box<dyn Error + Send + Sync>> {
8286
Ok(execute_query(stmt, params, conn)
8387
.await?
@@ -92,12 +96,12 @@ pub(crate) mod mysql_query_launcher {
9296
async fn execute_query<S>(
9397
stmt: S,
9498
params: &[&'_ dyn QueryParameter],
95-
conn: &MysqlConnection,
99+
conn: &MySQLConnector,
96100
) -> Result<Vec<Row>, Box<dyn Error + Send + Sync>>
97101
where
98102
S: AsRef<str> + Send,
99103
{
100-
let mysql_connection = conn.client.get_conn().await?;
104+
let mysql_connection = conn.0.get_conn().await?;
101105
let is_insert = stmt.as_ref().find(" RETURNING");
102106
let mysql_stmt = generate_mysql_stmt(stmt.as_ref(), params)?;
103107

@@ -122,12 +126,12 @@ pub(crate) mod mysql_query_launcher {
122126
pub(crate) async fn execute<S>(
123127
stmt: S,
124128
params: &[&'_ dyn QueryParameter],
125-
conn: &MysqlConnection,
129+
conn: &MySQLConnector,
126130
) -> Result<u64, Box<dyn Error + Send + Sync>>
127131
where
128132
S: AsRef<str> + Send,
129133
{
130-
let mysql_connection = conn.client.get_conn().await?;
134+
let mysql_connection = conn.0.get_conn().await?;
131135
let mysql_stmt = generate_mysql_stmt(stmt.as_ref(), params)?;
132136

133137
Ok(mysql_stmt.run(mysql_connection).await?.affected_rows())

0 commit comments

Comments
 (0)