Skip to content

Commit b2175d1

Browse files
committed
Up tests
1 parent 80ec5ba commit b2175d1

10 files changed

Lines changed: 283 additions & 49 deletions

File tree

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
## Project information
1818

1919
project = "Neba"
20-
copyright = "2023, Clément Haëck"
20+
copyright = "2026, Clément Haëck"
2121
author = "Clément Haëck"
2222

2323
version = neba.__version__

src/neba/data/source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def root_directory(self) -> Path:
8383

8484
return rootdir
8585

86-
def get_source(self, relative: bool = False, _warn: bool = True) -> list[str]: # type: ignore
86+
def get_source(self, relative: bool = False, _warn: bool = True) -> list[str]: # type: ignore[override]
8787
"""Return list of filenames.
8888
8989
Parameters

src/neba/data/xarray.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import os
77
from collections.abc import Callable, Hashable, Mapping, Sequence
8-
from typing import TYPE_CHECKING, Any, Literal, assert_never, cast, overload
8+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
99

1010
import xarray as xr
1111

@@ -162,7 +162,7 @@ def send_single_call(
162162
kwargs = self.to_zarr_kwargs | kwargs
163163
return ds.to_zarr(outfile, **kwargs)
164164

165-
assert_never(format)
165+
raise ValueError(f"File format '{format}' not supported.")
166166

167167
def add_metadata(
168168
self,
@@ -247,11 +247,8 @@ def send_calls_together(
247247
grouped_calls = calls[slc]
248248
delayed = [self.send_single_call(c, **kwargs) for c in grouped_calls]
249249

250-
# Compute them all at once
251-
# This loop is super important, this create all the futures for computation
252-
# and remove them as soon as they are completed (and the variable `future`
253-
# goes out of scope). That way the data does not pile up, it is freed.
254-
# We only care about the side effect of writing to disk, not the result data.
250+
# Futures are deleted as soon as they go out of scope. They do not pile up
251+
# but we still return only when all are completed.
255252
for future in distributed.as_completed(client.compute(delayed)):
256253
log.debug("\t\tfuture completed: %s", future)
257254

@@ -510,7 +507,7 @@ def split_by_time(
510507
"Resampling frequency is equal to that of dataset. "
511508
"Will not resample."
512509
)
513-
return [ds_unit for _, ds_unit in ds.groupby("time", squeeze=False)]
510+
return [ds.isel(time=[i]) for i in range(ds.time.size)]
514511

515512
resample = ds.resample(time=freq)
516513
return [ds_unit for _, ds_unit in resample]

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "dev").lower())
2020

2121

22+
# register dask.distributed fixtures
23+
pytest_plugins = "distributed.utils_test"
24+
2225
def pytest_configure(config: pytest.Config):
2326
config.addinivalue_line("markers", "todo")
2427

tests/data/test_interface.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,29 @@ def setup(self):
240240
is_setup.add("B")
241241
super().setup()
242242

243+
def _lines(self):
244+
return "SourceB_repr"
245+
246+
243247
class MyDataInterface(DataInterface):
244248
Source = ModuleMix.create([SourceA, SourceB])
245249

246250
MyDataInterface()
247251
assert is_setup == {"A", "B", "C"}
252+
253+
def test_repr(self):
254+
255+
class MyDataInterface(DataInterface):
256+
class SourceA(SourceAbstract):
257+
def _lines(self):
258+
return ["SourceA_repr"]
259+
class SourceB(SourceAbstract):
260+
def _lines(self):
261+
return ["SourceB_repr"]
262+
263+
Source = ModuleMix.create([SourceA, SourceB])
264+
265+
di = MyDataInterface()
266+
assert repr(di.source) == "SourceA\n\tSourceA_repr\nSourceB\n\tSourceB_repr"
267+
268+
# more tests in test_saource.TestModuleMix

tests/data/test_params.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -300,31 +300,49 @@ class sub(Section):
300300
di.sub.b = 1
301301

302302

303-
def test_autocached():
304-
class MyDataInterface(DataInterface):
305-
Parameters = ParametersDict
303+
class TestCachedModule:
306304

307-
class Loader(LoaderAbstract, CachedModule):
308-
@property
309-
@autocached
310-
def test_property(self):
311-
return 0
305+
def get_interface(self):
306+
class MyDataInterface(DataInterface):
307+
Parameters = ParametersDict
308+
309+
class Loader(LoaderAbstract, CachedModule):
310+
@property
311+
@autocached
312+
def test_property(self):
313+
return 0
314+
315+
@autocached
316+
def test_method(self):
317+
return 1
318+
319+
return MyDataInterface
320+
321+
def test_autocached(self):
322+
di = self.get_interface()()
323+
assert len(di.loader.cache) == 0
324+
325+
_ = di.loader.test_property
326+
assert di.loader.cache == dict(test_property=0)
327+
328+
_ = di.loader.test_method()
329+
assert di.loader.cache == dict(test_property=0, test_method=1)
312330

313-
@autocached
314-
def test_method(self):
315-
return 1
331+
di.trigger_callbacks()
332+
assert len(di.loader.cache) == 0
316333

317-
di = MyDataInterface()
318-
assert len(di.loader.cache) == 0
334+
def test_disable(self):
335+
di_cls = self.get_interface()
336+
di_cls.Loader._add_void_callback = False
337+
di = di_cls()
319338

320-
_ = di.loader.test_property
321-
assert di.loader.cache == dict(test_property=0)
339+
assert len(di._reset_callbacks) == 0
322340

323-
_ = di.loader.test_method()
324-
assert di.loader.cache == dict(test_property=0, test_method=1)
341+
_ = di.loader.test_property
342+
assert "test_property" in di.loader.cache
343+
di.trigger_callbacks()
344+
assert "test_property" in di.loader.cache
325345

326-
di.trigger_callbacks()
327-
assert len(di.loader.cache) == 0
328346

329347

330348
class TestParamsExcursion:

tests/data/test_source.py

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,23 @@ def test_intersection(self):
5151
mix = mix_cls()
5252
assert mix.get_source() == ["b", "c"]
5353

54-
def test_file_select(self):
54+
def test_no_select(self):
55+
SourceA = get_simple_source("SourceA", ["a", "b", "c"])
56+
SourceB = get_simple_source("SourceB", ["b", "c", "d", "e"])
57+
mix_cls = SourceIntersection.create([SourceA, SourceB])
58+
mix = mix_cls()
59+
with pytest.raises(ValueError):
60+
mix.select()
61+
62+
def select(mod, *args, **kwargs):
63+
return "SourceA"
64+
65+
mix.set_select(select)
66+
assert isinstance(mix.select(), SourceA)
67+
68+
def test_select_unbound(self):
69+
"""Select function is defined outside of interface."""
70+
5571
class SourceA(SimpleSource):
5672
source_loc = "a"
5773

@@ -81,6 +97,42 @@ class DataInterfaceMix(DataInterface):
8197

8298
assert di.source.apply_select("get_filename") == "file_a_0"
8399

100+
def get_interface(self):
101+
class DataInterfaceMix(DataInterface):
102+
Parameters = ParametersDict
103+
104+
class SourceA(SimpleSource):
105+
source_loc = "a"
106+
107+
def get_filename(self, **fixes):
108+
param = fixes.get("param", self.parameters["param"])
109+
return f"file_a_{param}"
110+
111+
class SourceB(SimpleSource):
112+
source_loc = "b"
113+
114+
def get_filename(self, **fixes):
115+
param = fixes.get("param", self.parameters["param"])
116+
return f"file_b_{param}"
117+
118+
@staticmethod
119+
def select(module, **kwargs):
120+
return kwargs.get("selected", module.parameters["selected"])
121+
122+
Source = SourceUnion.create([SourceA, SourceB], select_func=select)
123+
124+
return DataInterfaceMix
125+
126+
def test_select_bound(self):
127+
"""Select is defined as static method."""
128+
di = self.get_interface()(param=0, selected="SourceA")
129+
assert di.source.apply_select("get_filename") == "file_a_0"
130+
131+
def test_file_select(self):
132+
di = self.get_interface()(param=0, selected="SourceA")
133+
134+
assert di.source.apply_select("get_filename") == "file_a_0"
135+
84136
di.parameters["param"] = 1
85137
di.parameters["selected"] = "SourceB"
86138
assert di.source.apply_select("get_filename") == "file_b_1"
@@ -94,8 +146,49 @@ class DataInterfaceMix(DataInterface):
94146
== "file_a_2"
95147
)
96148

97-
# automatic dispatch
98-
assert di.source.get_filename() == "file_b_1"
149+
def test_apply(self):
150+
di = self.get_interface()(param=0, selected="SourceA")
151+
assert di.source.apply("get_filename", all=True, param=1) == [
152+
"file_a_1",
153+
"file_b_1",
154+
]
155+
156+
assert di.source.apply("get_filename", all=False, param=1) == "file_a_1"
157+
158+
def test_automatic_dispatch(self):
159+
di = self.get_interface()(param=0, selected="SourceA")
160+
161+
assert di.source.get_filename() == "file_a_0"
162+
di.parameters["selected"] = "SourceB"
163+
assert di.source.get_filename() == "file_b_0"
164+
165+
# disabled
166+
di.source._auto_dispatch_getattr = False
167+
with pytest.raises(AttributeError):
168+
di.source.get_filename()
169+
170+
# attribute does not exist in base classes
171+
with pytest.raises(AttributeError):
172+
di.source.unknown_attribute()
173+
174+
# exception in selection function: no infinie recursion
175+
def select(mod, **kwargs):
176+
raise ValueError
177+
178+
di.source.set_select(select)
179+
with pytest.raises(ValueError):
180+
di.source.select()
181+
with pytest.raises(AttributeError):
182+
di.source.get_filename()
183+
184+
def test_bad_select(self):
185+
di = self.get_interface()
186+
187+
def select(mod, **kwargs):
188+
return "NonExistentBaseModule"
189+
190+
with pytest.raises(AttributeError):
191+
di.source.select()
99192

100193

101194
def setup_multiple_files(tmpdir, var: str = "A") -> list[str]:
@@ -129,6 +222,7 @@ class MyDataInterface(DataInterface):
129222
Parameters = ParametersDict
130223

131224
class Source(GlobSource):
225+
# here we test root_directory as a simple str
132226
def get_root_directory(self):
133227
return str(tmpdir)
134228

@@ -140,6 +234,11 @@ def get_glob_pattern(self):
140234
di = MyDataInterface(var="A")
141235
assert di.get_source() == ref_filenames
142236

237+
# test relative
238+
assert di.get_source(relative=True) == [
239+
f.removeprefix(str(tmpdir) + "/") for f in ref_filenames
240+
]
241+
143242
# check files cached
144243
assert di.source.cache["datafiles"] == ref_filenames
145244

@@ -155,8 +254,9 @@ class MyDataInterface(DataInterface):
155254
Parameters = ParametersDict
156255

157256
class Source(FileFinderSource):
257+
# here we test root_directory as a list
158258
def get_root_directory(self):
159-
return str(tmpdir)
259+
return [str(tmpdir), "subdir"]
160260

161261
def get_filename_pattern(self):
162262
var = self.parameters["var"]
@@ -165,11 +265,16 @@ def get_filename_pattern(self):
165265
return MyDataInterface
166266

167267
def test_get_source(self, tmpdir):
168-
ref_filenames = setup_multiple_files(tmpdir, var="A")
268+
ref_filenames = setup_multiple_files(tmpdir / "subdir", var="A")
169269

170270
di = self.setup_interface(tmpdir)(var="A")
171271
assert di.get_source() == ref_filenames
172272

273+
# test relative
274+
assert di.get_source(relative=True) == [
275+
f.removeprefix(str(tmpdir / "subdir") + "/") for f in ref_filenames
276+
]
277+
173278
# check files cached
174279
assert di.source.cache["datafiles"] == ref_filenames
175280

@@ -179,7 +284,7 @@ def test_get_source(self, tmpdir):
179284
assert len(di.get_source()) == 0
180285

181286
def test_fixes(self, tmpdir):
182-
ref_filenames = setup_multiple_files(tmpdir, var="A")
287+
ref_filenames = setup_multiple_files(tmpdir / "subdir", var="A")
183288

184289
di = self.setup_interface(tmpdir)(var="A", Y="2010")
185290

0 commit comments

Comments
 (0)