Skip to content

[CI] TRTLLM Gen-full Attn Test Coverage#34986

Merged
MatthewBonanni merged 1 commit intovllm-project:mainfrom
ojhaanshika:trtllm-gen-full-attn-test-coverage
Mar 3, 2026
Merged

[CI] TRTLLM Gen-full Attn Test Coverage#34986
MatthewBonanni merged 1 commit intovllm-project:mainfrom
ojhaanshika:trtllm-gen-full-attn-test-coverage

Conversation

@ojhaanshika
Copy link
Copy Markdown
Contributor

@ojhaanshika ojhaanshika commented Feb 20, 2026

Purpose

Add test coverage for the TRTLLM gen-full attention pipeline in vLLM. The existing kernel-level tests (test_flashinfer_trtllm_attention.py) call the FlashInfer C++ kernels directly, bypassing the vLLM integration layer entirely. This leaves the decision logic, metadata construction, and forward dispatch at 0% coverage.

This PR adds two new test files:

  1. Unit tests for attention decision functions (tests/kernels/attention/test_use_trtllm_attention.py) — 22 tests covering all branches of use_trtllm_attention(), can_use_trtllm_attention(), and supports_trtllm_attention() in vllm/utils/flashinfer.py. These validate the gatekeeper logic that decides whether to use TRTLLM or fall back to native FlashInfer (force on/off, DCP fallback, head incompatibility, speculative decoding, FP8 query, attention sinks, auto-detection).

  2. Integration tests for the full pipeline (tests/v1/attention/test_trtllm_attention_integration.py) — 3 tests that exercise FlashInferMetadataBuilder.build() -> FlashInferImpl.forward() with TRTLLM kernels on Blackwell, comparing output against an SDPA reference. Covers decode-only, prefill-only, and mixed prefill+decode batches.

Test Plan

Unit tests (runs on any platform):
python -m pytest tests/kernels/attention/test_use_trtllm_attention.py -v

Integration tests (requires Blackwell SM100 GPU):
python -m pytest tests/v1/attention/test_trtllm_attention_integration.py -v

Test Result

Unit tests: 22 passed
Integration tests: 3 passed

Coverage improvement for vllm/utils/flashinfer.py: 31% -> 47%
Coverage improvement for vllm/v1/attention/backends/flashinfer.py: 0% -> 22%

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive test coverage for the new TRT-LLM attention backend. The changes include two new test files: one for unit testing the logic that determines when to use TRT-LLM attention, and another for integration testing its correctness against a reference implementation.

The unit tests in tests/kernels/attention/test_use_trtllm_attention.py are well-structured, using extensive mocking to cover a wide range of scenarios and edge cases. The integration tests in tests/v1/attention/test_trtllm_attention_integration.py are also robust, validating the end-to-end functionality for different batch types (decode, prefill, and mixed) and ensuring numerical correctness.

Overall, the tests are of high quality, well-written, and significantly improve confidence in the new TRT-LLM attention feature. I have no major concerns with this contribution.

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

nice job

@ojhaanshika ojhaanshika force-pushed the trtllm-gen-full-attn-test-coverage branch 2 times, most recently from a8ff90b to 2e0648e Compare February 23, 2026 20:44
@ojhaanshika ojhaanshika changed the title Trtllm gen full attn test coverage TRTLLM gen-full attn Test Coverage Feb 23, 2026
@ojhaanshika ojhaanshika marked this pull request as ready for review February 23, 2026 20:50
Copy link
Copy Markdown
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @ojhaanshika. I have added some feedback. The integration test looks good to me.

Comment on lines +131 to +135

@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_incompatible_heads_force_on_still_false(_mock):
assert _call(num_qo_heads=40, num_kv_heads=6, force_use_trtllm=True) is False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should re-think the whole test in terms of model config rather than arbitrary numbers that would and wouldn't use trtllm. For example, something like _call(get_config("Llama-3")). That would also help us quickly see which models does TRTLLM support vs which ones don't. We could have a dictionary that maps model to its individual parameters like query heads, kv heads, uses sliding window attention, etc.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a MODEL_CONFIGS dictionary with get_config() as suggested. The default test config now uses get_config("Llama-3-70B"). For the incompatible heads case, I kept bare numbers where num_qo_heads % num_kv_heads != 0. Let me know if this is what you were thinking or which specific models I should add to the dictionary?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach sounds good to me. I did a quick internet search and I don't see many models that actually have the scenario where num_qo_heads % num_kv_heads != 0. There were some test models but no production models.

Comment on lines +166 to +173
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_auto_decode_small_batch(_mock):
assert _call(is_prefill=False, num_tokens=128, kv_cache_dtype="auto") is True


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_auto_decode_large_batch(_mock):
assert _call(is_prefill=False, num_tokens=512, kv_cache_dtype="auto") is False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we take this opportunity to actually measure if the bounds set are still true?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! The test values (128 and 512) were chosen to land on either side of the num_tokens <= 256 threshold on line 386 of flashinfer.py. They verify the branching logic works correctly. Did you want me to benchmark to confirm 256 is still the optimal crossover point between TRTLLM and FlashInfer native decode?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly. It's been a while since we ran the benchmark to see where the perf flips in favor of Flashinfer native FA2 kernels.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I benchmarked on GB200 using benchmarks/kernels/benchmark_trtllm_decode_attention.py to confirm the num_tokens <= 256 threshold for bf16 decode (kv_cache_dtype="auto").

I found that at batch_size=256, TRTLLM wins across all sequence lengths (1K-128K), confirming the current threshold is safe and the crossover where FlashInfer overtakes TRTLLM depends on sequence length.

For shorter sequences (1K-8K), TRTLLM is faster even up to batch_size=2048. For longer sequences the crossover shifts lower: 896 at 32K, 640 at 64K, and 512-640 at 128K.

The worst-case crossover is 512, so the threshold could potentially be raised to 384-512? Let me know if you'd like me to raise the threshold (e.g., to 384 or 512) in this PR, or if we should keep it at 256 for now and address it separately.

Here are the benchmark CSV's:
Full benchmark (all quant configs, batch sizes 1-256)
Crossover sweep (bf16 only, batch sizes 128-2048)

Copy link
Copy Markdown
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one more comment.

Comment on lines +131 to +135

@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_incompatible_heads_force_on_still_false(_mock):
assert _call(num_qo_heads=40, num_kv_heads=6, force_use_trtllm=True) is False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach sounds good to me. I did a quick internet search and I don't see many models that actually have the scenario where num_qo_heads % num_kv_heads != 0. There were some test models but no production models.

Comment on lines +111 to +113
def test_can_use_platform_unsupported(_sup, _force):
cfg = get_config("Llama-3-70B")
assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add more models here and parameterize the test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

@pavanimajety pavanimajety added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 27, 2026
@pavanimajety
Copy link
Copy Markdown
Collaborator

@ojhaanshika could you please fix your DCO?

@ojhaanshika ojhaanshika force-pushed the trtllm-gen-full-attn-test-coverage branch 2 times, most recently from 9000420 to 6920dc1 Compare March 2, 2026 20:48
…line

Signed-off-by: Anshika Ojha <anshikao@nvidia.com>
@ojhaanshika ojhaanshika force-pushed the trtllm-gen-full-attn-test-coverage branch from 6920dc1 to 5eaa442 Compare March 2, 2026 22:19
@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Mar 3, 2026
Copy link
Copy Markdown
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for doing this!

@MatthewBonanni MatthewBonanni merged commit e05cb3b into vllm-project:main Mar 3, 2026
21 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Mar 3, 2026
@ojhaanshika ojhaanshika changed the title TRTLLM gen-full attn Test Coverage [CI] TRTLLM Gen-full Attn Test Coverage Mar 9, 2026
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
Signed-off-by: Anshika Ojha <anshikao@nvidia.com>
Co-authored-by: Anshika Ojha <anshikao@gb-nvl-059-compute09.nvidia.com>
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Mar 12, 2026
Signed-off-by: Anshika Ojha <anshikao@nvidia.com>
Co-authored-by: Anshika Ojha <anshikao@gb-nvl-059-compute09.nvidia.com>
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
Signed-off-by: Anshika Ojha <anshikao@nvidia.com>
Co-authored-by: Anshika Ojha <anshikao@gb-nvl-059-compute09.nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants