Skip to content

Commit fb25b23

Browse files
ninatumartinarroyo
andcommitted
Update wan configs for training
- Ensure `adam_weight_decay` is a float. - Add `tensorboard_dir` parameter for logging. Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent 4afed9f commit fb25b23

5 files changed

Lines changed: 10 additions & 5 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ diffusion_scheduler_config: {
145145
# Output directory
146146
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
147147
base_output_directory: ""
148+
tensorboard_dir: ""
148149

149150
# Hardware
150151
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
@@ -300,7 +301,7 @@ save_optimizer: False
300301
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
301302
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
302303
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
303-
adam_weight_decay: 0 # AdamW Weight decay
304+
adam_weight_decay: 0.0 # AdamW Weight decay
304305
max_grad_norm: 1.0
305306

306307
enable_profiler: False

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ diffusion_scheduler_config: {
122122
# Output directory
123123
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
124124
base_output_directory: ""
125+
tensorboard_dir: ""
125126

126127
# Hardware
127128
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
@@ -256,7 +257,7 @@ save_optimizer: False
256257
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
257258
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
258259
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
259-
adam_weight_decay: 0 # AdamW Weight decay
260+
adam_weight_decay: 0.0 # AdamW Weight decay
260261
max_grad_norm: 1.0
261262

262263
enable_profiler: False

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ diffusion_scheduler_config: {
133133
# Output directory
134134
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
135135
base_output_directory: ""
136+
tensorboard_dir: ""
136137

137138
# Hardware
138139
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
@@ -267,7 +268,7 @@ save_optimizer: False
267268
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
268269
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
269270
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
270-
adam_weight_decay: 0 # AdamW Weight decay
271+
adam_weight_decay: 0.0 # AdamW Weight decay
271272
max_grad_norm: 1.0
272273

273274
enable_profiler: False

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ diffusion_scheduler_config: {
128128
# Output directory
129129
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
130130
base_output_directory: ""
131+
tensorboard_dir: ""
131132

132133
# Hardware
133134
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
@@ -262,7 +263,7 @@ save_optimizer: False
262263
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
263264
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
264265
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
265-
adam_weight_decay: 0 # AdamW Weight decay
266+
adam_weight_decay: 0.0 # AdamW Weight decay
266267
max_grad_norm: 1.0
267268

268269
enable_profiler: False

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ diffusion_scheduler_config: {
129129
# Output directory
130130
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
131131
base_output_directory: ""
132+
tensorboard_dir: ""
132133

133134
# Hardware
134135
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
@@ -263,7 +264,7 @@ save_optimizer: False
263264
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
264265
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
265266
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
266-
adam_weight_decay: 0 # AdamW Weight decay
267+
adam_weight_decay: 0.0 # AdamW Weight decay
267268
max_grad_norm: 1.0
268269

269270
enable_profiler: False

0 commit comments

Comments
 (0)