Skip to content

Commit d10ee08

Browse files
wsredniawaccluri
authored andcommitted
Figures (#90)
* Update KCSD.py * Update kCSD1D_test_fig1_and_fig2.py * Update kCSD1D_test_fig1_and_fig2.py * Update KCSD.py In cross-validation name errs changed to self.errs * KCSD and L-curve update (#72) * Update KCSD.py * Update kCSD1D_test_fig1_and_fig2.py * Update kCSD1D_test_fig1_and_fig2.py * Update KCSD.py In cross-validation name errs changed to self.errs * changing the folder structure * static_R_code update * fig fonts and layout correction issue #74 * Delete kCSD1D_test_fig1_and_fig2_staticR.py * Delete kCSD1D_test_fig1_and_fig2.py * issue #74 updatte * Update kCSD1D_test_fig1_and_fig2_v2.py * New atribute in KCSD class I would like to include new auxiliary atribute of the KCSD class "own_est" which is an array of shape (number of dimensions 2 or 3, number of csd points coordinates) . It allows to reconstruct CSD only in given set of points (i.e. body of neuron) and work only when the array is not empty (default value is np.array([])). * Update KCSD.py Two new classes oKCSD2D and oKCSD3D to estimate CSD at exact given location. Method takes parameter own_src(#dimensions, #spatial coordinates) * Update KCSD.py * Update KCSD.py * update 74 Tests for oKCSD2D and oKCSD3d * kcsd_test with oKCSD2D and oKCSD3D * Update kcsd_tests.py * Update figures/kCSD_properties/kCSD1D_test_fig1_and_fig2_v2.py Co-Authored-By: wsredniawa <32837003+wsredniawa@users.noreply.github.com> * tests update oKCSD * Update kcsd_tests.py * Update kcsd_tests.py * Update kcsd_tests.py * Correct super execution Make super execution compatible with Python2. Switch kwargs.pop('own_est') with super execution otherwise an exception is raised in KCSD because of unnkown attribute. * Change setUp to test_2D and test_3D Tests were not running * Update kCSD1D_test_fig1_and_fig2_v2.py Corrected parameter name noise_lvl_n to noises * Seperation of src and est locations There are two corrections in the code. In method suggest Lambda there was a mistake in taking maximal value of lambda (should be log10 from that value). Second in classes oKCSD2D and 3D from now on locations of sources and estimations of potential/csd are seperated and are called own_src and own_est respectively. if only own_src are used own_est will have default positions the same as are in own_src array. * Seperation of src and est locations There are two corrections in the code. In method suggest Lambda there was a mistake in taking maximal value of lambda (should be log10 from that value). Second in classes oKCSD2D and 3D from now on locations of sources and estimations of potential/csd are seperated and are called own_src and own_est respectively. if only own_src are used own_est will have default positions the same as are in own_src array. * oKCSD2D and oKCSD3D will raise error if user will not provide own_src array Following error will be raised in oKCSD2D and oKCSD3D class in case when user will not provide own_src array: "'"own_src" is required argument to use oKCSD2D. If you would like to reconstruct in default region of interest please use KCSD2D') * Update kcsd_tests.py * New tests to test own_est and own_src parameters problems New tests for oKCSD2D and oKCSD3D that raises error when neither own_est nor own_src are provided * Corrected tests Checking if own_est is correctly overwritten with own_src * Update kcsd_tests.py * Create make_fig_lcurve.py script to generate figures for l-curve part * L-Curve figure correction
1 parent b71c707 commit d10ee08

5 files changed

Lines changed: 547 additions & 69 deletions

File tree

figures/kCSD_properties/kCSD1D_test_fig1_and_fig2_v2.py

Lines changed: 67 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -52,38 +52,37 @@ def make_plot_fig2(sim_results):
5252
trans = np.min(np.mean(rms_lc, axis=0))
5353
mn_rms = np.mean(rms_lc, axis=0) - trans
5454
st_rms = st.sem(rms_lc, axis=0)
55-
plt.plot(noise_lvl, mn_rms, marker = 'o', color = 'blue', label = 'L-curve')
55+
plt.plot(noise_lvl, mn_rms, marker = 'o', color = 'blue', label = 'kCSD L-Curve')
5656
plt.fill_between(noise_lvl, mn_rms - st_rms,
5757
mn_rms + st_rms, alpha = 0.3, color = 'blue')
5858
mn_rms = np.mean(rms_cv, axis=0) - trans
5959
st_rms = st.sem(rms_cv, axis=0)
60-
plt.plot(noise_lvl, mn_rms, marker = 'o', color = 'green', label = 'cross-validation')
60+
plt.plot(noise_lvl, mn_rms, marker = 'o', color = 'green', label = 'kCSD Cross-Validation')
6161
plt.fill_between(noise_lvl, mn_rms - st_rms,
6262
mn_rms + st_rms, alpha = 0.3, color = 'green')
63-
plt.ylabel('Estimation error')
64-
plt.xlabel('Relative noise level')
63+
plt.ylabel('Estimation error', labelpad = 30)
64+
plt.xlabel('Relative noise level', labelpad = 15)
6565
ax1.spines['right'].set_visible(False)
6666
ax1.spines['top'].set_visible(False)
6767
set_axis(ax1, -0.05, 1.05, letter='A')
68-
ht, lh = ax1.get_legend_handles_labels()
69-
fig.legend(ht, lh, loc='center', ncol=2, frameon=False)
70-
68+
plt.title('Comparison of the regularization methods performance')
69+
7170
'''second plot'''
7271
ax2 = plt.subplot(gs[1])
7372
mn_lam = np.mean(lam_lc, axis=0)
7473
st_lam = st.sem(lam_lc, axis=0)
75-
plt.plot(noise_lvl, mn_lam, marker = 'o', color = 'blue', label = 'L-curve')
74+
plt.plot(noise_lvl, mn_lam, marker = 'o', color = 'blue', label = 'kCSD L-Curve')
7675
plt.fill_between(noise_lvl, mn_lam - st_lam,
7776
mn_lam + st_lam, alpha = 0.3, color = 'blue')
7877
mn_lam = np.mean(lam_cv, axis=0)
7978
st_lam = st.sem(lam_cv, axis=0)
80-
plt.plot(noise_lvl, mn_lam, marker = 'o', color = 'green', label = 'cross-validation')
79+
plt.plot(noise_lvl, mn_lam, marker = 'o', color = 'green', label = 'kCSD Cross-Validation')
8180
plt.fill_between(noise_lvl, mn_lam - st_lam,
8281
mn_lam + st_lam, alpha = 0.3, color = 'green')
8382
# ax2.set_yscale('log')
8483
ax2.ticklabel_format(style='sci', axis='y', scilimits=((0.0, 0.0)))
85-
plt.ylabel('Lambda')
86-
plt.xlabel('Relative noise level')
84+
plt.ylabel('Lambda', labelpad = 20)
85+
plt.xlabel('Relative noise level', labelpad = 15)
8786
set_axis(ax2, -0.05, 1.05, letter='B')
8887
ht, lh = ax2.get_legend_handles_labels()
8988
fig.legend(ht, lh, loc='lower center', ncol=2, frameon=False)
@@ -117,53 +116,41 @@ def make_plots(title, m_norm, m_resi, true_csd, curveseq, ele_y,
117116
1_ LFP measured (with added noise) and estimated LFP with kCSD method
118117
2_ true CSD and reconstructed CSD with kCSD
119118
3_ L-curve of the model
120-
3_ Surface of parameters R and Lambda with scores for optimal paramater selection with L-curve or cross-validation
119+
4_ Surface of parameters R and Lambda with scores for optimal paramater selection with L-curve or cross-validation
121120
"""
122121
#True CSD
123122
fig = plt.figure(figsize=(12, 12), dpi=300)
124123
widths = [1, 1]
125124
heights = [1, 1]
125+
xpad=5
126+
ypad = 10
126127
gs = gridspec.GridSpec(2, 2, height_ratios=heights, width_ratios=widths,
127-
hspace=0.45, wspace=0.3)
128+
hspace=0.5, wspace=0.5)
128129
xrange = np.linspace(0, 1, len(true_csd))
129130
ax1 = plt.subplot(gs[0])
130-
ax1.plot(ele_y, pots*1e3, 'r', marker='o', linewidth=0, label='measured potential')
131-
ax1.scatter(ele_y, np.zeros(len(ele_y)), 8, color='black')
132-
ax1.plot(xrange, est_pot*1e3, label='recon. potential', color='blue')
133-
ax1.set_ylabel('Potential [mV]')
134-
ax1.set_xlabel('Distance')
131+
ax1.plot(ele_y, pots*1e3, 'r', marker='o', linewidth=0, label='Measured potential')
132+
ax1.scatter(ele_y, np.zeros(len(ele_y)), 8, color='black', label = "Electrode position")
133+
ax1.plot(xrange, est_pot*1e3, label='Reconstructed potential', color='blue')
134+
ax1.set_ylabel('Potential ($mV$)', labelpad = ypad)
135+
ax1.set_xlabel('Distance', labelpad = xpad)
135136
ax1.tick_params(axis='both', which='major')
136137
ax1.spines['right'].set_visible(False)
137138
ax1.spines['top'].set_visible(False)
138139
set_axis(ax1, -0.05, 1.05, letter='A')
139-
ht, lh = ax1.get_legend_handles_labels()
140-
fig.legend(ht, lh, loc='center left', ncol=2, frameon=False)
140+
ax1.legend(bbox_to_anchor=(1.5, -0.16), ncol=2, frameon=False)
141141

142-
ax2 = plt.subplot(gs[2])
143-
plt.plot(xrange, true_csd, label='true csd', color='red', linestyle = '--')
144-
plt.plot(xrange, est_csd, label='recon. csd with reg.', color='blue')
145-
plt.plot(xrange, noreg_csd, label='recon. csd no reg.', color='darkgreen', alpha = 0.6)
146-
plt.ylim(-1, 1)
147-
plt.scatter(ele_y, np.zeros(len(ele_y)), 8, color='black')
148-
ax2.set_ylabel('CSD [mA/mm]')
149-
ax2.set_xlabel('Distance')
150-
ax2.tick_params(axis='both', which='major')
151-
ax2.spines['right'].set_visible(False)
152-
ax2.spines['top'].set_visible(False)
153-
set_axis(ax2, -0.05, 1.05, letter='C')
154-
ht, lh = ax2.get_legend_handles_labels()
155-
fig.legend(ht, lh, loc='lower left', ncol=2, frameon=False)
156142
ax_L = plt.subplot(gs[1])
157143
Lamb = [-7, -3]
158144
if name == 'lc':
159145
imax = np.argmax(curveseq[np.argmax(np.max(curveseq, axis=-1))])
160-
plt.ylabel("Norm of the model")
161-
plt.xlabel("Norm of the prediction error")
146+
plt.ylabel("Norm of the model", labelpad = ypad)
147+
plt.xlabel("Norm of the prediction error", labelpad = xpad)
148+
ax_L.plot(m_resi, m_norm, marker=".", c="green", label = 'L-Curve')
162149
else:
163150
imax = np.argmin(m_norm)
164-
plt.xlabel("Lambda")
165-
plt.ylabel("CV error")
166-
ax_L.plot(m_resi, m_norm, marker=".", c="green")
151+
plt.xlabel("Lambda", labelpad = xpad)
152+
plt.ylabel("CV error", labelpad = ypad)
153+
ax_L.plot(m_resi, m_norm, marker=".", c="green", label = 'CV curve')
167154
#print(m_resi, m_norm)
168155
ax_L.plot([m_resi[imax]], [m_norm[imax]], marker="o", c="red")
169156
x = [m_resi[0], m_resi[imax], m_resi[-1]]
@@ -175,39 +162,50 @@ def make_plots(title, m_norm, m_resi, true_csd, curveseq, ele_y,
175162
ax_L.spines['right'].set_visible(False)
176163
ax_L.spines['top'].set_visible(False)
177164
set_axis(ax_L, -0.05, 1.05, letter='B')
165+
ax_L.legend(bbox_to_anchor=(0.7, -0.16), ncol=1, frameon=False)
166+
167+
ax2 = plt.subplot(gs[2])
168+
plt.plot(xrange, true_csd, label='True CSD', color='red', linestyle = '--')
169+
plt.plot(xrange, est_csd, label='kCSD + regularization', color='blue')
170+
plt.plot(xrange, noreg_csd, label='kCSD', color='darkgreen', alpha = 0.6)
171+
plt.ylim(-1, 1)
172+
plt.scatter(ele_y, np.zeros(len(ele_y)), 8, color='black', label = "Electrode position")
173+
ax2.set_ylabel('CSD ($mA/mm$)', labelpad = ypad)
174+
ax2.set_xlabel('Distance', labelpad = xpad)
175+
ax2.tick_params(axis='both', which='major')
176+
ax2.spines['right'].set_visible(False)
177+
ax2.spines['top'].set_visible(False)
178+
ax2.legend(bbox_to_anchor=(-.25, -0.25), ncol=2, frameon=False, loc = 'center left')
179+
set_axis(ax2, -0.05, 1.05, letter='C')
180+
178181
ax4 = plt.subplot(gs[3])
179182
if name == 'lc':
180183
Lamb = [-7, -3]
181184
lambdas = np.logspace(Lamb[0], Lamb[1], 50, base=10)
182-
ax4.plot(lambdas, curveseq[0], marker=".")
185+
ax4.plot(lambdas, curveseq[0], marker=".", label = 'Curvature evaluation')
183186
ax4.plot([lambdas[imax]], [curveseq[0][imax]], marker="o", c="red")
184-
# im = plt.imshow(curveseq, extent=[Lamb[0], Lamb[1], ery[-1], ery[0]],
185-
# interpolation='none', aspect='auto',
186-
# cmap='BrBG_r', vmax=np.max(curveseq), vmin=-np.max(curveseq))
187-
# divider = make_axes_locatable(ax4)
188-
# cax4 = divider.append_axes("right", size="5%", pad=0.05)
189-
# fig.colorbar(im, cax = cax4)
190-
ax4.set_ylabel('Curvature')
191-
ax4.set_xlabel('Lambda')
187+
ax4.set_ylabel('Curvature', labelpad = ypad)
188+
ax4.set_xlabel('Lambda', labelpad = xpad)
192189
ax4.set_xscale('log')
193190
ax4.tick_params(axis='both', which='major')
194191
ax4.spines['right'].set_visible(False)
195192
ax4.spines['top'].set_visible(False)
193+
ax4.legend(bbox_to_anchor=(1, -0.16), ncol=2, frameon=False)
196194
else:
197195
imax = np.argmin(m_norm)
198-
plt.xlabel("Lambda")
199-
plt.ylabel("CV error")
200-
ax_L.plot(m_resi, m_norm, marker=".", c="green")
201-
#print(m_resi, m_norm)
202-
ax_L.plot([m_resi[imax]], [m_norm[imax]], marker="o", c="red")
196+
plt.ylabel("CV error", labelpad = ypad)
197+
plt.xlabel("Lambda", labelpad = xpad)
198+
ax4.plot(m_resi, m_norm, marker=".", c="green", label = 'CV curve')
199+
ax4.plot([m_resi[imax]], [m_norm[imax]], marker="o", c="red")
203200
x = [m_resi[0], m_resi[imax], m_resi[-1]]
204201
y = [m_norm[0], m_norm[imax], m_norm[-1]]
205-
ax_L.fill(x, y, alpha=0.2)
206-
ax_L.set_xscale('log')
207-
ax_L.set_yscale('log')
208-
ax_L.tick_params(axis='both', which='major')
209-
ax_L.spines['right'].set_visible(False)
210-
ax_L.spines['top'].set_visible(False)
202+
ax4.fill(x, y, alpha=0.2)
203+
ax4.set_xscale('log')
204+
ax4.set_yscale('log')
205+
ax4.tick_params(axis='both', which='major')
206+
ax4.spines['right'].set_visible(False)
207+
ax4.spines['top'].set_visible(False)
208+
ax4.legend(bbox_to_anchor=(1, -0.16), ncol=2, frameon=False)
211209
set_axis(ax4, -0.05, 1.05, letter='D')
212210
fig.savefig(save_as+'.jpg')
213211
true_csd_error = np.linalg.norm(true_csd/np.max(abs(true_csd)) - est_csd/np.max(abs(est_csd)))
@@ -328,17 +326,22 @@ def main_loop(src_width, total_ele, inpos, lpos, nm, noise=0, srcs=1):
328326

329327
if __name__=='__main__':
330328
saveDir = "./LCurve/"
331-
os.chdir(saveDir)
329+
try:
330+
os.chdir(saveDir)
331+
except FileNotFoundError:
332+
os.mkdir(saveDir)
333+
os.chdir(saveDir)
332334
figs1_and_fig2 = True
333335
total_ele = 32
334336
names = ['lc', 'cv']
335337
src_width = 0.001
336-
noise_lvl = np.linspace(0, 0.5, 10)
338+
noises = 3
337339
seeds = 3
338340
inpos = [0.5, 0.1]#od dolu
339341
lpos = [0.5, 0.9]
342+
noise_lvl = np.linspace(0, 0.5, noises)
340343
ery = np.linspace(3*0.025, 0.025*16, 1)
341-
sim_results = np.zeros((2, 4, seeds, 10))
344+
sim_results = np.zeros((2, 4, seeds, noises))
342345
sim_results[:, 3, :, :] = noise_lvl
343346
if figs1_and_fig2:
344347
for iname, name in enumerate(names):
@@ -352,11 +355,11 @@ def main_loop(src_width, total_ele, inpos, lpos, nm, noise=0, srcs=1):
352355
lpos, name, noise=noise_lvl, srcs=src)
353356
sim_results[iname,:2, src] = LandR
354357
sim_results[iname, 2, src] = RMS_wek
355-
np.save('sim_results', sim_results)
356358
os.chdir('..')
359+
np.save('sim_results', sim_results)
357360
sim_results = np.load('sim_results.npy')
358361
make_plot_fig2(sim_results)
359362
else:
360363
name = 'lc'
361364
main_loop(src_width, total_ele, inpos,
362-
lpos, name, noise=noise_lvl[:1], srcs=0)
365+
lpos, name, noise=noise_lvl[:1], srcs=0)

0 commit comments

Comments
 (0)