diff --git a/vllm/v1/attention/ops/pallas_kv_cache_update.py b/vllm/v1/attention/ops/pallas_kv_cache_update.py deleted file mode 100644 index 51214b02271a..000000000000 --- a/vllm/v1/attention/ops/pallas_kv_cache_update.py +++ /dev/null @@ -1,130 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import functools - -import jax -from jax.experimental import pallas as pl -from jax.experimental.pallas import tpu as pltpu - -from vllm.utils.math_utils import cdiv - - -def _kv_cache_update_kernel( - # Prefetch - slices_ref, # [3, padded_num_slices], list of (kv_cache_start, - # new_kv_start, slice_len) - num_slices_ref, # [1] - # Input - new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim] - kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads, - # head_dim] - # Output - _, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] - # Scratch - scratch, # [num_slices_per_block, page_size, num_combined_kv_heads, - # head_dim] - sem, -): - async_copies = [] - block_idx = pl.program_id(0) - num_slices_per_block = scratch.shape[0] - - # Copy from new_kv_hbm_ref to scratch - for i in range(num_slices_per_block): - offset_i = i + block_idx * num_slices_per_block - new_kv_start = jax.lax.select( - offset_i < num_slices_ref[0], slices_ref[1, offset_i], 0 - ) - length = jax.lax.select( - offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0 - ) - async_copy = pltpu.make_async_copy( - new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], - scratch.at[i, pl.ds(0, length), ...], - sem, - ) - async_copy.start() - async_copies.append(async_copy) - - for async_copy in async_copies: - async_copy.wait() - - # Copy from scratch to kv_cache_hbm_ref - async_copies.clear() - for i in range(num_slices_per_block): - offset_i = i + block_idx * num_slices_per_block - kv_cache_start = jax.lax.select( - offset_i < num_slices_ref[0], slices_ref[0, offset_i], 0 - ) - length = jax.lax.select( - offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0 - ) - async_copy = pltpu.make_async_copy( - scratch.at[i, pl.ds(0, length), ...], - kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], - sem, - ) - async_copy.start() - async_copies.append(async_copy) - for async_copy in async_copies: - async_copy.wait() - - -@functools.partial( - jax.jit, - static_argnames=["page_size", "num_slices_per_block"], -) -def kv_cache_update( - # [total_num_token, num_combined_kv_heads, head_dim] - new_kv: jax.Array, - # [3, slices], list of (kv_cache_start, new_kv_start, slice_len) - slices: jax.Array, - # [total_num_pages * page_size, num_combined_kv_heads, head_dim] - kv_cache: jax.Array, - # [1] - num_kv_update_slices: jax.Array, - *, - page_size: int = 32, - num_slices_per_block: int = 8, -): - _, num_combined_kv_heads, head_dim = new_kv.shape - assert kv_cache.shape[1] == num_combined_kv_heads - assert kv_cache.shape[2] == head_dim - assert head_dim % 128 == 0 - # TODO: Add dynamic check to make sure that the all the slice lengths are - # smaller or equal to page_size - - in_specs = [ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - ] - - out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)] - out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)] - - scalar_prefetches = [slices, num_kv_update_slices] - scratch = pltpu.VMEM( - (num_slices_per_block, page_size, num_combined_kv_heads, head_dim), - new_kv.dtype, - ) - - scratch_shapes = [ - scratch, - pltpu.SemaphoreType.DMA, - ] - - kernel = pl.pallas_call( - _kv_cache_update_kernel, - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=len(scalar_prefetches), - in_specs=in_specs, - out_specs=out_specs, - grid=(cdiv(num_kv_update_slices[0], num_slices_per_block),), - scratch_shapes=scratch_shapes, - ), - out_shape=out_shape, - input_output_aliases={len(scalar_prefetches) + 1: 0}, - ) - - return kernel(*scalar_prefetches, new_kv, kv_cache)[0]