Skip to content

[CUDA] Remove TensorRT fused causal attention kernels#29143

Open
tianleiwu wants to merge 1 commit into
mainfrom
tlwu/remove_causal_fmha_cubin
Open

[CUDA] Remove TensorRT fused causal attention kernels#29143
tianleiwu wants to merge 1 commit into
mainfrom
tlwu/remove_causal_fmha_cubin

Conversation

@tianleiwu

Copy link
Copy Markdown
Contributor

Description

This PR removes the TensorRT fused causal attention kernels (the fmha_v2_*_Causal_* and fmha_v2_flash_attention_*_Causal_* cubins) and all of the code paths that selected them from the CUDA Attention operator.

These causal fused kernels were disabled by default (since #14732) and were only reachable via the opt-in ORT_ENABLE_FUSED_CAUSAL_ATTENTION environment variable / TRT_CAUSAL_ATTENTION backend bit. They used fp16 accumulation, which can cause accuracy drops, and have been superseded by flash attention, memory-efficient attention, and cuDNN SDPA. Removing them deletes ~1.27M lines of generated cubin source and simplifies the attention dispatch logic.

Key Changes

Motivation and Context

The fused causal kernels were off by default, carried potential fp16-accumulation accuracy risk, and added a large amount of generated cubin source to the repo. Causal attention is already well covered by flash attention, memory-efficient attention, and cuDNN SDPA, so these kernels can be safely removed to reduce binary size and simplify maintenance.

Testing

  • Build the CUDA EP and run the attention contrib op tests (ContribOpAttentionTest.*, including Causal_EmptyPastState).
  • Run AttentionKernelOptionsTest.* to verify the kernel-option parsing no longer references the causal backend.

@tianleiwu tianleiwu changed the title Remove TensorRT fused causal attention kernels [CUDA] Remove TensorRT fused causal attention kernels Jun 18, 2026
@tianleiwu tianleiwu requested a review from Copilot June 18, 2026 05:57

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Not ready to approve

There are a couple of remaining inconsistencies/redundant test paths after causal-kernel removal that should be cleaned up to match the PR’s stated behavior and keep tests deterministic.

Pull request overview

This PR removes the TensorRT fused causal attention kernel assets and eliminates the corresponding CUDA Attention dispatch/option plumbing, consolidating causal attention execution onto existing paths (flash attention, memory-efficient attention, cuDNN SDPA, or unfused).

Changes:

  • Removed TRT fused causal attention selection/metadata and simplified fused-runner logic to the non-causal (BERT) case.
  • Deleted causal-only QKV formats and runner APIs (dropped causal parameters, removed causal kernel meta entries).
  • Updated tests and Python benchmarking helpers to stop referencing the removed ORT_ENABLE_FUSED_CAUSAL_ATTENTION behavior.
File summaries
File Description
onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc Drops causal-backend assertions/env usage from kernel-option tests.
onnxruntime/test/contrib_ops/attention_op_test.cc Removes fused-causal env usage in causal Attention test scaffolding.
onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py Removes legacy workaround that set ORT_ENABLE_FUSED_CAUSAL_ATTENTION.
onnxruntime/python/tools/transformers/benchmark_helper.py Removes ORT_ENABLE_FUSED_CAUSAL_ATTENTION from env var enumeration.
onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h Removes causal from TRT fused runner APIs/state.
onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu Deletes causal runner setup and causal-mask routing; updates IsSupported/Create/Run.
onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention.h Removes unused causal-mask parameter from kernel run interface (cleanup).
onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention_v2.h Removes causal cubin externs and causal kernel meta entries; simplifies hashing/signatures.
onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc Updates debug-info API calls for fused runner (no causal arg).
onnxruntime/contrib_ops/cuda/bert/packed_attention.cc Updates TRT fused runner IsSupported/Create call sites and debug-info API.
onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc Updates TRT fused runner calls and debug-info API (no causal arg).
onnxruntime/contrib_ops/cuda/bert/attention.h Removes member storing fused-causal enablement.
onnxruntime/contrib_ops/cuda/bert/attention.cc Removes fused-causal selection branch; fused runner now only for non-unidirectional path.
onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu Removes fused-causal QKV format handling and bias-to-gemm-buffer path.
onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h Removes causal option surface and updates debug-info API signature.
onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc Stops parsing/printing TRT causal option; updates debug-info classification.
onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Removes fused-causal execution path and tightens QKV format assertions.
onnxruntime/contrib_ops/cpu/bert/attention_common.h Removes fused-causal env var constant and QKV format enum entry (CPU-side shared defs).

Copilot's findings

  • Files reviewed: 18/88 changed files
  • Comments generated: 3

Note

Your feedback helps us improve the quality of this feature.
Please use 👍 or 👎 to tell us whether this assessment is correct.

Comment on lines 121 to 125
// Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION";

// Environment variable to enable or disable TRT fused causal attention kernels. Default is 0 (disabled).
// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels.
constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION";

// Environment variable to enable or disable cuDNN flash attention.
constexpr const char* kEnableCudnnFlashAttention = "ORT_ENABLE_CUDNN_FLASH_ATTENTION";
Comment on lines +889 to 893
// TRT flash attention disabled, fused self attention enabled
{
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
Comment on lines 878 to 883
// Unfused kernel
{
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}}};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants