Skip to content

Commit 6d49f69

Browse files
Aniketsytsbinnswmvanvliet
authored
Extend plot_csd support to SEEG, ECoG, and DBS channel types (#13713)
Co-authored-by: Thomas S. Binns <t.s.binns@outlook.com> Co-authored-by: Marijn van Vliet <w.m.vanvliet@gmail.com>
1 parent 39ee092 commit 6d49f69

3 files changed

Lines changed: 98 additions & 59 deletions

File tree

doc/changes/dev/13713.other.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Extended support for visualising cross-spectral densities in :func:`mne.viz.plot_csd` and :meth:`mne.time_frequency.CrossSpectralDensity.plot` to all types of :term:`data channels`, by :newcontrib:`Aniket Singh Yadav`.

mne/viz/misc.py

Lines changed: 52 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
_picks_by_type,
2424
pick_channels,
2525
pick_info,
26-
pick_types,
2726
)
2827
from .._fiff.proj import make_projector
2928
from .._freesurfer import _check_mri, _mri_orientation, _read_mri_info, _reorient_image
@@ -53,35 +52,56 @@
5352
)
5453

5554

56-
def _index_info_cov(info, cov, exclude):
57-
if exclude == "bads":
58-
exclude = info["bads"]
59-
info = pick_info(info, pick_channels(info["ch_names"], cov["names"], exclude))
60-
del exclude
55+
def _get_ch_type_metadata(info, ch_names):
56+
"""Get indices, titles, units, scalings, and types for plottable channel types."""
57+
info_ch_names = info["ch_names"]
6158
picks_list = _picks_by_type(info, meg_combined=False, ref_meg=False, exclude=())
6259
picks_by_type = dict(picks_list)
6360

64-
ch_names = [n for n in cov.ch_names if n in info["ch_names"]]
65-
ch_idx = [cov.ch_names.index(n) for n in ch_names]
66-
67-
info_ch_names = info["ch_names"]
6861
idx_by_type = defaultdict(list)
6962
for ch_type, sel in picks_by_type.items():
7063
idx_by_type[ch_type] = [
7164
ch_names.index(info_ch_names[c])
7265
for c in sel
7366
if info_ch_names[c] in ch_names
7467
]
68+
69+
indices = []
70+
titles = []
71+
units = []
72+
scalings = []
73+
ch_types = []
74+
for key in _DATA_CH_TYPES_SPLIT:
75+
if len(idx_by_type[key]) > 0:
76+
indices.append(idx_by_type[key])
77+
titles.append(DEFAULTS["titles"][key])
78+
units.append(DEFAULTS["units"][key])
79+
scalings.append(DEFAULTS["scalings"][key])
80+
ch_types.append(key)
81+
if len(indices) == 0:
82+
raise RuntimeError(
83+
"No plottable channel types found. "
84+
f"Allowed types are: {_DATA_CH_TYPES_SPLIT}"
85+
)
86+
return indices, titles, units, scalings, ch_types
87+
88+
89+
def _index_info_cov(info, cov, exclude):
90+
"""Pick cov data and get metadata for present, plottable data channel types."""
91+
if exclude == "bads":
92+
exclude = info["bads"]
93+
info = pick_info(info, pick_channels(info["ch_names"], cov["names"], exclude))
94+
del exclude
95+
96+
ch_names = [n for n in cov.ch_names if n in info["ch_names"]]
97+
ch_idx = [cov.ch_names.index(n) for n in ch_names]
98+
99+
indices, titles, units, scalings, ch_types = _get_ch_type_metadata(info, ch_names)
75100
idx_names = [
76-
(
77-
idx_by_type[key],
78-
f"{DEFAULTS['titles'][key]} covariance",
79-
DEFAULTS["units"][key],
80-
DEFAULTS["scalings"][key],
81-
key,
101+
(idx, f"{title} covariance", unit, scaling, key)
102+
for idx, title, unit, scaling, key in zip(
103+
indices, titles, units, scalings, ch_types
82104
)
83-
for key in _DATA_CH_TYPES_SPLIT
84-
if len(idx_by_type[key]) > 0
85105
]
86106
C = cov.data[ch_idx][:, ch_idx]
87107
return info, C, ch_names, idx_names
@@ -1483,39 +1503,16 @@ def plot_csd(
14831503
raise ValueError('"mode" should be either "csd" or "coh".')
14841504

14851505
if info is not None:
1486-
info_ch_names = info["ch_names"]
1487-
sel_eeg = pick_types(info, meg=False, eeg=True, ref_meg=False, exclude=[])
1488-
sel_mag = pick_types(info, meg="mag", eeg=False, ref_meg=False, exclude=[])
1489-
sel_grad = pick_types(info, meg="grad", eeg=False, ref_meg=False, exclude=[])
1490-
idx_eeg = [
1491-
csd.ch_names.index(info_ch_names[c])
1492-
for c in sel_eeg
1493-
if info_ch_names[c] in csd.ch_names
1494-
]
1495-
idx_mag = [
1496-
csd.ch_names.index(info_ch_names[c])
1497-
for c in sel_mag
1498-
if info_ch_names[c] in csd.ch_names
1499-
]
1500-
idx_grad = [
1501-
csd.ch_names.index(info_ch_names[c])
1502-
for c in sel_grad
1503-
if info_ch_names[c] in csd.ch_names
1504-
]
1505-
indices = [idx_eeg, idx_mag, idx_grad]
1506-
titles = ["EEG", "Magnetometers", "Gradiometers"]
1507-
1508-
if mode == "csd":
1509-
# The units in which to plot the CSD
1510-
units = dict(eeg="µV²", grad="fT²/cm²", mag="fT²")
1511-
scalings = dict(eeg=1e12, grad=1e26, mag=1e30)
1506+
indices, titles, units, scalings, ch_types = _get_ch_type_metadata(
1507+
info, csd.ch_names
1508+
)
15121509
else:
15131510
indices = [np.arange(len(csd.ch_names))]
1511+
units = [""]
1512+
scalings = [1]
1513+
ch_types = [None]
15141514
if mode == "csd":
15151515
titles = ["Cross-spectral density"]
1516-
# Units and scaling unknown
1517-
units = dict()
1518-
scalings = dict()
15191516
elif mode == "coh":
15201517
titles = ["Coherence"]
15211518

@@ -1526,10 +1523,9 @@ def plot_csd(
15261523
n_rows = int(np.ceil(n_freqs / float(n_cols)))
15271524

15281525
figs = []
1529-
for ind, title, ch_type in zip(indices, titles, ["eeg", "mag", "grad"]):
1530-
if len(ind) == 0:
1531-
continue
1532-
1526+
for ind, title, unit, scaling, ch_type in zip(
1527+
indices, titles, units, scalings, ch_types
1528+
):
15331529
fig, axes = plt.subplots(
15341530
n_rows,
15351531
n_cols,
@@ -1542,7 +1538,7 @@ def plot_csd(
15421538
for i in range(len(csd.frequencies)):
15431539
cm = csd.get_data(index=i)[ind][:, ind]
15441540
if mode == "csd":
1545-
cm = np.abs(cm) * scalings.get(ch_type, 1)
1541+
cm = np.abs(cm) * scaling**2
15461542
elif mode == "coh":
15471543
# Compute coherence from the CSD matrix
15481544
psd = np.diag(cm).real
@@ -1566,8 +1562,10 @@ def plot_csd(
15661562
cb = plt.colorbar(im, ax=[a for ax_ in axes for a in ax_])
15671563
if mode == "csd":
15681564
label = "CSD"
1569-
if ch_type in units:
1570-
label += f" ({units[ch_type]})"
1565+
if ch_type is not None:
1566+
if "/" in unit:
1567+
unit = f"({unit})"
1568+
label += f" ({unit}²)"
15711569
cb.set_label(label)
15721570
elif mode == "coh":
15731571
cb.set_label("Coherence")

mne/viz/tests/test_misc.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from mne import (
1313
SourceEstimate,
14+
create_info,
1415
pick_events,
1516
read_cov,
1617
read_dipole,
@@ -135,6 +136,12 @@ def test_plot_cov():
135136
cov["data"] = cov.data * (1 + 1j)
136137
fig1, fig2 = cov.plot(raw.info)
137138

139+
# Test no plottable channel types caught
140+
raw.del_proj("all") # allow us to change channel types
141+
raw.set_channel_types({ch: "misc" for ch in raw.ch_names}, on_unit_change="ignore")
142+
with pytest.raises(RuntimeError, match="No plottable channel types found"):
143+
cov.plot(raw.info, exclude=raw.ch_names[6:])
144+
138145

139146
@testing.requires_testing_data
140147
def test_plot_bem():
@@ -312,18 +319,51 @@ def test_plot_dipole_amplitudes():
312319
dipoles.plot_amplitudes(show=False)
313320

314321

315-
def test_plot_csd():
322+
@pytest.mark.parametrize(
323+
"ch_types",
324+
[
325+
None,
326+
"eeg",
327+
"ecog",
328+
["grad", "dbs"],
329+
"misc",
330+
],
331+
)
332+
def test_plot_csd(ch_types):
316333
"""Test plotting of CSD matrices."""
334+
if isinstance(ch_types, list):
335+
n_ch_types = len(ch_types)
336+
n_ch = 2 * n_ch_types
337+
ch_types = np.repeat(ch_types, 2).tolist()
338+
else:
339+
n_ch = 2
340+
ch_names = [f"CH{i + 1}" for i in range(n_ch)]
341+
n_data = n_ch * (n_ch + 1) // 2
342+
317343
csd = CrossSpectralDensity(
318-
[1, 2, 3],
319-
["CH1", "CH2"],
344+
np.arange(1, n_data + 1),
345+
ch_names,
320346
frequencies=[(10, 20)],
321347
n_fft=1,
322348
tmin=0,
323349
tmax=1,
324350
)
325-
plot_csd(csd, mode="csd") # Plot cross-spectral density
326-
plot_csd(csd, mode="coh") # Plot coherence
351+
352+
if ch_types is None:
353+
info = None
354+
expected_n_figs = 1
355+
else:
356+
info = create_info(ch_names, sfreq=1.0, ch_types=ch_types)
357+
unique_types = set(ch_types) if isinstance(ch_types, list) else {ch_types}
358+
expected_n_figs = len(unique_types)
359+
360+
for mode in ("csd", "coh"):
361+
if ch_types == "misc":
362+
with pytest.raises(RuntimeError, match="No plottable channel types"):
363+
plot_csd(csd, info=info, mode=mode, show=False)
364+
else:
365+
figs = plot_csd(csd, info=info, mode=mode, show=False)
366+
assert len(figs) == expected_n_figs
327367

328368

329369
@pytest.mark.slowtest # Slow on Azure

0 commit comments

Comments
 (0)