|
| 1 | +import numpy as np |
| 2 | +from mne.stats import permutation_cluster_test |
| 3 | +import matplotlib.pyplot as plt |
| 4 | + |
| 5 | +from plot_funcs import choose_hist_rule |
| 6 | + |
| 7 | + |
| 8 | +def cluster_analysis(X1, X2, n_permutations=1000,p_alpha=0.05, **kwargs): |
| 9 | + """ |
| 10 | + Perform cluster-based permutation test on two conditions. |
| 11 | +
|
| 12 | + Returns |
| 13 | + ------- |
| 14 | + clusters : list of tuples |
| 15 | + All clusters (tuple of slices). |
| 16 | + cluster_p_values : list of float |
| 17 | + p-value of each cluster from permutation test. |
| 18 | + cluster_mass : list of float |
| 19 | + Sum of t-values per cluster (raw cluster-mass). |
| 20 | + normalized_mass : list of float |
| 21 | + Mean t-value per cluster (length-normalized). |
| 22 | + """ |
| 23 | + # return None |
| 24 | + X = [X2, X1] |
| 25 | + tfce_params = dict( |
| 26 | + E=1, # extent exponent (sensitivity to cluster width) |
| 27 | + H=1.0, # height exponent (sensitivity to effect size) |
| 28 | + start=0.2, # start threshold (no threshold) |
| 29 | + step=0.1,) # step size (smaller = more accurate) |
| 30 | + |
| 31 | + # silence runtime warnings |
| 32 | + |
| 33 | + T_obs, clusters, cluster_p_values, H0 = permutation_cluster_test( |
| 34 | + X, n_permutations=n_permutations, tail=1, n_jobs=1, seed=42,threshold=2.5, |
| 35 | + verbose=False, |
| 36 | + ) |
| 37 | + # print(f'--- H0: {H0} ---') |
| 38 | + |
| 39 | + # T_obs, clusters, cluster_p_values, H0 = split_clusters(T_obs, clusters, cluster_p_values, H0) |
| 40 | + # clusters = [c[0] for c in clusters] |
| 41 | + |
| 42 | + cluster_mass = [T_obs[c].sum() for c in clusters] |
| 43 | + cluster_mass = [e if e != np.nan else 0 for e in cluster_mass] |
| 44 | + normalized_mass = [T_obs[c].mean() for c in clusters] |
| 45 | + |
| 46 | + return clusters, cluster_p_values, cluster_mass, normalized_mass |
| 47 | + |
| 48 | + |
| 49 | +def plot_clusters(x_ser, clusters, cluster_p_values,plot,p_alpha=0.05,plot_y=0,plot_kwargs=None): |
| 50 | + """ |
| 51 | + Plot mean responses with significant clusters shaded. |
| 52 | + """ |
| 53 | + |
| 54 | + fig, ax = plot |
| 55 | + |
| 56 | + if plot_kwargs is None: |
| 57 | + plot_kwargs = {} |
| 58 | + |
| 59 | + for i, c in enumerate(clusters): |
| 60 | + if cluster_p_values[i] < p_alpha: |
| 61 | + ax.plot([x_ser[c[0]], x_ser[c[-1]]],[plot_y,plot_y], **plot_kwargs) |
| 62 | + |
| 63 | + return fig, ax |
| 64 | + |
| 65 | + |
| 66 | +def plot_cluster_stats(cluster_mass, shuff_cluster_mass,plot=None, plot_kwargs=None, |
| 67 | + plot_raw=True, plot_normed=True): |
| 68 | + |
| 69 | + if plot is None: |
| 70 | + fig, ax = plt.subplots() |
| 71 | + else: |
| 72 | + fig, ax = plot |
| 73 | + |
| 74 | + if plot_kwargs is None: |
| 75 | + plot_kwargs = {} |
| 76 | + |
| 77 | + """ |
| 78 | + Plot total cluster mass and total normalized cluster mass (sums across all clusters). |
| 79 | + """ |
| 80 | + total_mass = np.sum(cluster_mass) |
| 81 | + total_shuff_mass = [np.sum(shuff) for shuff in shuff_cluster_mass] |
| 82 | + # plot 1% to 99% range |
| 83 | + total_shuff_mass = [m for m in total_shuff_mass if m >= np.percentile(total_shuff_mass,1) |
| 84 | + and m <= np.percentile(total_shuff_mass,99)] |
| 85 | + |
| 86 | + ax.hist(total_shuff_mass,bins=choose_hist_rule(total_shuff_mass),fc='gray') |
| 87 | + ax.axvline(total_mass,color='goldenrod',ls='-') |
| 88 | + ax.set_ylabel("Frequency",fontdict={'size':5}) |
| 89 | + ax.set_xlabel("Total mass",fontdict={'size':5}) |
| 90 | + ax.yaxis.labelpad = 1 |
| 91 | + ax.xaxis.labelpad = 2 |
| 92 | + ax.tick_params(axis='both', which='both', pad=2, labelsize=5,) |
| 93 | + # set symlog scale if needed |
| 94 | + if abs(total_mass) / np.min(total_shuff_mass) > 20: |
| 95 | + ax.set_xscale('symlog', linthresh=1) |
| 96 | + else: |
| 97 | + ax.set_xscale('linear') |
| 98 | + ax.set_xscale('symlog', linthresh=abs(np.max(total_shuff_mass)-np.min(total_shuff_mass))*0.1+0.1) |
| 99 | + # ax.set_xscale('linear',) |
| 100 | + |
| 101 | + ax.set_title("",fontdict={'size':5}) |
| 102 | + ax.locator_params(axis='y', nbins=3) |
| 103 | + # ax.locator_params(axis='x', nbins=4) |
| 104 | + fig.tight_layout() |
| 105 | + fig.show() |
| 106 | + |
| 107 | + return fig, ax |
| 108 | + |
| 109 | + |
| 110 | +def split_clusters(T_obs, clusters, cluster_p_values, H0=None, |
| 111 | + depth_fraction=0.5, min_gap=1): |
| 112 | + """ |
| 113 | + Split clusters (arrays of indices) if valleys appear inside them. |
| 114 | + Returns results in the exact same format as the input: |
| 115 | + T_obs, clusters, cluster_p_values, H0 |
| 116 | +
|
| 117 | + Parameters |
| 118 | + ---------- |
| 119 | + T_obs : ndarray, shape (n_timepoints,) |
| 120 | + Observed t-values. |
| 121 | + clusters : list of np.ndarray |
| 122 | + Each cluster is an array of indices (sorted, contiguous). |
| 123 | + cluster_p_values : ndarray |
| 124 | + P-values corresponding to clusters. |
| 125 | + H0 : ndarray | None |
| 126 | + Permutation null distribution (unchanged). |
| 127 | + depth_fraction : float |
| 128 | + Split when |t| dips below (peak * depth_fraction) inside a cluster. |
| 129 | + min_gap : int |
| 130 | + Minimum valley length (in samples) to split on. |
| 131 | +
|
| 132 | + Returns |
| 133 | + ------- |
| 134 | + T_obs : ndarray |
| 135 | + Same as input. |
| 136 | + new_clusters : list of np.ndarray |
| 137 | + Same format as input, but with splits applied. |
| 138 | + new_cluster_p_values : ndarray |
| 139 | + Same length as new_clusters, children inherit parent's p-value. |
| 140 | + H0 : ndarray | None |
| 141 | + Same as input. |
| 142 | + """ |
| 143 | + new_clusters = [] |
| 144 | + new_pvals = [] |
| 145 | + |
| 146 | + for idx, pval in zip(clusters, cluster_p_values): |
| 147 | + # ensure indices are sorted |
| 148 | + idx = np.sort(np.unique(idx)) |
| 149 | + if idx.size == 0: |
| 150 | + continue |
| 151 | + |
| 152 | + tvals = T_obs[idx] |
| 153 | + peak = np.nanmax(np.abs(tvals)) |
| 154 | + if peak == 0: # nothing to split |
| 155 | + new_clusters.append(idx) |
| 156 | + new_pvals.append(pval) |
| 157 | + continue |
| 158 | + |
| 159 | + valley_threshold = peak * depth_fraction |
| 160 | + below = np.where(np.abs(tvals) <= valley_threshold)[0] |
| 161 | + |
| 162 | + if below.size == 0: |
| 163 | + # no valley → keep cluster as is |
| 164 | + new_clusters.append(idx) |
| 165 | + new_pvals.append(pval) |
| 166 | + continue |
| 167 | + |
| 168 | + # contiguous runs of below-threshold points |
| 169 | + runs = np.split(below, np.where(np.diff(below) > 1)[0] + 1) |
| 170 | + |
| 171 | + cut_points = [] |
| 172 | + for run in runs: |
| 173 | + if run.size >= min_gap: |
| 174 | + cut_rel = run[-1] + 1 |
| 175 | + if cut_rel < idx.size: # stay inside cluster |
| 176 | + cut_points.append(idx[cut_rel]) |
| 177 | + |
| 178 | + if not cut_points: |
| 179 | + new_clusters.append(idx) |
| 180 | + new_pvals.append(pval) |
| 181 | + continue |
| 182 | + |
| 183 | + # split cluster at cut_points |
| 184 | + split_idx = [idx[0]] + cut_points + [idx[-1] + 1] |
| 185 | + for a, b in zip(split_idx[:-1], split_idx[1:]): |
| 186 | + sub = idx[(idx >= a) & (idx < b)] |
| 187 | + if sub.size > 0: |
| 188 | + new_clusters.append(sub) |
| 189 | + new_pvals.append(pval) |
| 190 | + |
| 191 | + return T_obs, new_clusters, np.array(new_pvals), H0 |
0 commit comments