diff --git a/python/sglang/jit_kernel/.clang-format b/python/sglang/jit_kernel/.clang-format index 75fe1387c84a..56acfb8b8f5c 100644 --- a/python/sglang/jit_kernel/.clang-format +++ b/python/sglang/jit_kernel/.clang-format @@ -15,5 +15,11 @@ PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name IncludeCategories: - - Regex: '^$' + - Regex: '^$' Priority: 0 + - Regex: '^$' + Priority: 2 + - Regex: '^$' + Priority: 1 + - Regex: '^<.*/.*>$' + Priority: 3 diff --git a/python/sglang/jit_kernel/benchmark/bench_per_tensor_quant_fp8.py b/python/sglang/jit_kernel/benchmark/bench_per_tensor_quant_fp8.py index 870057a22c37..8c19cb4b7eed 100644 --- a/python/sglang/jit_kernel/benchmark/bench_per_tensor_quant_fp8.py +++ b/python/sglang/jit_kernel/benchmark/bench_per_tensor_quant_fp8.py @@ -1,4 +1,3 @@ -import itertools import os from typing import Optional, Tuple @@ -57,7 +56,7 @@ def sglang_scaled_fp8_quant( def calculate_diff(batch_size: int, seq_len: int): device = torch.device("cuda") - x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device) + x = torch.rand((batch_size, seq_len), dtype=torch.bfloat16, device=device) if not VLLM_AVAILABLE: print("vLLM not available, skipping comparison") @@ -66,25 +65,17 @@ def calculate_diff(batch_size: int, seq_len: int): vllm_out, vllm_scale = vllm_scaled_fp8_quant(x) sglang_out, sglang_scale = sglang_scaled_fp8_quant(x) - scale_diff = torch.abs(vllm_scale - sglang_scale).item() - output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item() + vllm_out = vllm_out.to(torch.float32) + sglang_out = sglang_out.to(torch.float32) - if torch.allclose( - vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 - ) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5): - print("All implementations match") - else: - print("Implementations differ") + triton.testing.assert_close(vllm_out, sglang_out, rtol=1e-3, atol=1e-3) + triton.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3) if IS_CI: - batch_size_range = [16] - seq_len_range = [64] + element_range = [16384] else: - batch_size_range = [16, 32, 64, 128] - seq_len_range = [64, 128, 256, 512, 1024, 2048] - -configs = list(itertools.product(batch_size_range, seq_len_range)) + element_range = [2**n for n in range(10, 20)] if VLLM_AVAILABLE: @@ -99,8 +90,8 @@ def calculate_diff(batch_size: int, seq_len: int): @triton.testing.perf_report( triton.testing.Benchmark( - x_names=["batch_size", "seq_len"], - x_vals=configs, + x_names=["element_count"], + x_vals=element_range, line_arg="provider", line_vals=line_vals, line_names=line_names, @@ -110,11 +101,11 @@ def calculate_diff(batch_size: int, seq_len: int): args={}, ) ) -def benchmark(batch_size, seq_len, provider): +def benchmark(element_count, provider): dtype = torch.float16 device = torch.device("cuda") - x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) + x = torch.randn(element_count, 4096, device=device, dtype=dtype) quantiles = [0.5, 0.2, 0.8] diff --git a/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py b/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py new file mode 100644 index 000000000000..6d2a1482493a --- /dev/null +++ b/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py @@ -0,0 +1,96 @@ +import itertools +import os + +import torch +import triton +import triton.testing +from flashinfer import rmsnorm as fi_rmsnorm +from sgl_kernel import rmsnorm + +from sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm + +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + + +def sglang_aot_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, +) -> None: + rmsnorm(input, weight, out=input) + + +def sglang_jit_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, +) -> None: + jit_rmsnorm(input, weight, output=input) + + +def flashinfer_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, +) -> None: + fi_rmsnorm(input, weight, out=input) + + +@torch.compile() +def torch_impl_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, +) -> None: + mean = input.float().pow(2).mean(dim=-1, keepdim=True) + norm = (mean + eps).rsqrt() + input.copy_(input.float() * norm * weight.float()) + + +DTYPE = torch.bfloat16 +DEVICE = "cuda" + +if IS_CI: + BS_LIST = [16] + HIDDEN_SIZE_LIST = [512, 2048] +else: + BS_LIST = [2**n for n in range(0, 14)] + HIDDEN_SIZE_LIST = [1536, 3072, 4096, 5120, 8192] + +LINE_VALS = ["aot", "jit", "fi", "torch"] +LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "FlashInfer", "PyTorch"] +STYLES = [("orange", "-"), ("blue", "--"), ("green", "-."), ("red", ":")] + +configs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["hidden_size", "batch_size"], + x_vals=configs, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="rmsnorm-performance", + args={}, + ) +) +def benchmark(hidden_size: int, batch_size: int, provider: str): + input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE) + weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE) + FN_MAP = { + "aot": sglang_aot_rmsnorm, + "jit": sglang_jit_rmsnorm, + "fi": flashinfer_rmsnorm, + "torch": torch_impl_rmsnorm, + } + fn = lambda: FN_MAP[provider](input.clone(), weight) + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) # type: ignore + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/python/sglang/jit_kernel/csrc/add_constant.cuh b/python/sglang/jit_kernel/csrc/add_constant.cuh index d4e5bed67a6f..6c723a761db2 100644 --- a/python/sglang/jit_kernel/csrc/add_constant.cuh +++ b/python/sglang/jit_kernel/csrc/add_constant.cuh @@ -1,6 +1,7 @@ -#include // For TensorMatcher, SymbolicSize, SymbolicDevice +#include // For TensorMatcher, SymbolicSize, SymbolicDevice +#include // For div_ceil, RuntimeCheck + #include // For LaunchKernel -#include // For div_ceil, RuntimeCheck #include #include diff --git a/python/sglang/jit_kernel/csrc/cuda_wait_value.cuh b/python/sglang/jit_kernel/csrc/cuda_wait_value.cuh index 5874ce1efbf4..49b274fb17d7 100644 --- a/python/sglang/jit_kernel/csrc/cuda_wait_value.cuh +++ b/python/sglang/jit_kernel/csrc/cuda_wait_value.cuh @@ -1,9 +1,9 @@ #include -#include -#include +#include #include +#include namespace { diff --git a/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh b/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh index 6ef424891164..e5a1336b0b66 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh @@ -1,8 +1,9 @@ #include -#include #include + +#include +#include #include -#include #include #include @@ -27,8 +28,17 @@ struct StoreKVCacheParams { constexpr uint32_t kNumWarps = 4; constexpr uint32_t kThreadsPerBlock = kNumWarps * device::kWarpThreads; +/** + * \brief Use a single warp to copy key and value data from source to destination. + * Each thread in the warp copies a portion of the data in a coalesced manner. + * \tparam kElementBytes The size of each key/value element in bytes. + * \param k_src Pointer to the source key data. + * \param v_src Pointer to the source value data. + * \param k_dst Pointer to the destination key data. + * \param v_dst Pointer to the destination value data. + */ template -__device__ void copy_impl( +SGL_DEVICE void copy_kv_warp( const void* __restrict__ k_src, const void* __restrict__ v_src, void* __restrict__ k_dst, @@ -42,31 +52,39 @@ __device__ void copy_impl( static_assert(kAlignment > 0, "Element size must be multiple of 4 bytes"); - using vec_t = aligned_vector; + using vec_t = AlignedStorage; constexpr auto kLoopBytes = sizeof(vec_t) * kWarpThreads; constexpr auto kLoopCount = kElementBytes / kLoopBytes; + const auto gmem = tile::Memory::warp(); + #pragma unroll kLoopCount for (int64_t i = 0; i < kLoopCount; ++i) { - const auto k = warp::load(pointer::offset(k_src, i * kLoopBytes)); - const auto v = warp::load(pointer::offset(v_src, i * kLoopBytes)); - warp::store(pointer::offset(k_dst, i * kLoopBytes), k); - warp::store(pointer::offset(v_dst, i * kLoopBytes), v); + const auto k = gmem.load(k_src, i); + const auto v = gmem.load(v_src, i); + gmem.store(k_dst, k, i); + gmem.store(v_dst, v, i); } // handle the epilogue if any if constexpr (kLoopCount * kLoopBytes < kElementBytes) { - constexpr auto kOffset = kLoopCount * kLoopBytes; - if ((threadIdx.x % kWarpThreads) * sizeof(vec_t) < kElementBytes - kOffset) { - const auto k = warp::load(pointer::offset(k_src, kOffset)); - const auto v = warp::load(pointer::offset(v_src, kOffset)); - warp::store(pointer::offset(k_dst, kOffset), k); - warp::store(pointer::offset(v_dst, kOffset), v); + if (gmem.in_bound(kElementBytes / sizeof(vec_t), kLoopCount)) { + const auto k = gmem.load(k_src, kLoopCount); + const auto v = gmem.load(v_src, kLoopCount); + gmem.store(k_dst, k, kLoopCount); + gmem.store(v_dst, v, kLoopCount); } } } -// Each warp handles one item +/** + * \brief Kernel to store key-value pairs into the KV cache. + * Each element is split into multiple parts to allow parallel memory copy. + * \tparam kElementBytes The size of each key/value element in bytes. + * \tparam kSplit The number of warps that handle each element. + * \tparam kUsePDL Whether to use PDL feature. + * \tparam T The data type of the indices (`int32_t` or `int64_t`). + */ template __global__ void store_kvcache(const __grid_constant__ StoreKVCacheParams params) { using namespace device; @@ -89,7 +107,7 @@ __global__ void store_kvcache(const __grid_constant__ StoreKVCacheParams params) const auto k_dst = pointer::offset(k_cache, index * stride_cache, split_id * kSplitSize); const auto v_dst = pointer::offset(v_cache, index * stride_cache, split_id * kSplitSize); - copy_impl(k_src, v_src, k_dst, v_dst); + copy_kv_warp(k_src, v_src, k_dst, v_dst); PDLTriggerSecondary(); } diff --git a/python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh b/python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh index a44ecba4ca74..9e6cec8d2524 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/qknorm.cuh @@ -1,42 +1,22 @@ -#include #include -#include #include -#include -#include -#include -#include +#include +#include +#include + +#include + #include #include -#include #include +#include +#include #include namespace { -[[maybe_unused]] -__device__ auto to_float2(nv_bfloat162 x) -> float2 { - return __bfloat1622float2(x); -} - -[[maybe_unused]] -__device__ auto to_float2(half2 x) -> float2 { - return __half22float2(x); -} - -template -__device__ auto from_float2(float2 x) -> T { - if constexpr (std::is_same_v) { - return __float22bfloat162_rn(x); - } else if constexpr (std::is_same_v) { - return __float22half2_rn(x); - } else { - static_assert(sizeof(T) == 0, "Unsupported type"); - } -} - struct QKNormParams { void* __restrict__ q; void* __restrict__ k; // k is offset by (-num_qo_heads * head_dim) elements @@ -50,52 +30,15 @@ struct QKNormParams { uint32_t num_tokens; }; -template -__always_inline __device__ void apply_norm(void* __restrict__ input, const void* __restrict__ weight, float eps) { - using namespace device; - - constexpr std::size_t kLoopCount = kHeadDim / (kWarpThreads * 2); - static_assert(kHeadDim % (kWarpThreads * 2) == 0); - - float sum_of_squares = 0.0f; - - using vec_t = aligned_vector; - const auto input_vec = warp::load(input); - -#pragma unroll - for (auto i = 0u; i < kLoopCount; ++i) { - const auto fp16_input = input_vec[i]; - const auto fp32_input = to_float2(fp16_input); - sum_of_squares += fp32_input.x * fp32_input.x; - sum_of_squares += fp32_input.y * fp32_input.y; - } - - sum_of_squares = warp::reduce_sum(sum_of_squares); - const auto norm_factor = rsqrtf(sum_of_squares / kHeadDim + eps); - const auto weight_vec = warp::load(weight); - vec_t output_vec; - -#pragma unroll - for (auto i = 0u; i < kLoopCount; ++i) { - const auto fp32_input = to_float2(input_vec[i]); - const auto fp32_weight = to_float2(weight_vec[i]); - output_vec[i] = from_float2({ - fp32_input.x * norm_factor * fp32_weight.x, - fp32_input.y * norm_factor * fp32_weight.y, - }); - } - - warp::store(input, output_vec); -} - constexpr uint32_t kWarpsPerBlock = 4; constexpr uint32_t kThreadsPerBlock = kWarpsPerBlock * device::kWarpThreads; -template +template __global__ void fused_qknorm(const QKNormParams __grid_constant__ params) { using namespace device; + using Storage = norm::StorageType; - static_assert(sizeof(Float) == 2 && sizeof(PackedFloat) == 4, "Only support FP16/BF16"); + static_assert(sizeof(Float) == 2, "Only support FP16/BF16"); const auto& [q, k, q_stride, k_stride, num_qo_heads, num_kv_heads, eps, q_weight, k_weight, num_tokens] = params; const auto num_blks = gridDim.x; @@ -103,6 +46,7 @@ __global__ void fused_qknorm(const QKNormParams __grid_constant__ params) { const auto num_q_and_k_heads = num_qo_heads + num_kv_heads; const auto num_works = num_q_and_k_heads * num_tokens; const auto start_worker_id = blockIdx.x * kWarpsPerBlock + threadIdx.x / kWarpThreads; + const auto gmem = tile::Memory::warp(); PDLWaitPrimary(); // wait for primary kernel @@ -113,7 +57,10 @@ __global__ void fused_qknorm(const QKNormParams __grid_constant__ params) { const auto input = load_q ? pointer::offset(q, 2 * (token_id * q_stride + head_id * kHeadDim)) : pointer::offset(k, 2 * (token_id * k_stride + head_id * kHeadDim)); const auto weight = load_q ? q_weight : k_weight; - apply_norm(input, weight, eps); + const auto input_vec = gmem.load(input); + const auto weight_vec = gmem.load(weight); + const auto output_vec = norm::apply_norm_warp(input_vec, weight_vec, eps); + gmem.store(input, output_vec); } PDLTriggerSecondary(); // launch secondary kernel @@ -121,13 +68,9 @@ __global__ void fused_qknorm(const QKNormParams __grid_constant__ params) { template struct QKNormKernel { - static_assert( - std::is_same_v || std::is_same_v, - "Unsupported DType: QKNormKernel only supports half and nv_bfloat16."); - using DType2 = host::PackedDType::type; - - // only initialize once (static variable) to avoid overhead - static constexpr auto kernel = fused_qknorm; + static_assert(std::is_same_v || std::is_same_v); + static_assert(!host::norm::should_use_cta(), "Head dim too large for QKNorm"); + static constexpr auto kernel = fused_qknorm; static void run(const tvm::ffi::TensorView q, @@ -143,38 +86,29 @@ struct QKNormKernel { auto D = SymbolicSize{"head_dim"}; auto Sq = SymbolicSize{"q_stride"}; auto Sk = SymbolicSize{"k_stride"}; - auto dtype = SymbolicDType{}; auto device = SymbolicDevice{}; + D.set_value(kHeadDim); + device.set_options(); - /* - * We need the .template disambiguator here because this call happens in a dependent context. - * After switching to with_dtype(...) (where DType is a template parameter), the chained expression becomes - * dependent. In C++, when calling a member function template via ./-> on a dependent expression, the compiler may - * otherwise parse as the < operator instead of template arguments. Adding .template forces correct - * parsing and fixes compilation errors (often seen with NVCC/clang). Ref: - * https://en.cppreference.com/w/cpp/language/dependent_name - */ TensorMatcher({N, Q, D}) // q input .with_strides({Sq, D, 1}) - .with_dtype(dtype) - .template with_device(device) + .with_dtype() + .with_device(device) .verify(q); TensorMatcher({N, K, D}) // k input .with_strides({Sk, D, 1}) - .with_dtype(dtype) - .template with_device(device) + .with_dtype() + .with_device(device) .verify(k); TensorMatcher({D}) // weight - .with_dtype(dtype) - .template with_device(device) + .with_dtype() + .with_device(device) .verify(q_weight) .verify(k_weight); const auto num_tokens = static_cast(N.unwrap()); const auto num_qo_heads = static_cast(Q.unwrap()); const auto num_kv_heads = static_cast(K.unwrap()); - const auto head_dim = D.unwrap(); - RuntimeCheck(head_dim == kHeadDim, "Wrong head_dim: ", head_dim, ". Expected:", kHeadDim); // NOTE: we offset the k here to reduce computation cost in the kernel const auto params = QKNormParams{ diff --git a/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh b/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh new file mode 100644 index 000000000000..aadcc495f51e --- /dev/null +++ b/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh @@ -0,0 +1,109 @@ +#include +#include + +#include +#include +#include + +#include + +#include + +namespace { + +struct RMSNormParams { + const void* input; + const void* __restrict__ weight; + void* output; + int64_t input_stride; + int64_t output_stride; + uint32_t num_tokens; + float eps; +}; + +template +__global__ void rmsnorm_cta(const RMSNormParams __grid_constant__ params) { + using namespace device; + using Storage = norm::StorageType; + + constexpr auto kNumThreads = host::norm::get_cta_threads(); + constexpr auto kNumWarps = kNumThreads / kWarpThreads; + + const auto& [input, weight_ptr, output, input_stride, output_stride, num_tokens, eps] = params; + const auto gmem = tile::Memory::cta(kNumThreads); + __shared__ float smem[norm::kSmemBufferSize]; + + PDLWaitPrimary(); // wait for primary kernel + + void* output_ptr = nullptr; + Storage output_vec; + for (uint32_t i = blockIdx.x; i < num_tokens; i += gridDim.x) { + const auto input_ptr = pointer::offset(input, i * input_stride); + const auto input_vec = gmem.load(input_ptr); + const auto weight_vec = gmem.load(weight_ptr); + if (output_ptr != nullptr) { + gmem.store(output_ptr, output_vec); + } + output_ptr = pointer::offset(output, i * output_stride); + output_vec = norm::apply_norm_cta(input_vec, weight_vec, eps, smem, kNumWarps); + } + gmem.store(output_ptr, output_vec); + + PDLTriggerSecondary(); // launch secondary kernel +} + +template +struct RMSNormKernel { + static_assert(host::norm::should_use_cta(), "Hidden size invalid for RMSNorm"); + static constexpr auto kernel = rmsnorm_cta; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView output, + float eps) { + using namespace host; + auto N = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"hidden_size"}; + auto SI = SymbolicSize{"input_stride"}; + auto SO = SymbolicSize{"output_stride"}; + auto device = SymbolicDevice{}; + D.set_value(kDim); + device.set_options(); + + TensorMatcher({N, D}) // input + .with_strides({SI, 1}) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({D}) // weight + .with_dtype() + .with_device(device) + .verify(weight); + TensorMatcher({N, D}) // output + .with_strides({SO, 1}) + .with_dtype() + .with_device(device) + .verify(output); + + const auto num_tokens = static_cast(N.unwrap()); + const auto params = RMSNormParams{ + .input = input.data_ptr(), + .weight = weight.data_ptr(), + .output = output.data_ptr(), + .input_stride = SI.unwrap(), + .output_stride = SO.unwrap(), + .num_tokens = num_tokens, + .eps = eps, + }; + + static constexpr auto kNumThreads = norm::get_cta_threads(); + static const uint32_t max_occupancy = runtime::get_blocks_per_sm(kernel, kNumThreads); + static const uint32_t kNumSM = runtime::get_sm_count(device.unwrap().device_id); + const auto num_blocks = std::min(num_tokens, max_occupancy * kNumSM); + LaunchKernel(num_blocks, kNumThreads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/gemm/per_tensor_quant_fp8.cuh b/python/sglang/jit_kernel/csrc/gemm/per_tensor_quant_fp8.cuh index 3a0d677f849f..344648ffc8af 100644 --- a/python/sglang/jit_kernel/csrc/gemm/per_tensor_quant_fp8.cuh +++ b/python/sglang/jit_kernel/csrc/gemm/per_tensor_quant_fp8.cuh @@ -1,54 +1,62 @@ -#include #include -#include #include +#include +#include +#include +#include +#include +#include +#include + #include #include -#include -#include namespace { -using device::atomicMaxFloat; -using device::blockReduceMax; -using device::FP8_E4M3_MAX; +constexpr size_t kBlockSize = 256; +// each warp will handle 512B data template __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) { - float max_value = 0.0f; - unsigned int tid = threadIdx.x; - unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x; - const int grid_size = blockDim.x * gridDim.x; - - constexpr uint32_t vec_size = 16 / sizeof(T); - using vec_t = flashinfer::vec_t; - - const int32_t num_vec_elems = num_elements / vec_size; + using namespace device; + constexpr uint32_t VEC_SIZE = 16 / sizeof(T); - for (int32_t i = gid; i < num_vec_elems; i += grid_size) { - vec_t input_vec; - input_vec.cast_load(input + i * vec_size); + const int64_t gid = blockIdx.x * gridDim.x + threadIdx.x; + float max_value = 0.0f; + if (gid * VEC_SIZE + VEC_SIZE <= num_elements) { + using vec_t = AlignedVector; + const auto gmem_in = tile::Memory::thread(); + const auto input_vec = gmem_in.load(input, gid); #pragma unroll - for (uint32_t j = 0; j < vec_size; ++j) { - float val = static_cast(input_vec[j]); - max_value = fmaxf(max_value, fabsf(val)); + for (uint32_t i = 0; i < VEC_SIZE; ++i) { + const float value = static_cast(input_vec[i]); + max_value = math::max(max_value, math::abs(value)); + } + } else if (gid * VEC_SIZE < num_elements) { + [[unlikely]]; // poorly aligned case, do not optimize + const auto remainder = num_elements - gid * VEC_SIZE; + for (uint32_t i = 0; i < remainder; ++i) { + const float value = static_cast(input[gid * VEC_SIZE + i]); + max_value = math::max(max_value, math::abs(value)); } } - const int32_t remaining_start = num_vec_elems * vec_size; - for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) { - float val = static_cast(input[idx]); - max_value = fmaxf(max_value, fabsf(val)); + // reduce within block and then atomic reduce between blocks + __shared__ float smem[kWarpThreads]; + cta::reduce_max(max_value, smem); + if (threadIdx.x == 0) { + const auto max_value = smem[0]; + atomic::max(output_s, max_value / math::FP8_E4M3_MAX); } +} - max_value = blockReduceMax(max_value); - - if (tid == 0) { - atomicMaxFloat(output_s, max_value / FP8_E4M3_MAX); - } +[[maybe_unused]] +SGL_DEVICE float fp8_e4m3_clip(float val) { + namespace math = device::math; + return math::max(math::min(val, math::FP8_E4M3_MAX), -math::FP8_E4M3_MAX); } template @@ -57,123 +65,75 @@ __global__ void per_tensor_quant_fp8_kernel( DST_DTYPE* __restrict__ output, const float* __restrict__ scale, const int64_t num_elements) { - const int gid = blockIdx.x * blockDim.x + threadIdx.x; - const int grid_size = blockDim.x * gridDim.x; - const float scale_val = 1.0f / (*scale); - - const uint32_t VEC_SIZE = 16; - using vec_t = flashinfer::vec_t; + using namespace device; + constexpr uint32_t VEC_SIZE = 16 / sizeof(T); - const int32_t num_vec_elems = num_elements / VEC_SIZE; - - for (int32_t i = gid; i < num_vec_elems; i += grid_size) { - vec_t input_vec; - input_vec.cast_load(input + i * VEC_SIZE); + const int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + const float scale_val = 1.0f / (*scale); - DST_DTYPE output_arr[VEC_SIZE]; + if (gid * VEC_SIZE + VEC_SIZE <= num_elements) { + using input_vec_t = AlignedVector; + using output_vec_t = AlignedVector; + const auto gmem_in = tile::Memory::thread(); + const auto gmem_out = tile::Memory::thread(); + const auto input_vec = gmem_in.load(input, gid); + output_vec_t output_vec; #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - float val = fmax(fmin(static_cast(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX); -#if !defined(USE_ROCM) || defined(HIP_FP8_TYPE_E4M3) - output_arr[j] = static_cast(val); -#else - output_arr[j] = c10::Float8_e4m3fnuz( - __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), - c10::Float8_e4m3fnuz::from_bits()); -#endif + for (uint32_t i = 0; i < VEC_SIZE; ++i) { + const float value = fp8_e4m3_clip(static_cast(input_vec[i]) * scale_val); + output_vec[i] = static_cast(value); + } + gmem_out.store(output, output_vec, gid); + } else if (gid * VEC_SIZE < num_elements) { + [[unlikely]]; // poorly aligned case, do not optimize + const auto remainder = num_elements - gid * VEC_SIZE; + for (uint32_t i = 0; i < remainder; ++i) { + const float value = fp8_e4m3_clip(static_cast(input[gid * VEC_SIZE + i]) * scale_val); + output[gid * VEC_SIZE + i] = static_cast(value); } - *(uint4*)(output + i * VEC_SIZE) = *(uint4*)output_arr; - } - - const int32_t remaining_start = num_vec_elems * VEC_SIZE; - for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) { - float val = fmax(-FP8_E4M3_MAX, fmin(static_cast(input[idx]) * scale_val, FP8_E4M3_MAX)); -#if !defined(USE_ROCM) || defined(HIP_FP8_TYPE_E4M3) - output[idx] = static_cast(val); -#else - output[idx] = c10::Float8_e4m3fnuz( - __hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), - c10::Float8_e4m3fnuz::from_bits()); -#endif } } -constexpr size_t kBlockSize = 256; - -template +template void per_tensor_quant_fp8(tvm::ffi::TensorView input, tvm::ffi::TensorView output_q, tvm::ffi::TensorView output_s) { using namespace host; - const DLDevice device = input.device(); - RuntimeCheck(device.device_type == kDLCUDA, "input must be on CUDA"); - RuntimeCheck(input.is_contiguous(), "input must be contiguous"); - - const int64_t ndim = input.dim(); - RuntimeCheck(ndim >= 1, "input.ndim must be >= 1, but got ", ndim); - - RuntimeCheck(output_q.device() == device, "output_q must be on the same device as input"); - RuntimeCheck(output_q.is_contiguous(), "output_q must be contiguous"); - RuntimeCheck(output_q.dim() == ndim, "output_q.ndim must match input.ndim"); - for (int64_t i = 0; i < ndim; ++i) { - RuntimeCheck( - output_q.size(i) == input.size(i), - "output_q.shape mismatch at dim ", - i, - ": expected ", - input.size(i), - " but got ", - output_q.size(i)); - } - + auto device = SymbolicDevice{}; + auto N = SymbolicSize{"num_elements"}; + device.set_options(); + + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(output_q); TensorMatcher({1}) // .with_dtype() - .with_device() + .with_device(device) .verify(output_s); - RuntimeCheck(output_s.device() == device, "output_s must be on the same device as input"); - - const DLDataType in_dtype = input.dtype(); - const bool in_ok = (in_dtype.code == kDLFloat && in_dtype.bits == 32) || - (in_dtype.code == kDLFloat && in_dtype.bits == 16) || - (in_dtype.code == kDLBfloat && in_dtype.bits == 16); - RuntimeCheck(in_ok, "input dtype must be fp32/fp16/bf16, but got ", in_dtype); - - const DLDataType out_dtype = output_q.dtype(); - RuntimeCheck( - out_dtype.code == kDLFloat8_e4m3fn && out_dtype.bits == 8, - "output_q dtype must be fp8_e4m3fn, but got ", - out_dtype); - - size_t total_elements = 1; - for (const auto s : input.shape()) { - RuntimeCheck(s > 0, "Input tensor must be non-empty"); - total_elements *= static_cast(s); - } - const size_t num_blocks = std::min((total_elements + kBlockSize - 1) / kBlockSize, size_t(1024)); - - auto launch_kernels = [&]() { - if constexpr (!kIsStatic) { - LaunchKernel(num_blocks, kBlockSize, device)( - per_tensor_absmax_kernel, - static_cast(input.data_ptr()), - static_cast(output_s.data_ptr()), - static_cast(total_elements)); - } - LaunchKernel(num_blocks, kBlockSize, device)( - per_tensor_quant_fp8_kernel, - static_cast(input.data_ptr()), - static_cast<__nv_fp8_e4m3*>(output_q.data_ptr()), - static_cast(output_s.data_ptr()), - static_cast(total_elements)); - }; - - if (in_dtype.code == kDLFloat && in_dtype.bits == 32) { - launch_kernels.template operator()(); - } else if (in_dtype.code == kDLBfloat && in_dtype.bits == 16) { - launch_kernels.template operator()<__nv_bfloat16>(); - } else if (in_dtype.code == kDLFloat && in_dtype.bits == 16) { - launch_kernels.template operator()<__half>(); + const auto num_elements = N.unwrap(); + + constexpr size_t kElementsPerBlock = kBlockSize * (16 / sizeof(DType)); + const uint32_t num_blocks = div_ceil(num_elements, kElementsPerBlock); + + if constexpr (!kIsStatic) { + LaunchKernel(num_blocks, kBlockSize, device.unwrap())( + per_tensor_absmax_kernel, + static_cast(input.data_ptr()), + static_cast(output_s.data_ptr()), + static_cast(num_elements)); } + + LaunchKernel(num_blocks, kBlockSize, device.unwrap())( + per_tensor_quant_fp8_kernel, + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + static_cast(num_elements)); } } // namespace diff --git a/python/sglang/jit_kernel/csrc/hicache.cuh b/python/sglang/jit_kernel/csrc/hicache.cuh index 230e4d513560..555282ee432d 100644 --- a/python/sglang/jit_kernel/csrc/hicache.cuh +++ b/python/sglang/jit_kernel/csrc/hicache.cuh @@ -1,7 +1,8 @@ #include -#include #include +#include + #include #include diff --git a/python/sglang/jit_kernel/csrc/test_utils.h b/python/sglang/jit_kernel/csrc/test_utils.h deleted file mode 100644 index acaf824b2bfa..000000000000 --- a/python/sglang/jit_kernel/csrc/test_utils.h +++ /dev/null @@ -1,22 +0,0 @@ -#include -#include - -#include -#include - -namespace { - -[[maybe_unused]] -void assert_same_shape(tvm::ffi::TensorView a, tvm::ffi::TensorView b) { - using namespace host; - auto N = SymbolicSize{"N"}; - auto D = SymbolicSize{"D"}; - TensorMatcher({N, D}) // - .with_dtype() - .with_device() - .verify(a) - .verify(b); - RuntimeCheck(N.unwrap() > 0 && D.unwrap() > 0); -} - -} // namespace diff --git a/python/sglang/jit_kernel/include/sgl_kernel/atomic.cuh b/python/sglang/jit_kernel/include/sgl_kernel/atomic.cuh new file mode 100644 index 000000000000..574f79b8d82a --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/atomic.cuh @@ -0,0 +1,23 @@ +#pragma once +#include + +namespace device::atomic { + +SGL_DEVICE float max(float* addr, float value) { +#ifndef USE_ROCM + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + return old; +#else + int* addr_as_i = (int*)addr; + int old = *addr_as_i, assumed; + do { + assumed = old; + old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +#endif +} + +} // namespace device::atomic diff --git a/python/sglang/jit_kernel/include/sgl_kernel/cta.cuh b/python/sglang/jit_kernel/include/sgl_kernel/cta.cuh new file mode 100644 index 000000000000..28db34f02595 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/cta.cuh @@ -0,0 +1,22 @@ +#pragma once +#include +#include +#include + +namespace device::cta { + +template +SGL_DEVICE void reduce_max(T value, float* smem, float min_value = 0.0f) { + const uint32_t warp_id = threadIdx.x / kWarpThreads; + smem[warp_id] = warp::reduce_max(value); + __syncthreads(); + if (warp_id == 0) { + const auto tx = threadIdx.x; + const auto local_value = tx * kWarpThreads < blockDim.x ? smem[tx] : min_value; + const auto max_value = warp::reduce_max(local_value); + smem[0] = max_value; + } + // no extra sync; it is caller's responsibility to sync if needed +} + +} // namespace device::cta diff --git a/python/sglang/jit_kernel/include/sgl_kernel/fp8_utils.cuh b/python/sglang/jit_kernel/include/sgl_kernel/fp8_utils.cuh deleted file mode 100644 index 3989b9443118..000000000000 --- a/python/sglang/jit_kernel/include/sgl_kernel/fp8_utils.cuh +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#ifdef __CUDACC__ -#include -#include -#include -#endif - -namespace device { - -inline constexpr float FP8_E4M3_MAX = 448.0f; - -} // namespace device diff --git a/python/sglang/jit_kernel/include/sgl_kernel/impl/norm.cuh b/python/sglang/jit_kernel/include/sgl_kernel/impl/norm.cuh new file mode 100644 index 000000000000..cd024acd4604 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/impl/norm.cuh @@ -0,0 +1,168 @@ +#pragma once +#include +#include +#include +#include +#include + +#include +#include + +namespace host::norm { + +/** + * \brief Check if the given configuration is supported. + * \tparam T Element type (only fp16_t/bf16_t is supported) + * \tparam kDim Dimension size (usually hidden size) + */ +template +inline constexpr bool is_config_supported() { + if (!std::is_same_v && !std::is_same_v) return false; + if (kDim <= 256) { + return (kDim == 64 || kDim == 128 || kDim == 256); + } else { + return (kDim % 256 == 0 && kDim <= 8192); + } +} + +/** + * \brief Determine whether to use cta norm based on dimension size. + * TL;DR: use warp norm for dim <= 256, cta norm otherwise. + * \tparam T Element type (fp16_t or bf16_t) + * \tparam kDim Dimension size (usually hidden size) + * \note This function assumes that the configuration is supported. + * \see `is_config_supported` + */ +template +inline constexpr bool should_use_cta() { + static_assert(is_config_supported(), "Unsupported norm configuration"); + return kDim > 256; +} + +/** + * \brief Get the number of threads per CTA for cta norm. + * \tparam T Element type (fp16_t or bf16_t) + * \tparam kDim Dimension size (usually hidden size) + * \return Number of threads per CTA + */ +template +inline constexpr uint32_t get_cta_threads() { + static_assert(should_use_cta()); + return (kDim / 256) * device::kWarpThreads; +} + +} // namespace host::norm + +namespace device::norm { + +namespace details { + +template +SGL_DEVICE AlignedVector apply_norm_impl( + const AlignedVector input, + const AlignedVector weight, + const float eps, + [[maybe_unused]] float* smem_buffer, + [[maybe_unused]] uint32_t num_warps) { + float sum_of_squares = 0.0f; + +#pragma unroll + for (auto i = 0u; i < N; ++i) { + const auto fp32_input = cast(input[i]); + sum_of_squares += fp32_input.x * fp32_input.x; + sum_of_squares += fp32_input.y * fp32_input.y; + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + float norm_factor; + if constexpr (kUseCTA) { + // need to synchronize across the cta + const auto warp_id = threadIdx.x / kWarpThreads; + smem_buffer[warp_id] = sum_of_squares; + __syncthreads(); + // use the first warp to reduce + if (warp_id == 0) { + const auto tx = threadIdx.x; + const auto local_sum = tx < num_warps ? smem_buffer[tx] : 0.0f; + sum_of_squares = warp::reduce_sum(local_sum); + smem_buffer[32] = math::rsqrt(sum_of_squares / kDim + eps); + } + __syncthreads(); + norm_factor = smem_buffer[32]; + } else { + norm_factor = math::rsqrt(sum_of_squares / kDim + eps); + } + + AlignedVector output; + +#pragma unroll + for (auto i = 0u; i < N; ++i) { + const auto fp32_input = cast(input[i]); + const auto fp32_weight = cast(weight[i]); + output[i] = cast({ + fp32_input.x * norm_factor * fp32_weight.x, + fp32_input.y * norm_factor * fp32_weight.y, + }); + } + + return output; +} + +} // namespace details + +/** + * \brief Apply norm using warp-level implementation. + * \tparam kDim Dimension size + * \tparam T Element type (fp16_t or bf16_t) + * \param input Input vector + * \param weight Weight vector + * \param eps Epsilon value for numerical stability + * \return Normalized output vector + */ +template +SGL_DEVICE T apply_norm_warp(const T& input, const T& weight, float eps) { + static_assert(kDim <= 256, "Warp norm only supports dim <= 256"); + return details::apply_norm_impl(input, weight, eps, nullptr, 0); +} + +/** + * \brief Apply norm using CTA-level implementation. + * \tparam kDim Dimension size + * \tparam T Element type (fp16_t or bf16_t) + * \param input Input vector + * \param weight Weight vector + * \param eps Epsilon value for numerical stability + * \param smem Shared memory buffer + * \param num_warps Number of warps in the CTA + * \return Normalized output vector + */ +template +SGL_DEVICE T apply_norm_cta( + const T& input, const T& weight, float eps, float* smem, uint32_t num_warps = blockDim.x / kWarpThreads) { + static_assert(kDim > 256, "CTA norm only supports dim > 256"); + return details::apply_norm_impl(input, weight, eps, smem, num_warps); +} + +/** + * \brief Storage type for norm operation. + * For warp norm, the storage size depends on kDim. + * For cta norm, the storage size is fixed to 16B. + * We will also pack the input 16-bit floats into 32-bit types + * for faster CUDA core operations. + * + * \tparam T Element type (fp16_t or bf16_t) + * \tparam kDim Dimension size + */ +template +using StorageType = std::conditional_t< // storage type + (kDim > 256), // whether to use cta norm + AlignedVector, 4>, // cta norm storage, fixed to 16B + AlignedVector, kDim / (2 * kWarpThreads)> // warp norm storage + >; + +/** + * \brief Minimum shared memory size (in bytes) required for cta norm. + */ +inline constexpr uint32_t kSmemBufferSize = 33; + +} // namespace device::norm diff --git a/python/sglang/jit_kernel/include/sgl_kernel/math.cuh b/python/sglang/jit_kernel/include/sgl_kernel/math.cuh new file mode 100644 index 000000000000..3d4aa7473dc3 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/math.cuh @@ -0,0 +1,36 @@ +#pragma once +#include + +namespace device::math { + +inline constexpr float log2e = 1.44269504088896340736f; +inline constexpr float loge2 = 0.693147180559945309417f; +inline constexpr float FP8_E4M3_MAX = 448.0f; +static_assert(log2e * loge2 == 1.0f, "log2e * loge2 must be 1"); + +template +SGL_DEVICE T max(T a, T b) { + return dtype_trait::max(a, b); +} + +template +SGL_DEVICE T min(T a, T b) { + return dtype_trait::min(a, b); +} + +template +SGL_DEVICE T abs(T a) { + return dtype_trait::abs(a); +} + +template +SGL_DEVICE T sqrt(T a) { + return dtype_trait::sqrt(a); +} + +template +SGL_DEVICE T rsqrt(T a) { + return dtype_trait::rsqrt(a); +} + +} // namespace device::math diff --git a/python/sglang/jit_kernel/include/sgl_kernel/runtime.cuh b/python/sglang/jit_kernel/include/sgl_kernel/runtime.cuh index c9ba59be49a5..33ea710946f5 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/runtime.cuh +++ b/python/sglang/jit_kernel/include/sgl_kernel/runtime.cuh @@ -4,6 +4,7 @@ #include #include +#include namespace host::runtime { diff --git a/python/sglang/jit_kernel/include/sgl_kernel/tensor.h b/python/sglang/jit_kernel/include/sgl_kernel/tensor.h index 7117ecf641d5..de7ed8a0c671 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/tensor.h +++ b/python/sglang/jit_kernel/include/sgl_kernel/tensor.h @@ -21,9 +21,7 @@ #include #ifdef __CUDACC__ -#include -#include -#include +#include #endif namespace host { @@ -41,10 +39,10 @@ struct DTypeRef; struct DeviceRef; template -struct dtype_trait {}; +struct _dtype_trait {}; template -struct dtype_trait { +struct _dtype_trait { inline static constexpr DLDataType value = { .code = std::is_signed_v ? DLDataTypeCode::kDLInt : DLDataTypeCode::kDLUInt, .bits = static_cast(sizeof(T) * 8), @@ -52,36 +50,36 @@ struct dtype_trait { }; template -struct dtype_trait { +struct _dtype_trait { inline static constexpr DLDataType value = { .code = DLDataTypeCode::kDLFloat, .bits = static_cast(sizeof(T) * 8), .lanes = 1}; }; #ifdef __CUDACC__ template <> -struct dtype_trait<__half> { +struct _dtype_trait { inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLFloat, .bits = 16, .lanes = 1}; }; template <> -struct dtype_trait<__nv_bfloat16> { +struct _dtype_trait { inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLBfloat, .bits = 16, .lanes = 1}; }; template <> -struct dtype_trait<__nv_fp8_e4m3> { +struct _dtype_trait { inline static constexpr DLDataType value = {.code = DLDataTypeCode::kDLFloat8_e4m3fn, .bits = 8, .lanes = 1}; }; #endif template -struct device_trait { +struct _device_trait { inline static constexpr DLDevice value = {.device_type = Code, .device_id = kAnyDeviceID}; }; template -inline constexpr auto kDTypeList = std::array{dtype_trait::value...}; +inline constexpr auto kDTypeList = std::array{_dtype_trait::value...}; template -inline constexpr auto kDeviceList = std::array{device_trait::value...}; +inline constexpr auto kDeviceList = std::array{_device_trait::value...}; template struct PrintAbleSpan { @@ -155,7 +153,7 @@ inline auto& operator<<(std::ostream& os, PrintAbleSpan span) { template inline bool is_type(DLDataType dtype) { - return dtype == details::dtype_trait::value; + return dtype == details::_dtype_trait::value; } struct SymbolicSize { diff --git a/python/sglang/jit_kernel/include/sgl_kernel/tile.cuh b/python/sglang/jit_kernel/include/sgl_kernel/tile.cuh new file mode 100644 index 000000000000..d227a5585bab --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/tile.cuh @@ -0,0 +1,36 @@ +#pragma once +#include + +#include + +namespace device::tile { + +template +struct Memory { + public: + SGL_DEVICE constexpr Memory(uint32_t tid, uint32_t tsize) : tid(tid), tsize(tsize) {} + SGL_DEVICE static constexpr Memory thread() { + return Memory{0, 1}; + } + SGL_DEVICE static Memory warp(int warp_threads = kWarpThreads) { + return Memory{threadIdx.x % warp_threads, warp_threads}; + } + SGL_DEVICE static Memory cta(int cta_threads = blockDim.x) { + return Memory{threadIdx.x, cta_threads}; + } + SGL_DEVICE T load(const void* ptr, int64_t offset = 0) const { + return static_cast(ptr)[tid + offset * tsize]; + } + SGL_DEVICE void store(void* ptr, T val, int64_t offset = 0) const { + static_cast(ptr)[tid + offset * tsize] = val; + } + SGL_DEVICE bool in_bound(int64_t element_count, int64_t offset = 0) const { + return tid + offset * tsize < element_count; + } + + private: + uint32_t tid; + uint32_t tsize; +}; + +} // namespace device::tile diff --git a/python/sglang/jit_kernel/include/sgl_kernel/type.cuh b/python/sglang/jit_kernel/include/sgl_kernel/type.cuh new file mode 100644 index 000000000000..047e5c24ba3e --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/type.cuh @@ -0,0 +1,72 @@ +#pragma once +#include + +template +struct dtype_trait {}; + +#define SGL_REGISTER_DTYPE_TRAIT(TYPE, PACK2, ...) \ + template <> \ + struct dtype_trait { \ + using self_t = TYPE; \ + using packed_t = PACK2; \ + template \ + SGL_DEVICE static self_t from(const S& value) { \ + return static_cast(value); \ + } \ + __VA_ARGS__ \ + } + +#define SGL_REGISTER_TYPE_END static_assert(true) + +#define SGL_REGISTER_FROM_FUNCTION(FROM, FN) \ + SGL_DEVICE static self_t from(const FROM& x) { \ + return FN(x); \ + } \ + static_assert(true) + +#define SGL_REGISTER_UNARY_FUNCTION(NAME, FN) \ + SGL_DEVICE static self_t NAME(const self_t& x) { \ + return FN(x); \ + } \ + static_assert(true) + +#define SGL_REGISTER_BINARY_FUNCTION(NAME, FN) \ + SGL_DEVICE static self_t NAME(const self_t& x, const self_t& y) { \ + return FN(x, y); \ + } \ + static_assert(true) + +SGL_REGISTER_DTYPE_TRAIT(fp32_t, fp32x2_t, SGL_REGISTER_TYPE_END; // + SGL_REGISTER_UNARY_FUNCTION(abs, fabsf); + SGL_REGISTER_UNARY_FUNCTION(sqrt, sqrtf); + SGL_REGISTER_UNARY_FUNCTION(rsqrt, rsqrtf); + SGL_REGISTER_BINARY_FUNCTION(max, fmaxf); + SGL_REGISTER_BINARY_FUNCTION(min, fminf);); +SGL_REGISTER_DTYPE_TRAIT(fp16_t, fp16x2_t); +SGL_REGISTER_DTYPE_TRAIT(bf16_t, bf16x2_t); + +/// TODO: Add ROCM implementation +SGL_REGISTER_DTYPE_TRAIT(fp32x2_t, fp32x4_t, SGL_REGISTER_TYPE_END; + SGL_REGISTER_FROM_FUNCTION(fp16x2_t, __half22float2); + SGL_REGISTER_FROM_FUNCTION(bf16x2_t, __bfloat1622float2);); + +SGL_REGISTER_DTYPE_TRAIT(fp16x2_t, void, SGL_REGISTER_TYPE_END; + SGL_REGISTER_FROM_FUNCTION(fp32x2_t, __float22half2_rn);); + +SGL_REGISTER_DTYPE_TRAIT(bf16x2_t, void, SGL_REGISTER_TYPE_END; + SGL_REGISTER_FROM_FUNCTION(fp32x2_t, __float22bfloat162_rn);); + +#undef SGL_REGISTER_DTYPE_TRAIT +#undef SGL_REGISTER_FROM_FUNCTION + +template +using packed_t = typename dtype_trait::packed_t; + +namespace device { + +template +SGL_DEVICE To cast(const From& value) { + return dtype_trait::from(value); +} + +} // namespace device diff --git a/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh b/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh index 0e503a3c004b..01ce21a7a813 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh +++ b/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh @@ -2,145 +2,76 @@ #include -#include -#include #include #include #include #include -#include - -namespace device { - -inline constexpr auto kWarpThreads = 32u; -inline constexpr auto kFullMask = 0xffffffffu; +#include +#include +#include +#include -__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { #ifndef USE_ROCM - float old; - old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) - : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); - return old; -#else - int* addr_as_i = (int*)addr; - int old = *addr_as_i, assumed; - do { - assumed = old; - old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed)))); - } while (assumed != old); - return __int_as_float(old); +using fp32_t = float; +using fp16_t = __half; +using bf16_t = __nv_bfloat16; +using fp8_e4m3_t = __nv_fp8_e4m3; +using fp8_e5m2_t = __nv_fp8_e5m2; + +using fp32x2_t = float2; +using fp16x2_t = __half2; +using bf16x2_t = __nv_bfloat162; +using fp8x2_e4m3_t = __nv_fp8x2_e4m3; +using fp8x2_e5m2_t = __nv_fp8x2_e5m2; + +using fp32x4_t = float4; #endif -} - -__device__ __forceinline__ float warpReduceMax(float value) { - value = fmaxf(value, __shfl_xor_sync(kFullMask, value, 16)); - value = fmaxf(value, __shfl_xor_sync(kFullMask, value, 8)); - value = fmaxf(value, __shfl_xor_sync(kFullMask, value, 4)); - value = fmaxf(value, __shfl_xor_sync(kFullMask, value, 2)); - value = fmaxf(value, __shfl_xor_sync(kFullMask, value, 1)); - return value; -} - -__device__ __forceinline__ float blockReduceMax(float value) { - static __shared__ float warpLevelMaxs[kWarpThreads]; - const int laneId = threadIdx.x % kWarpThreads; - const int warpId = threadIdx.x / kWarpThreads; - value = warpReduceMax(value); - - if (laneId == 0) warpLevelMaxs[warpId] = value; - __syncthreads(); - - value = (threadIdx.x < blockDim.x / kWarpThreads) ? warpLevelMaxs[laneId] : 0; - if (warpId == 0) value = warpReduceMax(value); - - return value; -} - -namespace pointer { - -// we only allow void * pointer arithmetic for safety - -template -__always_inline __device__ auto offset(T* ptr, U... offset) -> void* { - static_assert(std::is_same_v, "Pointer arithmetic is only allowed for void* pointers"); - return static_cast(ptr) + (... + offset); -} +namespace device { -template -__always_inline __device__ auto offset(const T* ptr, U... offset) -> const void* { - static_assert(std::is_same_v, "Pointer arithmetic is only allowed for void* pointers"); - return static_cast(ptr) + (... + offset); -} +#define SGL_DEVICE __forceinline__ __device__ -} // namespace pointer +inline constexpr auto kWarpThreads = 32u; +inline constexpr auto kFullMask = 0xffffffffu; template -__forceinline__ __device__ void PDLWaitPrimary() { +SGL_DEVICE void PDLWaitPrimary() { #ifndef USE_ROCM if constexpr (kUsePDL) { - asm volatile("griddepcontrol.wait;"); + asm volatile("griddepcontrol.wait;" ::: "memory"); } #endif } template -__forceinline__ __device__ void PDLTriggerSecondary() { +SGL_DEVICE void PDLTriggerSecondary() { #ifndef USE_ROCM if constexpr (kUsePDL) { - asm volatile("griddepcontrol.launch_dependents;"); + asm volatile("griddepcontrol.launch_dependents;" :::); } #endif } -} // namespace device - -namespace host { - -// DType -template -struct PackedDType { - static_assert(dependent_false_v, "Unsupported dtype for Packed"); -}; - -template <> -struct PackedDType { - using type = float2; -}; - -template <> -struct PackedDType { - using type = float4; -}; +namespace pointer { -template <> -struct PackedDType<__half, 2> { - using type = __half2; -}; +// we only allow void * pointer arithmetic for safety -struct alignas(8) half4 { - __half x, y, z, w; -}; +template +SGL_DEVICE auto offset(void* ptr, U... offset) -> void* { + return static_cast(ptr) + (... + offset); +} -template <> -struct PackedDType<__half, 4> { - using type = half4; -}; +template +SGL_DEVICE auto offset(const void* ptr, U... offset) -> const void* { + return static_cast(ptr) + (... + offset); +} -template <> -struct PackedDType { - using type = nv_bfloat162; -}; +} // namespace pointer -struct alignas(8) bf16_4 { - nv_bfloat16 x, y, z, w; -}; +} // namespace device -template <> -struct PackedDType { - using type = bf16_4; -}; +namespace host { inline void RuntimeDeviceCheck(::cudaError_t error, DebugInfo location = {}) { if (error != ::cudaSuccess) { diff --git a/python/sglang/jit_kernel/include/sgl_kernel/utils.h b/python/sglang/jit_kernel/include/sgl_kernel/utils.h index 106ae14eed0b..78eae19fc83e 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/utils.h +++ b/python/sglang/jit_kernel/include/sgl_kernel/utils.h @@ -121,16 +121,14 @@ namespace pointer { // we only allow void * pointer arithmetic for safety -template -inline auto offset(T* ptr, U... offset) -> void* { - static_assert(std::is_same_v, "Pointer arithmetic is only allowed for void* pointers"); - return static_cast(ptr) + (... + offset); +template +inline auto offset(void* ptr, U... offset) -> void* { + return static_cast(ptr) + (... + offset); } -template -inline auto offset(const T* ptr, U... offset) -> const void* { - static_assert(std::is_same_v, "Pointer arithmetic is only allowed for void* pointers"); - return static_cast(ptr) + (... + offset); +template +inline auto offset(const void* ptr, U... offset) -> const void* { + return static_cast(ptr) + (... + offset); } } // namespace pointer diff --git a/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh b/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh index c0e72c2555b8..5510b44746c9 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh +++ b/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh @@ -1,6 +1,5 @@ #pragma once -#include -#include +#include #include #include @@ -38,30 +37,30 @@ using sized_int = typename uint_trait::type; } // namespace details template -struct alignas(sizeof(T) * N) aligned_storage { +struct alignas(sizeof(T) * N) AlignedStorage { T data[N]; }; template -struct aligned_vector { +struct AlignedVector { private: /// NOTE: 1. must be pow of two 2. 16 * 8 = 128 byte, which is the max vector size supported by most devices static_assert((N > 0 && (N & (N - 1)) == 0) && sizeof(T) * N <= 16, "CUDA only support at most 128B vector op"); using element_t = typename details::sized_int; - using storage_t = aligned_storage; + using storage_t = AlignedStorage; public: template - __forceinline__ __device__ void load(const U* ptr, std::size_t offset = 0) { + SGL_DEVICE void load(const U* ptr, std::size_t offset = 0) { static_assert(std::is_same_v || std::is_same_v); m_storage = reinterpret_cast(ptr)[offset]; } template - __forceinline__ __device__ void store(U* ptr, std::size_t offset = 0) const { + SGL_DEVICE void store(U* ptr, std::size_t offset = 0) const { static_assert(std::is_same_v || std::is_same_v); reinterpret_cast(ptr)[offset] = m_storage; } - __forceinline__ __device__ void fill(T value) { + SGL_DEVICE void fill(T value) { const auto store_value = *reinterpret_cast(&value); #pragma unroll for (std::size_t i = 0; i < N; ++i) { @@ -69,16 +68,16 @@ struct aligned_vector { } } - __forceinline__ __device__ auto operator[](std::size_t idx) -> T& { + SGL_DEVICE auto operator[](std::size_t idx) -> T& { return reinterpret_cast(&m_storage)[idx]; } - __forceinline__ __device__ auto operator[](std::size_t idx) const -> T { + SGL_DEVICE auto operator[](std::size_t idx) const -> T { return reinterpret_cast(&m_storage)[idx]; } - __forceinline__ __device__ auto data() -> T* { + SGL_DEVICE auto data() -> T* { return reinterpret_cast(&m_storage); } - __forceinline__ __device__ auto data() const -> const T* { + SGL_DEVICE auto data() const -> const T* { return reinterpret_cast(&m_storage); } diff --git a/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh b/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh index 88804b42bbee..d69526e97f29 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh +++ b/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh @@ -1,33 +1,25 @@ #pragma once -#include -#include - -#include +#include // Some warp primitives namespace device::warp { +static constexpr uint32_t kFullMask = 0xffffffffu; + template -__forceinline__ __device__ T reduce_sum(T val, uint32_t active_mask = 0xffffffff) { +SGL_DEVICE T reduce_sum(T value, uint32_t active_mask = kFullMask) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(active_mask, val, mask, 32); - return val; + value = value + __shfl_xor_sync(active_mask, value, mask, 32); + return value; } -template -__forceinline__ __device__ T load(const void* ptr) { - return static_cast(ptr)[threadIdx.x % kWarpThreads]; -} - -template -__forceinline__ __device__ T load(const T* ptr) { - return static_cast(ptr)[threadIdx.x % kWarpThreads]; -} - -template -__forceinline__ __device__ void store(void* ptr, T val) { - static_cast(ptr)[threadIdx.x % kWarpThreads] = val; +template +SGL_DEVICE T reduce_max(T value, uint32_t active_mask = kFullMask) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + value = math::max(value, __shfl_xor_sync(active_mask, value, mask, 32)); + return value; } } // namespace device::warp diff --git a/python/sglang/jit_kernel/norm.py b/python/sglang/jit_kernel/norm.py index e1a05aa2799c..4e8082c34acf 100644 --- a/python/sglang/jit_kernel/norm.py +++ b/python/sglang/jit_kernel/norm.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -27,6 +27,17 @@ def _jit_qknorm_module(head_dim: int, dtype: torch.dtype) -> Module: ) +@cache_once +def _jit_rmsnorm_module(hidden_size: int, dtype: torch.dtype) -> Module: + args = make_cpp_args(hidden_size, is_arch_support_pdl(), dtype) + return load_jit( + "rmsnorm", + *args, + cuda_files=["elementwise/rmsnorm.cuh"], + cuda_wrappers=[("rmsnorm", f"RMSNormKernel<{args}>::run")], + ) + + @cache_once def can_use_fused_inplace_qknorm(head_dim: int, dtype: torch.dtype) -> bool: logger = logging.getLogger(__name__) @@ -53,3 +64,15 @@ def fused_inplace_qknorm( head_dim = head_dim or q.size(-1) module = _jit_qknorm_module(head_dim, q.dtype) module.qknorm(q, k, q_weight, k_weight, eps) + + +def rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + output: Optional[torch.Tensor] = None, + eps: float = 1e-6, +) -> None: + output = output if output is not None else input + hidden_size = input.size(-1) + module = _jit_rmsnorm_module(hidden_size, input.dtype) + module.rmsnorm(input, weight, output, eps) diff --git a/python/sglang/jit_kernel/per_tensor_quant_fp8.py b/python/sglang/jit_kernel/per_tensor_quant_fp8.py index bee07a207b5a..9225aa45d1e5 100644 --- a/python/sglang/jit_kernel/per_tensor_quant_fp8.py +++ b/python/sglang/jit_kernel/per_tensor_quant_fp8.py @@ -1,11 +1,8 @@ from __future__ import annotations -import os from typing import TYPE_CHECKING -import flashinfer import torch -from torch.utils.cpp_extension import CUDA_HOME from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args from sglang.srt.utils.custom_op import register_custom_op @@ -15,20 +12,13 @@ @cache_once -def _jit_per_tensor_quant_fp8_module(is_static: bool) -> Module: - args = make_cpp_args(is_static) - - flashinfer_include = os.path.join( - os.path.dirname(flashinfer.__file__), "data", "include" - ) - cub_include = os.path.join(CUDA_HOME, "include") - +def _jit_per_tensor_quant_fp8_module(is_static: bool, dtype: torch.dtype) -> Module: + args = make_cpp_args(is_static, dtype) return load_jit( "per_tensor_quant_fp8", *args, cuda_files=["gemm/per_tensor_quant_fp8.cuh"], cuda_wrappers=[("per_tensor_quant_fp8", f"per_tensor_quant_fp8<{args}>")], - extra_include_paths=[flashinfer_include, cub_include], ) @@ -51,10 +41,5 @@ def per_tensor_quant_fp8( output_s: Output scale tensor (float scalar or 1D tensor with 1 element) is_static: If True, assumes scale is pre-computed and skips absmax computation """ - # Ensure output_s has shape [1] instead of being a 0D scalar - # The JIT kernel expects a 1D tensor - if output_s.ndim == 0: - output_s = output_s.reshape(1) - - module = _jit_per_tensor_quant_fp8_module(is_static) - module.per_tensor_quant_fp8(input, output_q, output_s) + module = _jit_per_tensor_quant_fp8_module(is_static, input.dtype) + module.per_tensor_quant_fp8(input.view(-1), output_q.view(-1), output_s.view(-1)) diff --git a/python/sglang/jit_kernel/tests/test_per_tensor_quant_fp8.py b/python/sglang/jit_kernel/tests/test_per_tensor_quant_fp8.py index 23357c57efc3..a08b698f9fb9 100644 --- a/python/sglang/jit_kernel/tests/test_per_tensor_quant_fp8.py +++ b/python/sglang/jit_kernel/tests/test_per_tensor_quant_fp8.py @@ -58,7 +58,7 @@ def test_jit_per_tensor_quant_compare_implementations( ) -@pytest.mark.parametrize("shape", [(4, 8, 64), (2, 16, 128)]) +@pytest.mark.parametrize("shape", [(4, 8, 64), (2, 16, 128), (19260817, 1, 1)]) def test_jit_per_tensor_quant_supports_3d(shape): device = torch.device("cuda") x = torch.rand(shape, dtype=torch.bfloat16, device=device) @@ -74,7 +74,7 @@ def test_jit_per_tensor_quant_supports_3d(shape): torch.testing.assert_close(out.float(), out_ref.float(), rtol=1e-3, atol=1e-3) scale = torch.rand(1, dtype=torch.float32, device=device) - sglang_out, sglang_scale = sglang_scaled_fp8_quant(x, scale) + sglang_out, _ = sglang_scaled_fp8_quant(x, scale) torch_out = torch_scaled_fp8_quant(x, scale) torch.testing.assert_close( diff --git a/python/sglang/jit_kernel/tests/test_qknorm.py b/python/sglang/jit_kernel/tests/test_qknorm.py index 8fb8f1ba7baa..2938341d7937 100644 --- a/python/sglang/jit_kernel/tests/test_qknorm.py +++ b/python/sglang/jit_kernel/tests/test_qknorm.py @@ -1,3 +1,6 @@ +import itertools + +import pytest import torch import triton @@ -56,30 +59,35 @@ def torch_impl_qknorm( k.copy_(k.float() * k_norm * k_weight.float()) +BS_LIST = [2**n for n in range(0, 14)] +BS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)] +N_K_LIST = [2, 4] +N_Q_LIST = [8, 16] +HEAD_DIM_LIST = [64, 128, 256] +DEVICE = "cuda" +DTYPE = torch.bfloat16 + # NOTE(dark): sgl_kernel use flashinfer template, which is bitwise identical to flashinfer impl. # However, sgl-jit-kernel, flashinfer, torch_impl, may have small numerical differences. # so we allow a small rel/abs tolerance in correctness test. -def main(): - N_K = 2 - N_Q = 16 - DEVICE = "cuda" - DTYPE = torch.bfloat16 - BS_LIST = [2**n for n in range(0, 15)] - BS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)] - for HEAD_DIM in [64, 128, 256]: - for BS in BS_LIST: - q = torch.randn(BS, N_Q, HEAD_DIM, device=DEVICE, dtype=DTYPE) - k = torch.randn(BS, N_K, HEAD_DIM, device=DEVICE, dtype=DTYPE) - q_weight = torch.randn(HEAD_DIM, device=DEVICE, dtype=DTYPE) - k_weight = torch.randn(HEAD_DIM, device=DEVICE, dtype=DTYPE) - q_k_aot = (q.clone(), k.clone()) - q_k_jit = (q.clone(), k.clone()) - sglang_aot_qknorm(q_k_aot[0], q_k_aot[1], q_weight, k_weight) - sglang_jit_qknorm(q_k_jit[0], q_k_jit[1], q_weight, k_weight) - triton.testing.assert_close(q_k_aot[0], q_k_jit[0], atol=1e-2, rtol=1e-2) - triton.testing.assert_close(q_k_aot[1], q_k_jit[1], atol=1e-2, rtol=1e-2) - print(f"HEAD_DIM={HEAD_DIM} correctness test passed.") + + +@pytest.mark.parametrize( + "batch_size,n_k,n_q,head_dim", + list(itertools.product(BS_LIST, N_K_LIST, N_Q_LIST, HEAD_DIM_LIST)), +) +def test_qknorm(batch_size: int, n_k: int, n_q: int, head_dim: int) -> None: + q = torch.randn(batch_size, n_q, head_dim, device=DEVICE, dtype=DTYPE) + k = torch.randn(batch_size, n_k, head_dim, device=DEVICE, dtype=DTYPE) + q_weight = torch.randn(head_dim, device=DEVICE, dtype=DTYPE) + k_weight = torch.randn(head_dim, device=DEVICE, dtype=DTYPE) + q_k_aot = (q.clone(), k.clone()) + q_k_jit = (q.clone(), k.clone()) + sglang_aot_qknorm(q_k_aot[0], q_k_aot[1], q_weight, k_weight) + sglang_jit_qknorm(q_k_jit[0], q_k_jit[1], q_weight, k_weight) + triton.testing.assert_close(q_k_aot[0], q_k_jit[0], atol=1e-2, rtol=1e-2) + triton.testing.assert_close(q_k_aot[1], q_k_jit[1], atol=1e-2, rtol=1e-2) if __name__ == "__main__": - main() + pytest.main([__file__]) diff --git a/python/sglang/jit_kernel/tests/test_rmsnorm.py b/python/sglang/jit_kernel/tests/test_rmsnorm.py new file mode 100644 index 000000000000..168a953347f4 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_rmsnorm.py @@ -0,0 +1,41 @@ +import itertools + +import pytest +import torch +import triton + + +def sglang_jit_rmsnorm(input: torch.Tensor, weight: torch.Tensor) -> None: + from sglang.jit_kernel.norm import rmsnorm + + rmsnorm(input, weight, output=input) + + +def flashinfer_rmsnorm(input: torch.Tensor, weight: torch.Tensor) -> None: + from flashinfer.norm import rmsnorm + + rmsnorm(input, weight, out=input) + + +BS_LIST = [2**n for n in range(0, 14)] +BS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)] +HIDDEN_SIZE_LIST = [512, 1024, 1536, 2048, 3072, 4096, 5120, 6144, 7168, 8192] +DEVICE = "cuda" +DTYPE = torch.bfloat16 + + +@pytest.mark.parametrize( + "batch_size,hidden_size", list(itertools.product(BS_LIST, HIDDEN_SIZE_LIST)) +) +def test_rmsnorm(batch_size: int, hidden_size: int) -> None: + input = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=DTYPE) + weight = torch.randn(hidden_size, device=DEVICE, dtype=DTYPE) + input_sglang = input.clone() + input_flashinfer = input.clone() + sglang_jit_rmsnorm(input_sglang, weight) + flashinfer_rmsnorm(input_flashinfer, weight) + triton.testing.assert_close(input_sglang, input_flashinfer, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/python/sglang/jit_kernel/utils.py b/python/sglang/jit_kernel/utils.py index fa326db00173..e8358d35d68e 100644 --- a/python/sglang/jit_kernel/utils.py +++ b/python/sglang/jit_kernel/utils.py @@ -42,7 +42,7 @@ def _package_install(): DEFAULT_CFLAGS = ["-std=c++20", "-O3"] DEFAULT_CUDA_CFLAGS = ["-std=c++20", "-O3", "--expt-relaxed-constexpr"] DEFAULT_LDFLAGS = [] -CPP_TEMPLATE_TYPE: TypeAlias = Union[int, float, bool] +CPP_TEMPLATE_TYPE: TypeAlias = Union[int, float, bool, torch.dtype] class CPPArgList(list[str]): @@ -50,6 +50,13 @@ def __str__(self) -> str: return ", ".join(self) +CPP_DTYPE_MAP = { + torch.float: "fp32_t", + torch.float16: "fp16_t", + torch.bfloat16: "bf16_t", +} + + def make_cpp_args(*args: CPP_TEMPLATE_TYPE) -> CPPArgList: def _convert(arg: CPP_TEMPLATE_TYPE) -> str: if isinstance(arg, bool): @@ -57,13 +64,7 @@ def _convert(arg: CPP_TEMPLATE_TYPE) -> str: if isinstance(arg, (int, float)): return str(arg) if isinstance(arg, torch.dtype): - if arg == torch.float: - return "float" - if arg == torch.float16: - return "__half" - if arg == torch.bfloat16: - return "nv_bfloat16" - raise TypeError(f"Not implement this arg wrapper yet: {arg}") + return CPP_DTYPE_MAP[arg] raise TypeError(f"Unsupported argument type for cpp template: {type(arg)}") return CPPArgList(_convert(arg) for arg in args)