Skip to content

Commit e205aa1

Browse files
ninatumartinarroyo
andcommitted
Wan training: Resolve training mode bug with dropout and layer_forward
- Conditionally apply dropout only when rate > 0. - Use standard list initialization. - Add rngs parameter to layer_forward (essential for gradient checkpointing with dropout > 0) Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent fb25b23 commit e205aa1

3 files changed

Lines changed: 15 additions & 7 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1239,7 +1239,10 @@ def __call__(
12391239

12401240
with jax.named_scope("proj_attn"):
12411241
hidden_states = self.proj_attn(attn_output)
1242-
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
1242+
if self.drop_out.rate > 0:
1243+
hidden_states = self.drop_out(
1244+
hidden_states, deterministic=deterministic, rngs=rngs
1245+
)
12431246
return hidden_states
12441247

12451248

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,10 @@ def conditional_named_scope(self, name: str):
262262
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
263263
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
264264
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
265-
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
265+
if self.drop_out.rate > 0:
266+
hidden_states = self.drop_out(
267+
hidden_states, deterministic=deterministic, rngs=rngs
268+
)
266269
with jax.named_scope("proj_out"):
267270
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
268271

src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -487,10 +487,10 @@ def __call__(
487487
raise NotImplementedError("scan_layers is not supported yet")
488488
else:
489489
# Prepare VACE hints
490-
control_hidden_states_list = nnx.List([])
490+
control_hidden_states_list = []
491491
for i, vace_block in enumerate(self.vace_blocks):
492492

493-
def layer_forward(hidden_states, control_hidden_states):
493+
def layer_forward(hidden_states, control_hidden_states, rngs):
494494
return vace_block(
495495
hidden_states=hidden_states,
496496
encoder_hidden_states=encoder_hidden_states,
@@ -507,14 +507,16 @@ def layer_forward(hidden_states, control_hidden_states):
507507
self.names_which_can_be_offloaded,
508508
prevent_cse=not self.scan_layers,
509509
)
510-
conditioning_states, control_hidden_states = rematted_layer_forward(hidden_states, control_hidden_states)
510+
conditioning_states, control_hidden_states = rematted_layer_forward(
511+
hidden_states, control_hidden_states, rngs
512+
)
511513
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
512514

513515
control_hidden_states_list = control_hidden_states_list[::-1]
514516

515517
for i, block in enumerate(self.blocks):
516518

517-
def layer_forward_vace(hidden_states):
519+
def layer_forward_vace(hidden_states, rngs):
518520
return block(
519521
hidden_states,
520522
encoder_hidden_states,
@@ -530,7 +532,7 @@ def layer_forward_vace(hidden_states):
530532
self.names_which_can_be_offloaded,
531533
prevent_cse=not self.scan_layers,
532534
)
533-
hidden_states = rematted_layer_forward(hidden_states)
535+
hidden_states = rematted_layer_forward(hidden_states, rngs)
534536
if i in self.config.vace_layers:
535537
control_hint, scale = control_hidden_states_list.pop()
536538
hidden_states = hidden_states + control_hint * scale

0 commit comments

Comments
 (0)