Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/jit_kernel/benchmark/bench_qknorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def flashinfer_qknorm(
q_weight: torch.Tensor,
k_weight: torch.Tensor,
) -> None:
from flashinfer import rmsnorm

rmsnorm(q, q_weight, out=q)
rmsnorm(k, k_weight, out=k)
Expand Down
133 changes: 133 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_store_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import itertools
from typing import Tuple

import torch
import triton
import triton.testing
from sgl_kernel import set_kv_buffer_kernel

from sglang.jit_kernel.benchmark.utils import is_in_ci
from sglang.jit_kernel.kvcache import store_cache

IS_CI = is_in_ci()


def sglang_aot_store_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
indices: torch.Tensor,
) -> None:
set_kv_buffer_kernel(k_cache, v_cache, indices, k, v)


def sglang_jit_store_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
indices: torch.Tensor,
) -> None:
store_cache(k, v, k_cache, v_cache, indices)


@torch.compile()
def torch_compile_store_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
indices: torch.Tensor,
) -> None:
k_cache[indices] = k
v_cache[indices] = v


alt_stream = torch.cuda.Stream()


def torch_streams_store_cache(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
indices: torch.Tensor,
) -> None:
current_stream = torch.cuda.current_stream()
alt_stream.wait_stream(current_stream)
k_cache[indices] = k
with torch.cuda.stream(alt_stream):
v_cache[indices] = v
current_stream.wait_stream(alt_stream)


DTYPE = torch.bfloat16
DEVICE = "cuda"
NUM_LAYERS = 8
CACHE_SIZE = 2 * 1024 * 1024 // NUM_LAYERS

if IS_CI:
BS_RANGE = [16]
ITEM_SIZE = [1024]
else:
BS_RANGE = [2**n for n in range(0, 15)]
ITEM_SIZE = [64, 128, 256, 512, 1024]

LINE_VALS = ["aot", "jit", "torch_compile", "torch_streams"]
LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "PyTorch Compile", "PyTorch 2 Stream"]
STYLES = [("orange", "-"), ("blue", "--"), ("red", ":"), ("green", "-.")]
X_NAMES = ["item_size", "batch_size"]
CONFIGS = list(itertools.product(ITEM_SIZE, BS_RANGE))


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=X_NAMES,
x_vals=CONFIGS,
line_arg="provider",
line_vals=LINE_VALS,
line_names=LINE_NAMES,
styles=STYLES,
ylabel="us",
plot_name="store-kvcache-performance",
args={},
)
)
def benchmark(
batch_size: int, item_size: int, provider: str
) -> Tuple[float, float, float]:
k = torch.randn((NUM_LAYERS, batch_size, item_size), dtype=DTYPE, device=DEVICE)
v = torch.randn((NUM_LAYERS, batch_size, item_size), dtype=DTYPE, device=DEVICE)
k_cache = torch.randn(
(NUM_LAYERS, CACHE_SIZE, item_size), dtype=DTYPE, device=DEVICE
)
v_cache = torch.randn(
(NUM_LAYERS, CACHE_SIZE, item_size), dtype=DTYPE, device=DEVICE
)
indices = torch.randperm(CACHE_SIZE, device=DEVICE)[:batch_size]
torch.cuda.synchronize()

FN_MAP = {
"aot": sglang_aot_store_cache,
"jit": sglang_jit_store_cache,
"torch_compile": torch_compile_store_cache,
"torch_streams": torch_streams_store_cache,
}

def fn():
impl = FN_MAP[provider]
for i in range(NUM_LAYERS):
impl(k[i], v[i], k_cache[i], v_cache[i], indices)

quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) # type: ignore
return (
1000 * ms / NUM_LAYERS,
1000 * max_ms / NUM_LAYERS,
1000 * min_ms / NUM_LAYERS,
)


if __name__ == "__main__":
benchmark.run(print_data=True)
8 changes: 8 additions & 0 deletions python/sglang/jit_kernel/benchmark/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import os


def is_in_ci():
return (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
181 changes: 181 additions & 0 deletions python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#include <sgl_kernel/tensor.h>
#include <sgl_kernel/utils.cuh>
#include <sgl_kernel/utils.h>
#include <sgl_kernel/vec.cuh>
#include <sgl_kernel/warp.cuh>

#include <dlpack/dlpack.h>
#include <tvm/ffi/container/tensor.h>

#include <cstdint>

namespace {

struct StoreKVCacheParams {
const void* __restrict__ k;
const void* __restrict__ v;
void* __restrict__ k_cache;
void* __restrict__ v_cache;
const void* __restrict__ indices;
int64_t stride_k_bytes;
int64_t stride_v_bytes;
int64_t stride_cache_bytes;
int64_t stride_indices;
uint32_t batch_size;
};

constexpr uint32_t kNumWarps = 4;
constexpr uint32_t kThreadsPerBlock = kNumWarps * device::kWarpThreads;

template <int64_t kElementBytes>
__device__ void copy_impl(
const void* __restrict__ k_src,
const void* __restrict__ v_src,
void* __restrict__ k_dst,
void* __restrict__ v_dst) {
using namespace device;
constexpr int64_t kAlignment = (kElementBytes % (16 * kWarpThreads) == 0) ? 16
: kElementBytes % (8 * kWarpThreads) == 0 ? 8
: kElementBytes % (4 * kWarpThreads) == 0 ? 4
: kElementBytes % 4 == 0 ? 4
: 0;

static_assert(kAlignment > 0, "Element size must be multiple of 4 bytes");

using vec_t = aligned_vector<uint32_t, kAlignment / 4>;
constexpr auto kLoopBytes = sizeof(vec_t) * kWarpThreads;
constexpr auto kLoopCount = kElementBytes / kLoopBytes;

#pragma unroll kLoopCount
for (int64_t i = 0; i < kLoopCount; ++i) {
const auto k = warp::load<vec_t>(pointer::offset(k_src, i * kLoopBytes));
const auto v = warp::load<vec_t>(pointer::offset(v_src, i * kLoopBytes));
warp::store(pointer::offset(k_dst, i * kLoopBytes), k);
warp::store(pointer::offset(v_dst, i * kLoopBytes), v);
}

// handle the epilogue if any
if constexpr (kLoopCount * kLoopBytes < kElementBytes) {
constexpr auto kOffset = kLoopCount * kLoopBytes;
if ((threadIdx.x % kWarpThreads) * sizeof(vec_t) < kElementBytes - kOffset) {
const auto k = warp::load<vec_t>(pointer::offset(k_src, kOffset));
const auto v = warp::load<vec_t>(pointer::offset(v_src, kOffset));
warp::store(pointer::offset(k_dst, kOffset), k);
warp::store(pointer::offset(v_dst, kOffset), v);
}
}
}

// Each warp handles one item
template <int64_t kElementBytes, int kSplit, bool kUsePDL, typename T>
__global__ void store_kvcache(const __grid_constant__ StoreKVCacheParams params) {
using namespace device;
constexpr auto kSplitSize = kElementBytes / kSplit;
const uint32_t warp_id = blockIdx.x * kNumWarps + threadIdx.x / kWarpThreads;
const uint32_t item_id = warp_id / kSplit;
const uint32_t split_id = warp_id % kSplit;
const auto& [
k_input, v_input, k_cache, v_cache, indices, // ptr
stride_k, stride_v, stride_cache, stride_indices, batch_size // size
] = params;
if (item_id >= batch_size) return;

const auto index_ptr = static_cast<const T*>(indices) + item_id * stride_indices;
PDLWaitPrimary<kUsePDL>();

const auto index = *index_ptr;
const auto k_src = pointer::offset(k_input, item_id * stride_k, split_id * kSplitSize);
const auto v_src = pointer::offset(v_input, item_id * stride_v, split_id * kSplitSize);
const auto k_dst = pointer::offset(k_cache, index * stride_cache, split_id * kSplitSize);
const auto v_dst = pointer::offset(v_cache, index * stride_cache, split_id * kSplitSize);

copy_impl<kSplitSize>(k_src, v_src, k_dst, v_dst);
PDLTriggerSecondary<kUsePDL>();
}

template <int64_t kElementBytes, bool kUsePDL>
struct StoreKVCacheKernel {
static_assert(kElementBytes > 0 && kElementBytes % 4 == 0);

template <int kSplit, typename T>
static constexpr auto store_kernel = store_kvcache<kElementBytes, kSplit, kUsePDL, T>;

template <typename T>
static auto get_kernel(const int num_split) {
using namespace host;
// only apply split optimization when element size is aligned
if constexpr (kElementBytes % (4 * 128) == 0) {
if (num_split == 4) return store_kernel<4, T>;
}
if constexpr (kElementBytes % (2 * 128) == 0) {
if (num_split == 2) return store_kernel<2, T>;
}
if (num_split == 1) return store_kernel<1, T>;
Panic("Unsupported num_split {} for element size {}", num_split, kElementBytes);
}

static void
run(const tvm::ffi::TensorView k,
const tvm::ffi::TensorView v,
const tvm::ffi::TensorView k_cache,
const tvm::ffi::TensorView v_cache,
const tvm::ffi::TensorView indices,
const int num_split) {
using namespace host;
auto B = SymbolicSize{"batch_size"};
auto D = SymbolicSize{"element_size"};
auto KS = SymbolicSize{"k_stride"};
auto VS = SymbolicSize{"v_stride"};
auto S = SymbolicSize{"cache_stride"};
auto I = SymbolicSize{"indices_stride"};
auto dtype = SymbolicDType{};
auto device = SymbolicDevice{};
device.set_options<kDLCUDA>();

TensorMatcher({B, D}) //
.with_strides({KS, 1})
.with_dtype(dtype)
.with_device(device)
.verify(k);
TensorMatcher({B, D}) //
.with_strides({VS, 1})
.with_dtype(dtype)
.with_device(device)
.verify(v);
TensorMatcher({-1, D}) //
.with_strides({S, 1})
.with_dtype(dtype)
.with_device(device)
.verify(k_cache)
.verify(v_cache);
TensorMatcher({B}) //
.with_strides({I})
.with_dtype<int32_t, int64_t>()
.with_device(device)
.verify(indices);

const int64_t dtype_size = dtype_bytes(dtype.unwrap());
const uint32_t num_elements = static_cast<uint32_t>(B.unwrap());
RuntimeCheck(kElementBytes == dtype_size * D.unwrap());

const auto params = StoreKVCacheParams{
.k = k.data_ptr(),
.v = v.data_ptr(),
.k_cache = k_cache.data_ptr(),
.v_cache = v_cache.data_ptr(),
.indices = indices.data_ptr(),
.stride_k_bytes = KS.unwrap() * dtype_size,
.stride_v_bytes = VS.unwrap() * dtype_size,
.stride_cache_bytes = S.unwrap() * dtype_size,
.stride_indices = I.unwrap(),
.batch_size = static_cast<uint32_t>(B.unwrap()),
};
// select kernel and update num_split if needed
const auto kernel = dtype.is_type<int32_t>() ? get_kernel<int32_t>(num_split) : get_kernel<int64_t>(num_split);
const auto num_blocks = div_ceil(num_elements * num_split, kNumWarps);
LaunchKernel(num_blocks, kThreadsPerBlock, device.unwrap()) //
.enable_pdl(kUsePDL)(kernel, params);
}
};

} // namespace
Loading
Loading