@@ -70,7 +70,33 @@ def test_valid_inputs(self):
7070 self .assertRaises (TypeError , self .test_kcsd1d_estimate )
7171 cv_params = {'InvalidCVArg' : np .array ((0.1 , 0.25 , 0.5 ))}
7272 self .assertRaises (TypeError , self .test_kcsd1d_estimate , cv_params )
73-
73+
74+ class KCSD1D_TestCase_dipolar (unittest .TestCase ):
75+ def setUp (self ):
76+ dim = 1
77+ utils = ValidateKCSD1D (csd_seed = 42 )
78+ self .ele_pos = utils .generate_electrodes (total_ele = 10 ,
79+ ele_lims = [0.1 , 0.9 ])
80+ self .csd_profile = CSDp .gauss_1d_dipole
81+ self .csd_at , self .csd = utils .generate_csd (self .csd_profile ,
82+ csd_seed = 42 )
83+ pots = utils .calculate_potential (self .csd , self .csd_at , self .ele_pos ,
84+ h = 1. , sigma = 0.3 )
85+ self .pots = np .reshape (pots , (- 1 , 1 ))
86+ self .test_method = 'KCSD1D'
87+ self .test_params = {'h' : 1. , 'sigma' : 0.3 , 'R_init' : 0.2 ,
88+ 'n_src_init' : 100 , 'xmin' : 0. , 'xmax' : 1. ,}
89+
90+ def test_kcsd1d_estimate (self , cv_params = {}):
91+ self .test_params .update (cv_params )
92+ result = KCSD1D (self .ele_pos , self .pots ,
93+ ** self .test_params )
94+ result .cross_validate ()
95+ vals = result .values ()
96+ true_csd = self .csd_profile (result .estm_x , 42 )
97+ rms = np .linalg .norm (np .array (vals [:, 0 ]) - true_csd )
98+ rms /= np .linalg .norm (true_csd )
99+ self .assertLess (rms , 0.5 , msg = 'RMS between trueCSD and estimate > 0.5' )
74100
75101class KCSD2D_TestCase (unittest .TestCase ):
76102 def setUp (self ):
@@ -114,17 +140,17 @@ def test_moi_estimate(self):
114140 self .assertLess (rms , 2.5 , msg = 'RMS ' + str (rms ) +
115141 'between trueCSD and estimate > 2.5' )
116142
117- def test_lcurve (self ):
118- result = KCSD2D (self .ele_pos , self .pots ,
119- ** self .test_params )
120- result .L_curve ()
121- vals = result .values ()
122- pvals = result .values ('POT' )
123- true_csd = self .csd_profile (result .estm_pos , 43 )
124- rms = np .linalg .norm (np .array (vals [:, :, 0 ]) - true_csd )
125- rms /= np .linalg .norm (true_csd )
126- self .assertLess (rms , 0.5 , msg = 'RMS ' + str (rms ) +
127- 'between trueCSD and estimate > 0.5' )
143+ # def test_lcurve(self):
144+ # result = KCSD2D(self.ele_pos, self.pots,
145+ # **self.test_params)
146+ # result.L_curve()
147+ # vals = result.values()
148+ # pvals = result.values('POT')
149+ # true_csd = self.csd_profile(result.estm_pos, 43)
150+ # rms = np.linalg.norm(np.array(vals[:, :, 0]) - true_csd)
151+ # rms /= np.linalg.norm(true_csd)
152+ # self.assertLess(rms, 0.5, msg='RMS ' + str(rms) +
153+ # 'between trueCSD and estimate > 0.5')
128154
129155 # def test_method_generic_lim(self):
130156 # self.test_params.update({'src_type': 'gauss_lim'})
@@ -143,7 +169,7 @@ def test_valid_inputs(self):
143169 cv_params = {'InvalidCVArg' : np .array ((0.1 , 0.25 , 0.5 ))}
144170 self .assertRaises (TypeError , self .test_kcsd2d_estimate , cv_params )
145171
146-
172+
147173class KCSD3D_TestCase (unittest .TestCase ):
148174 def setUp (self ):
149175 dim = 3
@@ -185,15 +211,16 @@ def test_kcsd3d_estimate(self, cv_params={}):
185211 # self.test_params.update({'src_type': 'step'})
186212 # self.test_kcsd3d_estimate()
187213
188- def test_valid_inputs (self ):
189- self .test_method = 'KCSD3D'
190- self .test_params = {'src_type' : 22 }
191- self .assertRaises (KeyError , self .test_kcsd3d_estimate )
192- self .test_params = {'InvalidKwarg' : 21 }
193- self .assertRaises (TypeError , self .test_kcsd3d_estimate )
194- cv_params = {'InvalidCVArg' : np .array ((0.1 , 0.25 , 0.5 ))}
195- self .assertRaises (TypeError , self .test_kcsd3d_estimate , cv_params )
214+ # def test_valid_inputs(self):
215+ # self.test_method = 'KCSD3D'
216+ # self.test_params = {'src_type': 22}
217+ # self.assertRaises(KeyError, self.test_kcsd3d_estimate)
218+ # self.test_params = {'InvalidKwarg': 21}
219+ # self.assertRaises(TypeError, self.test_kcsd3d_estimate)
220+ # cv_params = {'InvalidCVArg': np.array((0.1, 0.25, 0.5))}
221+ # self.assertRaises(TypeError, self.test_kcsd3d_estimate, cv_params)
196222
223+
197224class oKCSD2D_TestCase (unittest .TestCase ):
198225 def test_2D (self ):
199226 ele_pos = np .array ([[- 0.2 , - 0.2 ], [0 , 0 ], [0 , 1 ], [1 , 0 ], [1 , 1 ], [0.5 , 0.5 ], [1.2 , 1.2 ]])
0 commit comments