Skip to content

Commit 679cb2e

Browse files
Ensure compute unit locations respects order of unit ids (#4329)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7b93c50 commit 679cb2e

2 files changed

Lines changed: 23 additions & 10 deletions

File tree

src/spikeinterface/postprocessing/localization_tools.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def compute_monopolar_triangulation(
9090
unit_ids = sorting_analyzer_or_templates.unit_ids
9191
else:
9292
unit_ids = np.asanyarray(unit_ids)
93-
keep = np.isin(sorting_analyzer_or_templates.unit_ids, unit_ids)
94-
templates = templates[keep, :, :]
93+
keep = [np.where(sorting_analyzer_or_templates.unit_ids == unit_id)[0][0] for unit_id in unit_ids]
94+
templates = templates[np.array(keep), :, :]
9595

9696
if enforce_decrease:
9797
neighbours_mask = np.zeros((templates.shape[0], templates.shape[2]), dtype=bool)
@@ -175,11 +175,11 @@ def compute_center_of_mass(
175175
if unit_ids is None:
176176
unit_ids = sorting_analyzer_or_templates.unit_ids
177177
else:
178-
unit_ids = np.asanyarray(unit_ids)
179-
keep = np.isin(sorting_analyzer_or_templates.unit_ids, unit_ids)
180-
templates = templates[keep, :, :]
178+
all_unit_ids = list(sorting_analyzer_or_templates.unit_ids)
179+
keep_unit_indices = np.array([all_unit_ids.index(unit_id) for unit_id in unit_ids])
180+
templates = templates[keep_unit_indices, :, :]
181181

182-
unit_location = np.zeros((unit_ids.size, 2), dtype="float64")
182+
unit_location = np.zeros((len(unit_ids), 2), dtype="float64")
183183
for i, unit_id in enumerate(unit_ids):
184184
chan_inds = sparsity.unit_id_to_channel_indices[unit_id]
185185
local_contact_locations = contact_locations[chan_inds, :]
@@ -258,9 +258,9 @@ def compute_grid_convolution(
258258
if unit_ids is None:
259259
unit_ids = sorting_analyzer_or_templates.unit_ids
260260
else:
261-
unit_ids = np.asanyarray(unit_ids)
262-
keep = np.isin(sorting_analyzer_or_templates.unit_ids, unit_ids)
263-
templates = templates[keep, :, :]
261+
all_unit_ids = list(sorting_analyzer_or_templates.unit_ids)
262+
keep_unit_indices = np.array([all_unit_ids.index(unit_id) for unit_id in unit_ids])
263+
templates = templates[keep_unit_indices, :, :]
264264

265265
fs = sorting_analyzer_or_templates.sampling_frequency
266266
percentile = 100 - percentile
@@ -283,7 +283,7 @@ def compute_grid_convolution(
283283
weights_sparsity_mask = weights > 0
284284

285285
nb_weights = weights.shape[0]
286-
unit_location = np.zeros((unit_ids.size, 3), dtype="float64")
286+
unit_location = np.zeros((len(unit_ids), 3), dtype="float64")
287287

288288
for i, unit_id in enumerate(unit_ids):
289289
main_chan = peak_channels[unit_id]

src/spikeinterface/postprocessing/tests/common_extension_tests.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer
99
from spikeinterface.core import estimate_sparsity
1010

11+
extensions_which_allow_unit_ids = ["unit_locations"]
12+
1113

1214
def get_dataset():
1315
recording, sorting = generate_ground_truth_recording(
@@ -140,6 +142,17 @@ def _check_one(self, sorting_analyzer, extension_class, params):
140142
merged = sorting_analyzer.merge_units(some_merges, format="memory", merging_mode="soft", sparsity_overlap=0.0)
141143
assert len(merged.unit_ids) == num_units_after_merge
142144

145+
# Test that order of units doesn't change things
146+
if extension_class.extension_name in extensions_which_allow_unit_ids:
147+
reversed_unit_ids = some_unit_ids[::-1]
148+
sliced_reversed = sorting_analyzer.select_units(reversed_unit_ids, format="memory")
149+
ext = sorting_analyzer.compute(
150+
extension_class.extension_name, unit_ids=reversed_unit_ids, **params, **job_kwargs
151+
)
152+
recomputed_data = ext.get_data()
153+
sliced_data = sliced_reversed.get_extension(extension_class.extension_name).get_data()
154+
np.testing.assert_allclose(recomputed_data, sliced_data)
155+
143156
# test roundtrip
144157
if sorting_analyzer.format in ("binary_folder", "zarr"):
145158
sorting_analyzer_loaded = load_sorting_analyzer(sorting_analyzer.folder)

0 commit comments

Comments
 (0)