Skip to content

Commit 6b373a2

Browse files
committed
FCA fixes
1 parent 902e75e commit 6b373a2

2 files changed

Lines changed: 32 additions & 39 deletions

File tree

dca/methods_comparison.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66
from scipy.optimize import minimize
77
from sklearn.decomposition import FactorAnalysis as FA, PCA
88
from sklearn.exceptions import ConvergenceWarning
9+
from sklearn.utils import check_random_state
910
from functools import partial
1011

1112
import torch
1213
from torch.nn import functional as F
1314

14-
from .dca import ortho_reg_fn
15+
16+
from .base import init_coef
17+
1518

1619
__all__ = ['GaussianProcessFactorAnalysis',
1720
'SlowFeatureAnalysis',
@@ -66,22 +69,26 @@ class ForecastableComponentsAnalysis(object):
6669
6770
"""
6871
def __init__(self, d, T, init="random_ortho", n_init=1, tol=1e-6,
69-
ortho_lambda=10., verbose=False,
70-
device="cpu", dtype=torch.float64):
72+
verbose=False, device="cpu", dtype=torch.float64,
73+
rng_or_seed=20200818):
7174
self.d = d
75+
if d > 1:
76+
raise ValueError
7277
self.T = T
7378
self.init = init
7479
self.n_init = n_init
7580
self.tol = tol
76-
self.ortho_lambda = ortho_lambda
7781
self.verbose = verbose
7882
self.device = device
7983
self.dtype = dtype
8084
self.cross_covs = None
85+
self.rng = check_random_state(rng_or_seed)
8186

8287
def fit(self, X, d=None, T=None, n_init=None):
8388
if d is None:
8489
d = self.d
90+
if d > 1:
91+
raise ValueError
8592
if T is None:
8693
T = self.T
8794
self.pca = PCA(whiten=True)
@@ -100,21 +107,10 @@ def fit(self, X, d=None, T=None, n_init=None):
100107
def _fit_projection(self, X, d=None):
101108
if d is None:
102109
d = self.d
103-
104-
N = X.shape[1]
105-
if type(self.init) == str:
106-
if self.init == "random":
107-
V_init = np.random.normal(0, 1, (N, d))
108-
elif self.init == "random_ortho":
109-
V_init = scipy.stats.ortho_group.rvs(N)[:, :d]
110-
elif self.init == "uniform":
111-
V_init = np.ones((N, d)) / np.sqrt(N)
112-
V_init = V_init + np.random.normal(0, 1e-3, V_init.shape)
113-
else:
114-
raise ValueError
115-
else:
110+
if d > 1:
116111
raise ValueError
117-
V_init /= np.linalg.norm(V_init, axis=0, keepdims=True)
112+
N = X.shape[1]
113+
V_init = init_coef(N, d, self.rng, self.init)
118114

119115
v = torch.tensor(V_init, requires_grad=True,
120116
device=self.device, dtype=self.dtype)
@@ -125,17 +121,15 @@ def _fit_projection(self, X, d=None):
125121

126122
if self.verbose:
127123
def callback(v_flat):
128-
v_flat_torch = torch.tensor(v_flat,
129-
requires_grad=True,
130-
device=self.device,
131-
dtype=self.dtype)
132-
v_torch = v_flat_torch.reshape(N, d)
133-
ent = ent_loss_fn(Xt, v_torch, self.T)
134-
reg_val = ortho_reg_fn(v_torch, self.ortho_lambda)
135-
ent = ent.detach().cpu().numpy()
136-
reg_val = reg_val.detach().cpu().numpy()
137-
print("Ent: {} bits, reg: {}".format(str(np.round(ent, 4)),
138-
str(np.round(reg_val, 4))))
124+
with torch.no_grad():
125+
v_flat_torch = torch.tensor(v_flat,
126+
device=self.device,
127+
dtype=self.dtype)
128+
v_torch = v_flat_torch.reshape(N, d)
129+
v_torch = v_torch / torch.norm(v_torch, dim=0, keepdim=True)
130+
ent = ent_loss_fn(Xt, v_torch, self.T)
131+
ent = ent.detach().cpu().numpy()
132+
print("Ent: {} bits".format(str(np.round(ent, 4))))
139133
callback(V_init)
140134
else:
141135
callback = None
@@ -146,9 +140,8 @@ def f_df(v_flat):
146140
device=self.device,
147141
dtype=self.dtype)
148142
v_torch = v_flat_torch.reshape(N, d)
149-
ent = ent_loss_fn(Xt, v_torch, self.T)
150-
reg_val = ortho_reg_fn(v_torch, self.ortho_lambda)
151-
loss = ent + reg_val
143+
v_torch = v_torch / torch.norm(v_torch, dim=0, keepdim=True)
144+
loss = ent_loss_fn(Xt, v_torch, self.T)
152145
loss.backward()
153146
grad = v_flat_torch.grad
154147
return (loss.detach().cpu().numpy().astype(float),
@@ -161,12 +154,12 @@ def f_df(v_flat):
161154

162155
# Orthonormalize the basis prior to returning it
163156
V_opt = scipy.linalg.orth(v)
164-
v_flat_torch = torch.tensor(V_opt.ravel(),
165-
requires_grad=True,
166-
device=self.device,
167-
dtype=self.dtype)
168-
v_torch = v_flat_torch.reshape(N, d)
169-
final_pi = ent_loss_fn(Xt, v_torch, self.T).detach().cpu().numpy()
157+
with torch.no_grad():
158+
v_flat_torch = torch.tensor(V_opt.ravel(),
159+
device=self.device,
160+
dtype=self.dtype)
161+
v_torch = v_flat_torch.reshape(N, d)
162+
final_pi = ent_loss_fn(Xt, v_torch, self.T).detach().cpu().numpy()
170163
return V_opt, final_pi
171164

172165
def transform(self, X):

tests/test_methods_comparison.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_FCA(noise_dataset):
2121
"""Test that a FCA model can be fit with no errors.
2222
"""
2323
X = noise_dataset
24-
model = FCA(d=3, T=10)
24+
model = FCA(d=1, T=10)
2525
model.fit(X)
2626
model.transform(X)
2727
model.fit_transform(X)

0 commit comments

Comments
 (0)