diff --git a/src/dve/core_engine/backends/base/backend.py b/src/dve/core_engine/backends/base/backend.py index 29e8644..f627412 100644 --- a/src/dve/core_engine/backends/base/backend.py +++ b/src/dve/core_engine/backends/base/backend.py @@ -3,7 +3,7 @@ import logging import warnings from abc import ABC, abstractmethod -from collections.abc import Mapping, MutableMapping +from collections.abc import MutableMapping from typing import Any, ClassVar, Generic, Optional from pyspark.sql import DataFrame, SparkSession @@ -41,14 +41,12 @@ def __init__( # pylint: disable=unused-argument self, contract: BaseDataContract[EntityType], steps: BaseStepImplementations[EntityType], - reference_data_loader_type: Optional[type[BaseRefDataLoader[EntityType]]], logger: Optional[logging.Logger] = None, **kwargs: Any, ) -> None: for component_name, component in ( ("Contract", contract), ("Step implementation", steps), - ("Reference data loader", reference_data_loader_type), ): component_entity_type = getattr(component, "__entity_type__", None) if component_entity_type != self.__entity_type__: @@ -61,12 +59,6 @@ def __init__( # pylint: disable=unused-argument """The data contract implementation used by the backend.""" self.step_implementations = steps """The step implementations used by the backend.""" - self.reference_data_loader_type = reference_data_loader_type - """ - The loader type to use for the reference data. If `None`, do not - load any reference data and error if it is provided. - - """ self.logger = logger or get_logger(type(self).__name__) """The `logging.Logger instance for the backend.""" @@ -74,29 +66,9 @@ def load_reference_data( self, reference_entity_config: dict[EntityName, ReferenceConfigUnion], submission_info: Optional[SubmissionInfo], - ) -> Mapping[EntityName, EntityType]: - """Load the reference data as specified in the reference entity config.""" - sub_info_entity: Optional[EntityType] = None - if submission_info: - sub_info_entity = self.convert_submission_info(submission_info) - - if self.reference_data_loader_type is None: - if reference_entity_config: - raise ValueError( - "Reference data has been specified but no reference data loader is " - + "configured for this backend" - ) - - reference_data_dict = {} - if sub_info_entity is not None: - reference_data_dict["dve_submission_info"] = sub_info_entity - return reference_data_dict - - reference_data_loader = self.reference_data_loader_type(reference_entity_config) - if sub_info_entity is not None: - reference_data_loader.entity_cache["dve_submission_info"] = sub_info_entity - - return reference_data_loader + ) -> BaseRefDataLoader[EntityType]: + """Supply configured reference data loader for use with business rules""" + raise NotImplementedError() @abstractmethod def convert_submission_info(self, submission_info: SubmissionInfo) -> EntityType: diff --git a/src/dve/core_engine/backends/base/reference_data.py b/src/dve/core_engine/backends/base/reference_data.py index 5be0ec0..9010e8d 100644 --- a/src/dve/core_engine/backends/base/reference_data.py +++ b/src/dve/core_engine/backends/base/reference_data.py @@ -11,6 +11,7 @@ from dve.core_engine.backends.base.core import get_entity_type from dve.core_engine.backends.exceptions import ( MissingRefDataEntity, + NoRefDataConfigSupplied, RefdataLacksFileExtensionSupport, ) from dve.core_engine.backends.types import EntityType @@ -147,11 +148,11 @@ class variable for the subclass. # pylint: disable=unused-argument def __init__( self, - reference_entity_config: dict[EntityName, ReferenceConfig], - dataset_config_uri: Optional[URI] = None, + reference_data_config: dict[EntityName, ReferenceConfig], + dataset_config_uri: URI, **kwargs, ) -> None: - self.reference_entity_config = reference_entity_config + self.reference_entity_config = reference_data_config self.dataset_config_uri = dataset_config_uri """ Configuration options for the reference data. This is likely to vary @@ -207,6 +208,8 @@ def __getitem__(self, key: EntityName) -> EntityType: try: config = self.reference_entity_config[key] return self.load_entity(entity_name=key, config=config) + except TypeError as err: + raise NoRefDataConfigSupplied() from err except Exception as err: raise MissingRefDataEntity(entity_name=key) from err diff --git a/src/dve/core_engine/backends/exceptions.py b/src/dve/core_engine/backends/exceptions.py index 8dd50ef..6878fc2 100644 --- a/src/dve/core_engine/backends/exceptions.py +++ b/src/dve/core_engine/backends/exceptions.py @@ -119,6 +119,20 @@ def get_message_preamble(self) -> str: return f"Missing reference data entity {self.entity_name!r}" +class NoRefDataConfigSupplied(BackendError): + """An error raised when trying to load a refdata entity when no refdata + config has been supplied. + + """ + + def __init__(self, *args: object) -> None: + super().__init__(*args) + + def get_message_preamble(self) -> EntityName: + """Message for logging purposes""" + return "Refdata loader not supplied with refdata config - unable to load refdata entities" + + class ConstraintError(ValueError, BackendErrorMixin): """Raised when a given constraint is violated.""" diff --git a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py index 394cd01..627822b 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py +++ b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py @@ -411,7 +411,7 @@ def get_duckdb_cast_statement_from_annotation( stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301 return stmt if issubclass(type_, time): - stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301 + stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301 return stmt duck_type = get_duckdb_type_from_annotation(type_) if duck_type: diff --git a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py index af815ce..c059811 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py +++ b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py @@ -1,13 +1,11 @@ """A reference data loader for duckdb.""" -from typing import Optional - from duckdb import DuckDBPyConnection, DuckDBPyRelation from pyarrow import ipc # type: ignore from dve.core_engine.backends.base.reference_data import ( BaseRefDataLoader, - ReferenceConfigUnion, + ReferenceConfig, ReferenceTable, mark_refdata_file_extension, ) @@ -19,17 +17,16 @@ class DuckDBRefDataLoader(BaseRefDataLoader[DuckDBPyRelation]): """A reference data loader using already existing DuckDB tables.""" - connection: DuckDBPyConnection - """The DuckDB connection for the backend.""" - dataset_config_uri: Optional[URI] = None - """The location of the dischema file""" - def __init__( self, - reference_entity_config: dict[EntityName, ReferenceConfigUnion], + connection: DuckDBPyConnection, + reference_data_config: dict[EntityName, ReferenceConfig], + dataset_config_uri: URI, **kwargs, ) -> None: - super().__init__(reference_entity_config, self.dataset_config_uri, **kwargs) + super().__init__(reference_data_config, dataset_config_uri, **kwargs) + + self.connection = connection if not self.connection: raise AttributeError("DuckDBConnection must be specified") diff --git a/src/dve/core_engine/backends/implementations/spark/backend.py b/src/dve/core_engine/backends/implementations/spark/backend.py index 3999b62..126e07a 100644 --- a/src/dve/core_engine/backends/implementations/spark/backend.py +++ b/src/dve/core_engine/backends/implementations/spark/backend.py @@ -6,6 +6,7 @@ from pyspark.sql import DataFrame, SparkSession from dve.core_engine.backends.base.backend import BaseBackend +from dve.core_engine.backends.base.reference_data import ReferenceConfigUnion from dve.core_engine.backends.implementations.spark.contract import SparkDataContract from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader from dve.core_engine.backends.implementations.spark.rules import SparkStepImplementations @@ -14,7 +15,7 @@ from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME from dve.core_engine.loggers import get_child_logger, get_logger from dve.core_engine.models import SubmissionInfo -from dve.core_engine.type_hints import URI, EntityParquetLocations +from dve.core_engine.type_hints import URI, EntityName, EntityParquetLocations from dve.parser.file_handling import get_resource_exists, joinuri @@ -26,7 +27,6 @@ def __init__( dataset_config_uri: Optional[URI] = None, contract: Optional[SparkDataContract] = None, steps: Optional[SparkStepImplementations] = None, - reference_data_loader: Optional[type[SparkRefDataLoader]] = None, logger: Optional[logging.Logger] = None, spark_session: Optional[SparkSession] = None, **kwargs: Any, @@ -36,6 +36,8 @@ def __init__( self.spark_session = spark_session or SparkSession.builder.getOrCreate() """The Spark session for the backend.""" + self.dataset_config_uri = dataset_config_uri + """The uri of the dischema specifying the DVE config""" if contract is None: contract = SparkDataContract( @@ -46,11 +48,27 @@ def __init__( steps = SparkStepImplementations.register_udfs( logger=get_child_logger("SparkStepImplementations", logger) ) - if reference_data_loader is None: - reference_data_loader = SparkRefDataLoader - reference_data_loader.spark = self.spark_session - reference_data_loader.dataset_config_uri = dataset_config_uri - super().__init__(contract, steps, reference_data_loader, logger, **kwargs) + super().__init__(contract, steps, logger, **kwargs) + + def load_reference_data( + self, + reference_entity_config: dict[EntityName, ReferenceConfigUnion], + submission_info: Optional[SubmissionInfo], + ): + """Load the reference data as specified in the reference entity config.""" + sub_info_entity: Optional[DataFrame] = None + if submission_info: + sub_info_entity = self.convert_submission_info(submission_info) + + reference_data_loader = SparkRefDataLoader( + spark=self.spark_session, + reference_data_config=reference_entity_config, + dataset_config_uri=self.dataset_config_uri, # type: ignore + ) + if sub_info_entity is not None: + reference_data_loader.entity_cache["dve_submission_info"] = sub_info_entity + + return reference_data_loader def write_entities_to_parquet( self, entities: SparkEntities, cache_prefix: URI diff --git a/src/dve/core_engine/backends/implementations/spark/reference_data.py b/src/dve/core_engine/backends/implementations/spark/reference_data.py index 90ba4f6..44f49af 100644 --- a/src/dve/core_engine/backends/implementations/spark/reference_data.py +++ b/src/dve/core_engine/backends/implementations/spark/reference_data.py @@ -1,8 +1,6 @@ # pylint: disable=no-member """A reference data loader for Spark.""" -from typing import Optional - from pyspark.sql import DataFrame, SparkSession from dve.core_engine.backends.base.reference_data import ( @@ -19,17 +17,15 @@ class SparkRefDataLoader(BaseRefDataLoader[DataFrame]): """A reference data loader using already existing Apache Spark Tables.""" - spark: SparkSession - """The Spark session for the backend.""" - dataset_config_uri: Optional[URI] = None - """The location of the dischema file defining business rules""" - def __init__( self, - reference_entity_config: dict[EntityName, ReferenceConfig], + spark: SparkSession, + reference_data_config: dict[EntityName, ReferenceConfig], + dataset_config_uri: URI, **kwargs, ) -> None: - super().__init__(reference_entity_config, self.dataset_config_uri, **kwargs) + super().__init__(reference_data_config, dataset_config_uri, **kwargs) + self.spark = spark if not self.spark: raise AttributeError("Spark session must be provided") diff --git a/src/dve/pipeline/duckdb_pipeline.py b/src/dve/pipeline/duckdb_pipeline.py index 87e927d..0370106 100644 --- a/src/dve/pipeline/duckdb_pipeline.py +++ b/src/dve/pipeline/duckdb_pipeline.py @@ -5,10 +5,12 @@ from duckdb import DuckDBPyConnection, DuckDBPyRelation -from dve.core_engine.backends.base.reference_data import BaseRefDataLoader +import dve.parser.file_handling as fh +from dve.core_engine.backends.base.reference_data import ReferenceConfig from dve.core_engine.backends.implementations.duckdb.auditing import DDBAuditingManager from dve.core_engine.backends.implementations.duckdb.contract import DuckDBDataContract from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import duckdb_get_entity_count +from dve.core_engine.backends.implementations.duckdb.reference_data import DuckDBRefDataLoader from dve.core_engine.backends.implementations.duckdb.rules import DuckDBStepImplementations from dve.core_engine.models import SubmissionInfo from dve.core_engine.type_hints import URI @@ -30,7 +32,6 @@ def __init__( connection: DuckDBPyConnection, rules_path: Optional[URI], submitted_files_path: Optional[URI], - reference_data_loader: Optional[type[BaseRefDataLoader]] = None, job_run_id: Optional[int] = None, logger: Optional[logging.Logger] = None, ): @@ -42,11 +43,20 @@ def __init__( DuckDBStepImplementations.register_udfs(connection=self._connection), rules_path, submitted_files_path, - reference_data_loader, job_run_id, logger, ) + def init_reference_data_loader( + self, reference_data_config: dict[str, ReferenceConfig], **kwargs + ) -> DuckDBRefDataLoader: + return DuckDBRefDataLoader( + connection=self._connection, + reference_data_config=reference_data_config, + dataset_config_uri=fh.get_parent(self._rules_path), # type: ignore + **kwargs + ) + # pylint: disable=arguments-differ def write_file_to_parquet( # type: ignore self, submission_file_uri: URI, submission_info: SubmissionInfo, output: URI diff --git a/src/dve/pipeline/pipeline.py b/src/dve/pipeline/pipeline.py index 26b682e..91ff2ee 100644 --- a/src/dve/pipeline/pipeline.py +++ b/src/dve/pipeline/pipeline.py @@ -26,7 +26,7 @@ from dve.core_engine.backends.base.auditing import BaseAuditingManager from dve.core_engine.backends.base.contract import BaseDataContract from dve.core_engine.backends.base.core import EntityManager -from dve.core_engine.backends.base.reference_data import BaseRefDataLoader +from dve.core_engine.backends.base.reference_data import BaseRefDataLoader, ReferenceConfig from dve.core_engine.backends.base.rules import BaseStepImplementations from dve.core_engine.backends.exceptions import MessageBearingError from dve.core_engine.backends.readers import BaseFileReader @@ -36,7 +36,7 @@ from dve.core_engine.loggers import get_logger from dve.core_engine.message import FeedbackMessage from dve.core_engine.models import SubmissionInfo, SubmissionStatisticsRecord -from dve.core_engine.type_hints import URI, DVEStageName, FileURI, InfoURI +from dve.core_engine.type_hints import URI, DVEStageName, EntityName, FileURI, InfoURI from dve.parser import file_handling as fh from dve.parser.file_handling.implementations.file import LocalFilesystemImplementation from dve.parser.file_handling.service import _get_implementation @@ -62,14 +62,12 @@ def __init__( step_implementations: Optional[BaseStepImplementations[EntityType]], rules_path: Optional[URI], submitted_files_path: Optional[URI], - reference_data_loader: Optional[type[BaseRefDataLoader]] = None, job_run_id: Optional[int] = None, logger: Optional[logging.Logger] = None, ): self._submitted_files_path = submitted_files_path self._processed_files_path = processed_files_path self._rules_path = rules_path - self._reference_data_loader = reference_data_loader self._job_run_id = job_run_id self._audit_tables = audit_tables self._data_contract = data_contract @@ -114,6 +112,12 @@ def get_entity_count(entity: EntityType) -> int: """Get a row count of an entity stored as parquet""" raise NotImplementedError() + def init_reference_data_loader( + self, reference_data_config: dict[EntityName, ReferenceConfig], **kwargs + ) -> BaseRefDataLoader: + """Get reference data loader if required for business rules""" + raise NotImplementedError() + def get_submission_status( self, step_name: DVEStageName, submission_id: str ) -> SubmissionStatus: @@ -527,7 +531,7 @@ def data_contract_step( return processed_files, failed_processing - def apply_business_rules( # pylint: disable=R0914 + def apply_business_rules( # pylint: disable=R0914 self, submission_info: SubmissionInfo, submission_status: Optional[SubmissionStatus] = None ) -> tuple[SubmissionInfo, SubmissionStatus]: """Apply the business rules to a given submission, the submission may have failed at the @@ -542,9 +546,6 @@ def apply_business_rules( # pylint: disable=R0914 if not self.rules_path: raise AttributeError("business rules path not provided.") - if not self._reference_data_loader: - raise AttributeError("reference data loader not provided.") - if not self.processed_files_path: raise AttributeError("processed files path has not been provided.") @@ -556,8 +557,10 @@ def apply_business_rules( # pylint: disable=R0914 self._processed_files_path, submission_info.submission_id ) ref_data = config.get_reference_data_config() + reference_data: BaseRefDataLoader = self.init_reference_data_loader( + reference_data_config=ref_data + ) rules = config.get_rule_metadata() - reference_data = self._reference_data_loader(ref_data) # type: ignore entities = {} contract = fh.joinuri( self.processed_files_path, submission_info.submission_id, "data_contract" @@ -582,10 +585,7 @@ def apply_business_rules( # pylint: disable=R0914 key_fields = {model: conf.reporting_fields for model, conf in model_config.items()} _errors_uri, rules_success = self.step_implementations.apply_rules( # type: ignore - working_directory, - entity_manager, - rules, - key_fields + working_directory, entity_manager, rules, key_fields ) rule_messages = load_feedback_messages( diff --git a/src/dve/pipeline/spark_pipeline.py b/src/dve/pipeline/spark_pipeline.py index 71fdb32..201abbf 100644 --- a/src/dve/pipeline/spark_pipeline.py +++ b/src/dve/pipeline/spark_pipeline.py @@ -6,9 +6,11 @@ from pyspark.sql import DataFrame, SparkSession -from dve.core_engine.backends.base.reference_data import BaseRefDataLoader +import dve.parser.file_handling as fh +from dve.core_engine.backends.base.reference_data import ReferenceConfig from dve.core_engine.backends.implementations.spark.auditing import SparkAuditingManager from dve.core_engine.backends.implementations.spark.contract import SparkDataContract +from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader from dve.core_engine.backends.implementations.spark.rules import SparkStepImplementations from dve.core_engine.backends.implementations.spark.spark_helpers import spark_get_entity_count from dve.core_engine.models import SubmissionInfo @@ -31,7 +33,6 @@ def __init__( audit_tables: SparkAuditingManager, rules_path: Optional[URI], submitted_files_path: Optional[URI], - reference_data_loader: Optional[type[BaseRefDataLoader]] = None, spark: Optional[SparkSession] = None, job_run_id: Optional[int] = None, logger: Optional[logging.Logger] = None, @@ -44,11 +45,20 @@ def __init__( SparkStepImplementations.register_udfs(self._spark), rules_path, submitted_files_path, - reference_data_loader, job_run_id, logger, ) + def init_reference_data_loader( + self, reference_data_config: dict[str, ReferenceConfig], **kwargs + ) -> SparkRefDataLoader: + return SparkRefDataLoader( + spark=self._spark, + reference_data_config=reference_data_config, + dataset_config_uri=fh.get_parent(self._rules_path), # type: ignore + **kwargs + ) + # pylint: disable=arguments-differ def write_file_to_parquet( # type: ignore self, submission_file_uri: URI, submission_info: SubmissionInfo, output: URI diff --git a/tests/features/steps/steps_pipeline.py b/tests/features/steps/steps_pipeline.py index fa1e848..55acadd 100644 --- a/tests/features/steps/steps_pipeline.py +++ b/tests/features/steps/steps_pipeline.py @@ -48,9 +48,6 @@ def setup_spark_pipeline( schema_file_name = f"{dataset_id}.dischema.json" if not schema_file_name else schema_file_name rules_path = get_test_file_path(f"{dataset_id}/{schema_file_name}").resolve().as_uri() - # configure reference data - SparkRefDataLoader.spark = spark - SparkRefDataLoader.dataset_config_uri = fh.get_parent(rules_path) return SparkDVEPipeline( processed_files_path=processing_path.as_uri(), @@ -61,7 +58,6 @@ def setup_spark_pipeline( job_run_id=12345, rules_path=rules_path, submitted_files_path=processing_path.as_uri(), - reference_data_loader=SparkRefDataLoader, spark=spark, ) @@ -78,9 +74,6 @@ def setup_duckdb_pipeline( # create duckdbpyconnection with dve database file in context.tempdir # TODO - doesn't like file scheme - need to provide absolute path db_file = Path(processing_path, "dve.duckdb") - # configure refdata - DuckDBRefDataLoader.connection = connection - DuckDBRefDataLoader.dataset_config_uri = fh.get_parent(rules_path) return DDBDVEPipeline( processed_files_path=processing_path.as_posix(), audit_tables=DDBAuditingManager( @@ -91,8 +84,7 @@ def setup_duckdb_pipeline( job_run_id=12345, connection=connection, rules_path=rules_path, - submitted_files_path=processing_path.as_posix(), - reference_data_loader=DuckDBRefDataLoader + submitted_files_path=processing_path.as_posix() ) @@ -314,18 +306,17 @@ def create_refdata_tables(context: Context, database: str): record = row.as_dict() refdata_tables[record["table_name"]] = record["parquet_path"] pipeline = ctxt.get_pipeline(context) - refdata_loader = getattr(pipeline, "_reference_data_loader") - if refdata_loader == SparkRefDataLoader: - refdata_loader.spark.sql(f"CREATE DATABASE IF NOT EXISTS {database}") + if isinstance(pipeline, SparkDVEPipeline): + pipeline._spark.sql(f"CREATE DATABASE IF NOT EXISTS {database}") for tbl, source in refdata_tables.items(): - (refdata_loader.spark.read.parquet(source) + (pipeline._spark.read.parquet(source) .write.saveAsTable(f"{database}.{tbl}")) - if refdata_loader == DuckDBRefDataLoader: + if isinstance(pipeline, DDBDVEPipeline): ref_db_file = Path(ctxt.get_processing_location(context), f"{database}.duckdb").as_posix() - refdata_loader.connection.sql(f"ATTACH '{ref_db_file}' AS {database}") + pipeline._connection.sql(f"ATTACH '{ref_db_file}' AS {database}") for tbl, source in refdata_tables.items(): - refdata_loader.connection.read_parquet(source).to_table(f"{database}.{tbl}") + pipeline._connection.read_parquet(source).to_table(f"{database}.{tbl}") diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_refdata.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_refdata.py index 7ae4858..ff73f85 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_refdata.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_ddb_refdata.py @@ -19,110 +19,118 @@ def temp_working_dir(): shutil.copytree(refdata_path.as_posix(), tmp, dirs_exist_ok=True) yield tmp -@pytest.fixture(scope="function") -def ddb_refdata_loader(temp_working_dir, temp_ddb_conn): - _, conn = temp_ddb_conn - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = temp_working_dir - yield DuckDBRefDataLoader, temp_working_dir @pytest.fixture(scope="function") -def ddb_refdata_table(ddb_refdata_loader): - refdata_loader, _ = ddb_refdata_loader +def ddb_refdata_table(temp_ddb_conn): + _, conn = temp_ddb_conn schema = "dve_" + uuid4().hex tbl = "movies_sequels" - refdata_loader.connection.sql(f"CREATE SCHEMA IF NOT EXISTS {schema}") - refdata_loader.connection.read_parquet(get_test_file_path("movies/refdata/movies_sequels.parquet").as_posix()).to_table(f"{schema}.{tbl}") + conn.sql(f"CREATE SCHEMA IF NOT EXISTS {schema}") + conn.read_parquet(get_test_file_path("movies/refdata/movies_sequels.parquet").as_posix()).to_table(f"{schema}.{tbl}") yield schema, tbl - refdata_loader.connection.sql(f"DROP TABLE IF EXISTS {schema}.{tbl}") - refdata_loader.connection.sql(f"DROP SCHEMA IF EXISTS {schema}") + conn.sql(f"DROP TABLE IF EXISTS {schema}.{tbl}") + conn.sql(f"DROP SCHEMA IF EXISTS {schema}") -def test_load_arrow_file(ddb_refdata_loader): - refdata_loader, _ = ddb_refdata_loader +def test_load_arrow_file(temp_working_dir, temp_ddb_conn): + _, conn = temp_ddb_conn config = { "test_refdata": ReferenceFile(type="filename", filename="./movies_sequels.arrow") } - duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + duckdb_refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = duckdb_refdata_loader.load_file(config.get("test_refdata")) assert test.shape == (3, 3) -def test_load_parquet_file(ddb_refdata_loader): - refdata_loader, _ = ddb_refdata_loader +def test_load_parquet_file(temp_working_dir, temp_ddb_conn): + _, conn = temp_ddb_conn config = { "test_refdata": ReferenceFile(type="filename", filename="./movies_sequels.parquet") } - duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + duckdb_refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = duckdb_refdata_loader.load_file(config.get("test_refdata")) assert test.shape == (2, 3) -def test_load_uri_parquet(ddb_refdata_loader): - refdata_dir: Path - refdata_loader, refdata_dir = ddb_refdata_loader +def test_load_uri_parquet(temp_working_dir, temp_ddb_conn): + _, conn = temp_ddb_conn config = { "test_refdata": ReferenceURI(type="uri", - uri=Path(refdata_dir).joinpath("movies_sequels.parquet").as_posix()) + uri=Path(temp_working_dir).joinpath("movies_sequels.parquet").as_posix()) } - duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + duckdb_refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = duckdb_refdata_loader.load_uri(config.get("test_refdata")) assert test.shape == (2, 3) -def test_load_uri_arrow(ddb_refdata_loader): - refdata_loader, refdata_dir = ddb_refdata_loader +def test_load_uri_arrow(temp_working_dir, temp_ddb_conn): + _, conn = temp_ddb_conn config = { "test_refdata": ReferenceURI(type="uri", - uri=Path(refdata_dir).joinpath("movies_sequels.arrow").as_posix()) + uri=Path(temp_working_dir).joinpath("movies_sequels.arrow").as_posix()) } - duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + duckdb_refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = duckdb_refdata_loader.load_uri(config.get("test_refdata")) assert test.shape == (3, 3) -def test_table_read(ddb_refdata_loader, ddb_refdata_table): - refdata_loader, _ = ddb_refdata_loader +def test_table_read(temp_working_dir, temp_ddb_conn, ddb_refdata_table): + _, conn = temp_ddb_conn db, tbl = ddb_refdata_table config = { "test_refdata": ReferenceTable(type="table", table_name=tbl, database=db) } - duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + duckdb_refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = duckdb_refdata_loader.load_table(config.get("test_refdata")) assert test.shape == (2, 3) -def test_via_entity_manager(ddb_refdata_loader, ddb_refdata_table): - refdata_loader, refdata_dir = ddb_refdata_loader +def test_via_entity_manager(temp_working_dir, temp_ddb_conn, ddb_refdata_table): + _, conn = temp_ddb_conn db, tbl = ddb_refdata_table config = { "test_refdata_file": ReferenceFile(type="filename", filename="./movies_sequels.arrow"), "test_refdata_uri": ReferenceURI(type="uri", - uri=Path(refdata_dir).joinpath("movies_sequels.parquet").as_posix()), + uri=Path(temp_working_dir).joinpath("movies_sequels.parquet").as_posix()), "test_refdata_table": ReferenceTable(type="table", table_name=tbl, database=db) } - em = EntityManager({}, reference_data=refdata_loader(config)) + refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) + em = EntityManager({}, reference_data=refdata_loader) assert em.get("refdata_test_refdata_file").shape == (3, 3) assert em.get("refdata_test_refdata_uri").shape == (2, 3) assert em.get("refdata_test_refdata_table").shape == (2, 3) -def test_refdata_error(ddb_refdata_loader): - refdata_loader, refdata_dir = ddb_refdata_loader +def test_refdata_error(temp_working_dir, temp_ddb_conn): + _, conn = temp_ddb_conn config = { "test_refdata_file": ReferenceFile(type="filename", filename="./movies_sequels.arrow") } - duckdb_refdata_loader: DuckDBRefDataLoader = refdata_loader(config) + duckdb_refdata_loader: DuckDBRefDataLoader = DuckDBRefDataLoader(connection=conn, + reference_data_config=config, + dataset_config_uri=temp_working_dir) with pytest.raises(MissingRefDataEntity): duckdb_refdata_loader["missing_refdata"] \ No newline at end of file diff --git a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_refdata.py b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_refdata.py index b50b9bb..8c60619 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_refdata.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_refdata.py @@ -2,7 +2,7 @@ import shutil import pytest -from dve.core_engine.backends.exceptions import MissingRefDataEntity, RefdataLacksFileExtensionSupport +from dve.core_engine.backends.exceptions import MissingRefDataEntity from dve.core_engine.backends.implementations.spark.reference_data import SparkRefDataLoader from dve.core_engine.backends.base.core import EntityManager from dve.core_engine.backends.base.reference_data import ReferenceFile, ReferenceTable, ReferenceURI @@ -19,83 +19,84 @@ def temp_working_dir(): yield tmp @pytest.fixture(scope="function") -def spark_refdata_loader(spark, temp_working_dir): - SparkRefDataLoader.spark = spark - SparkRefDataLoader.dataset_config_uri = temp_working_dir - yield SparkRefDataLoader, temp_working_dir - -@pytest.fixture(scope="function") -def spark_refdata_table(spark_refdata_loader, spark_test_database): - refdata_loader, _ = spark_refdata_loader +def spark_refdata_table(spark, spark_test_database): tbl = "movies_sequels" - refdata_loader.spark.read.parquet(get_test_file_path("movies/refdata/movies_sequels.parquet").as_posix()).write.saveAsTable(f"{spark_test_database}.{tbl}") + spark.read.parquet(get_test_file_path("movies/refdata/movies_sequels.parquet").as_posix()).write.saveAsTable(f"{spark_test_database}.{tbl}") yield spark_test_database, tbl - refdata_loader.spark.sql(f"DROP TABLE IF EXISTS {spark_test_database}.{tbl}") + spark.sql(f"DROP TABLE IF EXISTS {spark_test_database}.{tbl}") -def test_load_parquet_file(spark_refdata_loader): - refdata_loader, _ = spark_refdata_loader +def test_load_parquet_file(spark, temp_working_dir): config = { "test_refdata": ReferenceFile(type="filename", filename="./movies_sequels.parquet") } - spk_refdata_loader: SparkRefDataLoader = refdata_loader(config) + spk_refdata_loader: SparkRefDataLoader = SparkRefDataLoader(spark=spark, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = spk_refdata_loader.load_file(config.get("test_refdata")) assert test.count() == 2 -def test_load_uri_parquet(spark_refdata_loader): - refdata_dir: Path - refdata_loader, refdata_dir = spark_refdata_loader +def test_load_uri_parquet(spark, temp_working_dir): config = { "test_refdata": ReferenceURI(type="uri", - uri=Path(refdata_dir).joinpath("movies_sequels.parquet").as_posix()) + uri=Path(temp_working_dir).joinpath("movies_sequels.parquet").as_posix()) } - spk_refdata_loader: SparkRefDataLoader = refdata_loader(config) + spk_refdata_loader: SparkRefDataLoader = SparkRefDataLoader(spark=spark, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = spk_refdata_loader.load_uri(config.get("test_refdata")) assert test.count() == 2 -def test_table_read(spark_refdata_loader, spark_refdata_table): - refdata_loader, _ = spark_refdata_loader +def test_table_read(spark, temp_working_dir, spark_refdata_table): db, tbl = spark_refdata_table config = { "test_refdata": ReferenceTable(type="table", table_name=tbl, database=db) } - spk_refdata_loader: SparkRefDataLoader = refdata_loader(config) + spk_refdata_loader: SparkRefDataLoader = SparkRefDataLoader(spark=spark, + reference_data_config=config, + dataset_config_uri=temp_working_dir) test = spk_refdata_loader.load_table(config.get("test_refdata")) assert test.count() == 2 -def test_via_entity_manager(spark_refdata_loader, spark_refdata_table): - refdata_loader, refdata_dir = spark_refdata_loader +def test_via_entity_manager(spark, temp_working_dir, spark_refdata_table): db, tbl = spark_refdata_table config = { "test_refdata_file": ReferenceFile(type="filename", filename="./movies_sequels.parquet"), "test_refdata_uri": ReferenceURI(type="uri", - uri=Path(refdata_dir).joinpath("movies_sequels.parquet").as_posix()), + uri=Path(temp_working_dir).joinpath("movies_sequels.parquet").as_posix()), "test_refdata_table": ReferenceTable(type="table", table_name=tbl, database=db) } - em = EntityManager({}, reference_data=refdata_loader(config)) + + spk_refdata_loader: SparkRefDataLoader = SparkRefDataLoader(spark=spark, + reference_data_config=config, + dataset_config_uri=temp_working_dir) + em = EntityManager({}, reference_data=spk_refdata_loader) assert em.get("refdata_test_refdata_file").count() == 2 assert em.get("refdata_test_refdata_uri").count() == 2 assert em.get("refdata_test_refdata_table").count() == 2 -def test_refdata_error(spark_refdata_loader): - refdata_loader, _ = spark_refdata_loader +def test_refdata_error(spark, temp_working_dir): config = { "test_refdata_file": ReferenceFile(type="filename", filename="./movies_sequels.arrow") } - em = EntityManager({}, reference_data=refdata_loader(config)) + + spk_refdata_loader: SparkRefDataLoader = SparkRefDataLoader(spark=spark, + reference_data_config=config, + dataset_config_uri=temp_working_dir) + em = EntityManager({}, reference_data=spk_refdata_loader) with pytest.raises(MissingRefDataEntity): em["refdata_missing"] em["refdata_test_refdata_file"] diff --git a/tests/test_core_engine/test_engine.py b/tests/test_core_engine/test_engine.py index ef23d71..7e0fd6e 100644 --- a/tests/test_core_engine/test_engine.py +++ b/tests/test_core_engine/test_engine.py @@ -29,8 +29,7 @@ def test_dummy_planet_run(self, spark: SparkSession, temp_dir: str): dataset_config_path=config_path.as_posix(), output_prefix=Path(temp_dir), backend=SparkBackend(dataset_config_uri=config_path.parent.as_posix(), - spark_session=spark, - reference_data_loader=refdata_loader) + spark_session=spark) ) with test_instance: diff --git a/tests/test_pipeline/test_duckdb_pipeline.py b/tests/test_pipeline/test_duckdb_pipeline.py index 29e0734..aa65516 100644 --- a/tests/test_pipeline/test_duckdb_pipeline.py +++ b/tests/test_pipeline/test_duckdb_pipeline.py @@ -148,9 +148,6 @@ def test_business_rule_step( db_file, conn = temp_ddb_conn sub_info, processed_files_path = planets_data_after_data_contract - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) - with DDBAuditingManager(db_file.as_uri(), ThreadPoolExecutor(1), conn) as audit_manager: dve_pipeline = DDBDVEPipeline( processed_files_path=processed_files_path, @@ -159,7 +156,6 @@ def test_business_rule_step( connection=conn, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=DuckDBRefDataLoader, ) audit_manager.add_new_submissions([sub_info], job_run_id=1) @@ -187,9 +183,6 @@ def test_error_report_step( db_file, conn = temp_ddb_conn submitted_file_info, processed_files_path, status = planets_data_after_business_rules - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) - with DDBAuditingManager(db_file.as_uri(), ThreadPoolExecutor(1), conn) as audit_manager: dve_pipeline = DDBDVEPipeline( processed_files_path=processed_files_path, @@ -198,7 +191,6 @@ def test_error_report_step( connection=conn, rules_path=None, submitted_files_path=None, - reference_data_loader=DuckDBRefDataLoader, ) reports = dve_pipeline.error_report_step( @@ -222,7 +214,6 @@ def test_get_submission_status(temp_ddb_conn): connection=conn, rules_path=None, submitted_files_path=None, - reference_data_loader=DuckDBRefDataLoader, ) dve_pipeline._logger = Mock(spec=logging.Logger) # add four submissions diff --git a/tests/test_pipeline/test_foundry_ddb_pipeline.py b/tests/test_pipeline/test_foundry_ddb_pipeline.py index 350b990..666bd90 100644 --- a/tests/test_pipeline/test_foundry_ddb_pipeline.py +++ b/tests/test_pipeline/test_foundry_ddb_pipeline.py @@ -34,10 +34,6 @@ def test_foundry_runner_validation_fail(planet_test_files, temp_ddb_conn): shutil.copytree(planet_test_files, sub_folder) - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) - - with DDBAuditingManager(db_file.as_uri(), None, conn) as audit_manager: dve_pipeline = FoundryDDBPipeline( processed_files_path=processing_folder, @@ -45,7 +41,6 @@ def test_foundry_runner_validation_fail(planet_test_files, temp_ddb_conn): connection=conn, rules_path=get_test_file_path("planets/planets_ddb.dischema.json").as_posix(), submitted_files_path=None, - reference_data_loader=DuckDBRefDataLoader, ) output_loc, report_uri, audit_files = dve_pipeline.run_pipeline(sub_info) assert fh.get_resource_exists(report_uri) @@ -69,11 +64,7 @@ def test_foundry_runner_validation_success(movies_test_files, temp_ddb_conn): datetime_received=datetime(2025,11,5)) sub_folder = processing_folder + f"/{sub_id}" - shutil.copytree(movies_test_files, sub_folder) - - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = None - + shutil.copytree(movies_test_files, sub_folder) with DDBAuditingManager(db_file.as_uri(), None, conn) as audit_manager: dve_pipeline = FoundryDDBPipeline( @@ -82,7 +73,6 @@ def test_foundry_runner_validation_success(movies_test_files, temp_ddb_conn): connection=conn, rules_path=get_test_file_path("movies/movies_ddb.dischema.json").as_posix(), submitted_files_path=None, - reference_data_loader=DuckDBRefDataLoader, ) output_loc, report_uri, audit_files = dve_pipeline.run_pipeline(sub_info) assert fh.get_resource_exists(report_uri) @@ -100,10 +90,6 @@ def test_foundry_runner_error(planet_test_files, temp_ddb_conn): shutil.copytree(planet_test_files, sub_folder) - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) - - with DDBAuditingManager(db_file.as_uri(), None, conn) as audit_manager: dve_pipeline = FoundryDDBPipeline( processed_files_path=processing_folder, @@ -111,7 +97,6 @@ def test_foundry_runner_error(planet_test_files, temp_ddb_conn): connection=conn, rules_path=get_test_file_path("planets/planets.dischema.json").as_posix(), submitted_files_path=None, - reference_data_loader=DuckDBRefDataLoader, ) output_loc, report_uri, audit_files = dve_pipeline.run_pipeline(sub_info) assert not fh.get_resource_exists(report_uri) @@ -174,9 +159,6 @@ def test_foundry_runner_with_submitted_files_path(movies_test_files, temp_ddb_co datetime_received=datetime(2025,11,5) ) - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = None - with DDBAuditingManager(db_file.as_uri(), None, conn) as audit_manager: dve_pipeline = FoundryDDBPipeline( processed_files_path=processing_folder, @@ -184,7 +166,6 @@ def test_foundry_runner_with_submitted_files_path(movies_test_files, temp_ddb_co connection=conn, rules_path=get_test_file_path("movies/movies_ddb.dischema.json").as_posix(), submitted_files_path=submitted_files_path, - reference_data_loader=DuckDBRefDataLoader, ) output_loc, report_uri, audit_files = dve_pipeline.run_pipeline(sub_info) @@ -209,9 +190,6 @@ def test_foundry_runner_error_at_bi_rules(movies_test_files, temp_ddb_conn): datetime_received=datetime(2025,11,5) ) - DuckDBRefDataLoader.connection = conn - DuckDBRefDataLoader.dataset_config_uri = None - with DDBAuditingManager(db_file.as_uri(), None, conn) as audit_manager: dve_pipeline = FoundryDDBPipeline( processed_files_path=processing_folder, @@ -219,7 +197,6 @@ def test_foundry_runner_error_at_bi_rules(movies_test_files, temp_ddb_conn): connection=conn, rules_path=get_test_file_path("movies/movies_ddb.dischema.json").as_posix(), submitted_files_path=submitted_files_path, - reference_data_loader=DuckDBRefDataLoader, ) output_loc, report_uri, audit_files = dve_pipeline.run_pipeline(sub_info) diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index 38418d6..a8f59c7 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -25,7 +25,6 @@ def test_get_submission_files_for_run(planet_test_files): # pylint: disable=red rules_path=None, processed_files_path=planet_test_files, submitted_files_path=planet_test_files, - reference_data_loader=None, ) result = list(dve_pipeline._get_submission_files_for_run()) @@ -42,7 +41,6 @@ def test_write_file_to_parquet(planet_test_files): # pylint: disable=redefined- rules_path=PLANETS_RULES_PATH, processed_files_path=planet_test_files, submitted_files_path=planet_test_files, - reference_data_loader=None, ) sub_id = uuid4().hex @@ -80,7 +78,6 @@ def test_file_transformation(planet_test_files): # pylint: disable=redefined-ou rules_path=PLANETS_RULES_PATH, processed_files_path=tdir, submitted_files_path=planet_test_files, - reference_data_loader=None, ) sub_id = uuid4().hex diff --git a/tests/test_pipeline/test_spark_pipeline.py b/tests/test_pipeline/test_spark_pipeline.py index 262d84f..b3048a1 100644 --- a/tests/test_pipeline/test_spark_pipeline.py +++ b/tests/test_pipeline/test_spark_pipeline.py @@ -49,7 +49,6 @@ def test_audit_received_step(planet_test_files, spark, spark_test_database): job_run_id=1, rules_path=None, submitted_files_path=planet_test_files, - reference_data_loader=None, ) sub_ids: Dict[str, SubmissionInfo] = {} @@ -91,7 +90,6 @@ def test_file_transformation_step( job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=planet_test_files, - reference_data_loader=None, spark=spark, ) sub_id = uuid4().hex @@ -129,7 +127,6 @@ def test_apply_data_contract_success( job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=None, spark=spark, ) sub_status = SubmissionStatus() @@ -150,7 +147,6 @@ def test_apply_data_contract_failed( # pylint: disable=redefined-outer-name job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=None, spark=spark, ) sub_status = SubmissionStatus() @@ -228,7 +224,6 @@ def test_data_contract_step( job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=None, ) success, failed = dve_pipeline.data_contract_step( @@ -252,9 +247,6 @@ def test_apply_business_rules_success( ): # pylint: disable=redefined-outer-name sub_info, processed_file_path = planets_data_after_data_contract - SparkRefDataLoader.spark = spark - SparkRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) - with SparkAuditingManager(spark_test_database, ThreadPoolExecutor(1), spark) as audit_manager: dve_pipeline = SparkDVEPipeline( processed_files_path=processed_file_path, @@ -262,7 +254,6 @@ def test_apply_business_rules_success( job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=SparkRefDataLoader, spark=spark, ) @@ -296,9 +287,6 @@ def test_apply_business_rules_with_data_errors( # pylint: disable=redefined-out spark_test_database, ): sub_info, processed_file_path = planets_data_after_data_contract_that_break_business_rules - - SparkRefDataLoader.spark = spark - SparkRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) with SparkAuditingManager(spark_test_database, ThreadPoolExecutor(1), spark) as audit_manager: dve_pipeline = SparkDVEPipeline( @@ -307,7 +295,6 @@ def test_apply_business_rules_with_data_errors( # pylint: disable=redefined-out job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=SparkRefDataLoader, spark=spark, ) @@ -380,9 +367,6 @@ def test_business_rule_step( ): # pylint: disable=redefined-outer-name sub_info, processed_files_path = planets_data_after_data_contract - SparkRefDataLoader.spark = spark - SparkRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) - with SparkAuditingManager(spark_test_database, ThreadPoolExecutor(1), spark) as audit_manager: dve_pipeline = SparkDVEPipeline( processed_files_path=processed_files_path, @@ -390,7 +374,6 @@ def test_business_rule_step( job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=SparkRefDataLoader, spark=spark, ) audit_manager.add_new_submissions([sub_info], job_run_id=1) @@ -416,15 +399,12 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out ): sub_info, processed_file_path = error_data_after_business_rules - SparkRefDataLoader.spark = spark - dve_pipeline = SparkDVEPipeline( processed_files_path=processed_file_path, audit_tables=None, job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=None, - reference_data_loader=SparkRefDataLoader, spark=spark, ) @@ -545,7 +525,6 @@ def test_error_report_step( job_run_id=1, rules_path=None, submitted_files_path=None, - reference_data_loader=None, spark=spark, ) @@ -564,8 +543,7 @@ def test_error_report_step( def test_cluster_pipeline_run( spark: SparkSession, planet_test_files: str, spark_test_database ): # pylint: disable=redefined-outer-name - SparkRefDataLoader.spark = spark - SparkRefDataLoader.dataset_config_uri = fh.get_parent(PLANETS_RULES_PATH) + audit_manager = SparkAuditingManager(spark_test_database, ThreadPoolExecutor(1), spark) dve_pipeline = SparkDVEPipeline( @@ -574,7 +552,6 @@ def test_cluster_pipeline_run( job_run_id=1, rules_path=PLANETS_RULES_PATH, submitted_files_path=planet_test_files, - reference_data_loader=SparkRefDataLoader, spark=spark, ) @@ -595,7 +572,6 @@ def test_get_submission_status(spark, spark_test_database): job_run_id=1, rules_path=None, submitted_files_path=None, - reference_data_loader=None, spark=spark, ) dve_pipeline._logger = Mock(spec=logging.Logger)