[CUDA] Remove TensorRT fused causal attention kernels#29143
Open
tianleiwu wants to merge 1 commit into
Open
Conversation
Contributor
There was a problem hiding this comment.
⚠️ Not ready to approve
There are a couple of remaining inconsistencies/redundant test paths after causal-kernel removal that should be cleaned up to match the PR’s stated behavior and keep tests deterministic.
Pull request overview
This PR removes the TensorRT fused causal attention kernel assets and eliminates the corresponding CUDA Attention dispatch/option plumbing, consolidating causal attention execution onto existing paths (flash attention, memory-efficient attention, cuDNN SDPA, or unfused).
Changes:
- Removed TRT fused causal attention selection/metadata and simplified fused-runner logic to the non-causal (BERT) case.
- Deleted causal-only QKV formats and runner APIs (dropped
causalparameters, removed causal kernel meta entries). - Updated tests and Python benchmarking helpers to stop referencing the removed
ORT_ENABLE_FUSED_CAUSAL_ATTENTIONbehavior.
File summaries
| File | Description |
|---|---|
| onnxruntime/test/providers/cuda/test_cases/attention_kernel_options_test.cc | Drops causal-backend assertions/env usage from kernel-option tests. |
| onnxruntime/test/contrib_ops/attention_op_test.cc | Removes fused-causal env usage in causal Attention test scaffolding. |
| onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py | Removes legacy workaround that set ORT_ENABLE_FUSED_CAUSAL_ATTENTION. |
| onnxruntime/python/tools/transformers/benchmark_helper.py | Removes ORT_ENABLE_FUSED_CAUSAL_ATTENTION from env var enumeration. |
| onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h | Removes causal from TRT fused runner APIs/state. |
| onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu | Deletes causal runner setup and causal-mask routing; updates IsSupported/Create/Run. |
| onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention.h | Removes unused causal-mask parameter from kernel run interface (cleanup). |
| onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention_v2.h | Removes causal cubin externs and causal kernel meta entries; simplifies hashing/signatures. |
| onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc | Updates debug-info API calls for fused runner (no causal arg). |
| onnxruntime/contrib_ops/cuda/bert/packed_attention.cc | Updates TRT fused runner IsSupported/Create call sites and debug-info API. |
| onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc | Updates TRT fused runner calls and debug-info API (no causal arg). |
| onnxruntime/contrib_ops/cuda/bert/attention.h | Removes member storing fused-causal enablement. |
| onnxruntime/contrib_ops/cuda/bert/attention.cc | Removes fused-causal selection branch; fused runner now only for non-unidirectional path. |
| onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu | Removes fused-causal QKV format handling and bias-to-gemm-buffer path. |
| onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.h | Removes causal option surface and updates debug-info API signature. |
| onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc | Stops parsing/printing TRT causal option; updates debug-info classification. |
| onnxruntime/contrib_ops/cuda/bert/attention_impl.cu | Removes fused-causal execution path and tightens QKV format assertions. |
| onnxruntime/contrib_ops/cpu/bert/attention_common.h | Removes fused-causal env var constant and QKV format enum entry (CPU-side shared defs). |
Copilot's findings
- Files reviewed: 18/88 changed files
- Comments generated: 3
Note
Your feedback helps us improve the quality of this feature.
Please use 👍 or 👎 to tell us whether this assessment is correct.
Comment on lines
121
to
125
| // Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled). | ||
| constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION"; | ||
|
|
||
| // Environment variable to enable or disable TRT fused causal attention kernels. Default is 0 (disabled). | ||
| // Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels. | ||
| constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION"; | ||
|
|
||
| // Environment variable to enable or disable cuDNN flash attention. | ||
| constexpr const char* kEnableCudnnFlashAttention = "ORT_ENABLE_CUDNN_FLASH_ATTENTION"; |
Comment on lines
+889
to
893
| // TRT flash attention disabled, fused self attention enabled | ||
| { | ||
| ScopedEnvironmentVariables scoped_env_vars{ | ||
| EnvVarMap{ | ||
| {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, |
Comment on lines
878
to
883
| // Unfused kernel | ||
| { | ||
| ScopedEnvironmentVariables scoped_env_vars{ | ||
| EnvVarMap{ | ||
| {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, | ||
| {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}, | ||
| {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}}}; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR removes the TensorRT fused causal attention kernels (the
fmha_v2_*_Causal_*andfmha_v2_flash_attention_*_Causal_*cubins) and all of the code paths that selected them from the CUDAAttentionoperator.These causal fused kernels were disabled by default (since #14732) and were only reachable via the opt-in
ORT_ENABLE_FUSED_CAUSAL_ATTENTIONenvironment variable /TRT_CAUSAL_ATTENTIONbackend bit. They used fp16 accumulation, which can cause accuracy drops, and have been superseded by flash attention, memory-efficient attention, and cuDNN SDPA. Removing them deletes ~1.27M lines of generated cubin source and simplifies the attention dispatch logic.Key Changes
causal/fmha_v2_fp16_Causal_*andflash_attention/fmha_v2_flash_attention_fp16_Causal_*generated cubin files (70+ files).is_unidirectional_/ causal fused-runner branch inComputeInternal; the fused runner path now only handles the BERT (non-causal) case.use_trt_causal_attention_,UseTrtCausalAttention(), theTRT_CAUSAL_ATTENTIONdebug print, and thecausalargument ofSetTrtFusedKernel.Q_K_V_BNSH_QKV_BS3NHformat and the fused-causal gemm-buffer-with-bias preparation path.causalparameter fromFusedMHARunnerFP16v2::Create/IsSupportedand removed the causal kernel metadata.ORT_ENABLE_FUSED_CAUSAL_ATTENTION(kEnableFusedCausalAttention) is no longer recognized.ORT_ENABLE_FUSED_CAUSAL_ATTENTIONreferences from the transformers benchmark helper and stable diffusion benchmark.Motivation and Context
The fused causal kernels were off by default, carried potential fp16-accumulation accuracy risk, and added a large amount of generated cubin source to the repo. Causal attention is already well covered by flash attention, memory-efficient attention, and cuDNN SDPA, so these kernels can be safely removed to reduce binary size and simplify maintenance.
Testing
ContribOpAttentionTest.*, includingCausal_EmptyPastState).AttentionKernelOptionsTest.*to verify the kernel-option parsing no longer references the causal backend.