44import os
55from os .path import expanduser
66import numpy as np
7+ from numpy .linalg import LinAlgError
78import matplotlib .pyplot as plt
89from figure_properties import *
910import matplotlib .gridspec as gridspec
10- from mpl_toolkits .axes_grid1 .inset_locator import zoomed_inset_axes
11- from mpl_toolkits .axes_grid1 .inset_locator import mark_inset
12- from matplotlib .ticker import FuncFormatter
1311import datetime
1412import time
1513
16- from kcsd import SpectralStructure , KCSD1D
14+ from kcsd import KCSD1D
1715import targeted_basis as tb
1816
1917__abs_file__ = os .path .abspath (__file__ )
@@ -23,38 +21,24 @@ def _html(r, g, b):
2321 return "#{:02X}{:02X}{:02X}" .format (r , g , b )
2422
2523
26- def stability_M (csd_profile , n_src , ele_lims , true_csd_xlims ,
27- total_ele , ele_pos , pots ,
28- method = 'cross-validation' , Rs = None , lambdas = None ):
24+ def stability_M (n_src , total_ele , ele_pos , pots , R_init = 0.23 ):
2925 """
3026 Investigates stability of reconstruction for different number of basis
3127 sources
3228
3329 Parameters
3430 ----------
35- csd_profile: function
36- Function to produce csd profile.
3731 n_src: int
3832 Number of basis sources.
39- ele_lims: list
40- Boundaries for electrodes placement.
41- true_csd_xlims: list
42- Boundaries for ground truth space.
4333 total_ele: int
4434 Number of electrodes.
4535 ele_pos: numpy array
4636 Electrodes positions.
4737 pots: numpy array
4838 Values of potentials at ele_pos.
49- method: string
50- Determines the method of regularization.
51- Default: cross-validation.
52- Rs: numpy 1D array
53- Basis source parameter for crossvalidation.
54- Default: None.
55- lambdas: numpy 1D array
56- Regularization parameter for crossvalidation.
57- Default: None.
39+ R_init: float
40+ Initial value of R parameter - width of basis source
41+ Default: 0.23.
5842
5943 Returns
6044 -------
@@ -70,14 +54,19 @@ def stability_M(csd_profile, n_src, ele_lims, true_csd_xlims,
7054 for i , value in enumerate (n_src ):
7155 pots = pots .reshape ((len (ele_pos ), 1 ))
7256 obj = KCSD1D (ele_pos , pots , src_type = 'gauss' , sigma = 0.3 , h = 0.25 ,
73- gdx = 0.01 , n_src_init = n_src [i ], ext_x = 0 , xmin = 0 , xmax = 1 )
74- if method == 'cross-validation' :
75- obj .cross_validate (Rs = Rs , lambdas = lambdas )
76- elif method == 'L-curve' :
77- obj .L_curve (Rs = Rs , lambdas = lambdas )
78- ss = SpectralStructure (obj )
79- eigenvectors [i ], eigenvalues [i ] = ss .evd ()
80-
57+ gdx = 0.01 , n_src_init = n_src [i ], ext_x = 0 , xmin = 0 , xmax = 1 ,
58+ R_init = R_init )
59+ try :
60+ eigenvalue , eigenvector = np .linalg .eigh (obj .k_pot +
61+ obj .lambd *
62+ np .identity
63+ (obj .k_pot .shape [0 ]))
64+ except LinAlgError :
65+ raise LinAlgError ('EVD is failing - try moving the electrodes'
66+ 'slightly' )
67+ idx = eigenvalue .argsort ()[::- 1 ]
68+ eigenvalues [i ] = eigenvalue [idx ]
69+ eigenvectors [i ] = eigenvector [:, idx ]
8170 obj_all .append (obj )
8271 return obj_all , eigenvalues , eigenvectors
8372
@@ -112,8 +101,7 @@ def set_axis(ax, x, y, letter=None):
112101
113102
114103def generate_figure (csd_profile , R , MU , true_csd_xlims , total_ele , ele_lims ,
115- save_path , method = 'cross-validation' , Rs = None ,
116- lambdas = None , noise = 0 ):
104+ save_path , noise = 0 , R_init = 0.23 ):
117105 """
118106 Generates figure for spectral structure decomposition.
119107
@@ -135,18 +123,12 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
135123 Electrodes limits.
136124 save_path: string
137125 Directory.
138- method: string
139- Determines the method of regularization.
140- Default: cross-validation.
141- Rs: numpy 1D array
142- Basis source parameter for crossvalidation.
143- Default: None.
144- lambdas: numpy 1D array
145- Regularization parameter for crossvalidation.
146- Default: None.
147126 noise: float
148127 Determines the level of noise in the data.
149128 Default: 0.
129+ R_init: float
130+ Initial value of R parameter - width of basis source
131+ Default: 0.23.
150132
151133 Returns
152134 -------
@@ -159,19 +141,16 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
159141 noise = noise )
160142
161143 n_src_M = [2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 ]
162- OBJ_M , eigenval_M , eigenvec_M = stability_M (csd_profile , n_src_M ,
163- ele_lims , true_csd_xlims ,
144+ OBJ_M , eigenval_M , eigenvec_M = stability_M (n_src_M ,
164145 total_ele , ele_pos , pots ,
165- method = method , Rs = Rs ,
166- lambdas = lambdas )
146+ R_init = R_init )
167147
168148 plt_cord = [(2 , 0 ), (2 , 2 ), (2 , 4 ),
169149 (3 , 0 ), (3 , 2 ), (3 , 4 ),
170150 (4 , 0 ), (4 , 2 ), (4 , 4 ),
171151 (5 , 0 ), (5 , 2 ), (5 , 4 )]
172152
173-
174- letters = ['C' , 'D' , 'E' , 'F' , 'G' , 'H' , 'I' , 'J' , 'K' , 'L' , 'M' , 'O' ]
153+ letters = ['C' , 'D' , 'E' , 'F' , 'G' , 'H' , 'I' , 'J' , 'K' , 'L' , 'M' , 'N' ]
175154
176155 BLACK = _html (0 , 0 , 0 )
177156 ORANGE = _html (230 , 159 , 0 )
@@ -199,10 +178,8 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
199178 linestyle = linestyles [indx ], color = colors [indx ],
200179 marker = markers [indx ], label = 'M=' + str (n_src_M [i ]),
201180 markersize = 10 )
202- # ax.set_title(' ', fontsize=12)
203181 ht , lh = ax .get_legend_handles_labels ()
204182 set_axis (ax , - 0.05 , 1.05 , letter = 'A' )
205- # ax.legend(loc='lower left')
206183 ax .set_xlabel ('Number of components' )
207184 ax .set_ylabel ('Eigenvalues' )
208185 ax .set_yscale ('log' )
@@ -213,7 +190,6 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
213190 ax = fig .add_subplot (gs [0 , 3 :])
214191 ax .plot (n_src_M , eigenval_M [:, 0 ], marker = 's' , color = 'k' , markersize = 5 ,
215192 linestyle = ' ' )
216- #ax.set_title(' ', fontsize=12)
217193 set_axis (ax , - 0.05 , 1.05 , letter = 'B' )
218194 ax .set_xlabel ('Number of basis sources' )
219195 ax .set_xscale ('log' )
@@ -229,13 +205,9 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
229205 eigenvec_M [j , :, i ]),
230206 linestyle = linestyles [idx ], color = colors [idx ],
231207 label = 'M=' + str (n_src_M [j ]), lw = 2 )
232- #ax.set_title(r"$\tilde{K}*v_{{%(i)d}}$" % {'i': i+1})
233- ax .text (0.5 , 1. , r"$\tilde{K}*v_{{%(i)d}}$" % {'i' : i + 1 },
234- horizontalalignment = 'center' , transform = ax .transAxes , fontsize = 20 )
235- # ax.locator_params(axis='y', nbins=3)
236-
237- # ax.set_xlabel('Depth (mm)', fontsize=12)
238- # ax.set_ylabel('CSD (mA/mm)', fontsize=12)
208+ ax .text (0.5 , 1. , r"$\tilde{K}\cdot{v_{{%(i)d}}}$" % {'i' : i + 1 },
209+ horizontalalignment = 'center' , transform = ax .transAxes ,
210+ fontsize = 20 )
239211 set_axis (ax , - 0.10 , 1.1 , letter = letters [i ])
240212 if i < 9 :
241213 ax .get_xaxis ().set_visible (False )
@@ -245,23 +217,12 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
245217 if i % 3 == 0 :
246218 ax .set_ylabel ('CSD ($mA/mm$)' )
247219 ax .yaxis .set_label_coords (- 0.18 , 0.5 )
248- # ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
249- # ax.tick_params(direction='out', pad=10)
250- # ax.yaxis.get_major_formatter(FormatStrFormatter('%.2f'))
251220 ax .ticklabel_format (style = 'sci' , axis = 'y' , scilimits = ((0.0 , 0.0 )))
252221 ax .spines ['right' ].set_visible (False )
253222 ax .spines ['top' ].set_visible (False )
254- # ht, lh = ax.get_legend_handles_labels()
255-
256- # ax = fig.add_subplot(gs[3, :])
257- # ax.legend(ht, lh, fancybox=False, shadow=False, ncol=len(src_idx),
258- # loc='upper center', frameon=False, bbox_to_anchor=(0.5, 0.0))
259- # ax.axis('off')
260-
261- # plt.tight_layout()
262223 fig .legend (ht , lh , loc = 'lower center' , ncol = 5 , frameon = False )
263- fig .savefig (os .path .join (save_path , 'vectors_' + method +
264- '_noise_' + str (noise ) + '.png' ), dpi = 300 )
224+ fig .savefig (os .path .join (save_path , 'vectors_' + '_noise_' +
225+ str (noise ) + 'R0_2' + '.png' ), dpi = 300 )
265226
266227 plt .show ()
267228
@@ -281,6 +242,6 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
281242 CSD_PROFILE = tb .csd_profile
282243 R = 0.2
283244 MU = 0.25
245+ R_init = 0.2
284246 generate_figure (CSD_PROFILE , R , MU , TRUE_CSD_XLIMS , TOTAL_ELE , ELE_LIMS ,
285- SAVE_PATH , method = 'cross-validation' ,
286- Rs = np .arange (0.1 , 0.5 , 0.05 ), noise = None )
247+ SAVE_PATH , noise = None , R_init = R_init )
0 commit comments