Skip to content

Commit 9dab794

Browse files
committed
fix and tested mongodb data loader
1 parent 6a5f0ed commit 9dab794

2 files changed

Lines changed: 132 additions & 42 deletions

File tree

py-src/data_formulator/data_loader/mongodb_data_loader.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def list_params() -> bool:
2424
{"name": "username", "type": "string", "required": False, "default": "", "description": ""},
2525
{"name": "password", "type": "string", "required": False, "default": "", "description": ""},
2626
{"name": "database", "type": "string", "required": True, "default": "", "description": ""},
27-
{"name": "collection", "type": "string", "required": False, "default": "", "description": "If specified, only this collection will be accessed"}
27+
{"name": "collection", "type": "string", "required": False, "default": "", "description": "If specified, only this collection will be accessed"},
28+
{"name": "authSource", "type": "string", "required": False, "default": "", "description": "Authentication database (defaults to target database if empty)"}
2829
]
2930
return params_list
3031

@@ -67,9 +68,17 @@ def __init__(self, params: Dict[str, Any], duck_db_conn: duckdb.DuckDBPyConnecti
6768
password = self.params.get("password", "")
6869
database = self.params.get("database", "")
6970
collection = self.params.get("collection", "")
71+
auth_source = self.params.get("authSource", "") or database # Default to target database
7072

7173
if username and password:
72-
self.mongo_client = pymongo.MongoClient(host=host, port=port, username=username, password=password)
74+
# Use authSource to specify which database contains user credentials
75+
self.mongo_client = pymongo.MongoClient(
76+
host=host,
77+
port=port,
78+
username=username,
79+
password=password,
80+
authSource=auth_source
81+
)
7382
else:
7483
self.mongo_client = pymongo.MongoClient(host=host, port=port)
7584

@@ -81,6 +90,28 @@ def __init__(self, params: Dict[str, Any], duck_db_conn: duckdb.DuckDBPyConnecti
8190
except Exception as e:
8291
raise Exception(f"Failed to connect to MongoDB: {e}")
8392

93+
def close(self):
94+
"""Close the MongoDB connection"""
95+
if hasattr(self, 'mongo_client') and self.mongo_client is not None:
96+
try:
97+
self.mongo_client.close()
98+
self.mongo_client = None
99+
except Exception as e:
100+
print(f"Warning: Failed to close MongoDB connection: {e}")
101+
102+
def __enter__(self):
103+
"""Context manager entry"""
104+
return self
105+
106+
def __exit__(self, exc_type, exc_val, exc_tb):
107+
"""Context manager exit - ensures connection is closed"""
108+
self.close()
109+
return False
110+
111+
def __del__(self):
112+
"""Destructor to ensure connection is closed"""
113+
self.close()
114+
84115
@staticmethod
85116
def _flatten_document(doc: Dict[str, Any], parent_key: str = '', sep: str = '_') -> Dict[str, Any]:
86117
"""
@@ -237,7 +268,7 @@ def ingest_data(self, table_name: str, name_as: Optional[str] = None, size: int
237268
return
238269

239270

240-
def view_query_sample(self, query: str) -> str:
271+
def view_query_sample(self, query: str) -> List[Dict[str, Any]]:
241272

242273
self._existed_collections_in_duckdb()
243274
self._difference_collections()
@@ -271,17 +302,27 @@ def ingest_data_from_query(self, query: str, name_as: str) -> pd.DataFrame:
271302
self._difference_collections()
272303
self._preload_all_collections(self.collection.name if self.collection else "")
273304

274-
df = self.duck_db_conn.execute(query).df()
305+
query_result_df = self.duck_db_conn.execute(query).df()
275306

276307
self._drop_all_loaded_tables()
277308

278-
for collection_name, df in self.existed_collections.items():
279-
self._load_dataframe_to_duckdb(df, collection_name)
309+
for collection_name, existing_df in self.existed_collections.items():
310+
self._load_dataframe_to_duckdb(existing_df, collection_name)
280311

281-
self._load_dataframe_to_duckdb(df, name_as)
312+
self._load_dataframe_to_duckdb(query_result_df, name_as)
282313

283-
return df
314+
return query_result_df
284315

316+
@staticmethod
317+
def _quote_identifier(name: str) -> str:
318+
"""
319+
Safely quote a SQL identifier to prevent SQL injection.
320+
Double quotes are escaped by doubling them.
321+
"""
322+
# Escape any double quotes in the identifier by doubling them
323+
escaped = name.replace('"', '""')
324+
return f'"{escaped}"'
325+
285326
def _existed_collections_in_duckdb(self):
286327
"""
287328
Return the names and contents of tables already loaded into DuckDB
@@ -290,7 +331,8 @@ def _existed_collections_in_duckdb(self):
290331
duckdb_tables = self.duck_db_conn.execute("SHOW TABLES").df()
291332
for _, row in duckdb_tables.iterrows():
292333
collection_name = row['name']
293-
df = self.duck_db_conn.execute(f"SELECT * FROM {collection_name}").df()
334+
quoted_name = self._quote_identifier(collection_name)
335+
df = self.duck_db_conn.execute(f"SELECT * FROM {quoted_name}").df()
294336
self.existed_collections[collection_name] = df
295337

296338

@@ -311,7 +353,8 @@ def _drop_all_loaded_tables(self):
311353
"""
312354
for table_name in self.loaded_tables.values():
313355
try:
314-
self.duck_db_conn.execute(f"DROP TABLE IF EXISTS main.{table_name}")
356+
quoted_name = self._quote_identifier(table_name)
357+
self.duck_db_conn.execute(f"DROP TABLE IF EXISTS main.{quoted_name}")
315358
print(f"Dropped loaded table: {table_name}")
316359
except Exception as e:
317360
print(f"Warning: Failed to drop table '{table_name}': {e}")
@@ -366,5 +409,10 @@ def _load_dataframe_to_duckdb(self, df: pd.DataFrame, table_name: str, size: int
366409

367410
self.duck_db_conn.register(temp_view_name, df)
368411
# Use CREATE OR REPLACE to directly replace existing table
369-
self.duck_db_conn.execute(f"CREATE OR REPLACE TABLE main.{table_name} AS SELECT * FROM {temp_view_name} LIMIT {size}")
370-
self.duck_db_conn.execute(f"DROP VIEW {temp_view_name}")
412+
# Quote identifiers to prevent SQL injection
413+
quoted_table_name = self._quote_identifier(table_name)
414+
quoted_temp_view = self._quote_identifier(temp_view_name)
415+
# Ensure size is an integer to prevent injection via size parameter
416+
safe_size = int(size)
417+
self.duck_db_conn.execute(f"CREATE OR REPLACE TABLE main.{quoted_table_name} AS SELECT * FROM {quoted_temp_view} LIMIT {safe_size}")
418+
self.duck_db_conn.execute(f"DROP VIEW {quoted_temp_view}")

src/views/DBTableManager.tsx

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ import {
3434
ToggleButtonGroup,
3535
ToggleButton,
3636
useTheme,
37-
Link
37+
Link,
38+
Checkbox
3839
} from '@mui/material';
3940

4041
import DeleteIcon from '@mui/icons-material/Delete';
@@ -820,9 +821,14 @@ export const DBTableSelectionDialog: React.FC<{
820821
onImport={() => {
821822
setIsUploading(true);
822823
}}
823-
onFinish={(status, message) => {
824+
onFinish={(status, message, importedTables) => {
824825
setIsUploading(false);
825-
fetchTables();
826+
fetchTables().then(() => {
827+
// Navigate to the first imported table after tables are fetched
828+
if (status === "success" && importedTables && importedTables.length > 0) {
829+
setSelectedTabKey(importedTables[0]);
830+
}
831+
});
826832
if (status === "error") {
827833
setSystemMessage(message, "error");
828834
}
@@ -1030,7 +1036,7 @@ export const DataLoaderForm: React.FC<{
10301036
paramDefs: {name: string, default: string, type: string, required: boolean, description: string}[],
10311037
authInstructions: string,
10321038
onImport: () => void,
1033-
onFinish: (status: "success" | "error", message: string) => void
1039+
onFinish: (status: "success" | "error", message: string, importedTables?: string[]) => void
10341040
}> = ({dataLoaderType, paramDefs, authInstructions, onImport, onFinish}) => {
10351041

10361042
const dispatch = useDispatch();
@@ -1039,6 +1045,7 @@ export const DataLoaderForm: React.FC<{
10391045

10401046
const [tableMetadata, setTableMetadata] = useState<Record<string, any>>({}); let [displaySamples, setDisplaySamples] = useState<Record<string, boolean>>({});
10411047
let [tableFilter, setTableFilter] = useState<string>("");
1048+
const [selectedTables, setSelectedTables] = useState<Set<string>>(new Set());
10421049

10431050
const [displayAuthInstructions, setDisplayAuthInstructions] = useState(false);
10441051

@@ -1086,7 +1093,12 @@ export const DataLoaderForm: React.FC<{
10861093
return [
10871094
<TableRow
10881095
key={tableName}
1089-
sx={{ '&:last-child td, &:last-child th': { border: 0 }, '& .MuiTableCell-root': { padding: 0.25, wordWrap: 'break-word', whiteSpace: 'normal' }}}
1096+
sx={{
1097+
'&:last-child td, &:last-child th': { border: 0 },
1098+
'& .MuiTableCell-root': { padding: 0.25, wordWrap: 'break-word', whiteSpace: 'normal' },
1099+
backgroundColor: selectedTables.has(tableName) ? 'action.selected' : 'inherit',
1100+
'&:hover': { backgroundColor: selectedTables.has(tableName) ? 'action.selected' : 'action.hover' }
1101+
}}
10901102
>
10911103
<TableCell sx={{borderBottom: displaySamples[tableName] ? 'none' : '1px solid rgba(0, 0, 0, 0.1)'}}>
10921104
<IconButton size="small" onClick={() => toggleDisplaySamples(tableName)}>
@@ -1103,33 +1115,20 @@ export const DataLoaderForm: React.FC<{
11031115
<Chip key={column.name} label={column.name} sx={{fontSize: 11, margin: 0.25, height: 20}} size="small" />
11041116
))}
11051117
</TableCell>
1106-
<TableCell sx={{width: 60}}>
1107-
<Button size="small" onClick={() => {
1108-
onImport();
1109-
fetch(getUrls().DATA_LOADER_INGEST_DATA, {
1110-
method: 'POST',
1111-
headers: {
1112-
'Content-Type': 'application/json',
1113-
},
1114-
body: JSON.stringify({
1115-
data_loader_type: dataLoaderType,
1116-
data_loader_params: params, table_name: tableName
1117-
})
1118-
})
1119-
.then(response => response.json())
1120-
.then(data => {
1121-
1122-
if (data.status === "success") {
1123-
onFinish("success", "Data ingested successfully");
1118+
<TableCell sx={{width: 40}} padding="checkbox">
1119+
<Checkbox
1120+
size="small"
1121+
checked={selectedTables.has(tableName)}
1122+
onChange={(e) => {
1123+
const newSelected = new Set(selectedTables);
1124+
if (e.target.checked) {
1125+
newSelected.add(tableName);
11241126
} else {
1125-
onFinish("error", data.error);
1127+
newSelected.delete(tableName);
11261128
}
1127-
})
1128-
.catch(error => {
1129-
console.error('Failed to ingest data:', error);
1130-
onFinish("error", `Failed to ingest data: ${error}`);
1131-
});
1132-
}}>Import</Button>
1129+
setSelectedTables(newSelected);
1130+
}}
1131+
/>
11331132
</TableCell>
11341133
</TableRow>,
11351134
<TableRow key={`${tableName}-sample`}>
@@ -1155,6 +1154,49 @@ export const DataLoaderForm: React.FC<{
11551154
</TableBody>
11561155
</Table>
11571156
</TableContainer>,
1157+
mode === "view tables" && Object.keys(tableMetadata).length > 0 && <Box sx={{ display: 'flex', justifyContent: 'flex-end', mt: 1 }}>
1158+
<Button
1159+
variant="contained"
1160+
size="small"
1161+
disabled={selectedTables.size === 0}
1162+
onClick={() => {
1163+
const tablesToImport = Array.from(selectedTables);
1164+
onImport();
1165+
1166+
// Import all selected tables sequentially
1167+
const importPromises = tablesToImport.map(tableName =>
1168+
fetch(getUrls().DATA_LOADER_INGEST_DATA, {
1169+
method: 'POST',
1170+
headers: {
1171+
'Content-Type': 'application/json',
1172+
},
1173+
body: JSON.stringify({
1174+
data_loader_type: dataLoaderType,
1175+
data_loader_params: params,
1176+
table_name: tableName
1177+
})
1178+
}).then(response => response.json())
1179+
);
1180+
1181+
Promise.all(importPromises)
1182+
.then(results => {
1183+
const errors = results.filter(r => r.status !== "success");
1184+
if (errors.length === 0) {
1185+
setSelectedTables(new Set());
1186+
onFinish("success", `Successfully imported ${tablesToImport.length} table(s)`, tablesToImport);
1187+
} else {
1188+
onFinish("error", `Failed to import some tables: ${errors.map(e => e.error).join(", ")}`);
1189+
}
1190+
})
1191+
.catch(error => {
1192+
console.error('Failed to ingest data:', error);
1193+
onFinish("error", `Failed to ingest data: ${error}`);
1194+
});
1195+
}}
1196+
>
1197+
Import Selected ({selectedTables.size})
1198+
</Button>
1199+
</Box>,
11581200
mode === "query" && <DataQueryForm
11591201
dataLoaderType={dataLoaderType}
11601202
availableTables={Object.keys(tableMetadata).map(t => ({name: t, fields: tableMetadata[t].columns.map((c: any) => c.name)}))}

0 commit comments

Comments
 (0)