Skip to content

Commit 35076d3

Browse files
ninatumartinarroyo
andcommitted
Add gradient clipping options to optimizer
Introduces options for clipping gradients by global norm or by value, configurable via `config.opt_enable_grad_global_norm_clipping` and `config.opt_enable_grad_clipping`, as well as `config.max_grad_norm` and `config.max_grad_value`. Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent 46cae70 commit 35076d3

14 files changed

Lines changed: 46 additions & 1 deletion

src/maxdiffusion/configs/base14.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
206206
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
207207
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
208208
adam_weight_decay: 1.e-2 # AdamW Weight decay
209+
opt_enable_grad_clipping: False
210+
max_grad_value: 1.0
211+
opt_enable_grad_global_norm_clipping: False
209212
max_grad_norm: 1.0
210213

211214
enable_profiler: False

src/maxdiffusion/configs/base21.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
211211
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
212212
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
213213
adam_weight_decay: 1.e-2 # AdamW Weight decay
214+
opt_enable_grad_clipping: False
215+
max_grad_value: 1.0
216+
opt_enable_grad_global_norm_clipping: False
214217
max_grad_norm: 1.0
215218

216219
enable_profiler: False

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
221221
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
222222
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
223223
adam_weight_decay: 1.e-2 # AdamW Weight decay
224+
opt_enable_grad_clipping: False
225+
max_grad_value: 1.0
226+
opt_enable_grad_global_norm_clipping: False
224227
max_grad_norm: 1.0
225228

226229
enable_profiler: False

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
245245
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
246246
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
247247
adam_weight_decay: 0 # AdamW Weight decay
248+
opt_enable_grad_clipping: False
249+
max_grad_value: 1.0
250+
opt_enable_grad_global_norm_clipping: False
248251
max_grad_norm: 1.0
249252

250253
enable_profiler: False

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
232232
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
233233
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
234234
adam_weight_decay: 1.e-2 # AdamW Weight decay
235+
opt_enable_grad_clipping: False
236+
max_grad_value: 1.0
237+
opt_enable_grad_global_norm_clipping: False
235238
max_grad_norm: 1.0
236239

237240
enable_profiler: False

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
240240
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
241241
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
242242
adam_weight_decay: 1.e-2 # AdamW Weight decay
243+
opt_enable_grad_clipping: False
244+
max_grad_value: 1.0
245+
opt_enable_grad_global_norm_clipping: False
243246
max_grad_norm: 1.0
244247

245248
enable_profiler: False

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
302302
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
303303
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
304304
adam_weight_decay: 0.0 # AdamW Weight decay
305+
opt_enable_grad_clipping: False
306+
max_grad_value: 1.0
307+
opt_enable_grad_global_norm_clipping: False
305308
max_grad_norm: 1.0
306309

307310
enable_profiler: False

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
258258
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
259259
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
260260
adam_weight_decay: 0.0 # AdamW Weight decay
261+
opt_enable_grad_clipping: False
262+
max_grad_value: 1.0
263+
opt_enable_grad_global_norm_clipping: False
261264
max_grad_norm: 1.0
262265

263266
enable_profiler: False

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
268268
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
269269
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
270270
adam_weight_decay: 0.0 # AdamW Weight decay
271+
opt_enable_grad_clipping: False
272+
max_grad_value: 1.0
273+
opt_enable_grad_global_norm_clipping: False
271274
max_grad_norm: 1.0
272275

273276
enable_profiler: False

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradient
263263
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
264264
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
265265
adam_weight_decay: 0.0 # AdamW Weight decay
266+
opt_enable_grad_clipping: False
267+
max_grad_value: 1.0
268+
opt_enable_grad_global_norm_clipping: False
266269
max_grad_norm: 1.0
267270

268271
enable_profiler: False

0 commit comments

Comments
 (0)