66from pyspark .sql import DataFrame , SparkSession
77
88from dve .core_engine .backends .base .backend import BaseBackend
9+ from dve .core_engine .backends .base .reference_data import ReferenceConfigUnion
910from dve .core_engine .backends .implementations .spark .contract import SparkDataContract
1011from dve .core_engine .backends .implementations .spark .reference_data import SparkRefDataLoader
1112from dve .core_engine .backends .implementations .spark .rules import SparkStepImplementations
1213from dve .core_engine .backends .implementations .spark .spark_helpers import get_type_from_annotation
1314from dve .core_engine .backends .implementations .spark .types import SparkEntities
15+ from dve .core_engine .backends .types import EntityType
1416from dve .core_engine .constants import RECORD_INDEX_COLUMN_NAME
1517from dve .core_engine .loggers import get_child_logger , get_logger
1618from dve .core_engine .models import SubmissionInfo
17- from dve .core_engine .type_hints import URI , EntityParquetLocations
18- from dve .parser .file_handling import get_resource_exists , joinuri
19+ from dve .core_engine .type_hints import URI , EntityName , EntityParquetLocations
20+ from dve .parser .file_handling import get_resource_exists , joinuri , get_parent
1921
2022
2123class SparkBackend (BaseBackend [DataFrame ]):
@@ -26,7 +28,6 @@ def __init__(
2628 dataset_config_uri : Optional [URI ] = None ,
2729 contract : Optional [SparkDataContract ] = None ,
2830 steps : Optional [SparkStepImplementations ] = None ,
29- reference_data_loader : Optional [type [SparkRefDataLoader ]] = None ,
3031 logger : Optional [logging .Logger ] = None ,
3132 spark_session : Optional [SparkSession ] = None ,
3233 ** kwargs : Any ,
@@ -36,6 +37,8 @@ def __init__(
3637
3738 self .spark_session = spark_session or SparkSession .builder .getOrCreate ()
3839 """The Spark session for the backend."""
40+ self .dataset_config_uri = dataset_config_uri
41+ """The uri of the dischema specifying the DVE config"""
3942
4043 if contract is None :
4144 contract = SparkDataContract (
@@ -46,11 +49,23 @@ def __init__(
4649 steps = SparkStepImplementations .register_udfs (
4750 logger = get_child_logger ("SparkStepImplementations" , logger )
4851 )
49- if reference_data_loader is None :
50- reference_data_loader = SparkRefDataLoader
51- reference_data_loader .spark = self .spark_session
52- reference_data_loader .dataset_config_uri = dataset_config_uri
53- super ().__init__ (contract , steps , reference_data_loader , logger , ** kwargs )
52+ super ().__init__ (contract , steps , logger , ** kwargs )
53+
54+ def load_reference_data (self ,
55+ reference_entity_config : dict [EntityName , ReferenceConfigUnion ],
56+ submission_info : Optional [SubmissionInfo ],):
57+ """Load the reference data as specified in the reference entity config."""
58+ sub_info_entity : Optional [EntityType ] = None
59+ if submission_info :
60+ sub_info_entity = self .convert_submission_info (submission_info )
61+
62+ reference_data_loader = SparkRefDataLoader (spark = self .spark_session ,
63+ reference_data_config = reference_entity_config ,
64+ dataset_config_uri = self .dataset_config_uri )
65+ if sub_info_entity is not None :
66+ reference_data_loader .entity_cache ["dve_submission_info" ] = sub_info_entity
67+
68+ return reference_data_loader
5469
5570 def write_entities_to_parquet (
5671 self , entities : SparkEntities , cache_prefix : URI
0 commit comments