2323 _picks_by_type ,
2424 pick_channels ,
2525 pick_info ,
26- pick_types ,
2726)
2827from .._fiff .proj import make_projector
2928from .._freesurfer import _check_mri , _mri_orientation , _read_mri_info , _reorient_image
5352)
5453
5554
56- def _index_info_cov (info , cov , exclude ):
57- if exclude == "bads" :
58- exclude = info ["bads" ]
59- info = pick_info (info , pick_channels (info ["ch_names" ], cov ["names" ], exclude ))
60- del exclude
55+ def _get_ch_type_metadata (info , ch_names ):
56+ """Get indices, titles, units, scalings, and types for plottable channel types."""
57+ info_ch_names = info ["ch_names" ]
6158 picks_list = _picks_by_type (info , meg_combined = False , ref_meg = False , exclude = ())
6259 picks_by_type = dict (picks_list )
6360
64- ch_names = [n for n in cov .ch_names if n in info ["ch_names" ]]
65- ch_idx = [cov .ch_names .index (n ) for n in ch_names ]
66-
67- info_ch_names = info ["ch_names" ]
6861 idx_by_type = defaultdict (list )
6962 for ch_type , sel in picks_by_type .items ():
7063 idx_by_type [ch_type ] = [
7164 ch_names .index (info_ch_names [c ])
7265 for c in sel
7366 if info_ch_names [c ] in ch_names
7467 ]
68+
69+ indices = []
70+ titles = []
71+ units = []
72+ scalings = []
73+ ch_types = []
74+ for key in _DATA_CH_TYPES_SPLIT :
75+ if len (idx_by_type [key ]) > 0 :
76+ indices .append (idx_by_type [key ])
77+ titles .append (DEFAULTS ["titles" ][key ])
78+ units .append (DEFAULTS ["units" ][key ])
79+ scalings .append (DEFAULTS ["scalings" ][key ])
80+ ch_types .append (key )
81+ if len (indices ) == 0 :
82+ raise RuntimeError (
83+ "No plottable channel types found. "
84+ f"Allowed types are: { _DATA_CH_TYPES_SPLIT } "
85+ )
86+ return indices , titles , units , scalings , ch_types
87+
88+
89+ def _index_info_cov (info , cov , exclude ):
90+ """Pick cov data and get metadata for present, plottable data channel types."""
91+ if exclude == "bads" :
92+ exclude = info ["bads" ]
93+ info = pick_info (info , pick_channels (info ["ch_names" ], cov ["names" ], exclude ))
94+ del exclude
95+
96+ ch_names = [n for n in cov .ch_names if n in info ["ch_names" ]]
97+ ch_idx = [cov .ch_names .index (n ) for n in ch_names ]
98+
99+ indices , titles , units , scalings , ch_types = _get_ch_type_metadata (info , ch_names )
75100 idx_names = [
76- (
77- idx_by_type [key ],
78- f"{ DEFAULTS ['titles' ][key ]} covariance" ,
79- DEFAULTS ["units" ][key ],
80- DEFAULTS ["scalings" ][key ],
81- key ,
101+ (idx , f"{ title } covariance" , unit , scaling , key )
102+ for idx , title , unit , scaling , key in zip (
103+ indices , titles , units , scalings , ch_types
82104 )
83- for key in _DATA_CH_TYPES_SPLIT
84- if len (idx_by_type [key ]) > 0
85105 ]
86106 C = cov .data [ch_idx ][:, ch_idx ]
87107 return info , C , ch_names , idx_names
@@ -1483,39 +1503,16 @@ def plot_csd(
14831503 raise ValueError ('"mode" should be either "csd" or "coh".' )
14841504
14851505 if info is not None :
1486- info_ch_names = info ["ch_names" ]
1487- sel_eeg = pick_types (info , meg = False , eeg = True , ref_meg = False , exclude = [])
1488- sel_mag = pick_types (info , meg = "mag" , eeg = False , ref_meg = False , exclude = [])
1489- sel_grad = pick_types (info , meg = "grad" , eeg = False , ref_meg = False , exclude = [])
1490- idx_eeg = [
1491- csd .ch_names .index (info_ch_names [c ])
1492- for c in sel_eeg
1493- if info_ch_names [c ] in csd .ch_names
1494- ]
1495- idx_mag = [
1496- csd .ch_names .index (info_ch_names [c ])
1497- for c in sel_mag
1498- if info_ch_names [c ] in csd .ch_names
1499- ]
1500- idx_grad = [
1501- csd .ch_names .index (info_ch_names [c ])
1502- for c in sel_grad
1503- if info_ch_names [c ] in csd .ch_names
1504- ]
1505- indices = [idx_eeg , idx_mag , idx_grad ]
1506- titles = ["EEG" , "Magnetometers" , "Gradiometers" ]
1507-
1508- if mode == "csd" :
1509- # The units in which to plot the CSD
1510- units = dict (eeg = "µV²" , grad = "fT²/cm²" , mag = "fT²" )
1511- scalings = dict (eeg = 1e12 , grad = 1e26 , mag = 1e30 )
1506+ indices , titles , units , scalings , ch_types = _get_ch_type_metadata (
1507+ info , csd .ch_names
1508+ )
15121509 else :
15131510 indices = [np .arange (len (csd .ch_names ))]
1511+ units = ["" ]
1512+ scalings = [1 ]
1513+ ch_types = [None ]
15141514 if mode == "csd" :
15151515 titles = ["Cross-spectral density" ]
1516- # Units and scaling unknown
1517- units = dict ()
1518- scalings = dict ()
15191516 elif mode == "coh" :
15201517 titles = ["Coherence" ]
15211518
@@ -1526,10 +1523,9 @@ def plot_csd(
15261523 n_rows = int (np .ceil (n_freqs / float (n_cols )))
15271524
15281525 figs = []
1529- for ind , title , ch_type in zip (indices , titles , ["eeg" , "mag" , "grad" ]):
1530- if len (ind ) == 0 :
1531- continue
1532-
1526+ for ind , title , unit , scaling , ch_type in zip (
1527+ indices , titles , units , scalings , ch_types
1528+ ):
15331529 fig , axes = plt .subplots (
15341530 n_rows ,
15351531 n_cols ,
@@ -1542,7 +1538,7 @@ def plot_csd(
15421538 for i in range (len (csd .frequencies )):
15431539 cm = csd .get_data (index = i )[ind ][:, ind ]
15441540 if mode == "csd" :
1545- cm = np .abs (cm ) * scalings . get ( ch_type , 1 )
1541+ cm = np .abs (cm ) * scaling ** 2
15461542 elif mode == "coh" :
15471543 # Compute coherence from the CSD matrix
15481544 psd = np .diag (cm ).real
@@ -1566,8 +1562,10 @@ def plot_csd(
15661562 cb = plt .colorbar (im , ax = [a for ax_ in axes for a in ax_ ])
15671563 if mode == "csd" :
15681564 label = "CSD"
1569- if ch_type in units :
1570- label += f" ({ units [ch_type ]} )"
1565+ if ch_type is not None :
1566+ if "/" in unit :
1567+ unit = f"({ unit } )"
1568+ label += f" ({ unit } ²)"
15711569 cb .set_label (label )
15721570 elif mode == "coh" :
15731571 cb .set_label ("Coherence" )
0 commit comments