Skip to content

Commit bda5e52

Browse files
1 parent 5053d79 commit bda5e52

182 files changed

Lines changed: 4915 additions & 1641 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Binary file not shown.

dev/_downloads/0ada35f18bb95235ba2842270483081f/plot_finetune_foundation_model.ipynb

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
"cell_type": "markdown",
4141
"metadata": {},
4242
"source": [
43-
"### Define Dataset parameters\n\nWe extract the sampling frequency and ensure that it is consistent across\nall recordings. We also extract the window size from the annotations and\ninformation about the EEG channels (names, positions, etc.).\n\n\n"
43+
"### Preprocessing to match the pretrained model\n\nThe pretrained SignalJEPA checkpoint expects 19 EEG channels at 128 Hz\nwith 2-second windows. We adapt the dataset accordingly: keep only\nEEG channels, pick the first 19, and resample.\n\n\n"
4444
]
4545
},
4646
{
@@ -51,32 +51,14 @@
5151
},
5252
"outputs": [],
5353
"source": [
54-
"# Extract sampling frequency\nsfreq = dataset.datasets[0].raw.info[\"sfreq\"]\nassert all([ds.raw.info[\"sfreq\"] == sfreq for ds in dataset.datasets])\n\n# Extract and validate window size from annotations\nwindow_size_seconds = dataset.datasets[0].raw.annotations.duration[0]\nassert all(\n d == window_size_seconds\n for ds in dataset.datasets\n for d in ds.raw.annotations.duration\n)\n\n# Extract channel information\nchs_info = dataset.datasets[0].raw.info[\"chs\"] # Channel information\n\nprint(f\"{sfreq=}, {window_size_seconds=}, {len(chs_info)=}\")"
54+
"for ds in dataset.datasets:\n ds.raw.pick_types(eeg=True) # drop EOG / stim channels\n ds.raw.pick(ds.raw.ch_names[:19]) # match pretrained channel count\n ds.raw.resample(128) # match pretrained sampling frequency"
5555
]
5656
},
5757
{
5858
"cell_type": "markdown",
5959
"metadata": {},
6060
"source": [
61-
"### Create Windows from Events\n\nWe use the `create_windows_from_events` function from Braindecode to segment\nthe dataset into windows based on events.\n\n\n"
62-
]
63-
},
64-
{
65-
"cell_type": "code",
66-
"execution_count": null,
67-
"metadata": {
68-
"collapsed": false
69-
},
70-
"outputs": [],
71-
"source": [
72-
"classes = [\"feet\", \"left_hand\", \"right_hand\"]\nclasses_mapping = {c: i for i, c in enumerate(classes)}\n\nwindows_dataset = create_windows_from_events(\n dataset,\n preload=True, # Preload the data into memory for faster processing\n mapping=classes_mapping,\n)\nmetadata = windows_dataset.get_metadata()\nprint(metadata.head(10))"
73-
]
74-
},
75-
{
76-
"cell_type": "markdown",
77-
"metadata": {},
78-
"source": [
79-
"## Loading a pre-trained foundation model\n\n### Download and Load Pre-trained Weights\n\nWe download the pre-trained weights for the SignalJEPA model from the Hugging Face Hub.\nThese weights will serve as the starting point for finetuning.\n\n\n"
61+
"### Define Dataset parameters\n\nWe extract the sampling frequency and channel information after\npreprocessing so they match the pretrained model.\n\n\n"
8062
]
8163
},
8264
{
@@ -87,14 +69,14 @@
8769
},
8870
"outputs": [],
8971
"source": [
90-
"model_state_dict = torch.hub.load_state_dict_from_url(\n url=\"https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth\"\n)\n# print(model_state_dict.keys())"
72+
"sfreq = dataset.datasets[0].raw.info[\"sfreq\"]\nchs_info = dataset.datasets[0].raw.info[\"chs\"]\n\nprint(f\"{sfreq=}, {len(chs_info)=}\")"
9173
]
9274
},
9375
{
9476
"cell_type": "markdown",
9577
"metadata": {},
9678
"source": [
97-
"### Instantiate the Foundation Model\n\nWe create an instance of the SignalJEPA model using the pre-local downstream\narchitecture. The model is initialized with the dataset's sampling frequency,\nwindow size, and channel information.\n\n\n"
79+
"### Create Windows from Events\n\nWe use the `create_windows_from_events` function from Braindecode to segment\nthe dataset into windows based on events.\n\n\n"
9880
]
9981
},
10082
{
@@ -105,14 +87,14 @@
10587
},
10688
"outputs": [],
10789
"source": [
108-
"model = SignalJEPA_PreLocal(\n sfreq=sfreq,\n input_window_seconds=window_size_seconds,\n chs_info=chs_info,\n n_outputs=len(classes),\n)\nprint(model)"
90+
"classes = [\"feet\", \"left_hand\", \"right_hand\"]\nclasses_mapping = {c: i for i, c in enumerate(classes)}\n\nwindows_dataset = create_windows_from_events(\n dataset,\n preload=True, # Preload the data into memory for faster processing\n mapping=classes_mapping,\n window_size_samples=256, # 2 s at 128 Hz \u2014 matches pretrained model\n window_stride_samples=256,\n)\nmetadata = windows_dataset.get_metadata()\nprint(metadata.head(10))"
10991
]
11092
},
11193
{
11294
"cell_type": "markdown",
11395
"metadata": {},
11496
"source": [
115-
"### Load the Pre-trained Weights into the Model\n\nWe load the pre-trained weights into the model. The transformer layers are excluded\nas this module is not used in the pre-local downstream architecture (see [1]_).\n\n\n"
97+
"## Loading a pre-trained foundation model\n\n### Load Pre-trained Weights from the Hub\n\nWe load the pre-trained SignalJEPA downstream model from the Hugging Face\nHub using ``from_pretrained``. The ``SignalJEPA_PreLocal`` checkpoint\nalready bundles the SSL backbone together with the downstream\nclassification layers, so a single call is all that is needed.\n\nFor other foundation models (BENDR, BIOT, Labram, etc.) the same\none-line pattern applies \u2014 see `load-pretrained-models`.\n\n\n"
11698
]
11799
},
118100
{
@@ -123,7 +105,7 @@
123105
},
124106
"outputs": [],
125107
"source": [
126-
"# Define layers to exclude from the pre-trained weights\nnew_layers = {\n \"spatial_conv.1.weight\",\n \"spatial_conv.1.bias\",\n \"final_layer.1.weight\",\n \"final_layer.1.bias\",\n}\n\n# Filter out transformer weights and load the state dictionary\nmodel_state_dict = {\n k: v for k, v in model_state_dict.items() if not k.startswith(\"transformer.\")\n}\nmissing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)\n\n# Ensure no unexpected keys and validate missing keys\nassert unexpected_keys == [], f\"{unexpected_keys=}\"\nassert set(missing_keys) == new_layers, f\"{missing_keys=}\""
108+
"model = SignalJEPA_PreLocal.from_pretrained(\n \"braindecode/SignalJEPA-PreLocal-pretrained\",\n n_outputs=len(classes),\n)\nprint(model)"
127109
]
128110
},
129111
{
@@ -141,7 +123,7 @@
141123
},
142124
"outputs": [],
143125
"source": [
144-
"for name, param in model.named_parameters():\n if name not in new_layers:\n param.requires_grad = False\n\nprint(\"Trainable parameters:\")\nother_modules = set()\nfor name, param in model.named_parameters():\n if param.requires_grad:\n print(name)\n else:\n other_modules.add(name.split(\".\")[0])\n\nprint(\"\\nOther modules:\")\nprint(other_modules)"
126+
"# Keep the task-specific head layers (spatial_conv and final_layer)\n# trainable and freeze the pretrained backbone.\nnew_layers = {\n name\n for name, _ in model.named_parameters()\n if name.startswith((\"spatial_conv.\", \"final_layer.\"))\n}\n\nfor name, param in model.named_parameters():\n if name not in new_layers:\n param.requires_grad = False\n\nprint(\"Trainable parameters:\")\nother_modules = set()\nfor name, param in model.named_parameters():\n if param.requires_grad:\n print(name)\n else:\n other_modules.add(name.split(\".\")[0])\n\nprint(\"\\nOther modules:\")\nprint(other_modules)"
145127
]
146128
},
147129
{
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)