Skip to content

Commit 7289218

Browse files
committed
activation offloading
1 parent 1ca1b1b commit 7289218

3 files changed

Lines changed: 23 additions & 10 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ def __init__(
800800
nnx.initializers.zeros,
801801
(
802802
None,
803-
"heads"
803+
"heads",
804804
),
805805
),
806806
)

src/maxdiffusion/models/gradient_checkpoint.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,25 @@ def to_jax_policy(self):
7171
offload_src="device", offload_dst="pinned_host"
7272
)
7373
case GradientCheckpointType.ATTN:
74-
offload_policy = cp.save_and_offload_only_these_names(
75-
names_which_can_be_saved=[], names_which_can_be_offloaded=["attn_output"], offload_src="device", offload_dst="pinned_host"
74+
policy = cp.save_and_offload_only_these_names(
75+
names_which_can_be_saved=[],
76+
names_which_can_be_offloaded=[
77+
#"attn_output",
78+
#"query_proj",
79+
#"key_proj",
80+
#"value_proj",
81+
#"xq_out",
82+
#"xk_out",
83+
"ffn_activation"
84+
],
85+
offload_src="device",
86+
offload_dst="pinned_host"
7687
)
77-
policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
78-
return cp.save_from_both_policies(offload_policy, policy)
88+
return policy
7989
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
8090
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
8191

82-
def apply(self, module: nnx.Module) -> nnx.Module:
92+
def apply(self, module: nnx.Module, static_argnums=()) -> nnx.Module:
8393
"""
8494
Applies a gradient checkpoint policy to a module
8595
if no policy is needed, it will return the module as is
@@ -97,4 +107,5 @@ def apply(self, module: nnx.Module) -> nnx.Module:
97107
module,
98108
prevent_cse=False,
99109
policy=policy,
110+
static_argnums=static_argnums
100111
)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax
2020
import jax.numpy as jnp
2121
from jax.sharding import PartitionSpec
22+
from jax.ad_checkpoint import checkpoint_name
2223
from flax import nnx
2324
import numpy as np
2425
from .... import common_types
@@ -42,7 +43,7 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int):
4243
t_dim = attention_head_dim - h_dim - w_dim
4344
freqs = []
4445
for dim in [t_dim, h_dim, w_dim]:
45-
freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float64, use_real=False)
46+
freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float32, use_real=False)
4647
freqs.append(freq)
4748
freqs = jnp.concatenate(freqs, axis=1)
4849
t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6)
@@ -180,7 +181,7 @@ def __init__(
180181
"embed",
181182
),
182183
),
183-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
184+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)),
184185
)
185186

186187
def __call__(self, x: jax.Array) -> jax.Array:
@@ -237,9 +238,10 @@ def __init__(
237238
)
238239

239240
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
240-
hidden_states = self.act_fn(hidden_states)
241+
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
242+
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
241243
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
242-
return self.proj_out(hidden_states)
244+
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
243245

244246

245247
class WanTransformerBlock(nnx.Module):

0 commit comments

Comments
 (0)