-
-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[TPU] add kv cache update kernel #19928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| import torch | ||
| import torch_xla | ||
|
|
||
| import vllm.v1.attention.backends.pallas # noqa: F401 | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not current_platform.is_tpu(), | ||
| reason="This is a test for TPU only") | ||
| @pytest.mark.parametrize("page_size", [32, 33]) | ||
| @pytest.mark.parametrize("combined_kv_head_num", [2, 16]) | ||
| @pytest.mark.parametrize("head_dim", [128, 256]) | ||
| @pytest.mark.parametrize("num_slices_per_block", [4, 8]) | ||
| def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int, | ||
| head_dim: int, num_slices_per_block: int): | ||
| page_num = 1000 | ||
| padded_num_tokens = 128 | ||
| kv_cache_cpu = torch.zeros( | ||
| (page_num * page_size, combined_kv_head_num, head_dim), | ||
| dtype=torch.bfloat16, | ||
| device="cpu") | ||
| kv_cache_xla = kv_cache_cpu.to(torch_xla.device()) | ||
| new_kv_cpu = torch.randn( | ||
| (padded_num_tokens, combined_kv_head_num, head_dim), | ||
| dtype=torch.bfloat16, | ||
| device="cpu") | ||
| new_kv_xla = new_kv_cpu.to(torch_xla.device()) | ||
| slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], | ||
| dtype=np.int32) | ||
| kv_cache_start_indices = np.array([ | ||
| page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6, | ||
| page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3 | ||
| ], | ||
| dtype=np.int32) | ||
| new_kv_cache_indices = np.concatenate( | ||
| [np.array([0], dtype=np.int32), | ||
| np.cumsum(slice_lens[:-1])]) | ||
| slot_mapping = np.stack( | ||
| [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1) | ||
| padded_size = (slot_mapping.shape[0] + num_slices_per_block - | ||
| 1) // num_slices_per_block * num_slices_per_block | ||
| slot_mapping = np.pad(slot_mapping, | ||
| [[0, padded_size - slot_mapping.shape[0]], [0, 0]], | ||
| constant_values=0) | ||
| slot_mapping = np.transpose(slot_mapping) | ||
| slot_mapping_cpu = torch.tensor(slot_mapping, | ||
| device="cpu", | ||
| dtype=torch.int32) | ||
| slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device()) | ||
| torch_xla.sync() | ||
|
|
||
| torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True) | ||
| new_kv_cache_xla = torch.ops.xla.kv_cache_update_op( | ||
| new_kv_xla, slot_mapping_xla, kv_cache_xla, page_size, | ||
| num_slices_per_block) | ||
| kv_cache_xla.copy_(new_kv_cache_xla) | ||
| torch_xla.sync() | ||
|
|
||
| for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, | ||
| slice_lens): | ||
| kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :] | ||
|
|
||
| assert torch.allclose(kv_cache_xla.cpu(), | ||
| kv_cache_cpu, | ||
| atol=1e-4, | ||
| rtol=1e-4) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import functools | ||
|
|
||
| import jax | ||
| from jax.experimental import pallas as pl | ||
| from jax.experimental.pallas import tpu as pltpu | ||
|
|
||
|
|
||
| def _kv_cache_update_kernel( | ||
| # Prefetch | ||
| slices_ref, # [3, num_slices], list of (kv_cache_start, new_kv_start, | ||
| # slice_len) | ||
| # Input | ||
| new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim] | ||
| kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads, | ||
| # head_dim] | ||
| # Output | ||
| _, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] | ||
| # Scratch | ||
| scratch, # [num_slices_per_block, page_size, num_combined_kv_heads, | ||
| # head_dim] | ||
| sem, | ||
| ): | ||
| async_copies = [] | ||
| block_idx = pl.program_id(0) | ||
| num_slices_per_block = scratch.shape[0] | ||
|
|
||
| # Copy from new_kv_hbm_ref to scratch | ||
| for i in range(num_slices_per_block): | ||
| offset_i = i + block_idx * num_slices_per_block | ||
| new_kv_start = slices_ref[1, offset_i] | ||
| length = slices_ref[2, offset_i] | ||
| async_copy = pltpu.make_async_copy( | ||
| new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], | ||
| scratch.at[i, pl.ds(0, length), ...], | ||
|
||
| sem, | ||
| ) | ||
| async_copy.start() | ||
| async_copies.append(async_copy) | ||
|
|
||
| for async_copy in async_copies: | ||
| async_copy.wait() | ||
|
|
||
| # Copy from scratch to kv_cache_hbm_ref | ||
| async_copies.clear() | ||
| for i in range(num_slices_per_block): | ||
| offset_i = i + block_idx * num_slices_per_block | ||
| kv_cache_start = slices_ref[0, offset_i] | ||
| length = slices_ref[2, offset_i] | ||
| async_copy = pltpu.make_async_copy( | ||
| scratch.at[i, pl.ds(0, length), ...], | ||
| kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], | ||
| sem, | ||
| ) | ||
| async_copy.start() | ||
| async_copies.append(async_copy) | ||
| for async_copy in async_copies: | ||
| async_copy.wait() | ||
|
|
||
|
|
||
| @functools.partial( | ||
| jax.jit, | ||
| static_argnames=["page_size", "num_slices_per_block"], | ||
| ) | ||
| def kv_cache_update( | ||
| new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim] | ||
| slices: jax. | ||
| Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len) | ||
| kv_cache: jax. | ||
| Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] | ||
| *, | ||
| page_size: int = 32, | ||
| num_slices_per_block: int = 8, | ||
| ): | ||
| assert slices.shape[1] % num_slices_per_block == 0 | ||
| _, num_combined_kv_heads, head_dim = new_kv.shape | ||
| assert kv_cache.shape[1] == num_combined_kv_heads | ||
| assert kv_cache.shape[2] == head_dim | ||
| assert head_dim % 128 == 0 | ||
| # TODO: Add dynamic check to make sure that the all the slice lengths are | ||
| # smaller or equal to page_size | ||
|
|
||
| in_specs = [ | ||
| pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), | ||
| pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), | ||
| ] | ||
|
|
||
| out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)] | ||
| out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)] | ||
|
|
||
| scalar_prefetches = [slices] | ||
| scratch = pltpu.VMEM( | ||
| (num_slices_per_block, page_size, num_combined_kv_heads, head_dim), | ||
| new_kv.dtype, | ||
| ) | ||
|
|
||
| scratch_shapes = [ | ||
| scratch, | ||
| pltpu.SemaphoreType.DMA, | ||
| ] | ||
|
|
||
| kernel = pl.pallas_call( | ||
| _kv_cache_update_kernel, | ||
| grid_spec=pltpu.PrefetchScalarGridSpec( | ||
| num_scalar_prefetch=len(scalar_prefetches), | ||
| in_specs=in_specs, | ||
| out_specs=out_specs, | ||
| grid=(slices.shape[1] // num_slices_per_block, ), | ||
| scratch_shapes=scratch_shapes, | ||
| ), | ||
| out_shape=out_shape, | ||
| input_output_aliases={len(scalar_prefetches) + 1: 0}, | ||
|
||
| ) | ||
|
|
||
| return kernel(*scalar_prefetches, new_kv, kv_cache)[0] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,8 +5,12 @@ | |
| from typing import Any, Optional | ||
|
|
||
| import torch | ||
| # Required to register custom ops. | ||
| import torch_xla.core.xla_builder as xb | ||
| import torch_xla.experimental.custom_kernel # noqa: F401 | ||
| # Required to register custom ops. | ||
| from torch.library import impl | ||
| from torch_xla._internal.jax_workarounds import requires_jax | ||
| from torch_xla.experimental.custom_kernel import XLA_LIB | ||
|
|
||
| from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||
| AttentionLayer, AttentionType) | ||
|
|
@@ -107,6 +111,7 @@ class PallasMetadata: | |
| context_lens: torch.Tensor | ||
| query_start_loc: torch.Tensor | ||
| num_seqs: torch.Tensor | ||
| num_slices_per_kv_cache_update_block: int | ||
|
|
||
|
|
||
| class PallasAttentionBackendImpl(AttentionImpl): | ||
|
|
@@ -212,7 +217,9 @@ def forward( | |
| # Write input keys and values to the KV cache. | ||
| # Skip this if sharing KV cache with an earlier attention layer. | ||
| slot_mapping = attn_metadata.slot_mapping | ||
| write_to_kv_cache(key, value, kv_cache, slot_mapping) | ||
| write_to_kv_cache( | ||
| key, value, kv_cache, slot_mapping, | ||
| attn_metadata.num_slices_per_kv_cache_update_block) | ||
|
|
||
| output = torch.ops.xla.ragged_paged_attention( | ||
| query, | ||
|
|
@@ -244,16 +251,17 @@ def write_to_kv_cache( | |
| value: torch.Tensor, | ||
| kv_cache: torch.Tensor, | ||
| slot_mapping: torch.Tensor, | ||
| num_slices_per_kv_cache_update_block: int, | ||
| ) -> None: | ||
| """ Write the key and values to the KV cache. | ||
|
|
||
| Args: | ||
| key: shape = [num_tokens, num_kv_heads * head_size] | ||
| value: shape = [num_tokens, num_kv_heads * head_size] | ||
| kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] | ||
|
|
||
| num_slices_per_kv_cache_update_block: int | ||
| """ | ||
| _, _, num_combined_kv_heads, head_size = kv_cache.shape | ||
| _, page_size, num_combined_kv_heads, head_size = kv_cache.shape | ||
| head_size = cdiv(head_size, | ||
| TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT | ||
| kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, | ||
|
|
@@ -262,4 +270,41 @@ def write_to_kv_cache( | |
| torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True) | ||
|
|
||
| kv_cache = kv_cache.flatten(0, 1) | ||
| kv_cache.index_copy_(0, slot_mapping, kv) | ||
| new_kv_cache = torch.ops.xla.kv_cache_update_op( | ||
| kv, slot_mapping, kv_cache, page_size, | ||
| num_slices_per_kv_cache_update_block) | ||
| # NOTE: the in-place copy will be optimized away by XLA compiler. | ||
| kv_cache.copy_(new_kv_cache) | ||
|
|
||
|
|
||
| @requires_jax | ||
| def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, | ||
| kv_cache: torch.Tensor, page_size: int, | ||
| num_slices_per_block: int): | ||
| from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update | ||
| new_kv_cache = xb.call_jax(kv_cache_update, (kv, slot_mapping, kv_cache), { | ||
|
||
| "page_size": page_size, | ||
| "num_slices_per_block": num_slices_per_block | ||
| }) | ||
| return new_kv_cache | ||
|
|
||
|
|
||
| XLA_LIB.define( | ||
| "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache, " | ||
| "int page_size, int num_slices_per_block) -> Tensor", ) | ||
|
|
||
|
|
||
| @impl(XLA_LIB, "kv_cache_update_op", "XLA") | ||
| def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, | ||
| kv_cache: torch.Tensor, page_size: int, | ||
| num_slices_per_block: int) -> torch.Tensor: | ||
| new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, | ||
| page_size, num_slices_per_block) | ||
| return new_kv_cache | ||
|
|
||
|
|
||
| @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") | ||
| def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, | ||
| kv_cache: torch.Tensor, page_size: int, | ||
| num_slices_per_block: int) -> torch.Tensor: | ||
| return kv_cache | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we want to do this
torch.ops.xla.dynamo_set_buffer_donor_?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it should be an inplace-update.