Skip to content

Commit e761ad4

Browse files
committed
added some checks for even T in PI calculation in cov_util
1 parent 8660a44 commit e761ad4

1 file changed

Lines changed: 12 additions & 0 deletions

File tree

src/dca/cov_util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,10 @@ def calc_pi_from_data(X, T, proj=None, stride=1, rng=None):
412412
PI : float
413413
Mutual information in nats.
414414
"""
415+
if T % 2 != 0:
416+
raise ValueError('T must be even (This T sets the joint window length,'
417+
+ ' not the past or future length')
418+
415419
ccms = calc_cross_cov_mats_from_data(X, T, stride=stride, rng=rng)
416420

417421
return calc_pi_from_cross_cov_mats(ccms, proj=proj)
@@ -432,7 +436,11 @@ def calc_pi_from_cov(cov_2_T_pi):
432436
Mutual information in nats.
433437
"""
434438

439+
if cov_2_T_pi.shape[0] % 2 != 0:
440+
raise ValueError('cov_2_T_pi must have even shape')
441+
435442
T_pi = cov_2_T_pi.shape[0] // 2
443+
436444
use_torch = isinstance(cov_2_T_pi, torch.Tensor)
437445

438446
cov_T_pi = cov_2_T_pi[:T_pi, :T_pi]
@@ -513,6 +521,10 @@ def calc_pi_from_cross_cov_mats(cross_cov_mats, proj=None):
513521
PI : float
514522
Mutual information in nats.
515523
"""
524+
525+
if len(cross_cov_mats) % 2 != 0:
526+
raise ValueError('number of cross covariance matrices provided must be even (equal to joint window length)')
527+
516528
if proj is not None:
517529
cross_cov_mats_proj = project_cross_cov_mats(cross_cov_mats, proj)
518530
else:

0 commit comments

Comments
 (0)