Skip to content

Commit 54820ab

Browse files
committed
args cleanup
1 parent 2291222 commit 54820ab

2 files changed

Lines changed: 42 additions & 45 deletions

File tree

dca/base.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717

1818
def ortho_reg_fn(ortho_lambda, *Vs):
1919
"""Regularization term which encourages the basis vectors in the
20-
columns of V to be orthonormal.
20+
columns of the Vs to be orthonormal.
21+
2122
Parameters
2223
----------
23-
V : np.ndarray, shape (N, d)
24-
Matrix whose columns are basis vectors.
24+
Vs : np.ndarrays, shape (N, d)
25+
Matrices whose columns are basis vectors.
2526
ortho_lambda : float
2627
Regularization hyperparameter.
28+
2729
Returns
2830
-------
2931
reg_val : float
@@ -35,12 +37,14 @@ def ortho_reg_fn(ortho_lambda, *Vs):
3537
def _ortho_reg_fn(ortho_lambda, V):
3638
"""Regularization term which encourages the basis vectors in the
3739
columns of V to be orthonormal.
40+
3841
Parameters
3942
----------
4043
V : np.ndarray, shape (N, d)
4144
Matrix whose columns are basis vectors.
4245
ortho_lambda : float
4346
Regularization hyperparameter.
47+
4448
Returns
4549
-------
4650
reg_val : float
@@ -165,17 +169,17 @@ class BaseComponentsAnalysis(object):
165169
Number of samples to skip when estimating cross covariance matrices. Settings stride > 1
166170
will speedup covariance estimation but may reduce the quality of the covariance estimate
167171
for small datasets.
172+
chunk_cov_estimate : None or int
173+
If `None`, cov is estimated from entire time series. If an `int`, cov is estimated
174+
by chunking up time series and averaging covariances from chucks. This can use less memory
175+
and be faster for long timeseries. Requires that the length of the shortest timeseries
176+
in the batch is longer than `T * chunk_cov_estimate`.
168177
tol : float
169178
Tolerance for stopping optimization. Default is 1e-6.
170179
ortho_lambda : float
171180
Coefficient on term that keeps V close to orthonormal.
172181
verbose : bool
173182
Verbosity during optimization.
174-
chunk_cov_estimate : None or int
175-
If `None`, cov is estimated from entire time series. If an `int`, cov is estimated
176-
by chunking up time series and averaging covariances from chucks. This can use less memory
177-
and be faster for long timeseries. Requires that the length of the shortest timeseries
178-
in the batch is longer than `2 * T * chunk_cov_estimate`.
179183
device : str
180184
What device to run the computation on in Pytorch.
181185
dtype : pytorch.dtype
@@ -190,13 +194,15 @@ class BaseComponentsAnalysis(object):
190194
T_fit : int
191195
T used for last cross covariance estimation.
192196
"""
193-
def __init__(self, T=None, init="random_ortho", n_init=1, stride=1, tol=1e-6,
194-
verbose=False, device="cpu", dtype=torch.float64, rng_or_seed=None):
197+
def __init__(self, T=None, init="random_ortho", n_init=1, stride=1,
198+
chunk_cov_estimate=None, tol=1e-6, verbose=False, device="cpu",
199+
dtype=torch.float64, rng_or_seed=None):
195200
self.T = T
196201
self.T_fit = None
197202
self.init = init
198203
self.n_init = n_init
199204
self.stride = stride
205+
self.chunk_cov_estimate = chunk_cov_estimate
200206
self.tol = tol
201207
self.verbose = verbose
202208
self._logger = logging.getLogger('Model')
@@ -215,12 +221,12 @@ def estimate_data_statistics(self):
215221
raise NotImplementedError
216222

217223
def _fit_projection(self):
218-
"""Fit the projections.
224+
"""Fit the a single round of projections.
219225
"""
220226
raise NotImplementedError
221227

222228
def fit_projection(self):
223-
"""Fit the projections.
229+
"""Fit the projections, with `n_init` restarts.
224230
"""
225231
raise NotImplementedError
226232

@@ -250,10 +256,9 @@ def score(self):
250256
class SingleProjectionComponentsAnalysis(BaseComponentsAnalysis):
251257
"""Base class for Components Analysis with 1 projection.
252258
253-
Runs DCA on multidimensional timeseries data X to discover a projection
254-
onto a d-dimensional subspace of an N-dimensional space which maximizes the complexity, as
255-
defined by the Gaussian Predictive Information (PI) of the d-dimensional dynamics over windows
256-
of length T.
259+
Runs a Components Analysis method on multidimensional timeseries data X to discover a projection
260+
onto a d-dimensional subspace of an N-dimensional space which maximizes the score of the
261+
d-dimensional dynamics over windows of length T.
257262
258263
Parameters
259264
----------
@@ -272,22 +277,15 @@ class SingleProjectionComponentsAnalysis(BaseComponentsAnalysis):
272277
Number of samples to skip when estimating cross covariance matrices. Settings stride > 1
273278
will speedup covariance estimation but may reduce the quality of the covariance estimate
274279
for small datasets.
275-
tol : float
276-
Tolerance for stopping optimization. Default is 1e-6.
277-
ortho_lambda : float
278-
Coefficient on term that keeps V close to orthonormal.
279-
verbose : bool
280-
Verbosity during optimization.
281-
use_scipy : bool
282-
Whether to use SciPy or Pytorch L-BFGS-B. Default is True. Pytorch is not well tested.
283-
block_toeplitz : bool
284-
If True, uses the block-Toeplitz logdet algorithm which is typically faster and less
285-
memory intensive on cpu for `T >~ 10` and `d >~ 40`.
286280
chunk_cov_estimate : None or int
287281
If `None`, cov is estimated from entire time series. If an `int`, cov is estimated
288282
by chunking up time series and averaging covariances from chucks. This can use less memory
289283
and be faster for long timeseries. Requires that the length of the shortest timeseries
290-
in the batch is longer than `2 * T * chunk_cov_estimate`.
284+
in the batch is longer than `T * chunk_cov_estimate`.
285+
tol : float
286+
Tolerance for stopping optimization. Default is 1e-6.
287+
verbose : bool
288+
Verbosity during optimization.
291289
device : str
292290
What device to run the computation on in Pytorch.
293291
dtype : pytorch.dtype
@@ -310,11 +308,13 @@ class SingleProjectionComponentsAnalysis(BaseComponentsAnalysis):
310308
coef_ : ndarray (N, d)
311309
Projection matrix from fit.
312310
"""
313-
def __init__(self, d=None, T=None, init="random_ortho", n_init=1, stride=1, tol=1e-6,
314-
verbose=False, device="cpu", dtype=torch.float64, rng_or_seed=None):
311+
def __init__(self, d=None, T=None, init="random_ortho", n_init=1, stride=1,
312+
chunk_cov_estimate=None, tol=1e-6, verbose=False, device="cpu",
313+
dtype=torch.float64, rng_or_seed=None):
315314

316315
super(SingleProjectionComponentsAnalysis,
317-
self).__init__(T=T, init=init, n_init=n_init, stride=stride, tol=tol,
316+
self).__init__(T=T, init=init, n_init=n_init, stride=stride,
317+
chunk_cov_estimate=chunk_cov_estimate, tol=tol,
318318
verbose=verbose, device=device, dtype=dtype, rng_or_seed=rng_or_seed)
319319

320320
self.d = d

dca/dca.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ class DynamicalComponentsAnalysis(SingleProjectionComponentsAnalysis):
8282
Number of samples to skip when estimating cross covariance matrices. Settings stride > 1
8383
will speedup covariance estimation but may reduce the quality of the covariance estimate
8484
for small datasets.
85+
chunk_cov_estimate : None or int
86+
If `None`, cov is estimated from entire time series. If an `int`, cov is estimated
87+
by chunking up time series and averaging covariances from chucks. This can use less memory
88+
and be faster for long timeseries. Requires that the length of the shortest timeseries
89+
in the batch is longer than `2 * T * chunk_cov_estimate`.
8590
tol : float
8691
Tolerance for stopping optimization. Default is 1e-6.
8792
ortho_lambda : float
@@ -93,11 +98,6 @@ class DynamicalComponentsAnalysis(SingleProjectionComponentsAnalysis):
9398
block_toeplitz : bool
9499
If True, uses the block-Toeplitz logdet algorithm which is typically faster and less
95100
memory intensive on cpu for `T >~ 10` and `d >~ 40`.
96-
chunk_cov_estimate : None or int
97-
If `None`, cov is estimated from entire time series. If an `int`, cov is estimated
98-
by chunking up time series and averaging covariances from chucks. This can use less memory
99-
and be faster for long timeseries. Requires that the length of the shortest timeseries
100-
in the batch is longer than `2 * T * chunk_cov_estimate`.
101101
device : str
102102
What device to run the computation on in Pytorch.
103103
dtype : pytorch.dtype
@@ -120,18 +120,16 @@ class DynamicalComponentsAnalysis(SingleProjectionComponentsAnalysis):
120120
coef_ : ndarray (N, d)
121121
Projection matrix from fit.
122122
"""
123-
def __init__(self, d=None, T=None, init="random_ortho", n_init=1, stride=1, tol=1e-6,
124-
ortho_lambda=10., verbose=False, block_toeplitz=None,
125-
chunk_cov_estimate=None, device="cpu", dtype=torch.float64, rng_or_seed=None):
123+
def __init__(self, d=None, T=None, init="random_ortho", n_init=1, stride=1,
124+
chunk_cov_estimate=None, tol=1e-6, ortho_lambda=10., verbose=False,
125+
block_toeplitz=None, device="cpu", dtype=torch.float64, rng_or_seed=None):
126126

127127
super(DynamicalComponentsAnalysis,
128-
self).__init__(d=d, T=T, init=init, n_init=n_init, stride=stride, tol=tol,
129-
verbose=verbose, device=device, dtype=dtype, rng_or_seed=rng_or_seed)
128+
self).__init__(d=d, T=T, init=init, n_init=n_init, stride=stride,
129+
chunk_cov_estimate=chunk_cov_estimate, tol=tol, verbose=verbose,
130+
device=device, dtype=dtype, rng_or_seed=rng_or_seed)
130131

131132
self.ortho_lambda = ortho_lambda
132-
self.chunk_cov_estimate = chunk_cov_estimate
133-
self.d = d
134-
self.d_fit = None
135133
if block_toeplitz is None:
136134
try:
137135
if d > 40 and T > 10:
@@ -143,7 +141,6 @@ def __init__(self, d=None, T=None, init="random_ortho", n_init=1, stride=1, tol=
143141
else:
144142
self.block_toeplitz = block_toeplitz
145143
self.cross_covs = None
146-
self.coef_ = None
147144

148145
def estimate_data_statistics(self, X, T=None, regularization=None, reg_ops=None):
149146
"""Estimate the cross covariance matrix from data.

0 commit comments

Comments
 (0)