Skip to content

Commit 095502a

Browse files
Cristian GarciaGoogle-ML-Automation
authored andcommitted
change defaults for Dropout and BatchNorm
Changes `Dropout.deterministic` and `BatchNorm.use_running_average` to be None by default, use now has to explicitely provide them by either: 1. Passing them to the constructor e.g: self.bn = nnx.BatchNorm(..., use_running_average=False) 2. Passing them to __call__: self.dropout(x, deterministic=False) 3. Using `nnx.view` to create a view of the model with specific values: train_model = nnx.view(model, detereministic=False, use_running_average=False) PiperOrigin-RevId: 878147422
1 parent 68e0696 commit 095502a

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ def __init__(
10111011
),
10121012
)
10131013

1014-
self.drop_out = nnx.Dropout(dropout)
1014+
self.drop_out = nnx.Dropout(dropout, deterministic=False)
10151015

10161016
self.norm_q = nnx.data(None)
10171017
self.norm_k = nnx.data(None)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def __init__(
237237
else:
238238
raise NotImplementedError(f"{activation_fn} is not implemented.")
239239

240-
self.drop_out = nnx.Dropout(dropout)
240+
self.drop_out = nnx.Dropout(dropout, deterministic=False)
241241
self.proj_out = nnx.Linear(
242242
rngs=rngs,
243243
in_features=inner_dim,

0 commit comments

Comments
 (0)