-
Notifications
You must be signed in to change notification settings - Fork 5.3k
[jit_kernel] Migrate cast (downcast_fp8) from sgl-kernel AOT to JIT #19103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
2edb0c2
[jit_kernel] Add downcast_fp8 JIT kernel (Phase 1)
Johnsonms 7f4999a
[jit_kernel] Migrate downcast_fp8 from sgl-kernel AOT to JIT (Phase 2)
Johnsonms d431152
style: apply code formatting to cast JIT kernel
Johnsonms 7fcf3cb
[jit_kernel] Optimize cast kernel: coalesced read and write with fixe…
Johnsonms cc80b9e
[jit_kernel] Rename input_sl to input_num_tokens in cast kernel
Johnsonms 9054279
[jit_kernel] Unify FP8 cast helpers into dtype_trait system
Johnsonms 0e876fa
[jit_kernel] Fix benchmark crash and rename variables in cast kernel
Johnsonms dc05df7
[jit_kernel] Add __restrict__ to cast kernel and fix bandwidth benchmark
Johnsonms 8da281b
[jit_kernel] Address review comments on cast kernel
Johnsonms 4053820
Merge branch 'main' into jit-kernel-cast
Johnsonms 37c3c0a
Merge branch 'main' into jit-kernel-cast
Johnsonms b16524d
Merge branch 'main' into jit-kernel-cast
Johnsonms 41547cb
Merge branch 'main' into jit-kernel-cast
Johnsonms fe1ad51
Merge branch 'main' into jit-kernel-cast
Johnsonms 4b8d668
Merge branch 'main' into jit-kernel-cast
Johnsonms 4be9d9b
[jit_kernel] Fix cast.cu build error due to conflicting headers
Johnsonms 20f0d36
[jit_kernel] Fix lint issues: reorder imports per isort rules
Johnsonms 583780b
[jit_kernel] Revert unrelated lint changes from cast.cu fix commit
Johnsonms cd210ee
Merge branch 'main' into jit-kernel-cast
Johnsonms 754ed28
Merge branch 'main' into jit-kernel-cast
Johnsonms File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| import torch | ||
| import triton | ||
| import triton.testing | ||
| from sgl_kernel import downcast_fp8 as downcast_fp8_aot | ||
|
|
||
| from sglang.jit_kernel.benchmark.utils import ( | ||
| DEFAULT_DEVICE, | ||
| get_benchmark_range, | ||
| run_benchmark, | ||
| ) | ||
| from sglang.jit_kernel.cast import downcast_fp8 as downcast_fp8_jit | ||
|
|
||
| DEVICE = DEFAULT_DEVICE | ||
| DTYPE = torch.bfloat16 | ||
|
|
||
|
|
||
| # ── Config ranges ────────────────────────────────────────────────────────────── | ||
|
|
||
| SL_LIST = get_benchmark_range( | ||
| full_range=[4, 16, 64, 256, 512, 1024, 2048], | ||
| ci_range=[4, 64], | ||
| ) | ||
|
|
||
| HEAD_DIM_LIST = get_benchmark_range( | ||
| full_range=[(8, 128), (32, 128), (8, 256), (32, 256)], | ||
| ci_range=[(8, 128)], | ||
| ) | ||
|
|
||
| CONFIGS = [(sl, h, d, sl * 2) for sl in SL_LIST for h, d in HEAD_DIM_LIST] | ||
|
|
||
| LINE_VALS = ["aot", "jit"] | ||
| LINE_NAMES = ["AOT (sgl-kernel)", "JIT (cast.cuh, 256 threads, 2D grid)"] | ||
| STYLES = [("blue", "--"), ("orange", "-")] | ||
|
|
||
|
|
||
| # ── Perf report ──────────────────────────────────────────────────────────────── | ||
|
|
||
|
|
||
| @triton.testing.perf_report( | ||
| triton.testing.Benchmark( | ||
| x_names=["input_sl", "head", "dim", "out_sl"], | ||
| x_vals=CONFIGS, | ||
| line_arg="provider", | ||
| line_vals=LINE_VALS, | ||
| line_names=LINE_NAMES, | ||
| styles=STYLES, | ||
| ylabel="us", | ||
| plot_name="downcast-fp8-aot-vs-jit", | ||
| args={}, | ||
| ) | ||
| ) | ||
| def benchmark(input_sl, head, dim, out_sl, provider): | ||
| k = torch.randn(input_sl, head, dim, dtype=DTYPE, device=DEVICE) | ||
| v = torch.randn(input_sl, head, dim, dtype=DTYPE, device=DEVICE) | ||
| k_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device=DEVICE) | ||
| v_out = torch.zeros(out_sl, head, dim, dtype=torch.uint8, device=DEVICE) | ||
| k_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) | ||
| v_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) | ||
| loc = torch.arange(input_sl, dtype=torch.int64, device=DEVICE) | ||
|
|
||
| if provider == "aot": | ||
| fn = lambda: downcast_fp8_aot(k, v, k_out, v_out, k_scale, v_scale, loc) | ||
| else: | ||
| fn = lambda: downcast_fp8_jit(k, v, k_out, v_out, k_scale, v_scale, loc) | ||
|
|
||
| return run_benchmark(fn) | ||
|
|
||
|
|
||
| # ── Bandwidth analysis ───────────────────────────────────────────────────────── | ||
|
|
||
|
|
||
| def _report_bandwidth(input_sl, head, dim, dtype): | ||
| elem_bytes = torch.finfo(dtype).bits // 8 | ||
| total_bytes = input_sl * head * dim * (2 * elem_bytes + 2) | ||
|
|
||
| k = torch.randn(input_sl, head, dim, dtype=dtype, device=DEVICE) | ||
| v = torch.randn(input_sl, head, dim, dtype=dtype, device=DEVICE) | ||
| k_out = torch.zeros(input_sl * 2, head, dim, dtype=torch.uint8, device=DEVICE) | ||
| v_out = torch.zeros(input_sl * 2, head, dim, dtype=torch.uint8, device=DEVICE) | ||
| k_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) | ||
| v_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) | ||
| loc = torch.arange(input_sl, dtype=torch.int64, device=DEVICE) | ||
|
|
||
| aot_fn = lambda: downcast_fp8_aot(k, v, k_out, v_out, k_scale, v_scale, loc) | ||
| jit_fn = lambda: downcast_fp8_jit(k, v, k_out, v_out, k_scale, v_scale, loc) | ||
|
|
||
| aot_ms, _, _ = triton.testing.do_bench(aot_fn, quantiles=[0.5, 0.2, 0.8]) | ||
| jit_ms, _, _ = triton.testing.do_bench(jit_fn, quantiles=[0.5, 0.2, 0.8]) | ||
|
|
||
| def fmt(ms): | ||
| return f"{ms*1000:6.2f}us {total_bytes/(ms*1e-3)/1e9:6.0f}GB/s" | ||
|
|
||
| print( | ||
| f" sl={input_sl:5d} h={head:2d} d={dim:4d}" | ||
| f" | aot {fmt(aot_ms)}" | ||
| f" | jit {fmt(jit_ms)}" | ||
| f" | speedup {aot_ms/jit_ms:.2f}x" | ||
| ) | ||
|
|
||
|
|
||
| def report_bandwidth(): | ||
| print(f"\n{'='*95}") | ||
| print(" AOT (sgl-kernel) vs JIT (cast.cuh, 256 threads, 2D grid)") | ||
| print(f" dtype={DTYPE}, device={DEVICE}") | ||
| print(f"{'='*95}") | ||
| for sl in [64, 256, 1024, 2048]: | ||
| for h, d in [(8, 128), (32, 128), (8, 256), (32, 256)]: | ||
| _report_bandwidth(sl, h, d, DTYPE) | ||
| print() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| benchmark.run(print_data=True) | ||
| report_bandwidth() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args | ||
|
|
||
| if TYPE_CHECKING: | ||
| from tvm_ffi.module import Module | ||
|
|
||
|
|
||
| @cache_once | ||
| def _jit_cast_module(dtype: torch.dtype) -> Module: | ||
| args = make_cpp_args(dtype) | ||
| return load_jit( | ||
| "cast", | ||
| *args, | ||
| cuda_files=["elementwise/cast.cuh"], | ||
| cuda_wrappers=[("downcast_fp8", f"downcast_fp8<{args}>")], | ||
| ) | ||
|
|
||
|
|
||
| def downcast_fp8( | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| k_out: torch.Tensor, | ||
| v_out: torch.Tensor, | ||
| k_scale: torch.Tensor, | ||
| v_scale: torch.Tensor, | ||
| loc: torch.Tensor, | ||
| mult: int = 1, | ||
| offset: int = 0, | ||
| ) -> None: | ||
| """Fused downcast of KV cache tensors from bf16/fp16 to fp8 (E4M3). | ||
|
|
||
| Scales each value by the inverse of its per-tensor scale, clamps to the | ||
| fp8 representable range [-448, 448], then converts to fp8 storage. | ||
|
|
||
| Args: | ||
| k: [input_sl, head, dim] bf16/fp16 CUDA tensor | ||
| v: [input_sl, head, dim] bf16/fp16 CUDA tensor | ||
| k_out: [out_sl, head, dim] uint8 CUDA tensor (fp8 storage) | ||
| v_out: [out_sl, head, dim] uint8 CUDA tensor (fp8 storage) | ||
| k_scale: [1] float32 CUDA tensor, scale for k | ||
| v_scale: [1] float32 CUDA tensor, scale for v | ||
| loc: [input_sl] int64 CUDA tensor, destination sequence indices | ||
| mult: stride multiplier for output index (default 1) | ||
| offset: offset added to output index (default 0) | ||
| """ | ||
| module = _jit_cast_module(k.dtype) | ||
| module.downcast_fp8(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| #pragma once | ||
|
|
||
| // Optimized cast kernel: fixed 256 threads, scaled out via 2D grid. | ||
| // Each thread handles exactly one float4 (kVecSize fp16/bf16 elements). | ||
| // No per-thread loop — pure grid scaling for any head*dim. | ||
|
|
||
| #include <sgl_kernel/tensor.h> | ||
| #include <sgl_kernel/utils.h> | ||
|
|
||
| #include <sgl_kernel/type.cuh> // For dtype_trait fp8 specialization | ||
| #include <sgl_kernel/utils.cuh> // For LaunchKernel | ||
| #include <sgl_kernel/vec.cuh> // For AlignedVector | ||
|
|
||
| #include <dlpack/dlpack.h> | ||
| #include <tvm/ffi/container/tensor.h> | ||
|
|
||
| #include <cstdint> | ||
|
|
||
| namespace { | ||
|
|
||
| constexpr int kBlockSize = 256; | ||
|
|
||
| template <typename T> | ||
| __global__ void fused_downcast_kernel( | ||
| const T* __restrict__ cache_k, | ||
| const T* __restrict__ cache_v, | ||
| const float* __restrict__ k_scale, | ||
| const float* __restrict__ v_scale, | ||
| fp8_e4m3_t* __restrict__ output_k, | ||
| fp8_e4m3_t* __restrict__ output_v, | ||
| const int input_num_tokens, | ||
| const int head, | ||
| const int dim, | ||
| const T max_fp8, | ||
| const T min_fp8, | ||
| const int64_t mult, | ||
| const int64_t offset, | ||
| const int64_t* __restrict__ loc) { | ||
| using namespace device; | ||
|
|
||
| constexpr int kVecSize = 16 / sizeof(T); | ||
| using vec_t = AlignedVector<T, kVecSize>; | ||
| using out_vec_t = AlignedVector<fp8_e4m3_t, kVecSize>; | ||
|
|
||
| const int token_idx = blockIdx.x; | ||
| const int vec_idx = blockIdx.y * kBlockSize + threadIdx.x; | ||
| const int num_vecs = head * dim / kVecSize; | ||
|
|
||
| if (token_idx >= input_num_tokens || vec_idx >= num_vecs) return; | ||
|
|
||
| T k_scale_inv = static_cast<T>(1.f) / cast<T>(k_scale[0]); | ||
| T v_scale_inv = static_cast<T>(1.f) / cast<T>(v_scale[0]); | ||
|
|
||
| auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); }; | ||
|
|
||
| const int out_seq_idx = loc[token_idx]; | ||
| const T* in_k_base = cache_k + token_idx * head * dim; | ||
| const T* in_v_base = cache_v + token_idx * head * dim; | ||
| fp8_e4m3_t* out_k_base = output_k + (out_seq_idx * mult + offset) * head * dim; | ||
| fp8_e4m3_t* out_v_base = output_v + (out_seq_idx * mult + offset) * head * dim; | ||
|
|
||
| vec_t k_vec, v_vec; | ||
| k_vec.load(in_k_base, vec_idx); | ||
| v_vec.load(in_v_base, vec_idx); | ||
|
|
||
| out_vec_t out_k, out_v; | ||
| #pragma unroll | ||
| for (int j = 0; j < kVecSize; j++) { | ||
| out_k[j] = cast<fp8_e4m3_t>(clamp(k_vec[j] * k_scale_inv)); | ||
| out_v[j] = cast<fp8_e4m3_t>(clamp(v_vec[j] * v_scale_inv)); | ||
| } | ||
|
|
||
| out_k.store(out_k_base, vec_idx); | ||
| out_v.store(out_v_base, vec_idx); | ||
| } | ||
|
|
||
| template <typename T> | ||
| void downcast_fp8( | ||
| tvm::ffi::TensorView k, | ||
| tvm::ffi::TensorView v, | ||
| tvm::ffi::TensorView k_out, | ||
| tvm::ffi::TensorView v_out, | ||
| tvm::ffi::TensorView k_scale, | ||
| tvm::ffi::TensorView v_scale, | ||
| tvm::ffi::TensorView loc, | ||
| int64_t mult, | ||
| int64_t offset) { | ||
| using namespace host; | ||
|
|
||
| auto input_num_tokens = SymbolicSize{"input_num_tokens"}; | ||
| auto head = SymbolicSize{"head"}; | ||
| auto dim = SymbolicSize{"dim"}; | ||
| auto output_num_tokens = SymbolicSize{"out_sl"}; | ||
| auto device = SymbolicDevice{}; | ||
| device.set_options<kDLCUDA>(); | ||
|
|
||
| TensorMatcher({input_num_tokens, head, dim}).with_dtype<T>().with_device(device).verify(k); | ||
| TensorMatcher({input_num_tokens, head, dim}).with_dtype<T>().with_device(device).verify(v); | ||
| TensorMatcher({output_num_tokens, head, dim}).with_dtype<uint8_t>().with_device(device).verify(k_out); | ||
| TensorMatcher({output_num_tokens, head, dim}).with_dtype<uint8_t>().with_device(device).verify(v_out); | ||
| TensorMatcher({1}).with_dtype<float>().with_device(device).verify(k_scale); | ||
| TensorMatcher({1}).with_dtype<float>().with_device(device).verify(v_scale); | ||
| TensorMatcher({input_num_tokens}).with_dtype<int64_t>().with_device(device).verify(loc); | ||
|
|
||
| const int num_tokens = static_cast<int>(input_num_tokens.unwrap()); | ||
| const int h = static_cast<int>(head.unwrap()); | ||
| const int d = static_cast<int>(dim.unwrap()); | ||
|
|
||
| constexpr int kVecSize = 16 / sizeof(T); | ||
| const int num_vecs = h * d / kVecSize; | ||
| const int grid_y = (num_vecs + kBlockSize - 1) / kBlockSize; | ||
|
|
||
| dim3 grid(num_tokens, grid_y); | ||
| dim3 block(kBlockSize); | ||
|
|
||
| const T max_fp8 = static_cast<T>(kFP8E4M3Max); | ||
| const T min_fp8 = static_cast<T>(-kFP8E4M3Max); | ||
|
|
||
| LaunchKernel(grid, block, device.unwrap())( | ||
| fused_downcast_kernel<T>, | ||
| static_cast<const T*>(k.data_ptr()), | ||
| static_cast<const T*>(v.data_ptr()), | ||
| static_cast<const float*>(k_scale.data_ptr()), | ||
| static_cast<const float*>(v_scale.data_ptr()), | ||
| static_cast<fp8_e4m3_t*>(k_out.data_ptr()), | ||
| static_cast<fp8_e4m3_t*>(v_out.data_ptr()), | ||
| num_tokens, | ||
| h, | ||
| d, | ||
| max_fp8, | ||
| min_fp8, | ||
| mult, | ||
| offset, | ||
| static_cast<const int64_t*>(loc.data_ptr())); | ||
| } | ||
|
|
||
| } // namespace |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.