diff --git a/vllm_spyre_next/examples/torch_spyre_inference.py b/vllm_spyre_next/examples/torch_spyre_inference.py index d80e600df..8a0eff936 100644 --- a/vllm_spyre_next/examples/torch_spyre_inference.py +++ b/vllm_spyre_next/examples/torch_spyre_inference.py @@ -24,6 +24,7 @@ def parse_args(): parser.add_argument("--model", type=str, default="ibm-ai-platform/micro-g3.3-8b-instruct-1b") parser.add_argument("--max_model_len", "--max-model-len", type=int, default=2048) parser.add_argument("--max_num_seqs", "--max-num-seqs", type=int, default=2) + parser.add_argument("--max_num_batched_tokens", "--max-num-batched-tokens", type=int, default=2) parser.add_argument("--tp", type=int, default=1) parser.add_argument("--num-prompts", "-n", type=int, default=3) parser.add_argument( @@ -34,6 +35,7 @@ def parse_args(): "This list is repeated until prompts are exhausted.", ) parser.add_argument("--compare-with-cpu", action=argparse.BooleanOptionalAction) + parser.add_argument("--attention_backend", "--attention-backend", type=str, default=None) parser.add_argument( "--enforce_eager", "--enforce-eager", @@ -95,7 +97,11 @@ def main(): "Compose a LinkedIn post about your company's latest product release.", ] - prompts = [template.format(instr) for instr in instructions] + simple_prompt = [ + "What are IBMs main businesses?", + ] + + prompts = simple_prompt + [template.format(instr) for instr in instructions] prompts = prompts * (args.num_prompts // len(prompts) + 1) prompts = prompts[0 : args.num_prompts] @@ -111,6 +117,8 @@ def main(): # lazy import to switch between old an new platform: # platform registration happens at import time from vllm import LLM, SamplingParams + from vllm.config import AttentionConfig + from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.config import CompilationConfig sampling_params = [ @@ -124,10 +132,13 @@ def main(): max_model_len=args.max_model_len, max_num_seqs=max_num_seqs, tensor_parallel_size=args.tp, - max_num_batched_tokens=1024, + max_num_batched_tokens=args.max_num_batched_tokens, dtype="float16", enforce_eager=args.enforce_eager, compilation_config=CompilationConfig(custom_ops=args.custom_ops), + attention_config=AttentionConfig(backend=AttentionBackendEnum[args.attention_backend]) + if args.attention_backend is not None + else None, ) # Generate texts from the prompts. The output is a list of RequestOutput objects diff --git a/vllm_spyre_next/tests/test_attention_correctness_stripped.py b/vllm_spyre_next/tests/test_attention_correctness_stripped.py new file mode 100644 index 000000000..7acf2b82a --- /dev/null +++ b/vllm_spyre_next/tests/test_attention_correctness_stripped.py @@ -0,0 +1,560 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for v1 attention backends without GPUModelRunner dependency.""" + +import os +import sys +from functools import partial + +import pytest +import torch +from torch.nn.attention.flex_attention import create_block_mask, flex_attention + +# Import attention test utilities from upstream vllm's test suite. +import vllm as _vllm + +_vllm_source_root = os.path.dirname(os.path.dirname(_vllm.__file__)) +sys.path.insert(0, os.path.join(_vllm_source_root, "tests", "v1", "attention")) + +from utils import ( # noqa: E402 + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + try_backend_includes_kv_cache_update, + try_get_attention_backend, +) +from vllm.utils.math_utils import cdiv # noqa: E402 +from vllm.utils.torch_utils import ( # noqa: E402 + STR_DTYPE_TO_TORCH_DTYPE, + is_torch_equal_or_newer, + set_random_seed, +) +from vllm.v1.attention.backend import AttentionType, CommonAttentionMetadata # noqa: E402 +from vllm.v1.attention.backends.registry import AttentionBackendEnum # noqa: E402 +from vllm.v1.attention.backends.utils import ( # noqa: E402 + set_kv_cache_layout, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec # noqa: E402 + +BACKENDS_TO_TEST = [ + AttentionBackendEnum.CUSTOM, +] + + +def _convert_dtype_to_torch(dtype): + """Convert ModelDType to torch.dtype.""" + if isinstance(dtype, str): + if dtype == "auto": + return torch.float16 # Default dtype for testing + elif dtype in STR_DTYPE_TO_TORCH_DTYPE: + return STR_DTYPE_TO_TORCH_DTYPE[dtype] + else: + raise ValueError(f"Unknown dtype: {dtype}") + elif isinstance(dtype, torch.dtype): + return dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + +# Define common batch configurations +BATCH_SPECS = { + "small_decode": BatchSpec(seq_lens=[40], query_lens=[1]), + "small_prefill": BatchSpec(seq_lens=[40], query_lens=[8]), + "mixed_small": BatchSpec(seq_lens=[48], query_lens=[5]), + "medium_decode": BatchSpec( + seq_lens=[1024], + query_lens=[1], + ), + "medium_prefill": BatchSpec(seq_lens=[1024], query_lens=[16]), + "mixed_medium": BatchSpec(seq_lens=[2048], query_lens=[1]), + "large_decode": BatchSpec(seq_lens=[2048], query_lens=[1]), + "large_prefill": BatchSpec(seq_lens=[4096], query_lens=[32]), + "mixed_large": BatchSpec(seq_lens=[4096], query_lens=[32]), + "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), + # encoder-only + "small_encoder_prefill": BatchSpec(seq_lens=[32], query_lens=[32]), + "medium_encoder_prefill": BatchSpec(seq_lens=[256], query_lens=[256]), +} + + +def create_and_prepopulate_kv_cache( + k_contexts: list[torch.Tensor], + v_contexts: list[torch.Tensor], + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True, +) -> torch.Tensor: + """Create and prepopulate a KV cache with context data. + + Args: + k_contexts: List of key context tensors for each sequence + v_contexts: List of value context tensors for each sequence + seq_lens: List of sequence lengths + block_size: Size of each block + num_kv_heads: Number of KV heads + head_size: Size of each head + dtype: Data type for the cache + device: Device to create the cache on + num_blocks: Total number of blocks in the cache + block_table: Block table tensor to populate + randomize_blocks: Whether to randomly permute blocks + or use sequential order + + Returns: + Tuple of (kv_cache, updated_block_table) + """ + batch_size = len(k_contexts) + seq_lens = common_attn_metadata.seq_lens.cpu() + query_lens = ( + common_attn_metadata.query_start_loc_cpu[1:] - common_attn_metadata.query_start_loc_cpu[:-1] + ) + context_lens = seq_lens - query_lens + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + # Create KV cache + kv_cache = torch.zeros( + 2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device + ) + kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size) + + # Populate the cache with the context tokens + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + k_context, v_context = k_contexts[i], v_contexts[i] + start = start_block_idx * block_size + end = start + k_context.shape[0] + kv_cache_flat[0, start:end, ...] = k_context + kv_cache_flat[1, start:end, ...] = v_context + + # Stay block aligned and allocate enough blocks for the new tokens + start_block_idx += cdiv(int(seq_lens[i]), block_size) + + blocks_end = start_block_idx + + # Permute the context blocks (excluding block 0 which is null) + if randomize_blocks: + # Random permutation starting from block 1 + perm = torch.randperm(blocks_end - 1) + 1 + else: + # Sequential order starting from block 1 + perm = torch.arange(1, blocks_end) + + inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) + # Add 1 to account for starting from block 1 + inv_perm[1:] = torch.argsort(perm) + 1 + kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...] + + # Construct the right block table + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size) + start = start_block_idx + end = start + num_blocks_for_seq + block_table[i, :num_blocks_for_seq] = inv_perm[start:end] + start_block_idx += num_blocks_for_seq + + # Create a realistic slot mapping that corresponds to the block table + for i in range(batch_size): + token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i]) + block_indices = token_offsets // block_size + token_inter_block_offsets = token_offsets % block_size + start = common_attn_metadata.query_start_loc_cpu[i] + end = common_attn_metadata.query_start_loc_cpu[i + 1] + slot_mapping[start:end] = block_table[ + i, block_indices + ] * block_size + token_inter_block_offsets.to(device) + + return kv_cache + + +class MockAttentionLayer: + """A mock attention layer for testing.""" + + def __init__(self, device: torch.device): + self._q_scale = torch.tensor(1.0, device=device) + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + # Add float versions for flashinfer + self._q_scale_float = 1.0 + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + + +def run_attention_backend( + backend: AttentionBackendEnum, + kv_cache_spec: FullAttentionSpec, + layer_names: list[str], + vllm_config, + device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_type: AttentionType = AttentionType.DECODER, + sliding_window: int | None = None, +) -> torch.Tensor: + """Run attention computation using the specified backend's AttentionImpl.""" + + # Handle special case for FLEX_ATTENTION_SLOW + actual_backend = backend + + use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0") + if backend == "FLEX_ATTENTION_SLOW": + actual_backend = AttentionBackendEnum.FLEX_ATTENTION + use_direct_block_mask = False + + builder_cls, impl_cls = try_get_attention_backend(actual_backend) + + # Mock flashinfer's get_per_layer_parameters if needed + if actual_backend == AttentionBackendEnum.FLASHINFER: + import unittest.mock + + from vllm.v1.attention.backends.utils import PerLayerParameters + + def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): + # Return mock parameters for a single layer + head_size = vllm_config.model_config.get_head_size() + return { + layer_name: PerLayerParameters( + window_left=-1, # No sliding window + logits_soft_cap=0.0, # No soft cap + sm_scale=1.0 / (head_size**0.5), # Standard scale + ) + for layer_name in layer_names + } + + with unittest.mock.patch( + "vllm.v1.attention.backends.flashinfer.get_per_layer_parameters", + mock_get_per_layer_parameters, + ): + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + # Build metadata + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + if actual_backend == AttentionBackendEnum.FLEX_ATTENTION: + builder.direct_build = use_direct_block_mask + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + # Instantiate implementation + num_heads = vllm_config.model_config.get_num_attention_heads(vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads(vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + scale = 1.0 / (head_size**0.5) + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=sliding_window, + attn_type=attn_type, + kv_cache_dtype="auto", + ) + + # Create mock layer and output buffer + mock_layer = MockAttentionLayer(device) + output = torch.empty_like(query) + + # Run forward pass + # NOTE: The query, key, and value are already shaped correctly + # in the calling test function. + if not try_backend_includes_kv_cache_update(actual_backend): + impl.do_kv_cache_update(mock_layer, key, value, kv_cache, attn_metadata.slot_mapping) + output = impl.forward(mock_layer, query, key, value, kv_cache, attn_metadata, output=output) + + return output + + +def _test_backend_correctness( + batch_spec: BatchSpec, + model: str, + backend_to_test: list[AttentionBackendEnum | str], + mask_mod, + *, + attn_type: AttentionType = AttentionType.DECODER, + block_size: int = 16, + atol: float = 1e-2, + rtol: float = 1e-2, + tensor_parallel_size: int = 1, +): + """ + Test that all backends produce similar outputs to a reference implementation + using torch.nn.functional.scaled_dot_product_attention. + + This test works by: + 1. Generating a batch of sequences with specified context and query lengths. + 2. Computing a ground-truth attention output using torch.sdpa on + contiguous Q, K, and V tensors. + 3. Simulating vLLM's paged KV cache: It takes the context portion of the + K/V tensors and manually places them into a paged buffer according to + the test's (randomly generated) block table. + 4. Running each vLLM attention backend with the new queries and the + simulated paged KV cache. + 5. Comparing the vLLM backend's output to the ground-truth SDPA output. + + Note: When tensor_parallel_size > 1, we simulate the head partitioning + by overriding the model config to use fewer heads, without requiring + multiple GPUs. This tests that backends work correctly with different + head counts. + """ + set_random_seed(42) + + hf_config_override = None + if tensor_parallel_size > 1: + from vllm.config import ModelConfig + + temp_config = ModelConfig(model=model, max_model_len=1) + original_num_heads = temp_config.hf_text_config.num_attention_heads + original_num_kv_heads = getattr(temp_config.hf_text_config, "num_key_value_heads", None) + hf_config_override = { + "num_attention_heads": original_num_heads // tensor_parallel_size, + } + if original_num_kv_heads is not None: + hf_config_override["num_key_value_heads"] = max( + 1, original_num_kv_heads // tensor_parallel_size + ) + + vllm_config = create_vllm_config( + model_name=model, + tensor_parallel_size=1, # Always use TP=1 to avoid multi-GPU requirements + max_model_len=max(batch_spec.seq_lens), + block_size=block_size, + num_gpu_blocks=8192, + hf_config_override=hf_config_override, + ) + # device = torch.device("cuda:0") + device = torch.device("cpu") + + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # 1. Setup + batch_size = batch_spec.batch_size + seq_lens = batch_spec.seq_lens + query_lens = batch_spec.query_lens + num_q_heads = vllm_config.model_config.get_num_attention_heads(vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads(vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + sliding_window = vllm_config.model_config.get_sliding_window() + dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + block_size = vllm_config.cache_config.block_size + scale = 1.0 / (head_size**0.5) + + # 2. Generate data and compute SDPA reference output + all_q_vllm, all_k_vllm, all_v_vllm = [], [], [] + all_sdpa_outputs = [] + k_contexts, v_contexts = [], [] + + for i in range(batch_size): + s_len = seq_lens[i] + q_len = query_lens[i] + context_len = s_len - q_len + + # Generate Q, K, V for the whole sequence to be used in SDPA + q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype, device=device) + k_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) + v_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) + + # SDPA expects (N, H, L, D), so unsqueeze batch and permute + q_sdpa_in = q.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0, ( + f"num_q_heads ({num_q_heads}) must be divisible by num_kv_heads ({num_kv_heads})" + ) + repeats = num_q_heads // num_kv_heads + k_sdpa_in = k_sdpa_in.repeat_interleave(repeats, dim=1) + v_sdpa_in = v_sdpa_in.repeat_interleave(repeats, dim=1) + + # Create causal mask: query token i attends to positions 0 to + # (context_len + i) + kv_len = s_len + + final_mask_mod = partial(mask_mod, context_len=context_len) + block_mask = create_block_mask( + final_mask_mod, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, device=device + ) + sdpa_out_i = flex_attention( + q_sdpa_in, + k_sdpa_in, + v_sdpa_in, + block_mask=block_mask, + scale=scale, + enable_gqa=True, + ) + + all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0)) + + # Inputs for vLLM backends are just the new tokens + all_q_vllm.append(q) + all_k_vllm.append(k_full[context_len:]) + all_v_vllm.append(v_full[context_len:]) + + # Contextual K/V data used to populate the paged cache + k_contexts.append(k_full[:context_len]) + v_contexts.append(v_full[:context_len]) + + query_vllm = torch.cat(all_q_vllm, dim=0) + key_vllm = torch.cat(all_k_vllm, dim=0) + value_vllm = torch.cat(all_v_vllm, dim=0) + sdpa_output = torch.cat(all_sdpa_outputs, dim=0) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device + ) + if attn_type == AttentionType.ENCODER_ONLY: + # For encoder-only, all tokens are prefill tokens + common_attn_metadata.causal = False + + # 3. Simulate Paged KV Cache and a realistic slot_mapping + kv_cache = create_and_prepopulate_kv_cache( + k_contexts=k_contexts, + v_contexts=v_contexts, + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + device=device, + num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000, + common_attn_metadata=common_attn_metadata, + randomize_blocks=True, + ) + + # 4. Run vLLM backends and compare + # Note: flex_attention has known Triton kernel compatibility issues + # with test infrastructures + for backend_name in backend_to_test: + # FlashAttentionm + FlexAttention: + # [2, num_blocks, block_size, num_kv_heads, head_size] + # FlashInfer + Triton: + # [num_blocks, 2, block_size, num_kv_heads, head_size] + # Select the appropriate KV cache format for each backend + kv_cache_for_backend = kv_cache + reset_kv_cache_layout = False + if backend_name in ( + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.TRITON_ATTN, + ): + kv_cache_for_backend = kv_cache.transpose(0, 1) + + if backend_name == AttentionBackendEnum.FLASHINFER: + # For FlashInfer default to HND layout and + kv_cache_for_backend = kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) + set_kv_cache_layout("HND") + reset_kv_cache_layout = True + elif backend_name == AttentionBackendEnum.TRITON_ATTN: + kv_cache_for_backend = kv_cache_for_backend.contiguous() + + try: + backend_output = run_attention_backend( + backend_name, + kv_cache_spec, + ["placeholder"], + vllm_config, + device, + common_attn_metadata, + query_vllm, + key_vllm, + value_vllm, + kv_cache_for_backend, + sliding_window=sliding_window, + attn_type=attn_type, + ) + finally: + if reset_kv_cache_layout: + set_kv_cache_layout(None) + + # Check shape and dtype consistency + assert backend_output.shape == sdpa_output.shape, ( + f"[{backend_name}] shape {backend_output.shape} != SDPA shape {sdpa_output.shape}" + ) + assert backend_output.dtype == sdpa_output.dtype, ( + f"[{backend_name}] dtype {backend_output.dtype} != SDPA dtype {sdpa_output.dtype}" + ) + + assert torch.isfinite(backend_output).all(), f"[{backend_name}] produced non-finite values" + + # Check numerical similarity + def error_msg(msg: str, backend_name: str): + return f"[{backend_name}] output differs from SDPA baseline. {msg}" + + torch.testing.assert_close( + backend_output, + sdpa_output, + rtol=rtol, + atol=atol, + msg=partial(error_msg, backend_name=backend_name), + ) + + +@pytest.mark.parametrize( + "batch_spec_name", + [ + "small_decode", + "small_prefill", + "mixed_small", + "medium_decode", + "medium_prefill", + "mixed_medium", + "large_decode", + "large_prefill", + "single_decode", + "single_prefill", + ], +) +@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_causal_backend_correctness(batch_spec_name: str, model: str, tensor_parallel_size: int): + """Test backend's correctness with causal attention.""" + + def causal_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + *, + context_len: int, + ): + return (q_idx + context_len) >= kv_idx + + batch_spec = BATCH_SPECS[batch_spec_name] + + _test_backend_correctness( + batch_spec, + model, + BACKENDS_TO_TEST, + causal_mask_mod, + tensor_parallel_size=tensor_parallel_size, + ) + + +if __name__ == "__main__": + import sys + + test_function = sys.argv[1] + + pytest.main( + [ + "-q", + __file__ + f"::{test_function}", + ] + ) diff --git a/vllm_spyre_next/tests/test_spyre_attn.py b/vllm_spyre_next/tests/test_spyre_attn.py new file mode 100644 index 000000000..c32dca2fc --- /dev/null +++ b/vllm_spyre_next/tests/test_spyre_attn.py @@ -0,0 +1,345 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.utils.torch_utils import set_random_seed +from vllm_spyre_next.v1.attention.backends.spyre_attn import ( + SpyreAttentionImpl, + SpyreAttentionMetadata, +) + + +def is_spyre_available(): + try: + test_tensor = torch.randn(1, device=torch.device("spyre")) + del test_tensor + return True + except Exception: + return False + + +SPYRE_AVAILABLE = is_spyre_available() + +pytestmark = pytest.mark.skipif( + not SPYRE_AVAILABLE, reason="Spyre device not available - these tests require Spyre hardware" +) + +NUM_HEADS = [(4, 4), (8, 2)] # (num_query_heads, num_kv_heads) +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16] +DTYPES = [torch.float16] +NUM_BLOCKS = [2048, 32768] + + +def ref_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: list[int], + kv_lens: list[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: int | None = None, + soft_cap: float | None = None, +) -> torch.Tensor: + """Reference implementation of attention for validation.""" + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs: list[torch.Tensor] = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx : start_idx + query_len] + q = q * scale # avoid in-place mutation of the input tensor + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size)[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size)[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = ( + torch.triu(empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1) + .bool() + .logical_not() + ) + mask |= sliding_window_mask + if soft_cap is not None and soft_cap > 0: + attn = soft_cap * torch.tanh(attn / soft_cap) + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize( + "seq_lens", + [ + [(1, 256), (2, 128), (4, 512)], + [(1, 256), (1, 128), (1, 512)], + [(72, 512), (1, 256), (4, 128)], + ], +) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None]) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("use_sdpa", [True, False]) +@torch.inference_mode() +def test_spyre_attn( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + sliding_window: int | None, + dtype: torch.dtype, + block_size: int, + soft_cap: float | None, + num_blocks: int, + use_sdpa: bool, +) -> None: + """Validate SpyreAttentionImpl against a reference implementation.""" + torch.set_default_device("cpu") + set_random_seed(0) + + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads, num_kv_heads = num_heads + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) + key = torch.randn(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + value = torch.randn(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + + kv_cache = torch.zeros(num_blocks, 2, block_size, num_kv_heads, head_size, dtype=dtype) + key_cache = kv_cache[:, 0] + value_cache = kv_cache[:, 1] + + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) + + # Pre-populate KV cache with historical context + for seq_idx in range(num_seqs): + query_len = query_lens[seq_idx] + kv_len = kv_lens[seq_idx] + historical_len = kv_len - query_len + if historical_len > 0: + historical_keys = torch.randn(historical_len, num_kv_heads, head_size, dtype=dtype) + historical_values = torch.randn(historical_len, num_kv_heads, head_size, dtype=dtype) + for token_idx in range(historical_len): + block_idx = token_idx // block_size + block_offset = token_idx % block_size + actual_block = block_tables[seq_idx, block_idx].item() + key_cache[actual_block, block_offset] = historical_keys[token_idx] + value_cache[actual_block, block_offset] = historical_values[token_idx] + + # Create slot mapping for new query tokens + slot_mapping = [] + for seq_idx in range(num_seqs): + query_len = query_lens[seq_idx] + kv_len = kv_lens[seq_idx] + for token_idx in range(query_len): + pos = kv_len - query_len + token_idx + actual_block = block_tables[seq_idx, pos // block_size].item() + slot_mapping.append(actual_block * block_size + pos % block_size) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64) + + attn_metadata = SpyreAttentionMetadata( + num_actual_tokens=sum(query_lens), + num_seqs=num_seqs, + max_query_len=max_query_len, + max_seq_len=max_kv_len, + seq_lens=kv_lens_tensor, + query_start_loc=cu_query_lens, + block_table=block_tables, + block_size=block_size, + slot_mapping=slot_mapping, + apply_causal_mask=max_query_len > 1, + num_kv_heads=num_kv_heads, + num_heads=num_query_heads, + ) + + attn_impl = SpyreAttentionImpl( + num_heads=num_query_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=sliding_window, + kv_cache_dtype="auto", + logits_soft_cap=soft_cap, + use_sdpa=use_sdpa, + ) + + output = torch.empty_like(query) + attn_impl.forward( + layer=None, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=output, + ) + + ref_output = ref_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + ) + + if use_sdpa: + atol, rtol = 0.1, 0.1 + elif max_query_len >= 32: + atol, rtol = 0.3, 5.0 # float16 accumulation errors for large prompts + else: + atol, rtol = 0.2, 0.2 + + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("num_heads", [(8, 2)]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@torch.inference_mode() +def test_spyre_attn_single_sequence( + num_heads: tuple[int, int], + head_size: int, + block_size: int, + dtype: torch.dtype, +) -> None: + """Test single-sequence attention across a range of query/kv lengths.""" + torch.set_default_device("cpu") + set_random_seed(42) + + num_query_heads, num_kv_heads = num_heads + scale = head_size**-0.5 + + test_cases = [ + (1, 128), # single token decode + (32, 256), # exact chunk size + (64, 512), # multi-chunk + (100, 512), # non-aligned query length + ] + + for query_len, kv_len in test_cases: + num_seqs = 1 + num_blocks = 1024 + + query = torch.randn(query_len, num_query_heads, head_size, dtype=dtype) + key = torch.randn(query_len, num_kv_heads, head_size, dtype=dtype) + value = torch.randn(query_len, num_kv_heads, head_size, dtype=dtype) + + kv_cache = torch.zeros(num_blocks, 2, block_size, num_kv_heads, head_size, dtype=dtype) + key_cache = kv_cache[:, 0] + value_cache = kv_cache[:, 1] + + max_num_blocks_per_seq = (kv_len + block_size - 1) // block_size + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) + + historical_len = kv_len - query_len + if historical_len > 0: + historical_keys = torch.randn(historical_len, num_kv_heads, head_size, dtype=dtype) + historical_values = torch.randn(historical_len, num_kv_heads, head_size, dtype=dtype) + for token_idx in range(historical_len): + block_idx = token_idx // block_size + block_offset = token_idx % block_size + actual_block = block_tables[0, block_idx].item() + key_cache[actual_block, block_offset] = historical_keys[token_idx] + value_cache[actual_block, block_offset] = historical_values[token_idx] + + slot_mapping = [] + for token_idx in range(query_len): + pos = kv_len - query_len + token_idx + actual_block = block_tables[0, pos // block_size].item() + slot_mapping.append(actual_block * block_size + pos % block_size) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64) + + attn_metadata = SpyreAttentionMetadata( + num_actual_tokens=query_len, + num_seqs=num_seqs, + max_query_len=query_len, + max_seq_len=kv_len, + seq_lens=torch.tensor([kv_len], dtype=torch.int32), + query_start_loc=torch.tensor([0, query_len], dtype=torch.int32), + block_table=block_tables, + block_size=block_size, + slot_mapping=slot_mapping, + apply_causal_mask=query_len > 1, + num_kv_heads=num_kv_heads, + num_heads=num_query_heads, + ) + + attn_impl = SpyreAttentionImpl( + num_heads=num_query_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + ) + + output = torch.empty_like(query) + attn_impl.forward( + layer=None, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + output=output, + ) + + ref_output = ref_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[query_len], + kv_lens=[kv_len], + block_tables=block_tables, + scale=scale, + ) + + if query_len >= 32: + atol, rtol = 0.3, 5.0 # float16 accumulation errors for large prompts + else: + atol, rtol = 0.1, 0.1 + + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) diff --git a/vllm_spyre_next/vllm_spyre_next/platform.py b/vllm_spyre_next/vllm_spyre_next/platform.py index 7c000e110..ca71a5c30 100644 --- a/vllm_spyre_next/vllm_spyre_next/platform.py +++ b/vllm_spyre_next/vllm_spyre_next/platform.py @@ -4,6 +4,7 @@ import multiprocessing import importlib.metadata + # When running this plugin on a Mac, we assume it's for local development # purposes. However, due to a compatibility issue with vLLM, which overrides # the Triton module with a placeholder, vLLM may fail to load on macOS. To @@ -17,6 +18,7 @@ from vllm.logger import init_logger from vllm.platforms import PlatformEnum from vllm.platforms.cpu import CpuPlatform +from vllm.v1.attention.backends.registry import AttentionBackendEnum, register_backend if TYPE_CHECKING: # NB: We can't eagerly import many things from vllm since vllm.config @@ -35,6 +37,12 @@ class TorchSpyrePlatform(CpuPlatform): device_name: str = "cpu" device_type: str = "cpu" + # Register the PyTorch Native Attention implementation as the CUSTOM backend + register_backend( + AttentionBackendEnum.CUSTOM, + "vllm_spyre_next.v1.attention.backends.spyre_attn.SpyreAttentionBackend", + ) + @classmethod def get_device_name(cls, device_id: int = 0) -> str: return "torch-spyre" @@ -73,6 +81,13 @@ def log_server_boot(cls, vllm_config: VllmConfig) -> None: logger.info(message, version, model_name) + @classmethod + def get_attn_backend_cls(cls, selected_backend, *args, **kwargs) -> str: + if selected_backend == AttentionBackendEnum.CUSTOM: + return AttentionBackendEnum.CUSTOM.get_path() + else: + return super().get_attn_backend_cls(selected_backend, *args, **kwargs) + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cls.log_server_boot(vllm_config) @@ -101,20 +116,5 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: logger.info("Loading scheduler from: %s", scheduler_class) scheduler_config.scheduler_cls = scheduler_class - # ---- attention backend ---- - # A custom attention backend can be registered with get_attn_backend_cls() - # see copied code from vllm/platforms/cpu.CpuPlatform illustrating the default - # TorchSDPABackend used for vLLM CPU execution - - # @classmethod - # def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, - # dtype: torch.dtype, kv_cache_dtype: Optional[str], - # block_size: int, use_v1: bool, - # use_mla: bool) -> str: - # if selected_backend and selected_backend != _Backend.TORCH_SDPA: - # logger.info("Cannot use %s backend on CPU.", selected_backend) - # logger.info("Using Torch SDPA backend.") - # return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" - # call CpuPlatform.check_and_update_config() super().check_and_update_config(vllm_config) diff --git a/vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py b/vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py new file mode 100644 index 000000000..cf2922035 --- /dev/null +++ b/vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py @@ -0,0 +1,776 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Contiguous KV-cache implementation of AttentionBackend using torch-spyre. + +This backend aims to implement attention using only PyTorch native operations, +such as matmul, softmax, etc. It supports vLLM's KV cache. +""" + +from dataclasses import dataclass +from typing import ClassVar + +import torch + +from vllm_spyre_next.custom_ops.utils import convert + +from vllm.config import VllmConfig +from vllm.config.cache import CacheDType +from vllm.v1.attention.backend import ( + AttentionBackend, + AttentionCGSupport, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType, + CommonAttentionMetadata, + MultipleOf, +) +from vllm.v1.kv_cache_interface import AttentionSpec + + +@dataclass +class SpyreAttentionMetadata(AttentionMetadata): + """Metadata for PyTorch native attention computation on Spyre.""" + + # Batch information + num_actual_tokens: int + num_seqs: int + max_query_len: int + max_seq_len: int + + # Sequence lengths + seq_lens: torch.Tensor # [num_seqs] + query_start_loc: torch.Tensor # [num_seqs + 1] + + # Block table for paged KV cache + block_table: torch.Tensor # [num_seqs, max_num_blocks_per_seq] + block_size: int + + # Slot mapping for KV cache updates + slot_mapping: torch.Tensor # [num_actual_tokens] + + # Whether causal masking is needed (True when max_query_len > 1) + apply_causal_mask: bool = False + + # For grouped-query attention + num_kv_heads: int = 0 + num_heads: int = 0 + + @property + def query_lens(self) -> torch.Tensor: + """Per-sequence query lengths, derived from query_start_loc. [num_seqs]""" + return self.query_start_loc[1:] - self.query_start_loc[:-1] + + +class SpyreAttentionMetadataBuilder(AttentionMetadataBuilder[SpyreAttentionMetadata]): + """Builds attention metadata from batch information.""" + + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + self.block_size = kv_cache_spec.block_size + + model_config = vllm_config.model_config + self.num_heads = model_config.get_num_attention_heads(vllm_config.parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(vllm_config.parallel_config) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> SpyreAttentionMetadata: + """Build attention metadata from common metadata.""" + return SpyreAttentionMetadata( + num_actual_tokens=common_attn_metadata.num_actual_tokens, + num_seqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + seq_lens=common_attn_metadata.seq_lens, + query_start_loc=common_attn_metadata.query_start_loc, + block_table=common_attn_metadata.block_table_tensor, + block_size=self.block_size, + slot_mapping=common_attn_metadata.slot_mapping, + apply_causal_mask=common_attn_metadata.causal + and common_attn_metadata.max_query_len > 1, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + ) + + +class SpyreAttentionBackend(AttentionBackend): + """Pure PyTorch implementation of Attention.""" + + accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "float16", + ] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + # Support any block size (no kernel-specific constraints) + return [MultipleOf(1)] + + @staticmethod + def get_name() -> str: + return "CUSTOM" + + @staticmethod + def get_impl_cls() -> type["SpyreAttentionImpl"]: + return SpyreAttentionImpl + + @staticmethod + def get_builder_cls() -> type["SpyreAttentionMetadataBuilder"]: + return SpyreAttentionMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + """KV cache shape: [num_blocks, 2, block_size, num_kv_heads, head_size]""" + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + # Spyre stick size is 128 bytes; tensors are transferred as float16 (2 bytes), + # so head_size must be a multiple of 64 (= 128 / 2) to satisfy stick alignment. + return head_size % 64 == 0 + + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: + if kv_cache_dtype is None: + return True + return kv_cache_dtype in cls.supported_kv_cache_dtypes + + +class SpyreAttentionImpl(AttentionImpl[SpyreAttentionMetadata]): + """PyTorch native implementation of attention with paged KV cache on Spyre.""" + + # TODO: Make these hyperparameters configurable + # KV length alignment: KV tensors are padded to the next multiple of this value. + # Because torch.compile treats shapes as static constants, every distinct kv_len + # triggers a full recompile. Aligning to 256 buckets sequence lengths into tiers + # (256, 512, 768, ...) so only the first request at each tier pays compilation cost, + # rather than recompiling on every decode step. + KV_LENGTH_ALIGNMENT = 256 + + # Query chunk size for padding - ensures consistent tensor sizes for Spyre compilation + QUERY_CHUNK_SIZE = 32 + + @staticmethod + def _attn_transposed(qt, k, vt, sm_scale, mask_values): + """Transposed attention for Spyre: handles all heads at once. + + Args: + qt: Query transposed [head_size, num_heads * query_len_padded] + k: Key [num_heads * kv_len, head_size] + vt: Value transposed [head_size, num_heads * kv_len] + sm_scale: Scale factor (1D tensor) [num_heads * query_len_padded] + mask_values: Mask values tensor [num_heads * kv_len, num_heads * query_len_padded] + Pre-computed on CPU: 0.0 for valid, -65504.0 for masked/padded + """ + kq = k @ qt # [num_heads * kv_len, num_heads * query_len_padded] + kq = kq * sm_scale + + # Add pre-computed mask values + # Valid positions have 0.0, masked/padded positions have -65504.0 + kq = kq + mask_values + + p = kq.softmax(dim=0) + return vt @ p # [head_size, num_heads * query_len_padded] + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None = None, + sliding_window: int | None = None, + kv_cache_dtype: str = "auto", + logits_soft_cap: float | None = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + use_sdpa: bool = False, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.num_queries_per_kv = num_heads // num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + self.attn_type = attn_type + + # Target device/dtype for compiled attention kernels + self._target_device = torch.device("spyre") + self._target_dtype = torch.float16 + + # When True, use torch.nn.functional.scaled_dot_product_attention. + # Otherwise, use the transposed matmul kernel (_attn_transposed). + self.use_sdpa = use_sdpa + + if self.use_sdpa: + self.attn_op = torch.nn.functional.scaled_dot_product_attention + else: + self.attn_op = self._attn_transposed + + # Compile the attention function once for reuse. + # dynamic=False forces static shapes, required by the Spyre compiler. + self.attn_op = torch.compile(self.attn_op, dynamic=False) + + # Simplified implementation: don't support these features initially + if alibi_slopes is not None: + raise NotImplementedError("ALiBi slopes not supported yet") + if sliding_window is not None: + raise NotImplementedError("Sliding window not supported yet") + if logits_soft_cap is not None: + raise NotImplementedError("Logits soft cap not supported yet") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, # [num_tokens, num_heads, head_size] + key: torch.Tensor, # [num_tokens, num_kv_heads, head_size] + value: torch.Tensor, # [num_tokens, num_kv_heads, head_size] + kv_cache: torch.Tensor, # [num_blocks, 2, block_size, num_kv_heads, head_size] + attn_metadata: SpyreAttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Compute attention output using PyTorch native operations.""" + + assert output is not None, "Output tensor must be provided" + + if attn_metadata is None: + return output + + num_actual_tokens = attn_metadata.num_actual_tokens + + # Step 1: Update KV cache (CPU) + self._write_to_kv_cache( + key[:num_actual_tokens], + value[:num_actual_tokens], + kv_cache, + attn_metadata.slot_mapping, + attn_metadata.block_size, + ) + + # Step 2: Gather compact KV cache (CPU) + # compact_k/v: [num_seqs, max_seq_len, num_kv_heads, head_size] + compact_k, compact_v = self._gather_compact_kv_cache( + kv_cache, + attn_metadata.block_table, + attn_metadata.seq_lens, + attn_metadata.block_size, + attn_metadata.max_seq_len, + query.device, + ) + + # Step 3: Reshape query to per-sequence format (CPU) + # query_per_seq: [num_seqs, max_query_len, num_heads, head_size] + query_per_seq = self._reshape_query_to_sequences( + query[:num_actual_tokens], + attn_metadata.query_start_loc, + attn_metadata.num_seqs, + attn_metadata.max_query_len, + query.device, + ) + + # Step 4: Build per-sequence attention mask (CPU) + # mask: [num_seqs, 1, max_query_len, max_seq_len] (True = masked out) + mask = self._build_attention_mask( + attn_metadata.seq_lens, + attn_metadata.query_start_loc, + attn_metadata.apply_causal_mask, + attn_metadata.max_seq_len, + attn_metadata.max_query_len, + query.device, + ) + + # Step 5: Compute batched per-sequence attention (CPU, Spyre) + # attn_output: [num_seqs, max_query_len, num_heads, head_size] + attn_output = self._compute_attention( + query_per_seq, compact_k, compact_v, mask, query.device, query.dtype + ) + + # Step 6: Extract only the actual query tokens (strip padding) (CPU) + # [num_actual_tokens, num_heads, head_size] + attn_output_flat = self._extract_relevant_output(attn_output, attn_metadata.query_start_loc) + + output[:num_actual_tokens].copy_(attn_output_flat) + return output + + def _write_to_kv_cache( + self, + key: torch.Tensor, # [num_tokens, num_kv_heads, head_size] + value: torch.Tensor, # [num_tokens, num_kv_heads, head_size] + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, # [num_tokens] + block_size: int, + ) -> None: + """Write keys and values to paged KV cache using vectorized scatter.""" + block_indices = slot_mapping // block_size + block_offsets = slot_mapping % block_size + + kv_cache[block_indices, 0, block_offsets] = key + kv_cache[block_indices, 1, block_offsets] = value + + def _gather_compact_kv_cache( + self, + kv_cache: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Gather only the relevant KV cache entries into compact tensors with alignment. + + Args: + kv_cache: [num_blocks, 2, block_size, num_kv_heads, head_size] + block_table: [num_seqs, max_num_blocks_per_seq] + seq_lens: [num_seqs] + block_size: int + max_seq_len: pre-computed max of seq_lens (avoids a device sync) + + Returns: + compact_k: [num_seqs, aligned_max_seq_len, num_kv_heads, head_size] + compact_v: [num_seqs, aligned_max_seq_len, num_kv_heads, head_size] + + Note: aligned_max_seq_len is max_seq_len rounded up to KV_LENGTH_ALIGNMENT + """ + num_seqs = block_table.shape[0] + max_blocks_per_seq = block_table.shape[1] + + # Align max_seq_len to KV_LENGTH_ALIGNMENT + aligned_max_seq_len = ( + (max_seq_len + self.KV_LENGTH_ALIGNMENT - 1) + // self.KV_LENGTH_ALIGNMENT + * self.KV_LENGTH_ALIGNMENT + ) + + key_cache = kv_cache[:, 0] # [num_blocks, block_size, num_kv_heads, head_size] + value_cache = kv_cache[:, 1] + num_kv_heads = key_cache.shape[2] + head_size = key_cache.shape[3] + + # [num_seqs, max_seq_len] - only gather up to actual max_seq_len + position_indices = ( + torch.arange(max_seq_len, device=device).unsqueeze(0).expand(num_seqs, -1) + ) + + block_indices = position_indices // block_size + offset_in_block = position_indices % block_size + + # Clamp to valid range + block_indices_clamped = torch.clamp(block_indices, 0, max_blocks_per_seq - 1) + physical_blocks = block_table.gather(1, block_indices_clamped) + + # Zero out physical blocks for padding positions + valid_mask = position_indices < seq_lens.unsqueeze(1) # [num_seqs, max_seq_len] + physical_blocks = physical_blocks * valid_mask + + # Gather: [num_seqs * max_seq_len, num_kv_heads, head_size] + flat_blocks = physical_blocks.reshape(-1) + flat_offsets = offset_in_block.reshape(-1) + gathered_k = key_cache[flat_blocks, flat_offsets] + gathered_v = value_cache[flat_blocks, flat_offsets] + + # Reshape to [num_seqs, max_seq_len, num_kv_heads, head_size] + gathered_k = gathered_k.reshape(num_seqs, max_seq_len, num_kv_heads, head_size) + gathered_v = gathered_v.reshape(num_seqs, max_seq_len, num_kv_heads, head_size) + + # Pad to aligned length if needed + if aligned_max_seq_len > max_seq_len: + padding_size = aligned_max_seq_len - max_seq_len + gathered_k = torch.nn.functional.pad( + gathered_k, + (0, 0, 0, 0, 0, padding_size), # pad seq_len dimension + mode="constant", + value=0.0, + ) + gathered_v = torch.nn.functional.pad( + gathered_v, (0, 0, 0, 0, 0, padding_size), mode="constant", value=0.0 + ) + + return gathered_k, gathered_v + + def _build_attention_mask( + self, + seq_lens: torch.Tensor, # [num_seqs] + query_start_loc: torch.Tensor, # [num_seqs + 1] + apply_causal_mask: bool, + max_seq_len: int, + max_query_len: int, + device: torch.device, + ) -> torch.Tensor: + """ + Build a per-sequence attention mask with aligned KV length. + + Args: + max_seq_len: pre-computed max of seq_lens (avoids a device sync) + max_query_len: pre-computed max of query_lens (avoids a device sync) + + Returns: + mask: [num_seqs, 1, max_query_len, aligned_max_seq_len] + True = masked out (don't attend), False = attend + """ + query_lens = query_start_loc[1:] - query_start_loc[:-1] # [num_seqs] + + # Align max_seq_len to KV_LENGTH_ALIGNMENT + aligned_max_seq_len = ( + (max_seq_len + self.KV_LENGTH_ALIGNMENT - 1) + // self.KV_LENGTH_ALIGNMENT + * self.KV_LENGTH_ALIGNMENT + ) + + # Positions along query and KV dimensions + q_pos = torch.arange(max_query_len, device=device) # [max_query_len] + kv_pos = torch.arange(aligned_max_seq_len, device=device) # [aligned_max_seq_len] + + # Validity: which (seq, q, kv) positions are real (not padding)? + # [num_seqs, max_query_len] + q_valid = q_pos.unsqueeze(0) < query_lens.unsqueeze(1) + # [num_seqs, aligned_max_seq_len] - only positions < seq_len are valid + kv_valid = kv_pos.unsqueeze(0) < seq_lens.unsqueeze(1) + + # [num_seqs, max_query_len, aligned_max_seq_len] + attend = q_valid.unsqueeze(2) & kv_valid.unsqueeze(1) + + if apply_causal_mask: + # query token q_i (0-indexed) can attend to KV positions 0 .. context_len + q_i + context_lens = seq_lens - query_lens # [num_seqs] + # [num_seqs, max_query_len, 1] + causal_limit = (context_lens.unsqueeze(1) + q_pos.unsqueeze(0)).unsqueeze(2) + # [num_seqs, 1, aligned_max_seq_len] + kv_pos_exp = kv_pos.unsqueeze(0).unsqueeze(0) + causal_ok = kv_pos_exp <= causal_limit # [num_seqs, max_query_len, aligned_max_seq_len] + attend = attend & causal_ok + + # [num_seqs, 1, max_query_len, aligned_max_seq_len] True = masked out + return ~attend.unsqueeze(1) + + def _reshape_query_to_sequences( + self, + query: torch.Tensor, # [num_actual_tokens, num_heads, head_size] + query_start_loc: torch.Tensor, # [num_seqs + 1] + num_seqs: int, + max_query_len: int, + device: torch.device, + ) -> torch.Tensor: + """ + Reshape flat query tokens into a padded per-sequence tensor. + + Returns: + [num_seqs, max_query_len, num_heads, head_size] + """ + + query_lens = query_start_loc[1:] - query_start_loc[:-1] # [num_seqs] + + # [num_seqs, max_query_len] + positions = torch.arange(max_query_len, device=device).unsqueeze(0).expand(num_seqs, -1) + global_indices = query_start_loc[:-1].unsqueeze(1) + positions + + # Clamp so gather doesn't go OOB; invalid positions are masked in attention + global_indices_clamped = torch.clamp(global_indices, 0, query.shape[0] - 1) + + # [num_seqs, max_query_len, num_heads, head_size] + query_per_seq = query[global_indices_clamped] + + # Zero out padding positions + valid_mask = positions < query_lens.unsqueeze(1) + query_per_seq = query_per_seq * valid_mask.unsqueeze(-1).unsqueeze(-1) + + return query_per_seq + + def _compute_attention( + self, + query: torch.Tensor, # [num_seqs, max_query_len, num_heads, head_size] + key: torch.Tensor, # [num_seqs, max_seq_len, num_kv_heads, head_size] + value: torch.Tensor, # [num_seqs, max_seq_len, num_kv_heads, head_size] + mask: torch.Tensor, # [num_seqs, 1, max_query_len, max_seq_len] True=masked + device: torch.device, # device for intermediate allocations + dtype: torch.dtype, # dtype for intermediate allocations + ) -> torch.Tensor: + """Dispatch attention: SDPA path or per-sequence chunked Spyre path. + + Returns: + [num_seqs, max_query_len, num_heads, head_size] + """ + num_seqs = query.shape[0] + + # As fallback, use SDPA implementation + if self.use_sdpa: + return self._compute_attention_sdpa(query, key, value, mask) + + # Allocate output tensor for all sequences + output_all_seqs = torch.zeros_like(query) + + # Process each sequence separately + for seq_idx in range(num_seqs): + # Extract single sequence + query_seq = query[seq_idx : seq_idx + 1] # [1, max_query_len, num_heads, head_size] + key_seq = key[seq_idx : seq_idx + 1] # [1, max_seq_len, num_kv_heads, head_size] + value_seq = value[seq_idx : seq_idx + 1] # [1, max_seq_len, num_kv_heads, head_size] + mask_seq = ( + mask[seq_idx : seq_idx + 1] if mask is not None else None + ) # [1, 1, max_query_len, max_seq_len] + + # Compute attention for this sequence + output_seq = self._compute_attention_single_seq( + query_seq, key_seq, value_seq, mask_seq, device, dtype + ) + + # Store result + output_all_seqs[seq_idx] = output_seq.squeeze(0) + + return output_all_seqs + + def _compute_attention_sdpa( + self, + query: torch.Tensor, # [num_seqs, max_query_len, num_heads, head_size] + key: torch.Tensor, # [num_seqs, max_seq_len, num_kv_heads, head_size] + value: torch.Tensor, # [num_seqs, max_seq_len, num_kv_heads, head_size] + mask: torch.Tensor, # [num_seqs, 1, max_query_len, max_seq_len] True=masked + ) -> torch.Tensor: + """SDPA path: runs compiled scaled_dot_product_attention. + + Note: Currently runs on CPU. TODO: Transfer to Spyre when supported. + Currently not supported because + - GQA + - Non-square attention + """ + out = self.attn_op( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=~mask, + scale=self.scale, + enable_gqa=True, + ) + return out.transpose(1, 2) + + def _compute_attention_single_seq( + self, + query: torch.Tensor, # [1, max_query_len, num_heads, head_size] + key: torch.Tensor, # [1, max_seq_len, num_kv_heads, head_size] + value: torch.Tensor, # [1, max_seq_len, num_kv_heads, head_size] + mask: torch.Tensor | None, # [1, 1, max_query_len, max_seq_len] + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """Compute attention for a single sequence using Spyre. + + Processes queries in fixed-size chunks of QUERY_CHUNK_SIZE tokens. + """ + + _, _, num_heads, head_size = query.shape + _, kv_len, num_kv_heads, _ = key.shape + + # Handle grouped-query attention by repeating KV heads + if self.num_queries_per_kv > 1: + key = key.repeat_interleave(self.num_queries_per_kv, dim=2) + value = value.repeat_interleave(self.num_queries_per_kv, dim=2) + + # Squeeze batch dimension + query_squeezed = query.squeeze(0) # [query_len, num_heads, head_size] + key_squeezed = key.squeeze(0) # [kv_len, num_heads, head_size] + value_squeezed = value.squeeze(0) # [kv_len, num_heads, head_size] + + # Calculate number of chunks needed + actual_query_len = query_squeezed.shape[0] + num_chunks = (actual_query_len + self.QUERY_CHUNK_SIZE - 1) // self.QUERY_CHUNK_SIZE + + output_full = torch.empty( + actual_query_len, + num_heads, + head_size, + dtype=dtype, + device=device, + ) + + # Process each chunk + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * self.QUERY_CHUNK_SIZE + chunk_end = min(chunk_start + self.QUERY_CHUNK_SIZE, actual_query_len) + chunk_len = chunk_end - chunk_start + + # Extract query chunk + query_chunk = query_squeezed[chunk_start:chunk_end] + + # Pad chunk if needed + if chunk_len < self.QUERY_CHUNK_SIZE: + padding_size = self.QUERY_CHUNK_SIZE - chunk_len + query_chunk_padded = torch.nn.functional.pad( + query_chunk, (0, 0, 0, 0, 0, padding_size), mode="constant", value=0.0 + ) + else: + query_chunk_padded = query_chunk + + # Extract corresponding mask for this chunk + if mask is not None: + mask_chunk = mask[:, :, chunk_start:chunk_end, :] # [1, 1, chunk_len, kv_len] + else: + mask_chunk = None + + # Compute attention for this chunk + chunk_output = self._compute_attention_chunk( + query_chunk_padded, + key_squeezed, + value_squeezed, + mask_chunk, + chunk_len, + num_heads, + head_size, + kv_len, + device, + dtype, + ) + + # Store chunk output (only valid positions) + output_full[chunk_start:chunk_end] = chunk_output[:chunk_len] + + return output_full.unsqueeze(0) # [1, query_len, num_heads, head_size] + + def _compute_attention_chunk( + self, + query_chunk_padded: torch.Tensor, # [QUERY_CHUNK_SIZE, num_heads, head_size] + key_squeezed: torch.Tensor, # [kv_len, num_heads, head_size] + value_squeezed: torch.Tensor, # [kv_len, num_heads, head_size] + mask_chunk: torch.Tensor | None, # [1, 1, chunk_len, kv_len] + chunk_len: int, + num_heads: int, + head_size: int, + kv_len: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """Compute attention for a single query chunk on Spyre. + + Prepares tensors on CPU (reshape, stickify, build mask), transfers to + Spyre for the compiled matmul kernel, then transfers the result back. + + Returns: + [QUERY_CHUNK_SIZE, num_heads, head_size] — attention output (padded) + """ + padded_query_len = self.QUERY_CHUNK_SIZE + + # Reshape query to flatten heads into query dimension + query_reordered = query_chunk_padded.transpose( + 0, 1 + ).contiguous() # [num_heads, QUERY_CHUNK_SIZE, head_size] + query_flat = query_reordered.reshape(num_heads * padded_query_len, head_size) + + # Key and value: also flatten across heads + key_reordered = key_squeezed.transpose(0, 1).contiguous() # [num_heads, kv_len, head_size] + value_reordered = value_squeezed.transpose( + 0, 1 + ).contiguous() # [num_heads, kv_len, head_size] + + key_flat = key_reordered.reshape(num_heads * kv_len, head_size) + value_flat = value_reordered.reshape(num_heads * kv_len, head_size) + + # Transpose for attention computation + qt = query_flat.T.contiguous() # [head_size, num_heads * QUERY_CHUNK_SIZE] + vt = value_flat.T.contiguous() # [head_size, num_heads * kv_len] + k = key_flat # [num_heads * kv_len, head_size] + + # Stickification: force Spyre-friendly memory layout. + # Transposed tensors need double transpose-contiguous; standard tensors just contiguous. + qt_stickified = qt.transpose(0, 1).contiguous().transpose(0, 1).contiguous() + vt_stickified = vt.transpose(0, 1).contiguous().transpose(0, 1).contiguous() + k_stickified = k.contiguous() + + # Scale factor: 1D tensor replicated per head × query position + sm_scale_1d = torch.tensor(self.scale, dtype=dtype, device=device).repeat( + num_heads * padded_query_len + ) # [num_heads * QUERY_CHUNK_SIZE] + + # --- Build block-diagonal additive mask --- + # The transposed kernel flattens all heads into one matmul, so the mask + # must be block-diagonal: each head's causal/padding mask sits on the + # diagonal, off-diagonal blocks are masked (-65504). + if mask_chunk is not None: + mask_all_heads = mask_chunk[0, 0] # [chunk_len, kv_len] + + # Pad query dimension to QUERY_CHUNK_SIZE if this is the last chunk + if chunk_len < self.QUERY_CHUNK_SIZE: + padding_size = self.QUERY_CHUNK_SIZE - chunk_len + mask_padding = torch.ones((padding_size, kv_len), dtype=torch.bool, device=device) + mask_all_heads = torch.cat([mask_all_heads, mask_padding], dim=0) + + head_mask_t = mask_all_heads.T # [kv_len, QUERY_CHUNK_SIZE], True = masked + mask_bool = ~torch.block_diag(*([~head_mask_t] * num_heads)) + else: + # No causal/padding mask: only cross-head positions are masked. + ones_block = torch.ones(kv_len, padded_query_len, dtype=torch.bool, device=device) + mask_bool = ~torch.block_diag(*([ones_block] * num_heads)) + + # Convert boolean mask to additive: True → -65504.0, False → 0.0 + mask_values = torch.where( + mask_bool, + torch.tensor(-65504.0, dtype=dtype, device=device), + torch.tensor(0.0, dtype=dtype, device=device), + ).contiguous() + + # --- Transfer to Spyre, compute, transfer back --- + qt_spyre = convert(qt_stickified, self._target_device, self._target_dtype) + k_spyre = convert(k_stickified, self._target_device, self._target_dtype) + vt_spyre = convert(vt_stickified, self._target_device, self._target_dtype) + sm_scale_spyre = convert(sm_scale_1d, self._target_device, self._target_dtype) + mask_spyre = convert(mask_values, self._target_device, self._target_dtype) + + # Compiled attention on Spyre + output_spyre_t = self.attn_op(qt_spyre, k_spyre, vt_spyre, sm_scale_spyre, mask_spyre) + + # Transfer back to CPU + output_flat = convert( + output_spyre_t, device, dtype + ).contiguous() # [head_size, num_heads * QUERY_CHUNK_SIZE] + + # Reshape: [head_size, N*Q] → [N, Q, head_size] → [Q, N, head_size] + output_transposed = output_flat.T # [num_heads * QUERY_CHUNK_SIZE, head_size] + output_reshaped = output_transposed.reshape(num_heads, padded_query_len, head_size) + + # [QUERY_CHUNK_SIZE, num_heads, head_size] + return output_reshaped.transpose(0, 1).contiguous() + + def _extract_relevant_output( + self, + attn_output: torch.Tensor, # [num_seqs, max_query_len, num_heads, head_size] + query_start_loc: torch.Tensor, # [num_seqs + 1] + ) -> torch.Tensor: + """ + Extract actual query tokens from padded per-sequence output. + + Returns: + [num_actual_tokens, num_heads, head_size] + """ + max_query_len = attn_output.shape[1] + device = attn_output.device + + query_lens = query_start_loc[1:] - query_start_loc[:-1] # [num_seqs] + + # Boolean index into [num_seqs, max_query_len] + positions = torch.arange(max_query_len, device=device).unsqueeze(0) + valid = positions < query_lens.unsqueeze(1) # [num_seqs, max_query_len] + + # Boolean indexing flattens the first two dims and keeps the rest + return attn_output[valid] # [num_actual_tokens, num_heads, head_size]