@@ -175,11 +175,12 @@ def __init__(
175175 kernel_init = nnx .with_partitioning (
176176 nnx .initializers .xavier_uniform (),
177177 (
178- "embed" ,
179178 None ,
180179 "mlp" ,
180+ "embed" ,
181181 ),
182182 ),
183+ bias_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" )),
183184 )
184185
185186 def __call__ (self , x : jax .Array ) -> jax .Array :
@@ -217,7 +218,6 @@ def __init__(
217218 raise NotImplementedError (f"{ activation_fn } is not implemented." )
218219
219220 self .drop_out = nnx .Dropout (dropout )
220-
221221 self .proj_out = nnx .Linear (
222222 rngs = rngs ,
223223 in_features = inner_dim ,
@@ -229,9 +229,9 @@ def __init__(
229229 kernel_init = nnx .with_partitioning (
230230 nnx .initializers .xavier_uniform (),
231231 (
232+ None ,
232233 "embed" ,
233234 "mlp" ,
234- None ,
235235 ),
236236 ),
237237 )
@@ -319,8 +319,7 @@ def __init__(
319319
320320 key = rngs .params ()
321321 self .adaln_scale_shift_table = nnx .Param (
322- jax .random .normal (key , (1 , 6 , dim )) / dim ** 0.5 ,
323- sharding = ("embed" ,))
322+ jax .random .normal (key , (1 , 6 , dim )) / dim ** 0.5 ,)
324323
325324 def __call__ (self , hidden_states : jax .Array , encoder_hidden_states : jax .Array , temb : jax .Array , rotary_emb : jax .Array , deterministic : bool = True , rngs : nnx .Rngs = None ,):
326325 shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
0 commit comments