@@ -258,6 +258,9 @@ def update_R(self, R):
258258 R : float
259259
260260 """
261+ if self .R == R :
262+ return
263+
261264 self .R = R
262265 self .dist_max = max (np .max (self .src_ele_dists ),
263266 np .max (self .src_estm_dists )) + self .R
@@ -273,7 +276,7 @@ def update_lambda(self, lambd):
273276 """
274277 self .lambd = lambd
275278
276- def cross_validate (self , lambdas = None , Rs = None ):
279+ def cross_validate (self , lambdas = None , Rs = None , regression_test = False ):
277280 """Method defines the cross validation.
278281
279282 By default only cross_validates over lambda,
@@ -286,6 +289,7 @@ def cross_validate(self, lambdas=None, Rs=None):
286289 ----------
287290 lambdas : numpy array
288291 Rs : numpy array
292+ regression_test : bool
289293
290294 Returns
291295 -------
@@ -302,27 +306,54 @@ def cross_validate(self, lambdas=None, Rs=None):
302306 if Rs is None : # when None
303307 Rs = np .array ((self .R )).flatten () # Default over one R value
304308 self .errs = np .zeros ((Rs .size , lambdas .size ))
305- index_generator = []
306- for ii in range (self .n_ele ):
307- idx_test = [ii ]
308- idx_train = list (range (self .n_ele ))
309- idx_train .remove (ii ) # Leave one out
310- index_generator .append ((idx_train , idx_test ))
309+ if regression_test :
310+ index_generator = self ._get_index_generator ()
311+ self .errs_regression = self .errs .copy ()
312+
311313 for R_idx , R in enumerate (Rs ): # Iterate over R
312314 self .update_R (R )
313315 print ('Cross validating R (all lambda) :' , R )
314316 for lambd_idx , lambd in enumerate (lambdas ): # Iterate over lambdas
315- self .errs [R_idx , lambd_idx ] = self .compute_cverror (lambd ,
316- index_generator )
317+ self .errs [R_idx , lambd_idx ] = self ._compute_cverr (lambd )
318+ if regression_test :
319+ self .errs_regression [R_idx , lambd_idx ] = self .compute_cverror (lambd ,
320+ index_generator )
317321 err_idx = np .where (self .errs == np .min (self .errs )) # Index of the least error
318322 cv_R = Rs [err_idx [0 ]][0 ] # First occurance of the least error's
319323 cv_lambda = lambdas [err_idx [1 ]][0 ]
320324 self .cv_error = np .min (self .errs ) # otherwise is None
321325 self .update_R (cv_R ) # Update solver
322326 self .update_lambda (cv_lambda )
323327 print ('R, lambda :' , cv_R , cv_lambda )
328+ if regression_test :
329+ err_idx = np .where (self .errs_regression == np .min (self .errs_regression ))
330+ print ("R, lambda :" , Rs [err_idx [0 ]][0 ], lambdas [err_idx [1 ]][0 ], "OLD" )
331+ print ("max diff :" , abs (self .errs - self .errs_regression ).max ())
332+ print ("new min, old min :" , self .cv_error , np .min (self .errs_regression ))
324333 return cv_R , cv_lambda
325334
335+ def _get_index_generator (self ):
336+ index_generator = []
337+ for ii in range (self .n_ele ):
338+ idx_test = [ii ]
339+ idx_train = list (range (self .n_ele ))
340+ idx_train .remove (ii ) # Leave one out
341+ index_generator .append ((idx_train , idx_test ))
342+ return index_generator
343+
344+ def _compute_cverr (self , lambd ):
345+ n = self .k_pot .shape [0 ] # self.n_ele
346+ K = self .k_pot + np .identity (n ) * lambd
347+ IDX_N = np .arange (n )
348+ return sum (map (np .linalg .norm ,
349+ [self ._leave_one_out_estimate (K , i , IDX_N != i ) - ROW
350+ for i , ROW in enumerate (self .pots )]))
351+
352+ def _leave_one_out_estimate (self , KERNEL , i , IDX ):
353+ return np .matmul (self .k_pot [np .ix_ ([i ], IDX )],
354+ np .linalg .solve (KERNEL [np .ix_ (IDX , IDX )],
355+ self .pots [IDX ]))[0 ]
356+
326357 def compute_cverror (self , lambd , index_generator ):
327358 """Useful for Cross validation error calculations
328359
0 commit comments