Skip to content

Commit a4d21d3

Browse files
authored
Merge pull request #60 from JesseLivezey/score_fix
fixed score T
2 parents 3629e9a + cbed0ad commit a4d21d3

2 files changed

Lines changed: 19 additions & 3 deletions

File tree

dca/dca.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _fit_projection(self, d=None, T=None, record_V=False):
206206
raise ValueError
207207
if (2 * T) > self.cross_covs.shape[0]:
208208
raise ValueError('T must less than or equal to the value when ' +
209-
'`estimate_cross_covariance()` was called.')
209+
'`estimate_data_statistics()` was called.')
210210
self.T_fit = T
211211

212212
if self.cross_covs is None:
@@ -276,10 +276,12 @@ def score(self, X=None):
276276
Optional. If X is none, calculate PI from the training data.
277277
If X is given, calcuate the PI of X for the learned projections.
278278
"""
279+
T = self.T_fit
279280
if X is None:
280-
cross_covs = self.cross_covs
281+
cross_covs = self.cross_covs.cpu().numpy()
281282
else:
282-
cross_covs = calc_cross_cov_mats_from_data(X, T=self.T)
283+
cross_covs = calc_cross_cov_mats_from_data(X, 2 * self.T)
284+
cross_covs = cross_covs[:2 * T]
283285
if self.block_toeplitz:
284286
return calc_pi_from_cross_cov_mats_block_toeplitz(cross_covs, self.coef_)
285287
else:

tests/test_dca.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,17 @@ def test_stride_DCA(lorenz_dataset):
181181
model.estimate_data_statistics(X)
182182
ccms2 = model.cross_covs.numpy()
183183
assert_allclose(ccms1, ccms2)
184+
185+
186+
def test_DCA_validation_score(lorenz_dataset):
187+
"""Check that score() returns the same value for the training data using the
188+
default value of None or passing in the training data.
189+
"""
190+
X = lorenz_dataset
191+
model = DCA(T=5, d=2)
192+
model.fit(X)
193+
assert_allclose(model.score(), model.score(X))
194+
195+
model.fit_projection(T=3)
196+
assert_allclose(model.score(), model.score(X))
197+

0 commit comments

Comments
 (0)