From 33e24a0def10308c9a07b6de3e974f8ca3cff17a Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Wed, 6 Aug 2025 19:04:47 -0700 Subject: [PATCH 1/3] feat: support custom set kv buffer kernel --- python/sglang/srt/mem_cache/memory_pool.py | 21 ++- sgl-kernel/CMakeLists.txt | 1 + sgl-kernel/csrc/common_extension.cc | 6 + sgl-kernel/csrc/memory/store.cu | 166 +++++++++++++++++++++ sgl-kernel/include/sgl_kernel_ops.h | 5 + sgl-kernel/python/sgl_kernel/__init__.py | 1 + sgl-kernel/python/sgl_kernel/memory.py | 26 ++++ 7 files changed, 215 insertions(+), 11 deletions(-) create mode 100644 sgl-kernel/csrc/memory/store.cu create mode 100644 sgl-kernel/python/sgl_kernel/memory.py diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index cc3faea0a03d..a8128421b9de 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -33,6 +33,7 @@ import torch import triton import triton.language as tl +from sgl_kernel import set_kv_buffer_kernel from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.layers.radix_attention import RadixAttention @@ -394,17 +395,15 @@ def set_kv_buffer( cache_k = cache_k.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype) - if get_is_capture_mode() and self.alt_stream is not None: - # Overlap the copy of K and V cache for small batch size - current_stream = self.device_module.current_stream() - self.alt_stream.wait_stream(current_stream) - self.k_buffer[layer_id - self.start_layer][loc] = cache_k - with self.device_module.stream(self.alt_stream): - self.v_buffer[layer_id - self.start_layer][loc] = cache_v - current_stream.wait_stream(self.alt_stream) - else: - self.k_buffer[layer_id - self.start_layer][loc] = cache_k - self.v_buffer[layer_id - self.start_layer][loc] = cache_v + # use optimized kernel when 1. cuda 2. in capture mode + set_kv_buffer_kernel( + self.k_buffer[layer_id - self.start_layer], + self.v_buffer[layer_id - self.start_layer], + loc, + cache_k, + cache_v, + fallback=(not _is_cuda or not get_is_capture_mode()), + ) def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): copy_all_layer_kv_cache[(len(self.data_ptrs),)]( diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 4fa98e436f3d..28509b60ad55 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -280,6 +280,7 @@ set(SOURCES "csrc/speculative/packbit.cu" "csrc/spatial/greenctx_stream.cu" "csrc/speculative/speculative_sampling.cu" + "csrc/memory/store.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 86ef29f243ae..05473e436c6f 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -413,6 +413,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { */ m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]"); m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value); + + /* + * From csrc/memory + */ + m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()"); + m.impl("store_kv_cache", &store_kv_cache); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/memory/store.cu b/sgl-kernel/csrc/memory/store.cu new file mode 100644 index 000000000000..3ba90fa6f3d4 --- /dev/null +++ b/sgl-kernel/csrc/memory/store.cu @@ -0,0 +1,166 @@ +#include +#include +#include + +#include +#include +#include + +namespace { + +using std::size_t; +using std::uint64_t; + +// Each warp will process 256 bytes per loop iteration +template +__global__ void store_kv_cache_256x1( + uint64_t* __restrict__ k_cache, + uint64_t* __restrict__ v_cache, + const T* __restrict__ out_loc, + const size_t length, + const uint64_t* __restrict__ k, + const uint64_t* __restrict__ v, + const size_t kv_cache_stride, + const size_t kv_input_stride, + const size_t num_items) { + static_assert(std::is_same_v || std::is_same_v, "out_loc must be int32 or int64 type"); + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + const auto warp_id = idx / 32; + const auto lane_id = idx % 32; + if (warp_id >= length) return; + const auto offset = out_loc[warp_id]; + const auto k_dst = k_cache + offset * kv_cache_stride; + const auto v_dst = v_cache + offset * kv_cache_stride; + const auto k_src = k + warp_id * kv_input_stride; + const auto v_src = v + warp_id * kv_input_stride; + for (size_t i = 0; i < num_items; ++i) { + k_dst[lane_id + i * 32] = k_src[lane_id + i * 32]; + v_dst[lane_id + i * 32] = v_src[lane_id + i * 32]; + } +} + +// Each warp will process 128 bytes per loop iteration +template +__global__ void store_kv_cache_128x2( + uint64_t* __restrict__ k_cache, + uint64_t* __restrict__ v_cache, + const T* __restrict__ out_loc, + const size_t length, + const uint64_t* __restrict__ k, + const uint64_t* __restrict__ v, + const size_t kv_cache_stride, + const size_t kv_input_stride, + const size_t num_items) { + static_assert(std::is_same_v || std::is_same_v, "out_loc must be int32 or int64 type"); + + const auto idx = blockIdx.x * blockDim.x + threadIdx.x; + const auto warp_id = idx / 32; + const auto lane_id = idx % 32; + if (warp_id >= length) return; + const auto offset = out_loc[warp_id]; + const auto copy_k = lane_id < 16; + const auto copy_id = lane_id % 16; + const auto cache = copy_k ? k_cache : v_cache; + const auto input = copy_k ? k : v; + const auto dst = cache + offset * kv_cache_stride; + const auto src = input + warp_id * kv_input_stride; + for (size_t i = 0; i < num_items; ++i) { + dst[copy_id + i * 16] = src[copy_id + i * 16]; + } +} + +} // namespace + +auto store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v) -> void { + TORCH_CHECK( + k_cache.is_cuda() && v_cache.is_cuda() && out_loc.is_cuda() && k.is_cuda() && v.is_cuda(), + "All tensors must be CUDA tensors"); + TORCH_CHECK(k_cache.sizes() == v_cache.sizes(), "k_cache and v_cache must have the same size"); + TORCH_CHECK(k_cache.strides() == v_cache.strides(), "k_cache and v_cache must have the same strides"); + TORCH_CHECK(k.sizes() == v.sizes(), "k and v must have the same size"); + TORCH_CHECK(k.strides() == v.strides(), "k and v must have the same strides"); + TORCH_CHECK(k.dim() == 2 && k_cache.dim() == 2, "k and k_cache must be 2D tensors"); + TORCH_CHECK(k.stride(-1) == 1 && k_cache.stride(-1) == 1, "k and k_cache must be contiguous in head."); + TORCH_CHECK(k.size(-1) == k_cache.size(-1), "k and k_cache must have the same head size"); + TORCH_CHECK(out_loc.dim() == 1 && out_loc.is_contiguous(), "out_loc must be a 1D contiguous tensor"); + static_assert(sizeof(uint64_t) == 8, "uint64_t must be 8 bytes, our code assumes that"); + + const auto length = out_loc.size(0); + const auto elem_size = k.element_size(); + const auto size_bytes = elem_size * k.size(-1); + const auto kv_cache_stride_bytes = elem_size * k_cache.stride(-2); + const auto kv_input_stride_bytes = elem_size * k.stride(-2); + const auto kv_cache_stride = kv_cache_stride_bytes / 8; + const auto kv_input_stride = kv_input_stride_bytes / 8; + + const auto k_cache_ptr = static_cast(k_cache.data_ptr()); + const auto v_cache_ptr = static_cast(v_cache.data_ptr()); + const auto k_ptr = static_cast(k.data_ptr()); + const auto v_ptr = static_cast(v.data_ptr()); + const auto num_threads = 256; + const auto num_warps = num_threads / 32; + const auto num_blocks = (length + num_warps - 1) / num_warps; + const auto stream = at::cuda::getCurrentCUDAStream(); + + if (size_bytes % 256 == 0) { + const auto items_per_warp = size_bytes / 256; + if (out_loc.dtype() == at::kInt) { + store_kv_cache_256x1<<>>( + k_cache_ptr, + v_cache_ptr, + out_loc.data_ptr(), + length, + k_ptr, + v_ptr, + kv_cache_stride, + kv_input_stride, + items_per_warp); + } else if (out_loc.dtype() == at::kLong) { + store_kv_cache_256x1<<>>( + k_cache_ptr, + v_cache_ptr, + out_loc.data_ptr(), + length, + k_ptr, + v_ptr, + kv_cache_stride, + kv_input_stride, + items_per_warp); + } else { + TORCH_CHECK(false, "out_loc must be a 1D tensor of int32 or int64 type"); + } + } else if (size_bytes % 128 == 0) { + const auto items_per_warp = size_bytes / 128; + if (out_loc.dtype() == at::kInt) { + store_kv_cache_128x2<<>>( + k_cache_ptr, + v_cache_ptr, + out_loc.data_ptr(), + length, + k_ptr, + v_ptr, + kv_cache_stride, + kv_input_stride, + items_per_warp); + } else if (out_loc.dtype() == at::kLong) { + store_kv_cache_128x2<<>>( + k_cache_ptr, + v_cache_ptr, + out_loc.data_ptr(), + length, + k_ptr, + v_ptr, + kv_cache_stride, + kv_input_stride, + items_per_warp); + } else { + TORCH_CHECK(false, "out_loc must be a 1D tensor of int32 or int64 type"); + } + } else { + TORCH_CHECK( + false, + "The last dimension size bytes of k and v must be" + " divisible by 128 at least, got: ", + size_bytes); + } +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index c007251cdc1f..df7c24095178 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -699,3 +699,8 @@ void qserve_w4a8_per_group_gemm( * From csrc/spatial */ std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, int64_t device); + +/* + * From csrc/memory + */ +void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index faeff924076d..2abf0bcf3262 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -67,6 +67,7 @@ awq_marlin_repack, gptq_marlin_repack, ) +from sgl_kernel.memory import set_kv_buffer_kernel from sgl_kernel.moe import ( apply_shuffle_mul_sum, cutlass_fp4_group_mm, diff --git a/sgl-kernel/python/sgl_kernel/memory.py b/sgl-kernel/python/sgl_kernel/memory.py new file mode 100644 index 000000000000..45041e38a177 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/memory.py @@ -0,0 +1,26 @@ +import torch + + +def set_kv_buffer_kernel( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + loc: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + fallback: bool = False, +): + try: + if fallback: + raise RuntimeError("Fallback to torch implementation") + max_tokens = k_cache.shape[0] + num_tokens = loc.shape[0] + k_cache_2d = k_cache.view(max_tokens, -1) + v_cache_2d = v_cache.view(max_tokens, -1) + k_2d = k.view(num_tokens, -1) + v_2d = v.view(num_tokens, -1) + torch.ops.sgl_kernel.store_kv_cache( # type: ignore + k_cache_2d, v_cache_2d, loc, k_2d, v_2d + ) + except RuntimeError: # ok, fallback to torch implementation + k_cache[loc] = k + v_cache[loc] = v From 6b5b5c8691ca6a6d53ae503596bfdf136fd58e92 Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Fri, 8 Aug 2025 15:43:59 -0700 Subject: [PATCH 2/3] minor: remove python part and refine C++ wrapper --- python/sglang/srt/mem_cache/memory_pool.py | 21 ++--- sgl-kernel/csrc/memory/store.cu | 99 ++++++++-------------- 2 files changed, 48 insertions(+), 72 deletions(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index a8128421b9de..cc3faea0a03d 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -33,7 +33,6 @@ import torch import triton import triton.language as tl -from sgl_kernel import set_kv_buffer_kernel from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.layers.radix_attention import RadixAttention @@ -395,15 +394,17 @@ def set_kv_buffer( cache_k = cache_k.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype) - # use optimized kernel when 1. cuda 2. in capture mode - set_kv_buffer_kernel( - self.k_buffer[layer_id - self.start_layer], - self.v_buffer[layer_id - self.start_layer], - loc, - cache_k, - cache_v, - fallback=(not _is_cuda or not get_is_capture_mode()), - ) + if get_is_capture_mode() and self.alt_stream is not None: + # Overlap the copy of K and V cache for small batch size + current_stream = self.device_module.current_stream() + self.alt_stream.wait_stream(current_stream) + self.k_buffer[layer_id - self.start_layer][loc] = cache_k + with self.device_module.stream(self.alt_stream): + self.v_buffer[layer_id - self.start_layer][loc] = cache_v + current_stream.wait_stream(self.alt_stream) + else: + self.k_buffer[layer_id - self.start_layer][loc] = cache_k + self.v_buffer[layer_id - self.start_layer][loc] = cache_v def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): copy_all_layer_kv_cache[(len(self.data_ptrs),)]( diff --git a/sgl-kernel/csrc/memory/store.cu b/sgl-kernel/csrc/memory/store.cu index 3ba90fa6f3d4..475b5baca4bf 100644 --- a/sgl-kernel/csrc/memory/store.cu +++ b/sgl-kernel/csrc/memory/store.cu @@ -1,3 +1,4 @@ +#include #include #include #include @@ -23,7 +24,6 @@ __global__ void store_kv_cache_256x1( const size_t kv_cache_stride, const size_t kv_input_stride, const size_t num_items) { - static_assert(std::is_same_v || std::is_same_v, "out_loc must be int32 or int64 type"); const auto idx = blockIdx.x * blockDim.x + threadIdx.x; const auto warp_id = idx / 32; const auto lane_id = idx % 32; @@ -51,8 +51,6 @@ __global__ void store_kv_cache_128x2( const size_t kv_cache_stride, const size_t kv_input_stride, const size_t num_items) { - static_assert(std::is_same_v || std::is_same_v, "out_loc must be int32 or int64 type"); - const auto idx = blockIdx.x * blockDim.x + threadIdx.x; const auto warp_id = idx / 32; const auto lane_id = idx % 32; @@ -102,65 +100,42 @@ auto store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, const auto num_blocks = (length + num_warps - 1) / num_warps; const auto stream = at::cuda::getCurrentCUDAStream(); - if (size_bytes % 256 == 0) { - const auto items_per_warp = size_bytes / 256; - if (out_loc.dtype() == at::kInt) { - store_kv_cache_256x1<<>>( - k_cache_ptr, - v_cache_ptr, - out_loc.data_ptr(), - length, - k_ptr, - v_ptr, - kv_cache_stride, - kv_input_stride, - items_per_warp); - } else if (out_loc.dtype() == at::kLong) { - store_kv_cache_256x1<<>>( - k_cache_ptr, - v_cache_ptr, - out_loc.data_ptr(), - length, - k_ptr, - v_ptr, - kv_cache_stride, - kv_input_stride, - items_per_warp); - } else { - TORCH_CHECK(false, "out_loc must be a 1D tensor of int32 or int64 type"); - } - } else if (size_bytes % 128 == 0) { - const auto items_per_warp = size_bytes / 128; - if (out_loc.dtype() == at::kInt) { - store_kv_cache_128x2<<>>( - k_cache_ptr, - v_cache_ptr, - out_loc.data_ptr(), - length, - k_ptr, - v_ptr, - kv_cache_stride, - kv_input_stride, - items_per_warp); - } else if (out_loc.dtype() == at::kLong) { - store_kv_cache_128x2<<>>( - k_cache_ptr, - v_cache_ptr, - out_loc.data_ptr(), - length, - k_ptr, - v_ptr, - kv_cache_stride, - kv_input_stride, - items_per_warp); + AT_DISPATCH_INTEGRAL_TYPES(out_loc.scalar_type(), "store_kv_cache", [&] { + if constexpr (!std::is_same_v && !std::is_same_v) { + // do not instantiate the kernel if out_loc is not int32 or int64 + TORCH_CHECK(false, "out_loc must be of type int32 or int64, got: ", out_loc.scalar_type()); } else { - TORCH_CHECK(false, "out_loc must be a 1D tensor of int32 or int64 type"); + if (size_bytes % 256 == 0) { + const auto items_per_warp = size_bytes / 256; + store_kv_cache_256x1<<>>( + k_cache_ptr, + v_cache_ptr, + out_loc.data_ptr(), + length, + k_ptr, + v_ptr, + kv_cache_stride, + kv_input_stride, + items_per_warp); + } else if (size_bytes % 128 == 0) { + const auto items_per_warp = size_bytes / 128; + store_kv_cache_128x2<<>>( + k_cache_ptr, + v_cache_ptr, + out_loc.data_ptr(), + length, + k_ptr, + v_ptr, + kv_cache_stride, + kv_input_stride, + items_per_warp); + } else { + TORCH_CHECK( + false, + "The last dimension size bytes of k and v must be" + " divisible by 128 at least, got: ", + size_bytes); + } } - } else { - TORCH_CHECK( - false, - "The last dimension size bytes of k and v must be" - " divisible by 128 at least, got: ", - size_bytes); - } + }); } From cdc1153abbb500dd2694a231f1290e2260defe2a Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Sat, 9 Aug 2025 23:41:28 -0700 Subject: [PATCH 3/3] chore: move reshape logic to C++ --- sgl-kernel/csrc/memory/store.cu | 8 +++++++- sgl-kernel/python/sgl_kernel/memory.py | 10 +--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/sgl-kernel/csrc/memory/store.cu b/sgl-kernel/csrc/memory/store.cu index 475b5baca4bf..c6dd97ebd710 100644 --- a/sgl-kernel/csrc/memory/store.cu +++ b/sgl-kernel/csrc/memory/store.cu @@ -70,6 +70,13 @@ __global__ void store_kv_cache_128x2( } // namespace auto store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v) -> void { + const auto max_tokens = k_cache.size(0); + const auto num_tokens = out_loc.size(0); + k_cache = k_cache.view({max_tokens, -1}); + v_cache = v_cache.view({max_tokens, -1}); + k = k.view({num_tokens, -1}); + v = v.view({num_tokens, -1}); + TORCH_CHECK( k_cache.is_cuda() && v_cache.is_cuda() && out_loc.is_cuda() && k.is_cuda() && v.is_cuda(), "All tensors must be CUDA tensors"); @@ -77,7 +84,6 @@ auto store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, TORCH_CHECK(k_cache.strides() == v_cache.strides(), "k_cache and v_cache must have the same strides"); TORCH_CHECK(k.sizes() == v.sizes(), "k and v must have the same size"); TORCH_CHECK(k.strides() == v.strides(), "k and v must have the same strides"); - TORCH_CHECK(k.dim() == 2 && k_cache.dim() == 2, "k and k_cache must be 2D tensors"); TORCH_CHECK(k.stride(-1) == 1 && k_cache.stride(-1) == 1, "k and k_cache must be contiguous in head."); TORCH_CHECK(k.size(-1) == k_cache.size(-1), "k and k_cache must have the same head size"); TORCH_CHECK(out_loc.dim() == 1 && out_loc.is_contiguous(), "out_loc must be a 1D contiguous tensor"); diff --git a/sgl-kernel/python/sgl_kernel/memory.py b/sgl-kernel/python/sgl_kernel/memory.py index 45041e38a177..eb997db0ccae 100644 --- a/sgl-kernel/python/sgl_kernel/memory.py +++ b/sgl-kernel/python/sgl_kernel/memory.py @@ -12,15 +12,7 @@ def set_kv_buffer_kernel( try: if fallback: raise RuntimeError("Fallback to torch implementation") - max_tokens = k_cache.shape[0] - num_tokens = loc.shape[0] - k_cache_2d = k_cache.view(max_tokens, -1) - v_cache_2d = v_cache.view(max_tokens, -1) - k_2d = k.view(num_tokens, -1) - v_2d = v.view(num_tokens, -1) - torch.ops.sgl_kernel.store_kv_cache( # type: ignore - k_cache_2d, v_cache_2d, loc, k_2d, v_2d - ) + torch.ops.sgl_kernel.store_kv_cache(k_cache, v_cache, loc, k, v) except RuntimeError: # ok, fallback to torch implementation k_cache[loc] = k v_cache[loc] = v