Skip to content

[jit_kernel] Migrate cast (downcast_fp8) from sgl-kernel AOT to JIT#19103

Merged
BBuf merged 20 commits intosgl-project:mainfrom
Johnsonms:jit-kernel-cast
Mar 27, 2026
Merged

[jit_kernel] Migrate cast (downcast_fp8) from sgl-kernel AOT to JIT#19103
BBuf merged 20 commits intosgl-project:mainfrom
Johnsonms:jit-kernel-cast

Conversation

@Johnsonms
Copy link
Copy Markdown
Contributor

@Johnsonms Johnsonms commented Feb 21, 2026

Motivation

#17865
downcast_fp8 is a fused kernel that casts KV cache tensors from bf16/fp16 to fp8 (E4M3), scaling and
clamping values in a single GPU pass. Migrating it from the AOT sgl-kernel build to the JIT
framework reduces build complexity and aligns with the ongoing effort to move lightweight kernels to
JIT.

Modifications

  • Add python/sglang/jit_kernel/csrc/elementwise/cast.cuh: CUDA kernel templated on dtype
    (bf16/fp16), input validation via TVM FFI TensorMatcher with pre-configured SymbolicDevice
  • Add python/sglang/jit_kernel/cast.py: JIT Python wrapper, module cached per dtype via @cache_once
  • Add python/sglang/jit_kernel/tests/test_cast.py: correctness test for bf16 and fp16
  • Add python/sglang/jit_kernel/benchmark/bench_cast.py: latency benchmark
  • Remove sgl-kernel/csrc/elementwise/cast.cu and AOT registration from CMakeLists.txt,
    common_extension.cc, sgl_kernel_ops.h, sgl_kernel/elementwise.py, and sgl_kernel/init.py

Accuracy Tests

python -m pytest python/sglang/jit_kernel/tests/test_cast.py -v
image

Benchmarking and Profiling

python python/sglang/jit_kernel/benchmark/bench_cast.py
image

KV Cache FP8 Downcast Kernel Optimization

Technical Changes

Two independent optimizations applied to the downcast_fp8 kernel (cast.cuh), migrated from AOT to the JIT kernel framework:

Optimization 1 — Vectorized Memory Access (AlignedVector + tile::Memory)

Replaced scalar element-by-element memory access with 128-bit vectorized loads/stores:

  • kVecSize = 16 / sizeof(T) — 8 elements per vector for bf16/fp16
  • Input: LDG.128 (16 bytes/thread) via AlignedVector<T, kVecSize>
  • Output: STG.64 (8 bytes/thread) via AlignedVector<uint8_t, kVecSize>
  • tile::Memory<vec_t>::cta() drives the strided loop; #pragma unroll on the inner 8-element conversion

Before:
// 1 element per iteration, LDG.32
for (int i = thread_idx; i < head * dim; i += total_threads)
output_k[out_idx] = ConvertToFP8(cache_k[in_idx] * scale_inv);
After:
// 8 elements per iteration, LDG.128
const auto gmem_in = tile::Memory<vec_t>::cta();
for (int i = 0; gmem_in.in_bound(num_vecs, i); i++) {
vec_t k_vec = gmem_in.load(in_k_base, i);
// unrolled fp8 conversion ...
gmem_out.store(out_k_base, out_k, i);
}
image

Optimization 2 — Fixed 256-Thread Blocks + 2D Grid Scaling

Replaced variable small block size with fixed 256 threads and scaled work via grid.y:

  • Before: grid=(input_sl,), block=d/kVecSize (e.g. 16 threads = 0.5 warps for d=128)
  • After: grid=(input_sl, ceil(num_vecs/256)), block=256 (8 warps)
  • Each thread handles exactly one float4 — no loop, pure grid scaling
  • vec_idx = blockIdx.y * 256 + threadIdx.x addresses elements directly

// No loop — one load → compute → store per thread
const int vec_idx = blockIdx.y * kBlockSize + threadIdx.x;
if (token_idx >= input_sl || vec_idx >= num_vecs) return;
k_vec.load(in_k_base, vec_idx);
// ... convert ...
out_k.store(out_k_base, vec_idx);

8 warps/block gives the warp scheduler sufficient warps to fully hide HBM latency (~400 cycles on H200).

image image

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Johnsonms Johnsonms changed the title [jit_kernel] Migrate downcast_fp8 from sgl-kernel AOT to JIT [jit_kernel] Migrate cast (downcast_fp8) from sgl-kernel AOT to JIT Feb 21, 2026
@Johnsonms Johnsonms force-pushed the jit-kernel-cast branch 3 times, most recently from 14b1bff to bb53759 Compare February 22, 2026 21:10
@Johnsonms Johnsonms requested a review from BBuf February 23, 2026 04:00
@Johnsonms
Copy link
Copy Markdown
Contributor Author

Using restrict to optimization:
image
Conclusion: restrict gives a noticeable improvement at larger shapes with more heads (h=32), where the compiler can now eliminate redundant reloads of k_scale[0]/v_scale[0] across the many threads operating on the same token. Small shapes are too fast to measure a difference (kernel is latency-bound, not bandwidth-bound there). The gain is consistent and real — up to ~17% on large shapes.

Replace pytorch_extension_utils.h with the specific ATen/cuda/CUDAContext.h
header to avoid redefinition conflicts (check_shape, pack_u16, is_float8_tensor)
between sgl-kernel's utils.h and FlashInfer's pytorch_extension_utils.h.
@BBuf BBuf merged commit 8a56a7b into sgl-project:main Mar 27, 2026
60 of 86 checks passed
@hnyls2002
Copy link
Copy Markdown
Collaborator

hnyls2002 commented Mar 27, 2026

This PR (and #19059) introduced test files under python/sglang/jit_kernel/tests/ without CI registry decorators (register_cuda_ci), which breaks run_suite.py collection since #21239 added strict sanity checks for that directory.

Related:

@BBuf @HaiShaw @Johnsonms

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Mar 27, 2026

This PR (and #19059) introduced test files under python/sglang/jit_kernel/tests/ without CI registry decorators (register_cuda_ci), which breaks run_suite.py collection since #21239 added strict sanity checks for that directory.

Related:

* Issue: [[Bug] Unregistered jit_kernel test file blocks CI #21538](https://github.com/sgl-project/sglang/issues/21538)

* Other missing-registry PR: [[jit_kernel] Add fused_qknorm_rope JIT kernel #19059](https://github.com/sgl-project/sglang/pull/19059)

* Fix PR: [[CI] Register jit_kernel Test Files to Solve No Registry Found Error #21541](https://github.com/sgl-project/sglang/pull/21541) (though it registers all tests as `disabled`, so they won't actually run)

@BBuf @HaiShaw @Johnsonms

#21547

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants