Skip to content

Commit 1ca1b1b

Browse files
committed
revert shardings
1 parent fb12602 commit 1ca1b1b

4 files changed

Lines changed: 22 additions & 16 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ def __init__(
734734
# None axes corresponds to the stacked weights across all blocks
735735
# because of the use of nnx.vmap and nnx.scan.
736736
# Dims are [num_blocks, embed, heads]
737-
kernel_axes = ("embed", None, "heads")
737+
kernel_axes = (None, "embed", "heads")
738738
qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes)
739739

740740
self.query = nnx.Linear(
@@ -748,8 +748,8 @@ def __init__(
748748
bias_init=nnx.with_partitioning(
749749
nnx.initializers.zeros,
750750
(
751+
None,
751752
"embed",
752-
"heads",
753753
),
754754
),
755755
)
@@ -765,8 +765,8 @@ def __init__(
765765
bias_init=nnx.with_partitioning(
766766
nnx.initializers.zeros,
767767
(
768+
None,
768769
"embed",
769-
"heads",
770770
),
771771
),
772772
)
@@ -782,8 +782,8 @@ def __init__(
782782
bias_init=nnx.with_partitioning(
783783
nnx.initializers.zeros,
784784
(
785+
None,
785786
"embed",
786-
"heads"
787787
),
788788
),
789789
)
@@ -792,15 +792,15 @@ def __init__(
792792
rngs=rngs,
793793
in_features=self.inner_dim,
794794
out_features=self.inner_dim,
795-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads", None)),
795+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads", "embed")),
796796
dtype=dtype,
797797
param_dtype=weights_dtype,
798798
precision=precision,
799799
bias_init=nnx.with_partitioning(
800800
nnx.initializers.zeros,
801801
(
802-
"embed",
803-
None
802+
None,
803+
"heads"
804804
),
805805
),
806806
)

src/maxdiffusion/models/gradient_checkpoint.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class GradientCheckpointType(Enum):
3939
NONE = auto()
4040
FULL = auto()
4141
MATMUL_WITHOUT_BATCH = auto()
42+
OFFLOAD_MATMUL_WITHOUT_BATCH = auto()
4243
ATTN = auto()
4344

4445
@classmethod
@@ -65,10 +66,16 @@ def to_jax_policy(self):
6566
return SKIP_GRADIENT_CHECKPOINT_KEY
6667
case GradientCheckpointType.FULL:
6768
return None
68-
case GradientCheckpointType.ATTN:
69-
return cp.save_and_offload_only_these_names(
70-
names_which_can_be_saved=[], names_which_can_be_offloaded=[], offload_src="device", offload_dst="pinned_host"
69+
case GradientCheckpointType.OFFLOAD_MATMUL_WITHOUT_BATCH:
70+
return cp.offload_dot_with_no_batch_dims(
71+
offload_src="device", offload_dst="pinned_host"
7172
)
73+
case GradientCheckpointType.ATTN:
74+
offload_policy = cp.save_and_offload_only_these_names(
75+
names_which_can_be_saved=[], names_which_can_be_offloaded=["attn_output"], offload_src="device", offload_dst="pinned_host"
76+
)
77+
policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
78+
return cp.save_from_both_policies(offload_policy, policy)
7279
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
7380
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
7481

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,12 @@ def __init__(
175175
kernel_init=nnx.with_partitioning(
176176
nnx.initializers.xavier_uniform(),
177177
(
178-
"embed",
179178
None,
180179
"mlp",
180+
"embed",
181181
),
182182
),
183+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
183184
)
184185

185186
def __call__(self, x: jax.Array) -> jax.Array:
@@ -217,7 +218,6 @@ def __init__(
217218
raise NotImplementedError(f"{activation_fn} is not implemented.")
218219

219220
self.drop_out = nnx.Dropout(dropout)
220-
221221
self.proj_out = nnx.Linear(
222222
rngs=rngs,
223223
in_features=inner_dim,
@@ -229,9 +229,9 @@ def __init__(
229229
kernel_init=nnx.with_partitioning(
230230
nnx.initializers.xavier_uniform(),
231231
(
232+
None,
232233
"embed",
233234
"mlp",
234-
None,
235235
),
236236
),
237237
)
@@ -319,8 +319,7 @@ def __init__(
319319

320320
key = rngs.params()
321321
self.adaln_scale_shift_table = nnx.Param(
322-
jax.random.normal(key, (1, 6, dim)) / dim**0.5,
323-
sharding=("embed",))
322+
jax.random.normal(key, (1, 6, dim)) / dim**0.5,)
324323

325324
def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, temb: jax.Array, rotary_emb: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None,):
326325
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def start_training(self):
149149

150150
pipeline = self.load_checkpoint()
151151
# Generate a sample before training to compare against generated sample after training.
152-
#pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
152+
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
153153

154154
# save some memory.
155155
del pipeline.vae

0 commit comments

Comments
 (0)