Skip to content

Commit c1d2ed6

Browse files
author
Dammy Desktop
committed
missing files
1 parent 832243a commit c1d2ed6

4 files changed

Lines changed: 329 additions & 0 deletions

File tree

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import numpy as np
2+
from scipy.stats import ttest_ind
3+
from sklearn.model_selection import StratifiedKFold
4+
from sklearn.metrics import accuracy_score
5+
import tensorflow as tf
6+
from tensorflow.keras import Sequential
7+
from tensorflow.keras.layers import Conv1D, GlobalAveragePooling1D, Dense, Dropout, BatchNormalization
8+
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
9+
import random
10+
import tempfile
11+
import os
12+
import matplotlib.pyplot as plt
13+
from tensorflow.python.keras import regularizers
14+
from tensorflow.python.keras.layers import LSTM
15+
16+
# -----------------------------
17+
# Step 1: Fix seeds for reproducibility
18+
# -----------------------------
19+
np.random.seed(42)
20+
tf.random.set_seed(42)
21+
random.seed(42)
22+
23+
# -----------------------------
24+
# Step 2: Simulate example data with n conditions
25+
# -----------------------------
26+
def simulate_data(n_conditions=3, n_trials=100, n_timesteps=200):
27+
conditions = []
28+
for i in range(n_conditions):
29+
base = np.sin(np.linspace(0, 2*np.pi*(i+1), n_timesteps))
30+
cond = np.array([
31+
base + 0.2*np.random.randn(n_timesteps)
32+
for _ in range(n_trials)
33+
])
34+
conditions.append(cond)
35+
return np.array(conditions) # shape (n_conditions, n_trials, n_timesteps)
36+
37+
# -----------------------------
38+
# Step 3: Prepare data
39+
# -----------------------------
40+
def prepare_data(data):
41+
n_conditions, n_trials, n_timesteps = data.shape
42+
X = data.reshape(-1, n_timesteps)
43+
y = np.repeat(np.arange(n_conditions), n_trials)
44+
# Add channel dimension for Conv1D
45+
X = X[..., np.newaxis]
46+
return X, y
47+
48+
49+
def preprocess_data(X):
50+
# Normalize per trial
51+
X = (X - X.mean(axis=1, keepdims=True)) / (X.std(axis=1, keepdims=True) + 1e-8)
52+
X = X[..., np.newaxis] # add channel dimension for Conv1D
53+
return X
54+
55+
# -----------------------------
56+
# Step 4: Build classifier (Conv1D + BatchNorm)
57+
# -----------------------------
58+
def build_model(n_timesteps, n_classes,learning_rate=1e-3,dropout_rate=0.3, l2_reg=1e-2):
59+
model = Sequential([
60+
Conv1D(4, kernel_size=5, activation='relu', kernel_regularizer=regularizers.l2(l2_reg),
61+
input_shape=(n_timesteps, 1)),
62+
BatchNormalization(),
63+
Conv1D(8, kernel_size=5, activation='relu', kernel_regularizer=regularizers.l2(l2_reg)),
64+
BatchNormalization(),
65+
LSTM(8, return_sequences=False, kernel_regularizer=regularizers.l2(l2_reg)),
66+
Dropout(dropout_rate),
67+
Dense(8, activation='relu', kernel_regularizer=regularizers.l2(l2_reg)),
68+
Dense(n_classes, activation='softmax')
69+
])
70+
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
71+
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
72+
return model
73+
74+
# -----------------------------
75+
# Step 5: Cross-validation & significance test
76+
# -----------------------------
77+
def train_model(X, y, epochs=50):
78+
rlrop = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, verbose=1)
79+
early_stop = EarlyStopping(monitor='val_loss', patience=50, restore_best_weights=True)
80+
81+
82+
model = build_model(X.shape[1], len(np.unique(y)))
83+
history = model.fit(X, y, validation_split=0.2, epochs=epochs, batch_size=16,
84+
shuffle=True, callbacks=[rlrop, early_stop], verbose=1)
85+
86+
87+
# Plot accuracy and loss
88+
fig, ax = plt.subplots(1,2, figsize=(12,4))
89+
ax[0].plot(history.history['accuracy'], label='Train Acc')
90+
ax[0].plot(history.history['val_accuracy'], label='Val Acc')
91+
ax[0].set_title('Accuracy')
92+
ax[0].set_xlabel('Epoch')
93+
ax[0].set_ylabel('Accuracy')
94+
ax[0].legend()
95+
96+
97+
ax[1].plot(history.history['loss'], label='Train Loss')
98+
ax[1].plot(history.history['val_loss'], label='Val Loss')
99+
ax[1].set_title('Loss')
100+
ax[1].set_xlabel('Epoch')
101+
ax[1].set_ylabel('Loss')
102+
ax[1].legend()
103+
104+
105+
fig.tight_layout()
106+
fig.show()
107+
108+
109+
return model
110+
111+
112+
# -----------------------------
113+
# Step 6: Permutation test
114+
# -----------------------------
115+
def permutation_test(X, y, model_builder, n_permutations=100, epochs=50):
116+
observed_model = model_builder(X.shape[1], len(np.unique(y)))
117+
observed_model.fit(X, y, epochs=epochs, batch_size=16, shuffle=True, verbose=0)
118+
y_pred = np.argmax(observed_model.predict(X, verbose=0), axis=1)
119+
observed_acc = accuracy_score(y, y_pred)
120+
121+
122+
null_accs = []
123+
for _ in range(n_permutations):
124+
y_perm = np.random.permutation(y)
125+
perm_model = model_builder(X.shape[1], len(np.unique(y)))
126+
perm_model.fit(X, y_perm, epochs=epochs, batch_size=16, shuffle=True, verbose=0)
127+
y_pred_perm = np.argmax(perm_model.predict(X, verbose=0), axis=1)
128+
null_accs.append(accuracy_score(y_perm, y_pred_perm))
129+
130+
131+
p_value = (np.sum(np.array(null_accs) >= observed_acc) + 1) / (n_permutations + 1)
132+
return observed_acc, null_accs, p_value
133+
134+
# -----------------------------
135+
# Step 6: Run
136+
# -----------------------------
137+
if __name__ == "__main__":
138+
pass
Binary file not shown.
Binary file not shown.
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import numpy as np
2+
from mne.stats import permutation_cluster_test
3+
import matplotlib.pyplot as plt
4+
5+
from plot_funcs import choose_hist_rule
6+
7+
8+
def cluster_analysis(X1, X2, n_permutations=1000,p_alpha=0.05, **kwargs):
9+
"""
10+
Perform cluster-based permutation test on two conditions.
11+
12+
Returns
13+
-------
14+
clusters : list of tuples
15+
All clusters (tuple of slices).
16+
cluster_p_values : list of float
17+
p-value of each cluster from permutation test.
18+
cluster_mass : list of float
19+
Sum of t-values per cluster (raw cluster-mass).
20+
normalized_mass : list of float
21+
Mean t-value per cluster (length-normalized).
22+
"""
23+
# return None
24+
X = [X2, X1]
25+
tfce_params = dict(
26+
E=1, # extent exponent (sensitivity to cluster width)
27+
H=1.0, # height exponent (sensitivity to effect size)
28+
start=0.2, # start threshold (no threshold)
29+
step=0.1,) # step size (smaller = more accurate)
30+
31+
# silence runtime warnings
32+
33+
T_obs, clusters, cluster_p_values, H0 = permutation_cluster_test(
34+
X, n_permutations=n_permutations, tail=1, n_jobs=1, seed=42,threshold=2.5,
35+
verbose=False,
36+
)
37+
# print(f'--- H0: {H0} ---')
38+
39+
# T_obs, clusters, cluster_p_values, H0 = split_clusters(T_obs, clusters, cluster_p_values, H0)
40+
# clusters = [c[0] for c in clusters]
41+
42+
cluster_mass = [T_obs[c].sum() for c in clusters]
43+
cluster_mass = [e if e != np.nan else 0 for e in cluster_mass]
44+
normalized_mass = [T_obs[c].mean() for c in clusters]
45+
46+
return clusters, cluster_p_values, cluster_mass, normalized_mass
47+
48+
49+
def plot_clusters(x_ser, clusters, cluster_p_values,plot,p_alpha=0.05,plot_y=0,plot_kwargs=None):
50+
"""
51+
Plot mean responses with significant clusters shaded.
52+
"""
53+
54+
fig, ax = plot
55+
56+
if plot_kwargs is None:
57+
plot_kwargs = {}
58+
59+
for i, c in enumerate(clusters):
60+
if cluster_p_values[i] < p_alpha:
61+
ax.plot([x_ser[c[0]], x_ser[c[-1]]],[plot_y,plot_y], **plot_kwargs)
62+
63+
return fig, ax
64+
65+
66+
def plot_cluster_stats(cluster_mass, shuff_cluster_mass,plot=None, plot_kwargs=None,
67+
plot_raw=True, plot_normed=True):
68+
69+
if plot is None:
70+
fig, ax = plt.subplots()
71+
else:
72+
fig, ax = plot
73+
74+
if plot_kwargs is None:
75+
plot_kwargs = {}
76+
77+
"""
78+
Plot total cluster mass and total normalized cluster mass (sums across all clusters).
79+
"""
80+
total_mass = np.sum(cluster_mass)
81+
total_shuff_mass = [np.sum(shuff) for shuff in shuff_cluster_mass]
82+
# plot 1% to 99% range
83+
total_shuff_mass = [m for m in total_shuff_mass if m >= np.percentile(total_shuff_mass,1)
84+
and m <= np.percentile(total_shuff_mass,99)]
85+
86+
ax.hist(total_shuff_mass,bins=choose_hist_rule(total_shuff_mass),fc='gray')
87+
ax.axvline(total_mass,color='goldenrod',ls='-')
88+
ax.set_ylabel("Frequency",fontdict={'size':5})
89+
ax.set_xlabel("Total mass",fontdict={'size':5})
90+
ax.yaxis.labelpad = 1
91+
ax.xaxis.labelpad = 2
92+
ax.tick_params(axis='both', which='both', pad=2, labelsize=5,)
93+
# set symlog scale if needed
94+
if abs(total_mass) / np.min(total_shuff_mass) > 20:
95+
ax.set_xscale('symlog', linthresh=1)
96+
else:
97+
ax.set_xscale('linear')
98+
ax.set_xscale('symlog', linthresh=abs(np.max(total_shuff_mass)-np.min(total_shuff_mass))*0.1+0.1)
99+
# ax.set_xscale('linear',)
100+
101+
ax.set_title("",fontdict={'size':5})
102+
ax.locator_params(axis='y', nbins=3)
103+
# ax.locator_params(axis='x', nbins=4)
104+
fig.tight_layout()
105+
fig.show()
106+
107+
return fig, ax
108+
109+
110+
def split_clusters(T_obs, clusters, cluster_p_values, H0=None,
111+
depth_fraction=0.5, min_gap=1):
112+
"""
113+
Split clusters (arrays of indices) if valleys appear inside them.
114+
Returns results in the exact same format as the input:
115+
T_obs, clusters, cluster_p_values, H0
116+
117+
Parameters
118+
----------
119+
T_obs : ndarray, shape (n_timepoints,)
120+
Observed t-values.
121+
clusters : list of np.ndarray
122+
Each cluster is an array of indices (sorted, contiguous).
123+
cluster_p_values : ndarray
124+
P-values corresponding to clusters.
125+
H0 : ndarray | None
126+
Permutation null distribution (unchanged).
127+
depth_fraction : float
128+
Split when |t| dips below (peak * depth_fraction) inside a cluster.
129+
min_gap : int
130+
Minimum valley length (in samples) to split on.
131+
132+
Returns
133+
-------
134+
T_obs : ndarray
135+
Same as input.
136+
new_clusters : list of np.ndarray
137+
Same format as input, but with splits applied.
138+
new_cluster_p_values : ndarray
139+
Same length as new_clusters, children inherit parent's p-value.
140+
H0 : ndarray | None
141+
Same as input.
142+
"""
143+
new_clusters = []
144+
new_pvals = []
145+
146+
for idx, pval in zip(clusters, cluster_p_values):
147+
# ensure indices are sorted
148+
idx = np.sort(np.unique(idx))
149+
if idx.size == 0:
150+
continue
151+
152+
tvals = T_obs[idx]
153+
peak = np.nanmax(np.abs(tvals))
154+
if peak == 0: # nothing to split
155+
new_clusters.append(idx)
156+
new_pvals.append(pval)
157+
continue
158+
159+
valley_threshold = peak * depth_fraction
160+
below = np.where(np.abs(tvals) <= valley_threshold)[0]
161+
162+
if below.size == 0:
163+
# no valley → keep cluster as is
164+
new_clusters.append(idx)
165+
new_pvals.append(pval)
166+
continue
167+
168+
# contiguous runs of below-threshold points
169+
runs = np.split(below, np.where(np.diff(below) > 1)[0] + 1)
170+
171+
cut_points = []
172+
for run in runs:
173+
if run.size >= min_gap:
174+
cut_rel = run[-1] + 1
175+
if cut_rel < idx.size: # stay inside cluster
176+
cut_points.append(idx[cut_rel])
177+
178+
if not cut_points:
179+
new_clusters.append(idx)
180+
new_pvals.append(pval)
181+
continue
182+
183+
# split cluster at cut_points
184+
split_idx = [idx[0]] + cut_points + [idx[-1] + 1]
185+
for a, b in zip(split_idx[:-1], split_idx[1:]):
186+
sub = idx[(idx >= a) & (idx < b)]
187+
if sub.size > 0:
188+
new_clusters.append(sub)
189+
new_pvals.append(pval)
190+
191+
return T_obs, new_clusters, np.array(new_pvals), H0

0 commit comments

Comments
 (0)