Skip to content

Commit 91889c4

Browse files
galenlynchclaudepre-commit-ci[bot]alejoe91
authored
Add regression test for GH-4508 (deprecated metric names in legacy WE folders) (#4514)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alessio Buccino <alejoe9187@gmail.com>
1 parent 2c6f251 commit 91889c4

2 files changed

Lines changed: 147 additions & 0 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ metrics = [
122122
]
123123

124124
test_core = [
125+
"pandas<3",
125126
"pytest<8.4.0",
126127
"psutil",
127128

src/spikeinterface/core/tests/test_waveforms_extractor_backwards_compatibility.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import pytest
22
from pathlib import Path
33

4+
import json
45
import shutil
56

67
import numpy as np
78

89
from spikeinterface.core import generate_ground_truth_recording, SortingAnalyzer
10+
from spikeinterface.core.core_tools import SIJsonEncoder
911

1012
from spikeinterface.core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor
1113
from spikeinterface.core.waveforms_extractor_backwards_compatibility import extract_waveforms as mock_extract_waveforms
@@ -104,6 +106,150 @@ def test_read_old_waveforms_extractor_binary():
104106
print(type(data))
105107

106108

109+
def _create_legacy_we_folder(recording, sorting, folder):
110+
"""Build a minimal legacy WaveformExtractor binary folder on disk.
111+
112+
Creates just enough structure for ``_read_old_waveforms_extractor_binary``
113+
to load: the top-level ``params.json``, ``recording_info/``, and a
114+
serialised sorting object. No waveform data is written – only the
115+
skeleton that extension sub-folders hang off of.
116+
"""
117+
from spikeinterface.core.recording_tools import get_rec_attributes
118+
119+
folder.mkdir(parents=True, exist_ok=True)
120+
121+
# params.json
122+
params = {
123+
"ms_before": 1.0,
124+
"ms_after": 2.0,
125+
"return_scaled": True,
126+
"dtype": "float32",
127+
}
128+
with open(folder / "params.json", "w") as f:
129+
json.dump(params, f)
130+
131+
# recording_info/
132+
rec_info_folder = folder / "recording_info"
133+
rec_info_folder.mkdir()
134+
rec_attributes = get_rec_attributes(recording)
135+
rec_attributes["probegroup"] = None
136+
with open(rec_info_folder / "recording_attributes.json", "w") as f:
137+
json.dump(rec_attributes, f, cls=SIJsonEncoder)
138+
139+
# No need to serialize the sorting on disk – the test passes it
140+
# directly via the ``sorting`` argument of ``load_waveforms_backwards``.
141+
142+
return folder
143+
144+
145+
def _add_legacy_quality_metrics(folder, unit_ids):
146+
"""Add a ``quality_metrics/`` sub-folder with deprecated 0.100-era params."""
147+
import pandas as pd
148+
149+
ext_folder = folder / "quality_metrics"
150+
ext_folder.mkdir()
151+
152+
deprecated_params = {
153+
"metric_names": [
154+
"num_spikes",
155+
"firing_rate",
156+
"snr",
157+
"isolation_distance",
158+
"l_ratio",
159+
],
160+
"qm_params": {
161+
"num_spikes": {},
162+
"firing_rate": {},
163+
"snr": {"peak_sign": "neg", "peak_mode": "extremum"},
164+
"isolation_distance": {},
165+
"l_ratio": {},
166+
"amplitude_cutoff": {"peak_sign": "neg"},
167+
"amplitude_median": {"peak_sign": "neg"},
168+
},
169+
"peak_sign": "neg",
170+
"seed": None,
171+
"skip_pc_metrics": False,
172+
}
173+
with open(ext_folder / "params.json", "w") as f:
174+
json.dump(deprecated_params, f)
175+
176+
metrics_df = pd.DataFrame(index=unit_ids, columns=["num_spikes", "firing_rate", "snr"])
177+
metrics_df["num_spikes"] = 100
178+
metrics_df["firing_rate"] = 5.0
179+
metrics_df["snr"] = 10.0
180+
metrics_df.to_csv(ext_folder / "metrics.csv")
181+
182+
183+
def _add_legacy_template_metrics(folder, unit_ids):
184+
"""Add a ``template_metrics/`` sub-folder with deprecated 0.100-era params."""
185+
import pandas as pd
186+
187+
ext_folder = folder / "template_metrics"
188+
ext_folder.mkdir()
189+
190+
deprecated_params = {
191+
"metric_names": [
192+
"peak_to_valley",
193+
"peak_trough_ratio",
194+
"half_width",
195+
],
196+
"metrics_kwargs": {
197+
"upsampling_factor": 10,
198+
"window_slope_ms": 0.7,
199+
},
200+
}
201+
with open(ext_folder / "params.json", "w") as f:
202+
json.dump(deprecated_params, f)
203+
204+
metrics_df = pd.DataFrame(index=unit_ids, columns=["peak_to_valley", "peak_trough_ratio", "half_width"])
205+
metrics_df["peak_to_valley"] = 0.5
206+
metrics_df["peak_trough_ratio"] = 2.0
207+
metrics_df["half_width"] = 0.3
208+
metrics_df.to_csv(ext_folder / "metrics.csv")
209+
210+
211+
def test_load_legacy_we_with_deprecated_metrics(create_cache_folder, tmp_path):
212+
"""Regression test for GH-4508.
213+
214+
A legacy WaveformExtractor folder whose ``quality_metrics/params.json``
215+
or ``template_metrics/params.json`` contains deprecated metric names
216+
(e.g. ``l_ratio``, ``peak_to_valley``) must load without raising a
217+
``ValueError``. The backward-compatibility handler must migrate the
218+
deprecated names before validation runs.
219+
"""
220+
recording, sorting = get_dataset()
221+
222+
we_folder = tmp_path / "legacy_we_deprecated_metrics"
223+
_create_legacy_we_folder(recording, sorting, we_folder)
224+
_add_legacy_quality_metrics(we_folder, sorting.unit_ids)
225+
_add_legacy_template_metrics(we_folder, sorting.unit_ids)
226+
227+
# This would raise ValueError on main before the fix
228+
sorting_analyzer = load_waveforms_backwards(we_folder, sorting=sorting, output="SortingAnalyzer")
229+
assert isinstance(sorting_analyzer, SortingAnalyzer)
230+
231+
# quality_metrics: deprecated names should be migrated
232+
qm = sorting_analyzer.get_extension("quality_metrics")
233+
assert qm is not None
234+
qm_names = qm.params["metric_names"]
235+
# The compat handler should have removed the deprecated names
236+
assert "l_ratio" not in qm_names
237+
assert "isolation_distance" not in qm_names
238+
# qm_params should have been renamed to metric_params
239+
assert "qm_params" not in qm.params
240+
assert "metric_params" in qm.params
241+
242+
# template_metrics: deprecated names should be migrated
243+
tm = sorting_analyzer.get_extension("template_metrics")
244+
assert tm is not None
245+
tm_names = tm.params["metric_names"]
246+
assert "peak_to_valley" not in tm_names
247+
assert "peak_trough_ratio" not in tm_names
248+
# metrics_kwargs should have been renamed to metric_params
249+
assert "metrics_kwargs" not in tm.params
250+
assert "metric_params" in tm.params
251+
252+
107253
# @pytest.mark.skip("This test is run locally")
108254
# def test_read_old_waveforms_extractor_zarr():
109255
# import pandas as pd

0 commit comments

Comments
 (0)