Skip to content

Commit 2e3cbcc

Browse files
Refactor API for auto_label_units (#4338)
Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com>
1 parent da0cf77 commit 2e3cbcc

10 files changed

Lines changed: 278 additions & 50 deletions

File tree

doc/api.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,10 @@ spikeinterface.curation
373373
.. autofunction:: remove_redundant_units
374374
.. autofunction:: remove_duplicated_spikes
375375
.. autofunction:: remove_excess_spikes
376-
.. autofunction:: auto_label_units
376+
.. autofunction:: model_based_label_units
377377
.. autofunction:: load_model
378378
.. autofunction:: train_model
379+
.. autofunction:: unitrefine_label_units
379380

380381
Curation Model
381382
~~~~~~~~~~~~~~

doc/how_to/auto_curation_prediction.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ repo's URL after huggingface.co/) and that we trust the model.
1616

1717
.. code::
1818
19-
from spikeinterface.curation import auto_label_units
19+
from spikeinterface.curation import model_based_label_units
2020
21-
labels_and_probabilities = auto_label_units(
21+
labels_and_probabilities = model_based_label_units(
2222
sorting_analyzer = sorting_analyzer,
2323
repo_id = "SpikeInterface/toy_tetrode_model",
2424
trust_model = True
@@ -29,7 +29,7 @@ create the labels:
2929

3030
.. code::
3131
32-
labels_and_probabilities = si.auto_label_units(
32+
labels_and_probabilities = si.model_based_label_units(
3333
sorting_analyzer = sorting_analyzer,
3434
model_folder = "my_folder_with_a_model_in_it",
3535
)
@@ -39,5 +39,5 @@ are also saved as a property of your ``sorting_analyzer`` and can be accessed li
3939

4040
.. code::
4141
42-
labels = sorting_analyzer.sorting.get_property("classifier_label")
43-
probabilities = sorting_analyzer.sorting.get_property("classifier_probability")
42+
labels = sorting_analyzer.get_sorting_property("classifier_label")
43+
probabilities = sorting_analyzer.get_sorting_property("classifier_probability")

doc/references.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ If you use the default "similarity_correlograms" preset in the :code:`compute_me
9292

9393
If you use the "slay" preset in the :code:`compute_merge_unit_groups` method, please cite [Koukuntla]_
9494

95-
If you use :code:`auto_label_units` or :code:`train_model`, please cite [Jain]_
95+
If you use :code:`unitrefine_label_units`, :code:`model_based_label_units` or :code:`train_model`, please cite [Jain]_
9696

9797
Benchmark
9898
---------

examples/tutorials/curation/plot_1_automated_curation.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@
8383

8484
##############################################################################
8585
# Great! We can now use the model to predict labels. Here, we pass the HF repo id directly
86-
# to the ``auto_label_units`` function. This returns a dictionary containing a label and
86+
# to the ``model_based_label_units`` function. This returns a dictionary containing a label and
8787
# a confidence for each unit contained in the ``sorting_analyzer``.
8888

89-
labels = sc.auto_label_units(
89+
labels = sc.model_based_label_units(
9090
sorting_analyzer = sorting_analyzer,
9191
repo_id = "SpikeInterface/toy_tetrode_model",
9292
trusted = ['numpy.dtype']
@@ -211,16 +211,16 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
211211
# For example, the following classifiers are trained on Neuropixels data from 11 mice recorded in
212212
# V1,SC and ALM: https://huggingface.co/SpikeInterface/UnitRefine_noise_neural_classifier/ and
213213
# https://huggingface.co/SpikeInterface/UnitRefine_sua_mua_classifier/. One will classify units into
214-
# `noise` or `not-noise` and the other will classify the `not-noise` units into single
214+
# `noise` or `neural` and the other will classify the `neural` units into single
215215
# unit activity (sua) units and multi-unit activity (mua) units.
216216
#
217217
# There is more information about the model on the model's HuggingFace page. Take a look!
218-
# The idea here is to first apply the noise/not-noise classifier, then the sua/mua one.
218+
# The idea here is to first apply the noise/neural classifier, then the sua/mua one.
219219
# We can do so as follows:
220220
#
221221

222-
# Apply the noise/not-noise model
223-
noise_neuron_labels = sc.auto_label_units(
222+
# Apply the noise/neural model
223+
noise_neuron_labels = sc.model_based_label_units(
224224
sorting_analyzer=sorting_analyzer,
225225
repo_id="SpikeInterface/UnitRefine_noise_neural_classifier",
226226
trust_model=True,
@@ -230,7 +230,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
230230
analyzer_neural = sorting_analyzer.remove_units(noise_units.index)
231231

232232
# Apply the sua/mua model
233-
sua_mua_labels = sc.auto_label_units(
233+
sua_mua_labels = sc.model_based_label_units(
234234
sorting_analyzer=analyzer_neural,
235235
repo_id="SpikeInterface/UnitRefine_sua_mua_classifier",
236236
trust_model=True,
@@ -239,6 +239,18 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
239239
all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index()
240240
print(all_labels)
241241

242+
##############################################################################
243+
# Both steps can be done in one go using the ``unitrefine_label_units`` function:
244+
#
245+
246+
all_labels = sc.unitrefine_label_units(
247+
sorting_analyzer,
248+
noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier",
249+
sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier",
250+
)
251+
print(all_labels)
252+
253+
242254
##############################################################################
243255
# If you run this without the ``trust_model=True`` parameter, you will receive an error:
244256
#
@@ -252,7 +264,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
252264
#
253265
# .. dropdown:: More about security
254266
#
255-
# Sharing models, with are Python objects, is complicated.
267+
# Sharing models, which are Python objects, is complicated.
256268
# We have chosen to use the `skops format <https://skops.readthedocs.io/en/stable/>`_, instead
257269
# of the common but insecure ``.pkl`` format (read about ``pickle`` security issues
258270
# `here <https://lwn.net/Articles/964392/>`_). While unpacking the ``.skops`` file, each function
@@ -276,7 +288,7 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
276288
#
277289
# .. code-block::
278290
#
279-
# labels = sc.auto_label_units(
291+
# labels = sc.model_based_label_units(
280292
# sorting_analyzer = sorting_analyzer,
281293
# model_folder = "path/to/model/folder",
282294
# )

examples/tutorials/curation/plot_3_upload_a_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@
112112
#
113113
# ` ` ` python (NOTE: you should remove the spaces between each backtick. This is just formatting for the notebook you are reading)
114114
#
115-
# from spikeinterface.curation import auto_label_units
116-
# labels = auto_label_units(
115+
# from spikeinterface.curation import model_based_label_units
116+
# labels = model_based_label_units(
117117
# sorting_analyzer = sorting_analyzer,
118118
# repo_id = "SpikeInterface/toy_tetrode_model",
119119
# trust_model=True
@@ -123,9 +123,9 @@
123123
# or you can download the entire repositry to `a_folder_for_a_model`, and use
124124
#
125125
# ` ` ` python
126-
# from spikeinterface.curation import auto_label_units
126+
# from spikeinterface.curation import model_based_label_units
127127
#
128-
# labels = auto_label_units(
128+
# labels = model_based_label_units(
129129
# sorting_analyzer = sorting_analyzer,
130130
# model_folder = "path/to/a_folder_for_a_model",
131131
# trusted = ['numpy.dtype']

src/spikeinterface/curation/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@
2020
from .sortingview_curation import apply_sortingview_curation
2121

2222
# automated curation
23-
from .model_based_curation import auto_label_units, load_model
23+
from .model_based_curation import model_based_label_units, load_model, auto_label_units
2424
from .train_manual_curation import train_model, get_default_classifier_search_spaces
25+
from .unitrefine_curation import unitrefine_label_units

src/spikeinterface/curation/model_based_curation.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ def predict_labels(
119119
)
120120

121121
# Set predictions and probability as sorting properties
122-
self.sorting_analyzer.sorting.set_property("classifier_label", predictions)
123-
self.sorting_analyzer.sorting.set_property("classifier_probability", probabilities)
122+
self.sorting_analyzer.set_sorting_property("classifier_label", predictions)
123+
self.sorting_analyzer.set_sorting_property("classifier_probability", probabilities)
124124

125125
if export_to_phy:
126126
self._export_to_phy(classified_units)
@@ -204,11 +204,11 @@ def _export_to_phy(self, classified_df):
204204
classified_df.to_csv(f"{sorting_path}/cluster_prediction.tsv", sep="\t", index_label="cluster_id")
205205

206206

207-
def auto_label_units(
207+
def model_based_label_units(
208208
sorting_analyzer: SortingAnalyzer,
209209
model_folder=None,
210-
model_name=None,
211210
repo_id=None,
211+
model_name=None,
212212
label_conversion=None,
213213
trust_model=False,
214214
trusted=None,
@@ -227,11 +227,11 @@ def auto_label_units(
227227
----------
228228
sorting_analyzer : SortingAnalyzer
229229
The sorting analyzer object containing the spike sorting results.
230-
model_folder : str or Path, defualt: None
230+
model_folder : str or Path, default: None
231231
The path to the folder containing the model
232-
repo_id : str | Path, default: None
232+
repo_id : str, default: None
233233
Hugging face repo id which contains the model e.g. 'username/model'
234-
model_name: str | Path, default: None
234+
model_name: str, default: None
235235
Filename of model e.g. 'my_model.skops'. If None, uses first model found.
236236
label_conversion : dic | None, default: None
237237
A dictionary for converting the predicted labels (which are integers) to custom labels. If None,
@@ -281,6 +281,19 @@ def auto_label_units(
281281
return classified_units
282282

283283

284+
def auto_label_units(*args, **kwargs):
285+
"""
286+
Deprecated function. Please use `model_based_label_units` instead.
287+
"""
288+
warnings.warn(
289+
"`auto_label_units` is deprecated and will be removed in v0.105.0. "
290+
"Please use `model_based_label_units` instead.",
291+
DeprecationWarning,
292+
stacklevel=2,
293+
)
294+
return model_based_label_units(*args, **kwargs)
295+
296+
284297
def load_model(model_folder=None, repo_id=None, model_name=None, trust_model=False, trusted=None):
285298
"""
286299
Loads a model and model_info from a HuggingFaceHub repo or a local folder.
@@ -289,9 +302,9 @@ def load_model(model_folder=None, repo_id=None, model_name=None, trust_model=Fal
289302
----------
290303
model_folder : str or Path, defualt: None
291304
The path to the folder containing the model
292-
repo_id : str | Path, default: None
305+
repo_id : str, default: None
293306
Hugging face repo id which contains the model e.g. 'username/model'
294-
model_name: str | Path, default: None
307+
model_name: str, default: None
295308
Filename of model e.g. 'my_model.skops'. If None, uses first model found.
296309
trust_model : bool, default: False
297310
Whether to trust the model. If True, the `trusted` parameter that is passed to `skops.load` to load the model will be

src/spikeinterface/curation/tests/common.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,22 +91,25 @@ def trained_pipeline_path():
9191
If the model already exists, this function does nothing.
9292
"""
9393
trained_model_folder = Path(__file__).parent / Path("trained_pipeline")
94-
analyzer = make_sorting_analyzer(sparse=True)
95-
analyzer.compute(
96-
{
97-
"quality_metrics": {"metric_names": ["snr", "num_spikes"]},
98-
"template_metrics": {"metric_names": ["half_width"]},
99-
}
100-
)
101-
train_model(
102-
analyzers=[analyzer] * 5,
103-
labels=[[1, 0, 1, 0, 1]] * 5,
104-
folder=trained_model_folder,
105-
classifiers=["RandomForestClassifier"],
106-
imputation_strategies=["median"],
107-
scaling_techniques=["standard_scaler"],
108-
)
109-
yield trained_model_folder
94+
if trained_model_folder.is_dir():
95+
yield trained_model_folder
96+
else:
97+
analyzer = make_sorting_analyzer(sparse=True)
98+
analyzer.compute(
99+
{
100+
"quality_metrics": {"metric_names": ["snr", "num_spikes"]},
101+
"template_metrics": {"metric_names": ["half_width"]},
102+
}
103+
)
104+
train_model(
105+
analyzers=[analyzer] * 5,
106+
labels=[[1, 0, 1, 0, 1]] * 5,
107+
folder=trained_model_folder,
108+
classifiers=["RandomForestClassifier"],
109+
imputation_strategies=["median"],
110+
scaling_techniques=["standard_scaler"],
111+
)
112+
yield trained_model_folder
110113

111114

112115
if __name__ == "__main__":

src/spikeinterface/curation/tests/test_model_based_curation.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import pytest
22
from pathlib import Path
3+
34
from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path
45
from spikeinterface.curation.model_based_curation import ModelBasedClassification
5-
from spikeinterface.curation import auto_label_units, load_model
6+
from spikeinterface.curation import model_based_label_units, load_model
67
from spikeinterface.curation.train_manual_curation import _get_computed_metrics
8+
from spikeinterface.curation import unitrefine_label_units
9+
710

811
import numpy as np
912

@@ -39,21 +42,21 @@ def test_model_based_classification_init(sorting_analyzer_for_curation, model):
3942

4043

4144
def test_metric_ordering_independence(sorting_analyzer_for_curation, trained_pipeline_path):
42-
"""The function `auto_label_units` needs the correct metrics to have been computed. However,
45+
"""The function `model_based_label_units` needs the correct metrics to have been computed. However,
4346
it should be independent of the order of computation. We test this here."""
4447

4548
sorting_analyzer_for_curation.compute("template_metrics", metric_names=["half_width"])
4649
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["num_spikes", "snr"])
4750

48-
prediction_prob_dataframe_1 = auto_label_units(
51+
prediction_prob_dataframe_1 = model_based_label_units(
4952
sorting_analyzer=sorting_analyzer_for_curation,
5053
model_folder=trained_pipeline_path,
5154
trusted=["numpy.dtype"],
5255
)
5356

5457
sorting_analyzer_for_curation.compute("quality_metrics", metric_names=["snr", "num_spikes"])
5558

56-
prediction_prob_dataframe_2 = auto_label_units(
59+
prediction_prob_dataframe_2 = model_based_label_units(
5760
sorting_analyzer=sorting_analyzer_for_curation,
5861
model_folder=trained_pipeline_path,
5962
trusted=["numpy.dtype"],
@@ -168,3 +171,83 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura
168171
model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"])
169172
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model)
170173
model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info)
174+
175+
176+
def test_unitrefine_label_units_hf(sorting_analyzer_for_curation):
177+
"""Test the `unitrefine_label_units` function."""
178+
sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True)
179+
sorting_analyzer_for_curation.compute("quality_metrics")
180+
181+
# test passing both classifiers
182+
labels = unitrefine_label_units(
183+
sorting_analyzer_for_curation,
184+
noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
185+
sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
186+
)
187+
188+
assert "label" in labels.columns
189+
assert "probability" in labels.columns
190+
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
191+
192+
# test only noise neural classifier
193+
labels = unitrefine_label_units(
194+
sorting_analyzer_for_curation,
195+
noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
196+
sua_mua_classifier=None,
197+
)
198+
199+
assert "label" in labels.columns
200+
assert "probability" in labels.columns
201+
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
202+
203+
# test only sua mua classifier
204+
labels = unitrefine_label_units(
205+
sorting_analyzer_for_curation,
206+
noise_neural_classifier=None,
207+
sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
208+
)
209+
210+
assert "label" in labels.columns
211+
assert "probability" in labels.columns
212+
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
213+
214+
# test passing none
215+
with pytest.raises(ValueError):
216+
labels = unitrefine_label_units(
217+
sorting_analyzer_for_curation,
218+
noise_neural_classifier=None,
219+
sua_mua_classifier=None,
220+
)
221+
222+
# test warnings when unexpected labels are returned
223+
with pytest.warns(UserWarning):
224+
labels = unitrefine_label_units(
225+
sorting_analyzer_for_curation,
226+
noise_neural_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
227+
sua_mua_classifier=None,
228+
)
229+
230+
with pytest.warns(UserWarning):
231+
labels = unitrefine_label_units(
232+
sorting_analyzer_for_curation,
233+
noise_neural_classifier=None,
234+
sua_mua_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
235+
)
236+
237+
238+
def test_unitrefine_label_units_with_local_models(sorting_analyzer_for_curation, trained_pipeline_path):
239+
# test with trained local models
240+
sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True)
241+
sorting_analyzer_for_curation.compute("quality_metrics")
242+
243+
# test passing model folder
244+
labels = unitrefine_label_units(
245+
sorting_analyzer_for_curation,
246+
noise_neural_classifier=trained_pipeline_path,
247+
)
248+
249+
# test passing model folder
250+
labels = unitrefine_label_units(
251+
sorting_analyzer_for_curation,
252+
noise_neural_classifier=trained_pipeline_path / "best_model.skops",
253+
)

0 commit comments

Comments
 (0)