Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
2edb0c2
[jit_kernel] Add downcast_fp8 JIT kernel (Phase 1)
Johnsonms Feb 21, 2026
7f4999a
[jit_kernel] Migrate downcast_fp8 from sgl-kernel AOT to JIT (Phase 2)
Johnsonms Feb 21, 2026
d431152
style: apply code formatting to cast JIT kernel
Johnsonms Feb 21, 2026
7fcf3cb
[jit_kernel] Optimize cast kernel: coalesced read and write with fixe…
Johnsonms Feb 22, 2026
cc80b9e
[jit_kernel] Rename input_sl to input_num_tokens in cast kernel
Johnsonms Feb 22, 2026
9054279
[jit_kernel] Unify FP8 cast helpers into dtype_trait system
Johnsonms Feb 24, 2026
0e876fa
[jit_kernel] Fix benchmark crash and rename variables in cast kernel
Johnsonms Feb 24, 2026
dc05df7
[jit_kernel] Add __restrict__ to cast kernel and fix bandwidth benchmark
Johnsonms Feb 24, 2026
8da281b
[jit_kernel] Address review comments on cast kernel
Johnsonms Feb 28, 2026
4053820
Merge branch 'main' into jit-kernel-cast
Johnsonms Feb 28, 2026
37c3c0a
Merge branch 'main' into jit-kernel-cast
Johnsonms Mar 11, 2026
b16524d
Merge branch 'main' into jit-kernel-cast
Johnsonms Mar 12, 2026
41547cb
Merge branch 'main' into jit-kernel-cast
Johnsonms Mar 12, 2026
fe1ad51
Merge branch 'main' into jit-kernel-cast
Johnsonms Mar 13, 2026
4b8d668
Merge branch 'main' into jit-kernel-cast
Johnsonms Mar 13, 2026
4be9d9b
[jit_kernel] Fix cast.cu build error due to conflicting headers
Johnsonms Mar 14, 2026
20f0d36
[jit_kernel] Fix lint issues: reorder imports per isort rules
Johnsonms Mar 14, 2026
583780b
[jit_kernel] Revert unrelated lint changes from cast.cu fix commit
Johnsonms Mar 14, 2026
cd210ee
Merge branch 'main' into jit-kernel-cast
Johnsonms Mar 14, 2026
754ed28
Merge branch 'main' into jit-kernel-cast
Johnsonms Mar 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
import triton
import triton.testing
from sgl_kernel import downcast_fp8 as downcast_fp8_aot

from sglang.jit_kernel.benchmark.utils import (
DEFAULT_DEVICE,
get_benchmark_range,
run_benchmark,
)
from sglang.jit_kernel.cast import downcast_fp8 as downcast_fp8_jit

DEVICE = DEFAULT_DEVICE
DTYPE = torch.bfloat16


# ── Config ranges ──────────────────────────────────────────────────────────────

SL_LIST = get_benchmark_range(
full_range=[4, 16, 64, 256, 512, 1024, 2048],
ci_range=[4, 64],
)

HEAD_DIM_LIST = get_benchmark_range(
full_range=[(8, 128), (32, 128), (8, 256), (32, 256)],
ci_range=[(8, 128)],
)

CONFIGS = [(sl, h, d, sl * 2) for sl in SL_LIST for h, d in HEAD_DIM_LIST]

LINE_VALS = ["aot", "jit"]
LINE_NAMES = ["AOT (sgl-kernel)", "JIT (cast.cuh, 256 threads, 2D grid)"]
STYLES = [("blue", "--"), ("orange", "-")]


# ── Perf report ────────────────────────────────────────────────────────────────


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["input_sl", "head", "dim", "out_sl"],
x_vals=CONFIGS,
line_arg="provider",
line_vals=LINE_VALS,
line_names=LINE_NAMES,
styles=STYLES,
ylabel="us",
plot_name="downcast-fp8-aot-vs-jit",
args={},
)
)
def benchmark(input_sl, head, dim, out_sl, provider):
k = torch.randn(input_sl, head, dim, dtype=DTYPE, device=DEVICE)
v = torch.randn(input_sl, head, dim, dtype=DTYPE, device=DEVICE)
k_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device=DEVICE)
v_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device=DEVICE)
k_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE)
v_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE)
loc = torch.arange(input_sl, dtype=torch.int64, device=DEVICE)

if provider == "aot":
fn = lambda: downcast_fp8_aot(k, v, k_out, v_out, k_scale, v_scale, loc)
else:
fn = lambda: downcast_fp8_jit(k, v, k_out, v_out, k_scale, v_scale, loc)

return run_benchmark(fn)


# ── Bandwidth analysis ─────────────────────────────────────────────────────────


def _report_bandwidth(input_sl, head, dim, dtype):
elem_bytes = torch.finfo(dtype).bits // 8
total_bytes = input_sl * head * dim * (2 * elem_bytes + 2)

k = torch.randn(input_sl, head, dim, dtype=dtype, device=DEVICE)
v = torch.randn(input_sl, head, dim, dtype=dtype, device=DEVICE)
k_out = torch.zeros(input_sl * 2, head, dim, dtype=torch.uint8, device=DEVICE)
v_out = torch.zeros(input_sl * 2, head, dim, dtype=torch.uint8, device=DEVICE)
k_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE)
v_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE)
loc = torch.arange(input_sl, dtype=torch.int64, device=DEVICE)

aot_fn = lambda: downcast_fp8_aot(k, v, k_out, v_out, k_scale, v_scale, loc)
jit_fn = lambda: downcast_fp8_jit(k, v, k_out, v_out, k_scale, v_scale, loc)

aot_ms, _, _ = triton.testing.do_bench(aot_fn, quantiles=[0.5, 0.2, 0.8])
jit_ms, _, _ = triton.testing.do_bench(jit_fn, quantiles=[0.5, 0.2, 0.8])

def fmt(ms):
return f"{ms*1000:6.2f}us {total_bytes/(ms*1e-3)/1e9:6.0f}GB/s"

print(
f" sl={input_sl:5d} h={head:2d} d={dim:4d}"
f" | aot {fmt(aot_ms)}"
f" | jit {fmt(jit_ms)}"
f" | speedup {aot_ms/jit_ms:.2f}x"
)


def report_bandwidth():
print(f"\n{'='*95}")
print(" AOT (sgl-kernel) vs JIT (cast.cuh, 256 threads, 2D grid)")
print(f" dtype={DTYPE}, device={DEVICE}")
print(f"{'='*95}")
for sl in [64, 256, 1024, 2048]:
for h, d in [(8, 128), (32, 128), (8, 256), (32, 256)]:
_report_bandwidth(sl, h, d, DTYPE)
print()


if __name__ == "__main__":
benchmark.run(print_data=True)
report_bandwidth()
52 changes: 52 additions & 0 deletions python/sglang/jit_kernel/cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args

if TYPE_CHECKING:
from tvm_ffi.module import Module


@cache_once
def _jit_cast_module(dtype: torch.dtype) -> Module:
args = make_cpp_args(dtype)
return load_jit(
"cast",
*args,
cuda_files=["elementwise/cast.cuh"],
cuda_wrappers=[("downcast_fp8", f"downcast_fp8<{args}>")],
)


def downcast_fp8(
k: torch.Tensor,
v: torch.Tensor,
k_out: torch.Tensor,
v_out: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
loc: torch.Tensor,
mult: int = 1,
offset: int = 0,
) -> None:
"""Fused downcast of KV cache tensors from bf16/fp16 to fp8 (E4M3).

Scales each value by the inverse of its per-tensor scale, clamps to the
fp8 representable range [-448, 448], then converts to fp8 storage.

Args:
k: [input_sl, head, dim] bf16/fp16 CUDA tensor
v: [input_sl, head, dim] bf16/fp16 CUDA tensor
k_out: [out_sl, head, dim] uint8 CUDA tensor (fp8 storage)
v_out: [out_sl, head, dim] uint8 CUDA tensor (fp8 storage)
k_scale: [1] float32 CUDA tensor, scale for k
v_scale: [1] float32 CUDA tensor, scale for v
loc: [input_sl] int64 CUDA tensor, destination sequence indices
mult: stride multiplier for output index (default 1)
offset: offset added to output index (default 0)
"""
module = _jit_cast_module(k.dtype)
module.downcast_fp8(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset)
137 changes: 137 additions & 0 deletions python/sglang/jit_kernel/csrc/elementwise/cast.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#pragma once

// Optimized cast kernel: fixed 256 threads, scaled out via 2D grid.
// Each thread handles exactly one float4 (kVecSize fp16/bf16 elements).
// No per-thread loop — pure grid scaling for any head*dim.

#include <sgl_kernel/tensor.h>
#include <sgl_kernel/utils.h>

#include <sgl_kernel/type.cuh> // For dtype_trait fp8 specialization
#include <sgl_kernel/utils.cuh> // For LaunchKernel
#include <sgl_kernel/vec.cuh> // For AlignedVector

#include <dlpack/dlpack.h>
#include <tvm/ffi/container/tensor.h>

#include <cstdint>

namespace {

constexpr int kBlockSize = 256;

template <typename T>
__global__ void fused_downcast_kernel(
const T* __restrict__ cache_k,
const T* __restrict__ cache_v,
const float* __restrict__ k_scale,
const float* __restrict__ v_scale,
fp8_e4m3_t* __restrict__ output_k,
fp8_e4m3_t* __restrict__ output_v,
const int input_num_tokens,
const int head,
const int dim,
const T max_fp8,
const T min_fp8,
const int64_t mult,
const int64_t offset,
const int64_t* __restrict__ loc) {
using namespace device;

constexpr int kVecSize = 16 / sizeof(T);
using vec_t = AlignedVector<T, kVecSize>;
using out_vec_t = AlignedVector<fp8_e4m3_t, kVecSize>;

const int token_idx = blockIdx.x;
const int vec_idx = blockIdx.y * kBlockSize + threadIdx.x;
const int num_vecs = head * dim / kVecSize;

if (token_idx >= input_num_tokens || vec_idx >= num_vecs) return;

T k_scale_inv = static_cast<T>(1.f) / cast<T>(k_scale[0]);
T v_scale_inv = static_cast<T>(1.f) / cast<T>(v_scale[0]);

auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); };

const int out_seq_idx = loc[token_idx];
const T* in_k_base = cache_k + token_idx * head * dim;
const T* in_v_base = cache_v + token_idx * head * dim;
fp8_e4m3_t* out_k_base = output_k + (out_seq_idx * mult + offset) * head * dim;
fp8_e4m3_t* out_v_base = output_v + (out_seq_idx * mult + offset) * head * dim;

vec_t k_vec, v_vec;
k_vec.load(in_k_base, vec_idx);
v_vec.load(in_v_base, vec_idx);

out_vec_t out_k, out_v;
#pragma unroll
for (int j = 0; j < kVecSize; j++) {
out_k[j] = cast<fp8_e4m3_t>(clamp(k_vec[j] * k_scale_inv));
out_v[j] = cast<fp8_e4m3_t>(clamp(v_vec[j] * v_scale_inv));
}

out_k.store(out_k_base, vec_idx);
out_v.store(out_v_base, vec_idx);
}

template <typename T>
void downcast_fp8(
tvm::ffi::TensorView k,
tvm::ffi::TensorView v,
tvm::ffi::TensorView k_out,
tvm::ffi::TensorView v_out,
tvm::ffi::TensorView k_scale,
tvm::ffi::TensorView v_scale,
tvm::ffi::TensorView loc,
int64_t mult,
int64_t offset) {
using namespace host;

auto input_num_tokens = SymbolicSize{"input_num_tokens"};
auto head = SymbolicSize{"head"};
auto dim = SymbolicSize{"dim"};
auto output_num_tokens = SymbolicSize{"out_sl"};
auto device = SymbolicDevice{};
device.set_options<kDLCUDA>();

TensorMatcher({input_num_tokens, head, dim}).with_dtype<T>().with_device(device).verify(k);
TensorMatcher({input_num_tokens, head, dim}).with_dtype<T>().with_device(device).verify(v);
TensorMatcher({output_num_tokens, head, dim}).with_dtype<uint8_t>().with_device(device).verify(k_out);
TensorMatcher({output_num_tokens, head, dim}).with_dtype<uint8_t>().with_device(device).verify(v_out);
TensorMatcher({1}).with_dtype<float>().with_device(device).verify(k_scale);
TensorMatcher({1}).with_dtype<float>().with_device(device).verify(v_scale);
TensorMatcher({input_num_tokens}).with_dtype<int64_t>().with_device(device).verify(loc);

const int num_tokens = static_cast<int>(input_num_tokens.unwrap());
const int h = static_cast<int>(head.unwrap());
const int d = static_cast<int>(dim.unwrap());

constexpr int kVecSize = 16 / sizeof(T);
const int num_vecs = h * d / kVecSize;
const int grid_y = (num_vecs + kBlockSize - 1) / kBlockSize;

dim3 grid(num_tokens, grid_y);
dim3 block(kBlockSize);

const T max_fp8 = static_cast<T>(kFP8E4M3Max);
const T min_fp8 = static_cast<T>(-kFP8E4M3Max);

LaunchKernel(grid, block, device.unwrap())(
fused_downcast_kernel<T>,
static_cast<const T*>(k.data_ptr()),
static_cast<const T*>(v.data_ptr()),
static_cast<const float*>(k_scale.data_ptr()),
static_cast<const float*>(v_scale.data_ptr()),
static_cast<fp8_e4m3_t*>(k_out.data_ptr()),
static_cast<fp8_e4m3_t*>(v_out.data_ptr()),
num_tokens,
h,
d,
max_fp8,
min_fp8,
mult,
offset,
static_cast<const int64_t*>(loc.data_ptr()));
}

} // namespace
20 changes: 20 additions & 0 deletions python/sglang/jit_kernel/include/sgl_kernel/type.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ SGL_REGISTER_DTYPE_TRAIT(
SGL_REGISTER_DTYPE_TRAIT(
bf16x2_t, void, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp32x2_t, __float22bfloat162_rn););

#ifndef USE_ROCM
SGL_REGISTER_DTYPE_TRAIT(fp8_e4m3_t, fp8x2_e4m3_t);
#endif

#undef SGL_REGISTER_DTYPE_TRAIT
#undef SGL_REGISTER_FROM_FUNCTION

Expand All @@ -98,3 +102,19 @@ SGL_DEVICE To cast(const From& value) {
}

} // namespace device

// ---------------------------------------------------------------------------
// FP8 max clamp value — platform-dependent
// CUDA (e4m3fn): 448.0f
// AMD FNUZ (e4m3fnuz): 224.0f
// AMD E4M3 (e4m3fn): 448.0f
// ---------------------------------------------------------------------------
#ifndef USE_ROCM
constexpr float kFP8E4M3Max = 448.0f;
#else // USE_ROCM
#if HIP_FP8_TYPE_FNUZ
constexpr float kFP8E4M3Max = 224.0f;
#else // HIP_FP8_TYPE_E4M3
constexpr float kFP8E4M3Max = 448.0f;
#endif // HIP_FP8_TYPE_FNUZ
#endif // USE_ROCM
Loading
Loading