diff --git a/python/sglang/jit_kernel/benchmark/bench_store_kv_cache.py b/python/sglang/jit_kernel/benchmark/bench_store_kv_cache.py new file mode 100644 index 000000000000..09108f6175a7 --- /dev/null +++ b/python/sglang/jit_kernel/benchmark/bench_store_kv_cache.py @@ -0,0 +1,109 @@ +import itertools +from typing import Tuple + +import torch +import triton +import triton.testing +from sgl_kernel import set_kv_buffer_kernel + +from sglang.jit_kernel.benchmark.utils import ( + DEFAULT_DEVICE, + DEFAULT_DTYPE, + DEFAULT_QUANTILES, + get_benchmark_range, +) +from sglang.jit_kernel.store import store_kv_cache + + +def sglang_aot_store_kv_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +) -> None: + set_kv_buffer_kernel(k_cache, v_cache, indices, k, v) + + +def sglang_jit_store_kv_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +) -> None: + store_kv_cache(k_cache, v_cache, indices, k, v) + + +NUM_LAYERS = 8 +CACHE_SIZE = 2 * 1024 * 1024 // NUM_LAYERS + +BS_RANGE = get_benchmark_range( + full_range=[2**n for n in range(0, 15)], + ci_range=[16], +) +ITEM_SIZE = get_benchmark_range( + full_range=[64, 128, 256, 512, 1024], + ci_range=[1024], +) + +LINE_VALS = ["aot", "jit"] +LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel"] +STYLES = [("orange", "-"), ("blue", "--")] +X_NAMES = ["item_size", "batch_size"] +CONFIGS = list(itertools.product(ITEM_SIZE, BS_RANGE)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=X_NAMES, + x_vals=CONFIGS, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="store-kv-cache-performance", + args={}, + ) +) +def benchmark( + batch_size: int, item_size: int, provider: str +) -> Tuple[float, float, float]: + k = torch.randn( + (NUM_LAYERS, batch_size, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + v = torch.randn( + (NUM_LAYERS, batch_size, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + k_cache = torch.randn( + (NUM_LAYERS, CACHE_SIZE, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + v_cache = torch.randn( + (NUM_LAYERS, CACHE_SIZE, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + indices = torch.randperm(CACHE_SIZE, device=DEFAULT_DEVICE)[:batch_size] + torch.cuda.synchronize() + + FN_MAP = { + "aot": sglang_aot_store_kv_cache, + "jit": sglang_jit_store_kv_cache, + } + + def fn(): + impl = FN_MAP[provider] + for i in range(NUM_LAYERS): + impl(k[i], v[i], k_cache[i], v_cache[i], indices) + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + fn, quantiles=DEFAULT_QUANTILES + ) + return ( + 1000 * ms / NUM_LAYERS, + 1000 * max_ms / NUM_LAYERS, + 1000 * min_ms / NUM_LAYERS, + ) + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/python/sglang/jit_kernel/csrc/memory/store.cuh b/python/sglang/jit_kernel/csrc/memory/store.cuh new file mode 100644 index 000000000000..992cfaeabe71 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/memory/store.cuh @@ -0,0 +1,173 @@ +// Adapted from https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/memory/store.cu +#pragma once + +#include + +#include + +#include +#include + +#include +#include + +namespace { + +using std::size_t; +using std::uint64_t; + +// Each warp will process 256 bytes per loop iteration +template +__global__ void store_kv_cache_256x1( + uint64_t* __restrict__ k_cache, + uint64_t* __restrict__ v_cache, + const T* __restrict__ out_loc, + const size_t length, + const uint64_t* __restrict__ k, + const uint64_t* __restrict__ v, + const size_t kv_cache_stride, + const size_t kv_input_stride, + const size_t num_items) { + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + const auto warp_id = idx / 32; + const auto lane_id = idx % 32; + if (warp_id >= length) return; + const auto offset = out_loc[warp_id]; + const auto k_dst = k_cache + offset * kv_cache_stride; + const auto v_dst = v_cache + offset * kv_cache_stride; + const auto k_src = k + warp_id * kv_input_stride; + const auto v_src = v + warp_id * kv_input_stride; + for (size_t i = 0; i < num_items; ++i) { + k_dst[lane_id + i * 32] = k_src[lane_id + i * 32]; + v_dst[lane_id + i * 32] = v_src[lane_id + i * 32]; + } +} + +// Each warp will process 128 bytes per loop iteration +template +__global__ void store_kv_cache_128x2( + uint64_t* __restrict__ k_cache, + uint64_t* __restrict__ v_cache, + const T* __restrict__ out_loc, + const size_t length, + const uint64_t* __restrict__ k, + const uint64_t* __restrict__ v, + const size_t kv_cache_stride, + const size_t kv_input_stride, + const size_t num_items) { + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + const auto warp_id = idx / 32; + const auto lane_id = idx % 32; + if (warp_id >= length) return; + const auto offset = out_loc[warp_id]; + const auto copy_k = lane_id < 16; + const auto copy_id = lane_id % 16; + const auto cache = copy_k ? k_cache : v_cache; + const auto input = copy_k ? k : v; + const auto dst = cache + offset * kv_cache_stride; + const auto src = input + warp_id * kv_input_stride; + for (size_t i = 0; i < num_items; ++i) { + dst[copy_id + i * 16] = src[copy_id + i * 16]; + } +} + +template +void dispatch_store_kv_cache( + uint64_t* k_cache_ptr, + uint64_t* v_cache_ptr, + const T* out_loc_ptr, + const size_t length, + const uint64_t* k_ptr, + const uint64_t* v_ptr, + const size_t kv_cache_stride, + const size_t kv_input_stride, + const int64_t size_bytes, + const int num_blocks, + const int num_threads, + cudaStream_t stream) { + if (size_bytes % 256 == 0) { + const size_t items_per_warp = static_cast(size_bytes / 256); + store_kv_cache_256x1<<>>( + k_cache_ptr, v_cache_ptr, out_loc_ptr, length, k_ptr, v_ptr, kv_cache_stride, kv_input_stride, items_per_warp); + } else if (size_bytes % 128 == 0) { + const size_t items_per_warp = static_cast(size_bytes / 128); + store_kv_cache_128x2<<>>( + k_cache_ptr, v_cache_ptr, out_loc_ptr, length, k_ptr, v_ptr, kv_cache_stride, kv_input_stride, items_per_warp); + } else { + host::Panic("Last dim size bytes of k/v must be divisible by 128, got: {}", size_bytes); + } +} + +// Expects 2D inputs: k_cache/v_cache shape (max_tokens, head_dim), +// k/v shape (num_tokens, head_dim), out_loc shape (num_tokens,). +void store_kv_cache( + tvm::ffi::TensorView k_cache, + tvm::ffi::TensorView v_cache, + tvm::ffi::TensorView out_loc, + tvm::ffi::TensorView k, + tvm::ffi::TensorView v) { + using namespace host; + + RuntimeCheck(k_cache.dim() == 2, "k_cache must be 2D"); + RuntimeCheck(v_cache.dim() == 2, "v_cache must be 2D"); + RuntimeCheck(k.dim() == 2, "k must be 2D"); + RuntimeCheck(v.dim() == 2, "v must be 2D"); + RuntimeCheck(out_loc.dim() == 1 && out_loc.is_contiguous(), "out_loc must be 1D contiguous"); + RuntimeCheck(k_cache.size(1) == v_cache.size(1), "k_cache and v_cache must have the same head dim"); + RuntimeCheck(k.size(1) == v.size(1), "k and v must have the same head dim"); + RuntimeCheck(k.size(1) == k_cache.size(1), "k and k_cache must have the same head dim"); + RuntimeCheck(k.stride(1) == 1 && k_cache.stride(1) == 1, "k and k_cache must be contiguous in head dim"); + static_assert(sizeof(uint64_t) == 8, "uint64_t must be 8 bytes"); + + const size_t length = static_cast(out_loc.size(0)); + const int64_t elem_size = k.dtype().bits / 8; + const int64_t size_bytes = elem_size * k.size(1); + const size_t kv_cache_stride = static_cast(elem_size * k_cache.stride(0) / 8); + const size_t kv_input_stride = static_cast(elem_size * k.stride(0) / 8); + + const auto k_cache_ptr = static_cast(k_cache.data_ptr()); + const auto v_cache_ptr = static_cast(v_cache.data_ptr()); + const auto k_ptr = static_cast(k.data_ptr()); + const auto v_ptr = static_cast(v.data_ptr()); + + constexpr int num_threads = 256; + constexpr int num_warps = num_threads / 32; + const int num_blocks = static_cast((length + num_warps - 1) / num_warps); + + const auto device = k_cache.device(); + const auto stream = static_cast(TVMFFIEnvGetStream(device.device_type, device.device_id)); + + if (host::is_type(out_loc.dtype())) { + dispatch_store_kv_cache( + k_cache_ptr, + v_cache_ptr, + static_cast(out_loc.data_ptr()), + length, + k_ptr, + v_ptr, + kv_cache_stride, + kv_input_stride, + size_bytes, + num_blocks, + num_threads, + stream); + } else if (host::is_type(out_loc.dtype())) { + dispatch_store_kv_cache( + k_cache_ptr, + v_cache_ptr, + static_cast(out_loc.data_ptr()), + length, + k_ptr, + v_ptr, + kv_cache_stride, + kv_input_stride, + size_bytes, + num_blocks, + num_threads, + stream); + } else { + RuntimeCheck(false, "out_loc must be int32 or int64"); + } +} + +} // namespace diff --git a/python/sglang/jit_kernel/store.py b/python/sglang/jit_kernel/store.py new file mode 100644 index 000000000000..62c9ad99c8e8 --- /dev/null +++ b/python/sglang/jit_kernel/store.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_store_module() -> Module: + return load_jit( + "store", + cuda_files=["memory/store.cuh"], + cuda_wrappers=[("store_kv_cache", "store_kv_cache")], + ) + + +def store_kv_cache( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + out_loc: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, +) -> None: + """Store key and value tensors into KV cache at specified indices. + + Args: + k_cache: Key cache tensor, first dim is max_tokens. + v_cache: Value cache tensor, first dim is max_tokens. + out_loc: Token indices, shape (num_tokens,), dtype int32 or int64. + k: Key tensor, first dim is num_tokens. + v: Value tensor, first dim is num_tokens. + """ + max_tokens = k_cache.size(0) + num_tokens = out_loc.size(0) + module = _jit_store_module() + module.store_kv_cache( + k_cache.view(max_tokens, -1), + v_cache.view(max_tokens, -1), + out_loc, + k.view(num_tokens, -1), + v.view(num_tokens, -1), + ) diff --git a/python/sglang/jit_kernel/tests/test_store_kv_cache.py b/python/sglang/jit_kernel/tests/test_store_kv_cache.py new file mode 100644 index 000000000000..5149ebeac0e3 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_store_kv_cache.py @@ -0,0 +1,45 @@ +""" +Correctness tests for the store JIT kernel. + +Tests the JIT-compiled store_kv_cache against direct tensor indexing. +""" + +import itertools + +import pytest +import torch + +from sglang.jit_kernel.store import store_kv_cache + +CACHE_SIZE = 1024 +DTYPES = [torch.float16, torch.bfloat16, torch.float32] +INDEX_DTYPES = [torch.int32, torch.int64] +BATCH_SIZES = [1, 4, 16, 64, 128] +HEAD_DIMS = [64, 128, 256, 512] +DEVICE = "cuda" + + +@pytest.mark.parametrize( + "batch_size,head_dim,dtype,index_dtype", + list(itertools.product(BATCH_SIZES, HEAD_DIMS, DTYPES, INDEX_DTYPES)), +) +def test_store_kv_cache( + batch_size: int, + head_dim: int, + dtype: torch.dtype, + index_dtype: torch.dtype, +) -> None: + k = torch.randn((batch_size, head_dim), dtype=dtype, device=DEVICE) + v = torch.randn((batch_size, head_dim), dtype=dtype, device=DEVICE) + k_cache = torch.zeros((CACHE_SIZE, head_dim), dtype=dtype, device=DEVICE) + v_cache = torch.zeros((CACHE_SIZE, head_dim), dtype=dtype, device=DEVICE) + indices = torch.randperm(CACHE_SIZE, device=DEVICE)[:batch_size].to(index_dtype) + + store_kv_cache(k_cache, v_cache, indices, k, v) + + assert torch.all(k_cache[indices] == k), "k mismatch" + assert torch.all(v_cache[indices] == v), "v mismatch" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/sgl-kernel/python/sgl_kernel/memory.py b/sgl-kernel/python/sgl_kernel/memory.py index 9ff4f957d704..ee0c85c5792b 100644 --- a/sgl-kernel/python/sgl_kernel/memory.py +++ b/sgl-kernel/python/sgl_kernel/memory.py @@ -1,5 +1,11 @@ import torch +_jit_store = None +try: + import sglang.jit_kernel.store as _jit_store +except Exception: + pass + def set_kv_buffer_kernel( k_cache: torch.Tensor, @@ -12,7 +18,10 @@ def set_kv_buffer_kernel( try: if fallback: raise RuntimeError("Fallback to torch implementation") - torch.ops.sgl_kernel.store_kv_cache(k_cache, v_cache, loc, k, v) + if _jit_store is not None: + _jit_store.store_kv_cache(k_cache, v_cache, loc, k, v) + else: + torch.ops.sgl_kernel.store_kv_cache(k_cache, v_cache, loc, k, v) except RuntimeError: # ok, fallback to torch implementation k_cache[loc] = k v_cache[loc] = v