Skip to content

Commit 6676bae

Browse files
authored
Merge pull request #151 from Neuroinflab/figs2023
figure for submission
2 parents 4630c45 + 0a88f76 commit 6676bae

10 files changed

Lines changed: 72 additions & 53 deletions

figures/kCSD_properties/L_curve_simulation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def main_loop(src_width, total_ele, inpos, lpos, nm, noise=0, srcs=1):
112112
t_csd_x, t_csd_y, true_csd = generate_csd_1D(src_width, nm, srcs=srcs,
113113
start_x=0, end_x=1.,
114114
start_y=0, end_y=1,
115-
res_x=100, res_y=100)
115+
res_x=101, res_y=101)
116116
if type(noise) == float: n_spec = [noise]
117117
else: n_spec = noise
118118
for i, noise in enumerate(n_spec):
@@ -142,6 +142,7 @@ def main_loop(src_width, total_ele, inpos, lpos, nm, noise=0, srcs=1):
142142
'pots':pots, 'estm_x':k.estm_x, 'est_pot':est_pot,
143143
'est_csd':est_csd, 'noreg_csd':noreg_csd, 'errsy':errsy}
144144
np.savez('data_fig4_and_fig13_'+save_as, **vals_to_save)
145+
print(true_csd.shape, est_csd[:,0].shape)
145146
RMS_wek[0, i] = np.linalg.norm(true_csd/np.linalg.norm(true_csd) - est_csd[:,0]/np.linalg.norm(est_csd[:,0]))
146147
RMS_wek[1, i] = np.linalg.norm(true_csd/np.linalg.norm(true_csd) - est_csd_cv[:,0]/np.linalg.norm(est_csd_cv[:,0]))
147148

figures/kCSD_properties/figure_LC.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,11 @@ def make_plots(title, m_norm, m_resi, true_csd, curveseq, ele_y,
103103
# os.chdir("./LCurve/LC2")
104104
noises = 3
105105
noise_lvl = np.linspace(0, 0.5, noises)
106-
# df = np.load('data_fig4_and_fig13_lc_noise25.0.npz')
106+
#df = np.load('data_fig4_and_fig13_lc_noise25.0.npz')
107107
Rs = np.linspace(0.025, 8*0.025, 8)
108108
title = ['nazwa_pliku']
109109
save_as = 'noise'
110-
# make_plots(title, df['m_norm'], df['m_resi'], df['true_csd'],
111-
# df['curve_surf'], df['ele_y'], df['pots_n'],
112-
# df['pots'], df['estm_x'], df['est_pot'], df['est_csd'],
113-
# df['noreg_csd'], save_as)
110+
make_plots(title, df['m_norm'], df['m_resi'], df['true_csd'],
111+
df['curve_surf'], df['ele_y'], df['pots_n'],
112+
df['pots'], df['estm_x'], df['est_pot'], df['est_csd'],
113+
df['noreg_csd'], save_as)

figures/kCSD_properties/figure_LCandCV.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,51 +23,58 @@ def set_axis(ax, x, y, letter=None):
2323
def plot_surface(curve_surf, errsy, save_as):
2424
fsize = 18
2525
lambdas = np.logspace(-7, -3, 50)
26-
fig = plt.figure(figsize = (20,9), dpi = 300)
27-
gs = gridspec.GridSpec(16, 12, hspace=2, wspace=2)
26+
fig = plt.figure(figsize = (15, 6), dpi = 300)
27+
gs = gridspec.GridSpec(16, 12, hspace=1, wspace=1)
2828
ax = plt.subplot(gs[0:16, 0:6])
2929
set_axis(ax, -0.05, 1.05, letter='A')
30-
plt.pcolormesh(lambdas, np.arange(9), curve_surf,
31-
cmap = 'BrBG', vmin = -2, vmax=2)
30+
31+
plt.pcolormesh(lambdas, np.arange(8), curve_surf,
32+
cmap = 'BrBG', vmin = -3, vmax=3)
3233
plt.colorbar()
3334
for i,m in enumerate(curve_surf.argmax(axis=1)):
34-
plt.scatter([lambdas[m]], [i+0.5], s=50, color='red', alpha = 0.7)
35+
plt.scatter([lambdas[m]], [i], s=50, color='red', alpha = 0.7)
3536
if i==7:
36-
plt.scatter([lambdas[m]], [i+0.5], s=50, color='red',
37-
label = 'Maximum Curvature', alpha = 0.7)
37+
plt.scatter([lambdas[m]], [i], s=50, color='red',
38+
label = 'Maximum \nCurvature', alpha = 0.7)
3839
plt.xlim(lambdas[1],lambdas[-1])
3940
plt.title('L-curve regularization', fontsize = fsize)
40-
plt.legend(loc='center', bbox_to_anchor=(0.5, -0.12), ncol=1,
41+
# plt.legend(loc='center', bbox_to_anchor=(0.5, -0.12), ncol=1,
42+
# frameon = False, fontsize = fsize)
43+
plt.legend(loc='upper left', ncol=1,
4144
frameon = False, fontsize = fsize)
42-
plt.yticks(np.arange(8)+0.5, [str(x)+'x' for x in range(1,9)])
45+
plt.yticks(np.arange(8), [str(x)+'x' for x in range(1,9)])
4346
plt.xscale('log')
4447
plt.ylabel('Parameter $R$ in electrode distance', fontsize=fsize, labelpad = 15)
4548
plt.xlabel('$\lambda$',fontsize=fsize)
4649
ax = plt.subplot(gs[0:16, 6:12])
4750
set_axis(ax, -0.05, 1.05, letter='B')
48-
plt.pcolormesh(lambdas, np.arange(9), errsy, cmap = 'Greys')
51+
plt.pcolormesh(lambdas, np.arange(8), errsy, cmap='Greys', vmin=0.01, vmax=0.02)
4952
plt.colorbar()
5053
for i,m in enumerate(errsy.argmin(axis=1)):
51-
plt.scatter([lambdas[m]], [i+0.5], s=50, color='red', alpha = 0.7)
54+
plt.scatter([lambdas[m]], [i], s=50, color='red', alpha = 0.7)
5255
if i==7:
53-
plt.scatter([lambdas[m]], [i+0.5], s=50, color='red',
54-
label = 'Minimum Error', alpha = 0.7)
56+
plt.scatter([lambdas[m]], [i], s=50, color='red',
57+
label = 'Minimum \nError', alpha = 0.7)
5558
plt.xlim(lambdas[1],lambdas[-1])
56-
plt.legend(loc='center', bbox_to_anchor=(0.5, -0.12), ncol=1,
59+
# plt.legend(loc='center', bbox_to_anchor=(0.5, -0.12), ncol=1,
60+
# frameon = False, fontsize = fsize)
61+
plt.legend(loc='upper left', ncol=1,
5762
frameon = False, fontsize = fsize)
5863
plt.title('Cross-validation regularization', fontsize = fsize)
59-
plt.yticks(np.arange(8)+0.5, [str(x)+'x' for x in range(1,9)])
64+
plt.yticks(np.arange(8), [str(x)+'x' for x in range(1,9)])
6065
plt.xscale('log')
6166
plt.xlabel('$\lambda$', fontsize=fsize)
6267
fig.savefig(save_as+'.png')
6368

6469
if __name__=='__main__':
65-
# os.chdir("./LCurve/")
70+
os.chdir("./LCurve/")
6671
noises = 3
6772
noise_lvl = np.linspace(0, 0.5, noises)
68-
# df = np.load('LC2/data_fig4_and_fig13_lc_noise25.0.npz')
73+
print(os.getcwd())
74+
df = np.load(os.path.join('LC2', 'data_fig4_and_fig13_LC_noise25.0.npz'))
75+
6976
Rs = np.linspace(0.025, 8*0.025, 8)
7077
title = ['nazwa_pliku']
7178
save_as = 'noise'
72-
# plot_surface(df['curve_surf'], df['errsy'], save_as+'surf')
73-
plt.close('all')
79+
plot_surface(df['curve_surf'], df['errsy'], save_as+'surf')
80+
plt.close('all')

figures/kCSD_properties/figure_LCandCVperformance.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ def make_plot_perf(sim_results):
2525
lam_lc = sim_results[0, 0]
2626
rms_cv = sim_results[1, 2]
2727
lam_cv = sim_results[1, 0]
28-
fig = plt.figure(figsize = (9,12), dpi = 300)
29-
widths = [10]
30-
heights = [1, 1]
31-
gs = gridspec.GridSpec(2, 1, height_ratios=heights, width_ratios=widths,
28+
fig = plt.figure(figsize = (12,7), dpi = 300)
29+
widths = [1, 1]
30+
heights = [1]
31+
gs = gridspec.GridSpec(1, 2, height_ratios=heights, width_ratios=widths,
3232
hspace=0.45, wspace=0.3)
3333
ax1 = plt.subplot(gs[0])
3434
if np.min(rms_cv) < np.min(rms_lc):
@@ -50,7 +50,8 @@ def make_plot_perf(sim_results):
5050
ax1.spines['right'].set_visible(False)
5151
ax1.spines['top'].set_visible(False)
5252
set_axis(ax1, -0.05, 1.05, letter='A')
53-
plt.title('Performance of regularization methods')
53+
ax1.legend(loc='upper left', frameon=False)
54+
# plt.title('Performance of regularization methods')
5455

5556
'''second plot'''
5657
ax2 = plt.subplot(gs[1])
@@ -69,14 +70,15 @@ def make_plot_perf(sim_results):
6970
plt.xlabel('Relative Noise Level', labelpad = 15)
7071
set_axis(ax2, -0.05, 1.05, letter='B')
7172
ht, lh = ax2.get_legend_handles_labels()
72-
fig.legend(ht, lh, loc='lower center', ncol=2, frameon=False)
73+
#fig.legend(ht, lh, loc='upper center', ncol=2, frameon=False)
7374
ax2.spines['right'].set_visible(False)
7475
ax2.spines['top'].set_visible(False)
76+
ax2.legend(loc='upper left', frameon=False)
7577
fig.savefig('stats.png')
7678

7779
if __name__=='__main__':
78-
# os.chdir("./LCurve/")
80+
os.chdir("./LCurve/")
7981
noises = 9
8082
noise_lvl = np.linspace(0, 0.5, noises)
81-
# sim_results = np.load('sim_results.npy')
82-
# make_plot_perf(sim_results)
83+
sim_results = np.load('sim_results.npy')
84+
make_plot_perf(sim_results)

figures/kCSD_properties/figure_properties.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@
1717
})
1818

1919

20+
def cm_to_inches(vals):
21+
return [0.393701*ii for ii in vals]
1.77 KB
Binary file not shown.

figures/kCSD_properties/tutorial_basic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def grid(x, y, z):
3636
def set_axis(ax, letter=None):
3737
ax.text(
3838
-0.05,
39-
1.05,
39+
1.10,
4040
letter,
4141
fontsize=20,
4242
weight='bold',
@@ -63,7 +63,7 @@ def make_subplot(ax, val_type, xs, ys, values, cax, title=None, ele_pos=None, xl
6363
if ylabel:
6464
ax.set_ylabel('Y (mm)')
6565
if title is not None:
66-
ax.set_title(title)
66+
ax.set_title(title, pad=10)
6767
ax.set_xticks([0, 0.5, 1])
6868
ax.set_yticks([0, 0.5, 1])
6969
ticks = np.linspace(-1 * t_max, t_max, 3, endpoint=True)
@@ -128,6 +128,7 @@ def generate_figure(small_seed, large_seed):
128128
cax = plt.subplot(gs[1, 0])
129129
t_max_1 = 0.50
130130
make_subplot(ax, 'csd', csd_x, csd_y, true_csd, cax, 'True CSD', xlabel=True, ylabel=True, letter='A', t_max=t_max_1)
131+
ax.text(-0.4, 0.5, 'Small sources', fontsize=20, rotation=90, va='center')
131132
ax = plt.subplot(gs[0, 1])
132133
cax = plt.subplot(gs[1, 1])
133134
make_subplot(ax, 'pot', pot_X, pot_Y, pot_Z, cax, 'Interpolated potentials', xlabel=True, ele_pos=ele_pos, letter='B')
@@ -145,9 +146,10 @@ def generate_figure(small_seed, large_seed):
145146
cax = plt.subplot(gs[1, 0])
146147
t_max_2 = 0.52
147148
make_subplot(ax, 'csd', csd_x, csd_y, true_csd, cax, ylabel=True, xlabel=True, letter='E', t_max=t_max_2)
149+
ax.text(-0.4, 0.5, 'Large sources', fontsize=20, rotation=90, va='center')
148150
ax = plt.subplot(gs[0, 1])
149151
cax = plt.subplot(gs[1, 1])
150-
make_subplot(ax, 'pot', pot_X, pot_Y, pot_Z, cax, xlabel=True, ele_pos=ele_pos, letter='F')
152+
make_subplot(ax, 'pot', pot_X, pot_Y, pot_Z, cax, xlabel=True, ele_pos=ele_pos, letter='F', t_max=1)
151153
ax = plt.subplot(gs[0, 2])
152154
cax = plt.subplot(gs[1, 2])
153155
make_subplot(ax, 'csd', k.estm_x, k.estm_y, est_csd_pre_cv[:, :, 0], cax, xlabel=True, letter='G', t_max=t_max_2)

figures/kCSD_properties/tutorial_broken_electrodes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def load_files(folderpaths, seeds):
4040
def set_axis(ax, letter=None):
4141
ax.text(
4242
-0.05,
43-
1.05,
43+
1.1,
4444
letter,
4545
fontsize=20,
4646
weight='bold',
@@ -121,7 +121,7 @@ def make_subplot(ax, val_type, xs, ys, values, cax, title=None, ele_pos=None, xl
121121
if ylabel:
122122
ax.set_ylabel('Y (mm)')
123123
if title is not None:
124-
ax.set_title(title)
124+
ax.set_title(title, pad=10)
125125
ax.set_xticks([0, 0.5, 1])
126126
ax.set_yticks([0, 0.5, 1])
127127
ticks = np.linspace(0, t_max, 3, endpoint=True)
@@ -154,6 +154,7 @@ def generate_figure():
154154
make_subplot(ax, 'err', csd_x, csd_y, errs[0], ele_pos=electrode_positions(missing_ele=0),
155155
cax=cax, title='Error CSD', xlabel=True, ylabel=True, letter='A',
156156
t_max=err_max)
157+
ax.text(-0.4, 0.5, 'Small sources', fontsize=20, rotation=90, va='center')
157158
ax = plt.subplot(gs[0, 1])
158159
cax = plt.subplot(gs[1, 1])
159160
make_subplot(ax, 'err', csd_x, csd_y, errs[1], ele_pos=electrode_positions(missing_ele=5),
@@ -179,6 +180,7 @@ def generate_figure():
179180
make_subplot(ax, 'err', csd_x, csd_y, errs[0], ele_pos=electrode_positions(missing_ele=0),
180181
cax=cax, xlabel=True, ylabel=True, letter='E',
181182
t_max=err_max)
183+
ax.text(-0.4, 0.5, 'Large sources', fontsize=20, rotation=90, va='center')
182184
ax = plt.subplot(gs[0, 1])
183185
cax = plt.subplot(gs[1, 1])
184186
make_subplot(ax, 'err', csd_x, csd_y, errs[1], ele_pos=electrode_positions(missing_ele=5),

figures/kCSD_properties/tutorial_broken_electrodes_diff_err.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def load_files(folderpaths, seeds):
3838
def set_axis(ax, letter=None):
3939
ax.text(
4040
-0.05,
41-
1.05,
41+
1.10,
4242
letter,
4343
fontsize=20,
4444
weight='bold',
@@ -145,7 +145,7 @@ def make_subplot(ax, val_type, xs, ys, values, cax, title=None, ele_pos=None,
145145
if ylabel:
146146
ax.set_ylabel('Y (mm)')
147147
if title is not None:
148-
ax.set_title(title)
148+
ax.set_title(title, pad=10)
149149
ax.set_xticks([0, 0.5, 1])
150150
ax.set_yticks([0, 0.5, 1])
151151
ticks = np.linspace(0, t_max, 3, endpoint=True)
@@ -277,20 +277,22 @@ def generate_figure2():
277277
make_subplot(ax, 'err', csd_x, csd_y, errs[0], ele_pos=electrode_positions(missing_ele=0),
278278
cax=cax, title='Error CSD', xlabel=True, ylabel=True, letter='A',
279279
t_max=.2)
280+
ax.text(-0.4, 0.5, 'Small+Large sources', fontsize=20, rotation=90, va='center')
281+
280282
ax = plt.subplot(gs[0, 1])
281283
cax = plt.subplot(gs[1, 1])
282284
make_subplot(ax, 'err', csd_x, csd_y, abs(errs[1] - errs[0]), ele_pos=electrode_positions(missing_ele=5),
283-
cax=cax, title='Error Diff CSD 5 broken', xlabel=True, letter='B',
285+
cax=cax, title='5 broken - Error CSD ', xlabel=True, letter='B',
284286
t_max=err_max)
285287
ax = plt.subplot(gs[0, 2])
286288
cax = plt.subplot(gs[1, 2])
287289
make_subplot(ax, 'err', csd_x, csd_y, abs(errs[2] - errs[0]), ele_pos=electrode_positions(missing_ele=10),
288-
cax=cax, title='Error Diff CSD 10 broken', xlabel=True, letter='C',
290+
cax=cax, title='10 broken - Error CSD', xlabel=True, letter='C',
289291
t_max=err_max)
290292
ax = plt.subplot(gs[0, 3])
291293
cax = plt.subplot(gs[1, 3])
292294
make_subplot(ax, 'err', csd_x, csd_y, abs(errs[3] - errs[0]), ele_pos=electrode_positions(missing_ele=20),
293-
cax=cax, title='Error Diff CSD 20 broken', xlabel=True, letter='D',
295+
cax=cax, title='20 broken - Error CSD', xlabel=True, letter='D',
294296
t_max=err_max)
295297

296298
errs = fetch_values('small')
@@ -302,6 +304,7 @@ def generate_figure2():
302304
make_subplot(ax, 'err', csd_x, csd_y, errs[0], ele_pos=electrode_positions(missing_ele=0),
303305
cax=cax, xlabel=True, ylabel=True, letter='E',
304306
t_max=.2)
307+
ax.text(-0.4, 0.5, 'Small sources', fontsize=20, rotation=90, va='center')
305308
ax = plt.subplot(gs[0, 1])
306309
cax = plt.subplot(gs[1, 1])
307310
make_subplot(ax, 'err', csd_x, csd_y, abs(errs[1] - errs[0]), ele_pos=electrode_positions(missing_ele=5),
@@ -327,6 +330,7 @@ def generate_figure2():
327330
make_subplot(ax, 'err', csd_x, csd_y, errs[0], ele_pos=electrode_positions(missing_ele=0),
328331
cax=cax, xlabel=True, ylabel=True, letter='I',
329332
t_max=.2)
333+
ax.text(-0.4, 0.5, 'Large sources', fontsize=20, rotation=90, va='center')
330334
ax = plt.subplot(gs[0, 1])
331335
cax = plt.subplot(gs[1, 1])
332336
make_subplot(ax, 'err', csd_x, csd_y, abs(errs[1] - errs[0]), ele_pos=electrode_positions(missing_ele=5),

figures/kCSD_properties/tutorial_noisy_electrodes.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def grid(x, y, z):
3636
def set_axis(ax, letter=None):
3737
ax.text(
3838
-0.05,
39-
1.05,
39+
1.1,
4040
letter,
4141
fontsize=20,
4242
weight='bold',
@@ -77,7 +77,7 @@ def make_subplot(ax, val_type, xs, ys, values, cax, title=None, ele_pos=None, xl
7777
if ylabel:
7878
ax.set_ylabel('Y (mm)')
7979
if title is not None:
80-
ax.set_title(title)
80+
ax.set_title(title, pad=10)
8181
ax.set_xticks([0, 0.5, 1])
8282
ax.set_yticks([0, 0.5, 1])
8383
ticks = np.linspace(0, t_max, 3, endpoint=True)
@@ -97,12 +97,11 @@ def do_kcsd(CSD_PROFILE, csd_seed, noise_level):
9797
# R_final = np.linspace(0.1, 1.5, 15)
9898
R_final = np.linspace(0.05, 1., 20)
9999
# True CSD_PROFILE
100-
csd_at = np.mgrid[0.:1.:100j,
101-
0.:1.:100j]
100+
csd_at = np.mgrid[0.:1.:101j,
101+
0.:1.:101j]
102102
csd_x, csd_y = csd_at
103103
# Small source
104104
true_csd = CSD_PROFILE(csd_at, seed=csd_seed)
105-
106105
# Electrode positions
107106
ele_x, ele_y = np.mgrid[0.05: 0.95: 10j,
108107
0.05: 0.95: 10j]
@@ -153,12 +152,12 @@ def generate_figure(small_seed, large_seed):
153152
cax = plt.subplot(gs[1, 0])
154153
make_subplot(ax, 'err', csd_x, csd_y, err, ele_pos=ele_pos,
155154
cax=cax, title='Error CSD', xlabel=True, ylabel=True, letter='A', t_max=t_max_1)
156-
155+
ax.text(-0.4, 0.5, 'Small sources', fontsize=20, rotation=90, va='center')
157156
csd_x, csd_y, err, ele_pos = do_kcsd(CSD.gauss_2d_small, csd_seed=small_seed, noise_level=5)
158157
ax = plt.subplot(gs[0, 1])
159158
cax = plt.subplot(gs[1, 1])
160159
make_subplot(ax, 'err', csd_x, csd_y, err, ele_pos=ele_pos,
161-
cax=cax, title='Error CSD 5% noise', xlabel=True, ylabel=True, letter='B', t_max=t_max_1)
160+
cax=cax, title='Error CSD 5% noise', xlabel=True, letter='B', t_max=t_max_1)
162161

163162
csd_x, csd_y, err, ele_pos = do_kcsd(CSD.gauss_2d_small, csd_seed=small_seed, noise_level=10)
164163
ax = plt.subplot(gs[0, 2])
@@ -180,7 +179,7 @@ def generate_figure(small_seed, large_seed):
180179
cax = plt.subplot(gs[1, 0])
181180
make_subplot(ax, 'err', csd_x, csd_y, err, ele_pos=ele_pos,
182181
cax=cax, xlabel=True, ylabel=True, letter='E', t_max=0.55)
183-
182+
ax.text(-0.4, 0.5, 'Large sources', fontsize=20, rotation=90, va='center')
184183
csd_x, csd_y, err, ele_pos = do_kcsd(CSD.gauss_2d_large, csd_seed=large_seed, noise_level=5)
185184
ax = plt.subplot(gs[0, 1])
186185
cax = plt.subplot(gs[1, 1])
@@ -199,7 +198,7 @@ def generate_figure(small_seed, large_seed):
199198
make_subplot(ax, 'err', csd_x, csd_y, err, ele_pos=ele_pos,
200199
cax=cax, xlabel=True, letter='H', t_max=0.55)
201200
plt.savefig('tutorial_noise.png', dpi=300)
202-
plt.show()
201+
# plt.show()
203202

204203
if __name__ == '__main__':
205204
small_seed = 15

0 commit comments

Comments
 (0)