Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 4 additions & 32 deletions src/dve/core_engine/backends/base/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import warnings
from abc import ABC, abstractmethod
from collections.abc import Mapping, MutableMapping
from collections.abc import MutableMapping
from typing import Any, ClassVar, Generic, Optional

from pyspark.sql import DataFrame, SparkSession
Expand Down Expand Up @@ -41,14 +41,12 @@ def __init__( # pylint: disable=unused-argument
self,
contract: BaseDataContract[EntityType],
steps: BaseStepImplementations[EntityType],
reference_data_loader_type: Optional[type[BaseRefDataLoader[EntityType]]],
logger: Optional[logging.Logger] = None,
**kwargs: Any,
) -> None:
for component_name, component in (
("Contract", contract),
("Step implementation", steps),
("Reference data loader", reference_data_loader_type),
):
component_entity_type = getattr(component, "__entity_type__", None)
if component_entity_type != self.__entity_type__:
Expand All @@ -61,42 +59,16 @@ def __init__( # pylint: disable=unused-argument
"""The data contract implementation used by the backend."""
self.step_implementations = steps
"""The step implementations used by the backend."""
self.reference_data_loader_type = reference_data_loader_type
"""
The loader type to use for the reference data. If `None`, do not
load any reference data and error if it is provided.

"""
self.logger = logger or get_logger(type(self).__name__)
"""The `logging.Logger instance for the backend."""

def load_reference_data(
self,
reference_entity_config: dict[EntityName, ReferenceConfigUnion],
submission_info: Optional[SubmissionInfo],
) -> Mapping[EntityName, EntityType]:
"""Load the reference data as specified in the reference entity config."""
sub_info_entity: Optional[EntityType] = None
if submission_info:
sub_info_entity = self.convert_submission_info(submission_info)

if self.reference_data_loader_type is None:
if reference_entity_config:
raise ValueError(
"Reference data has been specified but no reference data loader is "
+ "configured for this backend"
)

reference_data_dict = {}
if sub_info_entity is not None:
reference_data_dict["dve_submission_info"] = sub_info_entity
return reference_data_dict

reference_data_loader = self.reference_data_loader_type(reference_entity_config)
if sub_info_entity is not None:
reference_data_loader.entity_cache["dve_submission_info"] = sub_info_entity

return reference_data_loader
) -> BaseRefDataLoader[EntityType]:
"""Supply configured reference data loader for use with business rules"""
raise NotImplementedError()

@abstractmethod
def convert_submission_info(self, submission_info: SubmissionInfo) -> EntityType:
Expand Down
9 changes: 6 additions & 3 deletions src/dve/core_engine/backends/base/reference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dve.core_engine.backends.base.core import get_entity_type
from dve.core_engine.backends.exceptions import (
MissingRefDataEntity,
NoRefDataConfigSupplied,
RefdataLacksFileExtensionSupport,
)
from dve.core_engine.backends.types import EntityType
Expand Down Expand Up @@ -147,11 +148,11 @@ class variable for the subclass.
# pylint: disable=unused-argument
def __init__(
self,
reference_entity_config: dict[EntityName, ReferenceConfig],
dataset_config_uri: Optional[URI] = None,
reference_data_config: dict[EntityName, ReferenceConfig],
dataset_config_uri: URI,
**kwargs,
) -> None:
self.reference_entity_config = reference_entity_config
self.reference_entity_config = reference_data_config
self.dataset_config_uri = dataset_config_uri
"""
Configuration options for the reference data. This is likely to vary
Expand Down Expand Up @@ -207,6 +208,8 @@ def __getitem__(self, key: EntityName) -> EntityType:
try:
config = self.reference_entity_config[key]
return self.load_entity(entity_name=key, config=config)
except TypeError as err:
raise NoRefDataConfigSupplied() from err
except Exception as err:
raise MissingRefDataEntity(entity_name=key) from err

Expand Down
14 changes: 14 additions & 0 deletions src/dve/core_engine/backends/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ def get_message_preamble(self) -> str:
return f"Missing reference data entity {self.entity_name!r}"


class NoRefDataConfigSupplied(BackendError):
"""An error raised when trying to load a refdata entity when no refdata
config has been supplied.

"""

def __init__(self, *args: object) -> None:
super().__init__(*args)

def get_message_preamble(self) -> EntityName:
"""Message for logging purposes"""
return "Refdata loader not supplied with refdata config - unable to load refdata entities"


class ConstraintError(ValueError, BackendErrorMixin):
"""Raised when a given constraint is violated."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def get_duckdb_cast_statement_from_annotation(
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301
return stmt
if issubclass(type_, time):
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301
return stmt
duck_type = get_duckdb_type_from_annotation(type_)
if duck_type:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""A reference data loader for duckdb."""

from typing import Optional

from duckdb import DuckDBPyConnection, DuckDBPyRelation
from pyarrow import ipc # type: ignore

from dve.core_engine.backends.base.reference_data import (
BaseRefDataLoader,
ReferenceConfigUnion,
ReferenceConfig,
ReferenceTable,
mark_refdata_file_extension,
)
Expand All @@ -19,17 +17,16 @@
class DuckDBRefDataLoader(BaseRefDataLoader[DuckDBPyRelation]):
"""A reference data loader using already existing DuckDB tables."""

connection: DuckDBPyConnection
"""The DuckDB connection for the backend."""
dataset_config_uri: Optional[URI] = None
"""The location of the dischema file"""

def __init__(
self,
reference_entity_config: dict[EntityName, ReferenceConfigUnion],
connection: DuckDBPyConnection,
reference_data_config: dict[EntityName, ReferenceConfig],
dataset_config_uri: URI,
**kwargs,
) -> None:
super().__init__(reference_entity_config, self.dataset_config_uri, **kwargs)
super().__init__(reference_data_config, dataset_config_uri, **kwargs)

self.connection = connection

if not self.connection:
raise AttributeError("DuckDBConnection must be specified")
Expand Down
32 changes: 25 additions & 7 deletions src/dve/core_engine/backends/implementations/spark/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pyspark.sql import DataFrame, SparkSession

from dve.core_engine.backends.base.backend import BaseBackend
from dve.core_engine.backends.base.reference_data import ReferenceConfigUnion
from dve.core_engine.backends.implementations.spark.contract import SparkDataContract
from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader
from dve.core_engine.backends.implementations.spark.rules import SparkStepImplementations
Expand All @@ -14,7 +15,7 @@
from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME
from dve.core_engine.loggers import get_child_logger, get_logger
from dve.core_engine.models import SubmissionInfo
from dve.core_engine.type_hints import URI, EntityParquetLocations
from dve.core_engine.type_hints import URI, EntityName, EntityParquetLocations
from dve.parser.file_handling import get_resource_exists, joinuri


Expand All @@ -26,7 +27,6 @@ def __init__(
dataset_config_uri: Optional[URI] = None,
contract: Optional[SparkDataContract] = None,
steps: Optional[SparkStepImplementations] = None,
reference_data_loader: Optional[type[SparkRefDataLoader]] = None,
logger: Optional[logging.Logger] = None,
spark_session: Optional[SparkSession] = None,
**kwargs: Any,
Expand All @@ -36,6 +36,8 @@ def __init__(

self.spark_session = spark_session or SparkSession.builder.getOrCreate()
"""The Spark session for the backend."""
self.dataset_config_uri = dataset_config_uri
"""The uri of the dischema specifying the DVE config"""

if contract is None:
contract = SparkDataContract(
Expand All @@ -46,11 +48,27 @@ def __init__(
steps = SparkStepImplementations.register_udfs(
logger=get_child_logger("SparkStepImplementations", logger)
)
if reference_data_loader is None:
reference_data_loader = SparkRefDataLoader
reference_data_loader.spark = self.spark_session
reference_data_loader.dataset_config_uri = dataset_config_uri
super().__init__(contract, steps, reference_data_loader, logger, **kwargs)
super().__init__(contract, steps, logger, **kwargs)

def load_reference_data(
self,
reference_entity_config: dict[EntityName, ReferenceConfigUnion],
submission_info: Optional[SubmissionInfo],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default?

):
"""Load the reference data as specified in the reference entity config."""
sub_info_entity: Optional[DataFrame] = None
if submission_info:
sub_info_entity = self.convert_submission_info(submission_info)

reference_data_loader = SparkRefDataLoader(
spark=self.spark_session,
reference_data_config=reference_entity_config,
dataset_config_uri=self.dataset_config_uri, # type: ignore
)
if sub_info_entity is not None:
reference_data_loader.entity_cache["dve_submission_info"] = sub_info_entity

return reference_data_loader

def write_entities_to_parquet(
self, entities: SparkEntities, cache_prefix: URI
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# pylint: disable=no-member
"""A reference data loader for Spark."""

from typing import Optional

from pyspark.sql import DataFrame, SparkSession

from dve.core_engine.backends.base.reference_data import (
Expand All @@ -19,17 +17,15 @@
class SparkRefDataLoader(BaseRefDataLoader[DataFrame]):
"""A reference data loader using already existing Apache Spark Tables."""

spark: SparkSession
"""The Spark session for the backend."""
dataset_config_uri: Optional[URI] = None
"""The location of the dischema file defining business rules"""

def __init__(
self,
reference_entity_config: dict[EntityName, ReferenceConfig],
spark: SparkSession,
reference_data_config: dict[EntityName, ReferenceConfig],
dataset_config_uri: URI,
**kwargs,
) -> None:
super().__init__(reference_entity_config, self.dataset_config_uri, **kwargs)
super().__init__(reference_data_config, dataset_config_uri, **kwargs)
self.spark = spark
if not self.spark:
raise AttributeError("Spark session must be provided")

Expand Down
16 changes: 13 additions & 3 deletions src/dve/pipeline/duckdb_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from duckdb import DuckDBPyConnection, DuckDBPyRelation

from dve.core_engine.backends.base.reference_data import BaseRefDataLoader
import dve.parser.file_handling as fh
from dve.core_engine.backends.base.reference_data import ReferenceConfig
from dve.core_engine.backends.implementations.duckdb.auditing import DDBAuditingManager
from dve.core_engine.backends.implementations.duckdb.contract import DuckDBDataContract
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import duckdb_get_entity_count
from dve.core_engine.backends.implementations.duckdb.reference_data import DuckDBRefDataLoader
from dve.core_engine.backends.implementations.duckdb.rules import DuckDBStepImplementations
from dve.core_engine.models import SubmissionInfo
from dve.core_engine.type_hints import URI
Expand All @@ -30,7 +32,6 @@ def __init__(
connection: DuckDBPyConnection,
rules_path: Optional[URI],
submitted_files_path: Optional[URI],
reference_data_loader: Optional[type[BaseRefDataLoader]] = None,
job_run_id: Optional[int] = None,
logger: Optional[logging.Logger] = None,
):
Expand All @@ -42,11 +43,20 @@ def __init__(
DuckDBStepImplementations.register_udfs(connection=self._connection),
rules_path,
submitted_files_path,
reference_data_loader,
job_run_id,
logger,
)

def init_reference_data_loader(
self, reference_data_config: dict[str, ReferenceConfig], **kwargs
) -> DuckDBRefDataLoader:
return DuckDBRefDataLoader(
connection=self._connection,
reference_data_config=reference_data_config,
dataset_config_uri=fh.get_parent(self._rules_path), # type: ignore
**kwargs
)

# pylint: disable=arguments-differ
def write_file_to_parquet( # type: ignore
self, submission_file_uri: URI, submission_info: SubmissionInfo, output: URI
Expand Down
Loading
Loading