66from scipy .optimize import minimize
77from sklearn .decomposition import FactorAnalysis as FA , PCA
88from sklearn .exceptions import ConvergenceWarning
9+ from sklearn .utils import check_random_state
910from functools import partial
1011
1112import torch
1213from torch .nn import functional as F
1314
14- from .dca import ortho_reg_fn
15+
16+ from .base import init_coef
17+
1518
1619__all__ = ['GaussianProcessFactorAnalysis' ,
1720 'SlowFeatureAnalysis' ,
@@ -66,22 +69,26 @@ class ForecastableComponentsAnalysis(object):
6669
6770 """
6871 def __init__ (self , d , T , init = "random_ortho" , n_init = 1 , tol = 1e-6 ,
69- ortho_lambda = 10. , verbose = False ,
70- device = "cpu" , dtype = torch . float64 ):
72+ verbose = False , device = "cpu" , dtype = torch . float64 ,
73+ rng_or_seed = 20200818 ):
7174 self .d = d
75+ if d > 1 :
76+ raise ValueError
7277 self .T = T
7378 self .init = init
7479 self .n_init = n_init
7580 self .tol = tol
76- self .ortho_lambda = ortho_lambda
7781 self .verbose = verbose
7882 self .device = device
7983 self .dtype = dtype
8084 self .cross_covs = None
85+ self .rng = check_random_state (rng_or_seed )
8186
8287 def fit (self , X , d = None , T = None , n_init = None ):
8388 if d is None :
8489 d = self .d
90+ if d > 1 :
91+ raise ValueError
8592 if T is None :
8693 T = self .T
8794 self .pca = PCA (whiten = True )
@@ -100,21 +107,10 @@ def fit(self, X, d=None, T=None, n_init=None):
100107 def _fit_projection (self , X , d = None ):
101108 if d is None :
102109 d = self .d
103-
104- N = X .shape [1 ]
105- if type (self .init ) == str :
106- if self .init == "random" :
107- V_init = np .random .normal (0 , 1 , (N , d ))
108- elif self .init == "random_ortho" :
109- V_init = scipy .stats .ortho_group .rvs (N )[:, :d ]
110- elif self .init == "uniform" :
111- V_init = np .ones ((N , d )) / np .sqrt (N )
112- V_init = V_init + np .random .normal (0 , 1e-3 , V_init .shape )
113- else :
114- raise ValueError
115- else :
110+ if d > 1 :
116111 raise ValueError
117- V_init /= np .linalg .norm (V_init , axis = 0 , keepdims = True )
112+ N = X .shape [1 ]
113+ V_init = init_coef (N , d , self .rng , self .init )
118114
119115 v = torch .tensor (V_init , requires_grad = True ,
120116 device = self .device , dtype = self .dtype )
@@ -125,17 +121,15 @@ def _fit_projection(self, X, d=None):
125121
126122 if self .verbose :
127123 def callback (v_flat ):
128- v_flat_torch = torch .tensor (v_flat ,
129- requires_grad = True ,
130- device = self .device ,
131- dtype = self .dtype )
132- v_torch = v_flat_torch .reshape (N , d )
133- ent = ent_loss_fn (Xt , v_torch , self .T )
134- reg_val = ortho_reg_fn (v_torch , self .ortho_lambda )
135- ent = ent .detach ().cpu ().numpy ()
136- reg_val = reg_val .detach ().cpu ().numpy ()
137- print ("Ent: {} bits, reg: {}" .format (str (np .round (ent , 4 )),
138- str (np .round (reg_val , 4 ))))
124+ with torch .no_grad ():
125+ v_flat_torch = torch .tensor (v_flat ,
126+ device = self .device ,
127+ dtype = self .dtype )
128+ v_torch = v_flat_torch .reshape (N , d )
129+ v_torch = v_torch / torch .norm (v_torch , dim = 0 , keepdim = True )
130+ ent = ent_loss_fn (Xt , v_torch , self .T )
131+ ent = ent .detach ().cpu ().numpy ()
132+ print ("Ent: {} bits" .format (str (np .round (ent , 4 ))))
139133 callback (V_init )
140134 else :
141135 callback = None
@@ -146,9 +140,8 @@ def f_df(v_flat):
146140 device = self .device ,
147141 dtype = self .dtype )
148142 v_torch = v_flat_torch .reshape (N , d )
149- ent = ent_loss_fn (Xt , v_torch , self .T )
150- reg_val = ortho_reg_fn (v_torch , self .ortho_lambda )
151- loss = ent + reg_val
143+ v_torch = v_torch / torch .norm (v_torch , dim = 0 , keepdim = True )
144+ loss = ent_loss_fn (Xt , v_torch , self .T )
152145 loss .backward ()
153146 grad = v_flat_torch .grad
154147 return (loss .detach ().cpu ().numpy ().astype (float ),
@@ -161,12 +154,12 @@ def f_df(v_flat):
161154
162155 # Orthonormalize the basis prior to returning it
163156 V_opt = scipy .linalg .orth (v )
164- v_flat_torch = torch .tensor ( V_opt . ravel (),
165- requires_grad = True ,
166- device = self .device ,
167- dtype = self .dtype )
168- v_torch = v_flat_torch .reshape (N , d )
169- final_pi = ent_loss_fn (Xt , v_torch , self .T ).detach ().cpu ().numpy ()
157+ with torch .no_grad ():
158+ v_flat_torch = torch . tensor ( V_opt . ravel () ,
159+ device = self .device ,
160+ dtype = self .dtype )
161+ v_torch = v_flat_torch .reshape (N , d )
162+ final_pi = ent_loss_fn (Xt , v_torch , self .T ).detach ().cpu ().numpy ()
170163 return V_opt , final_pi
171164
172165 def transform (self , X ):
0 commit comments