Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_store_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import itertools
from typing import Tuple

import torch
import triton
import triton.testing
from sgl_kernel import set_kv_buffer_kernel

from sglang.jit_kernel.benchmark.utils import (
DEFAULT_DEVICE,
DEFAULT_DTYPE,
DEFAULT_QUANTILES,
get_benchmark_range,
)
from sglang.jit_kernel.store import store_kv_cache


def sglang_aot_store_kv_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
indices: torch.Tensor,
) -> None:
set_kv_buffer_kernel(k_cache, v_cache, indices, k, v)


def sglang_jit_store_kv_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
indices: torch.Tensor,
) -> None:
store_kv_cache(k_cache, v_cache, indices, k, v)


NUM_LAYERS = 8
CACHE_SIZE = 2 * 1024 * 1024 // NUM_LAYERS

BS_RANGE = get_benchmark_range(
full_range=[2**n for n in range(0, 15)],
ci_range=[16],
)
ITEM_SIZE = get_benchmark_range(
full_range=[64, 128, 256, 512, 1024],
ci_range=[1024],
)

LINE_VALS = ["aot", "jit"]
LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel"]
STYLES = [("orange", "-"), ("blue", "--")]
X_NAMES = ["item_size", "batch_size"]
CONFIGS = list(itertools.product(ITEM_SIZE, BS_RANGE))


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=X_NAMES,
x_vals=CONFIGS,
line_arg="provider",
line_vals=LINE_VALS,
line_names=LINE_NAMES,
styles=STYLES,
ylabel="us",
plot_name="store-kv-cache-performance",
args={},
)
)
def benchmark(
batch_size: int, item_size: int, provider: str
) -> Tuple[float, float, float]:
k = torch.randn(
(NUM_LAYERS, batch_size, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE
)
v = torch.randn(
(NUM_LAYERS, batch_size, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE
)
k_cache = torch.randn(
(NUM_LAYERS, CACHE_SIZE, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE
)
v_cache = torch.randn(
(NUM_LAYERS, CACHE_SIZE, item_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE
)
indices = torch.randperm(CACHE_SIZE, device=DEFAULT_DEVICE)[:batch_size]
torch.cuda.synchronize()

FN_MAP = {
"aot": sglang_aot_store_kv_cache,
"jit": sglang_jit_store_kv_cache,
}

def fn():
impl = FN_MAP[provider]
for i in range(NUM_LAYERS):
impl(k[i], v[i], k_cache[i], v_cache[i], indices)

ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
fn, quantiles=DEFAULT_QUANTILES
)
return (
1000 * ms / NUM_LAYERS,
1000 * max_ms / NUM_LAYERS,
1000 * min_ms / NUM_LAYERS,
)


if __name__ == "__main__":
benchmark.run(print_data=True)
173 changes: 173 additions & 0 deletions python/sglang/jit_kernel/csrc/memory/store.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Adapted from https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/memory/store.cu
#pragma once

#include <sgl_kernel/tensor.h>

#include <sgl_kernel/utils.cuh>

#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/c_env_api.h>

#include <cstddef>
#include <cstdint>

namespace {

using std::size_t;
using std::uint64_t;

// Each warp will process 256 bytes per loop iteration
template <typename T>
__global__ void store_kv_cache_256x1(
uint64_t* __restrict__ k_cache,
uint64_t* __restrict__ v_cache,
const T* __restrict__ out_loc,
const size_t length,
const uint64_t* __restrict__ k,
const uint64_t* __restrict__ v,
const size_t kv_cache_stride,
const size_t kv_input_stride,
const size_t num_items) {
const auto idx = blockIdx.x * blockDim.x + threadIdx.x;
const auto warp_id = idx / 32;
const auto lane_id = idx % 32;
if (warp_id >= length) return;
const auto offset = out_loc[warp_id];
const auto k_dst = k_cache + offset * kv_cache_stride;
const auto v_dst = v_cache + offset * kv_cache_stride;
const auto k_src = k + warp_id * kv_input_stride;
const auto v_src = v + warp_id * kv_input_stride;
for (size_t i = 0; i < num_items; ++i) {
k_dst[lane_id + i * 32] = k_src[lane_id + i * 32];
v_dst[lane_id + i * 32] = v_src[lane_id + i * 32];
}
}

// Each warp will process 128 bytes per loop iteration
template <typename T>
__global__ void store_kv_cache_128x2(
uint64_t* __restrict__ k_cache,
uint64_t* __restrict__ v_cache,
const T* __restrict__ out_loc,
const size_t length,
const uint64_t* __restrict__ k,
const uint64_t* __restrict__ v,
const size_t kv_cache_stride,
const size_t kv_input_stride,
const size_t num_items) {
const auto idx = blockIdx.x * blockDim.x + threadIdx.x;
const auto warp_id = idx / 32;
const auto lane_id = idx % 32;
if (warp_id >= length) return;
const auto offset = out_loc[warp_id];
const auto copy_k = lane_id < 16;
const auto copy_id = lane_id % 16;
const auto cache = copy_k ? k_cache : v_cache;
const auto input = copy_k ? k : v;
const auto dst = cache + offset * kv_cache_stride;
const auto src = input + warp_id * kv_input_stride;
for (size_t i = 0; i < num_items; ++i) {
dst[copy_id + i * 16] = src[copy_id + i * 16];
}
}

template <typename T>
void dispatch_store_kv_cache(
uint64_t* k_cache_ptr,
uint64_t* v_cache_ptr,
const T* out_loc_ptr,
const size_t length,
const uint64_t* k_ptr,
const uint64_t* v_ptr,
const size_t kv_cache_stride,
const size_t kv_input_stride,
const int64_t size_bytes,
const int num_blocks,
const int num_threads,
cudaStream_t stream) {
if (size_bytes % 256 == 0) {
const size_t items_per_warp = static_cast<size_t>(size_bytes / 256);
store_kv_cache_256x1<<<num_blocks, num_threads, 0, stream>>>(
k_cache_ptr, v_cache_ptr, out_loc_ptr, length, k_ptr, v_ptr, kv_cache_stride, kv_input_stride, items_per_warp);
} else if (size_bytes % 128 == 0) {
const size_t items_per_warp = static_cast<size_t>(size_bytes / 128);
store_kv_cache_128x2<<<num_blocks, num_threads, 0, stream>>>(
k_cache_ptr, v_cache_ptr, out_loc_ptr, length, k_ptr, v_ptr, kv_cache_stride, kv_input_stride, items_per_warp);
} else {
host::Panic("Last dim size bytes of k/v must be divisible by 128, got: {}", size_bytes);
}
}

// Expects 2D inputs: k_cache/v_cache shape (max_tokens, head_dim),
// k/v shape (num_tokens, head_dim), out_loc shape (num_tokens,).
void store_kv_cache(
tvm::ffi::TensorView k_cache,
tvm::ffi::TensorView v_cache,
tvm::ffi::TensorView out_loc,
tvm::ffi::TensorView k,
tvm::ffi::TensorView v) {
using namespace host;

RuntimeCheck(k_cache.dim() == 2, "k_cache must be 2D");
RuntimeCheck(v_cache.dim() == 2, "v_cache must be 2D");
RuntimeCheck(k.dim() == 2, "k must be 2D");
RuntimeCheck(v.dim() == 2, "v must be 2D");
RuntimeCheck(out_loc.dim() == 1 && out_loc.is_contiguous(), "out_loc must be 1D contiguous");
RuntimeCheck(k_cache.size(1) == v_cache.size(1), "k_cache and v_cache must have the same head dim");
RuntimeCheck(k.size(1) == v.size(1), "k and v must have the same head dim");
RuntimeCheck(k.size(1) == k_cache.size(1), "k and k_cache must have the same head dim");
RuntimeCheck(k.stride(1) == 1 && k_cache.stride(1) == 1, "k and k_cache must be contiguous in head dim");
static_assert(sizeof(uint64_t) == 8, "uint64_t must be 8 bytes");

const size_t length = static_cast<size_t>(out_loc.size(0));
const int64_t elem_size = k.dtype().bits / 8;
const int64_t size_bytes = elem_size * k.size(1);
const size_t kv_cache_stride = static_cast<size_t>(elem_size * k_cache.stride(0) / 8);
const size_t kv_input_stride = static_cast<size_t>(elem_size * k.stride(0) / 8);

const auto k_cache_ptr = static_cast<uint64_t*>(k_cache.data_ptr());
const auto v_cache_ptr = static_cast<uint64_t*>(v_cache.data_ptr());
const auto k_ptr = static_cast<const uint64_t*>(k.data_ptr());
const auto v_ptr = static_cast<const uint64_t*>(v.data_ptr());

constexpr int num_threads = 256;
constexpr int num_warps = num_threads / 32;
const int num_blocks = static_cast<int>((length + num_warps - 1) / num_warps);

const auto device = k_cache.device();
const auto stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type, device.device_id));

if (host::is_type<int32_t>(out_loc.dtype())) {
dispatch_store_kv_cache<int32_t>(
k_cache_ptr,
v_cache_ptr,
static_cast<const int32_t*>(out_loc.data_ptr()),
length,
k_ptr,
v_ptr,
kv_cache_stride,
kv_input_stride,
size_bytes,
num_blocks,
num_threads,
stream);
} else if (host::is_type<int64_t>(out_loc.dtype())) {
dispatch_store_kv_cache<int64_t>(
k_cache_ptr,
v_cache_ptr,
static_cast<const int64_t*>(out_loc.data_ptr()),
length,
k_ptr,
v_ptr,
kv_cache_stride,
kv_input_stride,
size_bytes,
num_blocks,
num_threads,
stream);
} else {
RuntimeCheck(false, "out_loc must be int32 or int64");
}
}

} // namespace
47 changes: 47 additions & 0 deletions python/sglang/jit_kernel/store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from sglang.jit_kernel.utils import cache_once, load_jit

if TYPE_CHECKING:
from tvm_ffi.module import Module


@cache_once
def _jit_store_module() -> Module:
return load_jit(
"store",
cuda_files=["memory/store.cuh"],
cuda_wrappers=[("store_kv_cache", "store_kv_cache")],
)


def store_kv_cache(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
out_loc: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> None:
"""Store key and value tensors into KV cache at specified indices.

Args:
k_cache: Key cache tensor, first dim is max_tokens.
v_cache: Value cache tensor, first dim is max_tokens.
out_loc: Token indices, shape (num_tokens,), dtype int32 or int64.
k: Key tensor, first dim is num_tokens.
v: Value tensor, first dim is num_tokens.
"""
max_tokens = k_cache.size(0)
num_tokens = out_loc.size(0)
module = _jit_store_module()
module.store_kv_cache(
k_cache.view(max_tokens, -1),
v_cache.view(max_tokens, -1),
out_loc,
k.view(num_tokens, -1),
v.view(num_tokens, -1),
)
45 changes: 45 additions & 0 deletions python/sglang/jit_kernel/tests/test_store_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
Correctness tests for the store JIT kernel.

Tests the JIT-compiled store_kv_cache against direct tensor indexing.
"""

import itertools

import pytest
import torch

from sglang.jit_kernel.store import store_kv_cache

CACHE_SIZE = 1024
DTYPES = [torch.float16, torch.bfloat16, torch.float32]
INDEX_DTYPES = [torch.int32, torch.int64]
BATCH_SIZES = [1, 4, 16, 64, 128]
HEAD_DIMS = [64, 128, 256, 512]
DEVICE = "cuda"


@pytest.mark.parametrize(
"batch_size,head_dim,dtype,index_dtype",
list(itertools.product(BATCH_SIZES, HEAD_DIMS, DTYPES, INDEX_DTYPES)),
)
def test_store_kv_cache(
batch_size: int,
head_dim: int,
dtype: torch.dtype,
index_dtype: torch.dtype,
) -> None:
k = torch.randn((batch_size, head_dim), dtype=dtype, device=DEVICE)
v = torch.randn((batch_size, head_dim), dtype=dtype, device=DEVICE)
k_cache = torch.zeros((CACHE_SIZE, head_dim), dtype=dtype, device=DEVICE)
v_cache = torch.zeros((CACHE_SIZE, head_dim), dtype=dtype, device=DEVICE)
indices = torch.randperm(CACHE_SIZE, device=DEVICE)[:batch_size].to(index_dtype)

store_kv_cache(k_cache, v_cache, indices, k, v)

assert torch.all(k_cache[indices] == k), "k mismatch"
assert torch.all(v_cache[indices] == v), "v mismatch"


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading
Loading