Skip to content

Commit d757ed4

Browse files
updating the version
1 parent 30d647b commit d757ed4

513 files changed

Lines changed: 136764 additions & 9045 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

stable/.buildinfo

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Sphinx build info version 1
2-
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
3-
config: dfa3b204b4ec02ee979e4bf266ee2b9c
2+
# This file records the configuration used when building these files. When it is not found, a full rebuild will be done.
3+
config: 74f6f6faf2b4ae6e758eaa194da67492
44
tags: 645f666f9bcd5a90fca523b33c5a78b7
Binary file not shown.

stable/_downloads/090305d06248840b75133975e5121f41/plot_sleep_staging_chambon2018.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@
283283
"name": "python",
284284
"nbconvert_exporter": "python",
285285
"pygments_lexer": "ipython3",
286-
"version": "3.12.11"
286+
"version": "3.12.12"
287287
}
288288
},
289289
"nbformat": 4,

stable/_downloads/0a8b8bc2f1b933515b7b4101626dd179/plot_bcic_iv_2a_moabb_trial.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,18 @@
148148

149149

150150
######################################################################
151-
# Now we create the deep learning model! Braindecode comes with some
151+
# Now we create the deep learning model!
152+
# First thing we need to do is know the properties of our signals.
153+
# For this, we use the :func:`braindecode.datautil.infer_signal_properties` function:
154+
#
155+
from braindecode.datautil import infer_signal_properties
156+
157+
sig_props = infer_signal_properties(train_set, mode="classification")
158+
print(sig_props)
159+
160+
161+
######################################################################
162+
# Braindecode comes with some
152163
# predefined convolutional neural network architectures for raw
153164
# time-domain EEG. Here, we use the :class:`ShallowFBCSPNet
154165
# <braindecode.models.ShallowFBCSPNet>` model from [3]_. These models are
@@ -175,16 +186,10 @@
175186
seed = 20200220
176187
set_random_seeds(seed=seed, cuda=cuda)
177188

178-
n_classes = 4
179-
classes = list(range(n_classes))
180-
# Extract number of chans and time steps from dataset
181-
n_chans = train_set[0][0].shape[0]
182-
n_times = train_set[0][0].shape[1]
183-
184189
model = ShallowFBCSPNet(
185-
n_chans,
186-
n_classes,
187-
n_times=n_times,
190+
n_chans=sig_props["n_chans"],
191+
n_outputs=sig_props["n_outputs"],
192+
n_times=sig_props["n_times"],
188193
final_conv_length="auto",
189194
)
190195

@@ -234,6 +239,7 @@
234239

235240
batch_size = 64
236241
n_epochs = 4
242+
classes = list(range(sig_props["n_outputs"]))
237243

238244
clf = EEGClassifier(
239245
model,
@@ -364,6 +370,6 @@
364370
# .. [3] Schirrmeister, R.T., Springenberg, J.T., Fiederer, L.D.J., Glasstetter, M.,
365371
# Eggensperger, K., Tangermann, M., Hutter, F., Burgard, W. and Ball, T. (2017),
366372
# Deep learning with convolutional neural networks for EEG decoding and visualization.
367-
# Hum. Brain Mapping, 38: 5391-5420. https://doi.org/10.1002/hbm.23730.
373+
# Hum. Brain Mapping, 38: 5391-5420. https://onlinelibrary.wiley.com/doi/10.1002/hbm.23730.
368374
#
369375
# .. include:: /links.inc
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n# Fine-tuning a Foundation Model (Signal-JEPA)\n\nFoundation models are large-scale pre-trained models that serve as a starting point\nfor a wide range of downstream tasks, leveraging their generalization capabilities.\nFine-tuning these models is necessary to adapt them to specific tasks or datasets,\nensuring optimal performance in specialized applications.\n\nIn this tutorial, we demonstrate how to load a pre-trained foundation model\nand fine-tune it for a specific task. We use the Signal-JEPA model [1]_\nand a MOABB motor-imagery dataset for this tutorial.\n :depth: 2\n"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {
14+
"collapsed": false
15+
},
16+
"outputs": [],
17+
"source": [
18+
"# Authors: Pierre Guetschel <pierre.guetschel@gmail.com>\n#\n# License: BSD (3-clause)\n#\nimport mne\nimport numpy as np\nimport torch\n\nfrom braindecode import EEGClassifier\nfrom braindecode.datasets import MOABBDataset\nfrom braindecode.models import SignalJEPA_PreLocal\nfrom braindecode.preprocessing import create_windows_from_events\n\ntorch.use_deterministic_algorithms(True)\ntorch.manual_seed(12)\nnp.random.seed(12)"
19+
]
20+
},
21+
{
22+
"cell_type": "markdown",
23+
"metadata": {},
24+
"source": [
25+
"## Loading and preparing the data\n\n### Loading a dataset\n\nWe start by loading a MOABB dataset, a single subject only for speed.\nThe dataset contains motor imagery EEG recordings, which we will preprocess and use for fine-tuning.\n\n\n"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {
32+
"collapsed": false
33+
},
34+
"outputs": [],
35+
"source": [
36+
"subject_id = 3 # Just one subject for speed\ndataset = MOABBDataset(dataset_name=\"BNCI2014_001\", subject_ids=[subject_id])\n\n# Set the standard 10-20 montage for EEG channel locations\nmontage = mne.channels.make_standard_montage(\"standard_1020\")\nfor ds in dataset.datasets:\n ds.raw.set_montage(montage)"
37+
]
38+
},
39+
{
40+
"cell_type": "markdown",
41+
"metadata": {},
42+
"source": [
43+
"### Define Dataset parameters\n\nWe extract the sampling frequency and ensure that it is consistent across\nall recordings. We also extract the window size from the annotations and\ninformation about the EEG channels (names, positions, etc.).\n\n\n"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": null,
49+
"metadata": {
50+
"collapsed": false
51+
},
52+
"outputs": [],
53+
"source": [
54+
"# Extract sampling frequency\nsfreq = dataset.datasets[0].raw.info[\"sfreq\"]\nassert all([ds.raw.info[\"sfreq\"] == sfreq for ds in dataset.datasets])\n\n# Extract and validate window size from annotations\nwindow_size_seconds = dataset.datasets[0].raw.annotations.duration[0]\nassert all(\n d == window_size_seconds\n for ds in dataset.datasets\n for d in ds.raw.annotations.duration\n)\n\n# Extract channel information\nchs_info = dataset.datasets[0].raw.info[\"chs\"] # Channel information\n\nprint(f\"{sfreq=}, {window_size_seconds=}, {len(chs_info)=}\")"
55+
]
56+
},
57+
{
58+
"cell_type": "markdown",
59+
"metadata": {},
60+
"source": [
61+
"### Create Windows from Events\n\nWe use the `create_windows_from_events` function from Braindecode to segment\nthe dataset into windows based on events.\n\n\n"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"metadata": {
68+
"collapsed": false
69+
},
70+
"outputs": [],
71+
"source": [
72+
"classes = [\"feet\", \"left_hand\", \"right_hand\"]\nclasses_mapping = {c: i for i, c in enumerate(classes)}\n\nwindows_dataset = create_windows_from_events(\n dataset,\n preload=True, # Preload the data into memory for faster processing\n mapping=classes_mapping,\n)\nmetadata = windows_dataset.get_metadata()\nprint(metadata.head(10))"
73+
]
74+
},
75+
{
76+
"cell_type": "markdown",
77+
"metadata": {},
78+
"source": [
79+
"## Loading a pre-trained foundation model\n\n### Download and Load Pre-trained Weights\n\nWe download the pre-trained weights for the SignalJEPA model from the Hugging Face Hub.\nThese weights will serve as the starting point for finetuning.\n\n\n"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"metadata": {
86+
"collapsed": false
87+
},
88+
"outputs": [],
89+
"source": [
90+
"model_state_dict = torch.hub.load_state_dict_from_url(\n url=\"https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth\"\n)\n# print(model_state_dict.keys())"
91+
]
92+
},
93+
{
94+
"cell_type": "markdown",
95+
"metadata": {},
96+
"source": [
97+
"### Instantiate the Foundation Model\n\nWe create an instance of the SignalJEPA model using the pre-local downstream\narchitecture. The model is initialized with the dataset's sampling frequency,\nwindow size, and channel information.\n\n\n"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": null,
103+
"metadata": {
104+
"collapsed": false
105+
},
106+
"outputs": [],
107+
"source": [
108+
"model = SignalJEPA_PreLocal(\n sfreq=sfreq,\n input_window_seconds=window_size_seconds,\n chs_info=chs_info,\n n_outputs=len(classes),\n)\nprint(model)"
109+
]
110+
},
111+
{
112+
"cell_type": "markdown",
113+
"metadata": {},
114+
"source": [
115+
"### Load the Pre-trained Weights into the Model\n\nWe load the pre-trained weights into the model. The transformer layers are excluded\nas this module is not used in the pre-local downstream architecture (see [1]_).\n\n\n"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": null,
121+
"metadata": {
122+
"collapsed": false
123+
},
124+
"outputs": [],
125+
"source": [
126+
"# Define layers to exclude from the pre-trained weights\nnew_layers = {\n \"spatial_conv.1.weight\",\n \"spatial_conv.1.bias\",\n \"final_layer.1.weight\",\n \"final_layer.1.bias\",\n}\n\n# Filter out transformer weights and load the state dictionary\nmodel_state_dict = {\n k: v for k, v in model_state_dict.items() if not k.startswith(\"transformer.\")\n}\nmissing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)\n\n# Ensure no unexpected keys and validate missing keys\nassert unexpected_keys == [], f\"{unexpected_keys=}\"\nassert set(missing_keys) == new_layers, f\"{missing_keys=}\""
127+
]
128+
},
129+
{
130+
"cell_type": "markdown",
131+
"metadata": {},
132+
"source": [
133+
"## Fine-tuning the Model\n\nSignal-JEPA is a model trained in a self-supervised manner on a masked\nprediction task. In this task, the model is configured in a many-to-many\nfashion, which is not suited for a classification task. Therefore, we need to\nadjust the model architecture for finetuning. This is what is done by the\n:class:`SignalJEPA_PreLocal`, :class:`SignalJEPA_Contextual`, and\n:class:`SignalJEPA_PostLocal` classes. In these classes, new layers are added\nspecifically for classification, as described in the article [1]_ and in the following figure:\n\n<img src=\"file://_static/model/sjepa_pre-local.jpg\" alt=\"Signal-JEPA Pre-Local Downstream Architecture\" align=\"center\">\n\nWith this downstream architecture, two options are possible for fine-tuning:\n\n1) Fine-tune only the newly added layers\n2) Fine-tune the entire model\n\n### Freezing Pre-trained Layers\n\nAs the second option is rather straightforward to implement,\nwe will focus on the first option here.\nWe will freeze all layers except the newly added ones.\n\n\n"
134+
]
135+
},
136+
{
137+
"cell_type": "code",
138+
"execution_count": null,
139+
"metadata": {
140+
"collapsed": false
141+
},
142+
"outputs": [],
143+
"source": [
144+
"for name, param in model.named_parameters():\n if name not in new_layers:\n param.requires_grad = False\n\nprint(\"Trainable parameters:\")\nother_modules = set()\nfor name, param in model.named_parameters():\n if param.requires_grad:\n print(name)\n else:\n other_modules.add(name.split(\".\")[0])\n\nprint(\"\\nOther modules:\")\nprint(other_modules)"
145+
]
146+
},
147+
{
148+
"cell_type": "markdown",
149+
"metadata": {},
150+
"source": [
151+
"### Fine-tuning Procedure\n\nFinally, we set up the fine-tuning procedure using Braindecode's\n:class:`EEGClassifier`. We define the loss function, optimizer, and training\nparameters. We then fit the model to the windows dataset.\n\nWe only train for a few epochs for demonstration purposes.\n\n\n"
152+
]
153+
},
154+
{
155+
"cell_type": "code",
156+
"execution_count": null,
157+
"metadata": {
158+
"collapsed": false
159+
},
160+
"outputs": [],
161+
"source": [
162+
"clf = EEGClassifier(\n model,\n criterion=torch.nn.CrossEntropyLoss,\n optimizer=torch.optim.AdamW,\n optimizer__lr=0.005,\n batch_size=16,\n callbacks=[\"accuracy\"],\n classes=range(3),\n)\n_ = clf.fit(windows_dataset, y=metadata[\"target\"], epochs=10)"
163+
]
164+
},
165+
{
166+
"cell_type": "markdown",
167+
"metadata": {},
168+
"source": [
169+
"### All-in-one Implementation\n\nIn the implementation above, we manually loaded the weights and froze the layers.\nThis forces us to pass an initialized model to :class:`EEGClassifier`, which may\ncreate issues if we use it in a cross-validation setting.\n\nInstead, we can implement the same procedure in a more compact and reproducible way,\nby using skorch's callback system.\n\nHere, we import a callback to freeze layers and define a custom\ncallback to load the pre-trained weights at the beginning of training:\n\n\n"
170+
]
171+
},
172+
{
173+
"cell_type": "code",
174+
"execution_count": null,
175+
"metadata": {
176+
"collapsed": false
177+
},
178+
"outputs": [],
179+
"source": [
180+
"from skorch.callbacks import Callback, Freezer\n\n\nclass WeightsLoader(Callback):\n def __init__(self, url, strict=False):\n self.url = url\n self.strict = strict\n\n def on_train_begin(self, net, X=None, y=None, **kwargs):\n state_dict = torch.hub.load_state_dict_from_url(url=self.url)\n net.module_.load_state_dict(state_dict, strict=self.strict)"
181+
]
182+
},
183+
{
184+
"cell_type": "markdown",
185+
"metadata": {},
186+
"source": [
187+
"We can now define a classifier with those callbacks, without having\nto pass an initialized model, and fit it as before:\n\n\n"
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"metadata": {
194+
"collapsed": false
195+
},
196+
"outputs": [],
197+
"source": [
198+
"clf = EEGClassifier(\n \"SignalJEPA_PreLocal\",\n criterion=torch.nn.CrossEntropyLoss,\n optimizer=torch.optim.AdamW,\n optimizer__lr=0.005,\n batch_size=16,\n callbacks=[\n \"accuracy\",\n WeightsLoader(\n url=\"https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth\"\n ),\n Freezer(patterns=\"feature_encoder.*\"),\n ],\n classes=range(3),\n)\n_ = clf.fit(windows_dataset, y=metadata[\"target\"], epochs=10)"
199+
]
200+
},
201+
{
202+
"cell_type": "markdown",
203+
"metadata": {},
204+
"source": [
205+
"## Conclusion and Next Steps\n\nIn this tutorial, we demonstrated how to fine-tune a pre-trained foundation\nmodel, Signal-JEPA, for a motor imagery classification task. We now have a basic\nimplementation that can automatically load pre-trained weights and freeze specific layers.\n\nThis setup can easily be extended to explore different fine-tuning techniques,\nbase foundation models, and downstream tasks.\n\n\n"
206+
]
207+
},
208+
{
209+
"cell_type": "markdown",
210+
"metadata": {},
211+
"source": [
212+
"## References\n\n.. [1] Guetschel, P., Moreau, T., and Tangermann, M. (2024)\n \u201cS-JEPA: towards seamless cross-dataset transfer\n through dynamic spatial attention\u201d. https://arxiv.org/abs/2403.11772\n\n"
213+
]
214+
}
215+
],
216+
"metadata": {
217+
"kernelspec": {
218+
"display_name": "Python 3",
219+
"language": "python",
220+
"name": "python3"
221+
},
222+
"language_info": {
223+
"codemirror_mode": {
224+
"name": "ipython",
225+
"version": 3
226+
},
227+
"file_extension": ".py",
228+
"mimetype": "text/x-python",
229+
"name": "python",
230+
"nbconvert_exporter": "python",
231+
"pygments_lexer": "ipython3",
232+
"version": "3.12.12"
233+
}
234+
},
235+
"nbformat": 4,
236+
"nbformat_minor": 0
237+
}
Binary file not shown.

stable/_downloads/0f2bf063e08b7d05b80e0004fcbbb6f9/benchmark_lazy_eager_loading.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@
175175
"name": "python",
176176
"nbconvert_exporter": "python",
177177
"pygments_lexer": "ipython3",
178-
"version": "3.12.11"
178+
"version": "3.12.12"
179179
}
180180
},
181181
"nbformat": 4,

stable/_downloads/0f763ae384277e558103757157e170fb/plot_data_augmentation_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from braindecode.datasets import MOABBDataset
5454

5555
subject_id = 3
56-
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])
56+
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id])
5757

5858
######################################################################
5959
# Preprocessing

0 commit comments

Comments
 (0)