@@ -629,31 +629,63 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx):
629629 if i0 == i1 :
630630 return
631631
632+ # Since `get_traces` accounts for nbefore and nafter, all spikes in the chunk are valid and we can extract
633+ # all waveforms in one go without worrying about borders.
632634 start = int (spike_times [i0 ] - nbefore )
633635 end = int (spike_times [i1 - 1 ] + nafter )
634636 traces = recording .get_traces (start_frame = start , end_frame = end , segment_index = segment_index )
635637
636- for i in range (i0 , i1 ):
637- st = spike_times [i ]
638- if st - start - nbefore < 0 :
639- continue
640- if st - start + nafter > traces .shape [0 ]:
641- continue
638+ nsamples = nbefore + nafter
642639
643- wf = traces [st - start - nbefore : st - start + nafter , :]
640+ # Extract all waveforms in the chunk at once
641+ spike_times_in_chunk = spike_times [i0 :i1 ]
642+ # Offset spike times to be relative to the start of the traces buffer
643+ spike_times_offset = spike_times_in_chunk - start - nbefore
644+ spike_indices = np .arange (i0 , i1 )
644645
645- unit_index = spike_labels [i ]
646- chan_inds = unit_channels [unit_index ]
646+ # Build waveform array: (n_spikes, nsamples, n_channels)
647+ # Use fancy indexing to extract all snippets at once
648+ sample_indices = spike_times_offset [:, None ] + np .arange (nsamples )[None , :] # (n_spikes, nsamples)
649+ all_wfs = traces [sample_indices ] # (n_spikes, nsamples, n_channels)
647650
651+ # Vectorized PCA: batch by channel across all spikes in the chunk.
652+ # For each unique channel, find all spikes that use it (via their unit's
653+ # sparsity), extract waveforms, and call transform once.
654+ labels_in_chunk = spike_labels [spike_indices ]
655+
656+ # Build a set of all channels used by spikes in this chunk
657+ unique_unit_indices = np .unique (labels_in_chunk )
658+ chan_info : dict [int , list [tuple [np .ndarray , int ]]] = {}
659+ for unit_index in unique_unit_indices :
660+ chan_inds = unit_channels [unit_index ]
661+ unit_mask = labels_in_chunk == unit_index
662+ unit_local_idxs = np .nonzero (unit_mask )[0 ]
648663 for c , chan_ind in enumerate (chan_inds ):
649- w = wf [:, chan_ind ]
650- if w .size > 0 :
651- w = w [None , :]
652- try :
653- all_pcs [i , :, c ] = pca_model [chan_ind ].transform (w )
654- except :
655- # this could happen if len(wfs) is less then n_comp for a channel
656- pass
664+ if chan_ind not in chan_info :
665+ chan_info [chan_ind ] = []
666+ chan_info [chan_ind ].append ((unit_local_idxs , c ))
667+
668+ for chan_ind , unit_groups in chan_info .items ():
669+ # Concatenate all spike indices for this channel across units
670+ all_local_idxs = np .concatenate ([g [0 ] for g in unit_groups ])
671+ global_idxs = spike_indices [all_local_idxs ]
672+
673+ # Batch waveforms for this channel: (n_spikes, nsamples)
674+ wfs_batch = all_wfs [all_local_idxs , :, chan_ind ]
675+
676+ if wfs_batch .size == 0 :
677+ continue
678+ try :
679+ pcs_batch = pca_model [chan_ind ].transform (wfs_batch )
680+ # Write results back — each unit group has a fixed channel position
681+ offset = 0
682+ for unit_local_idxs , c_pos in unit_groups :
683+ n = len (unit_local_idxs )
684+ all_pcs [global_idxs [offset : offset + n ], :, c_pos ] = pcs_batch [offset : offset + n ]
685+ offset += n
686+ except Exception :
687+ # this could happen if len(wfs) is less than n_comp for a channel
688+ pass
657689
658690
659691def _init_work_all_pc_extractor (recording , sorting , all_pcs_args , nbefore , nafter , unit_channels , pca_model ):
0 commit comments