Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/sglang/jit_kernel/.clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,11 @@ PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first
PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name

IncludeCategories:
- Regex: '^<sgl_kernel/.*>$'
- Regex: '^<sgl_kernel/.*\.h>$'
Priority: 0
- Regex: '^<sgl_kernel/impl/.*>$'
Priority: 2
- Regex: '^<sgl_kernel/.*\.cuh>$'
Priority: 1
- Regex: '^<.*/.*>$'
Priority: 3
31 changes: 11 additions & 20 deletions python/sglang/jit_kernel/benchmark/bench_per_tensor_quant_fp8.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import os
from typing import Optional, Tuple

Expand Down Expand Up @@ -57,7 +56,7 @@ def sglang_scaled_fp8_quant(

def calculate_diff(batch_size: int, seq_len: int):
device = torch.device("cuda")
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device)
x = torch.rand((batch_size, seq_len), dtype=torch.bfloat16, device=device)

if not VLLM_AVAILABLE:
print("vLLM not available, skipping comparison")
Expand All @@ -66,25 +65,17 @@ def calculate_diff(batch_size: int, seq_len: int):
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)

scale_diff = torch.abs(vllm_scale - sglang_scale).item()
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
vllm_out = vllm_out.to(torch.float32)
sglang_out = sglang_out.to(torch.float32)

if torch.allclose(
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5):
print("All implementations match")
else:
print("Implementations differ")
triton.testing.assert_close(vllm_out, sglang_out, rtol=1e-3, atol=1e-3)
triton.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)


if IS_CI:
batch_size_range = [16]
seq_len_range = [64]
element_range = [16384]
else:
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048]

configs = list(itertools.product(batch_size_range, seq_len_range))
element_range = [2**n for n in range(10, 20)]


if VLLM_AVAILABLE:
Expand All @@ -99,8 +90,8 @@ def calculate_diff(batch_size: int, seq_len: int):

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=configs,
x_names=["element_count"],
x_vals=element_range,
line_arg="provider",
line_vals=line_vals,
line_names=line_names,
Expand All @@ -110,11 +101,11 @@ def calculate_diff(batch_size: int, seq_len: int):
args={},
)
)
def benchmark(batch_size, seq_len, provider):
def benchmark(element_count, provider):
dtype = torch.float16
device = torch.device("cuda")

x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
x = torch.randn(element_count, 4096, device=device, dtype=dtype)

quantiles = [0.5, 0.2, 0.8]

Expand Down
96 changes: 96 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import itertools
import os

import torch
import triton
import triton.testing
from flashinfer import rmsnorm as fi_rmsnorm
from sgl_kernel import rmsnorm

from sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm

IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)


def sglang_aot_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
) -> None:
rmsnorm(input, weight, out=input)


def sglang_jit_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
) -> None:
jit_rmsnorm(input, weight, output=input)


def flashinfer_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
) -> None:
fi_rmsnorm(input, weight, out=input)


@torch.compile()
def torch_impl_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
) -> None:
mean = input.float().pow(2).mean(dim=-1, keepdim=True)
norm = (mean + eps).rsqrt()
input.copy_(input.float() * norm * weight.float())


DTYPE = torch.bfloat16
DEVICE = "cuda"

if IS_CI:
BS_LIST = [16]
HIDDEN_SIZE_LIST = [512, 2048]
else:
BS_LIST = [2**n for n in range(0, 14)]
HIDDEN_SIZE_LIST = [1536, 3072, 4096, 5120, 8192]

LINE_VALS = ["aot", "jit", "fi", "torch"]
LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "FlashInfer", "PyTorch"]
STYLES = [("orange", "-"), ("blue", "--"), ("green", "-."), ("red", ":")]

configs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST))


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["hidden_size", "batch_size"],
x_vals=configs,
line_arg="provider",
line_vals=LINE_VALS,
line_names=LINE_NAMES,
styles=STYLES,
ylabel="us",
plot_name="rmsnorm-performance",
args={},
)
)
def benchmark(hidden_size: int, batch_size: int, provider: str):
input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE)
weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE)
FN_MAP = {
"aot": sglang_aot_rmsnorm,
"jit": sglang_jit_rmsnorm,
"fi": flashinfer_rmsnorm,
"torch": torch_impl_rmsnorm,
}
fn = lambda: FN_MAP[provider](input.clone(), weight)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) # type: ignore
return 1000 * ms, 1000 * max_ms, 1000 * min_ms


if __name__ == "__main__":
benchmark.run(print_data=True)
5 changes: 3 additions & 2 deletions python/sglang/jit_kernel/csrc/add_constant.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <sgl_kernel/tensor.h> // For TensorMatcher, SymbolicSize, SymbolicDevice
#include <sgl_kernel/tensor.h> // For TensorMatcher, SymbolicSize, SymbolicDevice
#include <sgl_kernel/utils.h> // For div_ceil, RuntimeCheck

#include <sgl_kernel/utils.cuh> // For LaunchKernel
#include <sgl_kernel/utils.h> // For div_ceil, RuntimeCheck

#include <dlpack/dlpack.h>
#include <tvm/ffi/container/tensor.h>
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/jit_kernel/csrc/cuda_wait_value.cuh
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include <sgl_kernel/tensor.h>
#include <sgl_kernel/utils.cuh>

#include <cuda_runtime_api.h>
#include <sgl_kernel/utils.cuh>

#include <cstdint>
#include <cuda_runtime_api.h>

namespace {

Expand Down
50 changes: 34 additions & 16 deletions python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#include <sgl_kernel/tensor.h>
#include <sgl_kernel/utils.cuh>
#include <sgl_kernel/utils.h>

#include <sgl_kernel/tile.cuh>
#include <sgl_kernel/utils.cuh>
#include <sgl_kernel/vec.cuh>
#include <sgl_kernel/warp.cuh>

#include <dlpack/dlpack.h>
#include <tvm/ffi/container/tensor.h>
Expand All @@ -27,8 +28,17 @@ struct StoreKVCacheParams {
constexpr uint32_t kNumWarps = 4;
constexpr uint32_t kThreadsPerBlock = kNumWarps * device::kWarpThreads;

/**
* \brief Use a single warp to copy key and value data from source to destination.
* Each thread in the warp copies a portion of the data in a coalesced manner.
* \tparam kElementBytes The size of each key/value element in bytes.
* \param k_src Pointer to the source key data.
* \param v_src Pointer to the source value data.
* \param k_dst Pointer to the destination key data.
* \param v_dst Pointer to the destination value data.
*/
template <int64_t kElementBytes>
__device__ void copy_impl(
SGL_DEVICE void copy_kv_warp(
const void* __restrict__ k_src,
const void* __restrict__ v_src,
void* __restrict__ k_dst,
Expand All @@ -42,31 +52,39 @@ __device__ void copy_impl(

static_assert(kAlignment > 0, "Element size must be multiple of 4 bytes");

using vec_t = aligned_vector<uint32_t, kAlignment / 4>;
using vec_t = AlignedStorage<uint32_t, kAlignment / 4>;
constexpr auto kLoopBytes = sizeof(vec_t) * kWarpThreads;
constexpr auto kLoopCount = kElementBytes / kLoopBytes;

const auto gmem = tile::Memory<vec_t>::warp();

#pragma unroll kLoopCount
for (int64_t i = 0; i < kLoopCount; ++i) {
const auto k = warp::load<vec_t>(pointer::offset(k_src, i * kLoopBytes));
const auto v = warp::load<vec_t>(pointer::offset(v_src, i * kLoopBytes));
warp::store(pointer::offset(k_dst, i * kLoopBytes), k);
warp::store(pointer::offset(v_dst, i * kLoopBytes), v);
const auto k = gmem.load(k_src, i);
const auto v = gmem.load(v_src, i);
gmem.store(k_dst, k, i);
gmem.store(v_dst, v, i);
}

// handle the epilogue if any
if constexpr (kLoopCount * kLoopBytes < kElementBytes) {
constexpr auto kOffset = kLoopCount * kLoopBytes;
if ((threadIdx.x % kWarpThreads) * sizeof(vec_t) < kElementBytes - kOffset) {
const auto k = warp::load<vec_t>(pointer::offset(k_src, kOffset));
const auto v = warp::load<vec_t>(pointer::offset(v_src, kOffset));
warp::store(pointer::offset(k_dst, kOffset), k);
warp::store(pointer::offset(v_dst, kOffset), v);
if (gmem.in_bound(kElementBytes / sizeof(vec_t), kLoopCount)) {
const auto k = gmem.load(k_src, kLoopCount);
const auto v = gmem.load(v_src, kLoopCount);
gmem.store(k_dst, k, kLoopCount);
gmem.store(v_dst, v, kLoopCount);
}
}
}

// Each warp handles one item
/**
* \brief Kernel to store key-value pairs into the KV cache.
* Each element is split into multiple parts to allow parallel memory copy.
* \tparam kElementBytes The size of each key/value element in bytes.
* \tparam kSplit The number of warps that handle each element.
* \tparam kUsePDL Whether to use PDL feature.
* \tparam T The data type of the indices (`int32_t` or `int64_t`).
*/
template <int64_t kElementBytes, int kSplit, bool kUsePDL, typename T>
__global__ void store_kvcache(const __grid_constant__ StoreKVCacheParams params) {
using namespace device;
Expand All @@ -89,7 +107,7 @@ __global__ void store_kvcache(const __grid_constant__ StoreKVCacheParams params)
const auto k_dst = pointer::offset(k_cache, index * stride_cache, split_id * kSplitSize);
const auto v_dst = pointer::offset(v_cache, index * stride_cache, split_id * kSplitSize);

copy_impl<kSplitSize>(k_src, v_src, k_dst, v_dst);
copy_kv_warp<kSplitSize>(k_src, v_src, k_dst, v_dst);
PDLTriggerSecondary<kUsePDL>();
}

Expand Down
Loading
Loading