[jit_kernel] Migrate cast (downcast_fp8) from sgl-kernel AOT to JIT#19103
Merged
BBuf merged 20 commits intosgl-project:mainfrom Mar 27, 2026
Merged
[jit_kernel] Migrate cast (downcast_fp8) from sgl-kernel AOT to JIT#19103BBuf merged 20 commits intosgl-project:mainfrom
BBuf merged 20 commits intosgl-project:mainfrom
Conversation
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
769d1c8 to
1fdbbfe
Compare
BBuf
reviewed
Feb 22, 2026
BBuf
reviewed
Feb 22, 2026
BBuf
reviewed
Feb 22, 2026
BBuf
reviewed
Feb 22, 2026
BBuf
reviewed
Feb 22, 2026
BBuf
reviewed
Feb 22, 2026
14b1bff to
bb53759
Compare
Contributor
Author
BBuf
approved these changes
Mar 14, 2026
Replace pytorch_extension_utils.h with the specific ATen/cuda/CUDAContext.h header to avoid redefinition conflicts (check_shape, pack_u16, is_float8_tensor) between sgl-kernel's utils.h and FlashInfer's pytorch_extension_utils.h.
975d134 to
4be9d9b
Compare
33de72c to
cd210ee
Compare
5 tasks
Collaborator
|
This PR (and #19059) introduced test files under Related:
|
Collaborator
|
3 tasks
satyamk7054
pushed a commit
to satyamk7054/sglang
that referenced
this pull request
Apr 3, 2026
JustinTong0323
pushed a commit
to JustinTong0323/sglang
that referenced
this pull request
Apr 7, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Motivation
#17865
downcast_fp8 is a fused kernel that casts KV cache tensors from bf16/fp16 to fp8 (E4M3), scaling and
clamping values in a single GPU pass. Migrating it from the AOT sgl-kernel build to the JIT
framework reduces build complexity and aligns with the ongoing effort to move lightweight kernels to
JIT.
Modifications
(bf16/fp16), input validation via TVM FFI TensorMatcher with pre-configured SymbolicDevice
common_extension.cc, sgl_kernel_ops.h, sgl_kernel/elementwise.py, and sgl_kernel/init.py
Accuracy Tests
python -m pytest python/sglang/jit_kernel/tests/test_cast.py -vBenchmarking and Profiling
python python/sglang/jit_kernel/benchmark/bench_cast.pyKV Cache FP8 Downcast Kernel Optimization
Technical Changes
Two independent optimizations applied to the downcast_fp8 kernel (cast.cuh), migrated from AOT to the JIT kernel framework:
Optimization 1 — Vectorized Memory Access (AlignedVector + tile::Memory)
Replaced scalar element-by-element memory access with 128-bit vectorized loads/stores:
Before:

// 1 element per iteration, LDG.32
for (int i = thread_idx; i < head * dim; i += total_threads)
output_k[out_idx] = ConvertToFP8(cache_k[in_idx] * scale_inv);
After:
// 8 elements per iteration, LDG.128
const auto gmem_in = tile::Memory<vec_t>::cta();
for (int i = 0; gmem_in.in_bound(num_vecs, i); i++) {
vec_t k_vec = gmem_in.load(in_k_base, i);
// unrolled fp8 conversion ...
gmem_out.store(out_k_base, out_k, i);
}
Optimization 2 — Fixed 256-Thread Blocks + 2D Grid Scaling
Replaced variable small block size with fixed 256 threads and scaled work via grid.y:
// No loop — one load → compute → store per thread
const int vec_idx = blockIdx.y * kBlockSize + threadIdx.x;
if (token_idx >= input_sl || vec_idx >= num_vecs) return;
k_vec.load(in_k_base, vec_idx);
// ... convert ...
out_k.store(out_k_base, vec_idx);
8 warps/block gives the warp scheduler sufficient warps to fully hide HBM latency (~400 cycles on H200).
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci