diff --git a/docs/advanced_features/quantization.md b/docs/advanced_features/quantization.md index 8a30d5084660..e8180fb716b6 100644 --- a/docs/advanced_features/quantization.md +++ b/docs/advanced_features/quantization.md @@ -71,11 +71,12 @@ Backend selection is supported only for **blockwise FP8** and **NVFP4** GEMM. Wh | Backend | Hardware | Description | |---------|----------|-------------| | `auto` | SM100/120 | Auto-selects: `flashinfer_cudnn` on SM120; `flashinfer_cutlass` on SM100 | +| `cutlass` | SM100/120 | SGLang CUTLASS kernel | | `flashinfer_cutlass` | SM100/120 | FlashInfer CUTLASS backend | | `flashinfer_cudnn` | SM100/120 (CUDA 13+, cuDNN 9.15+) | FlashInfer cuDNN backend; used on SM120 for performance | | `flashinfer_trtllm` | SM100 | FlashInfer TensorRT-LLM backend | -When FlashInfer is unavailable for NVFP4, sgl-kernel CUTLASS is used as an automatic fallback. +When FlashInfer is unavailable for NVFP4, the SGLang CUTLASS kernel is used as an automatic fallback. ## Offline Quantization diff --git a/python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py b/python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py index f7af1e7c5100..4278f0348260 100644 --- a/python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py +++ b/python/sglang/jit_kernel/benchmark/bench_nvfp4_scaled_mm.py @@ -7,7 +7,7 @@ from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm, scaled_fp4_quant -from sglang.srt.utils import is_sm100_supported +from sglang.srt.utils import is_sm100_supported, is_sm120_supported from sglang.test.ci.ci_register import register_cuda_ci register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large") @@ -15,7 +15,7 @@ FLOAT4_E2M1_MAX = 6.0 FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max BLOCK_SIZE = 16 -_NVFP4_SUPPORTED = is_sm100_supported() +_NVFP4_SUPPORTED = is_sm100_supported() or is_sm120_supported() K_E2M1_TO_FLOAT = [ 0.0, @@ -178,7 +178,7 @@ def benchmark(m, n, k, provider): if __name__ == "__main__": if not _NVFP4_SUPPORTED: - print("[skip] NVFP4 scaled_mm benchmark requires sm100+ with CUDA 12.8+.") + print("[skip] NVFP4 scaled_mm benchmark requires sm100/sm120 with CUDA 12.8+.") sys.exit(0) if not _AOT_SCALED_MM_AVAILABLE: print( diff --git a/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_common.cuh b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_common.cuh new file mode 100644 index 000000000000..f5ebca05b37c --- /dev/null +++ b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_common.cuh @@ -0,0 +1,66 @@ +/* Copyright 2026 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include +#include + +using namespace host; + +// clang-format off +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/packed_stride.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + RuntimeCheck(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + } + +using namespace cute; + +inline uint32_t next_pow_2(uint32_t x) noexcept { + if (x <= 1) return 1; + return 1u << (32 - __builtin_clz(x - 1)); +} + +inline auto alloc_workspace_tensor(size_t required_bytes, DLDevice device) -> tvm::ffi::Tensor { + if (required_bytes == 0) return {}; + DLDataType u8 = {kDLUInt, 8, 1}; + int64_t shape[] = {static_cast(required_bytes)}; + return ffi::empty(tvm::ffi::ShapeView(shape, 1), u8, device); +} + +inline int getSMVersion(int device_id) { + int sm_major = 0; + int sm_minor = 0; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device_id)); + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device_id)); + return sm_major * 10 + sm_minor; +} diff --git a/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh index 9cc309f14b55..8c5cfefd7956 100644 --- a/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh +++ b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_kernels.cuh @@ -1,4 +1,4 @@ -/* Copyright 2025 SGLang Team. All Rights Reserved. +/* Copyright 2026 SGLang Team. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,593 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include - -#include -#include - -#include -#include -#include -#include - -using namespace host; - -// clang-format off -#include "cutlass/cutlass.h" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/util/packed_stride.hpp" -// clang-format on - -/** - * Helper function for checking CUTLASS errors - */ -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - RuntimeCheck(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ - } - -using namespace cute; - -// Helper function for next power of 2 -inline uint32_t next_pow_2(uint32_t x) { - if (x == 0) return 1; - x--; - x |= x >> 1; - x |= x >> 2; - x |= x >> 4; - x |= x >> 8; - x |= x >> 16; - return x + 1; -} - -struct WorkspaceKey { - int device_id; - uintptr_t stream; - auto operator==(const WorkspaceKey&) const -> bool = default; -}; - -struct WorkspaceKeyHash { - auto operator()(const WorkspaceKey& key) const -> size_t { - size_t h1 = std::hash{}(key.device_id); - size_t h2 = std::hash{}(key.stream); - return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2)); - } -}; - -struct WorkspaceState { - void* ptr = nullptr; - size_t bytes = 0; -}; - -inline auto get_cached_workspace(size_t required_bytes, int device_id, cudaStream_t stream) -> void* { - if (required_bytes == 0) { - return nullptr; - } - - thread_local std::unordered_map cache; - WorkspaceKey key{device_id, reinterpret_cast(stream)}; - auto& ws = cache[key]; - - if (ws.ptr != nullptr && ws.bytes >= required_bytes) { - return ws.ptr; - } - - RuntimeDeviceCheck(cudaSetDevice(device_id)); - if (ws.ptr != nullptr) { - RuntimeDeviceCheck(cudaFreeAsync(ws.ptr, stream)); - ws.ptr = nullptr; - ws.bytes = 0; - } - RuntimeDeviceCheck(cudaMallocAsync(&ws.ptr, required_bytes, stream)); - ws.bytes = required_bytes; - return ws.ptr; -} - -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || \ - defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) -// Config(half_t/bfloat16_t) for M <= 128 -template -struct KernelConfigM128 { - using OutputType = T; - using MmaTileShape = Shape<_128, _256, _256>; - using ClusterShape = Shape; - using EpilogueTile = Shape<_128, _64>; // Avoid register spilling - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; - using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; - const static dim3 preferred_cluster; - const static dim3 fallback_cluster; -}; -template -const dim3 KernelConfigM128::preferred_cluster(1, 4, 1); -template -const dim3 KernelConfigM128::fallback_cluster(1, 2, 1); - -// Config(half_t/bfloat16_t) for M <= 256 -template -struct KernelConfigM256 { - using OutputType = T; - using MmaTileShape = Shape<_256, _256, _256>; - using ClusterShape = Shape; - using EpilogueTile = Shape<_128, _64>; // Avoid register spilling - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; - using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; - const static dim3 preferred_cluster; - const static dim3 fallback_cluster; -}; -template -const dim3 KernelConfigM256::preferred_cluster(2, 4, 1); -template -const dim3 KernelConfigM256::fallback_cluster(2, 1, 1); - -// Default config(half_t/bfloat16_t) for M > 256 -template -struct KernelConfigDefault { - using OutputType = T; - using MmaTileShape = Shape<_256, _256, _256>; - using ClusterShape = Shape; - using EpilogueTile = Shape<_128, _64>; // Avoid register spilling - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; - using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; - const static dim3 preferred_cluster; - const static dim3 fallback_cluster; -}; -template -const dim3 KernelConfigDefault::preferred_cluster(4, 4, 1); -template -const dim3 KernelConfigDefault::fallback_cluster(2, 1, 1); - -struct KernelConfigFp32 { - using OutputType = float; - using MmaTileShape = Shape<_128, _128, _256>; - using ClusterShape = Shape; - using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; - using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; - const static dim3 preferred_cluster; - const static dim3 fallback_cluster; -}; -const dim3 KernelConfigFp32::preferred_cluster = dim3(1, 4, 1); -const dim3 KernelConfigFp32::fallback_cluster = dim3(1, 2, 1); - -// SM120 specific configurations -struct sm120_fp4_config_M256 { - using ClusterShape = Shape<_1, _1, _1>; - using MmaTileShape = Shape<_128, _128, _128>; - using PerSmTileShape_MNK = Shape<_128, _128, _128>; -}; - -struct sm120_fp4_config_default { - using ClusterShape = Shape<_1, _1, _1>; - using MmaTileShape = Shape<_256, _128, _128>; - using PerSmTileShape_MNK = Shape<_256, _128, _128>; -}; - -template -struct Fp4GemmSm100 { - using Config = KernelConfig; // For generating args - using OutputType = typename KernelConfig::OutputType; - // A matrix configuration - using ElementA = cutlass::nv_float4_t; - using LayoutATag = cutlass::layout::RowMajor; - static constexpr int AlignmentA = 32; - - // B matrix configuration - using ElementB = cutlass::nv_float4_t; - using LayoutBTag = cutlass::layout::ColumnMajor; - static constexpr int AlignmentB = 32; - - // C/D matrix configuration - using ElementD = OutputType; - using ElementC = OutputType; - using LayoutCTag = cutlass::layout::RowMajor; - using LayoutDTag = cutlass::layout::RowMajor; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - // Kernel functional config - using ElementAccumulator = float; - using ArchTag = cutlass::arch::Sm100; - using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; - - // Kernel Perf config - using MmaTileShape = typename KernelConfig::MmaTileShape; - using ClusterShape = typename KernelConfig::ClusterShape; - using EpilogueTile = typename KernelConfig::EpilogueTile; - using EpilogueSchedule = typename KernelConfig::EpilogueSchedule; - using MainloopSchedule = typename KernelConfig::MainloopSchedule; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - MmaTileShape, - ClusterShape, - EpilogueTile, - ElementAccumulator, - ElementAccumulator, - void, - LayoutCTag, - AlignmentC, - ElementD, - LayoutDTag, - AlignmentD, - EpilogueSchedule, - cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementA, - LayoutATag, - AlignmentA, - ElementB, - LayoutBTag, - AlignmentB, - ElementAccumulator, - MmaTileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainloopSchedule>::CollectiveOp; - - using GemmKernel = - cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - using StrideA = typename Gemm::GemmKernel::StrideA; - using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{})); - using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; - using StrideB = typename Gemm::GemmKernel::StrideB; - using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{})); - using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; - using StrideC = typename Gemm::GemmKernel::StrideC; - using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{})); - using StrideD = typename Gemm::GemmKernel::StrideD; - using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); -}; - -// SM120 specific GEMM template -template -struct Fp4GemmSm120 { - using ElementA = cutlass::nv_float4_t; - using LayoutATag = cutlass::layout::RowMajor; - static constexpr int AlignmentA = 32; - - using ElementB = cutlass::nv_float4_t; - using LayoutBTag = cutlass::layout::ColumnMajor; - static constexpr int AlignmentB = 32; - - using ElementD = OutType; - using ElementC = OutType; - using LayoutCTag = cutlass::layout::RowMajor; - using LayoutDTag = cutlass::layout::RowMajor; - static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - - using ElementAccumulator = float; - using ArchTag = cutlass::arch::Sm120; - using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; - - using MmaTileShape = typename Config::MmaTileShape; - using ClusterShape = typename Config::ClusterShape; - using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - PerSmTileShape_MNK, - ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, - ElementAccumulator, - ElementC, - LayoutCTag, - AlignmentC, - ElementD, - LayoutDTag, - AlignmentD, - cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementA, - LayoutATag, - AlignmentA, - ElementB, - LayoutBTag, - AlignmentB, - ElementAccumulator, - MmaTileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; - - using GemmKernel = - cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -}; - -template -typename T::Gemm::Arguments args_from_options( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int64_t M, - int64_t N, - int64_t K) { - using ElementA = typename T::Gemm::ElementA; - using ElementB = typename T::Gemm::ElementB; - using ElementSFA = cutlass::float_ue4m3_t; - using ElementSFB = cutlass::float_ue4m3_t; - using ElementD = typename T::Gemm::ElementD; - using ElementCompute = float; - using StrideA = typename T::StrideA; - using StrideB = typename T::StrideB; - using StrideD = typename T::StrideD; - using Sm1xxBlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; - - int m = static_cast(M); - int n = static_cast(N); - int k = static_cast(K); - auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); - auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); - auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); - - auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); - auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); - - typename T::Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {m, n, k, 1}, - {// Mainloop arguments - static_cast(A.data_ptr()), - stride_A, - static_cast(B.data_ptr()), - stride_B, - static_cast(A_sf.data_ptr()), - layout_SFA, - static_cast(B_sf.data_ptr()), - layout_SFB}, - { // Epilogue arguments - {}, // epilogue.thread - nullptr, - stride_D, - static_cast(D.data_ptr()), - stride_D}}; - auto& fusion_args = arguments.epilogue.thread; - fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - using KernelConfig = typename T::Config; - arguments.hw_info.cluster_shape = KernelConfig::preferred_cluster; - arguments.hw_info.cluster_shape_fallback = KernelConfig::fallback_cluster; - return arguments; -} - -template -void runGemm( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int64_t m, - int64_t n, - int64_t k, - cudaStream_t stream) { - typename T::Gemm gemm; - auto arguments = args_from_options(D, A, B, A_sf, B_sf, alpha, m, n, k); - - size_t workspace_size = T::Gemm::get_workspace_size(arguments); - int device_id = A.device().device_id; - void* workspace = get_cached_workspace(workspace_size, device_id, stream); - - CUTLASS_CHECK(gemm.can_implement(arguments)); - - CUTLASS_CHECK(gemm.initialize(arguments, workspace, stream)); - - CUTLASS_CHECK(gemm.run(arguments, workspace, stream)); -} - -// SM120 specific args_from_options function -template -typename Gemm::Arguments args_from_options_sm120( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int M, - int N, - int K) { - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementD = typename Gemm::ElementD; - using ElementSFA = cutlass::float_ue4m3_t; - using ElementSFB = cutlass::float_ue4m3_t; - using ElementCompute = float; - - using StrideA = typename Gemm::GemmKernel::StrideA; - using StrideB = typename Gemm::GemmKernel::StrideB; - using StrideC = typename Gemm::GemmKernel::StrideC; - using StrideD = typename Gemm::GemmKernel::StrideD; - - using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; - - auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); - auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); - auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); - - auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); - auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); - - typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K, 1}, - {static_cast(A.data_ptr()), - stride_A, - static_cast(B.data_ptr()), - stride_B, - static_cast(A_sf.data_ptr()), - layout_SFA, - static_cast(B_sf.data_ptr()), - layout_SFB}, - {{}, static_cast(D.data_ptr()), stride_D, static_cast(D.data_ptr()), stride_D}}; - auto& fusion_args = arguments.epilogue.thread; - fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - - return arguments; -} - -// SM120 specific runGemm function -template -void runGemmSm120( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int M, - int N, - int K, - cudaStream_t stream) { - Gemm gemm; - - auto arguments = args_from_options_sm120(D, A, B, A_sf, B_sf, alpha, M, N, K); - - size_t workspace_size = Gemm::get_workspace_size(arguments); - int device_id = A.device().device_id; - void* workspace = get_cached_workspace(workspace_size, device_id, stream); - - CUTLASS_CHECK(gemm.can_implement(arguments)); - - CUTLASS_CHECK(gemm.initialize(arguments, workspace, stream)); - - CUTLASS_CHECK(gemm.run(arguments, workspace, stream)); -} - -// Dispatch function to select appropriate config based on M -template -void cutlassFp4GemmDispatch( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int64_t m, - int64_t n, - int64_t k, - cudaStream_t stream) { - if (m <= 128) { - // m in [1, 128] - runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } else if (m <= 256) { - // m in (128, 256] - runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } else { - // m in (256, inf) - runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } -} - -// Dispatch function to select appropriate config based on M -template <> -void cutlassFp4GemmDispatch( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int64_t m, - int64_t n, - int64_t k, - cudaStream_t stream) { - runGemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); -} - -// SM120 specific dispatch functions -void cutlass_fp4_bf16_gemm_dispatch_sm120( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int m, - int n, - int k, - cudaStream_t stream) { - uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); - if (mp2 <= 256) { - runGemmSm120::Gemm>( - D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } else { - runGemmSm120::Gemm>( - D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } -} - -void cutlass_fp4_f16_gemm_dispatch_sm120( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int m, - int n, - int k, - cudaStream_t stream) { - uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); - if (mp2 <= 256) { - runGemmSm120::Gemm>( - D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } else { - runGemmSm120::Gemm>( - D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } -} - -#else -template -void cutlassFp4GemmDispatch( - tvm::ffi::TensorView D, - tvm::ffi::TensorView A, - tvm::ffi::TensorView B, - tvm::ffi::TensorView A_sf, - tvm::ffi::TensorView B_sf, - tvm::ffi::TensorView alpha, - int64_t m, - int64_t n, - int64_t k, - cudaStream_t stream) { - RuntimeCheck( - false, - "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " - "a CUTLASS 3.8 source directory to enable support."); -} -#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || - // defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) - -inline int getSMVersion(int device_id) { - int sm_major = 0; - int sm_minor = 0; - RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device_id)); - RuntimeDeviceCheck(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device_id)); - return sm_major * 10 + sm_minor; -} +#include "nvfp4_scaled_mm_common.cuh" +#include "nvfp4_scaled_mm_sm100.cuh" +#include "nvfp4_scaled_mm_sm120.cuh" void cutlass_scaled_fp4_mm_sm100a_sm120a( tvm::ffi::TensorView D, @@ -718,11 +134,11 @@ void cutlass_scaled_fp4_mm_sm100a_sm120a( } } else { if (host::is_type(D.dtype())) { - cutlassFp4GemmDispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + cutlassFp4GemmDispatchSm100(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else if (host::is_type(D.dtype())) { - cutlassFp4GemmDispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + cutlassFp4GemmDispatchSm100(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else if (host::is_type(D.dtype())) { - cutlassFp4GemmDispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + cutlassFp4GemmDispatchSm100(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else { Panic("Unsupported output data type of nvfp4 mm"); } diff --git a/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm100.cuh b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm100.cuh new file mode 100644 index 000000000000..bd5927a23f09 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm100.cuh @@ -0,0 +1,284 @@ +/* Copyright 2026 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "nvfp4_scaled_mm_common.cuh" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// Config(half_t/bfloat16_t) for M <= 128 +template +struct KernelConfigM128 { + using OutputType = T; + using MmaTileShape = Shape<_128, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = Shape<_128, _64>; // Avoid register spilling + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +template +const dim3 KernelConfigM128::preferred_cluster(1, 4, 1); +template +const dim3 KernelConfigM128::fallback_cluster(1, 2, 1); + +// Config(half_t/bfloat16_t) for M <= 256 +template +struct KernelConfigM256 { + using OutputType = T; + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = Shape<_128, _64>; // Avoid register spilling + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +template +const dim3 KernelConfigM256::preferred_cluster(2, 4, 1); +template +const dim3 KernelConfigM256::fallback_cluster(2, 1, 1); + +// Default config(half_t/bfloat16_t) for M > 256 +template +struct KernelConfigDefault { + using OutputType = T; + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = Shape<_128, _64>; // Avoid register spilling + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +template +const dim3 KernelConfigDefault::preferred_cluster(4, 4, 1); +template +const dim3 KernelConfigDefault::fallback_cluster(2, 1, 1); + +struct KernelConfigFp32 { + using OutputType = float; + using MmaTileShape = Shape<_128, _128, _256>; + using ClusterShape = Shape; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +const dim3 KernelConfigFp32::preferred_cluster = dim3(1, 4, 1); +const dim3 KernelConfigFp32::fallback_cluster = dim3(1, 2, 1); + +template +struct Fp4GemmSm100 { + using Config = KernelConfig; + using OutputType = typename KernelConfig::OutputType; + + using ElementA = cutlass::nv_float4_t; + using LayoutATag = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 32; + + using ElementB = cutlass::nv_float4_t; + using LayoutBTag = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 32; + + using ElementD = OutputType; + using ElementC = OutputType; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + using MmaTileShape = typename KernelConfig::MmaTileShape; + using ClusterShape = typename KernelConfig::ClusterShape; + using EpilogueTile = typename KernelConfig::EpilogueTile; + using EpilogueSchedule = typename KernelConfig::EpilogueSchedule; + using MainloopSchedule = typename KernelConfig::MainloopSchedule; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + EpilogueTile, + ElementAccumulator, + ElementAccumulator, + void, + LayoutCTag, + AlignmentC, + ElementD, + LayoutDTag, + AlignmentD, + EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutATag, + AlignmentA, + ElementB, + LayoutBTag, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{})); + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{})); + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); +}; + +template +typename T::Gemm::Arguments args_from_options( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int64_t M, + int64_t N, + int64_t K) { + using ElementA = typename T::Gemm::ElementA; + using ElementB = typename T::Gemm::ElementB; + using ElementSFA = cutlass::float_ue4m3_t; + using ElementSFB = cutlass::float_ue4m3_t; + using ElementD = typename T::Gemm::ElementD; + using ElementCompute = float; + using StrideA = typename T::StrideA; + using StrideB = typename T::StrideB; + using StrideD = typename T::StrideD; + using Sm1xxBlkScaledConfig = typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + int m = static_cast(M); + int n = static_cast(N); + int k = static_cast(K); + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1)); + + typename T::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + nullptr, + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + using KernelConfig = typename T::Config; + arguments.hw_info.cluster_shape = KernelConfig::preferred_cluster; + arguments.hw_info.cluster_shape_fallback = KernelConfig::fallback_cluster; + return arguments; +} + +template +void runGemm( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + typename T::Gemm gemm; + auto arguments = args_from_options(D, A, B, A_sf, B_sf, alpha, m, n, k); + + size_t workspace_size = T::Gemm::get_workspace_size(arguments); + auto workspace_tensor = alloc_workspace_tensor(workspace_size, A.device()); + void* workspace = (workspace_size == 0) ? nullptr : workspace_tensor.data_ptr(); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + + CUTLASS_CHECK(gemm.initialize(arguments, workspace, stream)); + + CUTLASS_CHECK(gemm.run(arguments, workspace, stream)); +} + +template +void cutlassFp4GemmDispatchSm100( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + if (m <= 128) { + runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (m <= 256) { + runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + +template <> +void cutlassFp4GemmDispatchSm100( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + runGemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm120.cuh b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm120.cuh new file mode 100644 index 000000000000..cdb159061eb9 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/gemm/nvfp4/nvfp4_scaled_mm_sm120.cuh @@ -0,0 +1,228 @@ +/* Copyright 2026 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "nvfp4_scaled_mm_common.cuh" + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + +struct sm120_fp4_config_small_m { + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_128, _128, _256>; + using PerSmTileShape_MNK = Shape<_128, _128, _256>; +}; + +struct sm120_fp4_config_M256 { + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_128, _128, _128>; + using PerSmTileShape_MNK = Shape<_128, _128, _128>; +}; + +struct sm120_fp4_config_default { + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_256, _128, _128>; + using PerSmTileShape_MNK = Shape<_256, _128, _128>; +}; + +template +struct Fp4GemmSm120 { + using ElementA = cutlass::nv_float4_t; + using LayoutATag = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 32; + + using ElementB = cutlass::nv_float4_t; + using LayoutBTag = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 32; + + using ElementD = OutType; + using ElementC = OutType; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm120; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + using MmaTileShape = typename Config::MmaTileShape; + using ClusterShape = typename Config::ClusterShape; + using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + PerSmTileShape_MNK, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + void, + LayoutCTag, + AlignmentC, + ElementD, + LayoutDTag, + AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutATag, + AlignmentA, + ElementB, + LayoutBTag, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +typename Gemm::Arguments args_from_options_sm120( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int M, + int N, + int K) { + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementD = typename Gemm::ElementD; + using ElementSFA = cutlass::float_ue4m3_t; + using ElementSFB = cutlass::float_ue4m3_t; + using ElementCompute = float; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + {static_cast(A.data_ptr()), + stride_A, + static_cast(B.data_ptr()), + stride_B, + static_cast(A_sf.data_ptr()), + layout_SFA, + static_cast(B_sf.data_ptr()), + layout_SFB}, + {{}, nullptr, stride_D, static_cast(D.data_ptr()), stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + + return arguments; +} + +template +void runGemmSm120( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int M, + int N, + int K, + cudaStream_t stream) { + Gemm gemm; + + auto arguments = args_from_options_sm120(D, A, B, A_sf, B_sf, alpha, M, N, K); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + auto workspace_tensor = alloc_workspace_tensor(workspace_size, A.device()); + void* workspace = (workspace_size == 0) ? nullptr : workspace_tensor.data_ptr(); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + + CUTLASS_CHECK(gemm.initialize(arguments, workspace, stream)); + + CUTLASS_CHECK(gemm.run(arguments, workspace, stream)); +} + +void cutlass_fp4_bf16_gemm_dispatch_sm120( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int m, + int n, + int k, + cudaStream_t stream) { + uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); + if (mp2 <= 32) { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (mp2 <= 256) { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + +void cutlass_fp4_f16_gemm_dispatch_sm120( + tvm::ffi::TensorView D, + tvm::ffi::TensorView A, + tvm::ffi::TensorView B, + tvm::ffi::TensorView A_sf, + tvm::ffi::TensorView B_sf, + tvm::ffi::TensorView alpha, + int m, + int n, + int k, + cudaStream_t stream) { + uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); + if (mp2 <= 32) { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (mp2 <= 256) { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + runGemmSm120::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 1072dac7b3ca..477339a54fa6 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -150,7 +150,10 @@ def apply_weights( w = layer.weight_packed w_blockscale = layer.weight_scale - if enable_flashinfer_fp4_gemm: + if ( + enable_flashinfer_fp4_gemm + and not get_fp4_gemm_runner_backend().is_cutlass() + ): w = layer.weight_packed.T w_blockscale = layer.weight_scale.T diff --git a/python/sglang/srt/layers/quantization/fp4_utils.py b/python/sglang/srt/layers/quantization/fp4_utils.py index 3e913e137f02..33de4bc0e7ad 100644 --- a/python/sglang/srt/layers/quantization/fp4_utils.py +++ b/python/sglang/srt/layers/quantization/fp4_utils.py @@ -17,6 +17,7 @@ class Fp4GemmRunnerBackend(Enum): """Enum for FP4 GEMM runner backend selection.""" AUTO = "auto" + CUTLASS = "cutlass" FLASHINFER_CUDNN = "flashinfer_cudnn" FLASHINFER_CUTLASS = "flashinfer_cutlass" FLASHINFER_TRTLLM = "flashinfer_trtllm" @@ -24,6 +25,9 @@ class Fp4GemmRunnerBackend(Enum): def is_auto(self) -> bool: return self == Fp4GemmRunnerBackend.AUTO + def is_cutlass(self) -> bool: + return self == Fp4GemmRunnerBackend.CUTLASS + def is_flashinfer_cudnn(self) -> bool: return self == Fp4GemmRunnerBackend.FLASHINFER_CUDNN @@ -33,6 +37,9 @@ def is_flashinfer_cutlass(self) -> bool: def is_flashinfer_trtllm(self) -> bool: return self == Fp4GemmRunnerBackend.FLASHINFER_TRTLLM + def is_flashinfer(self) -> bool: + return self.value.startswith("flashinfer_") + def get_flashinfer_backend(self) -> str: """Get the backend string to pass to FlashInfer's mm_fp4 API. diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index dc4fa71fcd20..0bd87934625c 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -86,13 +86,19 @@ enable_flashinfer_fp4_gemm = True except ImportError: - if is_cuda(): - from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm as cutlass_fp4_gemm enable_flashinfer_fp4_gemm = False reorder_rows_for_gated_act_gemm = None shuffle_matrix_a = None shuffle_matrix_sf_a = None +if is_cuda(): + try: + from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm as cutlass_fp4_gemm + except ImportError: + cutlass_fp4_gemm = None +else: + cutlass_fp4_gemm = None + try: from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe from flashinfer.fused_moe.core import ActivationType @@ -134,7 +140,15 @@ def fp4_gemm( out_features: int, ) -> torch.Tensor: fp4_backend = get_fp4_gemm_runner_backend() - if enable_flashinfer_fp4_gemm: + if fp4_backend.is_cutlass() and cutlass_fp4_gemm is not None: + # flashinfer.fp4_quantize returns scale factors as uint8 (e4m3fn bits + # stored in uint8 memory). The JIT kernel requires float8_e4m3fn dtype. + if input_sf.dtype != torch.float8_e4m3fn: + input_sf = input_sf.view(torch.float8_e4m3fn) + if weight_sf.dtype != torch.float8_e4m3fn: + weight_sf = weight_sf.view(torch.float8_e4m3fn) + return cutlass_fp4_gemm(input, weight, input_sf, weight_sf, alpha, out_dtype) + elif enable_flashinfer_fp4_gemm: # Use the remapping logic to convert SGLang backend names to FlashInfer API names backend = fp4_backend.get_flashinfer_backend() return flashinfer_fp4_gemm( @@ -1478,7 +1492,10 @@ def apply( w = layer.weight w_scale_interleaved = layer.weight_scale_interleaved - if enable_flashinfer_fp4_gemm: + if ( + enable_flashinfer_fp4_gemm + and not get_fp4_gemm_runner_backend().is_cutlass() + ): w = layer.weight.T w_scale_interleaved = layer.weight_scale_interleaved.T diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a0f0704f3469..a071fdae1810 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -214,6 +214,7 @@ FP4_GEMM_RUNNER_BACKEND_CHOICES = [ "auto", + "cutlass", "flashinfer_cudnn", "flashinfer_cutlass", "flashinfer_trtllm", @@ -4619,7 +4620,8 @@ def add_cli_args(parser: argparse.ArgumentParser): dest="fp4_gemm_runner_backend", help="Choose the runner backend for NVFP4 GEMM operations. " "Options: 'auto' (default; selects flashinfer_cudnn on SM120, flashinfer_cutlass otherwise), " - "'flashinfer_cutlass' (CUTLASS backend), " + "'cutlass' (SGLang CUTLASS kernel), " + "'flashinfer_cutlass' (FlashInfer CUTLASS backend), " "'flashinfer_cudnn' (FlashInfer cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), " "'flashinfer_trtllm' (FlashInfer TensorRT-LLM backend, requires different weight preparation with shuffling). " "NOTE: This replaces the deprecated environment variable " diff --git a/sgl-kernel/benchmark/bench_fp4_gemm.py b/sgl-kernel/benchmark/bench_fp4_gemm.py index f8f0bd666a21..0f1023af8fd6 100755 --- a/sgl-kernel/benchmark/bench_fp4_gemm.py +++ b/sgl-kernel/benchmark/bench_fp4_gemm.py @@ -1,13 +1,20 @@ import argparse import csv import os +from functools import partial +from typing import List, Tuple import torch import triton from flashinfer import mm_fp4 +from flashinfer.testing import bench_gpu_time from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm, scaled_fp4_quant -from sglang.srt.utils import get_device_capability, is_sm100_supported +from sglang.srt.utils import ( + get_device_capability, + is_sm100_supported, + is_sm120_supported, +) from sglang.utils import is_in_ci IS_CI = is_in_ci() @@ -15,30 +22,102 @@ FLOAT4_E2M1_MAX = 6.0 FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +DEEPSEEK_R1_MODEL = "deepseek-ai/DeepSeek-R1-0528-FP4" -def get_weight_shapes(args): - models_tps = args.tp_sizes +# Weight shapes are in the format: ([K, N], TP_SPLIT_DIM) +# TP split dim 0 means split K by tp size; dim 1 means split N by tp size. +WEIGHT_SHAPES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "Qwen/Qwen3.5-27B": [ + ([5120, 8192], 1), + ([6144, 5120], 0), + ([5120, 34816], 1), + ([17408, 5120], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], +} - if models_tps == [4]: - return [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]] +DEEPSEEK_R1_WEIGHT_SHAPES = { + 4: [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]], + 8: [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]], +} - if models_tps == [8]: - return [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]] - return [ - [1024, 3584], - [7168, 256], - [7168, 2304], - [9216, 3584], - [512, 3584], - [7168, 128], - [7168, 1152], - [4608, 3584], - ] + +def get_weight_shapes(args) -> List[Tuple[int, int, str]]: + shapes: List[Tuple[int, int, str]] = [] + for model in args.models: + if model == DEEPSEEK_R1_MODEL: + for tp_size in args.tp_sizes: + if tp_size in DEEPSEEK_R1_WEIGHT_SHAPES: + selected = DEEPSEEK_R1_WEIGHT_SHAPES[tp_size] + else: + selected = ( + DEEPSEEK_R1_WEIGHT_SHAPES[4] + DEEPSEEK_R1_WEIGHT_SHAPES[8] + ) + for n, packed_k in selected: + shapes.append((n, packed_k, model)) + continue + + if model not in WEIGHT_SHAPES: + raise ValueError(f"Unsupported model: {model}") + for tp_size in args.tp_sizes: + for k_n, tp_split_dim in WEIGHT_SHAPES[model]: + k, n = k_n + if tp_split_dim == 0: + k = k // tp_size + else: + n = n // tp_size + packed_k = k // 2 + shapes.append((n, packed_k, model)) + return shapes -# CI environment uses simplified parameters if IS_CI: - batch_sizes = [1, 8] # Simplified for CI + batch_sizes = [1, 8] else: batch_sizes = [ 1, @@ -60,29 +139,54 @@ def get_weight_shapes(args): ] +def _run_mm_fp4(a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, dtype, res_fi, backend): + return mm_fp4(a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, dtype, res_fi, backend=backend) + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], x_vals=batch_sizes, - # x_vals = [64], x_log=False, line_arg="provider", - line_vals=["sglang_cutlass", "cutlass", "cudnn", "trtllm", "auto"], - line_names=[ - "sglang cutlass fp4", - "flashinfer cutlass fp4", - "cudnn fp4", - "trtllm fp4", - "auto fp4 (cudnn/cutlass)", - ], - styles=[ - ("red", "solid"), - ("orange", "solid"), - ("blue", "solid"), - ("green", "solid"), - ("purple", "solid"), - ], - ylabel="latency (ms)", + line_vals=( + ["sglang_cutlass", "cutlass", "cudnn", "trtllm", "auto"] + if is_sm100_supported() + else ["sglang_cutlass", "cutlass", "cudnn", "auto"] + ), + line_names=( + [ + "sglang cutlass fp4", + "flashinfer cutlass fp4", + "cudnn fp4", + "trtllm fp4", + "auto fp4 (cudnn/cutlass)", + ] + if is_sm100_supported() + else [ + "sglang cutlass fp4", + "flashinfer cutlass fp4", + "cudnn fp4", + "auto fp4", + ] + ), + styles=( + [ + ("red", "solid"), + ("orange", "solid"), + ("blue", "solid"), + ("green", "solid"), + ("purple", "solid"), + ] + if is_sm100_supported() + else [ + ("red", "solid"), + ("orange", "solid"), + ("blue", "solid"), + ("purple", "solid"), + ] + ), + ylabel="bandwidth (GB/s)", plot_name="fp4_gemm_benchmark", args={}, ) @@ -99,87 +203,93 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file): b_global_scale = ( (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) ).to(torch.float32) - alpha = 1.0 / (a_global_scale * b_global_scale) a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale) - # print("a_fp4", a_fp4) b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale) + b_fp4_T = b_fp4.T + b_sf_T = b_scale_interleaved.T res_fi = torch.empty((M, N), dtype=dtype, device="cuda") - quantiles = [0.5, 0.2, 0.8] if provider == "sglang_cutlass": - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: cutlass_scaled_fp4_mm( - a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype - ), - quantiles=quantiles, - ) - if provider == "cutlass": - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: mm_fp4( + times_ms = bench_gpu_time( + fn=cutlass_scaled_fp4_mm, + input_args=( a_fp4, - b_fp4.T, + b_fp4, a_scale_interleaved, - b_scale_interleaved.T, + b_scale_interleaved, alpha, dtype, - res_fi, - backend="cutlass", ), - quantiles=quantiles, + use_cuda_graph=True, ) - if provider == "cudnn": - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: mm_fp4( + elif provider == "cutlass": + times_ms = bench_gpu_time( + fn=partial(_run_mm_fp4, backend="cutlass"), + input_args=( a_fp4, - b_fp4.T, + b_fp4_T, a_scale_interleaved, - b_scale_interleaved.T, + b_sf_T, alpha, dtype, res_fi, - backend="cudnn", ), - quantiles=quantiles, + use_cuda_graph=True, ) - if provider == "trtllm": - a_scale_interleaved = a_scale_interleaved.to(torch.uint8) - b_scale_interleaved = b_scale_interleaved.to(torch.uint8) - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: mm_fp4( + elif provider == "cudnn": + times_ms = bench_gpu_time( + fn=partial(_run_mm_fp4, backend="cudnn"), + input_args=( a_fp4, - b_fp4.T, + b_fp4_T, a_scale_interleaved, - b_scale_interleaved.T, + b_sf_T, alpha, dtype, res_fi, - backend="trtllm", ), - quantiles=quantiles, + use_cuda_graph=True, ) - if provider == "auto": - ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: mm_fp4( + elif provider == "trtllm": + a_sf_u8 = a_scale_interleaved.to(torch.uint8) + b_sf_u8_T = b_sf_T.to(torch.uint8) + times_ms = bench_gpu_time( + fn=partial(_run_mm_fp4, backend="trtllm"), + input_args=(a_fp4, b_fp4_T, a_sf_u8, b_sf_u8_T, alpha, dtype, res_fi), + use_cuda_graph=True, + ) + elif provider == "auto": + times_ms = bench_gpu_time( + fn=partial(_run_mm_fp4, backend="auto"), + input_args=( a_fp4, - b_fp4.T, + b_fp4_T, a_scale_interleaved, - b_scale_interleaved.T, + b_sf_T, alpha, dtype, res_fi, ), - quantiles=quantiles, + use_cuda_graph=True, ) + + ms = torch.tensor(times_ms).median().item() + + # A: M×packed_k bytes (fp4 packed), B: N×packed_k bytes, C: M×N×element_size bytes + element_size = torch.finfo(dtype).bits // 8 + total_bytes = M * packed_k + N * packed_k + M * N * element_size + bandwidth_gbs = total_bytes / (ms * 1e-3) / 1e9 + if correctness: res_cutlass = cutlass_scaled_fp4_mm( a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype ) mm_fp4( a_fp4, - b_fp4.T, + b_fp4_T, a_scale_interleaved, - b_scale_interleaved.T, + b_sf_T, alpha, dtype, res_fi, @@ -190,9 +300,9 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file): ), "cudnn fp4 doesn't match cutlass fp4" mm_fp4( a_fp4, - b_fp4.T, + b_fp4_T, a_scale_interleaved, - b_scale_interleaved.T, + b_sf_T, alpha, dtype, res_fi, @@ -205,13 +315,20 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file): if csv_file: with open(csv_file, "a", newline="") as f: writer = csv.writer(f) - writer.writerow([provider, M, N, K, ms]) + writer.writerow([provider, M, N, K, ms, bandwidth_gbs]) - return ms, min_ms, max_ms + return bandwidth_gbs if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=[DEEPSEEK_R1_MODEL], + help="List of models to benchmark. Supported: Llama 8B/70B, Qwen, Mistral, DeepSeek.", + ) parser.add_argument( "--tp-sizes", nargs="+", @@ -223,7 +340,7 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file): "--dtype", type=torch.dtype, default=torch.bfloat16, - help="Data type", + help="Output data type", ) parser.add_argument( "--correctness", @@ -238,34 +355,29 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file): ) args = parser.parse_args() - # Simplify for CI environment if IS_CI: - args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size + args.tp_sizes = [args.tp_sizes[0]] if args.csv: with open(args.csv, "w", newline="") as f: writer = csv.writer(f) - writer.writerow(["provider", "m", "n", "k", "time_ms"]) + writer.writerow(["provider", "m", "n", "k", "time_ms", "bandwidth_gbs"]) - # FP4 operations require Blackwell SM100 support major, minor = get_device_capability() - if not is_sm100_supported(): + if not (is_sm100_supported() or is_sm120_supported()): print("Skipping FP4 GEMM benchmark") if major is not None: - print( - f"FP4 operations require SM100 (Blackwell), but found sm{major}{minor}" - ) + print(f"FP4 operations require sm100+, but found sm{major}{minor}") else: print("Could not determine device capability") else: NKs = get_weight_shapes(args) - # Limit iterations in CI if IS_CI: - NKs = NKs[:2] # Only test first 2 shapes in CI + NKs = NKs[:2] - for N, K in NKs: - print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ") + for N, K, model_name in NKs: + print(f"{model_name} N={N} packed_k={K}: ") benchmark.run( print_data=True, N=N, diff --git a/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py b/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py deleted file mode 100644 index eeb5842edec2..000000000000 --- a/sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py +++ /dev/null @@ -1,192 +0,0 @@ -import argparse -import copy -import itertools -import os - -import torch -import triton - -from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm, scaled_fp4_quant -from sglang.srt.utils import get_device_capability - -# CI environment detection -IS_CI = ( - os.getenv("CI", "false").lower() == "true" - or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" -) - -FLOAT4_E2M1_MAX = 6.0 -FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max - -# Weight Shapes are in the format -# ([K, N], TP_SPLIT_DIM) -# Example: -# A shape of ([14336, 4096], 0) indicates the following GEMM shape, -# - TP1 : K = 14336, N = 4096 -# - TP2 : K = 7168, N = 4096 -# A shape of ([4096, 6144], 1) indicates the following GEMM shape, -# - TP1 : K = 4096, N = 6144 -# - TP4 : K = 4096, N = 1536 - -# TP1 shapes -WEIGHT_SHAPES = { - "meta-llama/Llama-3.1-8B-Instruct": [ - ([4096, 6144], 1), - ([4096, 4096], 0), - ([4096, 28672], 1), - ([14336, 4096], 0), - ], - "meta-llama/Llama-3.3-70B-Instruct": [ - ([8192, 10240], 1), - ([8192, 8192], 0), - ([8192, 57344], 1), - ([28672, 8192], 0), - ], - "mistralai/Mistral-Large-Instruct-2407": [ - ([12288, 14336], 1), - ([12288, 12288], 0), - ([12288, 57344], 1), - ([28672, 12288], 0), - ], - "Qwen/Qwen2.5-7B-Instruct": [ - ([3584, 4608], 1), - ([3584, 3584], 0), - ([3584, 37888], 1), - ([18944, 3584], 0), - ], - "Qwen/Qwen2.5-32B-Instruct": [ - ([5120, 7168], 1), - ([5120, 5120], 0), - ([5120, 55296], 1), - ([27648, 5120], 0), - ], - "Qwen/Qwen2.5-72B-Instruct": [ - ([8192, 10240], 1), - ([8192, 8192], 0), - ([8192, 59136], 1), - ([29568, 8192], 0), - ], - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ - ([2048, 3072], 1), - ([2048, 4096], 1), - ([2048, 2048], 0), - ([2048, 576], 0), - ([2048, 21888], 1), - ([10944, 2048], 0), - ([2048, 2816], 1), - ([1408, 2048], 0), - ], -} - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size"], - x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], - x_log=False, - line_arg="provider", - line_vals=[ - "sglang-fp4-fp16", - "sglang-fp4-bf16", - ], - line_names=[ - "sglang-fp4-fp16", - "sglang-fp4-bf16", - ], - styles=[("green", "-"), ("blue", "-")], - ylabel="TFLOPS", - plot_name="fp4 block scaled matmul", - args={}, - ) -) -def benchmark(batch_size, provider, N, K): - # M, N, K = batch_size, 4096, 8192 - run_step = 100 - dtype = torch.float16 if "fp16" in provider else torch.bfloat16 - M = batch_size - a = torch.randn((M, K), dtype=dtype, device="cuda") - b = torch.randn((N, K), dtype=dtype, device="cuda") - a_global_scale = ( - (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1) - ).to(torch.float32) - b_global_scale = ( - (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b.flatten(), dim=-1) - ).to(torch.float32) - alpha = 1.0 / (a_global_scale * b_global_scale) - a_fp4, a_scale_interleaved = scaled_fp4_quant(a, a_global_scale) - b_fp4, b_scale_interleaved = scaled_fp4_quant(b, b_global_scale) - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # Bridging the gap between CPU and GPU - for _ in range(25): - c = a @ b.t() - # Warmup - for _ in range(5): - cutlass_scaled_fp4_mm( - a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype - ) - start_event.record() - for _ in range(run_step): - cutlass_scaled_fp4_mm( - a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype - ) - end_event.record() - end_event.synchronize() - torch.cuda.synchronize() - ms = start_event.elapsed_time(end_event) / run_step - - tflops = lambda ms: (2 * M * N * K) * 1e-9 / ms - return tflops(ms) - - -def prepare_shapes(args): - KN_model_names = [] - models_tps = list(itertools.product(args.models, args.tp_sizes)) - for model, tp_size in models_tps: - assert model in WEIGHT_SHAPES - for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): - KN[tp_split_dim] = KN[tp_split_dim] // tp_size - KN.append(model) - KN_model_names.append(KN) - return KN_model_names - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--models", - nargs="+", - type=str, - default=["meta-llama/Llama-3.1-8B-Instruct"], - help="List of models to benchmark", - ) - parser.add_argument( - "--tp-sizes", - nargs="+", - type=int, - default=[1], - help="List of tensor parallel sizes", - ) - args = parser.parse_args() - - # Check architecture compatibility - FP4 operations require sm100a/sm103a - major, minor = get_device_capability() - if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a) - print("Skipping NVIDIA FP4 scaled GEMM benchmark") - if major is not None: - print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}") - else: - print("Could not determine device capability") - else: - KN_model_names = prepare_shapes(args) - - # Limit iterations in CI - if IS_CI: - KN_model_names = KN_model_names[:2] # Only test first 2 shapes in CI - - for K, N, model_name in KN_model_names: - print(f"{model_name} N={N} K={K}: ") - benchmark.run(print_data=True, N=N, K=K) - print("Benchmark finished!")