@@ -23,51 +23,58 @@ def set_axis(ax, x, y, letter=None):
2323def plot_surface (curve_surf , errsy , save_as ):
2424 fsize = 18
2525 lambdas = np .logspace (- 7 , - 3 , 50 )
26- fig = plt .figure (figsize = (20 , 9 ), dpi = 300 )
27- gs = gridspec .GridSpec (16 , 12 , hspace = 2 , wspace = 2 )
26+ fig = plt .figure (figsize = (15 , 6 ), dpi = 300 )
27+ gs = gridspec .GridSpec (16 , 12 , hspace = 1 , wspace = 1 )
2828 ax = plt .subplot (gs [0 :16 , 0 :6 ])
2929 set_axis (ax , - 0.05 , 1.05 , letter = 'A' )
30- plt .pcolormesh (lambdas , np .arange (9 ), curve_surf ,
31- cmap = 'BrBG' , vmin = - 2 , vmax = 2 )
30+
31+ plt .pcolormesh (lambdas , np .arange (8 ), curve_surf ,
32+ cmap = 'BrBG' , vmin = - 3 , vmax = 3 )
3233 plt .colorbar ()
3334 for i ,m in enumerate (curve_surf .argmax (axis = 1 )):
34- plt .scatter ([lambdas [m ]], [i + 0.5 ], s = 50 , color = 'red' , alpha = 0.7 )
35+ plt .scatter ([lambdas [m ]], [i ], s = 50 , color = 'red' , alpha = 0.7 )
3536 if i == 7 :
36- plt .scatter ([lambdas [m ]], [i + 0.5 ], s = 50 , color = 'red' ,
37- label = 'Maximum Curvature ' , alpha = 0.7 )
37+ plt .scatter ([lambdas [m ]], [i ], s = 50 , color = 'red' ,
38+ label = 'Maximum \n Curvature ' , alpha = 0.7 )
3839 plt .xlim (lambdas [1 ],lambdas [- 1 ])
3940 plt .title ('L-curve regularization' , fontsize = fsize )
40- plt .legend (loc = 'center' , bbox_to_anchor = (0.5 , - 0.12 ), ncol = 1 ,
41+ # plt.legend(loc='center', bbox_to_anchor=(0.5, -0.12), ncol=1,
42+ # frameon = False, fontsize = fsize)
43+ plt .legend (loc = 'upper left' , ncol = 1 ,
4144 frameon = False , fontsize = fsize )
42- plt .yticks (np .arange (8 )+ 0.5 , [str (x )+ 'x' for x in range (1 ,9 )])
45+ plt .yticks (np .arange (8 ), [str (x )+ 'x' for x in range (1 ,9 )])
4346 plt .xscale ('log' )
4447 plt .ylabel ('Parameter $R$ in electrode distance' , fontsize = fsize , labelpad = 15 )
4548 plt .xlabel ('$\lambda$' ,fontsize = fsize )
4649 ax = plt .subplot (gs [0 :16 , 6 :12 ])
4750 set_axis (ax , - 0.05 , 1.05 , letter = 'B' )
48- plt .pcolormesh (lambdas , np .arange (9 ), errsy , cmap = 'Greys' )
51+ plt .pcolormesh (lambdas , np .arange (8 ), errsy , cmap = 'Greys' , vmin = 0.01 , vmax = 0.02 )
4952 plt .colorbar ()
5053 for i ,m in enumerate (errsy .argmin (axis = 1 )):
51- plt .scatter ([lambdas [m ]], [i + 0.5 ], s = 50 , color = 'red' , alpha = 0.7 )
54+ plt .scatter ([lambdas [m ]], [i ], s = 50 , color = 'red' , alpha = 0.7 )
5255 if i == 7 :
53- plt .scatter ([lambdas [m ]], [i + 0.5 ], s = 50 , color = 'red' ,
54- label = 'Minimum Error ' , alpha = 0.7 )
56+ plt .scatter ([lambdas [m ]], [i ], s = 50 , color = 'red' ,
57+ label = 'Minimum \n Error ' , alpha = 0.7 )
5558 plt .xlim (lambdas [1 ],lambdas [- 1 ])
56- plt .legend (loc = 'center' , bbox_to_anchor = (0.5 , - 0.12 ), ncol = 1 ,
59+ # plt.legend(loc='center', bbox_to_anchor=(0.5, -0.12), ncol=1,
60+ # frameon = False, fontsize = fsize)
61+ plt .legend (loc = 'upper left' , ncol = 1 ,
5762 frameon = False , fontsize = fsize )
5863 plt .title ('Cross-validation regularization' , fontsize = fsize )
59- plt .yticks (np .arange (8 )+ 0.5 , [str (x )+ 'x' for x in range (1 ,9 )])
64+ plt .yticks (np .arange (8 ), [str (x )+ 'x' for x in range (1 ,9 )])
6065 plt .xscale ('log' )
6166 plt .xlabel ('$\lambda$' , fontsize = fsize )
6267 fig .savefig (save_as + '.png' )
6368
6469if __name__ == '__main__' :
65- # os.chdir("./LCurve/")
70+ os .chdir ("./LCurve/" )
6671 noises = 3
6772 noise_lvl = np .linspace (0 , 0.5 , noises )
68- # df = np.load('LC2/data_fig4_and_fig13_lc_noise25.0.npz')
73+ print (os .getcwd ())
74+ df = np .load (os .path .join ('LC2' , 'data_fig4_and_fig13_LC_noise25.0.npz' ))
75+
6976 Rs = np .linspace (0.025 , 8 * 0.025 , 8 )
7077 title = ['nazwa_pliku' ]
7178 save_as = 'noise'
72- # plot_surface(df['curve_surf'], df['errsy'], save_as+'surf')
73- plt .close ('all' )
79+ plot_surface (df ['curve_surf' ], df ['errsy' ], save_as + 'surf' )
80+ plt .close ('all' )
0 commit comments