diff --git a/python/sglang/jit_kernel/benchmark/bench_cast.py b/python/sglang/jit_kernel/benchmark/bench_cast.py new file mode 100644 index 000000000000..18dbbf726f99 --- /dev/null +++ b/python/sglang/jit_kernel/benchmark/bench_cast.py @@ -0,0 +1,114 @@ +import torch +import triton +import triton.testing +from sgl_kernel import downcast_fp8 as downcast_fp8_aot + +from sglang.jit_kernel.benchmark.utils import ( + DEFAULT_DEVICE, + get_benchmark_range, + run_benchmark, +) +from sglang.jit_kernel.cast import downcast_fp8 as downcast_fp8_jit + +DEVICE = DEFAULT_DEVICE +DTYPE = torch.bfloat16 + + +# ── Config ranges ────────────────────────────────────────────────────────────── + +SL_LIST = get_benchmark_range( + full_range=[4, 16, 64, 256, 512, 1024, 2048], + ci_range=[4, 64], +) + +HEAD_DIM_LIST = get_benchmark_range( + full_range=[(8, 128), (32, 128), (8, 256), (32, 256)], + ci_range=[(8, 128)], +) + +CONFIGS = [(sl, h, d, sl * 2) for sl in SL_LIST for h, d in HEAD_DIM_LIST] + +LINE_VALS = ["aot", "jit"] +LINE_NAMES = ["AOT (sgl-kernel)", "JIT (cast.cuh, 256 threads, 2D grid)"] +STYLES = [("blue", "--"), ("orange", "-")] + + +# ── Perf report ──────────────────────────────────────────────────────────────── + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["input_sl", "head", "dim", "out_sl"], + x_vals=CONFIGS, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="downcast-fp8-aot-vs-jit", + args={}, + ) +) +def benchmark(input_sl, head, dim, out_sl, provider): + k = torch.randn(input_sl, head, dim, dtype=DTYPE, device=DEVICE) + v = torch.randn(input_sl, head, dim, dtype=DTYPE, device=DEVICE) + k_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device=DEVICE) + v_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device=DEVICE) + k_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) + v_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) + loc = torch.arange(input_sl, dtype=torch.int64, device=DEVICE) + + if provider == "aot": + fn = lambda: downcast_fp8_aot(k, v, k_out, v_out, k_scale, v_scale, loc) + else: + fn = lambda: downcast_fp8_jit(k, v, k_out, v_out, k_scale, v_scale, loc) + + return run_benchmark(fn) + + +# ── Bandwidth analysis ───────────────────────────────────────────────────────── + + +def _report_bandwidth(input_sl, head, dim, dtype): + elem_bytes = torch.finfo(dtype).bits // 8 + total_bytes = input_sl * head * dim * (2 * elem_bytes + 2) + + k = torch.randn(input_sl, head, dim, dtype=dtype, device=DEVICE) + v = torch.randn(input_sl, head, dim, dtype=dtype, device=DEVICE) + k_out = torch.zeros(input_sl * 2, head, dim, dtype=torch.uint8, device=DEVICE) + v_out = torch.zeros(input_sl * 2, head, dim, dtype=torch.uint8, device=DEVICE) + k_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) + v_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) + loc = torch.arange(input_sl, dtype=torch.int64, device=DEVICE) + + aot_fn = lambda: downcast_fp8_aot(k, v, k_out, v_out, k_scale, v_scale, loc) + jit_fn = lambda: downcast_fp8_jit(k, v, k_out, v_out, k_scale, v_scale, loc) + + aot_ms, _, _ = triton.testing.do_bench(aot_fn, quantiles=[0.5, 0.2, 0.8]) + jit_ms, _, _ = triton.testing.do_bench(jit_fn, quantiles=[0.5, 0.2, 0.8]) + + def fmt(ms): + return f"{ms*1000:6.2f}us {total_bytes/(ms*1e-3)/1e9:6.0f}GB/s" + + print( + f" sl={input_sl:5d} h={head:2d} d={dim:4d}" + f" | aot {fmt(aot_ms)}" + f" | jit {fmt(jit_ms)}" + f" | speedup {aot_ms/jit_ms:.2f}x" + ) + + +def report_bandwidth(): + print(f"\n{'='*95}") + print(" AOT (sgl-kernel) vs JIT (cast.cuh, 256 threads, 2D grid)") + print(f" dtype={DTYPE}, device={DEVICE}") + print(f"{'='*95}") + for sl in [64, 256, 1024, 2048]: + for h, d in [(8, 128), (32, 128), (8, 256), (32, 256)]: + _report_bandwidth(sl, h, d, DTYPE) + print() + + +if __name__ == "__main__": + benchmark.run(print_data=True) + report_bandwidth() diff --git a/python/sglang/jit_kernel/cast.py b/python/sglang/jit_kernel/cast.py new file mode 100644 index 000000000000..f0201c4aba20 --- /dev/null +++ b/python/sglang/jit_kernel/cast.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@cache_once +def _jit_cast_module(dtype: torch.dtype) -> Module: + args = make_cpp_args(dtype) + return load_jit( + "cast", + *args, + cuda_files=["elementwise/cast.cuh"], + cuda_wrappers=[("downcast_fp8", f"downcast_fp8<{args}>")], + ) + + +def downcast_fp8( + k: torch.Tensor, + v: torch.Tensor, + k_out: torch.Tensor, + v_out: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + loc: torch.Tensor, + mult: int = 1, + offset: int = 0, +) -> None: + """Fused downcast of KV cache tensors from bf16/fp16 to fp8 (E4M3). + + Scales each value by the inverse of its per-tensor scale, clamps to the + fp8 representable range [-448, 448], then converts to fp8 storage. + + Args: + k: [input_sl, head, dim] bf16/fp16 CUDA tensor + v: [input_sl, head, dim] bf16/fp16 CUDA tensor + k_out: [out_sl, head, dim] uint8 CUDA tensor (fp8 storage) + v_out: [out_sl, head, dim] uint8 CUDA tensor (fp8 storage) + k_scale: [1] float32 CUDA tensor, scale for k + v_scale: [1] float32 CUDA tensor, scale for v + loc: [input_sl] int64 CUDA tensor, destination sequence indices + mult: stride multiplier for output index (default 1) + offset: offset added to output index (default 0) + """ + module = _jit_cast_module(k.dtype) + module.downcast_fp8(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset) diff --git a/python/sglang/jit_kernel/csrc/elementwise/cast.cuh b/python/sglang/jit_kernel/csrc/elementwise/cast.cuh new file mode 100644 index 000000000000..f537ddc58819 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/elementwise/cast.cuh @@ -0,0 +1,137 @@ +#pragma once + +// Optimized cast kernel: fixed 256 threads, scaled out via 2D grid. +// Each thread handles exactly one float4 (kVecSize fp16/bf16 elements). +// No per-thread loop — pure grid scaling for any head*dim. + +#include +#include + +#include // For dtype_trait fp8 specialization +#include // For LaunchKernel +#include // For AlignedVector + +#include +#include + +#include + +namespace { + +constexpr int kBlockSize = 256; + +template +__global__ void fused_downcast_kernel( + const T* __restrict__ cache_k, + const T* __restrict__ cache_v, + const float* __restrict__ k_scale, + const float* __restrict__ v_scale, + fp8_e4m3_t* __restrict__ output_k, + fp8_e4m3_t* __restrict__ output_v, + const int input_num_tokens, + const int head, + const int dim, + const T max_fp8, + const T min_fp8, + const int64_t mult, + const int64_t offset, + const int64_t* __restrict__ loc) { + using namespace device; + + constexpr int kVecSize = 16 / sizeof(T); + using vec_t = AlignedVector; + using out_vec_t = AlignedVector; + + const int token_idx = blockIdx.x; + const int vec_idx = blockIdx.y * kBlockSize + threadIdx.x; + const int num_vecs = head * dim / kVecSize; + + if (token_idx >= input_num_tokens || vec_idx >= num_vecs) return; + + T k_scale_inv = static_cast(1.f) / cast(k_scale[0]); + T v_scale_inv = static_cast(1.f) / cast(v_scale[0]); + + auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); }; + + const int out_seq_idx = loc[token_idx]; + const T* in_k_base = cache_k + token_idx * head * dim; + const T* in_v_base = cache_v + token_idx * head * dim; + fp8_e4m3_t* out_k_base = output_k + (out_seq_idx * mult + offset) * head * dim; + fp8_e4m3_t* out_v_base = output_v + (out_seq_idx * mult + offset) * head * dim; + + vec_t k_vec, v_vec; + k_vec.load(in_k_base, vec_idx); + v_vec.load(in_v_base, vec_idx); + + out_vec_t out_k, out_v; +#pragma unroll + for (int j = 0; j < kVecSize; j++) { + out_k[j] = cast(clamp(k_vec[j] * k_scale_inv)); + out_v[j] = cast(clamp(v_vec[j] * v_scale_inv)); + } + + out_k.store(out_k_base, vec_idx); + out_v.store(out_v_base, vec_idx); +} + +template +void downcast_fp8( + tvm::ffi::TensorView k, + tvm::ffi::TensorView v, + tvm::ffi::TensorView k_out, + tvm::ffi::TensorView v_out, + tvm::ffi::TensorView k_scale, + tvm::ffi::TensorView v_scale, + tvm::ffi::TensorView loc, + int64_t mult, + int64_t offset) { + using namespace host; + + auto input_num_tokens = SymbolicSize{"input_num_tokens"}; + auto head = SymbolicSize{"head"}; + auto dim = SymbolicSize{"dim"}; + auto output_num_tokens = SymbolicSize{"out_sl"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({input_num_tokens, head, dim}).with_dtype().with_device(device).verify(k); + TensorMatcher({input_num_tokens, head, dim}).with_dtype().with_device(device).verify(v); + TensorMatcher({output_num_tokens, head, dim}).with_dtype().with_device(device).verify(k_out); + TensorMatcher({output_num_tokens, head, dim}).with_dtype().with_device(device).verify(v_out); + TensorMatcher({1}).with_dtype().with_device(device).verify(k_scale); + TensorMatcher({1}).with_dtype().with_device(device).verify(v_scale); + TensorMatcher({input_num_tokens}).with_dtype().with_device(device).verify(loc); + + const int num_tokens = static_cast(input_num_tokens.unwrap()); + const int h = static_cast(head.unwrap()); + const int d = static_cast(dim.unwrap()); + + constexpr int kVecSize = 16 / sizeof(T); + const int num_vecs = h * d / kVecSize; + const int grid_y = (num_vecs + kBlockSize - 1) / kBlockSize; + + dim3 grid(num_tokens, grid_y); + dim3 block(kBlockSize); + + const T max_fp8 = static_cast(kFP8E4M3Max); + const T min_fp8 = static_cast(-kFP8E4M3Max); + + LaunchKernel(grid, block, device.unwrap())( + fused_downcast_kernel, + static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + static_cast(k_scale.data_ptr()), + static_cast(v_scale.data_ptr()), + static_cast(k_out.data_ptr()), + static_cast(v_out.data_ptr()), + num_tokens, + h, + d, + max_fp8, + min_fp8, + mult, + offset, + static_cast(loc.data_ptr())); +} + +} // namespace diff --git a/python/sglang/jit_kernel/include/sgl_kernel/type.cuh b/python/sglang/jit_kernel/include/sgl_kernel/type.cuh index 15018fbb293a..a7a534619696 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/type.cuh +++ b/python/sglang/jit_kernel/include/sgl_kernel/type.cuh @@ -77,6 +77,10 @@ SGL_REGISTER_DTYPE_TRAIT( SGL_REGISTER_DTYPE_TRAIT( bf16x2_t, void, SGL_REGISTER_TYPE_END; SGL_REGISTER_FROM_FUNCTION(fp32x2_t, __float22bfloat162_rn);); +#ifndef USE_ROCM +SGL_REGISTER_DTYPE_TRAIT(fp8_e4m3_t, fp8x2_e4m3_t); +#endif + #undef SGL_REGISTER_DTYPE_TRAIT #undef SGL_REGISTER_FROM_FUNCTION @@ -98,3 +102,19 @@ SGL_DEVICE To cast(const From& value) { } } // namespace device + +// --------------------------------------------------------------------------- +// FP8 max clamp value — platform-dependent +// CUDA (e4m3fn): 448.0f +// AMD FNUZ (e4m3fnuz): 224.0f +// AMD E4M3 (e4m3fn): 448.0f +// --------------------------------------------------------------------------- +#ifndef USE_ROCM +constexpr float kFP8E4M3Max = 448.0f; +#else // USE_ROCM +#if HIP_FP8_TYPE_FNUZ +constexpr float kFP8E4M3Max = 224.0f; +#else // HIP_FP8_TYPE_E4M3 +constexpr float kFP8E4M3Max = 448.0f; +#endif // HIP_FP8_TYPE_FNUZ +#endif // USE_ROCM diff --git a/python/sglang/jit_kernel/tests/test_cast.py b/python/sglang/jit_kernel/tests/test_cast.py new file mode 100644 index 000000000000..6a71dc194214 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_cast.py @@ -0,0 +1,310 @@ +import pytest +import torch + +from sglang.jit_kernel.cast import downcast_fp8 + +DTYPES = [torch.bfloat16, torch.float16] + +# FP8 E4M3 representable range (matches kFP8E4M3Max in type.cuh) +_FP8_E4M3_MAX = 448.0 + + +def _run(input_sl, head, dim, out_sl, dtype): + k = torch.randn(input_sl, head, dim, dtype=dtype, device="cuda") + v = torch.randn(input_sl, head, dim, dtype=dtype, device="cuda") + k_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + v_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + k_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + v_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + loc = torch.arange(input_sl, dtype=torch.int64, device="cuda") + downcast_fp8(k, v, k_out, v_out, k_scale, v_scale, loc) + return k_out, v_out + + +def _ref_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Reference: replicate kernel precision — scale_inv in dtype T, then to fp8. + + Mirrors the kernel logic: + scale_inv = cast(1.0f) / cast(scale[0]) + out[j] = cast(clamp(x[j] * scale_inv)) + """ + dtype = x.dtype + scale_inv = x.new_ones(1) / scale[0].to(dtype) + x_scaled = (x * scale_inv).clamp(-_FP8_E4M3_MAX, _FP8_E4M3_MAX) + return x_scaled.to(torch.float8_e4m3fn).view(torch.uint8) + + +def _ref_downcast( + x: torch.Tensor, + scale: torch.Tensor, + loc: torch.Tensor, + out_sl: int, + mult: int = 1, + offset: int = 0, +) -> torch.Tensor: + """Scatter _ref_fp8 output to the correct output slots via loc/mult/offset.""" + head, dim = x.shape[1], x.shape[2] + out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device=x.device) + fp8 = _ref_fp8(x, scale) + for i, dst in enumerate(loc.tolist()): + out[dst * mult + offset] = fp8[i] + return out + + +# --------------------------------------------------------------------------- +# Existing sanity test +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("input_sl,head,dim,out_sl", [(4, 8, 128, 16)]) +def test_downcast_fp8(input_sl, head, dim, out_sl, dtype): + k = torch.randn(input_sl, head, dim, dtype=dtype, device="cuda") + v = torch.randn(input_sl, head, dim, dtype=dtype, device="cuda") + k_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + v_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + loc = torch.arange(input_sl, dtype=torch.int64, device="cuda") + + k_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + v_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + downcast_fp8(k, v, k_out, v_out, k_scale, v_scale, loc) + + # Verify written slots are non-zero (fp8 of random non-zero values) + assert k_out[:input_sl].any(), "k_out should have non-zero fp8 values" + assert v_out[:input_sl].any(), "v_out should have non-zero fp8 values" + # Verify unwritten slots remain zero + assert not k_out[input_sl:].any(), "k_out slots beyond input_sl should be zero" + assert not v_out[input_sl:].any(), "v_out slots beyond input_sl should be zero" + + +# --------------------------------------------------------------------------- +# Numerical correctness: kernel output must match PyTorch fp8 reference. +# This verifies that cast(float) and cast(T) produce the +# same bit patterns as the removed ConvertFromFloat / ConvertToFP8 structs. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("input_sl,head,dim,out_sl", [(4, 8, 128, 16), (1, 4, 64, 8)]) +def test_downcast_fp8_matches_reference(input_sl, head, dim, out_sl, dtype): + torch.manual_seed(42) + k = torch.randn(input_sl, head, dim, dtype=dtype, device="cuda") + v = torch.randn(input_sl, head, dim, dtype=dtype, device="cuda") + k_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + v_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + loc = torch.arange(input_sl, dtype=torch.int64, device="cuda") + + k_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + v_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + downcast_fp8(k, v, k_out, v_out, k_scale, v_scale, loc) + + k_ref = _ref_downcast(k, k_scale, loc, out_sl) + v_ref = _ref_downcast(v, v_scale, loc, out_sl) + + torch.testing.assert_close(k_out, k_ref, msg="k: kernel vs reference mismatch") + torch.testing.assert_close(v_out, v_ref, msg="v: kernel vs reference mismatch") + + +# --------------------------------------------------------------------------- +# Scale: a non-unit scale divides the values before fp8 conversion. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("scale_val", [0.5, 2.0, 0.1]) +def test_downcast_fp8_scale(scale_val, dtype): + torch.manual_seed(0) + input_sl, head, dim, out_sl = 4, 4, 64, 8 + + k = torch.randn(input_sl, head, dim, dtype=dtype, device="cuda") + v = torch.randn(input_sl, head, dim, dtype=dtype, device="cuda") + k_scale = torch.tensor([scale_val], dtype=torch.float32, device="cuda") + v_scale = torch.tensor([scale_val], dtype=torch.float32, device="cuda") + loc = torch.arange(input_sl, dtype=torch.int64, device="cuda") + + k_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + v_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + downcast_fp8(k, v, k_out, v_out, k_scale, v_scale, loc) + + k_ref = _ref_downcast(k, k_scale, loc, out_sl) + v_ref = _ref_downcast(v, v_scale, loc, out_sl) + + torch.testing.assert_close( + k_out, k_ref, msg=f"scale={scale_val}: kernel vs reference mismatch" + ) + torch.testing.assert_close( + v_out, v_ref, msg=f"scale={scale_val}: kernel vs reference mismatch" + ) + + +# --------------------------------------------------------------------------- +# Clamping: values exceeding ±448 must be saturated to fp8 max/min. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("dtype", DTYPES) +def test_downcast_fp8_clamp(dtype): + input_sl, head, dim, out_sl = 2, 1, 8, 4 + + # All values well outside fp8 range so clamping is unavoidable. + k = torch.full((input_sl, head, dim), 1000.0, dtype=dtype, device="cuda") + v = torch.full((input_sl, head, dim), -1000.0, dtype=dtype, device="cuda") + scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + loc = torch.arange(input_sl, dtype=torch.int64, device="cuda") + + k_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + v_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + downcast_fp8(k, v, k_out, v_out, scale, scale, loc) + + # Reference fp8 max/min byte values (E4M3: 0x7e = 448.0, 0xfe = -448.0) + fp8_pos_max = ( + torch.tensor([_FP8_E4M3_MAX], dtype=dtype, device="cuda") + .to(torch.float8_e4m3fn) + .view(torch.uint8) + .item() + ) + fp8_neg_max = ( + torch.tensor([-_FP8_E4M3_MAX], dtype=dtype, device="cuda") + .to(torch.float8_e4m3fn) + .view(torch.uint8) + .item() + ) + + assert ( + k_out[:input_sl] == fp8_pos_max + ).all(), "large positive values should clamp to fp8 max" + assert ( + v_out[:input_sl] == fp8_neg_max + ).all(), "large negative values should clamp to fp8 min" + + +# --------------------------------------------------------------------------- +# Scatter: loc controls which output rows receive the converted values. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("dtype", DTYPES) +def test_downcast_fp8_loc(dtype): + torch.manual_seed(7) + input_sl, head, dim, out_sl = 3, 2, 32, 10 + + k = torch.randn(input_sl, head, dim, dtype=dtype, device="cuda") + v = torch.randn(input_sl, head, dim, dtype=dtype, device="cuda") + scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + + # Write to non-contiguous output positions: 0, 5, 9 + loc = torch.tensor([0, 5, 9], dtype=torch.int64, device="cuda") + + k_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + v_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + downcast_fp8(k, v, k_out, v_out, scale, scale, loc) + + k_ref = _ref_downcast(k, scale, loc, out_sl) + v_ref = _ref_downcast(v, scale, loc, out_sl) + + torch.testing.assert_close( + k_out, k_ref, msg="loc scatter: kernel vs reference mismatch" + ) + torch.testing.assert_close( + v_out, v_ref, msg="loc scatter: kernel vs reference mismatch" + ) + + # Slots not in loc must remain zero + written = {0, 5, 9} + for i in range(out_sl): + if i not in written: + assert not k_out[i].any(), f"k_out[{i}] should be zero (not a loc target)" + assert not v_out[i].any(), f"v_out[{i}] should be zero (not a loc target)" + + +# --------------------------------------------------------------------------- +# mult/offset: output index = loc[i] * mult + offset +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("mult,offset", [(2, 0), (1, 3), (2, 1)]) +def test_downcast_fp8_mult_offset(mult, offset, dtype): + torch.manual_seed(3) + input_sl, head, dim = 2, 2, 32 + out_sl = input_sl * mult + offset + 4 # ensure output is large enough + + k = torch.randn(input_sl, head, dim, dtype=dtype, device="cuda") + v = torch.randn(input_sl, head, dim, dtype=dtype, device="cuda") + scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + loc = torch.arange(input_sl, dtype=torch.int64, device="cuda") + + k_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + v_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + downcast_fp8(k, v, k_out, v_out, scale, scale, loc, mult=mult, offset=offset) + + k_ref = _ref_downcast(k, scale, loc, out_sl, mult=mult, offset=offset) + v_ref = _ref_downcast(v, scale, loc, out_sl, mult=mult, offset=offset) + + torch.testing.assert_close( + k_out, k_ref, msg=f"mult={mult},offset={offset}: kernel vs reference mismatch" + ) + torch.testing.assert_close( + v_out, v_ref, msg=f"mult={mult},offset={offset}: kernel vs reference mismatch" + ) + + +# --------------------------------------------------------------------------- +# static_cast conversion: verify static_cast matches PyTorch fp8 +# for a comprehensive sweep including values near and at the fp8 boundary. +# This specifically validates that the static_cast fallback (used after +# removing explicit __nv_cvt_*raw_to_fp8 from dtype_trait) produces the +# same bit patterns as the reference path. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("dtype", DTYPES) +def test_downcast_fp8_static_cast_boundary(dtype): + """Test conversion accuracy near ±448 fp8 boundary using static_cast path.""" + torch.manual_seed(0) + # Values specifically chosen to stress the static_cast conversion path: + # - exactly at ±448 (representable fp8 max) + # - just inside the range + # - just outside (must saturate) + # - zero, small, and mid-range values + boundary_vals = [ + 0.0, + 1.0, + -1.0, + 100.0, + -100.0, + 447.0, + -447.0, + 448.0, + -448.0, + 449.0, + -449.0, + 1000.0, + -1000.0, + ] + input_sl = len(boundary_vals) + head, dim, out_sl = 1, 8, input_sl + + base = torch.tensor(boundary_vals, dtype=dtype, device="cuda") + k = base.unsqueeze(1).unsqueeze(2).expand(input_sl, head, dim).contiguous() + v = (-base).unsqueeze(1).unsqueeze(2).expand(input_sl, head, dim).contiguous() + scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + loc = torch.arange(input_sl, dtype=torch.int64, device="cuda") + + k_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + v_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device="cuda") + downcast_fp8(k, v, k_out, v_out, scale, scale, loc) + + k_ref = _ref_downcast(k, scale, loc, out_sl) + v_ref = _ref_downcast(v, scale, loc, out_sl) + + torch.testing.assert_close( + k_out, k_ref, msg="boundary values: k static_cast vs reference mismatch" + ) + torch.testing.assert_close( + v_out, v_ref, msg="boundary values: v static_cast vs reference mismatch" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/sgl-kernel/csrc/elementwise/cast.cu b/sgl-kernel/csrc/elementwise/cast.cu index 3ce8135debdf..a6a3b31a1d1a 100644 --- a/sgl-kernel/csrc/elementwise/cast.cu +++ b/sgl-kernel/csrc/elementwise/cast.cu @@ -1,4 +1,6 @@ -#include "pytorch_extension_utils.h" +#include + +#include "utils.h" template struct ConvertToFP8 { @@ -118,8 +120,8 @@ void downcast_fp8_impl( int vec_size = 8; dim3 block(std::min(int(dim) / vec_size, 1024)); - const T max_fp8 = static_cast(448.0f); - const T min_fp8 = static_cast(-448.0f); + const T max_fp8 = static_cast(FP8_E4M3_MAX); + const T min_fp8 = static_cast(-FP8_E4M3_MAX); fused_downcast_kernel<<>>( static_cast(k.data_ptr()),