Skip to content

Commit 0d1e0f1

Browse files
fix sharding contraint for padded tensor.
1 parent 6e6fb76 commit 0d1e0f1

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1
148148
if kv_size < 128 or seq_len_pad != 0:
149149
npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad))
150150
padded_tensor = jnp.pad(tensor, npad)
151-
tensor = jax.lax.with_sharding_constraint(padded_tensor, PartitionSpec("data", "fsdp", "tensor"))
151+
tensor = jax.lax.with_sharding_constraint(padded_tensor, PartitionSpec("data", "tensor", "fsdp", None))
152152

153153
return tensor, kv_size, seq_len
154154

0 commit comments

Comments
 (0)