Skip to content

Commit 2e50768

Browse files
committed
adding more tests to improve coverage
1 parent cee9bb6 commit 2e50768

3 files changed

Lines changed: 58 additions & 6 deletions

File tree

README.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ Also included here are installation instructions, authors and their
4242
contributions, citation policy, contacts etc.,
4343

4444

45+
Also included are `Tutorials!`_
46+
47+
.. _Tutorials!: /docs/source/TUTORIALS.rst
48+
4549
.. include:: ./docs/source/TUTORIALS.rst
4650

4751

-523 Bytes
Binary file not shown.

kcsd/tests/kcsd_tests.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import unittest
1212
import numpy as np
1313
from kcsd import ValidateKCSD1D, ValidateKCSD2D, ValidateKCSD3D
14-
from kcsd import csd_profile as CSD
14+
from kcsd import csd_profile as CSDp
1515
from kcsd import KCSD1D, KCSD2D, MoIKCSD, KCSD3D, oKCSD2D, oKCSD3D
1616

1717

@@ -21,7 +21,7 @@ def setUp(self):
2121
utils = ValidateKCSD1D(csd_seed=42)
2222
self.ele_pos = utils.generate_electrodes(total_ele=10,
2323
ele_lims=[0.1, 0.9])
24-
self.csd_profile = CSD.gauss_1d_mono
24+
self.csd_profile = CSDp.gauss_1d_mono
2525
self.csd_at, self.csd = utils.generate_csd(self.csd_profile,
2626
csd_seed=42)
2727
pots = utils.calculate_potential(self.csd, self.csd_at, self.ele_pos,
@@ -42,6 +42,25 @@ def test_kcsd1d_estimate(self, cv_params={}):
4242
rms /= np.linalg.norm(true_csd)
4343
self.assertLess(rms, 0.5, msg='RMS between trueCSD and estimate > 0.5')
4444

45+
def test_lcurve(self):
46+
result = KCSD1D(self.ele_pos, self.pots,
47+
**self.test_params)
48+
result.L_curve()
49+
vals = result.values()
50+
pvals = result.values('POT')
51+
true_csd = self.csd_profile(result.estm_x, 42)
52+
rms = np.linalg.norm(np.array(vals[:, 0]) - true_csd)
53+
rms /= np.linalg.norm(true_csd)
54+
self.assertLess(rms, 0.5, msg='RMS between trueCSD and estimate > 0.5')
55+
56+
# def test_method_generic_lim(self):
57+
# self.test_params.update({'src_type': 'gauss_lim'})
58+
# self.test_kcsd1d_estimate()
59+
60+
# def test_method_generic_step(self):
61+
# self.test_params.update({'src_type': 'step'})
62+
# self.test_kcsd1d_estimate()
63+
4564
def test_valid_inputs(self):
4665
self.test_method = 'KCSD1D'
4766
self.test_params = {'src_type': 22}
@@ -51,15 +70,15 @@ def test_valid_inputs(self):
5170
self.assertRaises(TypeError, self.test_kcsd1d_estimate)
5271
cv_params = {'InvalidCVArg': np.array((0.1, 0.25, 0.5))}
5372
self.assertRaises(TypeError, self.test_kcsd1d_estimate, cv_params)
54-
73+
5574

5675
class KCSD2D_TestCase(unittest.TestCase):
5776
def setUp(self):
5877
dim = 2
5978
utils = ValidateKCSD2D(csd_seed=43)
6079
self.ele_pos = utils.generate_electrodes(total_ele=49,
6180
ele_lims=[0.1, 0.9])
62-
self.csd_profile = CSD.gauss_2d_large
81+
self.csd_profile = CSDp.gauss_2d_large
6382
self.csd_at, self.csd = utils.generate_csd(self.csd_profile,
6483
csd_seed=43)
6584
pots = utils.calculate_potential(self.csd, self.csd_at, self.ele_pos,
@@ -95,6 +114,26 @@ def test_moi_estimate(self):
95114
self.assertLess(rms, 2.5, msg='RMS ' + str(rms) +
96115
'between trueCSD and estimate > 2.5')
97116

117+
def test_lcurve(self):
118+
result = KCSD2D(self.ele_pos, self.pots,
119+
**self.test_params)
120+
result.L_curve()
121+
vals = result.values()
122+
pvals = result.values('POT')
123+
true_csd = self.csd_profile(result.estm_pos, 43)
124+
rms = np.linalg.norm(np.array(vals[:, :, 0]) - true_csd)
125+
rms /= np.linalg.norm(true_csd)
126+
self.assertLess(rms, 0.5, msg='RMS ' + str(rms) +
127+
'between trueCSD and estimate > 0.5')
128+
129+
def test_method_generic_lim(self):
130+
self.test_params.update({'src_type': 'gauss_lim'})
131+
self.test_kcsd2d_estimate()
132+
133+
# def test_method_generic_step(self):
134+
# self.test_params.update({'src_type': 'step'})
135+
# self.test_kcsd2d_estimate()
136+
98137
def test_valid_inputs(self):
99138
self.test_method = 'KCSD2D'
100139
self.test_params = {'src_type': 22}
@@ -111,7 +150,7 @@ def setUp(self):
111150
utils = ValidateKCSD3D(csd_seed=44)
112151
self.ele_pos = utils.generate_electrodes(total_ele=64,
113152
ele_lims=[0.1, 0.9])
114-
self.csd_profile = CSD.gauss_3d_large
153+
self.csd_profile = CSDp.gauss_3d_large
115154
self.csd_at, self.csd = utils.generate_csd(self.csd_profile,
116155
csd_seed=44)
117156
pots = utils.calculate_potential(self.csd, self.csd_at, self.ele_pos,
@@ -130,13 +169,22 @@ def test_kcsd3d_estimate(self, cv_params={}):
130169
**self.test_params)
131170
result.cross_validate()
132171
vals = result.values()
172+
pvals = result.values('POT')
133173
true_csd = self.csd_profile(result.estm_pos, 44)
134-
print(true_csd.shape, vals.shape) # Meh here!
174+
# print(true_csd.shape, vals.shape) # Meh here!
135175
rms = np.linalg.norm(np.array(vals[:, :, :, 0]) - true_csd)
136176
rms /= np.linalg.norm(true_csd)
137177
self.assertLess(rms, 0.5, msg='RMS ' + str(rms) +
138178
'between trueCSD and estimate > 0.5')
139179

180+
def test_method_generic_lim(self):
181+
self.test_params.update({'src_type': 'gauss_lim'})
182+
self.test_kcsd3d_estimate()
183+
184+
# def test_method_generic_step(self):
185+
# self.test_params.update({'src_type': 'step'})
186+
# self.test_kcsd3d_estimate()
187+
140188
def test_valid_inputs(self):
141189
self.test_method = 'KCSD3D'
142190
self.test_params = {'src_type': 22}

0 commit comments

Comments
 (0)