Skip to content

Commit 00b66a6

Browse files
committed
refactor: add in backend kwargs for readers to allow reader args not determinable at config write time to be passed
1 parent 8fa895e commit 00b66a6

7 files changed

Lines changed: 110 additions & 102 deletions

File tree

poetry.lock

Lines changed: 81 additions & 89 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/dve/core_engine/backends/implementations/duckdb/readers/csv.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,7 @@
66

77
import duckdb as ddb
88
import polars as pl
9-
from duckdb import (
10-
DuckDBPyConnection,
11-
DuckDBPyRelation,
12-
StarExpression,
13-
read_csv,
14-
)
9+
from duckdb import DuckDBPyConnection, DuckDBPyRelation, StarExpression, read_csv
1510
from pydantic import BaseModel
1611

1712
from dve.core_engine.backends.base.reader import BaseFileReader, read_function

src/dve/core_engine/backends/implementations/duckdb/readers/json.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,7 @@ def read_to_relation( # pylint: disable=unused-argument
5353
}
5454

5555
return self.add_record_index(
56-
self._connection.read_json(resource, columns=ddb_schema, format=self._json_format) # type: ignore
56+
self._connection.read_json(
57+
resource, columns=ddb_schema, format=self._json_format # type: ignore
58+
)
5759
)

src/dve/pipeline/duckdb_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
submitted_files_path,
4646
job_run_id,
4747
logger,
48+
{"connection": self._connection},
4849
)
4950

5051
def init_reference_data_loader(

src/dve/pipeline/pipeline.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from functools import lru_cache
1010
from itertools import starmap
1111
from threading import Lock
12-
from typing import Optional, Union
12+
from typing import Any, Optional, Union
1313
from uuid import uuid4
1414

1515
import polars as pl
@@ -49,6 +49,7 @@
4949
)
5050

5151

52+
# pylint: disable=R0904
5253
class BaseDVEPipeline:
5354
"""
5455
Base class for running a DVE Pipeline either by a given step or a full e2e process.
@@ -64,6 +65,7 @@ def __init__(
6465
submitted_files_path: Optional[URI],
6566
job_run_id: Optional[int] = None,
6667
logger: Optional[logging.Logger] = None,
68+
backend_reader_kwargs: Optional[dict[str, Any]] = None,
6769
):
6870
self._submitted_files_path = submitted_files_path
6971
self._processed_files_path = processed_files_path
@@ -76,6 +78,7 @@ def __init__(
7678
self._summary_lock = Lock()
7779
self._rec_tracking_lock = Lock()
7880
self._aggregates_lock = Lock()
81+
self._backend_reader_kwargs = backend_reader_kwargs
7982

8083
if self._data_contract:
8184
self._data_contract.logger = self._logger
@@ -107,6 +110,12 @@ def step_implementations(self) -> Optional[BaseStepImplementations[EntityType]]:
107110
"""The step implementations to apply the business rules to a given dataset"""
108111
return self._step_implementations
109112

113+
@property
114+
def backend_reader_kwargs(self) -> dict[str, Any] | None:
115+
"""Important required arguments for all readers related to the specific backend
116+
that can't be specified at time of writing config eg. duckdb connection"""
117+
return self._backend_reader_kwargs
118+
110119
@staticmethod
111120
def get_entity_count(entity: EntityType) -> int:
112121
"""Get a row count of an entity stored as parquet"""
@@ -203,7 +212,9 @@ def write_file_to_parquet(
203212

204213
for model_name, model in models.items():
205214
self._logger.info(f"Transforming {model_name} to stringified parquet")
206-
reader: BaseFileReader = load_reader(dataset, model_name, ext)
215+
reader: BaseFileReader = load_reader(
216+
dataset, model_name, ext, self.backend_reader_kwargs
217+
)
207218
try:
208219
if not entity_type:
209220
reader.write_parquet(

src/dve/pipeline/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import json
55
from threading import Lock
6-
from typing import Optional
6+
from typing import Any, Optional
77

88
from pydantic.main import ModelMetaclass
99
from pyspark.sql import SparkSession
@@ -45,10 +45,17 @@ def load_config(
4545
return models, config, dataset
4646

4747

48-
def load_reader(dataset: Dataset, model_name: str, file_extension: str):
48+
def load_reader(
49+
dataset: Dataset,
50+
model_name: str,
51+
file_extension: str,
52+
backend_reader_kwargs: Optional[dict[str, Any]] = None,
53+
):
4954
"""Loads the readers for the diven feed, model name and file extension"""
5055
reader_config = dataset[model_name].reader_config[f".{file_extension.lower()}"]
51-
reader = _READER_REGISTRY[reader_config.reader](**reader_config.kwargs_)
56+
reader = _READER_REGISTRY[reader_config.reader](
57+
**reader_config.kwargs_, **backend_reader_kwargs if backend_reader_kwargs else {}
58+
)
5259
return reader
5360

5461

tests/test_pipeline/test_foundry_ddb_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
@pytest.fixture(scope="function")
3232
def prep_multithreading_test():
3333
sub_details: dict[str, tuple[DuckDBPyConnection, str, DDBAuditingManager]] = {}
34-
for idx in range(1, 10):
34+
for idx in range(1, 4):
3535
db = f"dve_{uuid4().hex}"
3636
tmp_dir = tempfile.mkdtemp(prefix="ddb_foundry_testing")
3737
db_file = Path(tmp_dir, db + ".duckdb")

0 commit comments

Comments
 (0)