Skip to content

Commit 39ee092

Browse files
ENH: compute CSD directly for upper-triangle channel pairs in Fourier/multitaper (#13719)
1 parent b3524b6 commit 39ee092

2 files changed

Lines changed: 18 additions & 29 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve memory usage and runtime of :func:`mne.time_frequency.csd_fourier`, :func:`mne.time_frequency.csd_multitaper`, :func:`mne.time_frequency.csd_array_fourier`, and :func:`mne.time_frequency.csd_array_multitaper` by avoiding unnecessary full-matrix CSD construction, by `Pragnya Khandelwal`_.

mne/time_frequency/csd.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,22 +1406,14 @@ def _csd_fourier(X, sfreq, n_times, freq_mask, n_fft):
14061406
x_mt, _ = _mt_spectra(X, np.hanning(n_times), sfreq, n_fft)
14071407

14081408
# Hack so we can sum over axis=-2
1409-
weights = np.array([1.0])[:, np.newaxis, np.newaxis, np.newaxis]
1409+
weights = np.array([1.0])[np.newaxis, :, np.newaxis]
14101410

14111411
x_mt = x_mt[:, :, freq_mask]
14121412

1413-
# Calculating CSD
1414-
# Tiling x_mt so that we can easily use _csd_from_mt()
1415-
x_mt = x_mt[:, np.newaxis, :, :]
1416-
x_mt = np.tile(x_mt, [1, x_mt.shape[0], 1, 1])
1417-
y_mt = np.transpose(x_mt, axes=[1, 0, 2, 3])
1418-
weights_y = np.transpose(weights, axes=[1, 0, 2, 3])
1419-
csds = _csd_from_mt(x_mt, y_mt, weights, weights_y)
1420-
1421-
# FIXME: don't compute full matrix in the first place
1422-
csds = np.array(
1423-
[_sym_mat_to_vector(csds[:, :, i]) for i in range(csds.shape[-1])]
1424-
).T
1413+
# Calculate CSD for upper-triangle channel pairs directly.
1414+
# This avoids computing/storing the full channel x channel matrix.
1415+
ii, jj = np.triu_indices(x_mt.shape[0])
1416+
csds = _csd_from_mt(x_mt[ii], x_mt[jj], weights, weights)
14251417

14261418
# Scaling by number of samples and compensating for loss of power
14271419
# due to windowing (see section 11.5.2 in Bendat & Piersol).
@@ -1445,27 +1437,23 @@ def _csd_multitaper(
14451437
_, weights = _psd_from_mt_adaptive(
14461438
x_mt, eigvals, freq_mask, max_iter, return_weights=True
14471439
)
1448-
# Tiling weights so that we can easily use _csd_from_mt()
1449-
weights = weights[:, np.newaxis, :, :]
1450-
weights = np.tile(weights, [1, x_mt.shape[0], 1, 1])
14511440
else:
14521441
# Do not use adaptive weights
1453-
weights = np.sqrt(eigvals)[np.newaxis, np.newaxis, :, np.newaxis]
1442+
weights = np.sqrt(eigvals)[np.newaxis, :, np.newaxis]
14541443

14551444
x_mt = x_mt[:, :, freq_mask]
14561445

1457-
# Calculating CSD
1458-
# Tiling x_mt so that we can easily use _csd_from_mt()
1459-
x_mt = x_mt[:, np.newaxis, :, :]
1460-
x_mt = np.tile(x_mt, [1, x_mt.shape[0], 1, 1])
1461-
y_mt = np.transpose(x_mt, axes=[1, 0, 2, 3])
1462-
weights_y = np.transpose(weights, axes=[1, 0, 2, 3])
1463-
csds = _csd_from_mt(x_mt, y_mt, weights, weights_y)
1464-
1465-
# FIXME: don't compute full matrix in the first place
1466-
csds = np.array(
1467-
[_sym_mat_to_vector(csds[:, :, i]) for i in range(csds.shape[-1])]
1468-
).T
1446+
# Calculate CSD for upper-triangle channel pairs directly.
1447+
# This avoids computing/storing the full channel x channel matrix.
1448+
ii, jj = np.triu_indices(x_mt.shape[0])
1449+
x_mt_i = x_mt[ii]
1450+
x_mt_j = x_mt[jj]
1451+
if adaptive:
1452+
weights_i = weights[ii]
1453+
weights_j = weights[jj]
1454+
else:
1455+
weights_i = weights_j = weights
1456+
csds = _csd_from_mt(x_mt_i, x_mt_j, weights_i, weights_j)
14691457

14701458
# Scaling by sampling frequency for compatibility with Matlab
14711459
csds /= sfreq

0 commit comments

Comments
 (0)