Skip to content
1 change: 1 addition & 0 deletions docs/features/batch_invariance.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ Batch invariance has been tested and verified on the following models:
- **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B`, `Qwen/Qwen3-Next-80B-A3B-Instruct`
- **Qwen2.5**: `Qwen/Qwen2.5-0.5B-Instruct`, `Qwen/Qwen2.5-1.5B-Instruct`, `Qwen/Qwen2.5-3B-Instruct`, `Qwen/Qwen2.5-7B-Instruct`, `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-32B-Instruct`
- **Llama 3**: `meta-llama/Llama-3.1-8B-Instruct`, `meta-llama/Llama-3.2-1B-Instruct`
- **GPT-OSS**: `openai/gpt-oss-20b`, `openai/gpt-oss-120b`

Other models may also work, but these have been explicitly validated. If you encounter issues with a specific model, please report them on the [GitHub issue tracker](https://github.com/vllm-project/vllm/issues/new/choose).

Expand Down
1 change: 1 addition & 0 deletions tests/v1/determinism/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

BACKENDS: list[str] = [
"FLASH_ATTN",
"TRITON_ATTN",
"TRITON_MLA",
]

Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/layers/batch_invariant.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,8 +1003,11 @@ def vllm_is_batch_invariant() -> bool:
def override_envs_for_invariance(
attention_backend: AttentionBackendEnum | None,
):
supported_backends = [
decode_invariant_backends = [
AttentionBackendEnum.FLASH_ATTN, # best supported backend
AttentionBackendEnum.TRITON_ATTN,
]
supported_backends = decode_invariant_backends + [
# FlashInfer temporarily disabled due to invariant CTA sizes.
# See FlashInfer issue #2424
# AttentionBackendEnum.FLASHINFER,
Expand All @@ -1025,9 +1028,9 @@ def override_envs_for_invariance(
"one of the supported backends before enabling batch_invariant."
)
raise RuntimeError(error)
if attention_backend != supported_backends[0]:
if attention_backend not in decode_invariant_backends:
warning = (
"You are using a decode-invariant form of batch invariance. "
"You are using a non-decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode."
)
logger.warning_once(warning, scope="local")
Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import torch

from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton

logger = init_logger(__name__)
is_batch_invariant = vllm_is_batch_invariant()
float8_info = torch.finfo(current_platform.fp8_dtype())


Expand Down Expand Up @@ -972,7 +974,8 @@ def unified_attention(
# Launch the 2D kernel if
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
# 2. The batch includes at least one prefill request, or
# 3. The number of sequences exceeds the configured threshold
# 3. The number of sequences exceeds the configured threshold, or
# 4. Batch invariance is enabled
if (
seq_threshold_3D is None
or num_par_softmax_segments is None
Expand All @@ -981,6 +984,7 @@ def unified_attention(
or softmax_segm_expsum is None
or max_seqlen_q > 1
or num_seqs > seq_threshold_3D
or is_batch_invariant
):
kernel_unified_attention_2d[
(
Expand Down
Loading