@@ -121,7 +121,7 @@ class DynamicalComponentsAnalysis(SingleProjectionComponentsAnalysis):
121121 Projection matrix from fit.
122122 """
123123 def __init__ (self , d = None , T = None , init = "random_ortho" , n_init = 1 , stride = 1 , tol = 1e-6 ,
124- ortho_lambda = 10. , verbose = False , use_scipy = True , block_toeplitz = None ,
124+ ortho_lambda = 10. , verbose = False , block_toeplitz = None ,
125125 chunk_cov_estimate = None , device = "cpu" , dtype = torch .float64 , rng_or_seed = None ):
126126
127127 super (DynamicalComponentsAnalysis ,
@@ -132,7 +132,6 @@ def __init__(self, d=None, T=None, init="random_ortho", n_init=1, stride=1, tol=
132132 self .chunk_cov_estimate = chunk_cov_estimate
133133 self .d = d
134134 self .d_fit = None
135- self .use_scipy = use_scipy
136135 if block_toeplitz is None :
137136 try :
138137 if d > 40 and T > 10 :
@@ -218,77 +217,52 @@ def _fit_projection(self, d=None, T=None, record_V=False):
218217 c = self .cross_covs [:2 * T ]
219218 N = c .shape [1 ]
220219 V_init = init_coef (N , d , self .rng , self .init )
221- v = torch .tensor (V_init , requires_grad = True ,
222- device = self .device , dtype = self .dtype )
223220
224221 if not isinstance (c , torch .Tensor ):
225222 c = torch .tensor (c , device = self .device , dtype = self .dtype )
226223
227- if self .use_scipy :
224+ def f_params (v_flat , requires_grad = True ):
225+ v_flat_torch = torch .tensor (v_flat ,
226+ requires_grad = requires_grad ,
227+ device = self .device ,
228+ dtype = self .dtype )
229+ v_torch = v_flat_torch .reshape (N , d )
230+ loss = build_loss (c , d , self .ortho_lambda , self .block_toeplitz )(v_torch )
231+ return loss , v_flat_torch
232+ objective = ObjectiveWrapper (f_params )
228233
229- def f_params (v_flat , requires_grad = True ):
230- v_flat_torch = torch .tensor (v_flat ,
231- requires_grad = requires_grad ,
232- device = self .device ,
233- dtype = self .dtype )
234- v_torch = v_flat_torch .reshape (N , d )
235- loss = build_loss (c , d , self .ortho_lambda , self .block_toeplitz )(v_torch )
236- return loss , v_flat_torch
237- objective = ObjectiveWrapper (f_params )
234+ def null_callback (* args , ** kwargs ):
235+ pass
238236
239- def null_callback (* args , ** kwargs ):
240- pass
237+ if self .verbose or record_V :
238+ if record_V :
239+ self .V_seq = [V_init ]
241240
242- if self . verbose or record_V :
241+ def callback ( v_flat , objective ) :
243242 if record_V :
244- self .V_seq = [V_init ]
245-
246- def callback (v_flat , objective ):
247- if record_V :
248- self .V_seq .append (v_flat .reshape (N , d ))
249- if self .verbose :
250- loss , v_flat_torch = objective .core_computations (v_flat ,
251- requires_grad = False )
252- v_torch = v_flat_torch .reshape (N , d )
253- loss = build_loss (c , d , self .ortho_lambda , self .block_toeplitz )(v_torch )
254- loss = build_loss (c , d , self .ortho_lambda , self .block_toeplitz )(v_torch )
255- reg_val = ortho_reg_fn (self .ortho_lambda , v_torch )
256- loss = loss .detach ().cpu ().numpy ()
257- reg_val = reg_val .detach ().cpu ().numpy ()
258- PI = - (loss - reg_val )
259- string = "Loss {}, PI: {} nats, reg: {}"
260- self ._logger .info (string .format (str (np .round (loss , 4 )),
261- str (np .round (PI , 4 )),
262- str (np .round (reg_val , 4 ))))
263-
264- callback (V_init , objective )
265- else :
266- callback = null_callback
267-
268- opt = minimize (objective .func , V_init .ravel (), method = 'L-BFGS-B' , jac = objective .grad ,
269- options = {'disp' : self .verbose , 'ftol' : self .tol },
270- callback = lambda x : callback (x , objective ))
271- v = opt .x .reshape (N , d )
272- else :
273- optimizer = torch .optim .LBFGS ([v ], max_eval = 15000 , max_iter = 15000 ,
274- tolerance_change = self .tol , history_size = 10 ,
275- line_search_fn = 'strong_wolfe' )
276-
277- def closure ():
278- optimizer .zero_grad ()
279- loss = build_loss (c , d , self .ortho_lambda , self .block_toeplitz )(v )
280- loss .backward ()
243+ self .V_seq .append (v_flat .reshape (N , d ))
281244 if self .verbose :
282- reg_val = ortho_reg_fn (self .ortho_lambda , v )
283- loss_no_reg = loss - reg_val
284- pi = - loss_no_reg .detach ().cpu ().numpy ()
245+ loss , v_flat_torch = objective .core_computations (v_flat ,
246+ requires_grad = False )
247+ v_torch = v_flat_torch .reshape (N , d )
248+ loss = build_loss (c , d , self .ortho_lambda , self .block_toeplitz )(v_torch )
249+ reg_val = ortho_reg_fn (self .ortho_lambda , v_torch )
250+ loss = loss .detach ().cpu ().numpy ()
285251 reg_val = reg_val .detach ().cpu ().numpy ()
286- print ("PI: {} nats, reg: {}" .format (str (np .round (pi , 4 )),
287- str (np .round (reg_val , 4 ))))
288- return loss
252+ PI = - (loss - reg_val )
253+ string = "Loss {}, PI: {} nats, reg: {}"
254+ self ._logger .info (string .format (str (np .round (loss , 4 )),
255+ str (np .round (PI , 4 )),
256+ str (np .round (reg_val , 4 ))))
289257
290- optimizer .step (closure )
291- v = v .detach ().cpu ().numpy ()
258+ callback (V_init , objective )
259+ else :
260+ callback = null_callback
261+
262+ opt = minimize (objective .func , V_init .ravel (), method = 'L-BFGS-B' , jac = objective .grad ,
263+ options = {'disp' : self .verbose , 'ftol' : self .tol },
264+ callback = lambda x : callback (x , objective ))
265+ v = opt .x .reshape (N , d )
292266
293267 # Orthonormalize the basis prior to returning it
294268 V_opt = scipy .linalg .orth (v )
0 commit comments