@@ -177,7 +177,7 @@ def calc_chunked_cov(X, T, stride, chunks, cov_est=None, rng=None, stride_tricks
177177
178178def calc_cross_cov_mats_from_data (X , T , mean = None , chunks = None , stride = 1 ,
179179 rng = None , regularization = None , reg_ops = None ,
180- stride_tricks = True , logger = None ):
180+ stride_tricks = True , logger = None , method = 'toeplitzify' ):
181181 """Compute the N-by-N cross-covariance matrix, where N is the data dimensionality,
182182 for each time lag up to T-1.
183183
@@ -205,6 +205,11 @@ def calc_cross_cov_mats_from_data(X, T, mean=None, chunks=None, stride=1,
205205 stride_tricks : bool
206206 Whether to use numpy stride tricks in form_lag_matrix. True will use less
207207 memory for large T.
208+ logger : logger
209+ Logger.
210+ method : str
211+ 'ML' for EM-based maximum likelihood block toeplitz estimation, 'toeplitzify' for naive
212+ block averaging.
208213
209214 Returns
210215 -------
@@ -263,7 +268,12 @@ def calc_cross_cov_mats_from_data(X, T, mean=None, chunks=None, stride=1,
263268 cov_est /= (n_samples - 1. )
264269
265270 if regularization is None :
266- cov_est = toeplitzify (cov_est , T , N )
271+ if method == 'toeplitzify' :
272+ cov_est = toeplitzify (cov_est , T , N )
273+ elif method == 'ML' :
274+ cov_est = block_toeplitz_covariance (X = None , S_X = cov_est , T = T )
275+ else :
276+ raise ValueError ('`method` should be "toeplitzify" or "ML".' )
267277 elif regularization == 'kron' :
268278 num_folds = reg_ops .get ('num_folds' , 5 )
269279 r_vals = np .arange (1 , min (2 * T , N ** 2 + 1 ))
@@ -556,8 +566,8 @@ def calc_block_toeplitz_logdets(cross_cov_mats, proj=None):
556566 if ii > 1 :
557567 As = torch .stack ([A [ii - 2 , ii - jj - 1 ] for jj in range (1 , ii )])
558568 D = ccms [ii ] - torch .matmul (As , ccms [1 :ii ]).sum (dim = 0 )
559- A [(ii - 1 , ii - 1 )] = torch .solve (D . t (), vb [ii - 1 ].t ())[ 0 ] .t ()
560- Ab [(ii - 1 , ii - 1 )] = torch .solve (D , v .t ())[ 0 ] .t ()
569+ A [(ii - 1 , ii - 1 )] = torch .linalg . solve (vb [ii - 1 ].t (), D . t ()) .t ()
570+ Ab [(ii - 1 , ii - 1 )] = torch .linalg . solve (v .t (), D ) .t ()
561571
562572 for kk in range (1 , ii ):
563573 A [(ii - 1 , kk - 1 )] = (A [(ii - 2 , kk - 1 )]
@@ -633,6 +643,116 @@ def calc_pi_from_cross_cov_mats_block_toeplitz(cross_cov_mats, proj=None):
633643 return sum (logdets [:T // 2 ]) - 0.5 * sum (logdets )
634644
635645
646+ def extract_diag_blocks (cov , Q ):
647+ """Extract the diagonal blocks from a matrix.
648+
649+ Parameters
650+ ----------
651+ cov : ndarray (Q*P, Q*P)
652+ Matrix to extract diagonal blocks from.
653+ Q : int
654+ The number of blocks.
655+ """
656+ P = cov .shape [0 ] // Q
657+ blocks = []
658+ start = 0
659+ for ii in range (Q ):
660+ start = ii * P
661+ blocks .append (cov [start :start + P , start :start + P ])
662+ return blocks
663+
664+
665+ def one_step (S_X , R_X , R_Y , A , A_R_Y , Q ):
666+ """Perform one EM step for ML block toeplitz covariance estimation.
667+
668+ Parameters
669+ ----------
670+ S_X : ndarray
671+ Sample covariance matrix.
672+ R_X : ndarray
673+ Observed covariance matrix.
674+ R_Y : ndarray
675+ Latent covariance matrix.
676+ A : ndarray
677+ Projection from latent to observed variables.
678+ A_R_Y : ndarray
679+ Precomputed A @ R_Y.
680+ """
681+ R_X_inv_A_R_Y = np .linalg .solve (R_X , A_R_Y )
682+ R_Y = (R_Y + A_R_Y .T .conj () @ ((np .linalg .solve (R_X , S_X ) @ R_X_inv_A_R_Y ) - R_X_inv_A_R_Y ))
683+ blocks = extract_diag_blocks (R_Y , Q )
684+ R_Y = sp .linalg .block_diag (* blocks )
685+ A_R_Y = A @ R_Y
686+ R_X = (A_R_Y @ A .T .conj ()).real
687+ return R_X , R_Y , A_R_Y
688+
689+
690+ def block_toeplitz_covariance (X , S_X , T , max_iter = 100 , tol = 1e-6 ):
691+ """Estimate the ML block toeplitz covariance matrix using:
692+
693+ Fuhrmann, Daniel R., and T. A. Barton.
694+ "Estimation of block-Toeplitz covariance matrices."
695+ 1990 Conference Record Twenty-Fourth Asilomar Conference on Signals, Systems and Computers.
696+
697+ Parameters
698+ ----------
699+ X : ndarray (time, features) or (batches, time, features)
700+ Data used to calculate the PI. Can be None if S_X is given.
701+ S_X : ndarray
702+ Sample covariance matrix. Can be None if X is given.
703+ T : int
704+ This T should be 2 * T_pi. This T sets the joint window length not the
705+ past or future window length.
706+ max_iter : int
707+ Maximum number of EM iterations.
708+ tol : int
709+ LL tolerance for stopping EM iterations.
710+ """
711+ N = T
712+ if X is not None :
713+ P = X .shape [1 ]
714+ X_N = form_lag_matrix (X , N )
715+ S_X = np .cov (X_N .T )
716+ else :
717+ P = S_X .shape [0 ] // N
718+
719+ rectify_spectrum (S_X )
720+ Q = 4 * N
721+ NP = N * P
722+ QP = Q * P
723+
724+ I_NP = np .eye (NP )
725+ I_P = np .eye (P )
726+ W_Q = sp .linalg .dft (Q , scale = 'sqrtn' )
727+ I_NP_0 = np .concatenate ([I_NP , np .zeros ((NP , QP - NP ))], axis = 1 )
728+ W_Q_I_P = np .kron (W_Q , I_P )
729+ A = I_NP_0 @ W_Q_I_P
730+
731+ R_Y = np .eye (QP )
732+ A_R_Y = A @ R_Y
733+ R_X = np .eye (NP )
734+ if X is not None :
735+ ll = sp .stats .multivariate_normal .logpdf (X_N , mean = np .zeros (NP ), cov = R_X ).mean ()
736+ else :
737+ d = S_X .shape [0 ]
738+ tr = np .trace (np .linalg .solve (R_X , S_X ))
739+ logdets = np .linalg .slogdet (R_X )[1 ] - np .linalg .slogdet (S_X )[1 ]
740+ ll = - .5 * (tr + logdets - d )
741+
742+ for ii in range (max_iter ):
743+ R_X , R_Y , A_R_Y = one_step (S_X , R_X , R_Y , A , A_R_Y , Q )
744+ if X is not None :
745+ new_ll = sp .stats .multivariate_normal .logpdf (X_N , mean = np .zeros (NP ), cov = R_X ).mean ()
746+ else :
747+ d = S_X .shape [0 ]
748+ tr = np .trace (np .linalg .solve (R_X , S_X ))
749+ logdets = np .linalg .slogdet (R_X )[1 ] - np .linalg .slogdet (S_X )[1 ]
750+ new_ll = - .5 * (tr + logdets - d )
751+ if abs (new_ll - ll ) / max ([1. , abs (new_ll ), abs (ll )]) < tol :
752+ break
753+ return R_X
754+
755+
636756"""
637757====================================================================================================
638758====================================================================================================
0 commit comments