Skip to content

Commit 65e46dc

Browse files
committed
first pass 3D reconstruction
1 parent dd22f89 commit 65e46dc

1 file changed

Lines changed: 282 additions & 0 deletions

File tree

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
import numpy as np
2+
from scipy.integrate import simps
3+
from numpy import exp, linspace
4+
import matplotlib.pyplot as plt
5+
from matplotlib import cm
6+
from matplotlib import gridspec
7+
from matplotlib.mlab import griddata
8+
from scipy.spatial import distance
9+
from kcsd import csd_profile as CSD
10+
from kcsd import KCSD3D
11+
12+
def generate_csd_3D(csd_profile, csd_seed,
13+
start_x=0., end_x=1.,
14+
start_y=0., end_y=1.,
15+
start_z=0., end_z=1.,
16+
res_x=50, res_y=50,
17+
res_z=50):
18+
"""
19+
Gives CSD profile at the requested spatial location, at 'res' resolution
20+
"""
21+
csd_at = np.mgrid[start_x:end_x:np.complex(0,res_x),
22+
start_y:end_y:np.complex(0,res_y),
23+
start_z:end_z:np.complex(0,res_z)]
24+
f = csd_profile(csd_at, seed=csd_seed)
25+
return csd_at, f
26+
27+
def grid(x, y, z, resX=100, resY=100):
28+
"""
29+
Convert 3 column data to matplotlib grid
30+
"""
31+
x = x.flatten()
32+
y = y.flatten()
33+
z = z.flatten()
34+
xi = linspace(min(x), max(x), resX)
35+
yi = linspace(min(y), max(y), resY)
36+
zi = griddata(x, y, z, xi, yi, interp='linear')
37+
return xi, yi, zi
38+
39+
def generate_electrodes(xlims=[0.1,0.9], ylims=[0.1,0.9], zlims=[0.1,0.9], res=5):
40+
"""
41+
Places electrodes in a square grid
42+
"""
43+
ele_x, ele_y, ele_z = np.mgrid[xlims[0]:xlims[1]:np.complex(0,res),
44+
ylims[0]:ylims[1]:np.complex(0,res),
45+
zlims[0]:zlims[1]:np.complex(0,res)]
46+
ele_x = ele_x.flatten()
47+
ele_y = ele_y.flatten()
48+
ele_z = ele_z.flatten()
49+
return ele_x, ele_y, ele_z
50+
51+
def make_plots(fig_title,
52+
csd_at, true_csd,
53+
ele_x, ele_y, ele_z, pots,
54+
k_csd_x, k_csd_y, k_csd_z, est_csd):
55+
"""
56+
Shows 3 plots
57+
1_ true CSD generated based on the random seed given
58+
2_ interpolated LFT (NOT kCSD pot though), generated by simpsons rule integration
59+
3_ results from the kCSD 2D for the default values
60+
"""
61+
t_csd_x, t_csd_y, t_csd_z = csd_at
62+
fig = plt.figure(figsize=(10,16))
63+
#True CSD
64+
z_steps = 5
65+
height_ratios = [1 for i in range(z_steps)]
66+
height_ratios.append(0.1)
67+
gs = gridspec.GridSpec(z_steps+1, 3, height_ratios=height_ratios)
68+
t_max = np.max(np.abs(true_csd))
69+
levels = np.linspace(-1*t_max, t_max, 16)
70+
ind_interest = np.mgrid[0:t_csd_z.shape[2]:np.complex(0,z_steps+2)]
71+
ind_interest = np.array(ind_interest, dtype=np.int)[1:-1]
72+
for ii, idx in enumerate(ind_interest):
73+
ax = plt.subplot(gs[ii, 0])
74+
im = plt.contourf(t_csd_x[:,:,idx], t_csd_y[:,:,idx], true_csd[:,:,idx],
75+
levels=levels, cmap=cm.bwr_r)
76+
ax.get_xaxis().set_visible(False)
77+
ax.get_yaxis().set_visible(False)
78+
title = str(t_csd_z[:,:,idx][0][0])[:4]
79+
ax.set_title(label=title, fontdict={'x':0.8, 'y':0.8})
80+
ax.set_aspect('equal')
81+
cax = plt.subplot(gs[z_steps,0])
82+
cbar = plt.colorbar(im, cax=cax, orientation='horizontal')
83+
cbar.set_ticks(levels[::2])
84+
cbar.set_ticklabels(np.around(levels[::2], decimals=2))
85+
#Potentials
86+
v_max = np.max(np.abs(pots))
87+
levels_pot = np.linspace(-1*v_max, v_max, 16)
88+
ele_res = int(np.ceil(len(pots)**(3**-1)))
89+
ele_x = ele_x.reshape(ele_res, ele_res, ele_res)
90+
ele_y = ele_y.reshape(ele_res, ele_res, ele_res)
91+
ele_z = ele_z.reshape(ele_res, ele_res, ele_res)
92+
pots = pots.reshape(ele_res, ele_res, ele_res)
93+
for idx in range(min(5,ele_res)):
94+
X,Y,Z = grid(ele_x[:,:,idx], ele_y[:,:,idx], pots[:,:,idx])
95+
ax = plt.subplot(gs[idx, 1])
96+
im = plt.contourf(X, Y, Z, levels=levels_pot, cmap=cm.PRGn)
97+
ax.hold(True)
98+
plt.scatter(ele_x[:,:,idx], ele_y[:,:,idx], 5)
99+
ax.get_xaxis().set_visible(False)
100+
ax.get_yaxis().set_visible(False)
101+
title = str(ele_z[:,:,idx][0][0])[:4]
102+
ax.set_title(label=title, fontdict={'x':0.8, 'y':0.8})
103+
ax.set_aspect('equal')
104+
ax.set_xlim([0.,1.])
105+
ax.set_ylim([0.,1.])
106+
cax = plt.subplot(gs[z_steps,1])
107+
cbar2 = plt.colorbar(im, cax=cax, orientation='horizontal')
108+
cbar2.set_ticks(levels_pot[::2])
109+
cbar2.set_ticklabels(np.around(levels_pot[::2], decimals=2))
110+
# #KCSD
111+
t_max = np.max(np.abs(est_csd[:,:,:,0]))
112+
levels_kcsd = np.linspace(-1*t_max, t_max, 16)
113+
ind_interest = np.mgrid[0:k_csd_z.shape[2]:np.complex(0,z_steps+2)]
114+
ind_interest = np.array(ind_interest, dtype=np.int)[1:-1]
115+
for ii, idx in enumerate(ind_interest):
116+
ax = plt.subplot(gs[ii, 2])
117+
im = plt.contourf(k_csd_x[:,:,idx], k_csd_y[:,:,idx], est_csd[:,:,idx,0],
118+
levels=levels_kcsd, cmap=cm.bwr_r)
119+
#im = plt.contourf(k_csd_x[:,:,idx], k_csd_y[:,:,idx], est_csd[:,:,idx,0],
120+
# levels=levels, cmap=cm.bwr_r)
121+
ax.get_xaxis().set_visible(False)
122+
ax.get_yaxis().set_visible(False)
123+
title = str(k_csd_z[:,:,idx][0][0])[:4]
124+
ax.set_title(label=title, fontdict={'x':0.8, 'y':0.8})
125+
ax.set_aspect('equal')
126+
cax = plt.subplot(gs[z_steps,2])
127+
cbar3 = plt.colorbar(im, cax=cax, orientation='horizontal')
128+
cbar3.set_ticks(levels_kcsd[::2])
129+
#cbar3.set_ticks(levels[::2])
130+
cbar3.set_ticklabels(np.around(levels_kcsd[::2], decimals=2))
131+
#cbar3.set_ticklabels(np.around(levels[::2], decimals=2))
132+
fig.suptitle("Lambda,R,CV_Error,RMS_Error,Time = "+fig_title)
133+
gs.tight_layout(fig, rect=[0, 0.03, 1, 0.95])
134+
# #Showing
135+
#plt.tight_layout()
136+
plt.show()
137+
return
138+
139+
def integrate_3D(x, y, z, xlim, ylim, zlim, csd, xlin, ylin, zlin, X, Y, Z):
140+
"""
141+
X,Y - parts of meshgrid - Mihav's implementation
142+
"""
143+
Nz = zlin.shape[0]
144+
Ny = ylin.shape[0]
145+
m = np.sqrt((x - X)**2 + (y - Y)**2 + (z - Z)**2)
146+
m[m < 0.0000001] = 0.0000001
147+
z = csd / m
148+
Iy = np.zeros(Ny)
149+
for j in range(Ny):
150+
Iz = np.zeros(Nz)
151+
for i in range(Nz):
152+
Iz[i] = simps(z[:,j,i], zlin)
153+
Iy[j] = simps(Iz, ylin)
154+
F = simps(Iy, xlin)
155+
return F
156+
157+
def calculate_potential_3D(true_csd, ele_xx, ele_yy, ele_zz,
158+
csd_x, csd_y, csd_z):
159+
"""
160+
For Mihav's implementation to compute the LFP generated
161+
"""
162+
xlin = csd_x[:,0,0]
163+
ylin = csd_y[0,:,0]
164+
zlin = csd_z[0,0,:]
165+
xlims = [xlin[0], xlin[-1]]
166+
ylims = [ylin[0], ylin[-1]]
167+
zlims = [zlin[0], zlin[-1]]
168+
sigma = 1.0
169+
pots = np.zeros(len(ele_xx))
170+
for ii in range(len(ele_xx)):
171+
pots[ii] = integrate_3D(ele_xx[ii], ele_yy[ii], ele_zz[ii],
172+
xlims, ylims, zlims, true_csd,
173+
xlin, ylin, zlin,
174+
csd_x, csd_y, csd_z)
175+
print('Electrode:', ii)
176+
pots /= 4*np.pi*sigma
177+
return pots
178+
179+
180+
def electrode_config(ele_lims, ele_res, true_csd, csd_at):
181+
"""
182+
What is the configuration of electrode positions, between what and what positions
183+
"""
184+
#Potentials
185+
csd_x, csd_y, csd_z = csd_at
186+
ele_x_lims = ele_y_lims = ele_z_lims = ele_lims
187+
ele_x, ele_y, ele_z = generate_electrodes(ele_x_lims, ele_y_lims, ele_z_lims, ele_res)
188+
pots = calculate_potential_3D(true_csd,
189+
ele_x, ele_y, ele_z,
190+
csd_x, csd_y, csd_z)
191+
ele_pos = np.vstack((ele_x, ele_y, ele_z)).T #Electrode configs
192+
num_ele = ele_pos.shape[0]
193+
print('Number of electrodes:', num_ele)
194+
return ele_pos, pots
195+
196+
def do_kcsd(ele_pos, pots, **params):
197+
"""
198+
Function that calls the KCSD3D module
199+
"""
200+
num_ele = len(ele_pos)
201+
pots = pots.reshape(num_ele, 1)
202+
k = KCSD3D(ele_pos, pots, **params)
203+
#k.cross_validate(Rs=np.arange(0.2,0.4,0.02))
204+
#k.cross_validate(Rs=np.arange(0.02,0.27,0.01))
205+
k.cross_validate(Rs=np.array(0.31).reshape(1))
206+
est_csd = k.values('CSD')
207+
return k, est_csd
208+
209+
def main_loop(csd_profile, csd_seed, total_ele, num_init_srcs=1000):
210+
"""
211+
Loop that decides the random number seed for the CSD profile,
212+
electrode configurations and etc.
213+
"""
214+
csd_name = csd_profile.__name__
215+
print('Using sources %s - Seed: %d ' % (csd_name, csd_seed))
216+
217+
#TrueCSD
218+
csd_at, true_csd = generate_csd_3D(csd_profile, csd_seed,
219+
start_x=0., end_x=1.,
220+
start_y=0., end_y=1.,
221+
start_z=0., end_z=1.,
222+
res_x=100, res_y=100,
223+
res_z=100)
224+
225+
#Electrodes
226+
ele_lims = [0.15, 0.85] #square grid, xy min,max limits
227+
ele_res = int(np.ceil(total_ele**(3**-1))) #resolution of electrode grid
228+
ele_pos, pots = electrode_config(ele_lims, ele_res, true_csd, csd_at)
229+
ele_x = ele_pos[:, 0]
230+
ele_y = ele_pos[:, 1]
231+
ele_z = ele_pos[:, 2]
232+
233+
#kCSD estimation
234+
gdX = 0.05
235+
gdY = 0.05
236+
gdZ = 0.05
237+
x_lims = [.0,1.] #CSD estimation place
238+
y_lims = [.0,1.]
239+
z_lims = [.0,1.]
240+
params = {'h':50., 'src_type': 'gauss',
241+
'gdX': gdX, 'gdY': gdY, 'gdZ': gdZ,
242+
'xmin': x_lims[0], 'xmax': x_lims[1],
243+
'ymin': y_lims[0], 'ymax': y_lims[1],
244+
'zmin': y_lims[0], 'zmax': y_lims[1],
245+
'n_src_init': num_init_srcs}
246+
k, est_csd = do_kcsd(ele_pos, pots, h=50.,
247+
gdx=gdX, gdy= gdY, gdz=gdZ,
248+
xmin=x_lims[0], xmax=x_lims[1],
249+
ymin=y_lims[0], ymax=y_lims[1],
250+
zmin=z_lims[0], zmax=z_lims[1],
251+
n_src_init=num_init_srcs, src_type='step')
252+
253+
#RMS of estimation - gives estimate of how good the reconstruction was
254+
chr_at, test_csd = generate_csd_3D(csd_profile, csd_seed,
255+
start_x=x_lims[0], end_x=x_lims[1],
256+
start_y=y_lims[0], end_y=y_lims[1],
257+
start_z=z_lims[0], end_z=z_lims[1],
258+
res_x=int((x_lims[1]-x_lims[0])/gdX),
259+
res_y=int((y_lims[1]-y_lims[0])/gdY),
260+
res_z=int((z_lims[1]-z_lims[0])/gdZ))
261+
rms = np.linalg.norm(abs(test_csd - est_csd[:,:,:,0]))
262+
rms /= np.linalg.norm(test_csd)
263+
264+
#Plots
265+
title = str(k.lambd)+','+str(k.R)+', '+str(k.cv_error)+', '+str(rms)
266+
save_as = csd_name+'_'+str(csd_seed)+'of'+str(total_ele)
267+
#save_as = csd_name+'_'+str(num_init_srcs)+'_'+str(total_ele)
268+
269+
make_plots(title,
270+
chr_at, test_csd,
271+
ele_x, ele_y, ele_z, pots,
272+
k.estm_x, k.estm_y, k.estm_z, est_csd)
273+
#save
274+
result_kcsd = [k.lambd, k.R, k.cv_error, rms]
275+
return est_csd, result_kcsd
276+
277+
if __name__=='__main__':
278+
total_ele = 9
279+
#Normal run
280+
csd_seed = 54 #0-49 are small sources, 50-99 are large sources
281+
csd_profile = CSD.gauss_3d_small
282+
main_loop(csd_profile, csd_seed, total_ele, 1000)

0 commit comments

Comments
 (0)