Commit e205aa1
Wan training: Resolve training mode bug with dropout and layer_forward
- Conditionally apply dropout only when rate > 0.
- Use standard list initialization.
- Add rngs parameter to layer_forward (essential for gradient checkpointing with dropout > 0)
Co-authored-by: martinarroyo <martinarroyo@google.com>1 parent fb25b23 commit e205aa1
3 files changed
Lines changed: 15 additions & 7 deletions
File tree
- src/maxdiffusion/models
- wan/transformers
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1239 | 1239 | | |
1240 | 1240 | | |
1241 | 1241 | | |
1242 | | - | |
| 1242 | + | |
| 1243 | + | |
| 1244 | + | |
| 1245 | + | |
1243 | 1246 | | |
1244 | 1247 | | |
1245 | 1248 | | |
| |||
Lines changed: 4 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
262 | 262 | | |
263 | 263 | | |
264 | 264 | | |
265 | | - | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
266 | 269 | | |
267 | 270 | | |
268 | 271 | | |
| |||
Lines changed: 7 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
487 | 487 | | |
488 | 488 | | |
489 | 489 | | |
490 | | - | |
| 490 | + | |
491 | 491 | | |
492 | 492 | | |
493 | | - | |
| 493 | + | |
494 | 494 | | |
495 | 495 | | |
496 | 496 | | |
| |||
507 | 507 | | |
508 | 508 | | |
509 | 509 | | |
510 | | - | |
| 510 | + | |
| 511 | + | |
| 512 | + | |
511 | 513 | | |
512 | 514 | | |
513 | 515 | | |
514 | 516 | | |
515 | 517 | | |
516 | 518 | | |
517 | | - | |
| 519 | + | |
518 | 520 | | |
519 | 521 | | |
520 | 522 | | |
| |||
530 | 532 | | |
531 | 533 | | |
532 | 534 | | |
533 | | - | |
| 535 | + | |
534 | 536 | | |
535 | 537 | | |
536 | 538 | | |
| |||
0 commit comments