Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions tests/kernels/attention/test_use_trtllm_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from unittest.mock import patch

import pytest
import torch

from vllm.utils.flashinfer import (
can_use_trtllm_attention,
supports_trtllm_attention,
use_trtllm_attention,
)

MODEL_CONFIGS = {
"Llama-3-70B": dict(num_qo_heads=64, num_kv_heads=8),
"Llama-3-8B": dict(num_qo_heads=32, num_kv_heads=8),
"Qwen2.5-0.5B": dict(num_qo_heads=14, num_kv_heads=2),
"Mistral-7B": dict(num_qo_heads=32, num_kv_heads=8),
"Gemma-2-9B": dict(num_qo_heads=8, num_kv_heads=4),
"Falcon-40B": dict(num_qo_heads=128, num_kv_heads=8),
}


def get_config(model: str) -> dict:
"""Return the attention config for a model."""
return MODEL_CONFIGS[model]


DEFAULT_KWARGS = dict(
**get_config("Llama-3-70B"),
num_tokens=128,
max_seq_len=4096,
dcp_world_size=1,
kv_cache_dtype="auto",
q_dtype=torch.bfloat16,
is_prefill=False,
force_use_trtllm=None,
has_sinks=False,
has_spec=False,
)


def _call(**overrides) -> bool:
kwargs = {**DEFAULT_KWARGS, **overrides}
return use_trtllm_attention(**kwargs)


@pytest.fixture(autouse=True)
def _clear_supports_cache():
"""Clear functools.cache to ensure each test runs independently."""
supports_trtllm_attention.cache_clear()


# supports_trtllm_attention


@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=True)
def test_supports_batch_invariant_disables(_mock):
assert supports_trtllm_attention() is False


@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False)
@patch(
"vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=True,
)
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True)
def test_supports_sm100_with_artifactory(_art, _cap, _bi):
assert supports_trtllm_attention() is True


@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False)
@patch(
"vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=False,
)
def test_supports_non_sm100_platform(_cap, _bi):
assert supports_trtllm_attention() is False


@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False)
@patch(
"vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=True,
)
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=False)
def test_supports_sm100_without_artifactory(_art, _cap, _bi):
assert supports_trtllm_attention() is False


# can_use_trtllm_attention


@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=False)
def test_can_use_force_disabled(_mock):
cfg = get_config("Llama-3-70B")
assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False


@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None)
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_can_use_compatible_heads(_sup, _force):
cfg = get_config("Llama-3-70B")
assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is True


@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None)
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_can_use_incompatible_heads(_sup, _force):
assert can_use_trtllm_attention(40, 6) is False


@pytest.mark.parametrize("model", list(MODEL_CONFIGS.keys()))
@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None)
@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False)
def test_can_use_platform_unsupported(_sup, _force, model):
cfg = get_config(model)
assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False


# use_trtllm_attention


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_force_off(_mock):
assert _call(force_use_trtllm=False) is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_dcp_fallback(_mock):
assert _call(dcp_world_size=2) is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False)
def test_use_platform_unsupported(_mock):
assert _call() is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False)
def test_use_platform_unsupported_force_on_still_false(_mock):
assert _call(force_use_trtllm=True) is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_incompatible_heads(_mock):
assert _call(num_qo_heads=40, num_kv_heads=6) is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_incompatible_heads_force_on_still_false(_mock):
assert _call(num_qo_heads=40, num_kv_heads=6, force_use_trtllm=True) is False

Comment on lines +149 to +153
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should re-think the whole test in terms of model config rather than arbitrary numbers that would and wouldn't use trtllm. For example, something like _call(get_config("Llama-3")). That would also help us quickly see which models does TRTLLM support vs which ones don't. We could have a dictionary that maps model to its individual parameters like query heads, kv heads, uses sliding window attention, etc.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a MODEL_CONFIGS dictionary with get_config() as suggested. The default test config now uses get_config("Llama-3-70B"). For the incompatible heads case, I kept bare numbers where num_qo_heads % num_kv_heads != 0. Let me know if this is what you were thinking or which specific models I should add to the dictionary?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach sounds good to me. I did a quick internet search and I don't see many models that actually have the scenario where num_qo_heads % num_kv_heads != 0. There were some test models but no production models.


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_spec_decode_enables(_mock):
assert _call(has_spec=True, is_prefill=False) is True


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
@patch(
"vllm.utils.flashinfer.current_platform.fp8_dtype",
return_value=torch.float8_e4m3fn,
)
def test_use_fp8_query_forces_trtllm(_fp8, _sup):
assert _call(q_dtype=torch.float8_e4m3fn) is True


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_sinks_force_trtllm(_mock):
assert _call(has_sinks=True) is True


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_auto_prefill_kv_auto(_mock):
assert _call(is_prefill=True, kv_cache_dtype="auto") is True


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_auto_prefill_kv_fp8(_mock):
assert _call(is_prefill=True, kv_cache_dtype="fp8") is False


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_auto_decode_small_batch(_mock):
assert _call(is_prefill=False, num_tokens=128, kv_cache_dtype="auto") is True


@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_auto_decode_large_batch(_mock):
assert _call(is_prefill=False, num_tokens=512, kv_cache_dtype="auto") is False
Comment on lines +184 to +191
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we take this opportunity to actually measure if the bounds set are still true?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! The test values (128 and 512) were chosen to land on either side of the num_tokens <= 256 threshold on line 386 of flashinfer.py. They verify the branching logic works correctly. Did you want me to benchmark to confirm 256 is still the optimal crossover point between TRTLLM and FlashInfer native decode?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly. It's been a while since we ran the benchmark to see where the perf flips in favor of Flashinfer native FA2 kernels.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I benchmarked on GB200 using benchmarks/kernels/benchmark_trtllm_decode_attention.py to confirm the num_tokens <= 256 threshold for bf16 decode (kv_cache_dtype="auto").

I found that at batch_size=256, TRTLLM wins across all sequence lengths (1K-128K), confirming the current threshold is safe and the crossover where FlashInfer overtakes TRTLLM depends on sequence length.

For shorter sequences (1K-8K), TRTLLM is faster even up to batch_size=2048. For longer sequences the crossover shifts lower: 896 at 32K, 640 at 64K, and 512-640 at 128K.

The worst-case crossover is 512, so the threshold could potentially be raised to 384-512? Let me know if you'd like me to raise the threshold (e.g., to 384 or 512) in this PR, or if we should keep it at 256 for now and address it separately.

Here are the benchmark CSV's:
Full benchmark (all quant configs, batch sizes 1-256)
Crossover sweep (bf16 only, batch sizes 128-2048)



@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True)
def test_use_force_on(_mock):
assert _call(force_use_trtllm=True) is True
Loading