Skip to content

Commit c45d7b0

Browse files
authored
ENH: allow per-channel colors in raw.plot via channel name dict keys (#13765)
1 parent 7e172ad commit c45d7b0

4 files changed

Lines changed: 38 additions & 2 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Allow per-channel color overrides in :func:`mne.viz.plot_raw` via channel name keys in the ``color`` dict, by :newcontrib:`Hansuja Budhiraja`.

mne/viz/_mpl_figure.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,8 +2006,13 @@ def _draw_traces(self):
20062006
)
20072007
offsets = self.mne.trace_offsets[offset_ixs]
20082008
bad_bool = np.isin(ch_names, self.mne.info["bads"])
2009-
# colors
2010-
good_ch_colors = [self.mne.ch_color_dict[_type] for _type in ch_types]
2009+
# colors: allow overrides by channel name, then by channel type
2010+
good_ch_colors = []
2011+
for _name, _type in zip(ch_names, ch_types):
2012+
if _name in self.mne.ch_color_dict:
2013+
good_ch_colors.append(self.mne.ch_color_dict[_name])
2014+
else:
2015+
good_ch_colors.append(self.mne.ch_color_dict[_type])
20112016
ch_colors = to_rgba_array(
20122017
[
20132018
self.mne.ch_color_bad if _bad else _color

mne/viz/raw.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ def plot_raw(
9696
emg='k', ref_meg='steelblue', misc='k', stim='k',
9797
resp='k', chpi='k')
9898
99+
If a dict, keys can be channel *types* (e.g., ``'eeg'``) and/or
100+
channel *names* (e.g., ``'SFG, Left'``); name-based entries
101+
take precedence over type-based ones.
102+
99103
bad_color : color object
100104
Color to make bad channels.
101105
%(event_color)s

mne/viz/tests/test_raw.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from copy import deepcopy
88
from pathlib import Path
99

10+
import matplotlib.colors as mcolors
1011
import matplotlib.pyplot as plt
1112
import numpy as np
1213
import pytest
@@ -317,6 +318,31 @@ def test_scale_bar(browser_backend):
317318
bar_lims = bar.get_ydata()
318319
assert_allclose(y_lims, bar_lims, atol=1e-4)
319320

321+
# Per-channel color overrides via channel names (matplotlib only).
322+
if ismpl:
323+
sfreq = 100.0
324+
ch_names = ["SFG, Left", "SFG, Right", "MFG, Left"]
325+
info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types="eeg")
326+
data = np.zeros((len(ch_names), int(sfreq))) # 1 second of zeros
327+
raw2 = RawArray(data, info)
328+
329+
color = {"eeg": "k", "SFG, Left": "red"}
330+
browser_backend._close_all()
331+
fig2 = plot_raw(raw2, color=color, show=False)
332+
333+
# ch_colors stores the "good" (non-bad) colors, in visible channel order
334+
assert fig2.mne.ch_colors[0] == "red"
335+
assert fig2.mne.ch_colors[1] == "k"
336+
assert fig2.mne.ch_colors[2] == "k"
337+
338+
# check colours on the plot are also correct
339+
for trace, ch_color in zip(fig2.mne.traces, fig2.mne.ch_colors):
340+
assert np.allclose(
341+
mcolors.to_rgba(trace.get_color()), mcolors.to_rgba(ch_color)
342+
), f"Expected {ch_color}, got {trace.get_color()}"
343+
344+
browser_backend._close_all()
345+
320346

321347
def test_plot_raw_selection(raw, browser_backend):
322348
"""Test selection mode of plot_raw()."""

0 commit comments

Comments
 (0)