@@ -24,7 +24,8 @@ def list_params() -> bool:
2424 {"name" : "username" , "type" : "string" , "required" : False , "default" : "" , "description" : "" },
2525 {"name" : "password" , "type" : "string" , "required" : False , "default" : "" , "description" : "" },
2626 {"name" : "database" , "type" : "string" , "required" : True , "default" : "" , "description" : "" },
27- {"name" : "collection" , "type" : "string" , "required" : False , "default" : "" , "description" : "If specified, only this collection will be accessed" }
27+ {"name" : "collection" , "type" : "string" , "required" : False , "default" : "" , "description" : "If specified, only this collection will be accessed" },
28+ {"name" : "authSource" , "type" : "string" , "required" : False , "default" : "" , "description" : "Authentication database (defaults to target database if empty)" }
2829 ]
2930 return params_list
3031
@@ -67,9 +68,17 @@ def __init__(self, params: Dict[str, Any], duck_db_conn: duckdb.DuckDBPyConnecti
6768 password = self .params .get ("password" , "" )
6869 database = self .params .get ("database" , "" )
6970 collection = self .params .get ("collection" , "" )
71+ auth_source = self .params .get ("authSource" , "" ) or database # Default to target database
7072
7173 if username and password :
72- self .mongo_client = pymongo .MongoClient (host = host , port = port , username = username , password = password )
74+ # Use authSource to specify which database contains user credentials
75+ self .mongo_client = pymongo .MongoClient (
76+ host = host ,
77+ port = port ,
78+ username = username ,
79+ password = password ,
80+ authSource = auth_source
81+ )
7382 else :
7483 self .mongo_client = pymongo .MongoClient (host = host , port = port )
7584
@@ -81,6 +90,28 @@ def __init__(self, params: Dict[str, Any], duck_db_conn: duckdb.DuckDBPyConnecti
8190 except Exception as e :
8291 raise Exception (f"Failed to connect to MongoDB: { e } " )
8392
93+ def close (self ):
94+ """Close the MongoDB connection"""
95+ if hasattr (self , 'mongo_client' ) and self .mongo_client is not None :
96+ try :
97+ self .mongo_client .close ()
98+ self .mongo_client = None
99+ except Exception as e :
100+ print (f"Warning: Failed to close MongoDB connection: { e } " )
101+
102+ def __enter__ (self ):
103+ """Context manager entry"""
104+ return self
105+
106+ def __exit__ (self , exc_type , exc_val , exc_tb ):
107+ """Context manager exit - ensures connection is closed"""
108+ self .close ()
109+ return False
110+
111+ def __del__ (self ):
112+ """Destructor to ensure connection is closed"""
113+ self .close ()
114+
84115 @staticmethod
85116 def _flatten_document (doc : Dict [str , Any ], parent_key : str = '' , sep : str = '_' ) -> Dict [str , Any ]:
86117 """
@@ -237,7 +268,7 @@ def ingest_data(self, table_name: str, name_as: Optional[str] = None, size: int
237268 return
238269
239270
240- def view_query_sample (self , query : str ) -> str :
271+ def view_query_sample (self , query : str ) -> List [ Dict [ str , Any ]] :
241272
242273 self ._existed_collections_in_duckdb ()
243274 self ._difference_collections ()
@@ -271,17 +302,27 @@ def ingest_data_from_query(self, query: str, name_as: str) -> pd.DataFrame:
271302 self ._difference_collections ()
272303 self ._preload_all_collections (self .collection .name if self .collection else "" )
273304
274- df = self .duck_db_conn .execute (query ).df ()
305+ query_result_df = self .duck_db_conn .execute (query ).df ()
275306
276307 self ._drop_all_loaded_tables ()
277308
278- for collection_name , df in self .existed_collections .items ():
279- self ._load_dataframe_to_duckdb (df , collection_name )
309+ for collection_name , existing_df in self .existed_collections .items ():
310+ self ._load_dataframe_to_duckdb (existing_df , collection_name )
280311
281- self ._load_dataframe_to_duckdb (df , name_as )
312+ self ._load_dataframe_to_duckdb (query_result_df , name_as )
282313
283- return df
314+ return query_result_df
284315
316+ @staticmethod
317+ def _quote_identifier (name : str ) -> str :
318+ """
319+ Safely quote a SQL identifier to prevent SQL injection.
320+ Double quotes are escaped by doubling them.
321+ """
322+ # Escape any double quotes in the identifier by doubling them
323+ escaped = name .replace ('"' , '""' )
324+ return f'"{ escaped } "'
325+
285326 def _existed_collections_in_duckdb (self ):
286327 """
287328 Return the names and contents of tables already loaded into DuckDB
@@ -290,7 +331,8 @@ def _existed_collections_in_duckdb(self):
290331 duckdb_tables = self .duck_db_conn .execute ("SHOW TABLES" ).df ()
291332 for _ , row in duckdb_tables .iterrows ():
292333 collection_name = row ['name' ]
293- df = self .duck_db_conn .execute (f"SELECT * FROM { collection_name } " ).df ()
334+ quoted_name = self ._quote_identifier (collection_name )
335+ df = self .duck_db_conn .execute (f"SELECT * FROM { quoted_name } " ).df ()
294336 self .existed_collections [collection_name ] = df
295337
296338
@@ -311,7 +353,8 @@ def _drop_all_loaded_tables(self):
311353 """
312354 for table_name in self .loaded_tables .values ():
313355 try :
314- self .duck_db_conn .execute (f"DROP TABLE IF EXISTS main.{ table_name } " )
356+ quoted_name = self ._quote_identifier (table_name )
357+ self .duck_db_conn .execute (f"DROP TABLE IF EXISTS main.{ quoted_name } " )
315358 print (f"Dropped loaded table: { table_name } " )
316359 except Exception as e :
317360 print (f"Warning: Failed to drop table '{ table_name } ': { e } " )
@@ -366,5 +409,10 @@ def _load_dataframe_to_duckdb(self, df: pd.DataFrame, table_name: str, size: int
366409
367410 self .duck_db_conn .register (temp_view_name , df )
368411 # Use CREATE OR REPLACE to directly replace existing table
369- self .duck_db_conn .execute (f"CREATE OR REPLACE TABLE main.{ table_name } AS SELECT * FROM { temp_view_name } LIMIT { size } " )
370- self .duck_db_conn .execute (f"DROP VIEW { temp_view_name } " )
412+ # Quote identifiers to prevent SQL injection
413+ quoted_table_name = self ._quote_identifier (table_name )
414+ quoted_temp_view = self ._quote_identifier (temp_view_name )
415+ # Ensure size is an integer to prevent injection via size parameter
416+ safe_size = int (size )
417+ self .duck_db_conn .execute (f"CREATE OR REPLACE TABLE main.{ quoted_table_name } AS SELECT * FROM { quoted_temp_view } LIMIT { safe_size } " )
418+ self .duck_db_conn .execute (f"DROP VIEW { quoted_temp_view } " )
0 commit comments