77from pydantic import BaseModel , Field
88from typing_extensions import Annotated , Literal
99
10+ import dve .parser .file_handling as fh
1011from dve .core_engine .backends .base .core import get_entity_type
11- from dve .core_engine .backends .exceptions import MissingRefDataEntity , RefdataLacksFileExtensionSupport
12+ from dve .core_engine .backends .exceptions import (
13+ MissingRefDataEntity ,
14+ RefdataLacksFileExtensionSupport ,
15+ )
1216from dve .core_engine .backends .types import EntityType
1317from dve .core_engine .type_hints import URI , EntityName
14- import dve .parser .file_handling as fh
1518from dve .parser .file_handling .implementations .file import LocalFilesystemImplementation
1619from dve .parser .file_handling .service import _get_implementation
1720
1821_FILE_EXTENSION_NAME : str = "_REFDATA_FILE_EXTENSION"
1922"""Name of attribute added to methods where they relate
2023 to loading a particular reference file type."""
2124
25+
2226def mark_refdata_file_extension (file_extension ):
2327 """Mark a method for loading a particular file extension"""
28+
2429 def wrapper (func : Callable ):
2530 setattr (func , _FILE_EXTENSION_NAME , file_extension )
2631 return func
32+
2733 return wrapper
2834
2935
@@ -52,9 +58,11 @@ class ReferenceFile(BaseModel, frozen=True):
5258 """The object type."""
5359 filename : str
5460 """The path to the reference data relative to the contract."""
61+
5562 @property
5663 def file_extension (self ) -> str :
57- return fh .get_file_suffix (self .filename )
64+ """The file extension of the reference file"""
65+ return fh .get_file_suffix (self .filename ) # type: ignore
5866
5967
6068class ReferenceURI (BaseModel , frozen = True ):
@@ -64,9 +72,11 @@ class ReferenceURI(BaseModel, frozen=True):
6472 """The object type."""
6573 uri : str
6674 """The absolute URI of the reference data (as Parquet)."""
75+
6776 @property
6877 def file_extension (self ) -> str :
69- return fh .get_file_suffix (self .uri )
78+ """The file extension of the reference uri"""
79+ return fh .get_file_suffix (self .uri ) # type: ignore
7080
7181
7282ReferenceConfig = Union [ReferenceFile , ReferenceTable , ReferenceURI ]
@@ -91,7 +101,7 @@ class BaseRefDataLoader(Generic[EntityType], Mapping[EntityName, EntityType], AB
91101 A mapping between refdata config types and functions to call to load these configs
92102 into reference data entities
93103 """
94-
104+
95105 __reader_functions__ : ClassVar [dict [str , Callable ]] = {}
96106 """
97107 A mapping between file extensions and functions to load the file uris
@@ -108,6 +118,9 @@ class variable for the subclass.
108118 if cls is not BaseRefDataLoader :
109119 cls .__entity_type__ = get_entity_type (cls , "BaseRefDataLoader" )
110120
121+ # ensure that dicts are specific to each subclass - redefine rather
122+ # than keep the same reference
123+ cls .__reader_functions__ = {}
111124 cls .__step_functions__ = {}
112125
113126 for method_name in dir (cls ):
@@ -117,7 +130,7 @@ class variable for the subclass.
117130 method = getattr (cls , method_name , None )
118131 if method is None or not callable (method ):
119132 continue
120-
133+
121134 if ext := getattr (method , _FILE_EXTENSION_NAME , None ):
122135 cls .__reader_functions__ [ext ] = method
123136 continue
@@ -136,7 +149,7 @@ def __init__(
136149 self ,
137150 reference_entity_config : dict [EntityName , ReferenceConfig ],
138151 dataset_config_uri : Optional [URI ] = None ,
139- ** kwargs
152+ ** kwargs ,
140153 ) -> None :
141154 self .reference_entity_config = reference_entity_config
142155 self .dataset_config_uri = dataset_config_uri
@@ -164,8 +177,8 @@ def load_file(self, config: ReferenceFile) -> EntityType:
164177 try :
165178 impl = self .__reader_functions__ [config .file_extension ]
166179 return impl (self , target_location )
167- except KeyError :
168- raise RefdataLacksFileExtensionSupport (file_extension = config .file_extension )
180+ except KeyError as exc :
181+ raise RefdataLacksFileExtensionSupport (file_extension = config .file_extension ) from exc
169182
170183 def load_uri (self , config : ReferenceURI ) -> EntityType :
171184 "Load reference entity from an absolute URI"
@@ -176,8 +189,8 @@ def load_uri(self, config: ReferenceURI) -> EntityType:
176189 try :
177190 impl = self .__reader_functions__ [config .file_extension ]
178191 return impl (self , target_location )
179- except KeyError :
180- raise RefdataLacksFileExtensionSupport (file_extension = config .file_extension )
192+ except KeyError as exc :
193+ raise RefdataLacksFileExtensionSupport (file_extension = config .file_extension ) from exc
181194
182195 def load_entity (self , entity_name : EntityName , config : ReferenceConfig ) -> EntityType :
183196 """Load a reference entity given the reference config"""
0 commit comments