@@ -90,8 +90,8 @@ def compute_monopolar_triangulation(
9090 unit_ids = sorting_analyzer_or_templates .unit_ids
9191 else :
9292 unit_ids = np .asanyarray (unit_ids )
93- keep = np .isin (sorting_analyzer_or_templates .unit_ids , unit_ids )
94- templates = templates [keep , :, :]
93+ keep = [ np .where (sorting_analyzer_or_templates .unit_ids == unit_id )[ 0 ][ 0 ] for unit_id in unit_ids ]
94+ templates = templates [np . array ( keep ) , :, :]
9595
9696 if enforce_decrease :
9797 neighbours_mask = np .zeros ((templates .shape [0 ], templates .shape [2 ]), dtype = bool )
@@ -175,11 +175,11 @@ def compute_center_of_mass(
175175 if unit_ids is None :
176176 unit_ids = sorting_analyzer_or_templates .unit_ids
177177 else :
178- unit_ids = np . asanyarray ( unit_ids )
179- keep = np .isin ( sorting_analyzer_or_templates . unit_ids , unit_ids )
180- templates = templates [keep , :, :]
178+ all_unit_ids = list ( sorting_analyzer_or_templates . unit_ids )
179+ keep_unit_indices = np .array ([ all_unit_ids . index ( unit_id ) for unit_id in unit_ids ] )
180+ templates = templates [keep_unit_indices , :, :]
181181
182- unit_location = np .zeros ((unit_ids . size , 2 ), dtype = "float64" )
182+ unit_location = np .zeros ((len ( unit_ids ) , 2 ), dtype = "float64" )
183183 for i , unit_id in enumerate (unit_ids ):
184184 chan_inds = sparsity .unit_id_to_channel_indices [unit_id ]
185185 local_contact_locations = contact_locations [chan_inds , :]
@@ -258,9 +258,9 @@ def compute_grid_convolution(
258258 if unit_ids is None :
259259 unit_ids = sorting_analyzer_or_templates .unit_ids
260260 else :
261- unit_ids = np . asanyarray ( unit_ids )
262- keep = np .isin ( sorting_analyzer_or_templates . unit_ids , unit_ids )
263- templates = templates [keep , :, :]
261+ all_unit_ids = list ( sorting_analyzer_or_templates . unit_ids )
262+ keep_unit_indices = np .array ([ all_unit_ids . index ( unit_id ) for unit_id in unit_ids ] )
263+ templates = templates [keep_unit_indices , :, :]
264264
265265 fs = sorting_analyzer_or_templates .sampling_frequency
266266 percentile = 100 - percentile
@@ -283,7 +283,7 @@ def compute_grid_convolution(
283283 weights_sparsity_mask = weights > 0
284284
285285 nb_weights = weights .shape [0 ]
286- unit_location = np .zeros ((unit_ids . size , 3 ), dtype = "float64" )
286+ unit_location = np .zeros ((len ( unit_ids ) , 3 ), dtype = "float64" )
287287
288288 for i , unit_id in enumerate (unit_ids ):
289289 main_chan = peak_channels [unit_id ]
0 commit comments