From 24833c2b03138bbed8cc50f3c8bf5ea33fc5c71a Mon Sep 17 00:00:00 2001 From: frankwang28 Date: Tue, 3 Feb 2026 03:03:11 -0800 Subject: [PATCH 1/2] Enable TRITON_ATTN for batch invariance Signed-off-by: frankwang28 --- tests/v1/determinism/utils.py | 1 + vllm/model_executor/layers/batch_invariant.py | 9 ++++++--- vllm/v1/attention/ops/triton_unified_attention.py | 6 +++++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/v1/determinism/utils.py b/tests/v1/determinism/utils.py index 5066315762e5..ca3ccab5efff 100644 --- a/tests/v1/determinism/utils.py +++ b/tests/v1/determinism/utils.py @@ -18,6 +18,7 @@ BACKENDS: list[str] = [ "FLASH_ATTN", + "TRITON_ATTN", "TRITON_MLA", ] diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 3f44608ab9ef..fcfadd60f5ce 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -1003,8 +1003,11 @@ def vllm_is_batch_invariant() -> bool: def override_envs_for_invariance( attention_backend: AttentionBackendEnum | None, ): - supported_backends = [ + decode_invariant_backends = [ AttentionBackendEnum.FLASH_ATTN, # best supported backend + AttentionBackendEnum.TRITON_ATTN, + ] + supported_backends = decode_invariant_backends + [ # FlashInfer temporarily disabled due to invariant CTA sizes. # See FlashInfer issue #2424 # AttentionBackendEnum.FLASHINFER, @@ -1025,9 +1028,9 @@ def override_envs_for_invariance( "one of the supported backends before enabling batch_invariant." ) raise RuntimeError(error) - if attention_backend != supported_backends[0]: + if attention_backend not in decode_invariant_backends: warning = ( - "You are using a decode-invariant form of batch invariance. " + "You are using a non-decode-invariant form of batch invariance. " "This will not be invariant between prefill and decode." ) logger.warning_once(warning, scope="local") diff --git a/vllm/v1/attention/ops/triton_unified_attention.py b/vllm/v1/attention/ops/triton_unified_attention.py index 6855233ee942..4ddd47c6dd65 100644 --- a/vllm/v1/attention/ops/triton_unified_attention.py +++ b/vllm/v1/attention/ops/triton_unified_attention.py @@ -10,10 +10,12 @@ import torch from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.platforms import current_platform from vllm.triton_utils import tl, triton logger = init_logger(__name__) +is_batch_invariant = vllm_is_batch_invariant() float8_info = torch.finfo(current_platform.fp8_dtype()) @@ -972,7 +974,8 @@ def unified_attention( # Launch the 2D kernel if # 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or # 2. The batch includes at least one prefill request, or - # 3. The number of sequences exceeds the configured threshold + # 3. The number of sequences exceeds the configured threshold, or + # 4. Batch invariance is enabled if ( seq_threshold_3D is None or num_par_softmax_segments is None @@ -981,6 +984,7 @@ def unified_attention( or softmax_segm_expsum is None or max_seqlen_q > 1 or num_seqs > seq_threshold_3D + or is_batch_invariant ): kernel_unified_attention_2d[ ( From 3bdd86e08ae6e69221e7222f0b66964683b96a0f Mon Sep 17 00:00:00 2001 From: frankwang28 Date: Tue, 3 Feb 2026 13:34:48 -0800 Subject: [PATCH 2/2] Add gpt-oss to tested models Signed-off-by: frankwang28 --- docs/features/batch_invariance.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/features/batch_invariance.md b/docs/features/batch_invariance.md index 0144e2f71f0e..72224c96cfdf 100644 --- a/docs/features/batch_invariance.md +++ b/docs/features/batch_invariance.md @@ -108,6 +108,7 @@ Batch invariance has been tested and verified on the following models: - **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B`, `Qwen/Qwen3-Next-80B-A3B-Instruct` - **Qwen2.5**: `Qwen/Qwen2.5-0.5B-Instruct`, `Qwen/Qwen2.5-1.5B-Instruct`, `Qwen/Qwen2.5-3B-Instruct`, `Qwen/Qwen2.5-7B-Instruct`, `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-32B-Instruct` - **Llama 3**: `meta-llama/Llama-3.1-8B-Instruct`, `meta-llama/Llama-3.2-1B-Instruct` +- **GPT-OSS**: `openai/gpt-oss-20b`, `openai/gpt-oss-120b` Other models may also work, but these have been explicitly validated. If you encounter issues with a specific model, please report them on the [GitHub issue tracker](https://github.com/vllm-project/vllm/issues/new/choose).