Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ metrics = [
]

test_core = [
"pandas<3",
Comment thread
alejoe91 marked this conversation as resolved.
"pytest<8.4.0",
"psutil",

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import pytest
from pathlib import Path

import json
import shutil

import numpy as np
import pandas as pd

from spikeinterface.core import generate_ground_truth_recording, SortingAnalyzer
from spikeinterface.core.core_tools import SIJsonEncoder

from spikeinterface.core.waveforms_extractor_backwards_compatibility import MockWaveformExtractor
from spikeinterface.core.waveforms_extractor_backwards_compatibility import extract_waveforms as mock_extract_waveforms
Expand Down Expand Up @@ -104,6 +107,146 @@ def test_read_old_waveforms_extractor_binary():
print(type(data))


def _create_legacy_we_folder(recording, sorting, folder):
"""Build a minimal legacy WaveformExtractor binary folder on disk.

Creates just enough structure for ``_read_old_waveforms_extractor_binary``
to load: the top-level ``params.json``, ``recording_info/``, and a
serialised sorting object. No waveform data is written – only the
skeleton that extension sub-folders hang off of.
"""
from spikeinterface.core.recording_tools import get_rec_attributes

folder.mkdir(parents=True, exist_ok=True)

# params.json
params = {
"ms_before": 1.0,
"ms_after": 2.0,
"return_scaled": True,
"dtype": "float32",
}
with open(folder / "params.json", "w") as f:
json.dump(params, f)

# recording_info/
rec_info_folder = folder / "recording_info"
rec_info_folder.mkdir()
rec_attributes = get_rec_attributes(recording)
rec_attributes["probegroup"] = None
with open(rec_info_folder / "recording_attributes.json", "w") as f:
json.dump(rec_attributes, f, cls=SIJsonEncoder)

# No need to serialize the sorting on disk – the test passes it
# directly via the ``sorting`` argument of ``load_waveforms_backwards``.

return folder


def _add_legacy_quality_metrics(folder, unit_ids):
"""Add a ``quality_metrics/`` sub-folder with deprecated 0.100-era params."""
ext_folder = folder / "quality_metrics"
ext_folder.mkdir()

deprecated_params = {
"metric_names": [
"num_spikes",
"firing_rate",
"snr",
"isolation_distance",
"l_ratio",
],
"qm_params": {
"num_spikes": {},
"firing_rate": {},
"snr": {"peak_sign": "neg", "peak_mode": "extremum"},
"isolation_distance": {},
"l_ratio": {},
"amplitude_cutoff": {"peak_sign": "neg"},
"amplitude_median": {"peak_sign": "neg"},
},
"peak_sign": "neg",
"seed": None,
"skip_pc_metrics": False,
}
with open(ext_folder / "params.json", "w") as f:
json.dump(deprecated_params, f)

metrics_df = pd.DataFrame(index=unit_ids, columns=["num_spikes", "firing_rate", "snr"])
metrics_df["num_spikes"] = 100
metrics_df["firing_rate"] = 5.0
metrics_df["snr"] = 10.0
metrics_df.to_csv(ext_folder / "metrics.csv")


def _add_legacy_template_metrics(folder, unit_ids):
"""Add a ``template_metrics/`` sub-folder with deprecated 0.100-era params."""
ext_folder = folder / "template_metrics"
ext_folder.mkdir()

deprecated_params = {
"metric_names": [
"peak_to_valley",
"peak_trough_ratio",
"half_width",
],
"metrics_kwargs": {
"upsampling_factor": 10,
"window_slope_ms": 0.7,
},
}
with open(ext_folder / "params.json", "w") as f:
json.dump(deprecated_params, f)

metrics_df = pd.DataFrame(index=unit_ids, columns=["peak_to_valley", "peak_trough_ratio", "half_width"])
metrics_df["peak_to_valley"] = 0.5
metrics_df["peak_trough_ratio"] = 2.0
metrics_df["half_width"] = 0.3
metrics_df.to_csv(ext_folder / "metrics.csv")


def test_load_legacy_we_with_deprecated_metrics(create_cache_folder, tmp_path):
"""Regression test for GH-4508.

A legacy WaveformExtractor folder whose ``quality_metrics/params.json``
or ``template_metrics/params.json`` contains deprecated metric names
(e.g. ``l_ratio``, ``peak_to_valley``) must load without raising a
``ValueError``. The backward-compatibility handler must migrate the
deprecated names before validation runs.
"""
recording, sorting = get_dataset()

we_folder = tmp_path / "legacy_we_deprecated_metrics"
_create_legacy_we_folder(recording, sorting, we_folder)
_add_legacy_quality_metrics(we_folder, sorting.unit_ids)
_add_legacy_template_metrics(we_folder, sorting.unit_ids)

# This would raise ValueError on main before the fix
sorting_analyzer = load_waveforms_backwards(we_folder, sorting=sorting, output="SortingAnalyzer")
assert isinstance(sorting_analyzer, SortingAnalyzer)

# quality_metrics: deprecated names should be migrated
qm = sorting_analyzer.get_extension("quality_metrics")
assert qm is not None
qm_names = qm.params["metric_names"]
# The compat handler should have removed the deprecated names
assert "l_ratio" not in qm_names
assert "isolation_distance" not in qm_names
# qm_params should have been renamed to metric_params
assert "qm_params" not in qm.params
assert "metric_params" in qm.params

# template_metrics: deprecated names should be migrated
tm = sorting_analyzer.get_extension("template_metrics")
assert tm is not None
tm_names = tm.params["metric_names"]
assert "peak_to_valley" not in tm_names
assert "peak_trough_ratio" not in tm_names
# metrics_kwargs should have been renamed to metric_params
assert "metrics_kwargs" not in tm.params
assert "metric_params" in tm.params


# @pytest.mark.skip("This test is run locally")
# def test_read_old_waveforms_extractor_zarr():
# import pandas as pd
Expand Down
Loading