Skip to content

Commit a94d80a

Browse files
authored
Merge pull request #61 from JesseLivezey/synth_edits
more careful rng in synth
2 parents a4d21d3 + ece7024 commit a94d80a

1 file changed

Lines changed: 8 additions & 7 deletions

File tree

dca/synth_data.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,15 @@ def embed_gp(T, N, d, kernel, noise_cov, T_pi, num_to_concat=1):
204204
return X, Y, U, full_pi, embedding_pi, high_d_cross_cov_mats
205205

206206

207-
def gen_lorenz_system(T, integration_dt=0.005):
207+
def gen_lorenz_system(T, seed, integration_dt=0.005):
208208
"""
209209
Period ~ 1 unit of time (total time is T)
210210
So make sure integration_dt << 1
211211
212212
Known-to-be-good chaotic parameters
213213
See sussillo LFADS paper
214214
"""
215+
rng = np.random.RandomState(seed)
215216
rho = 28.0
216217
sigma = 10.0
217218
beta = 8 / 3.
@@ -223,23 +224,23 @@ def dx_dt(state, t):
223224
z_dot = x * y - beta * z
224225
return (x_dot, y_dot, z_dot)
225226

226-
x_0 = np.ones(3)
227+
x_0 = rng.randn(3)
227228
t = np.arange(0, T, integration_dt)
228229
X = scipy.integrate.odeint(dx_dt, x_0, t)
229230
return X
230231

231232

232-
def gen_lorenz_data(num_samples, normalize=True):
233+
def gen_lorenz_data(num_samples, normalize=True, seed=20210610):
233234
integration_dt = 0.005
234235
data_dt = 0.025
235236
skipped_samples = 1000
236-
T = (num_samples + skipped_samples) * data_dt
237-
X = gen_lorenz_system(T, integration_dt)
237+
T = (num_samples + 2 * skipped_samples) * data_dt
238+
X = gen_lorenz_system(T, seed, integration_dt)
238239
if normalize:
239240
X -= X.mean(axis=0)
240241
X /= X.std(axis=0)
241-
X_dwn = resample(X, num_samples + skipped_samples, axis=0)
242-
X_dwn = X_dwn[skipped_samples:, :]
242+
X_dwn = resample(X, num_samples + 2 * skipped_samples, axis=0)
243+
X_dwn = X_dwn[skipped_samples:-skipped_samples]
243244
return X_dwn
244245

245246

0 commit comments

Comments
 (0)