Skip to content

Commit 9a24eca

Browse files
m-kowalskaccluri
authored andcommitted
Scripts producing figures for kCSD paper (#104)
* corrected algorithm for distribution of basis sources in 2D * updated documentation: validation part * Added script that produces plot for unitary potential propagation * code refactoring * Renamed script that produces reliability map plot * Renamed script that produces example of kCSD reconstruction with reliability map in 2D * Code and documentation refactoring * Code refactoring, changed scripts names * Common scale on the plot * Common scale on the plot, error formulation as in the paper * Fixed typo * Removed redundant variables. * Fixed round problem - generating electrodes in 3D * Fixed potential unit * Electrodes markers on the X-axis
1 parent a92889e commit 9a24eca

14 files changed

Lines changed: 986 additions & 315 deletions

figures/kCSD_properties/figure_Tbasis.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ def make_subplot(ax, true_csd, est_csd, estm_x, title=None, ele_pos=None,
8888
l3 = ax.plot(estm_x, est_csd_LC, '.', label='kCSD L-Curve', lw=2.)
8989
else:
9090
l2 = ax.plot(estm_x, est_csd, '--', label='kCSD', lw=2.)
91-
s1 = ax.scatter(ele_pos, np.zeros(len(ele_pos)), 17, 'k', label='Electrodes')
92-
# ax.legend(fontsize=10)
9391
ax.set_xlim([0, 1])
9492
if xlabel:
9593
ax.set_xlabel('Depth ($mm$)')
@@ -99,8 +97,23 @@ def make_subplot(ax, true_csd, est_csd, estm_x, title=None, ele_pos=None,
9997
ax.set_title(title)
10098
if np.max(est_csd) < 1.2:
10199
ax.set_ylim(-0.2, 1.2)
102-
elif np.max(est_csd) > 500:
103-
ax.set_yticks([-5000, 0, 5000], [-5000, 0, 5000])
100+
s1 = ax.scatter(ele_pos, np.zeros(len(ele_pos))-0.2, 17, 'k', label='Electrodes')
101+
s1.set_clip_on(False)
102+
elif np.max(est_csd) < 1.7:
103+
ax.set_ylim(-10000, 10000)
104+
s3 = ax.scatter(ele_pos, np.zeros(len(ele_pos))-10000, 17, 'k', label='Electrodes')
105+
s3.set_clip_on(False)
106+
ax.set_yticks([-7000, 0, 7000])
107+
if np.max(est_csd) > 500:
108+
ax.set_ylim(-7000, 7000)
109+
s3 = ax.scatter(ele_pos, np.zeros(len(ele_pos))-7000, 17, 'k', label='Electrodes')
110+
s3.set_clip_on(False)
111+
ax.set_yticks([-5000, 0, 5000])
112+
elif np.max(est_csd) > 50:
113+
ax.set_ylim(-100, 100)
114+
s2 = ax.scatter(ele_pos, np.zeros(len(ele_pos))-100, 17, 'k', label='Electrodes')
115+
s2.set_clip_on(False)
116+
ax.set_yticks([-70, 0, 70])
104117
ax.set_xticks([0, 0.5, 1])
105118
set_axis(ax, letter=letter)
106119
# ax.legend(frameon=False, loc='upper center', ncol=3)
@@ -493,10 +506,10 @@ def generate_figure_CVLC(R, MU, n_src, true_csd_xlims, total_ele, save_path,
493506
MU = 0.25
494507
method = 'cross-validation' # L-curve
495508
# method = 'L-curve'
496-
Rs = np.arange(0.1, 0.4, 0.05)
497-
# Rs = np.array([0.2])
509+
# Rs = np.arange(0.1, 0.4, 0.05)
510+
Rs = np.array([0.2])
498511
lambdas = np.zeros(1)
499-
# generate_figure(R, MU, N_SRC, TRUE_CSD_XLIMS, TOTAL_ELE, SAVE_PATH,
500-
# method=method, Rs=Rs, lambdas=lambdas, noise=0)
501-
generate_figure_CVLC(R, MU, N_SRC, TRUE_CSD_XLIMS, TOTAL_ELE, SAVE_PATH,
502-
Rs=Rs, lambdas=None, noise=10)
512+
generate_figure(R, MU, N_SRC, TRUE_CSD_XLIMS, TOTAL_ELE, SAVE_PATH,
513+
method=method, Rs=Rs, lambdas=lambdas, noise=0)
514+
# generate_figure_CVLC(R, MU, N_SRC, TRUE_CSD_XLIMS, TOTAL_ELE, SAVE_PATH,
515+
# Rs=Rs, lambdas=None, noise=10)

figures/kCSD_properties/figure_vectors_v2.py renamed to figures/kCSD_properties/figure_eigensources_M_1D.py

Lines changed: 33 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@
44
import os
55
from os.path import expanduser
66
import numpy as np
7+
from numpy.linalg import LinAlgError
78
import matplotlib.pyplot as plt
89
from figure_properties import *
910
import matplotlib.gridspec as gridspec
10-
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
11-
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
12-
from matplotlib.ticker import FuncFormatter
1311
import datetime
1412
import time
1513

16-
from kcsd import SpectralStructure, KCSD1D
14+
from kcsd import KCSD1D
1715
import targeted_basis as tb
1816

1917
__abs_file__ = os.path.abspath(__file__)
@@ -23,38 +21,24 @@ def _html(r, g, b):
2321
return "#{:02X}{:02X}{:02X}".format(r, g, b)
2422

2523

26-
def stability_M(csd_profile, n_src, ele_lims, true_csd_xlims,
27-
total_ele, ele_pos, pots,
28-
method='cross-validation', Rs=None, lambdas=None):
24+
def stability_M(n_src, total_ele, ele_pos, pots, R_init=0.23):
2925
"""
3026
Investigates stability of reconstruction for different number of basis
3127
sources
3228
3329
Parameters
3430
----------
35-
csd_profile: function
36-
Function to produce csd profile.
3731
n_src: int
3832
Number of basis sources.
39-
ele_lims: list
40-
Boundaries for electrodes placement.
41-
true_csd_xlims: list
42-
Boundaries for ground truth space.
4333
total_ele: int
4434
Number of electrodes.
4535
ele_pos: numpy array
4636
Electrodes positions.
4737
pots: numpy array
4838
Values of potentials at ele_pos.
49-
method: string
50-
Determines the method of regularization.
51-
Default: cross-validation.
52-
Rs: numpy 1D array
53-
Basis source parameter for crossvalidation.
54-
Default: None.
55-
lambdas: numpy 1D array
56-
Regularization parameter for crossvalidation.
57-
Default: None.
39+
R_init: float
40+
Initial value of R parameter - width of basis source
41+
Default: 0.23.
5842
5943
Returns
6044
-------
@@ -70,14 +54,19 @@ def stability_M(csd_profile, n_src, ele_lims, true_csd_xlims,
7054
for i, value in enumerate(n_src):
7155
pots = pots.reshape((len(ele_pos), 1))
7256
obj = KCSD1D(ele_pos, pots, src_type='gauss', sigma=0.3, h=0.25,
73-
gdx=0.01, n_src_init=n_src[i], ext_x=0, xmin=0, xmax=1)
74-
if method == 'cross-validation':
75-
obj.cross_validate(Rs=Rs, lambdas=lambdas)
76-
elif method == 'L-curve':
77-
obj.L_curve(Rs=Rs, lambdas=lambdas)
78-
ss = SpectralStructure(obj)
79-
eigenvectors[i], eigenvalues[i] = ss.evd()
80-
57+
gdx=0.01, n_src_init=n_src[i], ext_x=0, xmin=0, xmax=1,
58+
R_init=R_init)
59+
try:
60+
eigenvalue, eigenvector = np.linalg.eigh(obj.k_pot +
61+
obj.lambd *
62+
np.identity
63+
(obj.k_pot.shape[0]))
64+
except LinAlgError:
65+
raise LinAlgError('EVD is failing - try moving the electrodes'
66+
'slightly')
67+
idx = eigenvalue.argsort()[::-1]
68+
eigenvalues[i] = eigenvalue[idx]
69+
eigenvectors[i] = eigenvector[:, idx]
8170
obj_all.append(obj)
8271
return obj_all, eigenvalues, eigenvectors
8372

@@ -112,8 +101,7 @@ def set_axis(ax, x, y, letter=None):
112101

113102

114103
def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
115-
save_path, method='cross-validation', Rs=None,
116-
lambdas=None, noise=0):
104+
save_path, noise=0, R_init=0.23):
117105
"""
118106
Generates figure for spectral structure decomposition.
119107
@@ -135,18 +123,12 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
135123
Electrodes limits.
136124
save_path: string
137125
Directory.
138-
method: string
139-
Determines the method of regularization.
140-
Default: cross-validation.
141-
Rs: numpy 1D array
142-
Basis source parameter for crossvalidation.
143-
Default: None.
144-
lambdas: numpy 1D array
145-
Regularization parameter for crossvalidation.
146-
Default: None.
147126
noise: float
148127
Determines the level of noise in the data.
149128
Default: 0.
129+
R_init: float
130+
Initial value of R parameter - width of basis source
131+
Default: 0.23.
150132
151133
Returns
152134
-------
@@ -159,19 +141,16 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
159141
noise=noise)
160142

161143
n_src_M = [2, 4, 8, 16, 32, 64, 128, 256, 512]
162-
OBJ_M, eigenval_M, eigenvec_M = stability_M(csd_profile, n_src_M,
163-
ele_lims, true_csd_xlims,
144+
OBJ_M, eigenval_M, eigenvec_M = stability_M(n_src_M,
164145
total_ele, ele_pos, pots,
165-
method=method, Rs=Rs,
166-
lambdas=lambdas)
146+
R_init=R_init)
167147

168148
plt_cord = [(2, 0), (2, 2), (2, 4),
169149
(3, 0), (3, 2), (3, 4),
170150
(4, 0), (4, 2), (4, 4),
171151
(5, 0), (5, 2), (5, 4)]
172152

173-
174-
letters = ['C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'O']
153+
letters = ['C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N']
175154

176155
BLACK = _html(0, 0, 0)
177156
ORANGE = _html(230, 159, 0)
@@ -199,10 +178,8 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
199178
linestyle=linestyles[indx], color=colors[indx],
200179
marker=markers[indx], label='M='+str(n_src_M[i]),
201180
markersize=10)
202-
# ax.set_title(' ', fontsize=12)
203181
ht, lh = ax.get_legend_handles_labels()
204182
set_axis(ax, -0.05, 1.05, letter='A')
205-
# ax.legend(loc='lower left')
206183
ax.set_xlabel('Number of components')
207184
ax.set_ylabel('Eigenvalues')
208185
ax.set_yscale('log')
@@ -213,7 +190,6 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
213190
ax = fig.add_subplot(gs[0, 3:])
214191
ax.plot(n_src_M, eigenval_M[:, 0], marker='s', color='k', markersize=5,
215192
linestyle=' ')
216-
#ax.set_title(' ', fontsize=12)
217193
set_axis(ax, -0.05, 1.05, letter='B')
218194
ax.set_xlabel('Number of basis sources')
219195
ax.set_xscale('log')
@@ -229,13 +205,9 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
229205
eigenvec_M[j, :, i]),
230206
linestyle=linestyles[idx], color=colors[idx],
231207
label='M='+str(n_src_M[j]), lw=2)
232-
#ax.set_title(r"$\tilde{K}*v_{{%(i)d}}$" % {'i': i+1})
233-
ax.text(0.5, 1., r"$\tilde{K}*v_{{%(i)d}}$" % {'i': i+1},
234-
horizontalalignment='center', transform=ax.transAxes, fontsize=20)
235-
# ax.locator_params(axis='y', nbins=3)
236-
237-
# ax.set_xlabel('Depth (mm)', fontsize=12)
238-
# ax.set_ylabel('CSD (mA/mm)', fontsize=12)
208+
ax.text(0.5, 1., r"$\tilde{K}\cdot{v_{{%(i)d}}}$" % {'i': i+1},
209+
horizontalalignment='center', transform=ax.transAxes,
210+
fontsize=20)
239211
set_axis(ax, -0.10, 1.1, letter=letters[i])
240212
if i < 9:
241213
ax.get_xaxis().set_visible(False)
@@ -245,23 +217,12 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
245217
if i % 3 == 0:
246218
ax.set_ylabel('CSD ($mA/mm$)')
247219
ax.yaxis.set_label_coords(-0.18, 0.5)
248-
# ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
249-
# ax.tick_params(direction='out', pad=10)
250-
# ax.yaxis.get_major_formatter(FormatStrFormatter('%.2f'))
251220
ax.ticklabel_format(style='sci', axis='y', scilimits=((0.0, 0.0)))
252221
ax.spines['right'].set_visible(False)
253222
ax.spines['top'].set_visible(False)
254-
# ht, lh = ax.get_legend_handles_labels()
255-
256-
# ax = fig.add_subplot(gs[3, :])
257-
# ax.legend(ht, lh, fancybox=False, shadow=False, ncol=len(src_idx),
258-
# loc='upper center', frameon=False, bbox_to_anchor=(0.5, 0.0))
259-
# ax.axis('off')
260-
261-
# plt.tight_layout()
262223
fig.legend(ht, lh, loc='lower center', ncol=5, frameon=False)
263-
fig.savefig(os.path.join(save_path, 'vectors_' + method +
264-
'_noise_' + str(noise) + '.png'), dpi=300)
224+
fig.savefig(os.path.join(save_path, 'vectors_' + '_noise_' +
225+
str(noise) + 'R0_2' + '.png'), dpi=300)
265226

266227
plt.show()
267228

@@ -281,6 +242,6 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
281242
CSD_PROFILE = tb.csd_profile
282243
R = 0.2
283244
MU = 0.25
245+
R_init = 0.2
284246
generate_figure(CSD_PROFILE, R, MU, TRUE_CSD_XLIMS, TOTAL_ELE, ELE_LIMS,
285-
SAVE_PATH, method='cross-validation',
286-
Rs=np.arange(0.1, 0.5, 0.05), noise=None)
247+
SAVE_PATH, noise=None, R_init=R_init)

0 commit comments

Comments
 (0)