Skip to content

Commit 32fe565

Browse files
authored
fix: configured refdata loader to be instantiated when required without need for class vars (#98)
* refactor: configured refdata loader to be instantiated when required without need for class vars * style: address formatting, linting and type checking issues * style: address review comments and linting issues
1 parent 61c0523 commit 32fe565

18 files changed

Lines changed: 186 additions & 226 deletions

File tree

src/dve/core_engine/backends/base/backend.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import warnings
55
from abc import ABC, abstractmethod
6-
from collections.abc import Mapping, MutableMapping
6+
from collections.abc import MutableMapping
77
from typing import Any, ClassVar, Generic, Optional
88

99
from pyspark.sql import DataFrame, SparkSession
@@ -41,14 +41,12 @@ def __init__( # pylint: disable=unused-argument
4141
self,
4242
contract: BaseDataContract[EntityType],
4343
steps: BaseStepImplementations[EntityType],
44-
reference_data_loader_type: Optional[type[BaseRefDataLoader[EntityType]]],
4544
logger: Optional[logging.Logger] = None,
4645
**kwargs: Any,
4746
) -> None:
4847
for component_name, component in (
4948
("Contract", contract),
5049
("Step implementation", steps),
51-
("Reference data loader", reference_data_loader_type),
5250
):
5351
component_entity_type = getattr(component, "__entity_type__", None)
5452
if component_entity_type != self.__entity_type__:
@@ -61,42 +59,16 @@ def __init__( # pylint: disable=unused-argument
6159
"""The data contract implementation used by the backend."""
6260
self.step_implementations = steps
6361
"""The step implementations used by the backend."""
64-
self.reference_data_loader_type = reference_data_loader_type
65-
"""
66-
The loader type to use for the reference data. If `None`, do not
67-
load any reference data and error if it is provided.
68-
69-
"""
7062
self.logger = logger or get_logger(type(self).__name__)
7163
"""The `logging.Logger instance for the backend."""
7264

7365
def load_reference_data(
7466
self,
7567
reference_entity_config: dict[EntityName, ReferenceConfigUnion],
7668
submission_info: Optional[SubmissionInfo],
77-
) -> Mapping[EntityName, EntityType]:
78-
"""Load the reference data as specified in the reference entity config."""
79-
sub_info_entity: Optional[EntityType] = None
80-
if submission_info:
81-
sub_info_entity = self.convert_submission_info(submission_info)
82-
83-
if self.reference_data_loader_type is None:
84-
if reference_entity_config:
85-
raise ValueError(
86-
"Reference data has been specified but no reference data loader is "
87-
+ "configured for this backend"
88-
)
89-
90-
reference_data_dict = {}
91-
if sub_info_entity is not None:
92-
reference_data_dict["dve_submission_info"] = sub_info_entity
93-
return reference_data_dict
94-
95-
reference_data_loader = self.reference_data_loader_type(reference_entity_config)
96-
if sub_info_entity is not None:
97-
reference_data_loader.entity_cache["dve_submission_info"] = sub_info_entity
98-
99-
return reference_data_loader
69+
) -> BaseRefDataLoader[EntityType]:
70+
"""Supply configured reference data loader for use with business rules"""
71+
raise NotImplementedError()
10072

10173
@abstractmethod
10274
def convert_submission_info(self, submission_info: SubmissionInfo) -> EntityType:

src/dve/core_engine/backends/base/reference_data.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from dve.core_engine.backends.base.core import get_entity_type
1212
from dve.core_engine.backends.exceptions import (
1313
MissingRefDataEntity,
14+
NoRefDataConfigSupplied,
1415
RefdataLacksFileExtensionSupport,
1516
)
1617
from dve.core_engine.backends.types import EntityType
@@ -147,11 +148,11 @@ class variable for the subclass.
147148
# pylint: disable=unused-argument
148149
def __init__(
149150
self,
150-
reference_entity_config: dict[EntityName, ReferenceConfig],
151-
dataset_config_uri: Optional[URI] = None,
151+
reference_data_config: dict[EntityName, ReferenceConfig],
152+
dataset_config_uri: URI,
152153
**kwargs,
153154
) -> None:
154-
self.reference_entity_config = reference_entity_config
155+
self.reference_entity_config = reference_data_config
155156
self.dataset_config_uri = dataset_config_uri
156157
"""
157158
Configuration options for the reference data. This is likely to vary
@@ -207,6 +208,8 @@ def __getitem__(self, key: EntityName) -> EntityType:
207208
try:
208209
config = self.reference_entity_config[key]
209210
return self.load_entity(entity_name=key, config=config)
211+
except TypeError as err:
212+
raise NoRefDataConfigSupplied() from err
210213
except Exception as err:
211214
raise MissingRefDataEntity(entity_name=key) from err
212215

src/dve/core_engine/backends/exceptions.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,20 @@ def get_message_preamble(self) -> str:
119119
return f"Missing reference data entity {self.entity_name!r}"
120120

121121

122+
class NoRefDataConfigSupplied(BackendError):
123+
"""An error raised when trying to load a refdata entity when no refdata
124+
config has been supplied.
125+
126+
"""
127+
128+
def __init__(self, *args: object) -> None:
129+
super().__init__(*args)
130+
131+
def get_message_preamble(self) -> EntityName:
132+
"""Message for logging purposes"""
133+
return "Refdata loader not supplied with refdata config - unable to load refdata entities"
134+
135+
122136
class ConstraintError(ValueError, BackendErrorMixin):
123137
"""Raised when a given constraint is violated."""
124138

src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def get_duckdb_cast_statement_from_annotation(
411411
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301
412412
return stmt
413413
if issubclass(type_, time):
414-
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301
414+
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301
415415
return stmt
416416
duck_type = get_duckdb_type_from_annotation(type_)
417417
if duck_type:

src/dve/core_engine/backends/implementations/duckdb/reference_data.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
"""A reference data loader for duckdb."""
22

3-
from typing import Optional
4-
53
from duckdb import DuckDBPyConnection, DuckDBPyRelation
64
from pyarrow import ipc # type: ignore
75

86
from dve.core_engine.backends.base.reference_data import (
97
BaseRefDataLoader,
10-
ReferenceConfigUnion,
8+
ReferenceConfig,
119
ReferenceTable,
1210
mark_refdata_file_extension,
1311
)
@@ -19,17 +17,16 @@
1917
class DuckDBRefDataLoader(BaseRefDataLoader[DuckDBPyRelation]):
2018
"""A reference data loader using already existing DuckDB tables."""
2119

22-
connection: DuckDBPyConnection
23-
"""The DuckDB connection for the backend."""
24-
dataset_config_uri: Optional[URI] = None
25-
"""The location of the dischema file"""
26-
2720
def __init__(
2821
self,
29-
reference_entity_config: dict[EntityName, ReferenceConfigUnion],
22+
connection: DuckDBPyConnection,
23+
reference_data_config: dict[EntityName, ReferenceConfig],
24+
dataset_config_uri: URI,
3025
**kwargs,
3126
) -> None:
32-
super().__init__(reference_entity_config, self.dataset_config_uri, **kwargs)
27+
super().__init__(reference_data_config, dataset_config_uri, **kwargs)
28+
29+
self.connection = connection
3330

3431
if not self.connection:
3532
raise AttributeError("DuckDBConnection must be specified")

src/dve/core_engine/backends/implementations/spark/backend.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pyspark.sql import DataFrame, SparkSession
77

88
from dve.core_engine.backends.base.backend import BaseBackend
9+
from dve.core_engine.backends.base.reference_data import ReferenceConfigUnion
910
from dve.core_engine.backends.implementations.spark.contract import SparkDataContract
1011
from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader
1112
from dve.core_engine.backends.implementations.spark.rules import SparkStepImplementations
@@ -14,7 +15,7 @@
1415
from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME
1516
from dve.core_engine.loggers import get_child_logger, get_logger
1617
from dve.core_engine.models import SubmissionInfo
17-
from dve.core_engine.type_hints import URI, EntityParquetLocations
18+
from dve.core_engine.type_hints import URI, EntityName, EntityParquetLocations
1819
from dve.parser.file_handling import get_resource_exists, joinuri
1920

2021

@@ -26,7 +27,6 @@ def __init__(
2627
dataset_config_uri: Optional[URI] = None,
2728
contract: Optional[SparkDataContract] = None,
2829
steps: Optional[SparkStepImplementations] = None,
29-
reference_data_loader: Optional[type[SparkRefDataLoader]] = None,
3030
logger: Optional[logging.Logger] = None,
3131
spark_session: Optional[SparkSession] = None,
3232
**kwargs: Any,
@@ -36,6 +36,8 @@ def __init__(
3636

3737
self.spark_session = spark_session or SparkSession.builder.getOrCreate()
3838
"""The Spark session for the backend."""
39+
self.dataset_config_uri = dataset_config_uri
40+
"""The uri of the dischema specifying the DVE config"""
3941

4042
if contract is None:
4143
contract = SparkDataContract(
@@ -46,11 +48,27 @@ def __init__(
4648
steps = SparkStepImplementations.register_udfs(
4749
logger=get_child_logger("SparkStepImplementations", logger)
4850
)
49-
if reference_data_loader is None:
50-
reference_data_loader = SparkRefDataLoader
51-
reference_data_loader.spark = self.spark_session
52-
reference_data_loader.dataset_config_uri = dataset_config_uri
53-
super().__init__(contract, steps, reference_data_loader, logger, **kwargs)
51+
super().__init__(contract, steps, logger, **kwargs)
52+
53+
def load_reference_data(
54+
self,
55+
reference_entity_config: dict[EntityName, ReferenceConfigUnion],
56+
submission_info: Optional[SubmissionInfo],
57+
):
58+
"""Load the reference data as specified in the reference entity config."""
59+
sub_info_entity: Optional[DataFrame] = None
60+
if submission_info:
61+
sub_info_entity = self.convert_submission_info(submission_info)
62+
63+
reference_data_loader = SparkRefDataLoader(
64+
spark=self.spark_session,
65+
reference_data_config=reference_entity_config,
66+
dataset_config_uri=self.dataset_config_uri, # type: ignore
67+
)
68+
if sub_info_entity is not None:
69+
reference_data_loader.entity_cache["dve_submission_info"] = sub_info_entity
70+
71+
return reference_data_loader
5472

5573
def write_entities_to_parquet(
5674
self, entities: SparkEntities, cache_prefix: URI

src/dve/core_engine/backends/implementations/spark/reference_data.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# pylint: disable=no-member
22
"""A reference data loader for Spark."""
33

4-
from typing import Optional
5-
64
from pyspark.sql import DataFrame, SparkSession
75

86
from dve.core_engine.backends.base.reference_data import (
@@ -19,17 +17,15 @@
1917
class SparkRefDataLoader(BaseRefDataLoader[DataFrame]):
2018
"""A reference data loader using already existing Apache Spark Tables."""
2119

22-
spark: SparkSession
23-
"""The Spark session for the backend."""
24-
dataset_config_uri: Optional[URI] = None
25-
"""The location of the dischema file defining business rules"""
26-
2720
def __init__(
2821
self,
29-
reference_entity_config: dict[EntityName, ReferenceConfig],
22+
spark: SparkSession,
23+
reference_data_config: dict[EntityName, ReferenceConfig],
24+
dataset_config_uri: URI,
3025
**kwargs,
3126
) -> None:
32-
super().__init__(reference_entity_config, self.dataset_config_uri, **kwargs)
27+
super().__init__(reference_data_config, dataset_config_uri, **kwargs)
28+
self.spark = spark
3329
if not self.spark:
3430
raise AttributeError("Spark session must be provided")
3531

src/dve/pipeline/duckdb_pipeline.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
from duckdb import DuckDBPyConnection, DuckDBPyRelation
77

8-
from dve.core_engine.backends.base.reference_data import BaseRefDataLoader
8+
import dve.parser.file_handling as fh
9+
from dve.core_engine.backends.base.reference_data import ReferenceConfig
910
from dve.core_engine.backends.implementations.duckdb.auditing import DDBAuditingManager
1011
from dve.core_engine.backends.implementations.duckdb.contract import DuckDBDataContract
1112
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import duckdb_get_entity_count
13+
from dve.core_engine.backends.implementations.duckdb.reference_data import DuckDBRefDataLoader
1214
from dve.core_engine.backends.implementations.duckdb.rules import DuckDBStepImplementations
1315
from dve.core_engine.models import SubmissionInfo
1416
from dve.core_engine.type_hints import URI
@@ -30,7 +32,6 @@ def __init__(
3032
connection: DuckDBPyConnection,
3133
rules_path: Optional[URI],
3234
submitted_files_path: Optional[URI],
33-
reference_data_loader: Optional[type[BaseRefDataLoader]] = None,
3435
job_run_id: Optional[int] = None,
3536
logger: Optional[logging.Logger] = None,
3637
):
@@ -42,11 +43,20 @@ def __init__(
4243
DuckDBStepImplementations.register_udfs(connection=self._connection),
4344
rules_path,
4445
submitted_files_path,
45-
reference_data_loader,
4646
job_run_id,
4747
logger,
4848
)
4949

50+
def init_reference_data_loader(
51+
self, reference_data_config: dict[str, ReferenceConfig], **kwargs
52+
) -> DuckDBRefDataLoader:
53+
return DuckDBRefDataLoader(
54+
connection=self._connection,
55+
reference_data_config=reference_data_config,
56+
dataset_config_uri=fh.get_parent(self._rules_path), # type: ignore
57+
**kwargs
58+
)
59+
5060
# pylint: disable=arguments-differ
5161
def write_file_to_parquet( # type: ignore
5262
self, submission_file_uri: URI, submission_info: SubmissionInfo, output: URI

0 commit comments

Comments
 (0)