|
3 | 3 | """ |
4 | 4 |
|
5 | 5 | import os |
6 | | -from os.path import expanduser |
7 | 6 | import numpy as np |
8 | 7 | from figure_properties import * |
9 | 8 | import matplotlib.pyplot as plt |
10 | 9 | import matplotlib.gridspec as gridspec |
11 | | -import datetime |
12 | | -import time |
13 | 10 |
|
14 | 11 | import targeted_basis as tb |
15 | 12 |
|
@@ -122,7 +119,7 @@ def make_subplot(ax, true_csd, est_csd, estm_x, title=None, ele_pos=None, |
122 | 119 | return ax |
123 | 120 |
|
124 | 121 |
|
125 | | -def generate_figure(R, MU, n_src, true_csd_xlims, total_ele, save_path, |
| 122 | +def generate_figure(R, MU, n_src, true_csd_xlims, total_ele, |
126 | 123 | method='cross-validation', Rs=None, lambdas=None, |
127 | 124 | noise=0): |
128 | 125 | """ |
@@ -284,232 +281,21 @@ def generate_figure(R, MU, n_src, true_csd_xlims, total_ele, save_path, |
284 | 281 | handles, labels = ax.get_legend_handles_labels() |
285 | 282 | fig.legend(handles, labels, loc='lower center', ncol=3, frameon=False) |
286 | 283 |
|
287 | | - fig.savefig(os.path.join(save_path, 'targeted_basis_' + method + |
288 | | - '_noise_' + str(noise) + '.png'), dpi=300) |
289 | | - plt.show() |
290 | | - |
291 | | - |
292 | | -def generate_figure_CVLC(R, MU, n_src, true_csd_xlims, total_ele, save_path, |
293 | | - Rs=None, lambdas=None, noise=0): |
294 | | - """ |
295 | | - Generates figure for targeted basis investigation including results from |
296 | | - both cross-validation and L-curve. |
297 | | -
|
298 | | - Parameters |
299 | | - ---------- |
300 | | - R: float |
301 | | - Thickness of the groundtruth source. |
302 | | - Default: 0.2. |
303 | | - MU: float |
304 | | - Central position of Gaussian source |
305 | | - Default: 0.25. |
306 | | - nr_src: int |
307 | | - Number of basis sources. |
308 | | - true_csd_xlims: list |
309 | | - Boundaries for ground truth space. |
310 | | - total_ele: int |
311 | | - Number of electrodes. |
312 | | - save_path: string |
313 | | - Directory. |
314 | | - Rs: numpy 1D array |
315 | | - Basis source parameter for crossvalidation. |
316 | | - Default: None. |
317 | | - lambdas: numpy 1D array |
318 | | - Regularization parameter for crossvalidation. |
319 | | - Default: None. |
320 | | - noise: float |
321 | | - Determines the level of noise in the data. |
322 | | - Default: 0. |
323 | | -
|
324 | | - Returns |
325 | | - ------- |
326 | | - None |
327 | | - """ |
328 | | - |
329 | | - m_cv = 'cross-validation' |
330 | | - m_lc = 'L-curve' |
331 | | - method = 'CV_LC' |
332 | | - ele_lims = [0, 1.] |
333 | | - csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(tb.csd_profile, |
334 | | - true_csd_xlims, R, |
335 | | - MU, total_ele, |
336 | | - ele_lims, |
337 | | - noise=noise) |
338 | | - |
339 | | - fig = plt.figure(figsize=(15, 12)) |
340 | | - widths = [1, 1, 1] |
341 | | - heights = [1, 1, 1] |
342 | | - gs = gridspec.GridSpec(3, 3, height_ratios=heights, width_ratios=widths, |
343 | | - hspace=0.45, wspace=0.3) |
344 | | - |
345 | | - ax = fig.add_subplot(gs[0, 0]) |
346 | | - xmin = 0 |
347 | | - xmax = 1 |
348 | | - ext_x = 0 |
349 | | - obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
350 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
351 | | - xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas) |
352 | | - obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
353 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
354 | | - xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas) |
355 | | - make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x, |
356 | | - ele_pos=ele_pos, title='Basis limits = [0, 1]', xlabel=False, |
357 | | - ylabel=True, letter='A', est_csd_LC=obj_LC.values('CSD')) |
358 | | - |
359 | | - ax = fig.add_subplot(gs[0, 1]) |
360 | | - xmin = -0.5 |
361 | | - xmax = 1 |
362 | | - ext_x = -0.5 |
363 | | - obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
364 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
365 | | - xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas) |
366 | | - obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
367 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
368 | | - xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas) |
369 | | - make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x, |
370 | | - ele_pos=ele_pos, title='Basis limits = [0, 0.5]', |
371 | | - xlabel=False, ylabel=False, letter='B', |
372 | | - est_csd_LC=obj_LC.values('CSD')) |
373 | | - |
374 | | - ax = fig.add_subplot(gs[0, 2]) |
375 | | - xmin = 0 |
376 | | - xmax = 1.5 |
377 | | - ext_x = -0.5 |
378 | | - obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
379 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
380 | | - xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas) |
381 | | - obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
382 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
383 | | - xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas) |
384 | | - make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x, |
385 | | - ele_pos=ele_pos, title='Basis limits = [0.5, 1]', |
386 | | - xlabel=False, ylabel=False, letter='C', |
387 | | - est_csd_LC=obj_LC.values('CSD')) |
388 | | - |
389 | | - ele_lims = [0, 0.5] |
390 | | -# total_ele = 6 |
391 | | - csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(tb.csd_profile, |
392 | | - true_csd_xlims, R, |
393 | | - MU, total_ele, |
394 | | - ele_lims, |
395 | | - noise=noise) |
396 | | - ax = fig.add_subplot(gs[1, 0]) |
397 | | - xmin = 0 |
398 | | - xmax = 1 |
399 | | - ext_x = 0 |
400 | | - obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
401 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
402 | | - xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas) |
403 | | - obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
404 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
405 | | - xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas) |
406 | | - make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x, |
407 | | - ele_pos=ele_pos, title=None, xlabel=False, ylabel=True, |
408 | | - letter='D', est_csd_LC=obj_LC.values('CSD')) |
409 | | - |
410 | | - ax = fig.add_subplot(gs[1, 1]) |
411 | | - xmin = -0.5 |
412 | | - xmax = 1 |
413 | | - ext_x = -0.5 |
414 | | - obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
415 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
416 | | - xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas) |
417 | | - obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
418 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
419 | | - xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas) |
420 | | - make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x, |
421 | | - ele_pos=ele_pos, title=None, xlabel=False, ylabel=False, |
422 | | - letter='E', est_csd_LC=obj_LC.values('CSD')) |
423 | | - |
424 | | - ax = fig.add_subplot(gs[1, 2]) |
425 | | - xmin = 0 |
426 | | - xmax = 1.5 |
427 | | - ext_x = -0.5 |
428 | | - obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
429 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
430 | | - xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas) |
431 | | - obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
432 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
433 | | - xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas) |
434 | | - make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x, |
435 | | - ele_pos=ele_pos, title=None, xlabel=False, ylabel=False, |
436 | | - letter='F', est_csd_LC=obj_LC.values('CSD')) |
437 | | - |
438 | | - ele_lims = [0.5, 1.] |
439 | | - csd_at, true_csd, ele_pos, pots, val = tb.simulate_data(tb.csd_profile, |
440 | | - true_csd_xlims, R, |
441 | | - MU, total_ele, |
442 | | - ele_lims, |
443 | | - noise=noise) |
444 | | - ax = fig.add_subplot(gs[2, 0]) |
445 | | - xmin = 0 |
446 | | - xmax = 1 |
447 | | - ext_x = 0 |
448 | | - obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
449 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
450 | | - xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas) |
451 | | - obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
452 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
453 | | - xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas) |
454 | | - make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x, |
455 | | - ele_pos=ele_pos, title=None, xlabel=True, ylabel=True, |
456 | | - letter='G', est_csd_LC=obj_LC.values('CSD')) |
457 | | - |
458 | | - ax = fig.add_subplot(gs[2, 1]) |
459 | | - xmin = -0.5 |
460 | | - xmax = 1 |
461 | | - ext_x = -0.5 |
462 | | - obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
463 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
464 | | - xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas) |
465 | | - obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
466 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
467 | | - xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas) |
468 | | - make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x, |
469 | | - ele_pos=ele_pos, title=None, xlabel=True, ylabel=False, |
470 | | - letter='H', est_csd_LC=obj_LC.values('CSD')) |
471 | | - |
472 | | - ax = fig.add_subplot(gs[2, 2]) |
473 | | - xmin = 0 |
474 | | - xmax = 1.5 |
475 | | - ext_x = -0.5 |
476 | | - obj_CV = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
477 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
478 | | - xmax=xmax, method=m_cv, Rs=Rs, lambdas=lambdas) |
479 | | - obj_LC = tb.modified_bases(val, pots, ele_pos, n_src, h=0.25, |
480 | | - sigma=0.3, gdx=0.01, ext_x=ext_x, xmin=xmin, |
481 | | - xmax=xmax, method=m_lc, Rs=Rs, lambdas=lambdas) |
482 | | - ax = make_subplot(ax, true_csd, obj_CV.values('CSD'), obj_CV.estm_x, |
483 | | - ele_pos=ele_pos, title=None, xlabel=True, ylabel=False, |
484 | | - letter='I', est_csd_LC=obj_LC.values('CSD')) |
485 | | - handles, labels = ax.get_legend_handles_labels() |
486 | | - fig.legend(handles, labels, loc='lower center', ncol=4, frameon=False) |
487 | | - |
488 | | - fig.savefig(os.path.join(save_path, 'targeted_basis_' + method + |
| 284 | + fig.savefig(os.path.join('targeted_basis_' + method + |
489 | 285 | '_noise_' + str(noise) + '.png'), dpi=300) |
490 | 286 | plt.show() |
491 | 287 |
|
492 | 288 |
|
493 | 289 | if __name__ == '__main__': |
494 | | - home = expanduser('~') |
495 | | - DAY = datetime.datetime.now() |
496 | | - DAY = DAY.strftime('%Y%m%d') |
497 | | - TIMESTR = time.strftime("%H%M%S") |
498 | | - SAVE_PATH = home + "/kCSD_results/" + DAY + '/' + TIMESTR |
499 | | - tb.makemydir(SAVE_PATH) |
500 | | - tb.save_source_code(SAVE_PATH, time.strftime("%Y%m%d-%H%M%S")) |
501 | | - |
502 | 290 | N_SRC = 64 |
503 | 291 | TRUE_CSD_XLIMS = [0., 1.] |
504 | 292 | TOTAL_ELE = 12 |
505 | 293 | R = 0.2 |
506 | 294 | MU = 0.25 |
507 | 295 | method = 'cross-validation' # L-curve |
508 | 296 | # method = 'L-curve' |
509 | | -# Rs = np.arange(0.1, 0.4, 0.05) |
510 | 297 | Rs = np.array([0.2]) |
511 | 298 | lambdas = np.zeros(1) |
512 | | - generate_figure(R, MU, N_SRC, TRUE_CSD_XLIMS, TOTAL_ELE, SAVE_PATH, |
| 299 | + generate_figure(R, MU, N_SRC, TRUE_CSD_XLIMS, TOTAL_ELE, |
513 | 300 | method=method, Rs=Rs, lambdas=lambdas, noise=0) |
514 | | -# generate_figure_CVLC(R, MU, N_SRC, TRUE_CSD_XLIMS, TOTAL_ELE, SAVE_PATH, |
515 | | -# Rs=Rs, lambdas=None, noise=10) |
| 301 | + |
0 commit comments