Skip to content

Commit 2291222

Browse files
committed
cleanup
1 parent 6b373a2 commit 2291222

2 files changed

Lines changed: 36 additions & 65 deletions

File tree

dca/dca.py

Lines changed: 36 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class DynamicalComponentsAnalysis(SingleProjectionComponentsAnalysis):
121121
Projection matrix from fit.
122122
"""
123123
def __init__(self, d=None, T=None, init="random_ortho", n_init=1, stride=1, tol=1e-6,
124-
ortho_lambda=10., verbose=False, use_scipy=True, block_toeplitz=None,
124+
ortho_lambda=10., verbose=False, block_toeplitz=None,
125125
chunk_cov_estimate=None, device="cpu", dtype=torch.float64, rng_or_seed=None):
126126

127127
super(DynamicalComponentsAnalysis,
@@ -132,7 +132,6 @@ def __init__(self, d=None, T=None, init="random_ortho", n_init=1, stride=1, tol=
132132
self.chunk_cov_estimate = chunk_cov_estimate
133133
self.d = d
134134
self.d_fit = None
135-
self.use_scipy = use_scipy
136135
if block_toeplitz is None:
137136
try:
138137
if d > 40 and T > 10:
@@ -218,77 +217,52 @@ def _fit_projection(self, d=None, T=None, record_V=False):
218217
c = self.cross_covs[:2 * T]
219218
N = c.shape[1]
220219
V_init = init_coef(N, d, self.rng, self.init)
221-
v = torch.tensor(V_init, requires_grad=True,
222-
device=self.device, dtype=self.dtype)
223220

224221
if not isinstance(c, torch.Tensor):
225222
c = torch.tensor(c, device=self.device, dtype=self.dtype)
226223

227-
if self.use_scipy:
224+
def f_params(v_flat, requires_grad=True):
225+
v_flat_torch = torch.tensor(v_flat,
226+
requires_grad=requires_grad,
227+
device=self.device,
228+
dtype=self.dtype)
229+
v_torch = v_flat_torch.reshape(N, d)
230+
loss = build_loss(c, d, self.ortho_lambda, self.block_toeplitz)(v_torch)
231+
return loss, v_flat_torch
232+
objective = ObjectiveWrapper(f_params)
228233

229-
def f_params(v_flat, requires_grad=True):
230-
v_flat_torch = torch.tensor(v_flat,
231-
requires_grad=requires_grad,
232-
device=self.device,
233-
dtype=self.dtype)
234-
v_torch = v_flat_torch.reshape(N, d)
235-
loss = build_loss(c, d, self.ortho_lambda, self.block_toeplitz)(v_torch)
236-
return loss, v_flat_torch
237-
objective = ObjectiveWrapper(f_params)
234+
def null_callback(*args, **kwargs):
235+
pass
238236

239-
def null_callback(*args, **kwargs):
240-
pass
237+
if self.verbose or record_V:
238+
if record_V:
239+
self.V_seq = [V_init]
241240

242-
if self.verbose or record_V:
241+
def callback(v_flat, objective):
243242
if record_V:
244-
self.V_seq = [V_init]
245-
246-
def callback(v_flat, objective):
247-
if record_V:
248-
self.V_seq.append(v_flat.reshape(N, d))
249-
if self.verbose:
250-
loss, v_flat_torch = objective.core_computations(v_flat,
251-
requires_grad=False)
252-
v_torch = v_flat_torch.reshape(N, d)
253-
loss = build_loss(c, d, self.ortho_lambda, self.block_toeplitz)(v_torch)
254-
loss = build_loss(c, d, self.ortho_lambda, self.block_toeplitz)(v_torch)
255-
reg_val = ortho_reg_fn(self.ortho_lambda, v_torch)
256-
loss = loss.detach().cpu().numpy()
257-
reg_val = reg_val.detach().cpu().numpy()
258-
PI = -(loss - reg_val)
259-
string = "Loss {}, PI: {} nats, reg: {}"
260-
self._logger.info(string.format(str(np.round(loss, 4)),
261-
str(np.round(PI, 4)),
262-
str(np.round(reg_val, 4))))
263-
264-
callback(V_init, objective)
265-
else:
266-
callback = null_callback
267-
268-
opt = minimize(objective.func, V_init.ravel(), method='L-BFGS-B', jac=objective.grad,
269-
options={'disp': self.verbose, 'ftol': self.tol},
270-
callback=lambda x: callback(x, objective))
271-
v = opt.x.reshape(N, d)
272-
else:
273-
optimizer = torch.optim.LBFGS([v], max_eval=15000, max_iter=15000,
274-
tolerance_change=self.tol, history_size=10,
275-
line_search_fn='strong_wolfe')
276-
277-
def closure():
278-
optimizer.zero_grad()
279-
loss = build_loss(c, d, self.ortho_lambda, self.block_toeplitz)(v)
280-
loss.backward()
243+
self.V_seq.append(v_flat.reshape(N, d))
281244
if self.verbose:
282-
reg_val = ortho_reg_fn(self.ortho_lambda, v)
283-
loss_no_reg = loss - reg_val
284-
pi = -loss_no_reg.detach().cpu().numpy()
245+
loss, v_flat_torch = objective.core_computations(v_flat,
246+
requires_grad=False)
247+
v_torch = v_flat_torch.reshape(N, d)
248+
loss = build_loss(c, d, self.ortho_lambda, self.block_toeplitz)(v_torch)
249+
reg_val = ortho_reg_fn(self.ortho_lambda, v_torch)
250+
loss = loss.detach().cpu().numpy()
285251
reg_val = reg_val.detach().cpu().numpy()
286-
print("PI: {} nats, reg: {}".format(str(np.round(pi, 4)),
287-
str(np.round(reg_val, 4))))
288-
return loss
252+
PI = -(loss - reg_val)
253+
string = "Loss {}, PI: {} nats, reg: {}"
254+
self._logger.info(string.format(str(np.round(loss, 4)),
255+
str(np.round(PI, 4)),
256+
str(np.round(reg_val, 4))))
289257

290-
optimizer.step(closure)
291-
v = v.detach().cpu().numpy()
258+
callback(V_init, objective)
259+
else:
260+
callback = null_callback
261+
262+
opt = minimize(objective.func, V_init.ravel(), method='L-BFGS-B', jac=objective.grad,
263+
options={'disp': self.verbose, 'ftol': self.tol},
264+
callback=lambda x: callback(x, objective))
265+
v = opt.x.reshape(N, d)
292266

293267
# Orthonormalize the basis prior to returning it
294268
V_opt = scipy.linalg.orth(v)

tests/test_dca.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ def test_DCA(noise_dataset):
3636
model.score()
3737
model.score(X)
3838

39-
model = DCA(d=3, T=10, use_scipy=False)
40-
model.fit(X)
41-
4239
model = DCA(d=3, T=10, verbose=True)
4340
model.fit(X)
4441

0 commit comments

Comments
 (0)