-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_cellular.py
More file actions
89 lines (74 loc) · 2.78 KB
/
run_cellular.py
File metadata and controls
89 lines (74 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import src.patch
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
from src.config import SimConfig, ForceConfig
from src.state import ParticleState
from src.flow import flow_cellular, temp_constant
from src.boundary import BoundaryManager
from src.solver import run_simulation_euler
def main():
# Fix alpha (domain size) and derive U_0 based on Maxey parameters.
config = SimConfig.from_maxey(
W=0.5, # Settling velocity ratio
A=0.5, # Inertia parameter
alpha=1.0, # Fixed flow scale (Domain size controls this)
rho_particle=2650.0, # Sand
rho_fluid=1.225, # Air
mu_fluid=1.81e-5, # Air
g=-9.81,
# Additional flags
enable_turbulence=False,
turbulence_intensity=0.2,
)
print("Maxey Config Generated:")
print(f" Alpha (Fixed): {config.alpha:.4f} m")
print(f" U_0 (Derived): {config.U_0:.4f} m/s")
print(f" d_p (Derived): {config.d_particle:.6e} m")
print(f" Stk (Calc): {config.get_stokes_number():.4f}")
force_config = ForceConfig(gravity=True, undisturbed_flow=True, drag=True)
# Domain & Boundaries
L = 4 * np.pi * config.alpha
bounds = BoundaryManager(x_bounds=(0.0, L), y_bounds=(0.0, L), periodic=True)
# Initialization
n_particles = 10000
grid_side = int(np.sqrt(n_particles))
gx = np.linspace(0.1, L - 0.1, grid_side)
gy = np.linspace(0.1, L - 0.1, grid_side)
mx, my = np.meshgrid(gx, gy)
pos = jnp.array(np.stack([mx.ravel(), my.ravel()], axis=1))
# Recalculate actual n_particles (in case of rounding)
actual_n = pos.shape[0]
print(f"Initializing {actual_n} particles on a {grid_side}x{grid_side} grid...")
vel = jax.vmap(lambda p: flow_cellular(p, config))(pos)
# Initialize Mass
m_p = config.m_particle_init
mass = jnp.full((actual_n,), m_p)
initial_state = ParticleState(position=pos, velocity=vel, mass=mass)
t_end = 10.0
dt = 0.005
t_eval = jnp.array(np.arange(0.0, t_end, dt))
key = jnp.array([0, 0], dtype=jnp.uint32)
# Print the stokes number
stokes = config.get_stokes_number()
print(f"Calculated Stokes Number: {stokes:.4f}")
print("Running Cellular Flow Simulation...")
history = run_simulation_euler(
initial_state, t_eval, config, force_config, bounds, flow_cellular, temp_constant, key
)
print("Generating Video (JAX Rasterizer)...")
from src.jax_visualizer import JAXVisualizer
flat_bounds = (0.0, L, 0.0, L)
viz = JAXVisualizer(config, history, t_eval, flow_cellular, temp_constant)
viz.generate_video(
"cellular_flow.mp4",
bounds=flat_bounds,
width=600,
height=600,
fps=20,
slow_mo_factor=1.0,
)
print("Done.")
if __name__ == "__main__":
main()