Skip to content

Commit 56e217f

Browse files
committed
fix: ddb xml reader connection args consistent with other ddb readers
1 parent 00b66a6 commit 56e217f

3 files changed

Lines changed: 11 additions & 10 deletions

File tree

src/dve/core_engine/backends/implementations/duckdb/readers/xml.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from typing import Optional
55

6+
import duckdb
67
import polars as pl
78
from duckdb import DuckDBPyConnection, DuckDBPyRelation, default_connection
89
from pydantic import BaseModel
@@ -24,8 +25,8 @@
2425
class DuckDBXMLStreamReader(XMLStreamReader):
2526
"""A reader for XML files"""
2627

27-
def __init__(self, *, ddb_connection: Optional[DuckDBPyConnection] = None, **kwargs):
28-
self.ddb_connection = ddb_connection if ddb_connection else default_connection
28+
def __init__(self, *, connection: Optional[DuckDBPyConnection] = None, **kwargs):
29+
self._connection = connection if connection else duckdb.connect(":memory:")
2930
super().__init__(**kwargs)
3031

3132
@read_function(DuckDBPyRelation)
@@ -49,4 +50,4 @@ def read_to_relation(self, resource: URI, entity_name: str, schema: type[BaseMod
4950
data=self.read_to_py_iterator(resource, entity_name, schema), schema=polars_schema
5051
)
5152
)
52-
return self.ddb_connection.sql("select * from _lazy_frame")
53+
return self._connection.sql("select * from _lazy_frame")

tests/test_core_engine/test_backends/test_readers/test_ddb_xml.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def test_ddb_xml_reader_all_str(temp_xml_file):
6767
uri, header_model, header_data, class_data_model, class_data = temp_xml_file
6868
ddb_conn = default_connection
6969
header_reader = DuckDBXMLStreamReader(
70-
ddb_connection=ddb_conn, root_tag="root", record_tag="Header"
70+
connection=ddb_conn, root_tag="root", record_tag="Header"
7171
)
7272
class_reader = DuckDBXMLStreamReader(
73-
ddb_connection=ddb_conn, root_tag="root", record_tag="ClassData"
73+
connection=ddb_conn, root_tag="root", record_tag="ClassData"
7474
)
7575
header_rel: DuckDBPyRelation = header_reader.read_to_relation(
7676
uri.as_uri(), "header", header_model
@@ -90,10 +90,10 @@ def test_ddb_xml_reader_write_parquet(temp_xml_file):
9090
uri, header_model, header_data, class_data_model, class_data = temp_xml_file
9191
ddb_conn = default_connection
9292
header_reader = DuckDBXMLStreamReader(
93-
ddb_connection=ddb_conn, root_tag="root", record_tag="Header"
93+
connection=ddb_conn, root_tag="root", record_tag="Header"
9494
)
9595
class_reader = DuckDBXMLStreamReader(
96-
ddb_connection=ddb_conn, root_tag="root", record_tag="ClassData"
96+
connection=ddb_conn, root_tag="root", record_tag="ClassData"
9797
)
9898
header_rel: DuckDBPyRelation = header_reader.read_to_relation(
9999
uri.as_uri(), "header", header_model
@@ -105,10 +105,10 @@ def test_ddb_xml_reader_write_parquet(temp_xml_file):
105105
target_class_loc: Path = uri.parent.joinpath("class_parquet.parquet").as_posix()
106106
header_reader.write_parquet(entity=header_rel, target_location=target_header_loc)
107107
class_reader.write_parquet(entity=class_rel, target_location=target_class_loc)
108-
header_parquet_rel: DuckDBPyRelation = header_reader.ddb_connection.read_parquet(
108+
header_parquet_rel: DuckDBPyRelation = header_reader._connection.read_parquet(
109109
target_header_loc
110110
)
111-
class_parquet_rel: DuckDBPyRelation = class_reader.ddb_connection.read_parquet(target_class_loc)
111+
class_parquet_rel: DuckDBPyRelation = class_reader._connection.read_parquet(target_class_loc)
112112
assert header_parquet_rel.df().to_dict(orient="records") == header_rel.df().to_dict(
113113
orient="records"
114114
)

tests/test_pipeline/test_foundry_ddb_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
@pytest.fixture(scope="function")
3232
def prep_multithreading_test():
3333
sub_details: dict[str, tuple[DuckDBPyConnection, str, DDBAuditingManager]] = {}
34-
for idx in range(1, 4):
34+
for idx in range(1, 10):
3535
db = f"dve_{uuid4().hex}"
3636
tmp_dir = tempfile.mkdtemp(prefix="ddb_foundry_testing")
3737
db_file = Path(tmp_dir, db + ".duckdb")

0 commit comments

Comments
 (0)