|
1 | 1 | import pytest |
2 | 2 | from pathlib import Path |
| 3 | + |
3 | 4 | from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path |
4 | 5 | from spikeinterface.curation.model_based_curation import ModelBasedClassification |
5 | | -from spikeinterface.curation import auto_label_units, load_model |
| 6 | +from spikeinterface.curation import model_based_label_units, load_model |
6 | 7 | from spikeinterface.curation.train_manual_curation import _get_computed_metrics |
| 8 | +from spikeinterface.curation import unitrefine_label_units |
| 9 | + |
7 | 10 |
|
8 | 11 | import numpy as np |
9 | 12 |
|
@@ -39,21 +42,21 @@ def test_model_based_classification_init(sorting_analyzer_for_curation, model): |
39 | 42 |
|
40 | 43 |
|
41 | 44 | def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pipeline_path): |
42 | | - """The function `auto_label_units` needs the correct metrics to have been computed. However, |
| 45 | + """The function `model_based_label_units` needs the correct metrics to have been computed. However, |
43 | 46 | it should be independent of the order of computation. We test this here.""" |
44 | 47 |
|
45 | 48 | sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"]) |
46 | 49 | sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"]) |
47 | 50 |
|
48 | | - prediction_prob_dataframe_1 = auto_label_units( |
| 51 | + prediction_prob_dataframe_1 = model_based_label_units( |
49 | 52 | sorting_analyzer=sorting_analyzer_for_curation, |
50 | 53 | model_folder=trained_pipeline_path, |
51 | 54 | trusted=["numpy.dtype"], |
52 | 55 | ) |
53 | 56 |
|
54 | 57 | sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"]) |
55 | 58 |
|
56 | | - prediction_prob_dataframe_2 = auto_label_units( |
| 59 | + prediction_prob_dataframe_2 = model_based_label_units( |
57 | 60 | sorting_analyzer=sorting_analyzer_for_curation, |
58 | 61 | model_folder=trained_pipeline_path, |
59 | 62 | trusted=["numpy.dtype"], |
@@ -168,3 +171,83 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura |
168 | 171 | model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"]) |
169 | 172 | model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) |
170 | 173 | model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) |
| 174 | + |
| 175 | + |
| 176 | +def test_unitrefine_label_units_hf(sorting_analyzer_for_curation): |
| 177 | + """Test the `unitrefine_label_units` function.""" |
| 178 | + sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True) |
| 179 | + sorting_analyzer_for_curation.compute("quality_metrics") |
| 180 | + |
| 181 | + # test passing both classifiers |
| 182 | + labels = unitrefine_label_units( |
| 183 | + sorting_analyzer_for_curation, |
| 184 | + noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", |
| 185 | + sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", |
| 186 | + ) |
| 187 | + |
| 188 | + assert "label" in labels.columns |
| 189 | + assert "probability" in labels.columns |
| 190 | + assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) |
| 191 | + |
| 192 | + # test only noise neural classifier |
| 193 | + labels = unitrefine_label_units( |
| 194 | + sorting_analyzer_for_curation, |
| 195 | + noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", |
| 196 | + sua_mua_classifier=None, |
| 197 | + ) |
| 198 | + |
| 199 | + assert "label" in labels.columns |
| 200 | + assert "probability" in labels.columns |
| 201 | + assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) |
| 202 | + |
| 203 | + # test only sua mua classifier |
| 204 | + labels = unitrefine_label_units( |
| 205 | + sorting_analyzer_for_curation, |
| 206 | + noise_neural_classifier=None, |
| 207 | + sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", |
| 208 | + ) |
| 209 | + |
| 210 | + assert "label" in labels.columns |
| 211 | + assert "probability" in labels.columns |
| 212 | + assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) |
| 213 | + |
| 214 | + # test passing none |
| 215 | + with pytest.raises(ValueError): |
| 216 | + labels = unitrefine_label_units( |
| 217 | + sorting_analyzer_for_curation, |
| 218 | + noise_neural_classifier=None, |
| 219 | + sua_mua_classifier=None, |
| 220 | + ) |
| 221 | + |
| 222 | + # test warnings when unexpected labels are returned |
| 223 | + with pytest.warns(UserWarning): |
| 224 | + labels = unitrefine_label_units( |
| 225 | + sorting_analyzer_for_curation, |
| 226 | + noise_neural_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", |
| 227 | + sua_mua_classifier=None, |
| 228 | + ) |
| 229 | + |
| 230 | + with pytest.warns(UserWarning): |
| 231 | + labels = unitrefine_label_units( |
| 232 | + sorting_analyzer_for_curation, |
| 233 | + noise_neural_classifier=None, |
| 234 | + sua_mua_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", |
| 235 | + ) |
| 236 | + |
| 237 | + |
| 238 | +def test_unitrefine_label_units_with_local_models(sorting_analyzer_for_curation, trained_pipeline_path): |
| 239 | + # test with trained local models |
| 240 | + sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True) |
| 241 | + sorting_analyzer_for_curation.compute("quality_metrics") |
| 242 | + |
| 243 | + # test passing model folder |
| 244 | + labels = unitrefine_label_units( |
| 245 | + sorting_analyzer_for_curation, |
| 246 | + noise_neural_classifier=trained_pipeline_path, |
| 247 | + ) |
| 248 | + |
| 249 | + # test passing model folder |
| 250 | + labels = unitrefine_label_units( |
| 251 | + sorting_analyzer_for_curation, |
| 252 | + noise_neural_classifier=trained_pipeline_path / "best_model.skops", |
| 253 | + ) |
0 commit comments