From 1d8dc286cb2a8ad042469f321f848df6d09cb041 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Sat, 21 Jun 2025 04:00:48 +0000 Subject: [PATCH 1/8] [TPU] kv cache update kernel Signed-off-by: Chengji Yao --- tests/v1/tpu/test_kv_cache_update_kernel.py | 59 +++++++++ vllm/attention/ops/pallas_kv_cache_update.py | 109 +++++++++++++++ vllm/v1/attention/backends/pallas.py | 55 +++++++- vllm/v1/worker/tpu_model_runner.py | 131 ++++++++++++++----- 4 files changed, 316 insertions(+), 38 deletions(-) create mode 100644 tests/v1/tpu/test_kv_cache_update_kernel.py create mode 100644 vllm/attention/ops/pallas_kv_cache_update.py diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py new file mode 100644 index 000000000000..e8045d1a2d8b --- /dev/null +++ b/tests/v1/tpu/test_kv_cache_update_kernel.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy as np +import pytest +import torch +import torch_xla + +import vllm.v1.attention.backends.pallas # noqa: F401 +from vllm.platforms import current_platform + + +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This is a test for TPU only") +def test_kv_cache_update_kernel(): + page_num = 1000 + page_size = 32 + combined_kv_head_num = 16 + head_dim = 128 + kernel_block_size = 16 + padded_num_tokens = 128 + kv_cache_cpu = torch.zeros( + (page_num * page_size, combined_kv_head_num, head_dim), + dtype=torch.bfloat16, + device="cpu") + kv_cache_xla = kv_cache_cpu.to(torch_xla.device()) + new_kv_cpu = torch.randn( + (padded_num_tokens, combined_kv_head_num, head_dim), + dtype=torch.bfloat16, + device="cpu") + new_kv_xla = new_kv_cpu.to(torch_xla.device()) + slice_lens = np.array([7, 32, 32, 1, 1, 1, 9], dtype=np.int32) + kv_cache_start_indices = np.array([57, 64, 96, 104, 213, 345, 488], + dtype=np.int32) + new_kv_cache_indices = np.array([0, 7, 39, 71, 72, 73, 74], dtype=np.int32) + slot_mapping = np.stack( + [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1) + slot_mapping = np.pad( + slot_mapping, [[0, kernel_block_size - slot_mapping.shape[0]], [0, 0]], + constant_values=0) + slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu") + slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device()) + torch_xla.sync() + + torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True) + new_kv_cache_xla = torch.ops.xla.kv_cache_update_op( + new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size, + kernel_block_size) + kv_cache_xla.copy_(new_kv_cache_xla) + torch_xla.sync() + + for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, + slice_lens): + kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :] + + assert torch.allclose(kv_cache_xla.cpu(), + kv_cache_cpu, + atol=1e-4, + rtol=1e-4) diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py new file mode 100644 index 000000000000..f4a3ce75025e --- /dev/null +++ b/vllm/attention/ops/pallas_kv_cache_update.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import functools + +import jax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + + +def _kv_cache_update_kernel( + # Prefetch + slices_ref, # [num_slices, 3] + # Input + new_kv_hbm_ref, # [tokens, kv_head_num, head_dim] + kv_cache_hbm_ref, + # Output + _, # [total_num_pages * page_size, kv_head_num, head_dim] + # Scratch + scratch, # [block_size, page_size, kv_head_num, head_dim] + sem, +): + async_copies = [] + block_idx = pl.program_id(0) + block_size = scratch.shape[0] + + # Copy from new_kv_hbm_ref to scratch + for i in range(block_size): + offset_i = i + block_idx * block_size + new_kv_start = slices_ref[offset_i, 1] + length = slices_ref[offset_i, 2] + async_copy = pltpu.make_async_copy( + new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], + scratch.at[i, pl.ds(0, length), ...], + sem, + ) + async_copy.start() + async_copies.append(async_copy) + + for async_copy in async_copies: + async_copy.wait() + + # Copy from scratch to kv_cache_hbm_ref + async_copies.clear() + for i in range(block_size): + offset_i = i + block_idx * block_size + kv_cache_start = slices_ref[offset_i, 0] + length = slices_ref[offset_i, 2] + async_copy = pltpu.make_async_copy( + scratch.at[i, pl.ds(0, length), ...], + kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], + sem, + ) + async_copy.start() + async_copies.append(async_copy) + for async_copy in async_copies: + async_copy.wait() + + +@functools.partial( + jax.jit, + static_argnames=["page_size", "block_size"], +) +def kv_cache_update( + new_kv: jax.Array, # [total_num_token, kv_head_num, head_dim] + slices: jax. + Array, # [num_slices, 3], list of (kv_cache_start, new_kv_start, slice_len) + kv_cache: jax. + Array, # [total_num_pages * page_size, kv_head_num, head_dim] + *, + page_size: int = 32, + block_size: int = 8, +): + assert slices.shape[0] % block_size == 0 + _, kv_head_num, head_dim = new_kv.shape + + in_specs = [ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ] + + out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)] + out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)] + + scalar_prefetches = [slices] + scratch = pltpu.VMEM( + (block_size, page_size, kv_head_num, head_dim), + new_kv.dtype, + ) + + scratch_shapes = [ + scratch, + pltpu.SemaphoreType.DMA, + ] + + kernel = pl.pallas_call( + _kv_cache_update_kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=len(scalar_prefetches), + in_specs=in_specs, + out_specs=out_specs, + grid=(slices.shape[0] // block_size, ), + scratch_shapes=scratch_shapes, + ), + out_shape=out_shape, + input_output_aliases={len(scalar_prefetches) + 1: 0}, + ) + + return kernel(*scalar_prefetches, new_kv, kv_cache)[0] diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index ff2862edaa01..aa155246bfbd 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -5,8 +5,12 @@ from typing import Any, Optional import torch -# Required to register custom ops. +import torch_xla.core.xla_builder as xb import torch_xla.experimental.custom_kernel # noqa: F401 +# Required to register custom ops. +from torch.library import impl +from torch_xla._internal.jax_workarounds import requires_jax +from torch_xla.experimental.custom_kernel import XLA_LIB from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) @@ -107,6 +111,7 @@ class PallasMetadata: context_lens: torch.Tensor query_start_loc: torch.Tensor num_seqs: torch.Tensor + kv_cache_update_block_size: int class PallasAttentionBackendImpl(AttentionImpl): @@ -212,7 +217,10 @@ def forward( # Write input keys and values to the KV cache. # Skip this if sharing KV cache with an earlier attention layer. slot_mapping = attn_metadata.slot_mapping - write_to_kv_cache(key, value, kv_cache, slot_mapping) + kv_cache_update_block_size = \ + attn_metadata.kv_cache_update_block_size + write_to_kv_cache(key, value, kv_cache, slot_mapping, + kv_cache_update_block_size) output = torch.ops.xla.ragged_paged_attention( query, @@ -244,6 +252,7 @@ def write_to_kv_cache( value: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache_update_block_size: int, ) -> None: """ Write the key and values to the KV cache. @@ -251,9 +260,9 @@ def write_to_kv_cache( key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] - + kv_cache_update_block_size: int """ - _, _, num_combined_kv_heads, head_size = kv_cache.shape + _, page_size, num_combined_kv_heads, head_size = kv_cache.shape head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, @@ -262,4 +271,40 @@ def write_to_kv_cache( torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True) kv_cache = kv_cache.flatten(0, 1) - kv_cache.index_copy_(0, slot_mapping, kv) + new_kv_cache = torch.ops.xla.kv_cache_update_op( + kv, slot_mapping, kv_cache, page_size, kv_cache_update_block_size) + # NOTE: the in-place copy will be optimized away by XLA compiler. + kv_cache.copy_(new_kv_cache) + + +@requires_jax +def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, page_size: int, + block_size: int): + from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update + new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), { + "page_size": page_size, + "block_size": block_size + }) + return new_kv_cache + + +XLA_LIB.define( + "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, " + "int page_size, int block_size) -> Tensor", ) + + +@impl(XLA_LIB, "kv_cache_update_op", "XLA") +def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, page_size: int, + block_size: int) -> torch.Tensor: + new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, + page_size, block_size) + return new_kv_cache + + +@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") +def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, page_size: int, + block_size: int) -> torch.Tensor: + return kv_cache diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 2d80bac3c954..b54bdb244bad 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -53,12 +53,10 @@ logger = init_logger(__name__) -# Here we utilize the behavior that out-of-bound index is ignored. -# FIXME(woosuk): Find a more reliable way to prevent possible bugs. -_PAD_SLOT_ID = 1_000_000_000 INVALID_TOKEN_ID = -1 # Smallest output size MIN_NUM_SEQS = 8 +KV_CACHE_UPDATE_BLOCK_SIZE = 8 ######################################################### @@ -526,6 +524,69 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return kv_cache_spec + def _get_slot_mapping_metadata(self, num_reqs, + num_scheduled_tokens_per_req): + """ + Computes metadata for mapping slots to blocks in the key-value (KV) + cache for a batch of requests. + + This function determines, for each request in the batch, how the + scheduled tokens are distributed across memory blocks, and generates + metadata needed to map slices of tokens to their corresponding positions + in the KV cache. + + Args: + num_reqs (int): Number of requests in the current batch. + num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens + to be scheduled for each request. + + Returns: + np.ndarray: A 2D array of shape (total_block_len, 3), where each row + contains: + - kv_cache_start_index (int): The starting index in the KV cache + for the corresponding slice. + - new_kv_start_index (int): The starting index in the new KV + cache for the corresponding slice. + - slice_len (int): The length of the slice. + """ + slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs] + slices_end = self.input_batch.num_computed_tokens_cpu[:num_reqs] + \ + num_scheduled_tokens_per_req + local_block_start_idx = slices_start // self.block_size + local_block_end_idx = (slices_end - 1) // self.block_size + no_repeat_req_indices = self.arange_np[:num_reqs] + global_block_start_idx = ( + no_repeat_req_indices * self.max_num_blocks_per_req + + local_block_start_idx) + block_lens = local_block_end_idx - local_block_start_idx + 1 + global_block_start_idx = np.repeat(global_block_start_idx, block_lens) + slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens]) + global_block_indices = global_block_start_idx + slice_arange + block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[global_block_indices].numpy() + total_block_len = np.sum(block_lens) + slot_mapping_slices = np.repeat(np.array([[0, self.block_size]], + dtype=np.int32), + total_block_len, + axis=0) + cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32) + np.cumsum(block_lens, out=cu_block_lens[1:]) + for req_idx in range(num_reqs): + slot_mapping_slices[cu_block_lens[req_idx]][ + 0] = slices_start[req_idx] % self.block_size + slot_mapping_slices[ + cu_block_lens[req_idx + 1] - + 1][1] = (slices_end[req_idx] - 1) % self.block_size + 1 + slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0] + cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32) + np.cumsum(slice_lens, out=cu_slices_lens[1:]) + kv_cache_start_indices = slot_mapping_slices[:, 0] + \ + (block_numbers * self.block_size) + new_kv_start_indices = cu_slices_lens[:-1] + slot_mapping_metadata = np.stack( + [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1) + return slot_mapping_metadata + def _prepare_inputs(self, scheduler_output: "SchedulerOutput", start_index: int): assert scheduler_output.total_num_scheduled_tokens > 0 @@ -603,26 +664,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` here - # because M (max_model_len) is not necessarily divisible by block_size. - # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.input_batch.block_table[0]. - slot_mapping_np[:total_num_scheduled_tokens]) - # Prepare the attention metadata. self.query_start_loc_np[0] = 0 np.cumsum(num_scheduled_tokens_per_req, @@ -645,12 +686,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", self.position_ids = self.positions_cpu[: padded_total_num_scheduled_tokens].to( self.device) - self.input_batch.block_table[0].slot_mapping_cpu[ - total_num_scheduled_tokens:] = _PAD_SLOT_ID - slot_mapping = ( - self.input_batch.block_table[0]. - slot_mapping_cpu[:padded_total_num_scheduled_tokens].to( - self.device)) if use_max_model_len: block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, : self.max_num_blocks_per_req] @@ -675,6 +710,18 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", self.device) block_tables = block_tables.to(self.device) + slot_mapping_metadata = self._get_slot_mapping_metadata( + num_reqs, num_scheduled_tokens_per_req) + padded_num_slices = _get_padded_num_kv_cache_update_slices( + padded_total_num_scheduled_tokens, self.max_num_reqs, + self.block_size) + slot_mapping_metadata = np.pad( + slot_mapping_metadata, + [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], + constant_values=0) + slot_mapping_metadata = torch.tensor(slot_mapping_metadata, + device=self.device) + if self.lora_config is not None: # We need to respect padding when activating LoRA adapters padded_num_scheduled_tokens_per_req = np.copy( @@ -687,13 +734,14 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", padded_num_scheduled_tokens_per_req) attn_metadata = PallasMetadata( - slot_mapping=slot_mapping, + slot_mapping=slot_mapping_metadata, block_tables=block_tables, context_lens=seq_lens, query_start_loc=query_start_loc, num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device), + kv_cache_update_block_size=KV_CACHE_UPDATE_BLOCK_SIZE, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -1119,10 +1167,13 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, actual_num_reqs = min(num_tokens, num_reqs) position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) - slot_mapping = torch.zeros(num_tokens, - dtype=torch.int64).to(self.device) - block_tables = torch.zeros((num_reqs, num_blocks), + padded_num_slices = _get_padded_num_kv_cache_update_slices( + num_tokens, self.max_num_reqs, self.block_size) + slot_mapping = torch.zeros((padded_num_slices, 3), dtype=torch.int32).to(self.device) + block_tables = torch.zeros( + (num_reqs, num_blocks), + dtype=torch.int32).to(self.device) query_lens = [1] * num_reqs query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), @@ -1138,6 +1189,7 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, + kv_cache_update_block_size=KV_CACHE_UPDATE_BLOCK_SIZE, ) if self.is_multimodal_model: @@ -1742,6 +1794,19 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: return paddings[index] +def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, + page_size: int) -> int: + """ A fixed shape of slot_mapping_metadata tensor is required to avoid + recompilation. + """ + padded_num_slices = 2 * max_num_reqs + num_tokens // page_size + pagged_num_slices = min(padded_num_slices, num_tokens) + pagged_num_slices = ( + pagged_num_slices + KV_CACHE_UPDATE_BLOCK_SIZE - + 1) // KV_CACHE_UPDATE_BLOCK_SIZE * KV_CACHE_UPDATE_BLOCK_SIZE + return pagged_num_slices + + def replace_set_lora(model): def _tpu_set_lora( From bd333a3f5a84af91a20a0d90683c0cbc31c95255 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Sat, 21 Jun 2025 05:13:46 +0000 Subject: [PATCH 2/8] add test in ci Signed-off-by: Chengji Yao --- .buildkite/scripts/hardware_ci/run-tpu-v1-test.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index a2a5c2a02cbb..90cad506ab1e 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py" run_and_track_test 15 "test_spmd_model_weight_loading.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py" +run_and_track_test 16 "test_kv_cache_update_kernel.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py" # After all tests have been attempted, exit with the overall status. if [ "$overall_script_exit_code" -ne 0 ]; then From 2e4d7408da4f25eb971164251333a1535cc0cc00 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Tue, 24 Jun 2025 00:18:04 +0000 Subject: [PATCH 3/8] fix comments Signed-off-by: Chengji Yao --- vllm/attention/ops/pallas_kv_cache_update.py | 14 +++++++------- vllm/v1/worker/tpu_model_runner.py | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py index f4a3ce75025e..a916eca30033 100644 --- a/vllm/attention/ops/pallas_kv_cache_update.py +++ b/vllm/attention/ops/pallas_kv_cache_update.py @@ -12,12 +12,12 @@ def _kv_cache_update_kernel( # Prefetch slices_ref, # [num_slices, 3] # Input - new_kv_hbm_ref, # [tokens, kv_head_num, head_dim] + new_kv_hbm_ref, # [tokens, num_combined_kv_heads, head_dim] kv_cache_hbm_ref, # Output - _, # [total_num_pages * page_size, kv_head_num, head_dim] + _, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] # Scratch - scratch, # [block_size, page_size, kv_head_num, head_dim] + scratch, # [block_size, page_size, num_combined_kv_heads, head_dim] sem, ): async_copies = [] @@ -62,17 +62,17 @@ def _kv_cache_update_kernel( static_argnames=["page_size", "block_size"], ) def kv_cache_update( - new_kv: jax.Array, # [total_num_token, kv_head_num, head_dim] + new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim] slices: jax. Array, # [num_slices, 3], list of (kv_cache_start, new_kv_start, slice_len) kv_cache: jax. - Array, # [total_num_pages * page_size, kv_head_num, head_dim] + Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] *, page_size: int = 32, block_size: int = 8, ): assert slices.shape[0] % block_size == 0 - _, kv_head_num, head_dim = new_kv.shape + _, num_combined_kv_heads, head_dim = new_kv.shape in_specs = [ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -84,7 +84,7 @@ def kv_cache_update( scalar_prefetches = [slices] scratch = pltpu.VMEM( - (block_size, page_size, kv_head_num, head_dim), + (block_size, page_size, num_combined_kv_heads, head_dim), new_kv.dtype, ) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index b54bdb244bad..73cfd26806cf 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -56,6 +56,7 @@ INVALID_TOKEN_ID = -1 # Smallest output size MIN_NUM_SEQS = 8 +# Block size used for kv cache updating kernel KV_CACHE_UPDATE_BLOCK_SIZE = 8 @@ -1796,9 +1797,8 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, page_size: int) -> int: - """ A fixed shape of slot_mapping_metadata tensor is required to avoid - recompilation. - """ + """Calculates the padded number of KV cache update slices to avoid + recompilation.""" padded_num_slices = 2 * max_num_reqs + num_tokens // page_size pagged_num_slices = min(padded_num_slices, num_tokens) pagged_num_slices = ( From 0c7ae5633e157568f1c010d1313a126dc4bac9bf Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Tue, 24 Jun 2025 19:36:14 +0000 Subject: [PATCH 4/8] fix comments Signed-off-by: Chengji Yao --- tests/v1/tpu/test_kv_cache_update_kernel.py | 35 +++++++++++++------- vllm/attention/ops/pallas_kv_cache_update.py | 3 ++ vllm/v1/worker/tpu_model_runner.py | 16 ++++----- 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py index e8045d1a2d8b..b73a4de2f7a6 100644 --- a/tests/v1/tpu/test_kv_cache_update_kernel.py +++ b/tests/v1/tpu/test_kv_cache_update_kernel.py @@ -12,12 +12,13 @@ @pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only") -def test_kv_cache_update_kernel(): +@pytest.mark.parametrize("page_size", [32, 33]) +@pytest.mark.parametrize("combined_kv_head_num", [2, 16]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("kernel_block_size", [4, 8]) +def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int, + head_dim: int, kernel_block_size: int): page_num = 1000 - page_size = 32 - combined_kv_head_num = 16 - head_dim = 128 - kernel_block_size = 16 padded_num_tokens = 128 kv_cache_cpu = torch.zeros( (page_num * page_size, combined_kv_head_num, head_dim), @@ -29,16 +30,26 @@ def test_kv_cache_update_kernel(): dtype=torch.bfloat16, device="cpu") new_kv_xla = new_kv_cpu.to(torch_xla.device()) - slice_lens = np.array([7, 32, 32, 1, 1, 1, 9], dtype=np.int32) - kv_cache_start_indices = np.array([57, 64, 96, 104, 213, 345, 488], + slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], + dtype=np.int32) + kv_cache_start_indices = np.array([ + page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6, + page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3 + ], dtype=np.int32) - new_kv_cache_indices = np.array([0, 7, 39, 71, 72, 73, 74], dtype=np.int32) + new_kv_cache_indices = np.concatenate( + [np.array([0], dtype=np.int32), + np.cumsum(slice_lens[:-1])]) slot_mapping = np.stack( [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1) - slot_mapping = np.pad( - slot_mapping, [[0, kernel_block_size - slot_mapping.shape[0]], [0, 0]], - constant_values=0) - slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu") + padded_size = (slot_mapping.shape[0] + kernel_block_size - + 1) // kernel_block_size * kernel_block_size + slot_mapping = np.pad(slot_mapping, + [[0, padded_size - slot_mapping.shape[0]], [0, 0]], + constant_values=0) + slot_mapping_cpu = torch.tensor(slot_mapping, + device="cpu", + dtype=torch.int32) slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device()) torch_xla.sync() diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py index a916eca30033..e775c38b1756 100644 --- a/vllm/attention/ops/pallas_kv_cache_update.py +++ b/vllm/attention/ops/pallas_kv_cache_update.py @@ -73,6 +73,9 @@ def kv_cache_update( ): assert slices.shape[0] % block_size == 0 _, num_combined_kv_heads, head_dim = new_kv.shape + assert kv_cache.shape[1] == num_combined_kv_heads + assert kv_cache.shape[2] == head_dim + assert head_dim % 128 == 0 in_specs = [ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 73cfd26806cf..2d5f3e6fad6b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -57,7 +57,7 @@ # Smallest output size MIN_NUM_SEQS = 8 # Block size used for kv cache updating kernel -KV_CACHE_UPDATE_BLOCK_SIZE = 8 +KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE = 8 ######################################################### @@ -742,7 +742,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device), - kv_cache_update_block_size=KV_CACHE_UPDATE_BLOCK_SIZE, + kv_cache_update_block_size=KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -1190,7 +1190,7 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, - kv_cache_update_block_size=KV_CACHE_UPDATE_BLOCK_SIZE, + kv_cache_update_block_size=KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE, ) if self.is_multimodal_model: @@ -1800,11 +1800,11 @@ def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, """Calculates the padded number of KV cache update slices to avoid recompilation.""" padded_num_slices = 2 * max_num_reqs + num_tokens // page_size - pagged_num_slices = min(padded_num_slices, num_tokens) - pagged_num_slices = ( - pagged_num_slices + KV_CACHE_UPDATE_BLOCK_SIZE - - 1) // KV_CACHE_UPDATE_BLOCK_SIZE * KV_CACHE_UPDATE_BLOCK_SIZE - return pagged_num_slices + padded_num_slices = min(padded_num_slices, num_tokens) + padded_num_slices = ( + padded_num_slices + KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE - 1 + ) // KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE * KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE + return padded_num_slices def replace_set_lora(model): From bb7fc2ded2219927c25c332b7be3c6215058e3e7 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Wed, 25 Jun 2025 17:07:21 +0000 Subject: [PATCH 5/8] fix test Signed-off-by: Chengji Yao --- tests/v1/tpu/test_pallas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index 3a9d80847a16..0329869b1a13 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -65,6 +65,7 @@ class FakeAttentionLayer: context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, + kv_cache_update_block_size=8, ) with patch("torch.ops.xla.ragged_paged_attention" From 068e16956cc5e347de1889c0dec271e6da627a18 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Wed, 25 Jun 2025 19:18:19 +0000 Subject: [PATCH 6/8] fix comments Signed-off-by: Chengji Yao --- tests/v1/tpu/test_kv_cache_update_kernel.py | 11 ++--- tests/v1/tpu/test_pallas.py | 2 +- vllm/attention/ops/pallas_kv_cache_update.py | 43 +++++++++++--------- vllm/v1/attention/backends/pallas.py | 28 ++++++------- vllm/v1/worker/tpu_model_runner.py | 16 +++++--- 5 files changed, 55 insertions(+), 45 deletions(-) diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py index b73a4de2f7a6..63a1f6777e4d 100644 --- a/tests/v1/tpu/test_kv_cache_update_kernel.py +++ b/tests/v1/tpu/test_kv_cache_update_kernel.py @@ -15,9 +15,9 @@ @pytest.mark.parametrize("page_size", [32, 33]) @pytest.mark.parametrize("combined_kv_head_num", [2, 16]) @pytest.mark.parametrize("head_dim", [128, 256]) -@pytest.mark.parametrize("kernel_block_size", [4, 8]) +@pytest.mark.parametrize("num_slices_per_block", [4, 8]) def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int, - head_dim: int, kernel_block_size: int): + head_dim: int, num_slices_per_block: int): page_num = 1000 padded_num_tokens = 128 kv_cache_cpu = torch.zeros( @@ -42,11 +42,12 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int, np.cumsum(slice_lens[:-1])]) slot_mapping = np.stack( [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1) - padded_size = (slot_mapping.shape[0] + kernel_block_size - - 1) // kernel_block_size * kernel_block_size + padded_size = (slot_mapping.shape[0] + num_slices_per_block - + 1) // num_slices_per_block * num_slices_per_block slot_mapping = np.pad(slot_mapping, [[0, padded_size - slot_mapping.shape[0]], [0, 0]], constant_values=0) + slot_mapping = np.transpose(slot_mapping) slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu", dtype=torch.int32) @@ -56,7 +57,7 @@ def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int, torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True) new_kv_cache_xla = torch.ops.xla.kv_cache_update_op( new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size, - kernel_block_size) + num_slices_per_block) kv_cache_xla.copy_(new_kv_cache_xla) torch_xla.sync() diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index 0329869b1a13..3087d720da63 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -65,7 +65,7 @@ class FakeAttentionLayer: context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, - kv_cache_update_block_size=8, + num_slices_per_kv_cache_update_block=8, ) with patch("torch.ops.xla.ragged_paged_attention" diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py index e775c38b1756..1a92b10e4f9c 100644 --- a/vllm/attention/ops/pallas_kv_cache_update.py +++ b/vllm/attention/ops/pallas_kv_cache_update.py @@ -10,25 +10,28 @@ def _kv_cache_update_kernel( # Prefetch - slices_ref, # [num_slices, 3] + slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start, + # slice_len) # Input - new_kv_hbm_ref, # [tokens, num_combined_kv_heads, head_dim] - kv_cache_hbm_ref, + new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim] + kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads, + # head_dim] # Output _, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] # Scratch - scratch, # [block_size, page_size, num_combined_kv_heads, head_dim] + scratch, # [num_slices_per_block, page_size, num_combined_kv_heads, + # head_dim] sem, ): async_copies = [] block_idx = pl.program_id(0) - block_size = scratch.shape[0] + num_slices_per_block = scratch.shape[0] # Copy from new_kv_hbm_ref to scratch - for i in range(block_size): - offset_i = i + block_idx * block_size - new_kv_start = slices_ref[offset_i, 1] - length = slices_ref[offset_i, 2] + for i in range(num_slices_per_block): + offset_i = i + block_idx * num_slices_per_block + new_kv_start = slices_ref[1, offset_i] + length = slices_ref[2, offset_i] async_copy = pltpu.make_async_copy( new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], scratch.at[i, pl.ds(0, length), ...], @@ -42,10 +45,10 @@ def _kv_cache_update_kernel( # Copy from scratch to kv_cache_hbm_ref async_copies.clear() - for i in range(block_size): - offset_i = i + block_idx * block_size - kv_cache_start = slices_ref[offset_i, 0] - length = slices_ref[offset_i, 2] + for i in range(num_slices_per_block): + offset_i = i + block_idx * num_slices_per_block + kv_cache_start = slices_ref[0, offset_i] + length = slices_ref[2, offset_i] async_copy = pltpu.make_async_copy( scratch.at[i, pl.ds(0, length), ...], kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], @@ -59,23 +62,25 @@ def _kv_cache_update_kernel( @functools.partial( jax.jit, - static_argnames=["page_size", "block_size"], + static_argnames=["page_size", "num_slices_per_block"], ) def kv_cache_update( new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim] slices: jax. - Array, # [num_slices, 3], list of (kv_cache_start, new_kv_start, slice_len) + Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len) kv_cache: jax. Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] *, page_size: int = 32, - block_size: int = 8, + num_slices_per_block: int = 8, ): - assert slices.shape[0] % block_size == 0 + assert slices.shape[1] % num_slices_per_block == 0 _, num_combined_kv_heads, head_dim = new_kv.shape assert kv_cache.shape[1] == num_combined_kv_heads assert kv_cache.shape[2] == head_dim assert head_dim % 128 == 0 + # TODO: Add dynamic check to make sure that the all the slice lengths are + # smaller or equal to page_size in_specs = [ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), @@ -87,7 +92,7 @@ def kv_cache_update( scalar_prefetches = [slices] scratch = pltpu.VMEM( - (block_size, page_size, num_combined_kv_heads, head_dim), + (num_slices_per_block, page_size, num_combined_kv_heads, head_dim), new_kv.dtype, ) @@ -102,7 +107,7 @@ def kv_cache_update( num_scalar_prefetch=len(scalar_prefetches), in_specs=in_specs, out_specs=out_specs, - grid=(slices.shape[0] // block_size, ), + grid=(slices.shape[1] // num_slices_per_block, ), scratch_shapes=scratch_shapes, ), out_shape=out_shape, diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index aa155246bfbd..49f0772c62d1 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -111,7 +111,7 @@ class PallasMetadata: context_lens: torch.Tensor query_start_loc: torch.Tensor num_seqs: torch.Tensor - kv_cache_update_block_size: int + num_slices_per_kv_cache_update_block: int class PallasAttentionBackendImpl(AttentionImpl): @@ -217,10 +217,9 @@ def forward( # Write input keys and values to the KV cache. # Skip this if sharing KV cache with an earlier attention layer. slot_mapping = attn_metadata.slot_mapping - kv_cache_update_block_size = \ - attn_metadata.kv_cache_update_block_size - write_to_kv_cache(key, value, kv_cache, slot_mapping, - kv_cache_update_block_size) + write_to_kv_cache( + key, value, kv_cache, slot_mapping, + attn_metadata.num_slices_per_kv_cache_update_block) output = torch.ops.xla.ragged_paged_attention( query, @@ -252,7 +251,7 @@ def write_to_kv_cache( value: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache_update_block_size: int, + num_slices_per_kv_cache_update_block: int, ) -> None: """ Write the key and values to the KV cache. @@ -260,7 +259,7 @@ def write_to_kv_cache( key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] - kv_cache_update_block_size: int + num_slices_per_kv_cache_update_block: int """ _, page_size, num_combined_kv_heads, head_size = kv_cache.shape head_size = cdiv(head_size, @@ -272,7 +271,8 @@ def write_to_kv_cache( kv_cache = kv_cache.flatten(0, 1) new_kv_cache = torch.ops.xla.kv_cache_update_op( - kv, slot_mapping, kv_cache, page_size, kv_cache_update_block_size) + kv, slot_mapping, kv_cache, page_size, + num_slices_per_kv_cache_update_block) # NOTE: the in-place copy will be optimized away by XLA compiler. kv_cache.copy_(new_kv_cache) @@ -280,31 +280,31 @@ def write_to_kv_cache( @requires_jax def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, kv_cache: torch.Tensor, page_size: int, - block_size: int): + num_slices_per_block: int): from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), { "page_size": page_size, - "block_size": block_size + "num_slices_per_block": num_slices_per_block }) return new_kv_cache XLA_LIB.define( "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, " - "int page_size, int block_size) -> Tensor", ) + "int page_size, int num_slices_per_block) -> Tensor", ) @impl(XLA_LIB, "kv_cache_update_op", "XLA") def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, kv_cache: torch.Tensor, page_size: int, - block_size: int) -> torch.Tensor: + num_slices_per_block: int) -> torch.Tensor: new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, - page_size, block_size) + page_size, num_slices_per_block) return new_kv_cache @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, kv_cache: torch.Tensor, page_size: int, - block_size: int) -> torch.Tensor: + num_slices_per_block: int) -> torch.Tensor: return kv_cache diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 2d5f3e6fad6b..2ca45c191942 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -57,7 +57,7 @@ # Smallest output size MIN_NUM_SEQS = 8 # Block size used for kv cache updating kernel -KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE = 8 +NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8 ######################################################### @@ -720,6 +720,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", slot_mapping_metadata, [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], constant_values=0) + slot_mapping_metadata = np.transpose(slot_mapping_metadata) slot_mapping_metadata = torch.tensor(slot_mapping_metadata, device=self.device) @@ -742,7 +743,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device), - kv_cache_update_block_size=KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE, + num_slices_per_kv_cache_update_block= + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -1170,7 +1172,7 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, dtype=torch.int32).to(self.device) padded_num_slices = _get_padded_num_kv_cache_update_slices( num_tokens, self.max_num_reqs, self.block_size) - slot_mapping = torch.zeros((padded_num_slices, 3), + slot_mapping = torch.zeros((3, padded_num_slices), dtype=torch.int32).to(self.device) block_tables = torch.zeros( (num_reqs, num_blocks), @@ -1190,7 +1192,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, - kv_cache_update_block_size=KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE, + num_slices_per_kv_cache_update_block= + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, ) if self.is_multimodal_model: @@ -1802,8 +1805,9 @@ def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, padded_num_slices = 2 * max_num_reqs + num_tokens // page_size padded_num_slices = min(padded_num_slices, num_tokens) padded_num_slices = ( - padded_num_slices + KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE - 1 - ) // KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE * KV_CACHE_UPDATE_KERNEL_BLOCK_SIZE + padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1 + ) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \ + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK return padded_num_slices From 1dab6505ea7e6209d2fa2eda72d697306d3670ab Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Wed, 25 Jun 2025 19:35:46 +0000 Subject: [PATCH 7/8] fix comments Signed-off-by: Chengji Yao --- tests/v1/tpu/test_pallas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index 3087d720da63..e279edfffbc7 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -47,7 +47,7 @@ class FakeAttentionLayer: key = torch.zeros(num_tokens, num_kv_heads * head_size) value = torch.zeros(num_tokens, num_kv_heads * head_size) kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size) - slot_mapping = torch.zeros(num_tokens, dtype=torch.int64) + slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64) max_num_reqs = 8 max_num_blocks_per_req = 8 block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req), From c04d8974f1636b5c1b8cca6a79f70e21891fa3ab Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Wed, 25 Jun 2025 23:29:29 +0000 Subject: [PATCH 8/8] fix pre-submit Signed-off-by: Chengji Yao --- vllm/v1/worker/tpu_model_runner.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 2ca45c191942..bc334419c4ce 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1174,9 +1174,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, num_tokens, self.max_num_reqs, self.block_size) slot_mapping = torch.zeros((3, padded_num_slices), dtype=torch.int32).to(self.device) - block_tables = torch.zeros( - (num_reqs, num_blocks), - dtype=torch.int32).to(self.device) + block_tables = torch.zeros((num_reqs, num_blocks), + dtype=torch.int32).to(self.device) query_lens = [1] * num_reqs query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32),