Skip to content

Commit 979295c

Browse files
committed
rectified flow scheduler test added
1 parent a3c35ce commit 979295c

1 file changed

Lines changed: 4 additions & 14 deletions

File tree

tests/schedulers/test_scheduler_rf.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,25 @@
55
import torch
66
import unittest
77
from absl.testing import absltest
8-
from absl import flags # Import absl.flags
8+
from absl import flags
99
import numpy as np
10-
import torch
1110

12-
# Define a command-line flag for models_dir
1311
FLAGS = flags.FLAGS
1412
flags.DEFINE_string('models_dir', None, 'Directory to load scheduler config.')
1513
flags.mark_flag_as_required('models_dir')
1614

1715

18-
1916
class rfTest(unittest.TestCase):
2017

2118
def test_rf_steps(self):
22-
# --- Configuration Parameters for the Scheduler ---
23-
# You can modify these parameters to test different scheduler behaviors
24-
2519
# --- Simulation Parameters ---
2620
latent_tensor_shape = (1, 256, 128) # Example latent tensor shape (Batch, Channels, Height, Width)
2721
inference_steps_count = 5 # Number of steps for the denoising process
2822

2923
# --- Run the Simulation ---
3024
# Use the value from the command-line flag
3125
models_dir = FLAGS.models_dir
32-
26+
3327
# Ensure the directory exists before downloading
3428
os.makedirs(models_dir, exist_ok=True)
3529

@@ -39,7 +33,7 @@ def test_rf_steps(self):
3933
local_dir=models_dir,
4034
repo_type="model",
4135
)
42-
print(f"\n--- Simulating RectifiedFlowMultistepScheduler ---")
36+
print("\n--- Simulating RectifiedFlowMultistepScheduler ---")
4337

4438
seed = 42
4539
device = 'cpu'
@@ -87,9 +81,6 @@ def test_rf_steps(self):
8781
base_dir = os.path.dirname(__file__)
8882
ref_dir = os.path.join(base_dir, "rf_scheduler_test_ref")
8983
ref_filename = os.path.join(ref_dir, f"step_{i+1:02d}.npy")
90-
# Ensure the reference directory exists for tests that might write to it,
91-
# or handle its absence if it's meant to be pre-existing.
92-
# For this example, assuming 'rf_scheduler_test_ref' exists with pre-saved .npy files.
9384
if os.path.exists(ref_filename):
9485
pt_sample = np.load(ref_filename)
9586
torch.testing.assert_close(np.array(sample), pt_sample)
@@ -105,5 +96,4 @@ def test_rf_steps(self):
10596

10697

10798
if __name__ == "__main__":
108-
# absltest.main() automatically parses flags defined by absl.flags
109-
absltest.main()
99+
absltest.main()

0 commit comments

Comments
 (0)