Skip to content

Commit c530c20

Browse files
Raising meaningful warnings/errors for interpolate_bads, when supplie… (#13518)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 781eb3f commit c530c20

3 files changed

Lines changed: 55 additions & 0 deletions

File tree

doc/changes/dev/13518.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make :meth:`~mne.io.Raw.interpolate_bads` method flexible (ignore, warn, raise) about how to handle interpolation of channels with invalid positions, by `Himanshu Mahor`_.

mne/channels/channels.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ def interpolate_bads(
829829
origin="auto",
830830
method=None,
831831
exclude=(),
832+
on_bad_position="warn",
832833
verbose=None,
833834
):
834835
"""Interpolate bad MEG and EEG channels.
@@ -874,6 +875,12 @@ def interpolate_bads(
874875
exclude : list | tuple
875876
The channels to exclude from interpolation. If excluded a bad
876877
channel will stay in bads.
878+
on_bad_position : "raise" | "warn" | "ignore"
879+
What to do when one or more sensor positions are invalid (zero or NaN).
880+
If ``"warn"`` or ``"ignore"``, channels with invalid positions will be
881+
filled with :data:`~numpy.nan`.
882+
883+
.. versionadded:: 1.12
877884
%(verbose)s
878885
879886
Returns
@@ -898,6 +905,32 @@ def interpolate_bads(
898905

899906
_check_preload(self, "interpolation")
900907
_validate_type(method, (dict, str, None), "method")
908+
909+
# check for channels with invalid position(s)
910+
invalid_chs = []
911+
for ch in self.info["bads"]:
912+
loc = self.info["chs"][self.ch_names.index(ch)]["loc"][:3]
913+
if np.allclose(loc, 0.0, rtol=0, atol=1e-16) or np.isnan(loc).any():
914+
invalid_chs.append(ch)
915+
916+
if invalid_chs:
917+
if on_bad_position == "raise":
918+
msg = (
919+
f"Channel(s) {invalid_chs} have invalid sensor position(s). "
920+
"Interpolation cannot proceed correctly. If you want to "
921+
"continue despite missing positions, set "
922+
"on_bad_position='warn' or 'ignore', which outputs all "
923+
"NaN values (np.nan) for the interpolated channel(s)."
924+
)
925+
else:
926+
msg = (
927+
f"Channel(s) {invalid_chs} have invalid sensor position(s) "
928+
"and cannot be interpolated. The values of these channels "
929+
"will be all NaN. To ignore this warning, pass "
930+
"on_bad_position='ignore'."
931+
)
932+
_on_missing(on_bad_position, msg)
933+
901934
method = _handle_default("interpolation_method", method)
902935
ch_types = self.get_channel_types(unique=True)
903936
# figure out if we have "mag" for "meg", "hbo" for "fnirs", ... to filter the

mne/channels/tests/test_interpolation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,27 @@ def test_nan_interpolation(raw):
421421
raw.interpolate_bads(method="nan", reset_bads=False)
422422
assert raw.info["bads"] == ch_to_interp
423423

424+
store = raw.info["chs"][1]["loc"]
425+
# for on_bad_position="raise"
426+
raw.info["bads"] = ch_to_interp
427+
raw.info["chs"][1]["loc"] = np.full(12, np.nan)
428+
with pytest.raises(ValueError, match="have invalid sensor position"):
429+
# DOES NOT interpolates at all. So raw.info["bads"] remains as is
430+
raw.interpolate_bads(on_bad_position="raise")
431+
432+
# for on_bad_position="warn"
433+
with pytest.warns(RuntimeWarning, match="have invalid sensor position"):
434+
# this DOES the interpolation BUT with a warning
435+
# so raw.info["bad"] will be empty again,
436+
# and interpolated channel with be all np.nan
437+
raw.interpolate_bads(on_bad_position="warn")
438+
439+
# for on_bad_position="ignore"
440+
raw.info["bads"] = ch_to_interp
441+
assert raw.interpolate_bads(on_bad_position="ignore")
442+
assert np.isnan(bad_chs).all, "Interpolated channel should be all NaN"
443+
raw.info["chs"][1]["loc"] = store
444+
424445
# make sure other channels are untouched
425446
raw.drop_channels(ch_to_interp)
426447
good_chs = raw.get_data()

0 commit comments

Comments
 (0)