Skip to content

Commit 360af80

Browse files
sharifhsnclaude
andcommitted
PERF: Batch spectrogram calls in Welch PSD computation
Replace np.apply_along_axis (which calls scipy.signal.spectrogram once per row) with chunked 2D calls. scipy.signal.spectrogram handles multi-row input efficiently via vectorized FFT, so processing ~10 MB chunks instead of individual rows eliminates per-call Python dispatch overhead. On 320 epochs x 376 channels (120K rows), psd_array_welch goes from ~5.0s to ~0.19s (~26x speedup). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 14d0916 commit 360af80

1 file changed

Lines changed: 15 additions & 6 deletions

File tree

mne/time_frequency/psd.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,23 @@ def _decomp_aggregate_mask(epoch, func, average, freq_sl):
6262

6363
def _spect_func(epoch, func, freq_sl, average, *, output="power"):
6464
"""Aux function."""
65-
# Decide if we should split this to save memory or not, since doing
66-
# multiple calls will incur some performance overhead. Eventually we might
67-
# want to write (really, go back to) our own spectrogram implementation
68-
# that, if possible, averages after each transform, but this will incur
69-
# a lot of overhead because of the many Python calls required.
65+
# Process in chunks to balance vectorization (scipy.signal.spectrogram
66+
# handles multi-row input efficiently) against memory usage.
7067
kwargs = dict(func=func, average=average, freq_sl=freq_sl)
7168
if epoch.nbytes > 10e6:
72-
spect = np.apply_along_axis(_decomp_aggregate_mask, -1, epoch, **kwargs)
69+
# Process in chunks of rows instead of one-by-one. Each chunk is
70+
# passed to spectrogram as a 2D array, which is much faster than
71+
# calling spectrogram per-row via np.apply_along_axis.
72+
n_rows = epoch.shape[0]
73+
# Target ~10 MB per chunk (same threshold as the original code)
74+
row_bytes = epoch[0].nbytes
75+
chunk_size = max(1, int(10e6 / row_bytes))
76+
parts = []
77+
for start in range(0, n_rows, chunk_size):
78+
parts.append(
79+
_decomp_aggregate_mask(epoch[start : start + chunk_size], **kwargs)
80+
)
81+
spect = np.concatenate(parts, axis=0)
7382
else:
7483
spect = _decomp_aggregate_mask(epoch, **kwargs)
7584
return spect

0 commit comments

Comments
 (0)