Skip to content

Commit 1393ac9

Browse files
authored
Optimized CV method (kESI's code) (#144)
* Optimized CV method (kESI's code) * Make `.update_R()` method lazy
1 parent 2e5f8f4 commit 1393ac9

1 file changed

Lines changed: 40 additions & 9 deletions

File tree

kcsd/KCSD.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,9 @@ def update_R(self, R):
258258
R : float
259259
260260
"""
261+
if self.R == R:
262+
return
263+
261264
self.R = R
262265
self.dist_max = max(np.max(self.src_ele_dists),
263266
np.max(self.src_estm_dists)) + self.R
@@ -273,7 +276,7 @@ def update_lambda(self, lambd):
273276
"""
274277
self.lambd = lambd
275278

276-
def cross_validate(self, lambdas=None, Rs=None):
279+
def cross_validate(self, lambdas=None, Rs=None, regression_test=False):
277280
"""Method defines the cross validation.
278281
279282
By default only cross_validates over lambda,
@@ -286,6 +289,7 @@ def cross_validate(self, lambdas=None, Rs=None):
286289
----------
287290
lambdas : numpy array
288291
Rs : numpy array
292+
regression_test : bool
289293
290294
Returns
291295
-------
@@ -302,27 +306,54 @@ def cross_validate(self, lambdas=None, Rs=None):
302306
if Rs is None: # when None
303307
Rs = np.array((self.R)).flatten() # Default over one R value
304308
self.errs = np.zeros((Rs.size, lambdas.size))
305-
index_generator = []
306-
for ii in range(self.n_ele):
307-
idx_test = [ii]
308-
idx_train = list(range(self.n_ele))
309-
idx_train.remove(ii) # Leave one out
310-
index_generator.append((idx_train, idx_test))
309+
if regression_test:
310+
index_generator = self._get_index_generator()
311+
self.errs_regression = self.errs.copy()
312+
311313
for R_idx, R in enumerate(Rs): # Iterate over R
312314
self.update_R(R)
313315
print('Cross validating R (all lambda) :', R)
314316
for lambd_idx, lambd in enumerate(lambdas): # Iterate over lambdas
315-
self.errs[R_idx, lambd_idx] = self.compute_cverror(lambd,
316-
index_generator)
317+
self.errs[R_idx, lambd_idx] = self._compute_cverr(lambd)
318+
if regression_test:
319+
self.errs_regression[R_idx, lambd_idx] = self.compute_cverror(lambd,
320+
index_generator)
317321
err_idx = np.where(self.errs == np.min(self.errs)) # Index of the least error
318322
cv_R = Rs[err_idx[0]][0] # First occurance of the least error's
319323
cv_lambda = lambdas[err_idx[1]][0]
320324
self.cv_error = np.min(self.errs) # otherwise is None
321325
self.update_R(cv_R) # Update solver
322326
self.update_lambda(cv_lambda)
323327
print('R, lambda :', cv_R, cv_lambda)
328+
if regression_test:
329+
err_idx = np.where(self.errs_regression == np.min(self.errs_regression))
330+
print("R, lambda :", Rs[err_idx[0]][0], lambdas[err_idx[1]][0], "OLD")
331+
print("max diff :", abs(self.errs - self.errs_regression).max())
332+
print("new min, old min :", self.cv_error, np.min(self.errs_regression))
324333
return cv_R, cv_lambda
325334

335+
def _get_index_generator(self):
336+
index_generator = []
337+
for ii in range(self.n_ele):
338+
idx_test = [ii]
339+
idx_train = list(range(self.n_ele))
340+
idx_train.remove(ii) # Leave one out
341+
index_generator.append((idx_train, idx_test))
342+
return index_generator
343+
344+
def _compute_cverr(self, lambd):
345+
n = self.k_pot.shape[0] # self.n_ele
346+
K = self.k_pot + np.identity(n) * lambd
347+
IDX_N = np.arange(n)
348+
return sum(map(np.linalg.norm,
349+
[self._leave_one_out_estimate(K, i, IDX_N != i) - ROW
350+
for i, ROW in enumerate(self.pots)]))
351+
352+
def _leave_one_out_estimate(self, KERNEL, i, IDX):
353+
return np.matmul(self.k_pot[np.ix_([i], IDX)],
354+
np.linalg.solve(KERNEL[np.ix_(IDX, IDX)],
355+
self.pots[IDX]))[0]
356+
326357
def compute_cverror(self, lambd, index_generator):
327358
"""Useful for Cross validation error calculations
328359

0 commit comments

Comments
 (0)