Skip to content

Commit 8a45645

Browse files
committed
Merge branch 'main' into improve-version-mismatch-warning
2 parents 415a0ec + 4c3a6f9 commit 8a45645

4 files changed

Lines changed: 100 additions & 23 deletions

File tree

src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,12 +629,16 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
629629
pc_all[mask, ...] = pc_one
630630
ext.data["pca_projection"] = pc_all
631631

632-
# update params
633-
new_params = ext._set_params()
634-
updated_params = make_ext_params_up_to_date(ext, params, new_params)
635-
ext.set_params(**updated_params, save=False)
632+
# Install raw on-disk params and run compat handler first,
633+
# matching what AnalyzerExtension.load does for non-legacy folders.
634+
ext.params = dict(params)
636635
if ext.need_backward_compatibility_on_load:
637636
ext._handle_backward_compatibility_on_load()
637+
638+
# Now merge and validate — deprecated names are already migrated.
639+
new_params = ext._set_params()
640+
updated_params = make_ext_params_up_to_date(ext, ext.params, new_params)
641+
ext.set_params(**updated_params, save=False)
638642
ext.run_info = None
639643

640644
sorting_analyzer.extensions[new_name] = ext

src/spikeinterface/extractors/neoextractors/openephys.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ def __init__(
224224
experiment_names: str | list | None = None,
225225
all_annotations: bool = False,
226226
):
227+
folder_path = Path(folder_path)
228+
227229
# Handle experiment_names deprecation
228230
if experiment_names is not None:
229231
warnings.warn(
@@ -336,8 +338,15 @@ def __init__(
336338
if sample_shifts is not None:
337339
self.set_property("inter_sample_shift", sample_shifts)
338340

339-
# load synchronized timestamps and set_times to recording
340-
recording_folder = Path(folder_path) / record_node
341+
# folder_path can point to different levels of the OE folder structure
342+
# (root, record node, experiment, or recording). We need to find the root folder
343+
# in order to load the sync timestamps and set them as times to the recording.
344+
if record_node in folder_path.parts:
345+
root_index = len(folder_path.parts) - folder_path.parts.index(record_node) - 1
346+
root_folder = folder_path.parents[root_index]
347+
else:
348+
root_folder = folder_path
349+
recording_folder = root_folder / record_node
341350
stream_folders = []
342351
for segment_index, rec_id in enumerate(rec_ids):
343352
stream_folder = (

src/spikeinterface/extractors/tests/test_neoextractors.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,38 @@ def test_non_trivial_wiring(self):
156156
probe = recording.get_probe()
157157
np.testing.assert_array_equal(recording.channel_ids, probe.contact_annotations["settings_channel_key"])
158158

159+
def test_timestamp_loading_multi_level(self):
160+
"""
161+
Test that we can load the sync timestamps from different levels of the folder structure and
162+
that they are the same.
163+
"""
164+
recording_folder = (
165+
local_folder / "openephysbinary/v0.6.x_neuropixels_with_sync/Record Node 104/experiment1/recording1"
166+
)
167+
stream_name = "Record Node 104#Neuropix-PXI-100.ProbeA-AP"
168+
block_index = 0
169+
170+
recording_from_recording_folder = self.ExtractorClass(
171+
recording_folder,
172+
stream_name=stream_name,
173+
block_index=block_index,
174+
load_sync_timestamps=True,
175+
)
176+
assert recording_from_recording_folder.has_time_vector()
177+
timestamps_recording = recording_from_recording_folder.get_times()
178+
parent_folder = recording_folder
179+
for _ in range(3):
180+
parent_folder = parent_folder.parent
181+
recording_from_parent = self.ExtractorClass(
182+
parent_folder,
183+
stream_name=stream_name,
184+
block_index=block_index,
185+
load_sync_timestamps=True,
186+
)
187+
assert recording_from_parent.has_time_vector()
188+
timestamps_parent = recording_from_parent.get_times()
189+
np.testing.assert_array_equal(timestamps_recording, timestamps_parent)
190+
159191

160192
class OpenEphysBinaryEventTest(EventCommonTestSuite, unittest.TestCase):
161193
ExtractorClass = OpenEphysBinaryEventExtractor

src/spikeinterface/postprocessing/principal_component.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -629,31 +629,63 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx):
629629
if i0 == i1:
630630
return
631631

632+
# Since `get_traces` accounts for nbefore and nafter, all spikes in the chunk are valid and we can extract
633+
# all waveforms in one go without worrying about borders.
632634
start = int(spike_times[i0] - nbefore)
633635
end = int(spike_times[i1 - 1] + nafter)
634636
traces = recording.get_traces(start_frame=start, end_frame=end, segment_index=segment_index)
635637

636-
for i in range(i0, i1):
637-
st = spike_times[i]
638-
if st - start - nbefore < 0:
639-
continue
640-
if st - start + nafter > traces.shape[0]:
641-
continue
638+
nsamples = nbefore + nafter
642639

643-
wf = traces[st - start - nbefore : st - start + nafter, :]
640+
# Extract all waveforms in the chunk at once
641+
spike_times_in_chunk = spike_times[i0:i1]
642+
# Offset spike times to be relative to the start of the traces buffer
643+
spike_times_offset = spike_times_in_chunk - start - nbefore
644+
spike_indices = np.arange(i0, i1)
644645

645-
unit_index = spike_labels[i]
646-
chan_inds = unit_channels[unit_index]
646+
# Build waveform array: (n_spikes, nsamples, n_channels)
647+
# Use fancy indexing to extract all snippets at once
648+
sample_indices = spike_times_offset[:, None] + np.arange(nsamples)[None, :] # (n_spikes, nsamples)
649+
all_wfs = traces[sample_indices] # (n_spikes, nsamples, n_channels)
647650

651+
# Vectorized PCA: batch by channel across all spikes in the chunk.
652+
# For each unique channel, find all spikes that use it (via their unit's
653+
# sparsity), extract waveforms, and call transform once.
654+
labels_in_chunk = spike_labels[spike_indices]
655+
656+
# Build a set of all channels used by spikes in this chunk
657+
unique_unit_indices = np.unique(labels_in_chunk)
658+
chan_info: dict[int, list[tuple[np.ndarray, int]]] = {}
659+
for unit_index in unique_unit_indices:
660+
chan_inds = unit_channels[unit_index]
661+
unit_mask = labels_in_chunk == unit_index
662+
unit_local_idxs = np.nonzero(unit_mask)[0]
648663
for c, chan_ind in enumerate(chan_inds):
649-
w = wf[:, chan_ind]
650-
if w.size > 0:
651-
w = w[None, :]
652-
try:
653-
all_pcs[i, :, c] = pca_model[chan_ind].transform(w)
654-
except:
655-
# this could happen if len(wfs) is less then n_comp for a channel
656-
pass
664+
if chan_ind not in chan_info:
665+
chan_info[chan_ind] = []
666+
chan_info[chan_ind].append((unit_local_idxs, c))
667+
668+
for chan_ind, unit_groups in chan_info.items():
669+
# Concatenate all spike indices for this channel across units
670+
all_local_idxs = np.concatenate([g[0] for g in unit_groups])
671+
global_idxs = spike_indices[all_local_idxs]
672+
673+
# Batch waveforms for this channel: (n_spikes, nsamples)
674+
wfs_batch = all_wfs[all_local_idxs, :, chan_ind]
675+
676+
if wfs_batch.size == 0:
677+
continue
678+
try:
679+
pcs_batch = pca_model[chan_ind].transform(wfs_batch)
680+
# Write results back — each unit group has a fixed channel position
681+
offset = 0
682+
for unit_local_idxs, c_pos in unit_groups:
683+
n = len(unit_local_idxs)
684+
all_pcs[global_idxs[offset : offset + n], :, c_pos] = pcs_batch[offset : offset + n]
685+
offset += n
686+
except Exception:
687+
# this could happen if len(wfs) is less than n_comp for a channel
688+
pass
657689

658690

659691
def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafter, unit_channels, pca_model):

0 commit comments

Comments
 (0)