[CI] TRTLLM Gen-full Attn Test Coverage#34986
[CI] TRTLLM Gen-full Attn Test Coverage#34986MatthewBonanni merged 1 commit intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive test coverage for the new TRT-LLM attention backend. The changes include two new test files: one for unit testing the logic that determines when to use TRT-LLM attention, and another for integration testing its correctness against a reference implementation.
The unit tests in tests/kernels/attention/test_use_trtllm_attention.py are well-structured, using extensive mocking to cover a wide range of scenarios and edge cases. The integration tests in tests/v1/attention/test_trtllm_attention_integration.py are also robust, validating the end-to-end functionality for different batch types (decode, prefill, and mixed) and ensuring numerical correctness.
Overall, the tests are of high quality, well-written, and significantly improve confidence in the new TRT-LLM attention feature. I have no major concerns with this contribution.
|
nice job |
a8ff90b to
2e0648e
Compare
pavanimajety
left a comment
There was a problem hiding this comment.
Thanks for the PR @ojhaanshika. I have added some feedback. The integration test looks good to me.
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_incompatible_heads_force_on_still_false(_mock): | ||
| assert _call(num_qo_heads=40, num_kv_heads=6, force_use_trtllm=True) is False | ||
|
|
There was a problem hiding this comment.
IMO we should re-think the whole test in terms of model config rather than arbitrary numbers that would and wouldn't use trtllm. For example, something like _call(get_config("Llama-3")). That would also help us quickly see which models does TRTLLM support vs which ones don't. We could have a dictionary that maps model to its individual parameters like query heads, kv heads, uses sliding window attention, etc.
There was a problem hiding this comment.
I've added a MODEL_CONFIGS dictionary with get_config() as suggested. The default test config now uses get_config("Llama-3-70B"). For the incompatible heads case, I kept bare numbers where num_qo_heads % num_kv_heads != 0. Let me know if this is what you were thinking or which specific models I should add to the dictionary?
There was a problem hiding this comment.
This approach sounds good to me. I did a quick internet search and I don't see many models that actually have the scenario where num_qo_heads % num_kv_heads != 0. There were some test models but no production models.
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_auto_decode_small_batch(_mock): | ||
| assert _call(is_prefill=False, num_tokens=128, kv_cache_dtype="auto") is True | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_auto_decode_large_batch(_mock): | ||
| assert _call(is_prefill=False, num_tokens=512, kv_cache_dtype="auto") is False |
There was a problem hiding this comment.
Could we take this opportunity to actually measure if the bounds set are still true?
There was a problem hiding this comment.
Sure! The test values (128 and 512) were chosen to land on either side of the num_tokens <= 256 threshold on line 386 of flashinfer.py. They verify the branching logic works correctly. Did you want me to benchmark to confirm 256 is still the optimal crossover point between TRTLLM and FlashInfer native decode?
There was a problem hiding this comment.
Yes, exactly. It's been a while since we ran the benchmark to see where the perf flips in favor of Flashinfer native FA2 kernels.
There was a problem hiding this comment.
I benchmarked on GB200 using benchmarks/kernels/benchmark_trtllm_decode_attention.py to confirm the num_tokens <= 256 threshold for bf16 decode (kv_cache_dtype="auto").
I found that at batch_size=256, TRTLLM wins across all sequence lengths (1K-128K), confirming the current threshold is safe and the crossover where FlashInfer overtakes TRTLLM depends on sequence length.
For shorter sequences (1K-8K), TRTLLM is faster even up to batch_size=2048. For longer sequences the crossover shifts lower: 896 at 32K, 640 at 64K, and 512-640 at 128K.
The worst-case crossover is 512, so the threshold could potentially be raised to 384-512? Let me know if you'd like me to raise the threshold (e.g., to 384 or 512) in this PR, or if we should keep it at 256 for now and address it separately.
Here are the benchmark CSV's:
Full benchmark (all quant configs, batch sizes 1-256)
Crossover sweep (bf16 only, batch sizes 128-2048)
pavanimajety
left a comment
There was a problem hiding this comment.
LGTM, just one more comment.
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_incompatible_heads_force_on_still_false(_mock): | ||
| assert _call(num_qo_heads=40, num_kv_heads=6, force_use_trtllm=True) is False | ||
|
|
There was a problem hiding this comment.
This approach sounds good to me. I did a quick internet search and I don't see many models that actually have the scenario where num_qo_heads % num_kv_heads != 0. There were some test models but no production models.
| def test_can_use_platform_unsupported(_sup, _force): | ||
| cfg = get_config("Llama-3-70B") | ||
| assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False |
There was a problem hiding this comment.
Could we add more models here and parameterize the test?
|
@ojhaanshika could you please fix your DCO? |
9000420 to
6920dc1
Compare
…line Signed-off-by: Anshika Ojha <anshikao@nvidia.com>
6920dc1 to
5eaa442
Compare
MatthewBonanni
left a comment
There was a problem hiding this comment.
LGTM, thanks for doing this!
Signed-off-by: Anshika Ojha <anshikao@nvidia.com> Co-authored-by: Anshika Ojha <anshikao@gb-nvl-059-compute09.nvidia.com>
Signed-off-by: Anshika Ojha <anshikao@nvidia.com> Co-authored-by: Anshika Ojha <anshikao@gb-nvl-059-compute09.nvidia.com>
Signed-off-by: Anshika Ojha <anshikao@nvidia.com> Co-authored-by: Anshika Ojha <anshikao@gb-nvl-059-compute09.nvidia.com>
Purpose
Add test coverage for the TRTLLM gen-full attention pipeline in vLLM. The existing kernel-level tests (
test_flashinfer_trtllm_attention.py) call the FlashInfer C++ kernels directly, bypassing the vLLM integration layer entirely. This leaves the decision logic, metadata construction, and forward dispatch at 0% coverage.This PR adds two new test files:
Unit tests for attention decision functions (
tests/kernels/attention/test_use_trtllm_attention.py) — 22 tests covering all branches ofuse_trtllm_attention(),can_use_trtllm_attention(), andsupports_trtllm_attention()invllm/utils/flashinfer.py. These validate the gatekeeper logic that decides whether to use TRTLLM or fall back to native FlashInfer (force on/off, DCP fallback, head incompatibility, speculative decoding, FP8 query, attention sinks, auto-detection).Integration tests for the full pipeline (
tests/v1/attention/test_trtllm_attention_integration.py) — 3 tests that exerciseFlashInferMetadataBuilder.build()->FlashInferImpl.forward()with TRTLLM kernels on Blackwell, comparing output against an SDPA reference. Covers decode-only, prefill-only, and mixed prefill+decode batches.Test Plan
Unit tests (runs on any platform):
python -m pytest tests/kernels/attention/test_use_trtllm_attention.py -v
Integration tests (requires Blackwell SM100 GPU):
python -m pytest tests/v1/attention/test_trtllm_attention_integration.py -v
Test Result
Unit tests: 22 passed
Integration tests: 3 passed
Coverage improvement for
vllm/utils/flashinfer.py: 31% -> 47%Coverage improvement for
vllm/v1/attention/backends/flashinfer.py: 0% -> 22%