@@ -204,14 +204,15 @@ def embed_gp(T, N, d, kernel, noise_cov, T_pi, num_to_concat=1):
204204 return X , Y , U , full_pi , embedding_pi , high_d_cross_cov_mats
205205
206206
207- def gen_lorenz_system (T , integration_dt = 0.005 ):
207+ def gen_lorenz_system (T , seed , integration_dt = 0.005 ):
208208 """
209209 Period ~ 1 unit of time (total time is T)
210210 So make sure integration_dt << 1
211211
212212 Known-to-be-good chaotic parameters
213213 See sussillo LFADS paper
214214 """
215+ rng = np .random .RandomState (seed )
215216 rho = 28.0
216217 sigma = 10.0
217218 beta = 8 / 3.
@@ -223,23 +224,23 @@ def dx_dt(state, t):
223224 z_dot = x * y - beta * z
224225 return (x_dot , y_dot , z_dot )
225226
226- x_0 = np . ones (3 )
227+ x_0 = rng . randn (3 )
227228 t = np .arange (0 , T , integration_dt )
228229 X = scipy .integrate .odeint (dx_dt , x_0 , t )
229230 return X
230231
231232
232- def gen_lorenz_data (num_samples , normalize = True ):
233+ def gen_lorenz_data (num_samples , normalize = True , seed = 20210610 ):
233234 integration_dt = 0.005
234235 data_dt = 0.025
235236 skipped_samples = 1000
236- T = (num_samples + skipped_samples ) * data_dt
237- X = gen_lorenz_system (T , integration_dt )
237+ T = (num_samples + 2 * skipped_samples ) * data_dt
238+ X = gen_lorenz_system (T , seed , integration_dt )
238239 if normalize :
239240 X -= X .mean (axis = 0 )
240241 X /= X .std (axis = 0 )
241- X_dwn = resample (X , num_samples + skipped_samples , axis = 0 )
242- X_dwn = X_dwn [skipped_samples :, : ]
242+ X_dwn = resample (X , num_samples + 2 * skipped_samples , axis = 0 )
243+ X_dwn = X_dwn [skipped_samples :- skipped_samples ]
243244 return X_dwn
244245
245246
0 commit comments