Skip to content

Add tile_m/tile_k/tile_n overrides to SwiGLUPrefill#106

Open
albiol2004 wants to merge 1 commit intoamd:develfrom
albiol2004:swiglu-prefill-tile-overrides
Open

Add tile_m/tile_k/tile_n overrides to SwiGLUPrefill#106
albiol2004 wants to merge 1 commit intoamd:develfrom
albiol2004:swiglu-prefill-tile-overrides

Conversation

@albiol2004
Copy link
Copy Markdown

Adds tile_m / tile_k / tile_n kwargs to SwiGLUPrefill, threading them through to the two inner GEMM operators. SwiGLUPrefill currently uses GEMM's default tile triple (64/64/64), which forces min_M = tile_m * 4 = 256. Real prefill batches from decoder-model runtimes (llama.cpp ubatch=32/64/128) fall below that threshold, so the fused SwiGLU path is unreachable in practice for the M range it was designed for. Passing tile_m=16 drops min_M to 64.

Added

  • Optional tile_m / tile_k / tile_n parameters on SwiGLUPrefill.__init__ (default None) that pass through to both inner GEMMs. Both stages receive the same tile triple.
  • New parametrized test case (seq_len=64, embedding_dim=1024, hidden_dim=3584, tile_m=16, tile_k=64, tile_n=64) covering a decode-runtime-sized prefill at the Qwen3.5-0.8B FFN shape. Existing (256, 2048, 2048) case unchanged.

Changed

  • When any of tile_m / tile_k / tile_n is None, the corresponding kwarg is omitted from the GEMM constructor call, preserving the previous behavior for existing callers.

Removed

PR Merge Checklist

  1. The PR is rebased on the latest devel commit and pointing to devel.
  2. Your PR has been reviewed and approved.
  3. All checks are passing.

SwiGLUPrefill currently uses GEMM's default tile triple (64/64/64),
which forces min_M = tile_m * num_aie_rows = 256. Real-world prefill
batch sizes from decoder-model runtimes (llama.cpp ubatch=32/64/128)
fall well below that threshold, leaving the fused SwiGLU path
unreachable in practice.

Add optional tile_m/tile_k/tile_n kwargs that pass through to both
inner GEMMs. When None (default), each falls back to GEMM's native
default, so existing callers and the existing (256, 2048, 2048) test
are unchanged.

Add a small-M test case (M=64, K=1024, N=3584, tile_m=16) that
exercises the override path at the Qwen3.5-0.8B FFN shape.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant