diff --git a/vllm_spyre_next/examples/Offline_demo.py b/vllm_spyre_next/examples/Offline_demo.py new file mode 100644 index 000000000..fadea332a --- /dev/null +++ b/vllm_spyre_next/examples/Offline_demo.py @@ -0,0 +1,64 @@ +### TEST 1 - Disable prefix caching + +from vllm import LLM, SamplingParams +from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.config import AttentionConfig + + +def print_outputs(outputs, engine): + print("-" * 50) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + print("-" * 50) + for m in engine.llm_engine.get_metrics(): + if "cache" in m.name: + print(m.name, m.value) + + +def main(): + # Configuration + # MODEL = "ibm-granite/granite-3.0-8b-base" # Tiny model + MODEL = "ibm-granite/granite-3.3-8b-instruct" # Instruct model + # MODEL = "ibm-granite/granite-4.0-tiny-preview" # Granite 3 + # MODEL = "ibm-granite/granite-4.0-h-small" # Granite 4 + # MODEL = "ibm-granite/granite-4.0-h-tiny" # Granite 4 + # MODEL = "facebook/opt-125m" # Small model + + # Sampling parameter for the inference process + sampling_params = SamplingParams( + max_tokens=5, # Maximum number of tokens to produce + ) + + # Prompts to use for inference + prompts = [ + "What are IBMs main businesses?", + ] + + engine = LLM( + model=MODEL, # Model to use for inference. + # By increasing utilization, you can provide more KV cache space. + gpu_memory_utilization=0.9, + # Flag determining whether prefix caching is enabled or disabled. + enable_prefix_caching=True, + # # Flag determining whether eager mode or torch.compile should be used. + # enforce_eager=True, + # # Datatype of the mamba cache (if any). + # mamba_ssm_cache_dtype="float32", + # # Datatype of the model. + # dtype="float32", + # # Maximum number of tokens for a prefill before being chunked + # max_num_batched_tokens=8192, + # # compliates logic with mamba + # disable_cascade_attn=True, + disable_log_stats=False, ## stats + attention_config=AttentionConfig(backend=AttentionBackendEnum.CUSTOM), + ) + + # Generate response for prompt 0 + outputs = engine.generate(prompts[0], sampling_params) + print_outputs(outputs, engine) + + +if __name__ == "__main__": + main() diff --git a/vllm_spyre_next/tests/attention_test_utils.py b/vllm_spyre_next/tests/attention_test_utils.py new file mode 100644 index 000000000..207712ad7 --- /dev/null +++ b/vllm_spyre_next/tests/attention_test_utils.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utility functions for attention-related v1 tests.""" + +from dataclasses import dataclass + +import pytest +import torch + +from vllm.config import ( + CacheConfig, + CompilationConfig, + DeviceConfig, + LoadConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + VllmConfig, +) +from vllm.config.model import ModelDType +from vllm.v1.attention.backend import ( + AttentionImpl, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) +from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.kv_cache_interface import FullAttentionSpec + + +@dataclass +class BatchSpec: + """Specification for a batch configuration (workload shape only).""" + + seq_lens: list[int] + query_lens: list[int] + + name: str = "unnamed" + + @property + def batch_size(self): + return len(self.seq_lens) + + def __post_init__(self): + assert len(self.seq_lens) == len(self.query_lens) + + def compute_num_tokens(self): + return sum(self.query_lens) + + +def create_common_attn_metadata( + batch_spec: BatchSpec, + block_size: int, + device: torch.device, + max_block_idx: int = 1000, + arrange_block_indices: bool = False, +) -> CommonAttentionMetadata: + """Create CommonAttentionMetadata from a BatchSpec and ModelParams.""" + # Create query start locations + query_start_loc = torch.zeros(batch_spec.batch_size + 1, dtype=torch.int32, device=device) + query_start_loc[1:] = torch.tensor( + batch_spec.query_lens, dtype=torch.int32, device=device + ).cumsum(0) + query_start_loc_cpu = query_start_loc.cpu() + num_tokens = batch_spec.compute_num_tokens() + + # Create sequence lengths + seq_lens = torch.tensor(batch_spec.seq_lens, dtype=torch.int32, device=device) + seq_lens_cpu = seq_lens.cpu() + max_seq_len = int(seq_lens_cpu.max()) + + # Create computed tokens (context length for each sequence) + context_lens = [ + batch_spec.seq_lens[i] - batch_spec.query_lens[i] for i in range(batch_spec.batch_size) + ] + num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32) + + # Create block table and slot mapping + max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size + if arrange_block_indices: + num_blocks = batch_spec.batch_size * max_blocks + block_table_tensor = torch.arange(num_blocks, dtype=torch.int32, device=device).view( + batch_spec.batch_size, max_blocks + ) + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device).view(num_tokens) + else: + block_table_tensor = torch.randint( + 0, + max_block_idx, + (batch_spec.batch_size, max_blocks), + dtype=torch.int32, + device=device, + ) + slot_mapping = torch.randint( + 0, max_block_idx, (num_tokens,), dtype=torch.int64, device=device + ) + + # Calculate max query length + max_query_len = max(batch_spec.query_lens) + + return CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + _seq_lens_cpu=seq_lens_cpu, + _num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=batch_spec.batch_size, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + block_table_tensor=block_table_tensor, + slot_mapping=slot_mapping, + causal=True, + ) + + +def try_get_attention_backend( + backend: AttentionBackendEnum, +) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]: + """Try to get the attention backend class, skipping test if not found.""" + try: + backend_class = backend.get_class() + return backend_class.get_builder_cls(), backend_class.get_impl_cls() + except ImportError as e: + pytest.skip(f"{backend.name} not available: {e}") + raise AssertionError("unreachable") from None + + +def try_backend_includes_kv_cache_update( + backend: AttentionBackendEnum, +) -> bool: + """Try to get the attention backend class, skipping test if not found.""" + try: + backend_class = backend.get_class() + return backend_class.forward_includes_kv_cache_update + except ImportError as e: + pytest.skip(f"{backend.name} not available: {e}") + raise AssertionError("unreachable") from None + + +def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec: + """Create a FullAttentionSpec from ModelParams only.""" + return FullAttentionSpec( + block_size=vllm_config.cache_config.block_size, + num_kv_heads=vllm_config.model_config.get_num_kv_heads(vllm_config.parallel_config), + head_size=vllm_config.model_config.get_head_size(), + dtype=vllm_config.model_config.dtype, + sliding_window=vllm_config.model_config.get_sliding_window(), + ) + + +def create_vllm_config( + model_name: str = "meta-llama/Meta-Llama-3-8B", + tensor_parallel_size: int = 1, + max_model_len: int = 1024, + dtype: ModelDType | torch.dtype = "auto", + num_gpu_blocks: int = 1000, + block_size: int = 16, + max_num_seqs: int = 256, + max_num_batched_tokens: int = 8192, + enable_chunked_prefill: bool = True, + add_mock_model_methods: bool = True, + hf_config_override: dict | None = None, +) -> VllmConfig: + """Create a VllmConfig for testing with reasonable defaults.""" + + model_config = ModelConfig( + model=model_name, + tokenizer=model_name, + trust_remote_code=False, + dtype=dtype, + seed=0, + max_model_len=max_model_len, + ) + + cache_config = CacheConfig( + block_size=block_size, + cache_dtype="auto", + swap_space=0, + ) + # Set cache blocks for testing + # (these may be set during initialization normally) + cache_config.num_gpu_blocks = num_gpu_blocks + cache_config.num_cpu_blocks = 0 + + parallel_config = ParallelConfig( + tensor_parallel_size=tensor_parallel_size, + ) + + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ) + + device_config = DeviceConfig() + load_config = LoadConfig() + compilation_config = CompilationConfig() + + if add_mock_model_methods: + # Add mock methods to satisfy backends that need them + # This is a workaround because tests don't build full, real models, + # but some backends expect to query the model for layer-specific + # parameters + import types + + model_config.get_num_layers = types.MethodType(lambda self: 1, model_config) + model_config.get_sliding_window_for_layer = types.MethodType( + lambda self, i: None, model_config + ) + model_config.get_logits_soft_cap_for_layer = types.MethodType( + lambda self, i: 0.0, model_config + ) + model_config.get_sm_scale_for_layer = types.MethodType( + lambda self, i: 1.0 / model_config.get_head_size() ** 0.5, model_config + ) + + if hf_config_override: + model_config.hf_config.update(hf_config_override) + + return VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config, + compilation_config=compilation_config, + ) + + +def create_dummy_kv_cache( + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int = 100, +) -> torch.Tensor: + """Create a dummy KV cache tensor for testing.""" + kv_cache = torch.randn( + num_blocks, + 2, # K and V + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device=device, + ) + return kv_cache + + +@dataclass +class BackendConfig: + name: str + attention_config: dict + comp_config: dict + specific_gpu_arch: tuple | None = None + + +# Define all backend configurations of full cudagraph to be tested +full_cg_backend_configs = { + # FA3 on Hopper + "FA3": BackendConfig( + name="FA3", + attention_config={ + "backend": "FLASH_ATTN", + "flash_attn_version": 3, + "flash_attn_max_num_splits_for_cuda_graph": 16, + }, + comp_config={ + "cudagraph_mode": "FULL", + }, + specific_gpu_arch=(9, 0), + ), + # FlashMLA on Hopper + "FlashMLA": BackendConfig( + name="FlashMLA", + attention_config={"backend": "FLASHMLA"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(9, 0), + ), + # Cutlass MLA on Blackwell + "CutlassMLA": BackendConfig( + name="CutlassMLA", + attention_config={"backend": "CUTLASS_MLA"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(10, 0), + ), + # FlashInfer MLA on Blackwell + "FlashInferMLA": BackendConfig( + name="FlashInferMLA", + attention_config={"backend": "FLASHINFER_MLA"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(10, 0), + ), + # FlashAttention MLA on Hopper + "FlashAttentionMLA": BackendConfig( + name="FlashAttentionMLA", + attention_config={ + "backend": "FLASH_ATTN_MLA", + "flash_attn_max_num_splits_for_cuda_graph": 16, + }, + comp_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + specific_gpu_arch=(9, 0), + ), + # FA2 + "FA2": BackendConfig( + name="FA2", + attention_config={ + "backend": "FLASH_ATTN", + "flash_attn_version": 2, + "flash_attn_max_num_splits_for_cuda_graph": 16, + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + ), + # Triton Attention + "TritonAttn": BackendConfig( + name="TritonAttn", + attention_config={"backend": "TRITON_ATTN"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + ), + # FlashInfer + "FlashInfer": BackendConfig( + name="FlashInfer", + attention_config={"backend": "FLASHINFER"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + ), + "RocmAttn": BackendConfig( + name="RocmAttn", + attention_config={ + "backend": "ROCM_ATTN", + "use_prefill_decode_attention": True, + }, + comp_config={ + "cudagraph_mode": "FULL", + }, + ), +} diff --git a/vllm_spyre_next/tests/conftest.py b/vllm_spyre_next/tests/conftest.py index 38916b016..3f34f1c24 100644 --- a/vllm_spyre_next/tests/conftest.py +++ b/vllm_spyre_next/tests/conftest.py @@ -289,6 +289,9 @@ def pytest_collection_modifyitems(config, items): item._nodeid = f"{vllm_prefix}::{test_part}" else: item._nodeid = vllm_prefix + else: + # Add spyre mark to our own tests + item.add_marker(pytest.mark.spyre) if marked_count > 0: _log(f"[vllm-upstream] Marked {marked_count} tests as 'upstream'") diff --git a/vllm_spyre_next/tests/test_attention_correctness_stripped.py b/vllm_spyre_next/tests/test_attention_correctness_stripped.py new file mode 100644 index 000000000..917fe1e7e --- /dev/null +++ b/vllm_spyre_next/tests/test_attention_correctness_stripped.py @@ -0,0 +1,565 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for v1 attention backends without GPUModelRunner dependency.""" + +from functools import partial + +import pytest +import torch +from torch.nn.attention.flex_attention import create_block_mask, flex_attention + +import sys +import os + +sys.path.append( + os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir, os.pardir) +) + +from attention_test_utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + try_backend_includes_kv_cache_update, + try_get_attention_backend, +) +from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + is_torch_equal_or_newer, + set_random_seed, +) +from vllm.v1.attention.backend import AttentionType, CommonAttentionMetadata +from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec + +BACKENDS_TO_TEST = [ + AttentionBackendEnum.CUSTOM, +] + + +def _convert_dtype_to_torch(dtype): + """Convert ModelDType to torch.dtype.""" + if isinstance(dtype, str): + if dtype == "auto": + return torch.float16 # Default dtype for testing + elif dtype in STR_DTYPE_TO_TORCH_DTYPE: + return STR_DTYPE_TO_TORCH_DTYPE[dtype] + else: + raise ValueError(f"Unknown dtype: {dtype}") + elif isinstance(dtype, torch.dtype): + return dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + +# Define common batch configurations +BATCH_SPECS = { + "small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), + "small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), + "mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), + "medium_decode": BatchSpec( + seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], + query_lens=[1, 1, 1, 1, 1, 1, 1, 1], + ), + "medium_prefill": BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]), + "mixed_medium": BatchSpec( + seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7] + ), + "large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), + "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "mixed_large": BatchSpec( + seq_lens=[1024, 2048, 4096, 1024, 2048, 4096], query_lens=[1, 1, 1, 32, 32, 32] + ), + "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), + # encoder-only + "small_encoder_prefill": BatchSpec(seq_lens=[32, 64, 128, 256], query_lens=[32, 64, 128, 256]), + "medium_encoder_prefill": BatchSpec( + seq_lens=[256, 512, 1024, 2048], query_lens=[256, 512, 1024, 2048] + ), +} + + +def create_and_prepopulate_kv_cache( + k_contexts: list[torch.Tensor], + v_contexts: list[torch.Tensor], + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True, +) -> torch.Tensor: + """Create and prepopulate a KV cache with context data. + + Args: + k_contexts: List of key context tensors for each sequence + v_contexts: List of value context tensors for each sequence + seq_lens: List of sequence lengths + block_size: Size of each block + num_kv_heads: Number of KV heads + head_size: Size of each head + dtype: Data type for the cache + device: Device to create the cache on + num_blocks: Total number of blocks in the cache + block_table: Block table tensor to populate + randomize_blocks: Whether to randomly permute blocks + or use sequential order + + Returns: + Tuple of (kv_cache, updated_block_table) + """ + batch_size = len(k_contexts) + seq_lens = common_attn_metadata.seq_lens.cpu() + query_lens = ( + common_attn_metadata.query_start_loc_cpu[1:] - common_attn_metadata.query_start_loc_cpu[:-1] + ) + context_lens = seq_lens - query_lens + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + # Create KV cache + kv_cache = torch.zeros( + 2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device + ) + kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size) + + # Populate the cache with the context tokens + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + k_context, v_context = k_contexts[i], v_contexts[i] + start = start_block_idx * block_size + end = start + k_context.shape[0] + kv_cache_flat[0, start:end, ...] = k_context + kv_cache_flat[1, start:end, ...] = v_context + + # Stay block aligned and allocate enough blocks for the new tokens + start_block_idx += cdiv(int(seq_lens[i]), block_size) + + blocks_end = start_block_idx + + # Permute the context blocks (excluding block 0 which is null) + if randomize_blocks: + # Random permutation starting from block 1 + perm = torch.randperm(blocks_end - 1) + 1 + else: + # Sequential order starting from block 1 + perm = torch.arange(1, blocks_end) + + inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) + # Add 1 to account for starting from block 1 + inv_perm[1:] = torch.argsort(perm) + 1 + kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...] + + # Construct the right block table + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size) + start = start_block_idx + end = start + num_blocks_for_seq + block_table[i, :num_blocks_for_seq] = inv_perm[start:end] + start_block_idx += num_blocks_for_seq + + # Create a realistic slot mapping that corresponds to the block table + for i in range(batch_size): + token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i]) + block_indices = token_offsets // block_size + token_inter_block_offsets = token_offsets % block_size + start = common_attn_metadata.query_start_loc_cpu[i] + end = common_attn_metadata.query_start_loc_cpu[i + 1] + slot_mapping[start:end] = block_table[ + i, block_indices + ] * block_size + token_inter_block_offsets.to(device) + + return kv_cache + + +class MockAttentionLayer: + """A mock attention layer for testing.""" + + def __init__(self, device: torch.device): + self._q_scale = torch.tensor(1.0, device=device) + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + # Add float versions for flashinfer + self._q_scale_float = 1.0 + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + + +def run_attention_backend( + backend: AttentionBackendEnum, + kv_cache_spec: FullAttentionSpec, + layer_names: list[str], + vllm_config, + device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_type: AttentionType = AttentionType.DECODER, + sliding_window: int | None = None, +) -> torch.Tensor: + """Run attention computation using the specified backend's AttentionImpl.""" + + # Handle special case for FLEX_ATTENTION_SLOW + actual_backend = backend + + use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0") + if backend == "FLEX_ATTENTION_SLOW": + actual_backend = AttentionBackendEnum.FLEX_ATTENTION + use_direct_block_mask = False + + builder_cls, impl_cls = try_get_attention_backend(actual_backend) + + # Mock flashinfer's get_per_layer_parameters if needed + if actual_backend == AttentionBackendEnum.FLASHINFER: + import unittest.mock + + from vllm.v1.attention.backends.utils import PerLayerParameters + + def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): + # Return mock parameters for a single layer + head_size = vllm_config.model_config.get_head_size() + return { + layer_name: PerLayerParameters( + window_left=-1, # No sliding window + logits_soft_cap=0.0, # No soft cap + sm_scale=1.0 / (head_size**0.5), # Standard scale + ) + for layer_name in layer_names + } + + with unittest.mock.patch( + "vllm.v1.attention.backends.flashinfer.get_per_layer_parameters", + mock_get_per_layer_parameters, + ): + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + # Build metadata + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + if actual_backend == AttentionBackendEnum.FLEX_ATTENTION: + builder.direct_build = use_direct_block_mask + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + # Instantiate implementation + num_heads = vllm_config.model_config.get_num_attention_heads(vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads(vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + scale = 1.0 / (head_size**0.5) + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=sliding_window, + attn_type=attn_type, + kv_cache_dtype="auto", + ) + + # Create mock layer and output buffer + mock_layer = MockAttentionLayer(device) + output = torch.empty_like(query) + + # Run forward pass + # NOTE: The query, key, and value are already shaped correctly + # in the calling test function. + if not try_backend_includes_kv_cache_update(actual_backend): + impl.do_kv_cache_update(mock_layer, key, value, kv_cache, attn_metadata.slot_mapping) + output = impl.forward(mock_layer, query, key, value, kv_cache, attn_metadata, output=output) + + return output + + +def _test_backend_correctness( + batch_spec: BatchSpec, + model: str, + backend_to_test: list[AttentionBackendEnum | str], + mask_mod, + *, + attn_type: AttentionType = AttentionType.DECODER, + block_size: int = 16, + atol: float = 1e-2, + rtol: float = 1e-2, + tensor_parallel_size: int = 1, +): + """ + Test that all backends produce similar outputs to a reference implementation + using torch.nn.functional.scaled_dot_product_attention. + + This test works by: + 1. Generating a batch of sequences with specified context and query lengths. + 2. Computing a ground-truth attention output using torch.sdpa on + contiguous Q, K, and V tensors. + 3. Simulating vLLM's paged KV cache: It takes the context portion of the + K/V tensors and manually places them into a paged buffer according to + the test's (randomly generated) block table. + 4. Running each vLLM attention backend with the new queries and the + simulated paged KV cache. + 5. Comparing the vLLM backend's output to the ground-truth SDPA output. + + Note: When tensor_parallel_size > 1, we simulate the head partitioning + by overriding the model config to use fewer heads, without requiring + multiple GPUs. This tests that backends work correctly with different + head counts. + """ + set_random_seed(42) + + hf_config_override = None + if tensor_parallel_size > 1: + from vllm.config import ModelConfig + + temp_config = ModelConfig(model=model, max_model_len=1) + original_num_heads = temp_config.hf_text_config.num_attention_heads + original_num_kv_heads = getattr(temp_config.hf_text_config, "num_key_value_heads", None) + hf_config_override = { + "num_attention_heads": original_num_heads // tensor_parallel_size, + } + if original_num_kv_heads is not None: + hf_config_override["num_key_value_heads"] = max( + 1, original_num_kv_heads // tensor_parallel_size + ) + + vllm_config = create_vllm_config( + model_name=model, + tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements + max_model_len=max(batch_spec.seq_lens), + block_size=block_size, + num_gpu_blocks=8192, + hf_config_override=hf_config_override, + ) + # device = torch.device("cuda:0") + device = torch.device("cpu") + + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # 1. Setup + batch_size = batch_spec.batch_size + seq_lens = batch_spec.seq_lens + query_lens = batch_spec.query_lens + num_q_heads = vllm_config.model_config.get_num_attention_heads(vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads(vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + sliding_window = vllm_config.model_config.get_sliding_window() + dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + block_size = vllm_config.cache_config.block_size + scale = 1.0 / (head_size**0.5) + + # 2. Generate data and compute SDPA reference output + all_q_vllm, all_k_vllm, all_v_vllm = [], [], [] + all_sdpa_outputs = [] + k_contexts, v_contexts = [], [] + + for i in range(batch_size): + s_len = seq_lens[i] + q_len = query_lens[i] + context_len = s_len - q_len + + # Generate Q, K, V for the whole sequence to be used in SDPA + q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype, device=device) + k_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) + v_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) + + # SDPA expects (N, H, L, D), so unsqueeze batch and permute + q_sdpa_in = q.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0, ( + f"num_q_heads ({num_q_heads}) must be divisible by num_kv_heads ({num_kv_heads})" + ) + repeats = num_q_heads // num_kv_heads + k_sdpa_in = k_sdpa_in.repeat_interleave(repeats, dim=1) + v_sdpa_in = v_sdpa_in.repeat_interleave(repeats, dim=1) + + # Create causal mask: query token i attends to positions 0 to + # (context_len + i) + kv_len = s_len + + final_mask_mod = partial(mask_mod, context_len=context_len) + block_mask = create_block_mask( + final_mask_mod, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, device=device + ) + sdpa_out_i = flex_attention( + q_sdpa_in, + k_sdpa_in, + v_sdpa_in, + block_mask=block_mask, + scale=scale, + enable_gqa=True, + ) + + all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0)) + + # Inputs for vLLM backends are just the new tokens + all_q_vllm.append(q) + all_k_vllm.append(k_full[context_len:]) + all_v_vllm.append(v_full[context_len:]) + + # Contextual K/V data used to populate the paged cache + k_contexts.append(k_full[:context_len]) + v_contexts.append(v_full[:context_len]) + + query_vllm = torch.cat(all_q_vllm, dim=0) + key_vllm = torch.cat(all_k_vllm, dim=0) + value_vllm = torch.cat(all_v_vllm, dim=0) + sdpa_output = torch.cat(all_sdpa_outputs, dim=0) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device + ) + if attn_type == AttentionType.ENCODER_ONLY: + # For encoder-only, all tokens are prefill tokens + common_attn_metadata.causal = False + + # 3. Simulate Paged KV Cache and a realistic slot_mapping + kv_cache = create_and_prepopulate_kv_cache( + k_contexts=k_contexts, + v_contexts=v_contexts, + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + device=device, + num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000, + common_attn_metadata=common_attn_metadata, + randomize_blocks=True, + ) + + # 4. Run vLLM backends and compare + # Note: flex_attention has known Triton kernel compatibility issues + # with test infrastructures + for backend_name in backend_to_test: + # FlashAttentionm + FlexAttention: + # [2, num_blocks, block_size, num_kv_heads, head_size] + # FlashInfer + Triton: + # [num_blocks, 2, block_size, num_kv_heads, head_size] + # Select the appropriate KV cache format for each backend + kv_cache_for_backend = kv_cache + reset_kv_cache_layout = False + if backend_name in ( + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.TRITON_ATTN, + ): + kv_cache_for_backend = kv_cache.transpose(0, 1) + + if backend_name == AttentionBackendEnum.FLASHINFER: + # For FlashInfer default to HND layout and + kv_cache_for_backend = kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) + set_kv_cache_layout("HND") + reset_kv_cache_layout = True + elif backend_name == AttentionBackendEnum.TRITON_ATTN: + kv_cache_for_backend = kv_cache_for_backend.contiguous() + + try: + backend_output = run_attention_backend( + backend_name, + kv_cache_spec, + ["placeholder"], + vllm_config, + device, + common_attn_metadata, + query_vllm, + key_vllm, + value_vllm, + kv_cache_for_backend, + sliding_window=sliding_window, + attn_type=attn_type, + ) + finally: + if reset_kv_cache_layout: + set_kv_cache_layout(None) + + # Check shape and dtype consistency + assert backend_output.shape == sdpa_output.shape, ( + f"[{backend_name}] shape {backend_output.shape} != SDPA shape {sdpa_output.shape}" + ) + assert backend_output.dtype == sdpa_output.dtype, ( + f"[{backend_name}] dtype {backend_output.dtype} != SDPA dtype {sdpa_output.dtype}" + ) + + assert torch.isfinite(backend_output).all(), f"[{backend_name}] produced non-finite values" + + # Check numerical similarity + def error_msg(msg: str, backend_name: str): + return f"[{backend_name}] output differs from SDPA baseline. {msg}" + + torch.testing.assert_close( + backend_output, + sdpa_output, + rtol=rtol, + atol=atol, + msg=partial(error_msg, backend_name=backend_name), + ) + + +@pytest.mark.parametrize( + "batch_spec_name", + [ + "small_decode", + "small_prefill", + "mixed_small", + "medium_decode", + "medium_prefill", + "mixed_medium", + "large_decode", + "large_prefill", + "single_decode", + "single_prefill", + ], +) +@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) +@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) +def test_causal_backend_correctness(batch_spec_name: str, model: str, tensor_parallel_size: int): + """Test backend's correctness with causal attention.""" + + def causal_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + *, + context_len: int, + ): + return (q_idx + context_len) >= kv_idx + + batch_spec = BATCH_SPECS[batch_spec_name] + + _test_backend_correctness( + batch_spec, + model, + BACKENDS_TO_TEST, + causal_mask_mod, + tensor_parallel_size=tensor_parallel_size, + ) + + +if __name__ == "__main__": + import sys + + test_function = sys.argv[1] + + pytest.main( + [ + "-q", + __file__ + f"::{test_function}", + ] + ) diff --git a/vllm_spyre_next/tests/test_vllm_spyre_next.py b/vllm_spyre_next/tests/test_vllm_spyre_next.py index 884bb74b6..17292938f 100644 --- a/vllm_spyre_next/tests/test_vllm_spyre_next.py +++ b/vllm_spyre_next/tests/test_vllm_spyre_next.py @@ -1,9 +1,6 @@ from vllm import LLM, RequestOutput, SamplingParams -import pytest - -@pytest.mark.spyre def test_basic_model_load(): model = LLM("ibm-ai-platform/micro-g3.3-8b-instruct-1b", max_model_len=128, max_num_seqs=2) diff --git a/vllm_spyre_next/vllm_spyre_next/compat_utils.py b/vllm_spyre_next/vllm_spyre_next/compat_utils.py new file mode 100644 index 000000000..edc2f5540 --- /dev/null +++ b/vllm_spyre_next/vllm_spyre_next/compat_utils.py @@ -0,0 +1,24 @@ +import inspect +from dataclasses import fields +from functools import lru_cache +from typing import Callable + + +def dataclass_fields(cls: type) -> list[str]: + return [f.name for f in fields(cls)] + + +@lru_cache +def has_argument(func: Callable, param_name: str) -> bool: + # Checks the signature of a method and returns true iff the method accepts + # a parameter named `$param_name`. + # `lru_cache` is used because inspect + for looping is pretty slow. This + # should not be invoked in the critical path. + signature = inspect.signature(func) + for param in signature.parameters.values(): + if ( + param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + and param.name == param_name + ): + return True + return False diff --git a/vllm_spyre_next/vllm_spyre_next/platform.py b/vllm_spyre_next/vllm_spyre_next/platform.py index 85046173b..96666a366 100644 --- a/vllm_spyre_next/vllm_spyre_next/platform.py +++ b/vllm_spyre_next/vllm_spyre_next/platform.py @@ -3,6 +3,8 @@ from string import Template import multiprocessing +from vllm_spyre_next.compat_utils import has_argument + # When running this plugin on a Mac, we assume it's for local development # purposes. However, due to a compatibility issue with vLLM, which overrides # the Triton module with a placeholder, vLLM may fail to load on macOS. To @@ -16,6 +18,7 @@ from vllm.logger import init_logger from vllm.platforms import PlatformEnum from vllm.platforms.cpu import CpuPlatform +from vllm.v1.attention.backends.registry import AttentionBackendEnum, register_backend if TYPE_CHECKING: # NB: We can't eagerly import many things from vllm since vllm.config @@ -34,6 +37,12 @@ class TorchSpyrePlatform(CpuPlatform): device_name: str = "cpu" device_type: str = "cpu" + # Register the PyTorch Native Paged Attention implementation as the CUSTOM backend + register_backend( + AttentionBackendEnum.CUSTOM, + "vllm_spyre_next.v1.attention.backends.spyre_paged_attn.SpyreAttentionPagedBackend", + ) + @classmethod def get_device_name(cls, device_id: int = 0) -> str: return "torch-spyre" @@ -72,6 +81,24 @@ def log_server_boot(cls, vllm_config: VllmConfig) -> None: logger.info(message, _version.version, model_name) + @classmethod + def get_attn_backend_cls( + cls, + selected_backend, + attn_selector_config, + num_heads=None, + ) -> str: + if selected_backend == AttentionBackendEnum.CUSTOM: + return AttentionBackendEnum.CUSTOM.get_path() + else: + # num_heads is in vllm:main but not in 0.16.0 or the release prep for 0.17.0 + kwargs = ( + {"num_heads": num_heads} + if has_argument(super().get_attn_backend_cls, "num_heads") + else {} + ) + return super().get_attn_backend_cls(selected_backend, attn_selector_config, **kwargs) + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cls.log_server_boot(vllm_config) diff --git a/vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_paged_attn.py b/vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_paged_attn.py new file mode 100644 index 000000000..91e35fd79 --- /dev/null +++ b/vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_paged_attn.py @@ -0,0 +1,572 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Pure PyTorch implementation of PagedAttention. + +This backend aims to implement PagedAttention using only PyTorch native operations, +such as matmul, softmax, etc. It supports vLLM's KV cache. +""" + +from dataclasses import dataclass +from typing import ClassVar + +import torch + +from vllm.config import VllmConfig +from vllm.config.cache import CacheDType +from vllm.v1.attention.backend import ( + AttentionBackend, + AttentionCGSupport, + AttentionImpl, + AttentionMetadataBuilder, + AttentionType, + CommonAttentionMetadata, + MultipleOf, +) +from vllm.v1.kv_cache_interface import AttentionSpec + + +@dataclass +class SpyreAttentionPagedMetadata: + """Metadata for PyTorch native attention computation.""" + + # Batch information + num_actual_tokens: int + num_seqs: int + max_query_len: int + max_seq_len: int + + # Sequence lengths + seq_lens: torch.Tensor # [num_seqs] + query_start_loc: torch.Tensor # [num_seqs + 1] + + # Block table for paged KV cache + block_table: torch.Tensor # [num_seqs, max_num_blocks_per_seq] + block_size: int + + # Slot mapping for KV cache updates + slot_mapping: torch.Tensor # [num_actual_tokens] + + # Whether causal masking is needed (True when max_query_len > 1) + causal_mask: torch.Tensor | None = None + + # For grouped-query attention + num_kv_heads: int = 0 + num_heads: int = 0 + + +class SpyreAttentionPagedMetadataBuilder(AttentionMetadataBuilder[SpyreAttentionPagedMetadata]): + """Builds attention metadata from batch information.""" + + _cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.block_size = kv_cache_spec.block_size + + model_config = vllm_config.model_config + self.num_heads = model_config.get_num_attention_heads(vllm_config.parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(vllm_config.parallel_config) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> SpyreAttentionPagedMetadata: + """Build attention metadata from common metadata.""" + + # Extract information from common metadata + num_actual_tokens = common_attn_metadata.num_actual_tokens + num_seqs = common_attn_metadata.num_reqs + max_query_len = common_attn_metadata.max_query_len + max_seq_len = common_attn_metadata.max_seq_len + + seq_lens = common_attn_metadata.seq_lens + query_start_loc = common_attn_metadata.query_start_loc + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + # Create causal mask if needed + causal_mask = None + if common_attn_metadata.causal and max_query_len > 1: + causal_mask = self._create_causal_mask(max_query_len, max_seq_len, self.device) + + return SpyreAttentionPagedMetadata( + num_actual_tokens=num_actual_tokens, + num_seqs=num_seqs, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + query_start_loc=query_start_loc, + block_table=block_table, + block_size=self.block_size, + slot_mapping=slot_mapping, + causal_mask=causal_mask, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + ) + + def _create_causal_mask( + self, + query_len: int, + kv_len: int, + device: torch.device, + ) -> torch.Tensor: + """Create causal attention mask.""" + # Create indices + query_idx = torch.arange(query_len, device=device).unsqueeze(1) + kv_idx = torch.arange(kv_len, device=device).unsqueeze(0) + + # Causal mask: query position can only attend to kv positions <= query position + mask = query_idx < kv_idx + + return mask + + +class SpyreAttentionPagedBackend(AttentionBackend): + """Pure PyTorch implementation of PagedAttention.""" + + accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + ] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + # Support any block size (no kernel-specific constraints) + return [MultipleOf(1)] + + @staticmethod + def get_name() -> str: + return "CUSTOM" + + @staticmethod + def get_impl_cls() -> type["SpyreAttentionPagedImpl"]: + return SpyreAttentionPagedImpl + + @staticmethod + def get_builder_cls() -> type["SpyreAttentionPagedMetadataBuilder"]: + return SpyreAttentionPagedMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + """KV cache shape: [2, num_blocks, block_size, num_kv_heads, head_size]""" + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + # Support any head size + return True + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: + if kv_cache_dtype is None: + return True + return kv_cache_dtype in cls.supported_kv_cache_dtypes + + +class SpyreAttentionPagedImpl(AttentionImpl[SpyreAttentionPagedMetadata]): + """PyTorch native implementation of attention with paged KV cache.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + use_sdpa: bool = False, + alibi_slopes: list[float] | None = None, + sliding_window: int | None = None, + kv_cache_dtype: str = "auto", + logits_soft_cap: float | None = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + working_precision=None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.num_queries_per_kv = num_heads // num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + self.attn_type = attn_type + + if working_precision is not None: + self.working_precision = working_precision + else: + self.working_precision = torch.bfloat16 + + # When True, use torch.nn.functional.scaled_dot_product_attention + # Otherwise, use implementation with native PyTorch ops + self.use_sdpa = use_sdpa + + # Simplified implementation: don't support these features initially + if alibi_slopes is not None: + raise NotImplementedError("ALiBi slopes not supported yet") + if sliding_window is not None: + raise NotImplementedError("Sliding window not supported yet") + if logits_soft_cap is not None: + raise NotImplementedError("Logits soft cap not supported yet") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, # [num_tokens, num_heads, head_size] + key: torch.Tensor, # [num_tokens, num_kv_heads, head_size] + value: torch.Tensor, # [num_tokens, num_kv_heads, head_size] + kv_cache: torch.Tensor, # [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: SpyreAttentionPagedMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Compute attention output using PyTorch native operations.""" + + assert output is not None, "Output tensor must be provided" + + if attn_metadata is None: + return output.fill_(0) + + num_actual_tokens = attn_metadata.num_actual_tokens + + # Step 1: Update KV cache + self._write_to_kv_cache( + key[:num_actual_tokens], + value[:num_actual_tokens], + kv_cache, + attn_metadata.slot_mapping, + attn_metadata.block_size, + ) + + # Step 2: Gather compact KV cache + # compact_k/v: [num_seqs, max_seq_len, num_kv_heads, head_size] + compact_k, compact_v = self._gather_compact_kv_cache( + kv_cache, + attn_metadata.block_table, + attn_metadata.seq_lens, + attn_metadata.block_size, + ) + + # Step 3: Reshape query to per-sequence format + # query_per_seq: [num_seqs, max_query_len, num_heads, head_size] + query_per_seq = self._reshape_query_to_sequences( + query[:num_actual_tokens], + attn_metadata.query_start_loc, + attn_metadata.num_seqs, + attn_metadata.max_query_len, + ) + + # Step 4: Build per-sequence attention mask + # mask: [num_seqs, 1, max_query_len, max_seq_len] (True = masked out) + mask = self._build_attention_mask( + attn_metadata.seq_lens, + attn_metadata.query_start_loc, + attn_metadata.causal_mask, + compact_k.device, + ) + + # Step 5: Compute batched per-sequence attention + # attn_output: [num_seqs, max_query_len, num_heads, head_size] + attn_output = self._compute_attention(query_per_seq, compact_k, compact_v, mask) + + # Step 6: Extract only the actual query tokens (strip padding) + # [num_actual_tokens, num_heads, head_size] + attn_output_flat = self._extract_relevant_output(attn_output, attn_metadata.query_start_loc) + + output[:num_actual_tokens].copy_(attn_output_flat) + return output + + def _write_to_kv_cache( + self, + key: torch.Tensor, # [num_tokens, num_kv_heads, head_size] + value: torch.Tensor, # [num_tokens, num_kv_heads, head_size] + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, # [num_tokens] + block_size: int, + ) -> None: + """Write keys and values to paged KV cache using vectorized scatter.""" + block_indices = slot_mapping // block_size + block_offsets = slot_mapping % block_size + + kv_cache[0][block_indices, block_offsets] = key + kv_cache[1][block_indices, block_offsets] = value + + def _gather_compact_kv_cache( + self, + kv_cache: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Gather only the relevant KV cache entries into compact tensors. + + Args: + kv_cache: [2, num_blocks, block_size, num_kv_heads, head_size] + block_table: [num_seqs, max_num_blocks_per_seq] + seq_lens: [num_seqs] + block_size: int + + Returns: + compact_k: [num_seqs, max_seq_len, num_kv_heads, head_size] + compact_v: [num_seqs, max_seq_len, num_kv_heads, head_size] + """ + num_seqs = block_table.shape[0] + max_blocks_per_seq = block_table.shape[1] + max_seq_len = seq_lens.max() + device = kv_cache.device + + key_cache = kv_cache[0] # [num_blocks, block_size, num_kv_heads, head_size] + value_cache = kv_cache[1] + + # [num_seqs, max_seq_len] + position_indices = ( + torch.arange(max_seq_len, device=device).unsqueeze(0).expand(num_seqs, -1) + ) + + block_indices = position_indices // block_size + offset_in_block = position_indices % block_size + + # Clamp to valid range; out-of-range positions are zeroed by valid_mask + block_indices_clamped = torch.clamp(block_indices, 0, max_blocks_per_seq - 1) + physical_blocks = block_table.gather(1, block_indices_clamped) + + # Zero out physical blocks for padding positions to avoid stale reads + valid_mask = position_indices < seq_lens.unsqueeze(1) # [num_seqs, max_seq_len] + physical_blocks = physical_blocks * valid_mask + + # Gather: [num_seqs * max_seq_len, num_kv_heads, head_size] + flat_blocks = physical_blocks.reshape(-1) + flat_offsets = offset_in_block.reshape(-1) + gathered_k = key_cache[flat_blocks, flat_offsets] + gathered_v = value_cache[flat_blocks, flat_offsets] + + # Reshape to [num_seqs, max_seq_len, num_kv_heads, head_size] + shape = (num_seqs, max_seq_len, key_cache.shape[2], key_cache.shape[3]) + compact_k = gathered_k.reshape(shape) + compact_v = gathered_v.reshape(shape) + + # No need to zero padding positions: the attention mask already sets + # scores for those positions to -inf, so their values are never used. + + return compact_k, compact_v + + def _build_attention_mask( + self, + seq_lens: torch.Tensor, # [num_seqs] + query_start_loc: torch.Tensor, # [num_seqs + 1] + causal_mask: torch.Tensor | None, + device: torch.device, + ) -> torch.Tensor: + """ + Build a per-sequence attention mask. + + Returns: + mask: [num_seqs, 1, max_query_len, max_seq_len] + True = masked out (don't attend), False = attend + """ + # num_seqs = seq_lens.shape[0] + query_lens = query_start_loc[1:] - query_start_loc[:-1] # [num_seqs] + max_query_len = query_lens.max() + max_seq_len = seq_lens.max() + + # Positions along query and KV dimensions + q_pos = torch.arange(max_query_len, device=device) # [max_query_len] + kv_pos = torch.arange(max_seq_len, device=device) # [max_seq_len] + + # Validity: which (seq, q, kv) positions are real (not padding)? + # [num_seqs, max_query_len] + q_valid = q_pos.unsqueeze(0) < query_lens.unsqueeze(1) + # [num_seqs, max_seq_len] + kv_valid = kv_pos.unsqueeze(0) < seq_lens.unsqueeze(1) + + # [num_seqs, max_query_len, max_seq_len] + attend = q_valid.unsqueeze(2) & kv_valid.unsqueeze(1) + + if causal_mask is not None: + # context_len[s] = seq_len[s] - query_len[s] + # query token q_i (0-indexed) can attend to KV positions 0 .. context_len + q_i + context_lens = seq_lens - query_lens # [num_seqs] + # [num_seqs, max_query_len, 1] + causal_limit = (context_lens.unsqueeze(1) + q_pos.unsqueeze(0)).unsqueeze(2) + # [num_seqs, 1, max_seq_len] + kv_pos_exp = kv_pos.unsqueeze(0).unsqueeze(0) + causal_ok = kv_pos_exp <= causal_limit # [num_seqs, max_query_len, max_seq_len] + attend = attend & causal_ok + + # [num_seqs, 1, max_query_len, max_seq_len] True = masked out + return ~attend.unsqueeze(1) + + def _reshape_query_to_sequences( + self, + query: torch.Tensor, # [num_actual_tokens, num_heads, head_size] + query_start_loc: torch.Tensor, # [num_seqs + 1] + num_seqs: int, + max_query_len: int, + ) -> torch.Tensor: + """ + Reshape flat query tokens into a padded per-sequence tensor. + + Returns: + [num_seqs, max_query_len, num_heads, head_size] + """ + # num_heads = query.shape[1] + # head_size = query.shape[2] + device = query.device + + query_lens = query_start_loc[1:] - query_start_loc[:-1] # [num_seqs] + + # [num_seqs, max_query_len] + positions = torch.arange(max_query_len, device=device).unsqueeze(0).expand(num_seqs, -1) + global_indices = query_start_loc[:-1].unsqueeze(1) + positions + + # Clamp so gather doesn't go OOB; invalid positions are masked in attention + global_indices_clamped = torch.clamp(global_indices, 0, query.shape[0] - 1) + + # [num_seqs, max_query_len, num_heads, head_size] + query_per_seq = query[global_indices_clamped] + + # Zero out padding positions + valid_mask = positions < query_lens.unsqueeze(1) + query_per_seq = query_per_seq * valid_mask.unsqueeze(-1).unsqueeze(-1) + + return query_per_seq + + def _compute_attention( + self, + query: torch.Tensor, # [num_seqs, max_query_len, num_heads, head_size] + key: torch.Tensor, # [num_seqs, max_seq_len, num_kv_heads, head_size] + value: torch.Tensor, # [num_seqs, max_seq_len, num_kv_heads, head_size] + mask: torch.Tensor, # [num_seqs, 1, max_query_len, max_seq_len] True=masked + ) -> torch.Tensor: + """ + Compute batched per-sequence attention with GQA via reshape+broadcast. + + Avoids repeat_interleave to expand KV heads (profiling showed it costs + ~17 ms/layer vs ~0.18 ms for the actual matmuls on 30 seqs × 300 tok). + Instead, Q is reshaped to expose the GQA grouping so that K/V can be + broadcast over the query-group dimension without materialising the + expanded tensors: + + Q [B, Hkv, Gq, q_len, D] × K [B, Hkv, 1, kv_len, D]^T + -> scores [B, Hkv, Gq, q_len, kv_len] + + All operations are pure torch.matmul / torch.softmax / masked_fill. + + # Alternative (requires PyTorch ≥ 2.5): + # out = torch.nn.functional.scaled_dot_product_attention( + # query.transpose(1, 2), # [B, H, q_len, D] + # key.transpose(1, 2), # [B, Hkv, kv_len, D] + # value.transpose(1, 2), + # attn_mask=~mask, # bool: True = attend + # scale=self.scale, + # enable_gqa=True, + # ).transpose(1, 2) + + Returns: + [num_seqs, max_query_len, num_heads, head_size] + """ + B = query.shape[0] + q_len = query.shape[1] + # kv_len = key.shape[1] + Hkv = self.num_kv_heads + Gq = self.num_queries_per_kv # Q heads per KV head + D = self.head_size + + if self.use_sdpa: + # torch.nn.functional.scaled_dot_product_attention path. + # Requires PyTorch ≥ 2.5 for enable_gqa support. + # Q: [B, q_len, H, D] → [B, H, q_len, D] + # K/V: [B, kv_len, Hkv, D] → [B, Hkv, kv_len, D] + # attn_mask: [B, 1, q_len, kv_len] True=attend (invert our mask) + out = torch.nn.functional.scaled_dot_product_attention( + query.transpose(1, 2), # [B, H, q_len, D] + key.transpose(1, 2), # [B, Hkv, kv_len, D] + value.transpose(1, 2), # [B, Hkv, kv_len, D] + attn_mask=~mask, # bool: True = attend + scale=self.scale, + enable_gqa=True, + ) + # [B, H, q_len, D] → [B, q_len, H, D] + return out.transpose(1, 2) + + # --- Manual matmul/softmax path --- + + # Q: [B, H, q_len, D] → [B, Hkv, Gq, q_len, D] + q = query.transpose(1, 2).reshape(B, Hkv, Gq, q_len, D) + + # K/V: [B, Hkv, kv_len, D] → unsqueeze Gq dim for broadcast + # Shape: [B, Hkv, 1, kv_len, D] (never materialised as Gq copies) + k = key.transpose(1, 2).unsqueeze(2) # [B, Hkv, 1, kv_len, D] + v = value.transpose(1, 2).unsqueeze(2) # [B, Hkv, 1, kv_len, D] + + # Scores: [B, Hkv, Gq, q_len, kv_len] + scores = ( + torch.matmul( + q.to(self.working_precision), + k.to(self.working_precision).transpose(-2, -1), + ) + * self.scale + ) + + # mask [B, 1, q_len, kv_len] → unsqueeze to [B, 1, 1, q_len, kv_len] + # for broadcast over Hkv and Gq + scores = scores.masked_fill(mask.unsqueeze(1), -float("inf")) + + weights = torch.softmax(scores, dim=-1) + + # Padding query rows (q >= query_len for each seq) get all-masked scores → softmax NaN. + # On some BLAS/CPU backends, NaN in padding rows can contaminate valid rows in + # the following matmul. Zero them out before the matmul; the padding output is + # discarded by _extract_relevant_output anyway. + weights = weights.nan_to_num(nan=0.0) + + # Output: [B, Hkv, Gq, q_len, D] → [B, H, q_len, D] + out = torch.matmul(weights, v.to(self.working_precision)) + out = out.to(query.dtype).reshape(B, Hkv * Gq, q_len, D) + + # [B, q_len, H, D] + return out.transpose(1, 2) + + def _extract_relevant_output( + self, + attn_output: torch.Tensor, # [num_seqs, max_query_len, num_heads, head_size] + query_start_loc: torch.Tensor, # [num_seqs + 1] + ) -> torch.Tensor: + """ + Extract actual query tokens from padded per-sequence output. + + Returns: + [num_actual_tokens, num_heads, head_size] + """ + max_query_len = attn_output.shape[1] + device = attn_output.device + + query_lens = query_start_loc[1:] - query_start_loc[:-1] # [num_seqs] + + # Boolean index into [num_seqs, max_query_len] + positions = torch.arange(max_query_len, device=device).unsqueeze(0) + valid = positions < query_lens.unsqueeze(1) # [num_seqs, max_query_len] + + # Boolean indexing flattens the first two dims and keeps the rest + return attn_output[valid] # [num_actual_tokens, num_heads, head_size]