Skip to content

Commit a3c35ce

Browse files
committed
added rf scheduler test
1 parent 8274aca commit a3c35ce

7 files changed

Lines changed: 109 additions & 0 deletions

File tree

128 KB
Binary file not shown.
128 KB
Binary file not shown.
128 KB
Binary file not shown.
128 KB
Binary file not shown.
128 KB
Binary file not shown.
128 KB
Binary file not shown.
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import jax.numpy as jnp
2+
from maxdiffusion.schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler
3+
import os
4+
from huggingface_hub import hf_hub_download
5+
import torch
6+
import unittest
7+
from absl.testing import absltest
8+
from absl import flags # Import absl.flags
9+
import numpy as np
10+
import torch
11+
12+
# Define a command-line flag for models_dir
13+
FLAGS = flags.FLAGS
14+
flags.DEFINE_string('models_dir', None, 'Directory to load scheduler config.')
15+
flags.mark_flag_as_required('models_dir')
16+
17+
18+
19+
class rfTest(unittest.TestCase):
20+
21+
def test_rf_steps(self):
22+
# --- Configuration Parameters for the Scheduler ---
23+
# You can modify these parameters to test different scheduler behaviors
24+
25+
# --- Simulation Parameters ---
26+
latent_tensor_shape = (1, 256, 128) # Example latent tensor shape (Batch, Channels, Height, Width)
27+
inference_steps_count = 5 # Number of steps for the denoising process
28+
29+
# --- Run the Simulation ---
30+
# Use the value from the command-line flag
31+
models_dir = FLAGS.models_dir
32+
33+
# Ensure the directory exists before downloading
34+
os.makedirs(models_dir, exist_ok=True)
35+
36+
ltxv_model_path = hf_hub_download(
37+
repo_id="Lightricks/LTX-Video",
38+
filename="ltxv-13b-0.9.7-dev.safetensors",
39+
local_dir=models_dir,
40+
repo_type="model",
41+
)
42+
print(f"\n--- Simulating RectifiedFlowMultistepScheduler ---")
43+
44+
seed = 42
45+
device = 'cpu'
46+
print(f"Sample shape: {latent_tensor_shape}, Inference steps: {inference_steps_count}, Seed: {seed}")
47+
48+
generator = torch.Generator(device=device).manual_seed(seed)
49+
50+
# 1. Instantiate the scheduler
51+
flax_scheduler = FlaxRectifiedFlowMultistepScheduler.from_pretrained_jax(ltxv_model_path)
52+
53+
# 2. Create and set initial state for the scheduler
54+
flax_state = flax_scheduler.create_state()
55+
flax_state = flax_scheduler.set_timesteps(flax_state, inference_steps_count, latent_tensor_shape)
56+
print("\nScheduler initialized.")
57+
print(f" flax_state timesteps shape: {flax_state.timesteps.shape}")
58+
59+
# 3. Prepare the initial noisy latent sample
60+
# In a real scenario, this would typically be pure random noise (e.g., N(0,1))
61+
# For simulation, we'll generate it.
62+
63+
sample = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy())
64+
print(f"\nInitial sample shape: {sample.shape}, dtype: {sample.dtype}")
65+
66+
# 4. Simulate the denoising loop
67+
print("\nStarting denoising loop:")
68+
for i, t in enumerate(flax_state.timesteps):
69+
print(f" Step {i+1}/{inference_steps_count}, Timestep: {t.item()}")
70+
71+
# Simulate model_output (e.g., noise prediction from a UNet)
72+
model_output = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy())
73+
74+
# Call the scheduler's step function
75+
scheduler_output = flax_scheduler.step(
76+
state=flax_state,
77+
model_output=model_output,
78+
timestep=t, # Pass the current timestep from the scheduler's sequence
79+
sample=sample,
80+
return_dict=True # Return a SchedulerOutput dataclass
81+
)
82+
83+
sample = scheduler_output.prev_sample # Update the sample for the next step
84+
flax_state = scheduler_output.state # Update the state for the next step
85+
86+
# Compare with pytorch implementation
87+
base_dir = os.path.dirname(__file__)
88+
ref_dir = os.path.join(base_dir, "rf_scheduler_test_ref")
89+
ref_filename = os.path.join(ref_dir, f"step_{i+1:02d}.npy")
90+
# Ensure the reference directory exists for tests that might write to it,
91+
# or handle its absence if it's meant to be pre-existing.
92+
# For this example, assuming 'rf_scheduler_test_ref' exists with pre-saved .npy files.
93+
if os.path.exists(ref_filename):
94+
pt_sample = np.load(ref_filename)
95+
torch.testing.assert_close(np.array(sample), pt_sample)
96+
else:
97+
print(f"Warning: Reference file not found: {ref_filename}")
98+
99+
100+
print("\nDenoising loop completed.")
101+
print(f"Final sample shape: {sample.shape}, dtype: {sample.dtype}")
102+
print(f"Final sample min: {sample.min().item():.4f}, max: {sample.max().item():.4f}")
103+
104+
print("\nSimulation of RectifiedMultistepScheduler usage complete.")
105+
106+
107+
if __name__ == "__main__":
108+
# absltest.main() automatically parses flags defined by absl.flags
109+
absltest.main()

0 commit comments

Comments
 (0)