|
| 1 | +""" |
| 2 | +.. _ex-epoch-quality: |
| 3 | +
|
| 4 | +======================================== |
| 5 | +Exploring epoch quality before rejection |
| 6 | +======================================== |
| 7 | +
|
| 8 | +This example shows an approach for identifying epochs containing potential artifacts and |
| 9 | +rejecting these bad epochs. We compute per-epoch outlier scores using peak-to-peak |
| 10 | +amplitude, variance, and kurtosis — inspired by FASTER :footcite:`NolanEtAl2010` and |
| 11 | +:footcite:t:`DelormeEtAl2007` — and use them to rank epochs from cleanest to noisiest to |
| 12 | +inform rejection decisions. |
| 13 | +""" |
| 14 | +# Authors: Aman Srivastava |
| 15 | +# |
| 16 | +# License: BSD-3-Clause |
| 17 | +# Copyright the MNE-Python contributors. |
| 18 | + |
| 19 | +# %% |
| 20 | +import matplotlib.pyplot as plt |
| 21 | +import numpy as np |
| 22 | +from scipy.stats import kurtosis |
| 23 | + |
| 24 | +import mne |
| 25 | +from mne.datasets import eegbci |
| 26 | + |
| 27 | +print(__doc__) |
| 28 | + |
| 29 | +# %% |
| 30 | +# Load the EEGBCI dataset and create epochs |
| 31 | +# ----------------------------------------- |
| 32 | +raw_fname = eegbci.load_data(subjects=3, runs=(3,))[0] |
| 33 | +raw = mne.io.read_raw(raw_fname, preload=True) |
| 34 | +eegbci.standardize(raw) |
| 35 | +montage = mne.channels.make_standard_montage("standard_1005") |
| 36 | +raw.set_montage(montage) |
| 37 | + |
| 38 | +events, event_id = mne.events_from_annotations(raw) |
| 39 | +epochs = mne.Epochs(raw, events, tmin=-0.2, tmax=0.5, preload=True, baseline=(None, 0)) |
| 40 | + |
| 41 | +# %% |
| 42 | +# Compute per-epoch outlier scores |
| 43 | +# -------------------------------- |
| 44 | +# Peak-to-peak amplitude, variance, and kurtosis are computed per epoch. Each feature is |
| 45 | +# z-scored robustly using median absolute deviation across epochs, and averaged into a |
| 46 | +# single outlier score normalised between [0, 1]. Scores close to 1 indicate a likely |
| 47 | +# presence of artifacts in the epoch. |
| 48 | + |
| 49 | +data = epochs.get_data() # (n_epochs, n_channels, n_times) |
| 50 | + |
| 51 | +ptp = np.ptp(data, axis=-1).mean(axis=-1) |
| 52 | +var = data.var(axis=-1).mean(axis=-1) |
| 53 | +kurt = np.array([kurtosis(data[i].ravel()) for i in range(len(data))]) |
| 54 | + |
| 55 | +features = np.column_stack([ptp, var, kurt]) |
| 56 | +median = np.median(features, axis=0) |
| 57 | +mad = np.median(np.abs(features - median), axis=0) + 1e-10 |
| 58 | +z = np.abs((features - median) / mad) |
| 59 | + |
| 60 | +raw_score = z.mean(axis=-1) |
| 61 | +scores = (raw_score - raw_score.min()) / (raw_score.max() - raw_score.min() + 1e-10) |
| 62 | + |
| 63 | +# %% |
| 64 | +# Determining outlier epochs |
| 65 | +# -------------------------- |
| 66 | +# Below, epochs are ranked from cleanest to noisiest. We need to find an appropriate |
| 67 | +# threshold to flag those epochs likely containing artifacts. The threshold to use will |
| 68 | +# vary depending on the dataset and analysis goals. In the plot, we show two example |
| 69 | +# thresholds: a more lenient threshold of 0.6; and a stricter threshold of 0.3. |
| 70 | +fig, ax = plt.subplots(layout="constrained") |
| 71 | +sorted_idx = np.argsort(scores) |
| 72 | +ax.bar(np.arange(len(scores)), scores[sorted_idx], color="steelblue") |
| 73 | +ax.axhline(0.6, color="red", linestyle="--", label="More lenient threshold (0.6)") |
| 74 | +ax.axhline(0.3, color="orange", linestyle="--", label="Stricter threshold (0.3)") |
| 75 | +ax.set( |
| 76 | + xlabel="Epoch (sorted by score)", |
| 77 | + ylabel="Outlier score", |
| 78 | + title="Epoch quality scores (0 = clean, 1 = likely artifact)", |
| 79 | +) |
| 80 | +ax.legend() |
| 81 | + |
| 82 | +for threshold in [0.6, 0.3]: |
| 83 | + bad_epochs = np.where(scores > threshold)[0] |
| 84 | + print( |
| 85 | + f"Threshold {threshold}: {len(bad_epochs)} epochs flagged " |
| 86 | + f"out of {len(epochs)} total" |
| 87 | + ) |
| 88 | + |
| 89 | +# %% |
| 90 | +# Epochs flagged by the thresholds can be inspected using the :meth:`mne.Epochs.plot` |
| 91 | +# method. This is a crucial step in identifying the optimal threshold. First, we show |
| 92 | +# those epochs with the worst scores (≥ 0.6), containing a number of amplitude spikes. |
| 93 | +picks = np.arange(17, 40, dtype=int) # channels with notable amplitude spikes |
| 94 | +epochs[np.where(scores >= 0.6)[0]].plot( |
| 95 | + picks=picks, title="Scores ≥ 0.6", scalings=dict(eeg=70e-6), n_channels=len(picks) |
| 96 | +) |
| 97 | +# %% |
| 98 | +# In contrast, the threshold of 0.3 captures epochs with less severe artifact activity, |
| 99 | +# which may be overly conservative to exclude from the analysis. |
| 100 | +epochs[np.where((scores >= 0.3) & (scores < 0.6))[0]].plot( |
| 101 | + picks=picks, |
| 102 | + title="0.3 ≤ scores < 0.6", |
| 103 | + scalings=dict(eeg=70e-6), |
| 104 | + n_channels=len(picks), |
| 105 | +) |
| 106 | + |
| 107 | +# %% |
| 108 | +# Dropping suspicious epochs |
| 109 | +# -------------------------- |
| 110 | +# Following visual inspection, bad epochs can be discarded using the |
| 111 | +# :meth:`mne.Epochs.drop` method. Here, we remove the worst scoring epochs (≥ 0.6) |
| 112 | +# which contained strong artifact activity. The remaining good epochs can then be used |
| 113 | +# for further analysis. |
| 114 | +epochs.drop(np.where(scores >= 0.6)[0]) |
| 115 | +print(f"Epochs remaining after dropping scores ≥ 0.6: {len(epochs)}") |
| 116 | + |
| 117 | +# %% |
| 118 | +# References |
| 119 | +# ---------- |
| 120 | +# .. footbibliography:: |
0 commit comments