Skip to content

[Feature] Enable TRITON_ATTN for Batch Invariance#33688

Merged
DarkLight1337 merged 8 commits intovllm-project:mainfrom
frankwang28:bat-inv-triton-attn
Feb 4, 2026
Merged

[Feature] Enable TRITON_ATTN for Batch Invariance#33688
DarkLight1337 merged 8 commits intovllm-project:mainfrom
frankwang28:bat-inv-triton-attn

Conversation

@frankwang28
Copy link
Copy Markdown
Contributor

@frankwang28 frankwang28 commented Feb 3, 2026

Purpose

This PR adds TRITON_ATTN support for batch invariance.

Related / parent issue: #27433

Test Plan

Run tests with and without the or is_batch_invariant check in the triton_unified_attention's unified_attention method.

Test Result

Tests are run on a B200 (do not have access to a Hopper GPU to validate there 🙁)

Without:

CUDA_VISIBLE_DEVICES=2 VLLM_TEST_SEED=12345 pytest tests/v1/determinism/test_batch_invariance.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[TRITON_ATTN] -s

...

FAILED tests/v1/determinism/test_batch_invariance.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[TRITON_ATTN] - Failed: Batch invariance violated in 128/128 prompts. See output above for details.
===================================================================== 1 failed, 3 warnings in 80.20s (0:01:20) =====================================================================
CUDA_VISIBLE_DEVICES=2 VLLM_TEST_SEED=12345 VLLM_TEST_MODEL=openai/gpt-oss-120b pytest tests/v1/determinism/test_batch_invariance.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[TRITON_ATTN] -s

...

FAILED tests/v1/determinism/test_batch_invariance.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[TRITON_ATTN] - Failed: Batch invariance violated in 128/128 prompts. See output above for details.
==================================================================== 1 failed, 3 warnings in 268.72s (0:04:28) =====================================================================

With:

CUDA_VISIBLE_DEVICES=2 VLLM_TEST_SEED=12345 pytest tests/v1/determinism/test_batch_invariance.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[TRITON_ATTN] -s

...

========================================================================== 1 passed, 3 warnings in 33.91s ==========================================================================
CUDA_VISIBLE_DEVICES=2 VLLM_TEST_SEED=12345 VLLM_TEST_MODEL=openai/gpt-oss-120b pytest tests/v1/determinism/test_batch_invariance.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[TRITON_ATTN] -s

...

===================================================================== 1 passed, 3 warnings in 65.82s (0:01:05) =====================================================================

Doing some more testing using my own test suite, Triton seems to also be decode invariant (prefilling part of a decoded sequence and decoding the rest of a sequence seems logprob identical).


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively enables batch invariance for the TRITON_ATTN backend. The changes are well-structured and logical. By adding TRITON_ATTN to the list of decode-invariant backends and forcing the use of the deterministic 2D Triton kernel when batch invariance is enabled, the PR successfully addresses the non-determinism issue. The accompanying test updates ensure that this new capability is properly verified. The code is clean and the changes are correct. Excellent work!

Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!
Since you have tested gpt oss, could you also added it in the doc? https://docs.vllm.ai/en/latest/features/batch_invariance/#tested-models

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 3, 2026
Signed-off-by: frankwang28 <frank.wbb@hotmail.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 3, 2026

Documentation preview: https://vllm--33688.org.readthedocs.build/en/33688/

@mergify mergify bot added the documentation Improvements or additions to documentation label Feb 3, 2026
@DarkLight1337 DarkLight1337 merged commit 45f8fd6 into vllm-project:main Feb 4, 2026
49 of 50 checks passed
gameofdimension pushed a commit to gameofdimension/vllm that referenced this pull request Feb 5, 2026
Signed-off-by: frankwang28 <frank.wbb@hotmail.com>
Signed-off-by: felix01.yu <felix01.yu@vipshop.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Signed-off-by: frankwang28 <frank.wbb@hotmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants