diff --git a/tests/kernels/attention/test_use_trtllm_attention.py b/tests/kernels/attention/test_use_trtllm_attention.py new file mode 100644 index 000000000000..e24ad1018638 --- /dev/null +++ b/tests/kernels/attention/test_use_trtllm_attention.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import patch + +import pytest +import torch + +from vllm.utils.flashinfer import ( + can_use_trtllm_attention, + supports_trtllm_attention, + use_trtllm_attention, +) + +MODEL_CONFIGS = { + "Llama-3-70B": dict(num_qo_heads=64, num_kv_heads=8), + "Llama-3-8B": dict(num_qo_heads=32, num_kv_heads=8), + "Qwen2.5-0.5B": dict(num_qo_heads=14, num_kv_heads=2), + "Mistral-7B": dict(num_qo_heads=32, num_kv_heads=8), + "Gemma-2-9B": dict(num_qo_heads=8, num_kv_heads=4), + "Falcon-40B": dict(num_qo_heads=128, num_kv_heads=8), +} + + +def get_config(model: str) -> dict: + """Return the attention config for a model.""" + return MODEL_CONFIGS[model] + + +DEFAULT_KWARGS = dict( + **get_config("Llama-3-70B"), + num_tokens=128, + max_seq_len=4096, + dcp_world_size=1, + kv_cache_dtype="auto", + q_dtype=torch.bfloat16, + is_prefill=False, + force_use_trtllm=None, + has_sinks=False, + has_spec=False, +) + + +def _call(**overrides) -> bool: + kwargs = {**DEFAULT_KWARGS, **overrides} + return use_trtllm_attention(**kwargs) + + +@pytest.fixture(autouse=True) +def _clear_supports_cache(): + """Clear functools.cache to ensure each test runs independently.""" + supports_trtllm_attention.cache_clear() + + +# supports_trtllm_attention + + +@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=True) +def test_supports_batch_invariant_disables(_mock): + assert supports_trtllm_attention() is False + + +@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) +@patch( + "vllm.utils.flashinfer.current_platform.is_device_capability_family", + return_value=True, +) +@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True) +def test_supports_sm100_with_artifactory(_art, _cap, _bi): + assert supports_trtllm_attention() is True + + +@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) +@patch( + "vllm.utils.flashinfer.current_platform.is_device_capability_family", + return_value=False, +) +def test_supports_non_sm100_platform(_cap, _bi): + assert supports_trtllm_attention() is False + + +@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) +@patch( + "vllm.utils.flashinfer.current_platform.is_device_capability_family", + return_value=True, +) +@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=False) +def test_supports_sm100_without_artifactory(_art, _cap, _bi): + assert supports_trtllm_attention() is False + + +# can_use_trtllm_attention + + +@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=False) +def test_can_use_force_disabled(_mock): + cfg = get_config("Llama-3-70B") + assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False + + +@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None) +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_can_use_compatible_heads(_sup, _force): + cfg = get_config("Llama-3-70B") + assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is True + + +@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None) +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_can_use_incompatible_heads(_sup, _force): + assert can_use_trtllm_attention(40, 6) is False + + +@pytest.mark.parametrize("model", list(MODEL_CONFIGS.keys())) +@patch("vllm.utils.flashinfer.force_use_trtllm_attention", return_value=None) +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False) +def test_can_use_platform_unsupported(_sup, _force, model): + cfg = get_config(model) + assert can_use_trtllm_attention(cfg["num_qo_heads"], cfg["num_kv_heads"]) is False + + +# use_trtllm_attention + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_force_off(_mock): + assert _call(force_use_trtllm=False) is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_dcp_fallback(_mock): + assert _call(dcp_world_size=2) is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False) +def test_use_platform_unsupported(_mock): + assert _call() is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=False) +def test_use_platform_unsupported_force_on_still_false(_mock): + assert _call(force_use_trtllm=True) is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_incompatible_heads(_mock): + assert _call(num_qo_heads=40, num_kv_heads=6) is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_incompatible_heads_force_on_still_false(_mock): + assert _call(num_qo_heads=40, num_kv_heads=6, force_use_trtllm=True) is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_spec_decode_enables(_mock): + assert _call(has_spec=True, is_prefill=False) is True + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +@patch( + "vllm.utils.flashinfer.current_platform.fp8_dtype", + return_value=torch.float8_e4m3fn, +) +def test_use_fp8_query_forces_trtllm(_fp8, _sup): + assert _call(q_dtype=torch.float8_e4m3fn) is True + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_sinks_force_trtllm(_mock): + assert _call(has_sinks=True) is True + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_auto_prefill_kv_auto(_mock): + assert _call(is_prefill=True, kv_cache_dtype="auto") is True + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_auto_prefill_kv_fp8(_mock): + assert _call(is_prefill=True, kv_cache_dtype="fp8") is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_auto_decode_small_batch(_mock): + assert _call(is_prefill=False, num_tokens=128, kv_cache_dtype="auto") is True + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_auto_decode_large_batch(_mock): + assert _call(is_prefill=False, num_tokens=512, kv_cache_dtype="auto") is False + + +@patch("vllm.utils.flashinfer.supports_trtllm_attention", return_value=True) +def test_use_force_on(_mock): + assert _call(force_use_trtllm=True) is True diff --git a/tests/v1/attention/test_trtllm_attention_integration.py b/tests/v1/attention/test_trtllm_attention_integration.py new file mode 100644 index 000000000000..50a2c8625313 --- /dev/null +++ b/tests/v1/attention/test_trtllm_attention_integration.py @@ -0,0 +1,360 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Integration tests for TRTLLM gen-full attention through FlashInfer.""" + +import unittest.mock +from functools import partial + +import pytest +import torch +from torch.nn.attention.flex_attention import create_block_mask, flex_attention + +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_vllm_config, +) +from vllm.config import set_current_vllm_config +from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import set_random_seed +from vllm.v1.attention.backends.utils import ( + PerLayerParameters, + get_kv_cache_layout, + set_kv_cache_layout, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec + +if not current_platform.is_device_capability_family(100): + pytest.skip( + "TRTLLM integration tests require NVIDIA Blackwell (SM100).", + allow_module_level=True, + ) + +from vllm.v1.attention.backends.flashinfer import ( # noqa: E402 + FlashInferImpl, + FlashInferMetadataBuilder, + TRTLLMDecode, + TRTLLMPrefill, +) + + +class MockAttentionLayer: + """Minimal mock of an attention layer for testing.""" + + def __init__(self, device: torch.device): + self._q_scale = torch.tensor(1.0, device=device) + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + self._q_scale_float = 1.0 + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + self._o_scale_float = None + + +MODEL = "Qwen/Qwen2.5-0.5B" +BLOCK_SIZE = 16 +NUM_GPU_BLOCKS = 8192 + +BATCH_SPECS = { + "decode_only": BatchSpec( + seq_lens=[128, 256, 512], + query_lens=[1, 1, 1], + ), + "prefill_only": BatchSpec( + seq_lens=[64, 128, 256], + query_lens=[16, 32, 16], + ), + "mixed": BatchSpec( + seq_lens=[128, 256, 512, 128], + query_lens=[1, 1, 8, 16], + ), +} + + +def _mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): + head_size = vllm_config.model_config.get_head_size() + return { + name: PerLayerParameters( + window_left=-1, + logits_soft_cap=0.0, + sm_scale=1.0 / (head_size**0.5), + ) + for name in layer_names + } + + +def _create_hnd_kv_cache( + k_contexts, + v_contexts, + block_size, + num_kv_heads, + head_size, + dtype, + device, + num_blocks, + common_attn_metadata, +): + """Create and populate a KV cache with HND-compatible strides. + + The returned tensor has logical shape + (num_blocks, 2, block_size, num_kv_heads, head_size) but is physically + laid out as (num_blocks, 2, num_kv_heads, block_size, head_size) so that + ``kv_cache.permute(0, 1, 3, 2, 4)`` yields a contiguous HND view. + """ + seq_lens = common_attn_metadata.seq_lens.cpu() + query_lens = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + batch_size = len(k_contexts) + + # Build cache in (2, num_blocks, block_size, num_kv_heads, head_size) + # then convert to HND format (same approach as test_attention_backends.py). + kv_cache_raw = torch.zeros( + 2, + num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device=device, + ) + kv_cache_flat = kv_cache_raw.view(2, -1, num_kv_heads, head_size) + + start_block_idx = 1 + for i in range(batch_size): + k_ctx, v_ctx = k_contexts[i], v_contexts[i] + start = start_block_idx * block_size + end = start + k_ctx.shape[0] + kv_cache_flat[0, start:end] = k_ctx + kv_cache_flat[1, start:end] = v_ctx + start_block_idx += cdiv(int(seq_lens[i]), block_size) + + blocks_end = start_block_idx + + # Randomly permute blocks (starting from block 1; block 0 is null). + perm = torch.randperm(blocks_end - 1) + 1 + inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) + inv_perm[1:] = torch.argsort(perm) + 1 + kv_cache_raw[:, 1:blocks_end] = kv_cache_raw[:, perm] + + # Build block table. + start_block_idx = 1 + for i in range(batch_size): + n_blocks = cdiv(int(seq_lens[i]), block_size) + block_table[i, :n_blocks] = inv_perm[ + start_block_idx : start_block_idx + n_blocks + ] + start_block_idx += n_blocks + + # Build slot mapping that is consistent with the block table. + for i in range(batch_size): + ctx_len = int(seq_lens[i]) - int(query_lens[i]) + token_offsets = torch.arange(int(query_lens[i])) + ctx_len + block_indices = token_offsets // block_size + intra_block_offsets = token_offsets % block_size + start = common_attn_metadata.query_start_loc_cpu[i] + end = common_attn_metadata.query_start_loc_cpu[i + 1] + slot_mapping[start:end] = block_table[ + i, block_indices + ] * block_size + intra_block_offsets.to(device) + + # Transpose to FlashInfer logical shape then make HND-strided. + kv_cache = kv_cache_raw.transpose(0, 1) + kv_cache = kv_cache.transpose(2, 3).contiguous().transpose(2, 3) + return kv_cache + + +def _run_trtllm_integration(batch_spec): + """Run TRTLLM attention through the full FlashInfer pipeline + and compare against an SDPA reference.""" + set_random_seed(42) + device = torch.device("cuda:0") + + vllm_config = create_vllm_config( + model_name=MODEL, + max_model_len=max(batch_spec.seq_lens), + block_size=BLOCK_SIZE, + num_gpu_blocks=NUM_GPU_BLOCKS, + ) + vllm_config.attention_config.use_trtllm_attention = True + + num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ) + head_size = vllm_config.model_config.get_head_size() + dtype = vllm_config.model_config.dtype + scale = 1.0 / (head_size**0.5) + + # 1. Generate data and compute SDPA reference + all_q, all_k, all_v = [], [], [] + all_sdpa_out = [] + k_contexts, v_contexts = [], [] + + for i in range(batch_spec.batch_size): + s_len = batch_spec.seq_lens[i] + q_len = batch_spec.query_lens[i] + ctx_len = s_len - q_len + + q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype, device=device) + k_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) + v_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) + + # SDPA reference (N=1, H, L, D) + q_sdpa = q.unsqueeze(0).transpose(1, 2) + k_sdpa = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa = v_full.unsqueeze(0).transpose(1, 2) + + if num_q_heads != num_kv_heads: + repeats = num_q_heads // num_kv_heads + k_sdpa = k_sdpa.repeat_interleave(repeats, dim=1) + v_sdpa = v_sdpa.repeat_interleave(repeats, dim=1) + + def causal_mask_mod(b, h, q_idx, kv_idx, *, context_len): + return (q_idx + context_len) >= kv_idx + + mask_fn = partial(causal_mask_mod, context_len=ctx_len) + block_mask = create_block_mask( + mask_fn, B=None, H=None, Q_LEN=q_len, KV_LEN=s_len, device=device + ) + sdpa_out = flex_attention( + q_sdpa, + k_sdpa, + v_sdpa, + block_mask=block_mask, + scale=scale, + enable_gqa=True, + ) + all_sdpa_out.append(sdpa_out.transpose(1, 2).squeeze(0)) + + all_q.append(q) + all_k.append(k_full[ctx_len:]) + all_v.append(v_full[ctx_len:]) + k_contexts.append(k_full[:ctx_len]) + v_contexts.append(v_full[:ctx_len]) + + query_vllm = torch.cat(all_q, dim=0) + key_vllm = torch.cat(all_k, dim=0) + value_vllm = torch.cat(all_v, dim=0) + sdpa_output = torch.cat(all_sdpa_out, dim=0) + + common_attn_metadata = create_common_attn_metadata(batch_spec, BLOCK_SIZE, device) + + # 2. Create HND KV cache + kv_cache = _create_hnd_kv_cache( + k_contexts, + v_contexts, + BLOCK_SIZE, + num_kv_heads, + head_size, + dtype, + device, + NUM_GPU_BLOCKS, + common_attn_metadata, + ) + + # 3. Run through FlashInfer with TRTLLM enabled + set_kv_cache_layout("HND") + get_kv_cache_layout.cache_clear() + + try: + kv_cache_spec = FullAttentionSpec( + block_size=BLOCK_SIZE, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + ) + layer_names = ["test_layer_0"] + + with ( + set_current_vllm_config(vllm_config), + unittest.mock.patch( + "vllm.utils.flashinfer.supports_trtllm_attention", + return_value=True, + ), + unittest.mock.patch( + "vllm.v1.attention.backends.flashinfer.get_per_layer_parameters", + _mock_get_per_layer_parameters, + ), + ): + builder = FlashInferMetadataBuilder( + kv_cache_spec, layer_names, vllm_config, device + ) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + # Verify the correct TRTLLM metadata types were produced. + has_prefills = any(ql > 1 for ql in batch_spec.query_lens) + has_decodes = any(ql == 1 for ql in batch_spec.query_lens) + + if has_prefills: + assert isinstance(attn_metadata.prefill, TRTLLMPrefill), ( + f"Expected TRTLLMPrefill, got {type(attn_metadata.prefill)}" + ) + if has_decodes: + assert isinstance(attn_metadata.decode, TRTLLMDecode), ( + f"Expected TRTLLMDecode, got {type(attn_metadata.decode)}" + ) + + impl = FlashInferImpl( + num_heads=num_q_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + ) + + mock_layer = MockAttentionLayer(device) + output = torch.empty_like(query_vllm) + + impl.do_kv_cache_update( + mock_layer, + key_vllm, + value_vllm, + kv_cache, + attn_metadata.slot_mapping, + ) + + output = impl.forward( + mock_layer, + query_vllm, + key_vllm, + value_vllm, + kv_cache, + attn_metadata, + output=output, + ) + + # 4. Compare against SDPA reference + torch.testing.assert_close( + output, + sdpa_output, + atol=1e-2, + rtol=1e-2, + ) + + finally: + set_kv_cache_layout(None) + get_kv_cache_layout.cache_clear() + + +@pytest.mark.parametrize( + "batch_spec_name", + list(BATCH_SPECS.keys()), +) +@torch.inference_mode() +def test_trtllm_gen_full_attention_integration(batch_spec_name: str): + """Test TRTLLM gen-full attention through the full FlashInfer + MetadataBuilder.build() -> FlashInferImpl.forward() pipeline, + with real TRTLLM kernels on Blackwell.""" + _run_trtllm_integration(BATCH_SPECS[batch_spec_name])