1+ import jax .numpy as jnp
2+ from maxdiffusion .schedulers .scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler
3+ import os
4+ from huggingface_hub import hf_hub_download
5+ import torch
6+ import unittest
7+ from absl .testing import absltest
8+ from absl import flags # Import absl.flags
9+ import numpy as np
10+ import torch
11+
12+ # Define a command-line flag for models_dir
13+ FLAGS = flags .FLAGS
14+ flags .DEFINE_string ('models_dir' , None , 'Directory to load scheduler config.' )
15+ flags .mark_flag_as_required ('models_dir' )
16+
17+
18+
19+ class rfTest (unittest .TestCase ):
20+
21+ def test_rf_steps (self ):
22+ # --- Configuration Parameters for the Scheduler ---
23+ # You can modify these parameters to test different scheduler behaviors
24+
25+ # --- Simulation Parameters ---
26+ latent_tensor_shape = (1 , 256 , 128 ) # Example latent tensor shape (Batch, Channels, Height, Width)
27+ inference_steps_count = 5 # Number of steps for the denoising process
28+
29+ # --- Run the Simulation ---
30+ # Use the value from the command-line flag
31+ models_dir = FLAGS .models_dir
32+
33+ # Ensure the directory exists before downloading
34+ os .makedirs (models_dir , exist_ok = True )
35+
36+ ltxv_model_path = hf_hub_download (
37+ repo_id = "Lightricks/LTX-Video" ,
38+ filename = "ltxv-13b-0.9.7-dev.safetensors" ,
39+ local_dir = models_dir ,
40+ repo_type = "model" ,
41+ )
42+ print (f"\n --- Simulating RectifiedFlowMultistepScheduler ---" )
43+
44+ seed = 42
45+ device = 'cpu'
46+ print (f"Sample shape: { latent_tensor_shape } , Inference steps: { inference_steps_count } , Seed: { seed } " )
47+
48+ generator = torch .Generator (device = device ).manual_seed (seed )
49+
50+ # 1. Instantiate the scheduler
51+ flax_scheduler = FlaxRectifiedFlowMultistepScheduler .from_pretrained_jax (ltxv_model_path )
52+
53+ # 2. Create and set initial state for the scheduler
54+ flax_state = flax_scheduler .create_state ()
55+ flax_state = flax_scheduler .set_timesteps (flax_state , inference_steps_count , latent_tensor_shape )
56+ print ("\n Scheduler initialized." )
57+ print (f" flax_state timesteps shape: { flax_state .timesteps .shape } " )
58+
59+ # 3. Prepare the initial noisy latent sample
60+ # In a real scenario, this would typically be pure random noise (e.g., N(0,1))
61+ # For simulation, we'll generate it.
62+
63+ sample = jnp .array (torch .randn (latent_tensor_shape , generator = generator , dtype = torch .float32 ).to (device ).numpy ())
64+ print (f"\n Initial sample shape: { sample .shape } , dtype: { sample .dtype } " )
65+
66+ # 4. Simulate the denoising loop
67+ print ("\n Starting denoising loop:" )
68+ for i , t in enumerate (flax_state .timesteps ):
69+ print (f" Step { i + 1 } /{ inference_steps_count } , Timestep: { t .item ()} " )
70+
71+ # Simulate model_output (e.g., noise prediction from a UNet)
72+ model_output = jnp .array (torch .randn (latent_tensor_shape , generator = generator , dtype = torch .float32 ).to (device ).numpy ())
73+
74+ # Call the scheduler's step function
75+ scheduler_output = flax_scheduler .step (
76+ state = flax_state ,
77+ model_output = model_output ,
78+ timestep = t , # Pass the current timestep from the scheduler's sequence
79+ sample = sample ,
80+ return_dict = True # Return a SchedulerOutput dataclass
81+ )
82+
83+ sample = scheduler_output .prev_sample # Update the sample for the next step
84+ flax_state = scheduler_output .state # Update the state for the next step
85+
86+ # Compare with pytorch implementation
87+ base_dir = os .path .dirname (__file__ )
88+ ref_dir = os .path .join (base_dir , "rf_scheduler_test_ref" )
89+ ref_filename = os .path .join (ref_dir , f"step_{ i + 1 :02d} .npy" )
90+ # Ensure the reference directory exists for tests that might write to it,
91+ # or handle its absence if it's meant to be pre-existing.
92+ # For this example, assuming 'rf_scheduler_test_ref' exists with pre-saved .npy files.
93+ if os .path .exists (ref_filename ):
94+ pt_sample = np .load (ref_filename )
95+ torch .testing .assert_close (np .array (sample ), pt_sample )
96+ else :
97+ print (f"Warning: Reference file not found: { ref_filename } " )
98+
99+
100+ print ("\n Denoising loop completed." )
101+ print (f"Final sample shape: { sample .shape } , dtype: { sample .dtype } " )
102+ print (f"Final sample min: { sample .min ().item ():.4f} , max: { sample .max ().item ():.4f} " )
103+
104+ print ("\n Simulation of RectifiedMultistepScheduler usage complete." )
105+
106+
107+ if __name__ == "__main__" :
108+ # absltest.main() automatically parses flags defined by absl.flags
109+ absltest .main ()
0 commit comments