Skip to content

Commit 7c84ec2

Browse files
adding localmask to check multihost.
1 parent 2a48490 commit 7c84ec2

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,11 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
207207
return splash_kernel
208208

209209
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
210+
mask &= splash_attention_mask.LocalMask(
211+
shape=(query.shape[2], key.shape[2]),
212+
window_size=(query.shape[2], key.shape[2]),
213+
offset=0
214+
)
210215
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
211216
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
212217
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)

0 commit comments

Comments
 (0)