55import torch
66import unittest
77from absl .testing import absltest
8- from absl import flags # Import absl.flags
8+ from absl import flags
99import numpy as np
10- import torch
1110
12- # Define a command-line flag for models_dir
1311FLAGS = flags .FLAGS
1412flags .DEFINE_string ('models_dir' , None , 'Directory to load scheduler config.' )
1513flags .mark_flag_as_required ('models_dir' )
1614
1715
18-
1916class rfTest (unittest .TestCase ):
2017
2118 def test_rf_steps (self ):
22- # --- Configuration Parameters for the Scheduler ---
23- # You can modify these parameters to test different scheduler behaviors
24-
2519 # --- Simulation Parameters ---
2620 latent_tensor_shape = (1 , 256 , 128 ) # Example latent tensor shape (Batch, Channels, Height, Width)
2721 inference_steps_count = 5 # Number of steps for the denoising process
2822
2923 # --- Run the Simulation ---
3024 # Use the value from the command-line flag
3125 models_dir = FLAGS .models_dir
32-
26+
3327 # Ensure the directory exists before downloading
3428 os .makedirs (models_dir , exist_ok = True )
3529
@@ -39,7 +33,7 @@ def test_rf_steps(self):
3933 local_dir = models_dir ,
4034 repo_type = "model" ,
4135 )
42- print (f "\n --- Simulating RectifiedFlowMultistepScheduler ---" )
36+ print ("\n --- Simulating RectifiedFlowMultistepScheduler ---" )
4337
4438 seed = 42
4539 device = 'cpu'
@@ -87,9 +81,6 @@ def test_rf_steps(self):
8781 base_dir = os .path .dirname (__file__ )
8882 ref_dir = os .path .join (base_dir , "rf_scheduler_test_ref" )
8983 ref_filename = os .path .join (ref_dir , f"step_{ i + 1 :02d} .npy" )
90- # Ensure the reference directory exists for tests that might write to it,
91- # or handle its absence if it's meant to be pre-existing.
92- # For this example, assuming 'rf_scheduler_test_ref' exists with pre-saved .npy files.
9384 if os .path .exists (ref_filename ):
9485 pt_sample = np .load (ref_filename )
9586 torch .testing .assert_close (np .array (sample ), pt_sample )
@@ -105,5 +96,4 @@ def test_rf_steps(self):
10596
10697
10798if __name__ == "__main__" :
108- # absltest.main() automatically parses flags defined by absl.flags
109- absltest .main ()
99+ absltest .main ()
0 commit comments