Skip to content

Commit 136fa4a

Browse files
committed
fixed analysis funcs from dca update
1 parent 2b96440 commit 136fa4a

1 file changed

Lines changed: 9 additions & 13 deletions

File tree

dca/analysis.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,14 @@ def run_analysis(X, Y, T_pi_vals, dim_vals, offset_vals, num_cv_folds, decoding_
114114
Y_test_ctd = Y_test - Y_mean
115115

116116
# compute cross-cov mats for DCA
117-
T_max = 2 * np.max(T_pi_vals)
118-
cross_cov_mats = calc_cross_cov_mats_from_data(X_train_ctd, T_max)
117+
dca_model = DynamicalComponentsAnalysis(T=np.max(T_pi_vals))
118+
dca_model.estimate_data_statistics(X_train_ctd)
119119

120120
# do PCA/SFA
121121
pca_model = PCA(svd_solver='full').fit(np.concatenate(X_train_ctd))
122122
sfa_model = SFA(1).fit(X_train_ctd)
123123

124124
# make DCA object
125-
opt = DynamicalComponentsAnalysis(d=1, T=1)
126125

127126
# loop over dimensionalities
128127
for dim_idx in range(len(dim_vals)):
@@ -147,9 +146,8 @@ def run_analysis(X, Y, T_pi_vals, dim_vals, offset_vals, num_cv_folds, decoding_
147146
# loop over T_pi vals
148147
for T_pi_idx in range(len(T_pi_vals)):
149148
T_pi = T_pi_vals[T_pi_idx]
150-
opt.cross_covs = cross_cov_mats[:2 * T_pi]
151-
opt.fit_projection(d=dim, n_init=n_init)
152-
V_dca = opt.coef_
149+
dca_model.fit_projection(d=dim, T=T_pi, n_init=n_init)
150+
V_dca = dca_model.coef_
153151

154152
# compute DCA R2 over offsets
155153
X_train_dca = [np.dot(Xi, V_dca) for Xi in X_train_ctd]
@@ -221,21 +219,19 @@ def run_dim_analysis_dca(X, Y, T_pi, dim_vals, offset, num_cv_folds, decoding_wi
221219
Y_train_ctd = [Yi - Y_mean for Yi in Y_train]
222220
Y_test_ctd = Y_test - Y_mean
223221

222+
# make DCA object
224223
# compute cross-cov mats for DCA
225-
cross_cov_mats = calc_cross_cov_mats_from_data(X_train_ctd, 2 * T_pi)
226-
227-
# make DCA object with arb parameters
228-
opt = DynamicalComponentsAnalysis(d=1, T=1)
224+
dca_model = DynamicalComponentsAnalysis(T=T_pi)
225+
dca_model.estimate_data_statistics(X_train_ctd)
229226

230227
# loop over dimensionalities
231228
for dim_idx in range(len(dim_vals)):
232229
dim = dim_vals[dim_idx]
233230
if verbose:
234231
print("dim", dim_idx + 1, "of", len(dim_vals))
235232

236-
opt.cross_covs = cross_cov_mats
237-
opt.fit_projection(d=dim, n_init=n_init)
238-
V_dca = opt.coef_
233+
dca_model.fit_projection(d=dim, n_init=n_init)
234+
V_dca = dca_model.coef_
239235

240236
# compute DCA R2 over offsets
241237
X_train_dca = [np.dot(Xi, V_dca) for Xi in X_train_ctd]

0 commit comments

Comments
 (0)