1111import unittest
1212import numpy as np
1313from kcsd import ValidateKCSD1D , ValidateKCSD2D , ValidateKCSD3D
14- from kcsd import csd_profile as CSD
14+ from kcsd import csd_profile as CSDp
1515from kcsd import KCSD1D , KCSD2D , MoIKCSD , KCSD3D , oKCSD2D , oKCSD3D
1616
1717
@@ -21,7 +21,7 @@ def setUp(self):
2121 utils = ValidateKCSD1D (csd_seed = 42 )
2222 self .ele_pos = utils .generate_electrodes (total_ele = 10 ,
2323 ele_lims = [0.1 , 0.9 ])
24- self .csd_profile = CSD .gauss_1d_mono
24+ self .csd_profile = CSDp .gauss_1d_mono
2525 self .csd_at , self .csd = utils .generate_csd (self .csd_profile ,
2626 csd_seed = 42 )
2727 pots = utils .calculate_potential (self .csd , self .csd_at , self .ele_pos ,
@@ -42,6 +42,25 @@ def test_kcsd1d_estimate(self, cv_params={}):
4242 rms /= np .linalg .norm (true_csd )
4343 self .assertLess (rms , 0.5 , msg = 'RMS between trueCSD and estimate > 0.5' )
4444
45+ def test_lcurve (self ):
46+ result = KCSD1D (self .ele_pos , self .pots ,
47+ ** self .test_params )
48+ result .L_curve ()
49+ vals = result .values ()
50+ pvals = result .values ('POT' )
51+ true_csd = self .csd_profile (result .estm_x , 42 )
52+ rms = np .linalg .norm (np .array (vals [:, 0 ]) - true_csd )
53+ rms /= np .linalg .norm (true_csd )
54+ self .assertLess (rms , 0.5 , msg = 'RMS between trueCSD and estimate > 0.5' )
55+
56+ # def test_method_generic_lim(self):
57+ # self.test_params.update({'src_type': 'gauss_lim'})
58+ # self.test_kcsd1d_estimate()
59+
60+ # def test_method_generic_step(self):
61+ # self.test_params.update({'src_type': 'step'})
62+ # self.test_kcsd1d_estimate()
63+
4564 def test_valid_inputs (self ):
4665 self .test_method = 'KCSD1D'
4766 self .test_params = {'src_type' : 22 }
@@ -51,15 +70,15 @@ def test_valid_inputs(self):
5170 self .assertRaises (TypeError , self .test_kcsd1d_estimate )
5271 cv_params = {'InvalidCVArg' : np .array ((0.1 , 0.25 , 0.5 ))}
5372 self .assertRaises (TypeError , self .test_kcsd1d_estimate , cv_params )
54-
73+
5574
5675class KCSD2D_TestCase (unittest .TestCase ):
5776 def setUp (self ):
5877 dim = 2
5978 utils = ValidateKCSD2D (csd_seed = 43 )
6079 self .ele_pos = utils .generate_electrodes (total_ele = 49 ,
6180 ele_lims = [0.1 , 0.9 ])
62- self .csd_profile = CSD .gauss_2d_large
81+ self .csd_profile = CSDp .gauss_2d_large
6382 self .csd_at , self .csd = utils .generate_csd (self .csd_profile ,
6483 csd_seed = 43 )
6584 pots = utils .calculate_potential (self .csd , self .csd_at , self .ele_pos ,
@@ -95,6 +114,26 @@ def test_moi_estimate(self):
95114 self .assertLess (rms , 2.5 , msg = 'RMS ' + str (rms ) +
96115 'between trueCSD and estimate > 2.5' )
97116
117+ def test_lcurve (self ):
118+ result = KCSD2D (self .ele_pos , self .pots ,
119+ ** self .test_params )
120+ result .L_curve ()
121+ vals = result .values ()
122+ pvals = result .values ('POT' )
123+ true_csd = self .csd_profile (result .estm_pos , 43 )
124+ rms = np .linalg .norm (np .array (vals [:, :, 0 ]) - true_csd )
125+ rms /= np .linalg .norm (true_csd )
126+ self .assertLess (rms , 0.5 , msg = 'RMS ' + str (rms ) +
127+ 'between trueCSD and estimate > 0.5' )
128+
129+ def test_method_generic_lim (self ):
130+ self .test_params .update ({'src_type' : 'gauss_lim' })
131+ self .test_kcsd2d_estimate ()
132+
133+ # def test_method_generic_step(self):
134+ # self.test_params.update({'src_type': 'step'})
135+ # self.test_kcsd2d_estimate()
136+
98137 def test_valid_inputs (self ):
99138 self .test_method = 'KCSD2D'
100139 self .test_params = {'src_type' : 22 }
@@ -111,7 +150,7 @@ def setUp(self):
111150 utils = ValidateKCSD3D (csd_seed = 44 )
112151 self .ele_pos = utils .generate_electrodes (total_ele = 64 ,
113152 ele_lims = [0.1 , 0.9 ])
114- self .csd_profile = CSD .gauss_3d_large
153+ self .csd_profile = CSDp .gauss_3d_large
115154 self .csd_at , self .csd = utils .generate_csd (self .csd_profile ,
116155 csd_seed = 44 )
117156 pots = utils .calculate_potential (self .csd , self .csd_at , self .ele_pos ,
@@ -130,13 +169,22 @@ def test_kcsd3d_estimate(self, cv_params={}):
130169 ** self .test_params )
131170 result .cross_validate ()
132171 vals = result .values ()
172+ pvals = result .values ('POT' )
133173 true_csd = self .csd_profile (result .estm_pos , 44 )
134- print (true_csd .shape , vals .shape ) # Meh here!
174+ # print(true_csd.shape, vals.shape) # Meh here!
135175 rms = np .linalg .norm (np .array (vals [:, :, :, 0 ]) - true_csd )
136176 rms /= np .linalg .norm (true_csd )
137177 self .assertLess (rms , 0.5 , msg = 'RMS ' + str (rms ) +
138178 'between trueCSD and estimate > 0.5' )
139179
180+ def test_method_generic_lim (self ):
181+ self .test_params .update ({'src_type' : 'gauss_lim' })
182+ self .test_kcsd3d_estimate ()
183+
184+ # def test_method_generic_step(self):
185+ # self.test_params.update({'src_type': 'step'})
186+ # self.test_kcsd3d_estimate()
187+
140188 def test_valid_inputs (self ):
141189 self .test_method = 'KCSD3D'
142190 self .test_params = {'src_type' : 22 }
0 commit comments