|
1 | 1 | import pytest |
2 | 2 | from pathlib import Path |
3 | 3 |
|
| 4 | +import json |
4 | 5 | import shutil |
5 | 6 |
|
6 | 7 | import numpy as np |
7 | 8 |
|
8 | 9 | from spikeinterface.core import generate_ground_truth_recording, SortingAnalyzer |
| 10 | +from spikeinterface.core.core_tools import SIJsonEncoder |
9 | 11 |
|
10 | 12 | from spikeinterface.core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor |
11 | 13 | from spikeinterface.core.waveforms_extractor_backwards_compatibility import extract_waveforms as mock_extract_waveforms |
@@ -104,6 +106,150 @@ def test_read_old_waveforms_extractor_binary(): |
104 | 106 | print(type(data)) |
105 | 107 |
|
106 | 108 |
|
| 109 | +def _create_legacy_we_folder(recording, sorting, folder): |
| 110 | + """Build a minimal legacy WaveformExtractor binary folder on disk. |
| 111 | +
|
| 112 | + Creates just enough structure for ``_read_old_waveforms_extractor_binary`` |
| 113 | + to load: the top-level ``params.json``, ``recording_info/``, and a |
| 114 | + serialised sorting object. No waveform data is written – only the |
| 115 | + skeleton that extension sub-folders hang off of. |
| 116 | + """ |
| 117 | + from spikeinterface.core.recording_tools import get_rec_attributes |
| 118 | + |
| 119 | + folder.mkdir(parents=True, exist_ok=True) |
| 120 | + |
| 121 | + # params.json |
| 122 | + params = { |
| 123 | + "ms_before": 1.0, |
| 124 | + "ms_after": 2.0, |
| 125 | + "return_scaled": True, |
| 126 | + "dtype": "float32", |
| 127 | + } |
| 128 | + with open(folder / "params.json", "w") as f: |
| 129 | + json.dump(params, f) |
| 130 | + |
| 131 | + # recording_info/ |
| 132 | + rec_info_folder = folder / "recording_info" |
| 133 | + rec_info_folder.mkdir() |
| 134 | + rec_attributes = get_rec_attributes(recording) |
| 135 | + rec_attributes["probegroup"] = None |
| 136 | + with open(rec_info_folder / "recording_attributes.json", "w") as f: |
| 137 | + json.dump(rec_attributes, f, cls=SIJsonEncoder) |
| 138 | + |
| 139 | + # No need to serialize the sorting on disk – the test passes it |
| 140 | + # directly via the ``sorting`` argument of ``load_waveforms_backwards``. |
| 141 | + |
| 142 | + return folder |
| 143 | + |
| 144 | + |
| 145 | +def _add_legacy_quality_metrics(folder, unit_ids): |
| 146 | + """Add a ``quality_metrics/`` sub-folder with deprecated 0.100-era params.""" |
| 147 | + import pandas as pd |
| 148 | + |
| 149 | + ext_folder = folder / "quality_metrics" |
| 150 | + ext_folder.mkdir() |
| 151 | + |
| 152 | + deprecated_params = { |
| 153 | + "metric_names": [ |
| 154 | + "num_spikes", |
| 155 | + "firing_rate", |
| 156 | + "snr", |
| 157 | + "isolation_distance", |
| 158 | + "l_ratio", |
| 159 | + ], |
| 160 | + "qm_params": { |
| 161 | + "num_spikes": {}, |
| 162 | + "firing_rate": {}, |
| 163 | + "snr": {"peak_sign": "neg", "peak_mode": "extremum"}, |
| 164 | + "isolation_distance": {}, |
| 165 | + "l_ratio": {}, |
| 166 | + "amplitude_cutoff": {"peak_sign": "neg"}, |
| 167 | + "amplitude_median": {"peak_sign": "neg"}, |
| 168 | + }, |
| 169 | + "peak_sign": "neg", |
| 170 | + "seed": None, |
| 171 | + "skip_pc_metrics": False, |
| 172 | + } |
| 173 | + with open(ext_folder / "params.json", "w") as f: |
| 174 | + json.dump(deprecated_params, f) |
| 175 | + |
| 176 | + metrics_df = pd.DataFrame(index=unit_ids, columns=["num_spikes", "firing_rate", "snr"]) |
| 177 | + metrics_df["num_spikes"] = 100 |
| 178 | + metrics_df["firing_rate"] = 5.0 |
| 179 | + metrics_df["snr"] = 10.0 |
| 180 | + metrics_df.to_csv(ext_folder / "metrics.csv") |
| 181 | + |
| 182 | + |
| 183 | +def _add_legacy_template_metrics(folder, unit_ids): |
| 184 | + """Add a ``template_metrics/`` sub-folder with deprecated 0.100-era params.""" |
| 185 | + import pandas as pd |
| 186 | + |
| 187 | + ext_folder = folder / "template_metrics" |
| 188 | + ext_folder.mkdir() |
| 189 | + |
| 190 | + deprecated_params = { |
| 191 | + "metric_names": [ |
| 192 | + "peak_to_valley", |
| 193 | + "peak_trough_ratio", |
| 194 | + "half_width", |
| 195 | + ], |
| 196 | + "metrics_kwargs": { |
| 197 | + "upsampling_factor": 10, |
| 198 | + "window_slope_ms": 0.7, |
| 199 | + }, |
| 200 | + } |
| 201 | + with open(ext_folder / "params.json", "w") as f: |
| 202 | + json.dump(deprecated_params, f) |
| 203 | + |
| 204 | + metrics_df = pd.DataFrame(index=unit_ids, columns=["peak_to_valley", "peak_trough_ratio", "half_width"]) |
| 205 | + metrics_df["peak_to_valley"] = 0.5 |
| 206 | + metrics_df["peak_trough_ratio"] = 2.0 |
| 207 | + metrics_df["half_width"] = 0.3 |
| 208 | + metrics_df.to_csv(ext_folder / "metrics.csv") |
| 209 | + |
| 210 | + |
| 211 | +def test_load_legacy_we_with_deprecated_metrics(create_cache_folder, tmp_path): |
| 212 | + """Regression test for GH-4508. |
| 213 | +
|
| 214 | + A legacy WaveformExtractor folder whose ``quality_metrics/params.json`` |
| 215 | + or ``template_metrics/params.json`` contains deprecated metric names |
| 216 | + (e.g. ``l_ratio``, ``peak_to_valley``) must load without raising a |
| 217 | + ``ValueError``. The backward-compatibility handler must migrate the |
| 218 | + deprecated names before validation runs. |
| 219 | + """ |
| 220 | + recording, sorting = get_dataset() |
| 221 | + |
| 222 | + we_folder = tmp_path / "legacy_we_deprecated_metrics" |
| 223 | + _create_legacy_we_folder(recording, sorting, we_folder) |
| 224 | + _add_legacy_quality_metrics(we_folder, sorting.unit_ids) |
| 225 | + _add_legacy_template_metrics(we_folder, sorting.unit_ids) |
| 226 | + |
| 227 | + # This would raise ValueError on main before the fix |
| 228 | + sorting_analyzer = load_waveforms_backwards(we_folder, sorting=sorting, output="SortingAnalyzer") |
| 229 | + assert isinstance(sorting_analyzer, SortingAnalyzer) |
| 230 | + |
| 231 | + # quality_metrics: deprecated names should be migrated |
| 232 | + qm = sorting_analyzer.get_extension("quality_metrics") |
| 233 | + assert qm is not None |
| 234 | + qm_names = qm.params["metric_names"] |
| 235 | + # The compat handler should have removed the deprecated names |
| 236 | + assert "l_ratio" not in qm_names |
| 237 | + assert "isolation_distance" not in qm_names |
| 238 | + # qm_params should have been renamed to metric_params |
| 239 | + assert "qm_params" not in qm.params |
| 240 | + assert "metric_params" in qm.params |
| 241 | + |
| 242 | + # template_metrics: deprecated names should be migrated |
| 243 | + tm = sorting_analyzer.get_extension("template_metrics") |
| 244 | + assert tm is not None |
| 245 | + tm_names = tm.params["metric_names"] |
| 246 | + assert "peak_to_valley" not in tm_names |
| 247 | + assert "peak_trough_ratio" not in tm_names |
| 248 | + # metrics_kwargs should have been renamed to metric_params |
| 249 | + assert "metrics_kwargs" not in tm.params |
| 250 | + assert "metric_params" in tm.params |
| 251 | + |
| 252 | + |
107 | 253 | # @pytest.mark.skip("This test is run locally") |
108 | 254 | # def test_read_old_waveforms_extractor_zarr(): |
109 | 255 | # import pandas as pd |
|
0 commit comments