66from pyspark .sql import DataFrame , SparkSession
77
88from dve .core_engine .backends .base .backend import BaseBackend
9+ from dve .core_engine .backends .base .reference_data import ReferenceConfigUnion
910from dve .core_engine .backends .implementations .spark .contract import SparkDataContract
1011from dve .core_engine .backends .implementations .spark .reference_data import SparkRefDataLoader
1112from dve .core_engine .backends .implementations .spark .rules import SparkStepImplementations
1415from dve .core_engine .constants import RECORD_INDEX_COLUMN_NAME
1516from dve .core_engine .loggers import get_child_logger , get_logger
1617from dve .core_engine .models import SubmissionInfo
17- from dve .core_engine .type_hints import URI , EntityParquetLocations
18+ from dve .core_engine .type_hints import URI , EntityName , EntityParquetLocations
1819from dve .parser .file_handling import get_resource_exists , joinuri
1920
2021
@@ -26,7 +27,6 @@ def __init__(
2627 dataset_config_uri : Optional [URI ] = None ,
2728 contract : Optional [SparkDataContract ] = None ,
2829 steps : Optional [SparkStepImplementations ] = None ,
29- reference_data_loader : Optional [type [SparkRefDataLoader ]] = None ,
3030 logger : Optional [logging .Logger ] = None ,
3131 spark_session : Optional [SparkSession ] = None ,
3232 ** kwargs : Any ,
@@ -36,6 +36,8 @@ def __init__(
3636
3737 self .spark_session = spark_session or SparkSession .builder .getOrCreate ()
3838 """The Spark session for the backend."""
39+ self .dataset_config_uri = dataset_config_uri
40+ """The uri of the dischema specifying the DVE config"""
3941
4042 if contract is None :
4143 contract = SparkDataContract (
@@ -46,11 +48,27 @@ def __init__(
4648 steps = SparkStepImplementations .register_udfs (
4749 logger = get_child_logger ("SparkStepImplementations" , logger )
4850 )
49- if reference_data_loader is None :
50- reference_data_loader = SparkRefDataLoader
51- reference_data_loader .spark = self .spark_session
52- reference_data_loader .dataset_config_uri = dataset_config_uri
53- super ().__init__ (contract , steps , reference_data_loader , logger , ** kwargs )
51+ super ().__init__ (contract , steps , logger , ** kwargs )
52+
53+ def load_reference_data (
54+ self ,
55+ reference_entity_config : dict [EntityName , ReferenceConfigUnion ],
56+ submission_info : Optional [SubmissionInfo ],
57+ ):
58+ """Load the reference data as specified in the reference entity config."""
59+ sub_info_entity : Optional [DataFrame ] = None
60+ if submission_info :
61+ sub_info_entity = self .convert_submission_info (submission_info )
62+
63+ reference_data_loader = SparkRefDataLoader (
64+ spark = self .spark_session ,
65+ reference_data_config = reference_entity_config ,
66+ dataset_config_uri = self .dataset_config_uri , # type: ignore
67+ )
68+ if sub_info_entity is not None :
69+ reference_data_loader .entity_cache ["dve_submission_info" ] = sub_info_entity
70+
71+ return reference_data_loader
5472
5573 def write_entities_to_parquet (
5674 self , entities : SparkEntities , cache_prefix : URI
0 commit comments