Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13745.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for :func:`mne.concatenate_epochs` with :class:`~mne.time_frequency.EpochsTFR` instances, by ``aman-coder03``. (:gh:`13745`)
62 changes: 62 additions & 0 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4693,6 +4693,66 @@ def _concatenate_epochs(
)


def _concatenate_epochs_tfr(epochs_list, add_offset=True):
"""Concatenate a list of EpochsTFR instances."""
for ii, ep in enumerate(epochs_list):
if type(ep).__name__ != "EpochsTFR":
raise TypeError(
f"epochs_list[{ii}] must be an instance of EpochsTFR, got {type(ep)}"
)
ref = epochs_list[0]
for ii, ep in enumerate(epochs_list[1:], 1):
if not np.array_equal(ep.freqs, ref.freqs):
raise ValueError(f"epochs_list[{ii}] freqs do not match epochs_list[0]")
if not np.array_equal(ep.times, ref.times):
raise ValueError(f"epochs_list[{ii}] times do not match epochs_list[0]")
_ensure_infos_match(ep.info, ref.info, f"epochs_list[{ii}]")

data = np.concatenate([ep.data for ep in epochs_list], axis=0)

shift = np.int64((10 + ref.times[-1]) * ref.info["sfreq"])
events_offset = int(np.max(epochs_list[0].events[:, 0])) + shift
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
shift = np.int64((10 + ref.times[-1]) * ref.info["sfreq"])
events_offset = int(np.max(epochs_list[0].events[:, 0])) + shift
shift = len(ref.times)
events_offset = ref.events[-1, 0] + shift

all_events = [epochs_list[0].events.copy()]
for ep in epochs_list[1:]:
evs = ep.events.copy()
if add_offset:
evs[:, 0] += events_offset
events_offset += int(np.max(ep.events[:, 0])) + shift
all_events.append(evs)
events = np.concatenate(all_events, axis=0)

event_id = deepcopy(ref.event_id)
for ep in epochs_list[1:]:
event_id.update(ep.event_id)

selection = np.concatenate([ep.selection for ep in epochs_list])
drop_log = sum([ep.drop_log for ep in epochs_list], ())

metadatas = [ep.metadata for ep in epochs_list]
n_have = sum(m is not None for m in metadatas)
if n_have == 0:
metadata = None
elif n_have != len(metadatas):
raise ValueError(
f"{n_have} of {len(metadatas)} EpochsTFR instances have metadata, "
"all or none must have metadata"
)
else:
pd = _check_pandas_installed(strict=False)
metadata = pd.concat(metadatas) if pd is not False else sum(metadatas, list())

state = ref.__getstate__()
state["data"] = data
state["events"] = events
state["event_id"] = event_id
state["selection"] = selection
state["drop_log"] = drop_log
state["metadata"] = metadata
out = type(epochs_list[0]).__new__(type(epochs_list[0]))
out.__setstate__(state)
return out


@verbose
def concatenate_epochs(
epochs_list, add_offset=True, *, on_mismatch="raise", verbose=None
Expand Down Expand Up @@ -4725,6 +4785,8 @@ def concatenate_epochs(
-----
.. versionadded:: 0.9.0
"""
if epochs_list and type(epochs_list[0]).__name__ == "EpochsTFR":
return _concatenate_epochs_tfr(epochs_list, add_offset=add_offset)
(
info,
data,
Expand Down
50 changes: 50 additions & 0 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,3 +1932,53 @@ def test_combine_tfr_error_catch(average_tfr):
match="Aggregating multitaper tapers across TFR datasets is not supported.",
):
combine_tfr([average_tfr_taper, average_tfr_taper])


def test_concatenate_epochs_tfr():
"""Test concatenate_epochs() works for EpochsTFR instances."""
from mne.epochs import concatenate_epochs

info = create_info(ch_names=["EEG1", "EEG2", "EEG3"], sfreq=256.0, ch_types="eeg")
data = np.random.randn(20, 3, 512)
events = np.column_stack([np.arange(20) * 512, np.zeros(20, int), np.ones(20, int)])
epochs = EpochsArray(data, info, events=events, tmin=0)
freqs = np.arange(8, 12)
tfr1 = epochs[:10].compute_tfr("morlet", freqs=freqs)
tfr2 = epochs[10:20].compute_tfr("morlet", freqs=freqs)

# basic concatenation
out = concatenate_epochs([tfr1, tfr2])
assert isinstance(out, EpochsTFR)
assert out.shape[0] == 20 # 10 + 10 epochs
assert out.shape[1] == 3 # channels unchanged
assert out.shape[2] == len(freqs) # freqs unchanged
assert len(out.events) == 20
assert_array_equal(out.freqs, tfr1.freqs)
assert_array_equal(out.times, tfr1.times)

# event offset: second block's events should be shifted forward
assert np.all(out.events[10:, 0] > out.events[:10, 0])

# add_offset=False: event times should be preserved as-is
out_no_offset = concatenate_epochs([tfr1, tfr2], add_offset=False)
assert_array_equal(out_no_offset.events[:10, 0], tfr1.events[:, 0])
assert_array_equal(out_no_offset.events[10:, 0], tfr2.events[:, 0])

# data integrity: first 10 epochs should match tfr1.data
assert_array_equal(out.data[:10], tfr1.data)
assert_array_equal(out.data[10:], tfr2.data)

# mismatched freqs should raise
tfr_bad_freqs = epochs[:10].compute_tfr("morlet", freqs=np.arange(13, 17))
with pytest.raises(ValueError, match="freqs do not match"):
concatenate_epochs([tfr1, tfr_bad_freqs])

# mismatched times should raise
tfr_bad_times = epochs[:10].compute_tfr("morlet", freqs=freqs)
tfr_bad_times._set_times(tfr_bad_times.times + 1.0) # shift times to force mismatch
with pytest.raises(ValueError, match="times do not match"):
concatenate_epochs([tfr1, tfr_bad_times])

# passing a non-EpochsTFR should raise
with pytest.raises(TypeError, match="must be an instance of EpochsTFR"):
concatenate_epochs([tfr1, epochs[:10]])
Loading