Skip to content

Commit 1f01bf8

Browse files
committed
ML BT covariance estimation
1 parent ece7024 commit 1f01bf8

2 files changed

Lines changed: 133 additions & 6 deletions

File tree

dca/cov_util.py

Lines changed: 124 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def calc_chunked_cov(X, T, stride, chunks, cov_est=None, rng=None, stride_tricks
177177

178178
def calc_cross_cov_mats_from_data(X, T, mean=None, chunks=None, stride=1,
179179
rng=None, regularization=None, reg_ops=None,
180-
stride_tricks=True, logger=None):
180+
stride_tricks=True, logger=None, method='toeplitzify'):
181181
"""Compute the N-by-N cross-covariance matrix, where N is the data dimensionality,
182182
for each time lag up to T-1.
183183
@@ -205,6 +205,11 @@ def calc_cross_cov_mats_from_data(X, T, mean=None, chunks=None, stride=1,
205205
stride_tricks : bool
206206
Whether to use numpy stride tricks in form_lag_matrix. True will use less
207207
memory for large T.
208+
logger : logger
209+
Logger.
210+
method : str
211+
'ML' for EM-based maximum likelihood block toeplitz estimation, 'toeplitzify' for naive
212+
block averaging.
208213
209214
Returns
210215
-------
@@ -263,7 +268,12 @@ def calc_cross_cov_mats_from_data(X, T, mean=None, chunks=None, stride=1,
263268
cov_est /= (n_samples - 1.)
264269

265270
if regularization is None:
266-
cov_est = toeplitzify(cov_est, T, N)
271+
if method == 'toeplitzify':
272+
cov_est = toeplitzify(cov_est, T, N)
273+
elif method == 'ML':
274+
cov_est = block_toeplitz_covariance(X=None, S_X=cov_est, T=T)
275+
else:
276+
raise ValueError('`method` should be "toeplitzify" or "ML".')
267277
elif regularization == 'kron':
268278
num_folds = reg_ops.get('num_folds', 5)
269279
r_vals = np.arange(1, min(2 * T, N**2 + 1))
@@ -556,8 +566,8 @@ def calc_block_toeplitz_logdets(cross_cov_mats, proj=None):
556566
if ii > 1:
557567
As = torch.stack([A[ii - 2, ii - jj - 1] for jj in range(1, ii)])
558568
D = ccms[ii] - torch.matmul(As, ccms[1:ii]).sum(dim=0)
559-
A[(ii - 1, ii - 1)] = torch.solve(D.t(), vb[ii - 1].t())[0].t()
560-
Ab[(ii - 1, ii - 1)] = torch.solve(D, v.t())[0].t()
569+
A[(ii - 1, ii - 1)] = torch.linalg.solve(vb[ii - 1].t(), D.t()).t()
570+
Ab[(ii - 1, ii - 1)] = torch.linalg.solve(v.t(), D).t()
561571

562572
for kk in range(1, ii):
563573
A[(ii - 1, kk - 1)] = (A[(ii - 2, kk - 1)]
@@ -633,6 +643,116 @@ def calc_pi_from_cross_cov_mats_block_toeplitz(cross_cov_mats, proj=None):
633643
return sum(logdets[:T // 2]) - 0.5 * sum(logdets)
634644

635645

646+
def extract_diag_blocks(cov, Q):
647+
"""Extract the diagonal blocks from a matrix.
648+
649+
Parameters
650+
----------
651+
cov : ndarray (Q*P, Q*P)
652+
Matrix to extract diagonal blocks from.
653+
Q : int
654+
The number of blocks.
655+
"""
656+
P = cov.shape[0] // Q
657+
blocks = []
658+
start = 0
659+
for ii in range(Q):
660+
start = ii * P
661+
blocks.append(cov[start:start + P, start:start + P])
662+
return blocks
663+
664+
665+
def one_step(S_X, R_X, R_Y, A, A_R_Y, Q):
666+
"""Perform one EM step for ML block toeplitz covariance estimation.
667+
668+
Parameters
669+
----------
670+
S_X : ndarray
671+
Sample covariance matrix.
672+
R_X : ndarray
673+
Observed covariance matrix.
674+
R_Y : ndarray
675+
Latent covariance matrix.
676+
A : ndarray
677+
Projection from latent to observed variables.
678+
A_R_Y : ndarray
679+
Precomputed A @ R_Y.
680+
"""
681+
R_X_inv_A_R_Y = np.linalg.solve(R_X, A_R_Y)
682+
R_Y = (R_Y + A_R_Y.T.conj() @ ((np.linalg.solve(R_X, S_X) @ R_X_inv_A_R_Y) - R_X_inv_A_R_Y))
683+
blocks = extract_diag_blocks(R_Y, Q)
684+
R_Y = sp.linalg.block_diag(*blocks)
685+
A_R_Y = A @ R_Y
686+
R_X = (A_R_Y @ A.T.conj()).real
687+
return R_X, R_Y, A_R_Y
688+
689+
690+
def block_toeplitz_covariance(X, S_X, T, max_iter=100, tol=1e-6):
691+
"""Estimate the ML block toeplitz covariance matrix using:
692+
693+
Fuhrmann, Daniel R., and T. A. Barton.
694+
"Estimation of block-Toeplitz covariance matrices."
695+
1990 Conference Record Twenty-Fourth Asilomar Conference on Signals, Systems and Computers.
696+
697+
Parameters
698+
----------
699+
X : ndarray (time, features) or (batches, time, features)
700+
Data used to calculate the PI. Can be None if S_X is given.
701+
S_X : ndarray
702+
Sample covariance matrix. Can be None if X is given.
703+
T : int
704+
This T should be 2 * T_pi. This T sets the joint window length not the
705+
past or future window length.
706+
max_iter : int
707+
Maximum number of EM iterations.
708+
tol : int
709+
LL tolerance for stopping EM iterations.
710+
"""
711+
N = T
712+
if X is not None:
713+
P = X.shape[1]
714+
X_N = form_lag_matrix(X, N)
715+
S_X = np.cov(X_N.T)
716+
else:
717+
P = S_X.shape[0] // N
718+
719+
rectify_spectrum(S_X)
720+
Q = 4 * N
721+
NP = N * P
722+
QP = Q * P
723+
724+
I_NP = np.eye(NP)
725+
I_P = np.eye(P)
726+
W_Q = sp.linalg.dft(Q, scale='sqrtn')
727+
I_NP_0 = np.concatenate([I_NP, np.zeros((NP, QP - NP))], axis=1)
728+
W_Q_I_P = np.kron(W_Q, I_P)
729+
A = I_NP_0 @ W_Q_I_P
730+
731+
R_Y = np.eye(QP)
732+
A_R_Y = A @ R_Y
733+
R_X = np.eye(NP)
734+
if X is not None:
735+
ll = sp.stats.multivariate_normal.logpdf(X_N, mean=np.zeros(NP), cov=R_X).mean()
736+
else:
737+
d = S_X.shape[0]
738+
tr = np.trace(np.linalg.solve(R_X, S_X))
739+
logdets = np.linalg.slogdet(R_X)[1] - np.linalg.slogdet(S_X)[1]
740+
ll = -.5 * (tr + logdets - d)
741+
742+
for ii in range(max_iter):
743+
R_X, R_Y, A_R_Y = one_step(S_X, R_X, R_Y, A, A_R_Y, Q)
744+
if X is not None:
745+
new_ll = sp.stats.multivariate_normal.logpdf(X_N, mean=np.zeros(NP), cov=R_X).mean()
746+
else:
747+
d = S_X.shape[0]
748+
tr = np.trace(np.linalg.solve(R_X, S_X))
749+
logdets = np.linalg.slogdet(R_X)[1] - np.linalg.slogdet(S_X)[1]
750+
new_ll = -.5 * (tr + logdets - d)
751+
if abs(new_ll - ll) / max([1., abs(new_ll), abs(ll)]) < tol:
752+
break
753+
return R_X
754+
755+
636756
"""
637757
====================================================================================================
638758
====================================================================================================

dca/dca.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ class DynamicalComponentsAnalysis(SingleProjectionComponentsAnalysis):
9999
block_toeplitz : bool
100100
If True, uses the block-Toeplitz logdet algorithm which is typically faster and less
101101
memory intensive on cpu for `T >~ 10` and `d >~ 40`.
102+
method : str
103+
'toeplitzify' for naive averaging to compute the covariance, which is faster but less
104+
accurate for very small datasets. 'ML' for maximum likelihood block toeplitz covariance
105+
estimation which can be very slow for large datasets.
102106
device : str
103107
What device to run the computation on in Pytorch.
104108
dtype : pytorch.dtype
@@ -123,7 +127,8 @@ class DynamicalComponentsAnalysis(SingleProjectionComponentsAnalysis):
123127
"""
124128
def __init__(self, d=None, T=None, init="random_ortho", n_init=1, stride=1,
125129
chunk_cov_estimate=None, tol=1e-6, ortho_lambda=10., verbose=False,
126-
block_toeplitz=None, device="cpu", dtype=torch.float64, rng_or_seed=None):
130+
block_toeplitz=None, method='toeplitzify', device="cpu", dtype=torch.float64,
131+
rng_or_seed=None):
127132

128133
super(DynamicalComponentsAnalysis,
129134
self).__init__(d=d, T=T, init=init, n_init=n_init, stride=stride,
@@ -142,6 +147,7 @@ def __init__(self, d=None, T=None, init="random_ortho", n_init=1, stride=1,
142147
else:
143148
self.block_toeplitz = block_toeplitz
144149
self.cross_covs = None
150+
self.method = method
145151

146152
def estimate_data_statistics(self, X, T=None, regularization=None, reg_ops=None):
147153
"""Estimate the cross covariance matrix from data.
@@ -174,7 +180,8 @@ def estimate_data_statistics(self, X, T=None, regularization=None, reg_ops=None)
174180
rng=self.rng,
175181
regularization=regularization,
176182
reg_ops=reg_ops,
177-
logger=self._logger)
183+
logger=self._logger,
184+
method=self.method)
178185
self.cross_covs = torch.tensor(cross_covs, device=self.device, dtype=self.dtype)
179186
delta_time = round((time.time() - start) / 60., 1)
180187
self._logger.info('Cross covariance estimate took {:0.1f} minutes.'.format(delta_time))

0 commit comments

Comments
 (0)