Skip to content

Commit 1fe4ce0

Browse files
ninatumartinarroyo
andcommitted
Wan training: use learning rate from config
Replaces the hardcoded learning rate in the optimizer creation with the value from `config.learning_rate`. Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent e205aa1 commit 1fe4ce0

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def start_training(self):
257257
scheduler, scheduler_state = self.create_scheduler()
258258
pipeline.scheduler = scheduler
259259
pipeline.scheduler_state = scheduler_state
260-
optimizer, learning_rate_scheduler = self.checkpointer._create_optimizer(pipeline.transformer, self.config, 1e-5)
260+
optimizer, learning_rate_scheduler = self.checkpointer._create_optimizer(
261+
pipeline.transformer, self.config, self.config.learning_rate
262+
)
261263
# Returns pipeline with trained transformer state
262264
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args)
263265

0 commit comments

Comments
 (0)