Skip to content

Commit 6101386

Browse files
ninatumartinarroyo
andcommitted
Wan training: Set default dropout to 0.0 in Wan configs
Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent 5c6f65f commit 6101386

5 files changed

Lines changed: 5 additions & 5 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ mask_padding_tokens: True
7272
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
7373
# in cross attention q.
7474
attention_sharding_uniform: True
75-
dropout: 0.1
75+
dropout: 0.0
7676

7777
flash_block_sizes: {
7878
"block_q" : 512,

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ mask_padding_tokens: True
7272
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
7373
# in cross attention q.
7474
attention_sharding_uniform: True
75-
dropout: 0.1
75+
dropout: 0.0
7676

7777
flash_block_sizes: {
7878
"block_q" : 512,

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ mask_padding_tokens: True
7171
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
7272
# in cross attention q.
7373
attention_sharding_uniform: True
74-
dropout: 0.1
74+
dropout: 0.0
7575

7676
flash_block_sizes: {
7777
"block_q" : 512,

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
6464
flash_min_seq_length: 4096
65-
dropout: 0.1
65+
dropout: 0.0
6666

6767
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
6868
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
6464
flash_min_seq_length: 4096
65-
dropout: 0.1
65+
dropout: 0.0
6666

6767
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
6868
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.

0 commit comments

Comments
 (0)