-
-
Notifications
You must be signed in to change notification settings - Fork 16.1k
[CI] TRTLLM Gen-full Attn Test Coverage #34986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,196 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from unittest.mock import patch | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.utils.flashinfer import ( | ||
| can_use_trtllm_attention, | ||
| supports_trtllm_attention, | ||
| use_trtllm_attention, | ||
| ) | ||
|
|
||
| MODEL_CONFIGS = { | ||
| "Llama-3-70B": dict(num_qo_heads=64, num_kv_heads=8), | ||
| "Llama-3-8B": dict(num_qo_heads=32, num_kv_heads=8), | ||
| "Qwen2.5-0.5B": dict(num_qo_heads=14, num_kv_heads=2), | ||
| "Mistral-7B": dict(num_qo_heads=32, num_kv_heads=8), | ||
| "Gemma-2-9B": dict(num_qo_heads=8, num_kv_heads=4), | ||
| "Falcon-40B": dict(num_qo_heads=128, num_kv_heads=8), | ||
| } | ||
|
|
||
|
|
||
| def get_config(model: str) -> dict: | ||
| """Return the attention config for a model.""" | ||
| return MODEL_CONFIGS[model] | ||
|
|
||
|
|
||
| DEFAULT_KWARGS = dict( | ||
| **get_config("Llama-3-70B"), | ||
| num_tokens=128, | ||
| max_seq_len=4096, | ||
| dcp_world_size=1, | ||
| kv_cache_dtype="auto", | ||
| q_dtype=torch.bfloat16, | ||
| is_prefill=False, | ||
| force_use_trtllm=None, | ||
| has_sinks=False, | ||
| has_spec=False, | ||
| ) | ||
|
|
||
|
|
||
| def _call(**overrides) -> bool: | ||
| kwargs = {**DEFAULT_KWARGS, **overrides} | ||
| return use_trtllm_attention(**kwargs) | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def _clear_supports_cache(): | ||
| """Clear functools.cache to ensure each test runs independently.""" | ||
| supports_trtllm_attention.cache_clear() | ||
|
|
||
|
|
||
| # supports_trtllm_attention | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=True) | ||
| def test_supports_batch_invariant_disables(_mock): | ||
| assert supports_trtllm_attention() is False | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) | ||
| @patch( | ||
| "vllm.utils.flashinfer.current_platform.is_device_capability_family", | ||
| return_value=True, | ||
| ) | ||
| @patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True) | ||
| def test_supports_sm100_with_artifactory(_art, _cap, _bi): | ||
| assert supports_trtllm_attention() is True | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) | ||
| @patch( | ||
| "vllm.utils.flashinfer.current_platform.is_device_capability_family", | ||
| return_value=False, | ||
| ) | ||
| def test_supports_non_sm100_platform(_cap, _bi): | ||
| assert supports_trtllm_attention() is False | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) | ||
| @patch( | ||
| "vllm.utils.flashinfer.current_platform.is_device_capability_family", | ||
| return_value=True, | ||
| ) | ||
| @patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=False) | ||
| def test_supports_sm100_without_artifactory(_art, _cap, _bi): | ||
| assert supports_trtllm_attention() is False | ||
|
|
||
|
|
||
| # can_use_trtllm_attention | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=False) | ||
| def test_can_use_force_disabled(_mock): | ||
| cfg = get_config("Llama-3-70B") | ||
| assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None) | ||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_can_use_compatible_heads(_sup, _force): | ||
| cfg = get_config("Llama-3-70B") | ||
| assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is True | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None) | ||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_can_use_incompatible_heads(_sup, _force): | ||
| assert can_use_trtllm_attention(40, 6) is False | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model", list(MODEL_CONFIGS.keys())) | ||
| @patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None) | ||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False) | ||
| def test_can_use_platform_unsupported(_sup, _force, model): | ||
| cfg = get_config(model) | ||
| assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False | ||
|
|
||
|
|
||
| # use_trtllm_attention | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_force_off(_mock): | ||
| assert _call(force_use_trtllm=False) is False | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_dcp_fallback(_mock): | ||
| assert _call(dcp_world_size=2) is False | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False) | ||
| def test_use_platform_unsupported(_mock): | ||
| assert _call() is False | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False) | ||
| def test_use_platform_unsupported_force_on_still_false(_mock): | ||
| assert _call(force_use_trtllm=True) is False | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_incompatible_heads(_mock): | ||
| assert _call(num_qo_heads=40, num_kv_heads=6) is False | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_incompatible_heads_force_on_still_false(_mock): | ||
| assert _call(num_qo_heads=40, num_kv_heads=6, force_use_trtllm=True) is False | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_spec_decode_enables(_mock): | ||
| assert _call(has_spec=True, is_prefill=False) is True | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| @patch( | ||
| "vllm.utils.flashinfer.current_platform.fp8_dtype", | ||
| return_value=torch.float8_e4m3fn, | ||
| ) | ||
| def test_use_fp8_query_forces_trtllm(_fp8, _sup): | ||
| assert _call(q_dtype=torch.float8_e4m3fn) is True | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_sinks_force_trtllm(_mock): | ||
| assert _call(has_sinks=True) is True | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_auto_prefill_kv_auto(_mock): | ||
| assert _call(is_prefill=True, kv_cache_dtype="auto") is True | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_auto_prefill_kv_fp8(_mock): | ||
| assert _call(is_prefill=True, kv_cache_dtype="fp8") is False | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_auto_decode_small_batch(_mock): | ||
| assert _call(is_prefill=False, num_tokens=128, kv_cache_dtype="auto") is True | ||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_auto_decode_large_batch(_mock): | ||
| assert _call(is_prefill=False, num_tokens=512, kv_cache_dtype="auto") is False | ||
|
Comment on lines
+184
to
+191
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we take this opportunity to actually measure if the bounds set are still true?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! The test values (
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, exactly. It's been a while since we ran the benchmark to see where the perf flips in favor of Flashinfer native FA2 kernels.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I benchmarked on GB200 using I found that at For shorter sequences (1K-8K), TRTLLM is faster even up to batch_size=2048. For longer sequences the crossover shifts lower: The worst-case crossover is Here are the benchmark CSV's: |
||
|
|
||
|
|
||
| @patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) | ||
| def test_use_force_on(_mock): | ||
| assert _call(force_use_trtllm=True) is True | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we should re-think the whole test in terms of model config rather than arbitrary numbers that would and wouldn't use trtllm. For example, something like
_call(get_config("Llama-3")). That would also help us quickly see which models does TRTLLM support vs which ones don't. We could have a dictionary that maps model to its individual parameters like query heads, kv heads, uses sliding window attention, etc.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a
MODEL_CONFIGSdictionary withget_config()as suggested. The default test config now usesget_config("Llama-3-70B"). For the incompatible heads case, I kept bare numbers wherenum_qo_heads % num_kv_heads != 0. Let me know if this is what you were thinking or which specific models I should add to the dictionary?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach sounds good to me. I did a quick internet search and I don't see many models that actually have the scenario where num_qo_heads % num_kv_heads != 0. There were some test models but no production models.