@@ -51,7 +51,23 @@ def test_intersection(self):
5151 mix = mix_cls ()
5252 assert mix .get_source () == ["b" , "c" ]
5353
54- def test_file_select (self ):
54+ def test_no_select (self ):
55+ SourceA = get_simple_source ("SourceA" , ["a" , "b" , "c" ])
56+ SourceB = get_simple_source ("SourceB" , ["b" , "c" , "d" , "e" ])
57+ mix_cls = SourceIntersection .create ([SourceA , SourceB ])
58+ mix = mix_cls ()
59+ with pytest .raises (ValueError ):
60+ mix .select ()
61+
62+ def select (mod , * args , ** kwargs ):
63+ return "SourceA"
64+
65+ mix .set_select (select )
66+ assert isinstance (mix .select (), SourceA )
67+
68+ def test_select_unbound (self ):
69+ """Select function is defined outside of interface."""
70+
5571 class SourceA (SimpleSource ):
5672 source_loc = "a"
5773
@@ -81,6 +97,42 @@ class DataInterfaceMix(DataInterface):
8197
8298 assert di .source .apply_select ("get_filename" ) == "file_a_0"
8399
100+ def get_interface (self ):
101+ class DataInterfaceMix (DataInterface ):
102+ Parameters = ParametersDict
103+
104+ class SourceA (SimpleSource ):
105+ source_loc = "a"
106+
107+ def get_filename (self , ** fixes ):
108+ param = fixes .get ("param" , self .parameters ["param" ])
109+ return f"file_a_{ param } "
110+
111+ class SourceB (SimpleSource ):
112+ source_loc = "b"
113+
114+ def get_filename (self , ** fixes ):
115+ param = fixes .get ("param" , self .parameters ["param" ])
116+ return f"file_b_{ param } "
117+
118+ @staticmethod
119+ def select (module , ** kwargs ):
120+ return kwargs .get ("selected" , module .parameters ["selected" ])
121+
122+ Source = SourceUnion .create ([SourceA , SourceB ], select_func = select )
123+
124+ return DataInterfaceMix
125+
126+ def test_select_bound (self ):
127+ """Select is defined as static method."""
128+ di = self .get_interface ()(param = 0 , selected = "SourceA" )
129+ assert di .source .apply_select ("get_filename" ) == "file_a_0"
130+
131+ def test_file_select (self ):
132+ di = self .get_interface ()(param = 0 , selected = "SourceA" )
133+
134+ assert di .source .apply_select ("get_filename" ) == "file_a_0"
135+
84136 di .parameters ["param" ] = 1
85137 di .parameters ["selected" ] = "SourceB"
86138 assert di .source .apply_select ("get_filename" ) == "file_b_1"
@@ -94,8 +146,49 @@ class DataInterfaceMix(DataInterface):
94146 == "file_a_2"
95147 )
96148
97- # automatic dispatch
98- assert di .source .get_filename () == "file_b_1"
149+ def test_apply (self ):
150+ di = self .get_interface ()(param = 0 , selected = "SourceA" )
151+ assert di .source .apply ("get_filename" , all = True , param = 1 ) == [
152+ "file_a_1" ,
153+ "file_b_1" ,
154+ ]
155+
156+ assert di .source .apply ("get_filename" , all = False , param = 1 ) == "file_a_1"
157+
158+ def test_automatic_dispatch (self ):
159+ di = self .get_interface ()(param = 0 , selected = "SourceA" )
160+
161+ assert di .source .get_filename () == "file_a_0"
162+ di .parameters ["selected" ] = "SourceB"
163+ assert di .source .get_filename () == "file_b_0"
164+
165+ # disabled
166+ di .source ._auto_dispatch_getattr = False
167+ with pytest .raises (AttributeError ):
168+ di .source .get_filename ()
169+
170+ # attribute does not exist in base classes
171+ with pytest .raises (AttributeError ):
172+ di .source .unknown_attribute ()
173+
174+ # exception in selection function: no infinie recursion
175+ def select (mod , ** kwargs ):
176+ raise ValueError
177+
178+ di .source .set_select (select )
179+ with pytest .raises (ValueError ):
180+ di .source .select ()
181+ with pytest .raises (AttributeError ):
182+ di .source .get_filename ()
183+
184+ def test_bad_select (self ):
185+ di = self .get_interface ()
186+
187+ def select (mod , ** kwargs ):
188+ return "NonExistentBaseModule"
189+
190+ with pytest .raises (AttributeError ):
191+ di .source .select ()
99192
100193
101194def setup_multiple_files (tmpdir , var : str = "A" ) -> list [str ]:
@@ -129,6 +222,7 @@ class MyDataInterface(DataInterface):
129222 Parameters = ParametersDict
130223
131224 class Source (GlobSource ):
225+ # here we test root_directory as a simple str
132226 def get_root_directory (self ):
133227 return str (tmpdir )
134228
@@ -140,6 +234,11 @@ def get_glob_pattern(self):
140234 di = MyDataInterface (var = "A" )
141235 assert di .get_source () == ref_filenames
142236
237+ # test relative
238+ assert di .get_source (relative = True ) == [
239+ f .removeprefix (str (tmpdir ) + "/" ) for f in ref_filenames
240+ ]
241+
143242 # check files cached
144243 assert di .source .cache ["datafiles" ] == ref_filenames
145244
@@ -155,8 +254,9 @@ class MyDataInterface(DataInterface):
155254 Parameters = ParametersDict
156255
157256 class Source (FileFinderSource ):
257+ # here we test root_directory as a list
158258 def get_root_directory (self ):
159- return str (tmpdir )
259+ return [ str (tmpdir ), "subdir" ]
160260
161261 def get_filename_pattern (self ):
162262 var = self .parameters ["var" ]
@@ -165,11 +265,16 @@ def get_filename_pattern(self):
165265 return MyDataInterface
166266
167267 def test_get_source (self , tmpdir ):
168- ref_filenames = setup_multiple_files (tmpdir , var = "A" )
268+ ref_filenames = setup_multiple_files (tmpdir / "subdir" , var = "A" )
169269
170270 di = self .setup_interface (tmpdir )(var = "A" )
171271 assert di .get_source () == ref_filenames
172272
273+ # test relative
274+ assert di .get_source (relative = True ) == [
275+ f .removeprefix (str (tmpdir / "subdir" ) + "/" ) for f in ref_filenames
276+ ]
277+
173278 # check files cached
174279 assert di .source .cache ["datafiles" ] == ref_filenames
175280
@@ -179,7 +284,7 @@ def test_get_source(self, tmpdir):
179284 assert len (di .get_source ()) == 0
180285
181286 def test_fixes (self , tmpdir ):
182- ref_filenames = setup_multiple_files (tmpdir , var = "A" )
287+ ref_filenames = setup_multiple_files (tmpdir / "subdir" , var = "A" )
183288
184289 di = self .setup_interface (tmpdir )(var = "A" , Y = "2010" )
185290
0 commit comments