1+ import logging
12import numpy as np
23import scipy as sp
34import collections
910from sklearn .utils import check_random_state
1011
1112
13+ logging .basicConfig ()
14+
15+
1216def form_lag_matrix (X , T , stride = 1 , stride_tricks = True , rng = None , writeable = False ):
1317 """Form the data matrix with `T` lags.
1418
@@ -69,7 +73,7 @@ def form_lag_matrix(X, T, stride=1, stride_tricks=True, rng=None, writeable=Fals
6973 return X_with_lags
7074
7175
72- def rectify_spectrum (cov , epsilon = 1e-6 , verbose = False ):
76+ def rectify_spectrum (cov , epsilon = 1e-6 , logger = None ):
7377 """Rectify the spectrum of a covariance matrix.
7478
7579 Parameters
@@ -81,11 +85,13 @@ def rectify_spectrum(cov, epsilon=1e-6, verbose=False):
8185 verbose : bool
8286 Whethere to print when the spectrum needs to be rectified.
8387 """
84- min_eig = np .min (sp .linalg .eigvalsh (cov ))
85- if min_eig < 0 :
86- cov += (- min_eig + epsilon ) * np .eye (cov .shape [0 ])
87- if verbose :
88- print ("Warning: non-PSD matrix (had to increase eigenvalues)" )
88+ eigvals = sp .linalg .eigvalsh (cov )
89+ n_neg = np .sum (eigvals <= 0. )
90+ if n_neg > 0 :
91+ cov += (- np .min (eigvals ) + epsilon ) * np .eye (cov .shape [0 ])
92+ if logger is not None :
93+ string = 'Non-PSD matrix, {} of {} eigenvalues were not positive.'
94+ logger .info (string .format (n_neg , eigvals .size ))
8995
9096
9197def toeplitzify (cov , T , N , symmetrize = True ):
@@ -171,7 +177,7 @@ def calc_chunked_cov(X, T, stride, chunks, cov_est=None, rng=None, stride_tricks
171177
172178def calc_cross_cov_mats_from_data (X , T , mean = None , chunks = None , stride = 1 ,
173179 rng = None , regularization = None , reg_ops = None ,
174- stride_tricks = True ):
180+ stride_tricks = True , logger = None ):
175181 """Compute the N-by-N cross-covariance matrix, where N is the data dimensionality,
176182 for each time lag up to T-1.
177183
@@ -271,7 +277,7 @@ def calc_cross_cov_mats_from_data(X, T, mean=None, chunks=None, stride=1,
271277 else :
272278 raise ValueError
273279
274- rectify_spectrum (cov_est , verbose = True )
280+ rectify_spectrum (cov_est , logger = logger )
275281 cross_cov_mats = calc_cross_cov_mats_from_cov (cov_est , T , N )
276282 return cross_cov_mats
277283
0 commit comments