Skip to content

Commit ee322b9

Browse files
author
Joanna Jędrzejewska-Szmek
authored
Add lost function check_estimated_shape (#101)
* Add lost function check_estimated_shape This function transforms 1D output of shape (n,) to a 2D array of shape (n, 1). * Add tests for check_estimated_shape
1 parent 8cf977a commit ee322b9

2 files changed

Lines changed: 29 additions & 0 deletions

File tree

kcsd/sKCSD_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,15 @@ def load_elpos(path):
168168
raise Exception('Unknown electrode position file format.')
169169
return ele_pos
170170

171+
172+
def check_estimated_shape(to_estimate):
173+
if len(to_estimate.shape) == 1:
174+
estimated = np.ndarray((to_estimate.shape[0], 1))
175+
estimated[:, 0] = to_estimate
176+
return estimated
177+
return to_estimate
178+
179+
171180
def _bresenhamline_nslope(slope):
172181
"""
173182
Normalize slope for Bresenham's line algorithm.

kcsd/tests/test_sKCSDutils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import print_function, division, absolute_import
2+
from kcsd.sKCSD_utils import check_estimated_shape
3+
import os
4+
import unittest
5+
import numpy as np
6+
7+
class testCheckEstimatedShape(unittest.TestCase):
8+
def test_unchanged(self):
9+
array = np.ones((1, 5))
10+
out = check_estimated_shape(array)
11+
self.assertEqual(array.shape, out.shape)
12+
13+
def test_changed(self):
14+
array = np.ones((5, ))
15+
out = check_estimated_shape(array)
16+
self.assertEqual(out.shape, (5, 1))
17+
18+
19+
if __name__ == '__main__':
20+
unittest.main()

0 commit comments

Comments
 (0)