|
| 1 | +import numpy as np |
| 2 | +from scipy.integrate import simps |
| 3 | +from numpy import exp, linspace |
| 4 | +import matplotlib.pyplot as plt |
| 5 | +from matplotlib import cm |
| 6 | +from matplotlib import gridspec |
| 7 | +from matplotlib.mlab import griddata |
| 8 | +from scipy.spatial import distance |
| 9 | +from kcsd import csd_profile as CSD |
| 10 | +from kcsd import KCSD3D |
| 11 | + |
| 12 | +def generate_csd_3D(csd_profile, csd_seed, |
| 13 | + start_x=0., end_x=1., |
| 14 | + start_y=0., end_y=1., |
| 15 | + start_z=0., end_z=1., |
| 16 | + res_x=50, res_y=50, |
| 17 | + res_z=50): |
| 18 | + """ |
| 19 | + Gives CSD profile at the requested spatial location, at 'res' resolution |
| 20 | + """ |
| 21 | + csd_at = np.mgrid[start_x:end_x:np.complex(0,res_x), |
| 22 | + start_y:end_y:np.complex(0,res_y), |
| 23 | + start_z:end_z:np.complex(0,res_z)] |
| 24 | + f = csd_profile(csd_at, seed=csd_seed) |
| 25 | + return csd_at, f |
| 26 | + |
| 27 | +def grid(x, y, z, resX=100, resY=100): |
| 28 | + """ |
| 29 | + Convert 3 column data to matplotlib grid |
| 30 | + """ |
| 31 | + x = x.flatten() |
| 32 | + y = y.flatten() |
| 33 | + z = z.flatten() |
| 34 | + xi = linspace(min(x), max(x), resX) |
| 35 | + yi = linspace(min(y), max(y), resY) |
| 36 | + zi = griddata(x, y, z, xi, yi, interp='linear') |
| 37 | + return xi, yi, zi |
| 38 | + |
| 39 | +def generate_electrodes(xlims=[0.1,0.9], ylims=[0.1,0.9], zlims=[0.1,0.9], res=5): |
| 40 | + """ |
| 41 | + Places electrodes in a square grid |
| 42 | + """ |
| 43 | + ele_x, ele_y, ele_z = np.mgrid[xlims[0]:xlims[1]:np.complex(0,res), |
| 44 | + ylims[0]:ylims[1]:np.complex(0,res), |
| 45 | + zlims[0]:zlims[1]:np.complex(0,res)] |
| 46 | + ele_x = ele_x.flatten() |
| 47 | + ele_y = ele_y.flatten() |
| 48 | + ele_z = ele_z.flatten() |
| 49 | + return ele_x, ele_y, ele_z |
| 50 | + |
| 51 | +def make_plots(fig_title, |
| 52 | + csd_at, true_csd, |
| 53 | + ele_x, ele_y, ele_z, pots, |
| 54 | + k_csd_x, k_csd_y, k_csd_z, est_csd): |
| 55 | + """ |
| 56 | + Shows 3 plots |
| 57 | + 1_ true CSD generated based on the random seed given |
| 58 | + 2_ interpolated LFT (NOT kCSD pot though), generated by simpsons rule integration |
| 59 | + 3_ results from the kCSD 2D for the default values |
| 60 | + """ |
| 61 | + t_csd_x, t_csd_y, t_csd_z = csd_at |
| 62 | + fig = plt.figure(figsize=(10,16)) |
| 63 | + #True CSD |
| 64 | + z_steps = 5 |
| 65 | + height_ratios = [1 for i in range(z_steps)] |
| 66 | + height_ratios.append(0.1) |
| 67 | + gs = gridspec.GridSpec(z_steps+1, 3, height_ratios=height_ratios) |
| 68 | + t_max = np.max(np.abs(true_csd)) |
| 69 | + levels = np.linspace(-1*t_max, t_max, 16) |
| 70 | + ind_interest = np.mgrid[0:t_csd_z.shape[2]:np.complex(0,z_steps+2)] |
| 71 | + ind_interest = np.array(ind_interest, dtype=np.int)[1:-1] |
| 72 | + for ii, idx in enumerate(ind_interest): |
| 73 | + ax = plt.subplot(gs[ii, 0]) |
| 74 | + im = plt.contourf(t_csd_x[:,:,idx], t_csd_y[:,:,idx], true_csd[:,:,idx], |
| 75 | + levels=levels, cmap=cm.bwr_r) |
| 76 | + ax.get_xaxis().set_visible(False) |
| 77 | + ax.get_yaxis().set_visible(False) |
| 78 | + title = str(t_csd_z[:,:,idx][0][0])[:4] |
| 79 | + ax.set_title(label=title, fontdict={'x':0.8, 'y':0.8}) |
| 80 | + ax.set_aspect('equal') |
| 81 | + cax = plt.subplot(gs[z_steps,0]) |
| 82 | + cbar = plt.colorbar(im, cax=cax, orientation='horizontal') |
| 83 | + cbar.set_ticks(levels[::2]) |
| 84 | + cbar.set_ticklabels(np.around(levels[::2], decimals=2)) |
| 85 | + #Potentials |
| 86 | + v_max = np.max(np.abs(pots)) |
| 87 | + levels_pot = np.linspace(-1*v_max, v_max, 16) |
| 88 | + ele_res = int(np.ceil(len(pots)**(3**-1))) |
| 89 | + ele_x = ele_x.reshape(ele_res, ele_res, ele_res) |
| 90 | + ele_y = ele_y.reshape(ele_res, ele_res, ele_res) |
| 91 | + ele_z = ele_z.reshape(ele_res, ele_res, ele_res) |
| 92 | + pots = pots.reshape(ele_res, ele_res, ele_res) |
| 93 | + for idx in range(min(5,ele_res)): |
| 94 | + X,Y,Z = grid(ele_x[:,:,idx], ele_y[:,:,idx], pots[:,:,idx]) |
| 95 | + ax = plt.subplot(gs[idx, 1]) |
| 96 | + im = plt.contourf(X, Y, Z, levels=levels_pot, cmap=cm.PRGn) |
| 97 | + ax.hold(True) |
| 98 | + plt.scatter(ele_x[:,:,idx], ele_y[:,:,idx], 5) |
| 99 | + ax.get_xaxis().set_visible(False) |
| 100 | + ax.get_yaxis().set_visible(False) |
| 101 | + title = str(ele_z[:,:,idx][0][0])[:4] |
| 102 | + ax.set_title(label=title, fontdict={'x':0.8, 'y':0.8}) |
| 103 | + ax.set_aspect('equal') |
| 104 | + ax.set_xlim([0.,1.]) |
| 105 | + ax.set_ylim([0.,1.]) |
| 106 | + cax = plt.subplot(gs[z_steps,1]) |
| 107 | + cbar2 = plt.colorbar(im, cax=cax, orientation='horizontal') |
| 108 | + cbar2.set_ticks(levels_pot[::2]) |
| 109 | + cbar2.set_ticklabels(np.around(levels_pot[::2], decimals=2)) |
| 110 | + # #KCSD |
| 111 | + t_max = np.max(np.abs(est_csd[:,:,:,0])) |
| 112 | + levels_kcsd = np.linspace(-1*t_max, t_max, 16) |
| 113 | + ind_interest = np.mgrid[0:k_csd_z.shape[2]:np.complex(0,z_steps+2)] |
| 114 | + ind_interest = np.array(ind_interest, dtype=np.int)[1:-1] |
| 115 | + for ii, idx in enumerate(ind_interest): |
| 116 | + ax = plt.subplot(gs[ii, 2]) |
| 117 | + im = plt.contourf(k_csd_x[:,:,idx], k_csd_y[:,:,idx], est_csd[:,:,idx,0], |
| 118 | + levels=levels_kcsd, cmap=cm.bwr_r) |
| 119 | + #im = plt.contourf(k_csd_x[:,:,idx], k_csd_y[:,:,idx], est_csd[:,:,idx,0], |
| 120 | + # levels=levels, cmap=cm.bwr_r) |
| 121 | + ax.get_xaxis().set_visible(False) |
| 122 | + ax.get_yaxis().set_visible(False) |
| 123 | + title = str(k_csd_z[:,:,idx][0][0])[:4] |
| 124 | + ax.set_title(label=title, fontdict={'x':0.8, 'y':0.8}) |
| 125 | + ax.set_aspect('equal') |
| 126 | + cax = plt.subplot(gs[z_steps,2]) |
| 127 | + cbar3 = plt.colorbar(im, cax=cax, orientation='horizontal') |
| 128 | + cbar3.set_ticks(levels_kcsd[::2]) |
| 129 | + #cbar3.set_ticks(levels[::2]) |
| 130 | + cbar3.set_ticklabels(np.around(levels_kcsd[::2], decimals=2)) |
| 131 | + #cbar3.set_ticklabels(np.around(levels[::2], decimals=2)) |
| 132 | + fig.suptitle("Lambda,R,CV_Error,RMS_Error,Time = "+fig_title) |
| 133 | + gs.tight_layout(fig, rect=[0, 0.03, 1, 0.95]) |
| 134 | + # #Showing |
| 135 | + #plt.tight_layout() |
| 136 | + plt.show() |
| 137 | + return |
| 138 | + |
| 139 | +def integrate_3D(x, y, z, xlim, ylim, zlim, csd, xlin, ylin, zlin, X, Y, Z): |
| 140 | + """ |
| 141 | + X,Y - parts of meshgrid - Mihav's implementation |
| 142 | + """ |
| 143 | + Nz = zlin.shape[0] |
| 144 | + Ny = ylin.shape[0] |
| 145 | + m = np.sqrt((x - X)**2 + (y - Y)**2 + (z - Z)**2) |
| 146 | + m[m < 0.0000001] = 0.0000001 |
| 147 | + z = csd / m |
| 148 | + Iy = np.zeros(Ny) |
| 149 | + for j in range(Ny): |
| 150 | + Iz = np.zeros(Nz) |
| 151 | + for i in range(Nz): |
| 152 | + Iz[i] = simps(z[:,j,i], zlin) |
| 153 | + Iy[j] = simps(Iz, ylin) |
| 154 | + F = simps(Iy, xlin) |
| 155 | + return F |
| 156 | + |
| 157 | +def calculate_potential_3D(true_csd, ele_xx, ele_yy, ele_zz, |
| 158 | + csd_x, csd_y, csd_z): |
| 159 | + """ |
| 160 | + For Mihav's implementation to compute the LFP generated |
| 161 | + """ |
| 162 | + xlin = csd_x[:,0,0] |
| 163 | + ylin = csd_y[0,:,0] |
| 164 | + zlin = csd_z[0,0,:] |
| 165 | + xlims = [xlin[0], xlin[-1]] |
| 166 | + ylims = [ylin[0], ylin[-1]] |
| 167 | + zlims = [zlin[0], zlin[-1]] |
| 168 | + sigma = 1.0 |
| 169 | + pots = np.zeros(len(ele_xx)) |
| 170 | + for ii in range(len(ele_xx)): |
| 171 | + pots[ii] = integrate_3D(ele_xx[ii], ele_yy[ii], ele_zz[ii], |
| 172 | + xlims, ylims, zlims, true_csd, |
| 173 | + xlin, ylin, zlin, |
| 174 | + csd_x, csd_y, csd_z) |
| 175 | + print('Electrode:', ii) |
| 176 | + pots /= 4*np.pi*sigma |
| 177 | + return pots |
| 178 | + |
| 179 | + |
| 180 | +def electrode_config(ele_lims, ele_res, true_csd, csd_at): |
| 181 | + """ |
| 182 | + What is the configuration of electrode positions, between what and what positions |
| 183 | + """ |
| 184 | + #Potentials |
| 185 | + csd_x, csd_y, csd_z = csd_at |
| 186 | + ele_x_lims = ele_y_lims = ele_z_lims = ele_lims |
| 187 | + ele_x, ele_y, ele_z = generate_electrodes(ele_x_lims, ele_y_lims, ele_z_lims, ele_res) |
| 188 | + pots = calculate_potential_3D(true_csd, |
| 189 | + ele_x, ele_y, ele_z, |
| 190 | + csd_x, csd_y, csd_z) |
| 191 | + ele_pos = np.vstack((ele_x, ele_y, ele_z)).T #Electrode configs |
| 192 | + num_ele = ele_pos.shape[0] |
| 193 | + print('Number of electrodes:', num_ele) |
| 194 | + return ele_pos, pots |
| 195 | + |
| 196 | +def do_kcsd(ele_pos, pots, **params): |
| 197 | + """ |
| 198 | + Function that calls the KCSD3D module |
| 199 | + """ |
| 200 | + num_ele = len(ele_pos) |
| 201 | + pots = pots.reshape(num_ele, 1) |
| 202 | + k = KCSD3D(ele_pos, pots, **params) |
| 203 | + #k.cross_validate(Rs=np.arange(0.2,0.4,0.02)) |
| 204 | + #k.cross_validate(Rs=np.arange(0.02,0.27,0.01)) |
| 205 | + k.cross_validate(Rs=np.array(0.31).reshape(1)) |
| 206 | + est_csd = k.values('CSD') |
| 207 | + return k, est_csd |
| 208 | + |
| 209 | +def main_loop(csd_profile, csd_seed, total_ele, num_init_srcs=1000): |
| 210 | + """ |
| 211 | + Loop that decides the random number seed for the CSD profile, |
| 212 | + electrode configurations and etc. |
| 213 | + """ |
| 214 | + csd_name = csd_profile.__name__ |
| 215 | + print('Using sources %s - Seed: %d ' % (csd_name, csd_seed)) |
| 216 | + |
| 217 | + #TrueCSD |
| 218 | + csd_at, true_csd = generate_csd_3D(csd_profile, csd_seed, |
| 219 | + start_x=0., end_x=1., |
| 220 | + start_y=0., end_y=1., |
| 221 | + start_z=0., end_z=1., |
| 222 | + res_x=100, res_y=100, |
| 223 | + res_z=100) |
| 224 | + |
| 225 | + #Electrodes |
| 226 | + ele_lims = [0.15, 0.85] #square grid, xy min,max limits |
| 227 | + ele_res = int(np.ceil(total_ele**(3**-1))) #resolution of electrode grid |
| 228 | + ele_pos, pots = electrode_config(ele_lims, ele_res, true_csd, csd_at) |
| 229 | + ele_x = ele_pos[:, 0] |
| 230 | + ele_y = ele_pos[:, 1] |
| 231 | + ele_z = ele_pos[:, 2] |
| 232 | + |
| 233 | + #kCSD estimation |
| 234 | + gdX = 0.05 |
| 235 | + gdY = 0.05 |
| 236 | + gdZ = 0.05 |
| 237 | + x_lims = [.0,1.] #CSD estimation place |
| 238 | + y_lims = [.0,1.] |
| 239 | + z_lims = [.0,1.] |
| 240 | + params = {'h':50., 'src_type': 'gauss', |
| 241 | + 'gdX': gdX, 'gdY': gdY, 'gdZ': gdZ, |
| 242 | + 'xmin': x_lims[0], 'xmax': x_lims[1], |
| 243 | + 'ymin': y_lims[0], 'ymax': y_lims[1], |
| 244 | + 'zmin': y_lims[0], 'zmax': y_lims[1], |
| 245 | + 'n_src_init': num_init_srcs} |
| 246 | + k, est_csd = do_kcsd(ele_pos, pots, h=50., |
| 247 | + gdx=gdX, gdy= gdY, gdz=gdZ, |
| 248 | + xmin=x_lims[0], xmax=x_lims[1], |
| 249 | + ymin=y_lims[0], ymax=y_lims[1], |
| 250 | + zmin=z_lims[0], zmax=z_lims[1], |
| 251 | + n_src_init=num_init_srcs, src_type='step') |
| 252 | + |
| 253 | + #RMS of estimation - gives estimate of how good the reconstruction was |
| 254 | + chr_at, test_csd = generate_csd_3D(csd_profile, csd_seed, |
| 255 | + start_x=x_lims[0], end_x=x_lims[1], |
| 256 | + start_y=y_lims[0], end_y=y_lims[1], |
| 257 | + start_z=z_lims[0], end_z=z_lims[1], |
| 258 | + res_x=int((x_lims[1]-x_lims[0])/gdX), |
| 259 | + res_y=int((y_lims[1]-y_lims[0])/gdY), |
| 260 | + res_z=int((z_lims[1]-z_lims[0])/gdZ)) |
| 261 | + rms = np.linalg.norm(abs(test_csd - est_csd[:,:,:,0])) |
| 262 | + rms /= np.linalg.norm(test_csd) |
| 263 | + |
| 264 | + #Plots |
| 265 | + title = str(k.lambd)+','+str(k.R)+', '+str(k.cv_error)+', '+str(rms) |
| 266 | + save_as = csd_name+'_'+str(csd_seed)+'of'+str(total_ele) |
| 267 | + #save_as = csd_name+'_'+str(num_init_srcs)+'_'+str(total_ele) |
| 268 | + |
| 269 | + make_plots(title, |
| 270 | + chr_at, test_csd, |
| 271 | + ele_x, ele_y, ele_z, pots, |
| 272 | + k.estm_x, k.estm_y, k.estm_z, est_csd) |
| 273 | + #save |
| 274 | + result_kcsd = [k.lambd, k.R, k.cv_error, rms] |
| 275 | + return est_csd, result_kcsd |
| 276 | + |
| 277 | +if __name__=='__main__': |
| 278 | + total_ele = 9 |
| 279 | + #Normal run |
| 280 | + csd_seed = 54 #0-49 are small sources, 50-99 are large sources |
| 281 | + csd_profile = CSD.gauss_3d_small |
| 282 | + main_loop(csd_profile, csd_seed, total_ele, 1000) |
0 commit comments