Skip to content

Commit adb23de

Browse files
committed
refactor: configured refdata loader to be instantiated when required without need for class vars
1 parent 61c0523 commit adb23de

17 files changed

Lines changed: 180 additions & 218 deletions

File tree

src/dve/core_engine/backends/base/backend.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,12 @@ def __init__( # pylint: disable=unused-argument
4141
self,
4242
contract: BaseDataContract[EntityType],
4343
steps: BaseStepImplementations[EntityType],
44-
reference_data_loader_type: Optional[type[BaseRefDataLoader[EntityType]]],
4544
logger: Optional[logging.Logger] = None,
4645
**kwargs: Any,
4746
) -> None:
4847
for component_name, component in (
4948
("Contract", contract),
5049
("Step implementation", steps),
51-
("Reference data loader", reference_data_loader_type),
5250
):
5351
component_entity_type = getattr(component, "__entity_type__", None)
5452
if component_entity_type != self.__entity_type__:
@@ -61,42 +59,15 @@ def __init__( # pylint: disable=unused-argument
6159
"""The data contract implementation used by the backend."""
6260
self.step_implementations = steps
6361
"""The step implementations used by the backend."""
64-
self.reference_data_loader_type = reference_data_loader_type
65-
"""
66-
The loader type to use for the reference data. If `None`, do not
67-
load any reference data and error if it is provided.
68-
69-
"""
7062
self.logger = logger or get_logger(type(self).__name__)
7163
"""The `logging.Logger instance for the backend."""
7264

7365
def load_reference_data(
7466
self,
7567
reference_entity_config: dict[EntityName, ReferenceConfigUnion],
7668
submission_info: Optional[SubmissionInfo],
77-
) -> Mapping[EntityName, EntityType]:
78-
"""Load the reference data as specified in the reference entity config."""
79-
sub_info_entity: Optional[EntityType] = None
80-
if submission_info:
81-
sub_info_entity = self.convert_submission_info(submission_info)
82-
83-
if self.reference_data_loader_type is None:
84-
if reference_entity_config:
85-
raise ValueError(
86-
"Reference data has been specified but no reference data loader is "
87-
+ "configured for this backend"
88-
)
89-
90-
reference_data_dict = {}
91-
if sub_info_entity is not None:
92-
reference_data_dict["dve_submission_info"] = sub_info_entity
93-
return reference_data_dict
94-
95-
reference_data_loader = self.reference_data_loader_type(reference_entity_config)
96-
if sub_info_entity is not None:
97-
reference_data_loader.entity_cache["dve_submission_info"] = sub_info_entity
98-
99-
return reference_data_loader
69+
) -> BaseRefDataLoader[EntityType]:
70+
raise NotImplementedError()
10071

10172
@abstractmethod
10273
def convert_submission_info(self, submission_info: SubmissionInfo) -> EntityType:

src/dve/core_engine/backends/base/reference_data.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from dve.core_engine.backends.base.core import get_entity_type
1212
from dve.core_engine.backends.exceptions import (
1313
MissingRefDataEntity,
14+
NoRefDataConfigSupplied,
1415
RefdataLacksFileExtensionSupport,
1516
)
1617
from dve.core_engine.backends.types import EntityType
@@ -147,11 +148,11 @@ class variable for the subclass.
147148
# pylint: disable=unused-argument
148149
def __init__(
149150
self,
150-
reference_entity_config: dict[EntityName, ReferenceConfig],
151-
dataset_config_uri: Optional[URI] = None,
151+
reference_data_config: dict[EntityName, ReferenceConfig],
152+
dataset_config_uri: URI,
152153
**kwargs,
153154
) -> None:
154-
self.reference_entity_config = reference_entity_config
155+
self.reference_entity_config = reference_data_config
155156
self.dataset_config_uri = dataset_config_uri
156157
"""
157158
Configuration options for the reference data. This is likely to vary
@@ -207,6 +208,8 @@ def __getitem__(self, key: EntityName) -> EntityType:
207208
try:
208209
config = self.reference_entity_config[key]
209210
return self.load_entity(entity_name=key, config=config)
211+
except TypeError:
212+
raise NoRefDataConfigSupplied()
210213
except Exception as err:
211214
raise MissingRefDataEntity(entity_name=key) from err
212215

src/dve/core_engine/backends/exceptions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,19 @@ def get_message_preamble(self) -> str:
118118
"""
119119
return f"Missing reference data entity {self.entity_name!r}"
120120

121+
class NoRefDataConfigSupplied(BackendError):
122+
"""An error raised when trying to load a refdata entity when no refdata
123+
config has been supplied.
124+
125+
"""
126+
127+
def __init__(self, *args: object) -> None:
128+
super().__init__(*args)
129+
130+
def get_message_preamble(self) -> EntityName:
131+
"""Message for logging purposes"""
132+
return f"Refdata loader not supplied with refdata config - unable to load refdata entities"
133+
121134

122135
class ConstraintError(ValueError, BackendErrorMixin):
123136
"""Raised when a given constraint is violated."""

src/dve/core_engine/backends/implementations/duckdb/reference_data.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from dve.core_engine.backends.base.reference_data import (
99
BaseRefDataLoader,
10+
ReferenceConfig,
1011
ReferenceConfigUnion,
1112
ReferenceTable,
1213
mark_refdata_file_extension,
@@ -17,19 +18,20 @@
1718

1819
# pylint: disable=too-few-public-methods
1920
class DuckDBRefDataLoader(BaseRefDataLoader[DuckDBPyRelation]):
20-
"""A reference data loader using already existing DuckDB tables."""
21-
22-
connection: DuckDBPyConnection
23-
"""The DuckDB connection for the backend."""
24-
dataset_config_uri: Optional[URI] = None
25-
"""The location of the dischema file"""
21+
"""A reference data loader using already existing DuckDB tables.
22+
reference_entity_config and dataset_config_uri (if config uses relative paths)
23+
should be supplied using setter methods for the dataset being processed before running."""
2624

2725
def __init__(
2826
self,
29-
reference_entity_config: dict[EntityName, ReferenceConfigUnion],
30-
**kwargs,
27+
connection: DuckDBPyConnection,
28+
reference_data_config: dict[EntityName, ReferenceConfig],
29+
dataset_config_uri: URI,
30+
**kwargs
3131
) -> None:
32-
super().__init__(reference_entity_config, self.dataset_config_uri, **kwargs)
32+
super().__init__(reference_data_config, dataset_config_uri,**kwargs)
33+
34+
self.connection = connection
3335

3436
if not self.connection:
3537
raise AttributeError("DuckDBConnection must be specified")

src/dve/core_engine/backends/implementations/spark/backend.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66
from pyspark.sql import DataFrame, SparkSession
77

88
from dve.core_engine.backends.base.backend import BaseBackend
9+
from dve.core_engine.backends.base.reference_data import ReferenceConfigUnion
910
from dve.core_engine.backends.implementations.spark.contract import SparkDataContract
1011
from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader
1112
from dve.core_engine.backends.implementations.spark.rules import SparkStepImplementations
1213
from dve.core_engine.backends.implementations.spark.spark_helpers import get_type_from_annotation
1314
from dve.core_engine.backends.implementations.spark.types import SparkEntities
15+
from dve.core_engine.backends.types import EntityType
1416
from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME
1517
from dve.core_engine.loggers import get_child_logger, get_logger
1618
from dve.core_engine.models import SubmissionInfo
17-
from dve.core_engine.type_hints import URI, EntityParquetLocations
18-
from dve.parser.file_handling import get_resource_exists, joinuri
19+
from dve.core_engine.type_hints import URI, EntityName, EntityParquetLocations
20+
from dve.parser.file_handling import get_resource_exists, joinuri, get_parent
1921

2022

2123
class SparkBackend(BaseBackend[DataFrame]):
@@ -26,7 +28,6 @@ def __init__(
2628
dataset_config_uri: Optional[URI] = None,
2729
contract: Optional[SparkDataContract] = None,
2830
steps: Optional[SparkStepImplementations] = None,
29-
reference_data_loader: Optional[type[SparkRefDataLoader]] = None,
3031
logger: Optional[logging.Logger] = None,
3132
spark_session: Optional[SparkSession] = None,
3233
**kwargs: Any,
@@ -36,6 +37,8 @@ def __init__(
3637

3738
self.spark_session = spark_session or SparkSession.builder.getOrCreate()
3839
"""The Spark session for the backend."""
40+
self.dataset_config_uri = dataset_config_uri
41+
"""The uri of the dischema specifying the DVE config"""
3942

4043
if contract is None:
4144
contract = SparkDataContract(
@@ -46,11 +49,23 @@ def __init__(
4649
steps = SparkStepImplementations.register_udfs(
4750
logger=get_child_logger("SparkStepImplementations", logger)
4851
)
49-
if reference_data_loader is None:
50-
reference_data_loader = SparkRefDataLoader
51-
reference_data_loader.spark = self.spark_session
52-
reference_data_loader.dataset_config_uri = dataset_config_uri
53-
super().__init__(contract, steps, reference_data_loader, logger, **kwargs)
52+
super().__init__(contract, steps, logger, **kwargs)
53+
54+
def load_reference_data(self,
55+
reference_entity_config: dict[EntityName, ReferenceConfigUnion],
56+
submission_info: Optional[SubmissionInfo],):
57+
"""Load the reference data as specified in the reference entity config."""
58+
sub_info_entity: Optional[EntityType] = None
59+
if submission_info:
60+
sub_info_entity = self.convert_submission_info(submission_info)
61+
62+
reference_data_loader = SparkRefDataLoader(spark=self.spark_session,
63+
reference_data_config=reference_entity_config,
64+
dataset_config_uri=self.dataset_config_uri)
65+
if sub_info_entity is not None:
66+
reference_data_loader.entity_cache["dve_submission_info"] = sub_info_entity
67+
68+
return reference_data_loader
5469

5570
def write_entities_to_parquet(
5671
self, entities: SparkEntities, cache_prefix: URI

src/dve/core_engine/backends/implementations/spark/reference_data.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@
1717

1818
# pylint: disable=too-few-public-methods
1919
class SparkRefDataLoader(BaseRefDataLoader[DataFrame]):
20-
"""A reference data loader using already existing Apache Spark Tables."""
21-
22-
spark: SparkSession
23-
"""The Spark session for the backend."""
24-
dataset_config_uri: Optional[URI] = None
25-
"""The location of the dischema file defining business rules"""
20+
"""A reference data loader using already existing Apache Spark Tables.
21+
reference_entity_config and dataset_config_uri (if config uses relative paths)
22+
should be supplied using setter methods for the dataset being processed before running."""
2623

2724
def __init__(
2825
self,
29-
reference_entity_config: dict[EntityName, ReferenceConfig],
26+
spark: SparkSession,
27+
reference_data_config: dict[EntityName, ReferenceConfig],
28+
dataset_config_uri: URI,
3029
**kwargs,
3130
) -> None:
32-
super().__init__(reference_entity_config, self.dataset_config_uri, **kwargs)
31+
super().__init__(reference_data_config, dataset_config_uri, **kwargs)
32+
self.spark = spark
3333
if not self.spark:
3434
raise AttributeError("Spark session must be provided")
3535

src/dve/pipeline/duckdb_pipeline.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55

66
from duckdb import DuckDBPyConnection, DuckDBPyRelation
77

8-
from dve.core_engine.backends.base.reference_data import BaseRefDataLoader
8+
from dve.core_engine.backends.base.reference_data import BaseRefDataLoader, ReferenceConfig
99
from dve.core_engine.backends.implementations.duckdb.auditing import DDBAuditingManager
1010
from dve.core_engine.backends.implementations.duckdb.contract import DuckDBDataContract
1111
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import duckdb_get_entity_count
12+
from dve.core_engine.backends.implementations.duckdb.reference_data import DuckDBRefDataLoader
1213
from dve.core_engine.backends.implementations.duckdb.rules import DuckDBStepImplementations
1314
from dve.core_engine.models import SubmissionInfo
1415
from dve.core_engine.type_hints import URI
1516
from dve.pipeline.pipeline import BaseDVEPipeline
17+
import dve.parser.file_handling as fh
1618

1719

1820
# pylint: disable=abstract-method
@@ -30,7 +32,6 @@ def __init__(
3032
connection: DuckDBPyConnection,
3133
rules_path: Optional[URI],
3234
submitted_files_path: Optional[URI],
33-
reference_data_loader: Optional[type[BaseRefDataLoader]] = None,
3435
job_run_id: Optional[int] = None,
3536
logger: Optional[logging.Logger] = None,
3637
):
@@ -42,11 +43,17 @@ def __init__(
4243
DuckDBStepImplementations.register_udfs(connection=self._connection),
4344
rules_path,
4445
submitted_files_path,
45-
reference_data_loader,
4646
job_run_id,
4747
logger,
4848
)
4949

50+
def get_reference_data_loader(self,
51+
reference_data_config: dict[str, ReferenceConfig],
52+
**kwargs) -> BaseRefDataLoader[DuckDBPyRelation]:
53+
return DuckDBRefDataLoader(connection=self._connection,
54+
reference_data_config=reference_data_config,
55+
dataset_config_uri=fh.get_parent(self._rules_path),
56+
**kwargs)
5057
# pylint: disable=arguments-differ
5158
def write_file_to_parquet( # type: ignore
5259
self, submission_file_uri: URI, submission_info: SubmissionInfo, output: URI

src/dve/pipeline/pipeline.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from dve.core_engine.backends.base.auditing import BaseAuditingManager
2727
from dve.core_engine.backends.base.contract import BaseDataContract
2828
from dve.core_engine.backends.base.core import EntityManager
29-
from dve.core_engine.backends.base.reference_data import BaseRefDataLoader
29+
from dve.core_engine.backends.base.reference_data import BaseRefDataLoader, ReferenceConfig
3030
from dve.core_engine.backends.base.rules import BaseStepImplementations
3131
from dve.core_engine.backends.exceptions import MessageBearingError
3232
from dve.core_engine.backends.readers import BaseFileReader
@@ -36,7 +36,7 @@
3636
from dve.core_engine.loggers import get_logger
3737
from dve.core_engine.message import FeedbackMessage
3838
from dve.core_engine.models import SubmissionInfo, SubmissionStatisticsRecord
39-
from dve.core_engine.type_hints import URI, DVEStageName, FileURI, InfoURI
39+
from dve.core_engine.type_hints import URI, DVEStageName, EntityName, FileURI, InfoURI
4040
from dve.parser import file_handling as fh
4141
from dve.parser.file_handling.implementations.file import LocalFilesystemImplementation
4242
from dve.parser.file_handling.service import _get_implementation
@@ -62,14 +62,13 @@ def __init__(
6262
step_implementations: Optional[BaseStepImplementations[EntityType]],
6363
rules_path: Optional[URI],
6464
submitted_files_path: Optional[URI],
65-
reference_data_loader: Optional[type[BaseRefDataLoader]] = None,
6665
job_run_id: Optional[int] = None,
6766
logger: Optional[logging.Logger] = None,
6867
):
6968
self._submitted_files_path = submitted_files_path
7069
self._processed_files_path = processed_files_path
7170
self._rules_path = rules_path
72-
self._reference_data_loader = reference_data_loader
71+
self._reference_data_loader = None
7372
self._job_run_id = job_run_id
7473
self._audit_tables = audit_tables
7574
self._data_contract = data_contract
@@ -113,6 +112,13 @@ def step_implementations(self) -> Optional[BaseStepImplementations[EntityType]]:
113112
def get_entity_count(entity: EntityType) -> int:
114113
"""Get a row count of an entity stored as parquet"""
115114
raise NotImplementedError()
115+
116+
def get_reference_data_loader(self,
117+
reference_data_config: dict[EntityName, ReferenceConfig],
118+
**kwargs) -> BaseRefDataLoader[EntityType]:
119+
"""Get reference data loader if required for business rules"""
120+
raise NotImplementedError()
121+
116122

117123
def get_submission_status(
118124
self, step_name: DVEStageName, submission_id: str
@@ -542,9 +548,6 @@ def apply_business_rules( # pylint: disable=R0914
542548
if not self.rules_path:
543549
raise AttributeError("business rules path not provided.")
544550

545-
if not self._reference_data_loader:
546-
raise AttributeError("reference data loader not provided.")
547-
548551
if not self.processed_files_path:
549552
raise AttributeError("processed files path has not been provided.")
550553

@@ -556,8 +559,8 @@ def apply_business_rules( # pylint: disable=R0914
556559
self._processed_files_path, submission_info.submission_id
557560
)
558561
ref_data = config.get_reference_data_config()
562+
reference_data = self.get_reference_data_loader(reference_data_config=ref_data)
559563
rules = config.get_rule_metadata()
560-
reference_data = self._reference_data_loader(ref_data) # type: ignore
561564
entities = {}
562565
contract = fh.joinuri(
563566
self.processed_files_path, submission_info.submission_id, "data_contract"

0 commit comments

Comments
 (0)