1919import jax
2020import jax .numpy as jnp
2121from jax .sharding import PartitionSpec
22+ from jax .ad_checkpoint import checkpoint_name
2223from flax import nnx
2324import numpy as np
2425from .... import common_types
@@ -42,7 +43,7 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int):
4243 t_dim = attention_head_dim - h_dim - w_dim
4344 freqs = []
4445 for dim in [t_dim , h_dim , w_dim ]:
45- freq = get_1d_rotary_pos_embed (dim , max_seq_len , theta , freqs_dtype = jnp .float64 , use_real = False )
46+ freq = get_1d_rotary_pos_embed (dim , max_seq_len , theta , freqs_dtype = jnp .float32 , use_real = False )
4647 freqs .append (freq )
4748 freqs = jnp .concatenate (freqs , axis = 1 )
4849 t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6 )
@@ -180,7 +181,7 @@ def __init__(
180181 "embed" ,
181182 ),
182183 ),
183- bias_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" )),
184+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" , )),
184185 )
185186
186187 def __call__ (self , x : jax .Array ) -> jax .Array :
@@ -237,9 +238,10 @@ def __init__(
237238 )
238239
239240 def __call__ (self , hidden_states : jax .Array , deterministic : bool = True , rngs : nnx .Rngs = None ) -> jax .Array :
240- hidden_states = self .act_fn (hidden_states )
241+ hidden_states = self .act_fn (hidden_states ) # Output is (4, 75600, 13824)
242+ hidden_states = checkpoint_name (hidden_states , "ffn_activation" )
241243 hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
242- return self .proj_out (hidden_states )
244+ return self .proj_out (hidden_states ) # output is (4, 75600, 5120)
243245
244246
245247class WanTransformerBlock (nnx .Module ):
0 commit comments