diff --git a/python/sglang/jit_kernel/benchmark/bench_qknorm.py b/python/sglang/jit_kernel/benchmark/bench_qknorm.py index a136131388b9..20ecd3b6eca1 100644 --- a/python/sglang/jit_kernel/benchmark/bench_qknorm.py +++ b/python/sglang/jit_kernel/benchmark/bench_qknorm.py @@ -53,6 +53,7 @@ def flashinfer_qknorm( q_weight: torch.Tensor, k_weight: torch.Tensor, ) -> None: + from flashinfer import rmsnorm rmsnorm(q, q_weight, out=q) rmsnorm(k, k_weight, out=k) diff --git a/python/sglang/jit_kernel/benchmark/bench_store_cache.py b/python/sglang/jit_kernel/benchmark/bench_store_cache.py new file mode 100644 index 000000000000..d7d014f2e8c1 --- /dev/null +++ b/python/sglang/jit_kernel/benchmark/bench_store_cache.py @@ -0,0 +1,133 @@ +import itertools +from typing import Tuple + +import torch +import triton +import triton.testing +from sgl_kernel import set_kv_buffer_kernel + +from sglang.jit_kernel.benchmark.utils import is_in_ci +from sglang.jit_kernel.kvcache import store_cache + +IS_CI = is_in_ci() + + +def sglang_aot_store_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +) -> None: + set_kv_buffer_kernel(k_cache, v_cache, indices, k, v) + + +def sglang_jit_store_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +) -> None: + store_cache(k, v, k_cache, v_cache, indices) + + +@torch.compile() +def torch_compile_store_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +) -> None: + k_cache[indices] = k + v_cache[indices] = v + + +alt_stream = torch.cuda.Stream() + + +def torch_streams_store_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +) -> None: + current_stream = torch.cuda.current_stream() + alt_stream.wait_stream(current_stream) + k_cache[indices] = k + with torch.cuda.stream(alt_stream): + v_cache[indices] = v + current_stream.wait_stream(alt_stream) + + +DTYPE = torch.bfloat16 +DEVICE = "cuda" +NUM_LAYERS = 8 +CACHE_SIZE = 2 * 1024 * 1024 // NUM_LAYERS + +if IS_CI: + BS_RANGE = [16] + ITEM_SIZE = [1024] +else: + BS_RANGE = [2**n for n in range(0, 15)] + ITEM_SIZE = [64, 128, 256, 512, 1024] + +LINE_VALS = ["aot", "jit", "torch_compile", "torch_streams"] +LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "PyTorch Compile", "PyTorch 2 Stream"] +STYLES = [("orange", "-"), ("blue", "--"), ("red", ":"), ("green", "-.")] +X_NAMES = ["item_size", "batch_size"] +CONFIGS = list(itertools.product(ITEM_SIZE, BS_RANGE)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=X_NAMES, + x_vals=CONFIGS, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="store-kvcache-performance", + args={}, + ) +) +def benchmark( + batch_size: int, item_size: int, provider: str +) -> Tuple[float, float, float]: + k = torch.randn((NUM_LAYERS, batch_size, item_size), dtype=DTYPE, device=DEVICE) + v = torch.randn((NUM_LAYERS, batch_size, item_size), dtype=DTYPE, device=DEVICE) + k_cache = torch.randn( + (NUM_LAYERS, CACHE_SIZE, item_size), dtype=DTYPE, device=DEVICE + ) + v_cache = torch.randn( + (NUM_LAYERS, CACHE_SIZE, item_size), dtype=DTYPE, device=DEVICE + ) + indices = torch.randperm(CACHE_SIZE, device=DEVICE)[:batch_size] + torch.cuda.synchronize() + + FN_MAP = { + "aot": sglang_aot_store_cache, + "jit": sglang_jit_store_cache, + "torch_compile": torch_compile_store_cache, + "torch_streams": torch_streams_store_cache, + } + + def fn(): + impl = FN_MAP[provider] + for i in range(NUM_LAYERS): + impl(k[i], v[i], k_cache[i], v_cache[i], indices) + + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) # type: ignore + return ( + 1000 * ms / NUM_LAYERS, + 1000 * max_ms / NUM_LAYERS, + 1000 * min_ms / NUM_LAYERS, + ) + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/python/sglang/jit_kernel/benchmark/utils.py b/python/sglang/jit_kernel/benchmark/utils.py new file mode 100644 index 000000000000..5055c700fe6b --- /dev/null +++ b/python/sglang/jit_kernel/benchmark/utils.py @@ -0,0 +1,8 @@ +import os + + +def is_in_ci(): + return ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" + ) diff --git a/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh b/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh new file mode 100644 index 000000000000..5a1fa2b6df27 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh @@ -0,0 +1,181 @@ +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace { + +struct StoreKVCacheParams { + const void* __restrict__ k; + const void* __restrict__ v; + void* __restrict__ k_cache; + void* __restrict__ v_cache; + const void* __restrict__ indices; + int64_t stride_k_bytes; + int64_t stride_v_bytes; + int64_t stride_cache_bytes; + int64_t stride_indices; + uint32_t batch_size; +}; + +constexpr uint32_t kNumWarps = 4; +constexpr uint32_t kThreadsPerBlock = kNumWarps * device::kWarpThreads; + +template +__device__ void copy_impl( + const void* __restrict__ k_src, + const void* __restrict__ v_src, + void* __restrict__ k_dst, + void* __restrict__ v_dst) { + using namespace device; + constexpr int64_t kAlignment = (kElementBytes % (16 * kWarpThreads) == 0) ? 16 + : kElementBytes % (8 * kWarpThreads) == 0 ? 8 + : kElementBytes % (4 * kWarpThreads) == 0 ? 4 + : kElementBytes % 4 == 0 ? 4 + : 0; + + static_assert(kAlignment > 0, "Element size must be multiple of 4 bytes"); + + using vec_t = aligned_vector; + constexpr auto kLoopBytes = sizeof(vec_t) * kWarpThreads; + constexpr auto kLoopCount = kElementBytes / kLoopBytes; + +#pragma unroll kLoopCount + for (int64_t i = 0; i < kLoopCount; ++i) { + const auto k = warp::load(pointer::offset(k_src, i * kLoopBytes)); + const auto v = warp::load(pointer::offset(v_src, i * kLoopBytes)); + warp::store(pointer::offset(k_dst, i * kLoopBytes), k); + warp::store(pointer::offset(v_dst, i * kLoopBytes), v); + } + + // handle the epilogue if any + if constexpr (kLoopCount * kLoopBytes < kElementBytes) { + constexpr auto kOffset = kLoopCount * kLoopBytes; + if ((threadIdx.x % kWarpThreads) * sizeof(vec_t) < kElementBytes - kOffset) { + const auto k = warp::load(pointer::offset(k_src, kOffset)); + const auto v = warp::load(pointer::offset(v_src, kOffset)); + warp::store(pointer::offset(k_dst, kOffset), k); + warp::store(pointer::offset(v_dst, kOffset), v); + } + } +} + +// Each warp handles one item +template +__global__ void store_kvcache(const __grid_constant__ StoreKVCacheParams params) { + using namespace device; + constexpr auto kSplitSize = kElementBytes / kSplit; + const uint32_t warp_id = blockIdx.x * kNumWarps + threadIdx.x / kWarpThreads; + const uint32_t item_id = warp_id / kSplit; + const uint32_t split_id = warp_id % kSplit; + const auto& [ + k_input, v_input, k_cache, v_cache, indices, // ptr + stride_k, stride_v, stride_cache, stride_indices, batch_size // size + ] = params; + if (item_id >= batch_size) return; + + const auto index_ptr = static_cast(indices) + item_id * stride_indices; + PDLWaitPrimary(); + + const auto index = *index_ptr; + const auto k_src = pointer::offset(k_input, item_id * stride_k, split_id * kSplitSize); + const auto v_src = pointer::offset(v_input, item_id * stride_v, split_id * kSplitSize); + const auto k_dst = pointer::offset(k_cache, index * stride_cache, split_id * kSplitSize); + const auto v_dst = pointer::offset(v_cache, index * stride_cache, split_id * kSplitSize); + + copy_impl(k_src, v_src, k_dst, v_dst); + PDLTriggerSecondary(); +} + +template +struct StoreKVCacheKernel { + static_assert(kElementBytes > 0 && kElementBytes % 4 == 0); + + template + static constexpr auto store_kernel = store_kvcache; + + template + static auto get_kernel(const int num_split) { + using namespace host; + // only apply split optimization when element size is aligned + if constexpr (kElementBytes % (4 * 128) == 0) { + if (num_split == 4) return store_kernel<4, T>; + } + if constexpr (kElementBytes % (2 * 128) == 0) { + if (num_split == 2) return store_kernel<2, T>; + } + if (num_split == 1) return store_kernel<1, T>; + Panic("Unsupported num_split {} for element size {}", num_split, kElementBytes); + } + + static void + run(const tvm::ffi::TensorView k, + const tvm::ffi::TensorView v, + const tvm::ffi::TensorView k_cache, + const tvm::ffi::TensorView v_cache, + const tvm::ffi::TensorView indices, + const int num_split) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto D = SymbolicSize{"element_size"}; + auto KS = SymbolicSize{"k_stride"}; + auto VS = SymbolicSize{"v_stride"}; + auto S = SymbolicSize{"cache_stride"}; + auto I = SymbolicSize{"indices_stride"}; + auto dtype = SymbolicDType{}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({B, D}) // + .with_strides({KS, 1}) + .with_dtype(dtype) + .with_device(device) + .verify(k); + TensorMatcher({B, D}) // + .with_strides({VS, 1}) + .with_dtype(dtype) + .with_device(device) + .verify(v); + TensorMatcher({-1, D}) // + .with_strides({S, 1}) + .with_dtype(dtype) + .with_device(device) + .verify(k_cache) + .verify(v_cache); + TensorMatcher({B}) // + .with_strides({I}) + .with_dtype() + .with_device(device) + .verify(indices); + + const int64_t dtype_size = dtype_bytes(dtype.unwrap()); + const uint32_t num_elements = static_cast(B.unwrap()); + RuntimeCheck(kElementBytes == dtype_size * D.unwrap()); + + const auto params = StoreKVCacheParams{ + .k = k.data_ptr(), + .v = v.data_ptr(), + .k_cache = k_cache.data_ptr(), + .v_cache = v_cache.data_ptr(), + .indices = indices.data_ptr(), + .stride_k_bytes = KS.unwrap() * dtype_size, + .stride_v_bytes = VS.unwrap() * dtype_size, + .stride_cache_bytes = S.unwrap() * dtype_size, + .stride_indices = I.unwrap(), + .batch_size = static_cast(B.unwrap()), + }; + // select kernel and update num_split if needed + const auto kernel = dtype.is_type() ? get_kernel(num_split) : get_kernel(num_split); + const auto num_blocks = div_ceil(num_elements * num_split, kNumWarps); + LaunchKernel(num_blocks, kThreadsPerBlock, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/norm.cuh b/python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh similarity index 100% rename from python/sglang/jit_kernel/csrc/norm.cuh rename to python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh diff --git a/python/sglang/jit_kernel/kvcache.py b/python/sglang/jit_kernel/kvcache.py new file mode 100644 index 000000000000..46a14612b6ff --- /dev/null +++ b/python/sglang/jit_kernel/kvcache.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import ( + cache_once, + is_arch_support_pdl, + load_jit, + make_cpp_args, +) + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_kvcache_module(row_bytes: int) -> Module: + args = make_cpp_args(row_bytes, is_arch_support_pdl()) + return load_jit( + "kvcache", + *args, + cuda_files=["elementwise/kvcache.cuh"], + cuda_wrappers=[("store_cache", f"StoreKVCacheKernel<{args}>::run")], + ) + + +@cache_once +def can_use_store_cache(size: int) -> bool: + logger = logging.getLogger(__name__) + if size % 4 != 0: + logger.warning( + f"Unsupported row_bytes={size} for JIT KV-Cache kernel:" + " must be multiple of 4" + ) + return False + try: + _jit_kvcache_module(size) + return True + except Exception as e: + logger.warning( + f"Failed to load JIT KV-Cache kernel " f"with row_bytes={size}: {e}" + ) + return False + + +def store_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, + *, + row_bytes: int = 0, + num_split: int = 0, # can be tuned for performance +) -> None: + """Store key and value tensors into KV cache at specified indices. + + Args: + k (torch.Tensor): Key tensor of shape (batch_size, H * D). + v (torch.Tensor): Value tensor of shape (batch_size, H * D). + k_cache (torch.Tensor): Key cache tensor of shape (num_pages, H * D). + v_cache (torch.Tensor): Value cache tensor of shape (num_pages, H * D). + indices (torch.Tensor): Indices tensor of shape (batch_size,). + """ + row_bytes = row_bytes or k.shape[-1] * k.element_size() + module = _jit_kvcache_module(row_bytes) + if num_split <= 0: + if row_bytes % 2048 == 0: + num_split = 4 + elif row_bytes % 1024 == 0: + num_split = 2 + else: + num_split = 1 + module.store_cache( + k, + v, + k_cache, + v_cache, + indices, + num_split, + ) diff --git a/python/sglang/jit_kernel/norm.py b/python/sglang/jit_kernel/norm.py index 963e5ed06ef2..e1a05aa2799c 100644 --- a/python/sglang/jit_kernel/norm.py +++ b/python/sglang/jit_kernel/norm.py @@ -17,12 +17,12 @@ @cache_once -def _jit_norm_module(head_dims: int, dtype: torch.dtype) -> Module: - args = make_cpp_args(head_dims, is_arch_support_pdl(), dtype) +def _jit_qknorm_module(head_dim: int, dtype: torch.dtype) -> Module: + args = make_cpp_args(head_dim, is_arch_support_pdl(), dtype) return load_jit( - "norm", + "qknorm", *args, - cuda_files=["norm.cuh"], + cuda_files=["elementwise/qknorm.cuh"], cuda_wrappers=[("qknorm", f"QKNormKernel<{args}>::run")], ) @@ -34,7 +34,7 @@ def can_use_fused_inplace_qknorm(head_dim: int, dtype: torch.dtype) -> bool: logger.warning(f"Unsupported head_dim={head_dim} for JIT QK-Norm kernel") return False try: - _jit_norm_module(head_dim, dtype) + _jit_qknorm_module(head_dim, dtype) return True except Exception as e: logger.warning(f"Failed to load JIT QK-Norm kernel: {e}") @@ -51,5 +51,5 @@ def fused_inplace_qknorm( head_dim: int = 0, ) -> None: head_dim = head_dim or q.size(-1) - module = _jit_norm_module(head_dim, q.dtype) + module = _jit_qknorm_module(head_dim, q.dtype) module.qknorm(q, k, q_weight, k_weight, eps) diff --git a/python/sglang/jit_kernel/tests/test_store_cache.py b/python/sglang/jit_kernel/tests/test_store_cache.py new file mode 100644 index 000000000000..770f257f9de6 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_store_cache.py @@ -0,0 +1,35 @@ +import itertools + +import pytest +import torch + +from sglang.jit_kernel.kvcache import store_cache + +BS_LIST = [2**n for n in range(0, 15)] +BS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)] +HIDDEN_DIMS = [64, 128, 256, 512, 1024, 96, 98, 100] +CACHE_SIZE = 1024 * 1024 +DTYPE = torch.bfloat16 +DEVICE = "cuda" + + +@pytest.mark.parametrize( + "batch_size,element_dim", + list(itertools.product(BS_LIST, HIDDEN_DIMS)), +) +def test_store_cache(batch_size: int, element_dim: int) -> None: + k = torch.randn((batch_size, element_dim), dtype=DTYPE, device=DEVICE) + v = torch.randn((batch_size, element_dim), dtype=DTYPE, device=DEVICE) + k_cache = torch.randn((CACHE_SIZE, element_dim), dtype=DTYPE, device=DEVICE) + v_cache = torch.randn((CACHE_SIZE, element_dim), dtype=DTYPE, device=DEVICE) + indices = torch.randperm(CACHE_SIZE, device=DEVICE)[:batch_size] + + # AOT store cache + store_cache(k, v, k_cache, v_cache, indices) + + assert torch.all(k_cache[indices] == k) + assert torch.all(v_cache[indices] == v) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 65d562a277c0..b10fa034d70d 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -15,19 +15,6 @@ from __future__ import annotations -import dataclasses -from dataclasses import dataclass -from typing import List - -from sglang.srt.configs.mamba_utils import BaseLinearStateParams -from sglang.srt.environ import envs -from sglang.srt.layers.attention.nsa import index_buf_accessor -from sglang.srt.layers.attention.nsa.quant_k_cache import ( - quantize_k_cache, - quantize_k_cache_separate, -) -from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter - """ Memory pool. @@ -38,16 +25,26 @@ """ import abc +import dataclasses import logging from contextlib import contextmanager, nullcontext -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import numpy as np import torch import triton import triton.language as tl +from sglang.jit_kernel.kvcache import can_use_store_cache, store_cache +from sglang.srt.configs.mamba_utils import BaseLinearStateParams from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE +from sglang.srt.environ import envs +from sglang.srt.layers.attention.nsa import index_buf_accessor +from sglang.srt.layers.attention.nsa.quant_k_cache import ( + quantize_k_cache, + quantize_k_cache_separate, +) from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.utils import ( get_mla_kv_buffer_triton, @@ -56,6 +53,10 @@ set_mla_kv_scale_buffer_triton, ) from sglang.srt.utils import is_cuda, is_npu, next_power_of_2 +from sglang.srt.utils.custom_op import register_custom_op +from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter + +store_cache = register_custom_op(store_cache, mutates_args=["k_cache", "v_cache"]) if TYPE_CHECKING: from sglang.srt.managers.cache_controller import LayerDoneCounter @@ -75,6 +76,43 @@ def get_tensor_size_bytes(t: Union[torch.Tensor, List[torch.Tensor]]): return np.prod(t.shape) * t.dtype.itemsize +def _set_kv_buffer_impl( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, + row_dim: int, # head_num * head_dim + store_dtype: torch.dtype, + device_module: Any, + alt_stream: Optional[torch.cuda.Stream] = None, + same_kv_dim: bool = True, +) -> None: + row_bytes = row_dim * store_dtype.itemsize + if _is_cuda and same_kv_dim and can_use_store_cache(row_bytes): + return store_cache( + k.view(-1, row_dim), + v.view(-1, row_dim), + k_cache.view(-1, row_dim), + v_cache.view(-1, row_dim), + indices, + row_bytes=row_bytes, + ) + + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + + if get_is_capture_mode() and alt_stream is not None: + current_stream = device_module.current_stream() + alt_stream.wait_stream(current_stream) + k_cache[indices] = k + with device_module.stream(alt_stream): + v_cache[indices] = v + current_stream.wait_stream(alt_stream) + else: # fallback to naive implementation + k_cache[indices] = k + v_cache[indices] = v + + class ReqToTokenPool: """A memory pool that maps a request to its token locations.""" @@ -652,6 +690,10 @@ def __init__( self._finalize_allocation_log(size) + # for store_cache JIT kernel + self.row_dim = self.head_num * self.head_dim + self.same_kv_dim = self.head_dim == self.v_head_dim + def _init_kv_copy_and_warmup(self): # Heuristics for KV copy tiling _KV_COPY_STRIDE_THRESHOLD_LARGE = 8192 @@ -859,8 +901,6 @@ def set_kv_buffer( v_scale: Optional[float] = None, layer_id_override: Optional[int] = None, ): - from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode - if layer_id_override is not None: layer_id = layer_id_override else: @@ -877,17 +917,18 @@ def set_kv_buffer( cache_k = cache_k.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype) - if get_is_capture_mode() and self.alt_stream is not None: - # Overlap the copy of K and V cache for small batch size - current_stream = self.device_module.current_stream() - self.alt_stream.wait_stream(current_stream) - self.k_buffer[layer_id - self.start_layer][loc] = cache_k - with self.device_module.stream(self.alt_stream): - self.v_buffer[layer_id - self.start_layer][loc] = cache_v - current_stream.wait_stream(self.alt_stream) - else: - self.k_buffer[layer_id - self.start_layer][loc] = cache_k - self.v_buffer[layer_id - self.start_layer][loc] = cache_v + _set_kv_buffer_impl( + cache_k, + cache_v, + self.k_buffer[layer_id - self.start_layer], + self.v_buffer[layer_id - self.start_layer], + loc, + row_dim=self.row_dim, + store_dtype=self.store_dtype, + device_module=self.device_module, + alt_stream=self.alt_stream, + same_kv_dim=self.same_kv_dim, + ) def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): if envs.SGLANG_NATIVE_MOVE_KV_CACHE.get():