From 3871b93d55f765df184f854a46e770a933e801a2 Mon Sep 17 00:00:00 2001 From: Vivian Chen <140748220+xuanzic@users.noreply.github.com> Date: Fri, 14 Nov 2025 20:33:40 -0800 Subject: [PATCH 1/3] add swapab linear gemm binding --- csrc/fp8_blockscale_gemm_sm90_binding.cu | 231 +++++++++++++++ flashinfer/gemm/__init__.py | 8 + flashinfer/gemm/gemm_base.py | 260 +++++++++++++++++ flashinfer/jit/gemm/__init__.py | 2 + flashinfer/jit/gemm/fp8_blockscale.py | 57 ++++ tests/gemm/test_fp8_blockscale_gemm.py | 349 +++++++++++++++++++++++ 6 files changed, 907 insertions(+) create mode 100755 csrc/fp8_blockscale_gemm_sm90_binding.cu create mode 100755 flashinfer/jit/gemm/fp8_blockscale.py create mode 100755 tests/gemm/test_fp8_blockscale_gemm.py diff --git a/csrc/fp8_blockscale_gemm_sm90_binding.cu b/csrc/fp8_blockscale_gemm_sm90_binding.cu new file mode 100755 index 0000000000..1a24ed624c --- /dev/null +++ b/csrc/fp8_blockscale_gemm_sm90_binding.cu @@ -0,0 +1,231 @@ + +#include +#include "tvm_ffi_utils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h" + +#include +#include +#include +#include +#include + +namespace kernels = tensorrt_llm::kernels::fp8_blockscale_gemm; + +using tvm::ffi::Function; +using tvm::ffi::Optional; + +/** + * @brief TVM FFI wrapper for TensorRT-LLM's FP8 block-scale GEMM kernel. + * + * This class provides a Python-accessible interface to the CUTLASS-based FP8 GEMM + * implementation with block-wise quantization. It supports two execution modes: + * + * 1. Pre-quantized FP8: Both inputs are FP8 with external scale factors + * 2. Internal quantization: BF16 inputs are quantized to FP8 internally + * + * The kernel automatically selects between normal and swapAB execution based on + * the M dimension for optimal performance. + * + * @note Requires NVIDIA Hopper (SM90) architecture and CUDA 12.8+ + */ +class Fp8BlockScaleGemmRunner : public tvm::ffi::ModuleObj { + public: + /** + * @brief Constructor initializes the CUTLASS FP8 GEMM runner. + * + * Template parameters allow the kernel to: + * - Accept FP8 inputs with external scales (pre-quantized path) + * - Accept BF16 inputs for internal quantization to FP8 + * - Produce BF16 output (accumulation happens in FP32 internally) + */ + Fp8BlockScaleGemmRunner() { + runner_ = std::make_unique>(); + } + + ~Fp8BlockScaleGemmRunner() = default; + + const char* type_key() const { + return "flashinfer.Fp8BlockScaleGemmRunner"; + } + + const char* kind() const final { + return "fp8_blockscale_gemm_runner"; + } + + Optional GetFunction(const tvm::ffi::String& name) { + if (name == "gemm") { + return Function::FromTyped( + [this](TensorView input, TensorView weight, TensorView output, + Optional scales_a, Optional scales_b) { + runGemm(input, weight, output, scales_a, scales_b); + }); + } else if (name == "get_workspace_size") { + return Function::FromTyped( + [this](int64_t shape_m, int64_t shape_n, int64_t shape_k) -> int64_t { + return getWorkspaceSize(shape_m, shape_n, shape_k); + }); + } else if (name == "configure_workspace") { + return Function::FromTyped( + [this](TensorView workspace) { + configureWorkspace(workspace); + }); + } + return Function(nullptr); + } + + private: + void runGemm(const TensorView& input, const TensorView& weight, const TensorView& output, + const Optional& scales_a, const Optional& scales_b) { + // Get CUDA stream from TVM runtime + auto stream = get_stream(input.device()); + + // Extract tensor info + auto input_ptr = input.data_ptr(); + auto weight_ptr = weight.data_ptr(); + auto output_ptr = output.data_ptr(); + + // Validate tensor dimensions + TVM_FFI_ICHECK(input.ndim() == 2) << "Input must be 2D (M, K), got " << input.ndim() << "D"; + TVM_FFI_ICHECK(weight.ndim() == 2) << "Weight must be 2D (N, K), got " << weight.ndim() << "D"; + TVM_FFI_ICHECK(output.ndim() == 2) << "Output must be 2D (M, N), got " << output.ndim() << "D"; + + int shape_m = input.size(0); + int shape_k = input.size(1); + int shape_n = weight.size(0); + + TVM_FFI_ICHECK_EQ(weight.size(1), shape_k) + << "Weight K dimension must match input K. Expected " << shape_k + << ", got " << weight.size(1); + TVM_FFI_ICHECK_EQ(output.size(0), shape_m) + << "Output M dimension must match input M. Expected " << shape_m + << ", got " << output.size(0); + TVM_FFI_ICHECK_EQ(output.size(1), shape_n) + << "Output N dimension must match weight N. Expected " << shape_n + << ", got " << output.size(1); + + // Validate K is divisible by block size (128) + constexpr int BLOCK_SIZE = 128; + TVM_FFI_ICHECK_EQ(shape_k % BLOCK_SIZE, 0) + << "K dimension must be divisible by block size (" << BLOCK_SIZE + << "), got K=" << shape_k; + + // Validate workspace is configured + TVM_FFI_ICHECK(workspace_ != nullptr) + << "Workspace not configured. Call configure_workspace() before gemm()"; + + // Get scales if provided + float const* scales_a_ptr = nullptr; + float const* scales_b_ptr = nullptr; + + if (scales_a.has_value()) { + const auto& scale_a_view = scales_a.value(); + TVM_FFI_ICHECK_EQ(scale_a_view.dtype().code, kDLFloat) + << "input_scale must be float32"; + TVM_FFI_ICHECK_EQ(scale_a_view.dtype().bits, 32) + << "input_scale must be float32"; + TVM_FFI_ICHECK_EQ(scale_a_view.ndim(), 2) + << "input_scale must be 2D (M, K//128)"; + TVM_FFI_ICHECK_EQ(scale_a_view.size(0), shape_m) + << "input_scale M dimension mismatch. Expected " << shape_m + << ", got " << scale_a_view.size(0); + TVM_FFI_ICHECK_EQ(scale_a_view.size(1), shape_k / BLOCK_SIZE) + << "input_scale K dimension mismatch. Expected " << (shape_k / BLOCK_SIZE) + << ", got " << scale_a_view.size(1); + scales_a_ptr = reinterpret_cast(scale_a_view.data_ptr()); + } + + if (scales_b.has_value()) { + const auto& scale_b_view = scales_b.value(); + TVM_FFI_ICHECK_EQ(scale_b_view.dtype().code, kDLFloat) + << "weight_scale must be float32"; + TVM_FFI_ICHECK_EQ(scale_b_view.dtype().bits, 32) + << "weight_scale must be float32"; + TVM_FFI_ICHECK_EQ(scale_b_view.ndim(), 2) + << "weight_scale must be 2D (N, K//128)"; + TVM_FFI_ICHECK_EQ(scale_b_view.size(0), shape_n) + << "weight_scale N dimension mismatch. Expected " << shape_n + << ", got " << scale_b_view.size(0); + TVM_FFI_ICHECK_EQ(scale_b_view.size(1), shape_k / BLOCK_SIZE) + << "weight_scale K dimension mismatch. Expected " << (shape_k / BLOCK_SIZE) + << ", got " << scale_b_view.size(1); + scales_b_ptr = reinterpret_cast(scale_b_view.data_ptr()); + } + + // Check input types - FP8 uses special dtype codes + bool input_is_fp8 = (input.dtype().code == kDLFloat8_e4m3fn || + input.dtype().code == kDLFloat8_e5m2); + bool weight_is_fp8 = (weight.dtype().code == kDLFloat8_e4m3fn || + weight.dtype().code == kDLFloat8_e5m2); + + // Dispatch to appropriate kernel path + if (input_is_fp8 && weight_is_fp8) { + // Path 1: Both inputs are FP8 - use pre-quantized FP8 GEMM + TVM_FFI_ICHECK(scales_a_ptr != nullptr && scales_b_ptr != nullptr) + << "FP8 inputs require scale factors. Provide both input_scale and weight_scale."; + + // Validate output dtype is BF16 + TVM_FFI_ICHECK_EQ(output.dtype().code, kDLBfloat) + << "Output must be BF16 for FP8 inputs"; + TVM_FFI_ICHECK_EQ(output.dtype().bits, 16) + << "Output must be BF16 for FP8 inputs"; + + auto input_fp8 = reinterpret_cast<__nv_fp8_e4m3 const*>(input_ptr); + auto weight_fp8 = reinterpret_cast<__nv_fp8_e4m3 const*>(weight_ptr); + auto output_bf16 = reinterpret_cast<__nv_bfloat16*>(output_ptr); + + int ld_a = shape_k; // Leading dimension for row-major input + int ld_b = shape_k; // Leading dimension for row-major weight + int ld_d = shape_n; // Leading dimension for row-major output + + runner_->gemm(input_fp8, ld_a, weight_fp8, ld_b, output_bf16, ld_d, + shape_m, shape_n, shape_k, scales_a_ptr, scales_b_ptr, stream); + } else if (!input_is_fp8 && !weight_is_fp8) { + // Path 2: Both inputs are BF16 - use internal quantization + TVM_FFI_ICHECK(scales_a_ptr == nullptr && scales_b_ptr == nullptr) + << "BF16 inputs use internal quantization. Do not provide scales. " + << "For external scales, use FP8 inputs."; + + // Validate input/weight dtypes are BF16 + TVM_FFI_ICHECK_EQ(input.dtype().code, kDLBfloat) + << "Input must be BF16 for internal quantization path"; + TVM_FFI_ICHECK_EQ(input.dtype().bits, 16) + << "Input must be BF16 for internal quantization path"; + TVM_FFI_ICHECK_EQ(weight.dtype().code, kDLBfloat) + << "Weight must be BF16 for internal quantization path"; + TVM_FFI_ICHECK_EQ(weight.dtype().bits, 16) + << "Weight must be BF16 for internal quantization path"; + + // Call internal quantization path (note: different argument order!) + runner_->gemm(output_ptr, input_ptr, weight_ptr, shape_m, shape_n, shape_k, + stream, scales_a_ptr, scales_b_ptr); + } else { + // Path 3: Mixed dtypes - not supported + TVM_FFI_ICHECK(false) + << "Mixed FP8/BF16 inputs not supported. Both input and weight must be " + << "either FP8 (with scales) or BF16 (internal quantization)."; + } + } + + int64_t getWorkspaceSize(int64_t shape_m, int64_t shape_n, int64_t shape_k) { + // Use getWorkspaceSizeBase to ensure internal state is properly initialized + // This is critical for BF16 internal quantization path + // num_problems=1 for single GEMM, top_k=1 for regular gemm (not MOE) + size_t workspace_size = runner_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, /*num_problems=*/1); + return workspace_size; + } + + void configureWorkspace(const TensorView& workspace) { + auto workspace_ptr = reinterpret_cast(workspace.data_ptr()); + runner_->configureWorkspace(workspace_ptr); + } + + std::unique_ptr runner_; +}; + +tvm::ffi::Module init() { + auto ptr = tvm::ffi::make_object(); + return tvm::ffi::Module(ptr); +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(init, init); diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index 15652268ba..545536424a 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -13,6 +13,12 @@ from .gemm_base import gemm_fp8_nt_blockscaled as gemm_fp8_nt_blockscaled from .gemm_base import gemm_fp8_nt_groupwise as gemm_fp8_nt_groupwise from .gemm_base import group_gemm_fp8_nt_groupwise as group_gemm_fp8_nt_groupwise +from .gemm_base import ( + get_fp8_blockscale_gemm_runner as get_fp8_blockscale_gemm_runner, +) +from .gemm_base import ( + fp8_blockscale_gemm_swapab as fp8_blockscale_gemm_swapab, +) from .routergemm_dsv3 import ( mm_M1_16_K7168_N256 as mm_M1_16_K7168_N256, @@ -30,5 +36,7 @@ "gemm_fp8_nt_blockscaled", "gemm_fp8_nt_groupwise", "group_gemm_fp8_nt_groupwise", + "get_fp8_blockscale_gemm_runner", + "fp8_blockscale_gemm_swapab", "mm_M1_16_K7168_N256", ] diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index ac0fbab4a0..b1a8b764c3 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -54,6 +54,7 @@ from ..jit.gemm import gen_trtllm_gen_gemm_module from ..jit.gemm import gen_tgv_gemm_sm10x_module from ..jit.gemm import gen_deepgemm_sm100_module +from ..jit.gemm import gen_fp8_blockscale_gemm_sm90_module CUDNN_AVAILABLE = False @@ -3111,3 +3112,262 @@ def batch_deepgemm_fp8_nt_groupwise( ) return out + +@functools.cache +def get_fp8_blockscale_gemm_runner(): + """Get the FP8 block scale GEMM runner module for SM90.""" + module = gen_fp8_blockscale_gemm_sm90_module().build_and_load() + return module.init() + + +def fp8_blockscale_gemm_swapab( + input: torch.Tensor, + weight: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + weight_scale: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + Perform FP8 block-scaled GEMM with automatic swapAB optimization. + This function automatically selects between normal and swapAB kernel based on + the M dimension. For small M (< 32), it uses the swapAB kernel for + better performance. + The computation is: output = input @ weight.T with per-block FP8 quantization + and scaling. + + Supported Dtype Combinations + ----------------------------- + - **BF16 + BF16 → BF16**: Both inputs BF16, internal quantization (no scales needed) + - **BF16 + FP8 → BF16**: BF16 input, FP8 weight + + Note: + - FP16 is NOT supported. + - FP8 + BF16 is NOT supported (missing kernel implementation) + + Parameters + ---------- + input : torch.Tensor + Input activation tensor of shape (M, K). + - BF16 (torch.bfloat16) with internal quantization + weight : torch.Tensor + Weight tensor of shape (N, K). Can be: + - FP8 (torch.float8_e4m3fn) with weight_scale required + - BF16 (torch.bfloat16) for internal quantization + input_scale : torch.Tensor, optional + Not used. Input is always BF16 with internal quantization. + weight_scale : torch.Tensor, optional + Scaling factors for weight. Required if weight is FP8. + Supports TWO granularities: + - Per-token (1x128 blocks): shape (N, K // 128) + - Per-block (128x128 blocks): shape (N // 128, K // 128) + out : torch.Tensor, optional + Output tensor of shape (M, N). If None, will be allocated. + out_dtype : torch.dtype, optional + Output data type. Default is torch.bfloat16. + Returns + ------- + torch.Tensor + Output tensor of shape (M, N) with dtype `out_dtype`. + Examples + -------- + >>> import torch + >>> from flashinfer.gemm import fp8_blockscale_gemm_swapab + >>> + >>> M, N, K = 16, 4096, 4096 + >>> device = "cuda" + >>> + >>> # BF16 inputs + >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) + >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) + >>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_bf16) + >>> print(output.shape) # torch.Size([16, 4096]) + >>> + >>> # Mixed: BF16 input + FP8 weight + >>> from flashinfer.testing.utils import per_token_cast_to_fp8 + >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) + >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) + >>> weight_fp8, weight_scale = per_token_cast_to_fp8(weight_bf16) + >>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_fp8, None, weight_scale) + >>> print(output.shape) # torch.Size([16, 4096]) + >>> + >>> # FP8 weight with 128x128 block scales + >>> from flashinfer.testing.utils import per_block_cast_to_fp8 + >>> weight_bf16 = torch.randn(N, K, device=device, dtype=torch.bfloat16) + >>> weight_fp8, weight_scale = per_block_cast_to_fp8(weight_bf16) + >>> # weight_scale has shape (N // 128, K // 128) + >>> input_bf16 = torch.randn(M, K, device=device, dtype=torch.bfloat16) + >>> output = fp8_blockscale_gemm_swapab(input_bf16, weight_fp8, None, weight_scale) + >>> print(output.shape) # torch.Size([16, 4096]) + Notes + ----- + - This function requires NVIDIA Hopper (SM90) architecture and CUDA 12.8+ + - SwapAB kernel is automatically used when M < 32 (threshold) + - For FP8 inputs, scaling factors must be provided + - For BF16 inputs, quantization and scaling happen internally + - Weight scales support two granularities: + * Per-token (1x128 blocks): (N, K//128) + * Per-block (128x128 blocks): (N//128, K//128) + - Input scales only support per-token format: (M, K//128) + - The function uses DeepGEMM backend with JIT compilation + """ + # Validate architecture support + if not _match_sm_version(input.device, ["90", "90a"]): + raise ValueError( + "fp8_blockscale_gemm_swapab is only supported on SM90 (Hopper) architecture." + ) + + # Validate tensor dimensions + if input.ndim != 2: + raise ValueError(f"Input must be 2D (M, K), got shape {input.shape}") + if weight.ndim != 2: + raise ValueError(f"Weight must be 2D (N, K), got shape {weight.shape}") + + # Get dimensions + M, K = input.shape + N, K_weight = weight.shape + + if K != K_weight: + raise ValueError( + f"K dimension mismatch: input has K={K}, weight has K={K_weight}" + ) + + # Validate K is divisible by block size (128) + BLOCK_SIZE = 128 + if K % BLOCK_SIZE != 0: + raise ValueError( + f"K dimension must be divisible by block size ({BLOCK_SIZE}), got K={K}" + ) + + # Validate device consistency + if input.device != weight.device: + raise ValueError( + f"Input and weight must be on the same device. " + f"Got input: {input.device}, weight: {weight.device}" + ) + + # Validate dtype combinations + input_is_fp8 = input.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + weight_is_fp8 = weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + input_is_bf16 = input.dtype == torch.bfloat16 + weight_is_bf16 = weight.dtype == torch.bfloat16 + + # Explicitly reject FP8 input + BF16 weight (missing kernel implementation) + if input_is_fp8 and weight_is_bf16: + raise ValueError( + "FP8 input + BF16 weight is not supported (missing kernel implementation). " + ) + + # Validate scale requirements for FP8 inputs + if input_is_fp8: + if input_scale is None: + raise ValueError( + "input_scale is required when input is FP8. " + "For BF16 inputs, omit input_scale for internal quantization." + ) + # Validate scale shape: (M, K // BLOCK_SIZE) + expected_scale_shape = (M, K // BLOCK_SIZE) + if input_scale.shape != expected_scale_shape: + raise ValueError( + f"input_scale shape mismatch. Expected {expected_scale_shape}, " + f"got {input_scale.shape}" + ) + if input_scale.dtype != torch.float32: + raise ValueError( + f"input_scale must be float32, got {input_scale.dtype}" + ) + if input_scale.device != input.device: + raise ValueError( + f"input_scale device mismatch. Expected {input.device}, " + f"got {input_scale.device}" + ) + else: + if not input_is_bf16: + raise ValueError( + f"Input must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), " + f"got {input.dtype}" + ) + if input_scale is not None: + raise ValueError( + "input_scale should not be provided for BF16 inputs. " + "Use FP8 inputs if you want to provide external scales." + ) + + if weight_is_fp8: + if weight_scale is None: + raise ValueError( + "weight_scale is required when weight is FP8. " + "For BF16 weights, omit weight_scale for internal quantization." + ) + # Validate scale shape: supports (N, K // BLOCK_SIZE) or (N // BLOCK_SIZE, K // BLOCK_SIZE) + expected_per_token_shape = (N, K // BLOCK_SIZE) + expected_per_block_shape = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, K // BLOCK_SIZE) + is_per_token = weight_scale.shape == expected_per_token_shape + is_per_block = weight_scale.shape == expected_per_block_shape + + if not (is_per_token or is_per_block): + raise ValueError( + f"weight_scale shape mismatch. Expected either {expected_per_token_shape} " + f"(per-token, 1x128 blocks) or {expected_per_block_shape} " + f"(per-block, 128x128 blocks), got {weight_scale.shape}" + ) + if weight_scale.dtype != torch.float32: + raise ValueError( + f"weight_scale must be float32, got {weight_scale.dtype}" + ) + if weight_scale.device != weight.device: + raise ValueError( + f"weight_scale device mismatch. Expected {weight.device}, " + f"got {weight_scale.device}" + ) + else: + if not weight_is_bf16: + raise ValueError( + f"Weight must be either FP8 (torch.float8_e4m3fn) or BF16 (torch.bfloat16), " + f"got {weight.dtype}" + ) + if weight_scale is not None: + raise ValueError( + "weight_scale should not be provided for BF16 weights. " + "Use FP8 weights if you want to provide external scales." + ) + + # Validate output tensor if provided + if out is not None: + if out.shape != (M, N): + raise ValueError( + f"Output shape mismatch. Expected ({M}, {N}), got {out.shape}" + ) + if out.device != input.device: + raise ValueError( + f"Output device mismatch. Expected {input.device}, got {out.device}" + ) + if out_dtype is not None and out.dtype != out_dtype: + raise ValueError( + f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}" + ) + out_dtype = out.dtype + else: + # Allocate output + out_dtype = out_dtype or torch.bfloat16 + if out_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Output dtype must be torch.bfloat16 or torch.float16, got {out_dtype}" + ) + out = torch.empty(M, N, dtype=out_dtype, device=input.device) + + # Get the runner + runner = get_fp8_blockscale_gemm_runner() + + # Allocate workspace + workspace_size = runner.get_workspace_size(M, N, K) + workspace = None + if workspace_size > 0: + workspace = torch.empty( + workspace_size, dtype=torch.uint8, device=input.device + ) + runner.configure_workspace(workspace) + + runner.gemm(input, weight, out, input_scale, weight_scale) + + return out diff --git a/flashinfer/jit/gemm/__init__.py b/flashinfer/jit/gemm/__init__.py index f1681d3bf5..e81d51e15f 100644 --- a/flashinfer/jit/gemm/__init__.py +++ b/flashinfer/jit/gemm/__init__.py @@ -27,6 +27,7 @@ gen_gemm_sm90_module, ) from .deepgemm import gen_deepgemm_sm100_module +from .fp8_blockscale import gen_fp8_blockscale_gemm_sm90_module __all__ = [ "gen_gemm_module", @@ -40,4 +41,5 @@ "gen_tgv_gemm_sm10x_module", "gen_gemm_sm90_module", "gen_deepgemm_sm100_module", + "gen_fp8_blockscale_gemm_sm90_module", ] diff --git a/flashinfer/jit/gemm/fp8_blockscale.py b/flashinfer/jit/gemm/fp8_blockscale.py new file mode 100755 index 0000000000..70546c93fd --- /dev/null +++ b/flashinfer/jit/gemm/fp8_blockscale.py @@ -0,0 +1,57 @@ +from typing import List + +from .. import env as jit_env +from ..core import ( + JitSpec, + gen_jit_spec, + sm90a_nvcc_flags, +) +from ..cpp_ext import is_cuda_version_at_least + + +def gen_fp8_blockscale_gemm_sm90_module(use_fast_build: bool = False) -> JitSpec: + """Generate JIT spec for FP8 block scale GEMM on SM90 (Hopper).""" + nvcc_flags = sm90a_nvcc_flags + [ + "-DCOMPILE_HOPPER_TMA_GEMMS", + "-DENABLE_BF16", + "-DENABLE_FP8", + "-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "", + ] + + return gen_jit_spec( + "fp8_blockscale_gemm_90", + [ + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu", + jit_env.FLASHINFER_CSRC_DIR / "fp8_blockscale_gemm_sm90_binding.cu", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/memoryUtils.cu", + ], + extra_cuda_cflags=nvcc_flags, + extra_cflags=["-DFAST_BUILD"] if use_fast_build else [], + extra_ldflags=["-lnvrtc", "-lcuda"], + extra_include_paths=[ + jit_env.FLASHINFER_CSRC_DIR / "nv_internal", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "cutlass_extensions" + / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "kernels" + / "cutlass_kernels" + / "include", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "kernels" + / "cutlass_kernels", + ], + ) + diff --git a/tests/gemm/test_fp8_blockscale_gemm.py b/tests/gemm/test_fp8_blockscale_gemm.py new file mode 100755 index 0000000000..49ae459fa2 --- /dev/null +++ b/tests/gemm/test_fp8_blockscale_gemm.py @@ -0,0 +1,349 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +import torch +import torch.nn.functional as F + +import flashinfer +from flashinfer.gemm import fp8_blockscale_gemm_swapab +from flashinfer.testing.utils import per_token_cast_to_fp8, per_block_cast_to_fp8 +from flashinfer.utils import ( + get_compute_capability, + has_flashinfer_jit_cache, + is_sm90a_supported, +) +from flashinfer.jit.gemm import gen_fp8_blockscale_gemm_sm90_module + + +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) +def warmup_jit(): + """Warm up JIT compilation for FP8 block-scale GEMM if not cached.""" + if is_sm90a_supported(torch.device("cuda:0")): + jit_specs = [gen_fp8_blockscale_gemm_sm90_module()] + flashinfer.jit.build_jit_specs(jit_specs, verbose=False) + yield + + +@pytest.mark.parametrize("m", [1, 16, 32, 64, 128]) +@pytest.mark.parametrize("n", [128, 256, 512, 1024, 4096]) +@pytest.mark.parametrize("k", [256, 512, 1024, 4096]) +@pytest.mark.parametrize("input_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("weight_dtype", [torch.bfloat16]) +def test_fp8_blockscale_gemm_swapab(m, n, k, input_dtype, weight_dtype): + """Test FP8 block-scale GEMM with swapAB optimization. + + This test focuses on the usage: BF16 inputs with internal quantization. + The kernel automatically handles FP8 quantization with proper block-scale computation. + """ + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + # K must be divisible by 128 (block size requirement) + if k % 128 != 0: + pytest.skip("K must be divisible by 128 for block-scale GEMM") + + device = "cuda" + torch.manual_seed(42) + + # Create BF16 inputs + input = torch.randn(m, k, device=device, dtype=input_dtype) + weight = torch.randn(n, k, device=device, dtype=weight_dtype) + + # Compute reference result + reference = torch.matmul(input, weight.T) + + # Run FP8 block-scale GEMM + output = fp8_blockscale_gemm_swapab(input, weight) + + # Verify output shape + assert output.shape == (m, n), f"Expected shape {(m, n)}, got {output.shape}" + assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}" + + # Check correctness using cosine similarity + cos_sim = F.cosine_similarity( + reference.flatten().float(), + output.flatten().float(), + dim=0 + ) + assert cos_sim > 0.99, f"Cosine similarity {cos_sim} is too low (expected > 0.99)" + + +@pytest.mark.parametrize("m", [1, 32, 128]) +@pytest.mark.parametrize("n", [1024, 4096]) +@pytest.mark.parametrize("k", [512, 4096]) +@pytest.mark.parametrize( + "input_dtype,weight_dtype", + [ + (torch.bfloat16, torch.bfloat16), # Both BF16 (for testing internal quantization) + (torch.bfloat16, torch.float8_e4m3fn), # BF16 input + FP8 weight + ] +) +def test_fp8_blockscale_gemm_dtypes(m, n, k, input_dtype, weight_dtype): + """Test the 2 recommended dtype combinations with proper FP8 quantization. + + Uses quantization from flashinfer.testing.utils: + - per_token_cast_to_fp8: 1x128 block quantization (for both input and weight) + + Note: Both input and weight use per_token (1x128 blocks). + The API expects scale shape (N, K//128), which per_token provides. + + These utilities return scales in the correct format (reciprocals) that + match TRT-LLM's kernel expectations. For kernel reference, + see csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh + """ + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + if k % 128 != 0: + pytest.skip("K must be divisible by 128 for block-scale GEMM") + + device = "cuda" + torch.manual_seed(42) + + # Create BF16 data for reference + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + # Quantize input + if input_dtype == torch.float8_e4m3fn: + input_tensor, input_scale = per_token_cast_to_fp8(input_bf16) + else: + input_tensor, input_scale = input_bf16, None + + # Quantize weight + # Note: Use per_token_cast_to_fp8 for weights too (1x128 blocks) + # This gives the correct scale shape (N, K//128) + if weight_dtype == torch.float8_e4m3fn: + weight_tensor, weight_scale = per_token_cast_to_fp8(weight_bf16) + else: + weight_tensor, weight_scale = weight_bf16, None + + # Compute reference + reference = torch.matmul(input_bf16, weight_bf16.T) + + # Run FP8 block-scale GEMM + output = fp8_blockscale_gemm_swapab( + input_tensor, weight_tensor, input_scale, weight_scale + ) + + # Verify output properties + assert output.shape == (m, n), f"Expected shape {(m, n)}, got {output.shape}" + assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}" + + # Check correctness using cosine similarity + cos_sim = F.cosine_similarity( + reference.flatten().float(), + output.flatten().float(), + dim=0 + ) + + if input_dtype == torch.bfloat16 and weight_dtype == torch.bfloat16: + threshold = 0.99 + else: + # BF16+FP8: BF16 input quantized internally, FP8 weight pre-quantized + # TODO: check threshold + threshold = 0.967 + + assert cos_sim > threshold, ( + f"Cosine similarity {cos_sim:.4f} too low for " + f"{input_dtype} + {weight_dtype} (expected > {threshold})" + ) + + +def test_fp8_blockscale_gemm_per_block_weight_scales(): + """Test BF16+FP8 GEMM with per-block (128x128) weight scales. + + This test demonstrates using 128x128 block quantization for weights with BF16 input, + """ + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + device = "cuda" + m, n, k = 16, 512, 512 # N and K must be divisible by 128 for per-block + torch.manual_seed(42) + + # Create BF16 data + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + # Quantize weight with per-block (128x128) blocks -> (N//128, K//128) + weight_fp8, weight_scale = per_block_cast_to_fp8(weight_bf16) + + # Verify scale shape + assert weight_scale.shape == (n // 128, k // 128), f"Expected weight scale shape ({n // 128}, {k // 128}), got {weight_scale.shape}" + assert weight_scale.min() > 0, "Weight scale should be positive (reciprocal format)" + + # Run GEMM: BF16 input (internal quant) + FP8 weight (per-block scales) + output = fp8_blockscale_gemm_swapab(input_bf16, weight_fp8, None, weight_scale) + + # Compare to BF16 reference + reference = torch.matmul(input_bf16, weight_bf16.T) + + cos_sim = F.cosine_similarity( + reference.flatten().float(), + output.flatten().float(), + dim=0 + ) + # TODO: check threshold + assert cos_sim > 0.967, f"Per-block weight scale accuracy too low: {cos_sim:.4f}" + + print(f"✓ Per-block weight scales: cosine similarity = {cos_sim:.4f}") + + +@pytest.mark.parametrize("m,n,k", [ + (1, 4096, 4096), + (8, 4096, 4096), + (128, 4096, 4096), + (16, 8192, 8192), + (32, 2048, 4096), +]) +def test_fp8_blockscale_gemm_shapes(m, n, k): + """Test various common shapes used in LLM inference.""" + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + if k % 128 != 0: + pytest.skip("K must be divisible by 128") + + device = "cuda" + torch.manual_seed(42) + + input = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + reference = torch.matmul(input, weight.T) + output = fp8_blockscale_gemm_swapab(input, weight) + + cos_sim = F.cosine_similarity( + reference.flatten().float(), + output.flatten().float(), + dim=0 + ) + assert cos_sim > 0.99, f"Shape ({m}, {n}, {k}): cosine similarity {cos_sim} too low" + + +def test_fp8_blockscale_gemm_error_handling(): + """Test that proper errors are raised for invalid inputs.""" + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + device = "cuda" + m, n, k = 16, 256, 256 + + # Test: K not divisible by 128 + input = torch.randn(m, 127, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, 127, device=device, dtype=torch.bfloat16) + with pytest.raises(ValueError, match="divisible by block size"): + fp8_blockscale_gemm_swapab(input, weight) + + # Test: FP16 not supported + input = torch.randn(m, k, device=device, dtype=torch.float16) + weight = torch.randn(n, k, device=device, dtype=torch.float16) + with pytest.raises(ValueError, match="FP8.*or BF16"): + fp8_blockscale_gemm_swapab(input, weight) + + # Test: FP8 weight without scale (naive conversion) + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16) + weight_fp8_naive = weight_bf16.to(torch.float8_e4m3fn) + with pytest.raises(ValueError, match="weight_scale is required when weight is FP8"): + fp8_blockscale_gemm_swapab(input_bf16, weight_fp8_naive, None, None) + + # Test: BF16 input with scale (should raise error) + input = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + fake_scale = torch.ones(m, k // 128, device=device, dtype=torch.float32) + with pytest.raises(ValueError, match="input_scale should not be provided for BF16"): + fp8_blockscale_gemm_swapab(input, weight, input_scale=fake_scale) + + # Test: Wrong scale shape for FP8 input + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + input_fp8, _ = per_token_cast_to_fp8(input_bf16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + wrong_scale = torch.ones(m, k // 64, device=device, dtype=torch.float32) + with pytest.raises(ValueError): + fp8_blockscale_gemm_swapab(input_fp8, weight, input_scale=wrong_scale) + + # Test: FP8 input + BF16 weight is NOT supported (missing kernel implementation) + input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) + input_fp8, input_scale = per_token_cast_to_fp8(input_bf16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + with pytest.raises(ValueError, match="FP8 input.*BF16 weight.*not supported"): + fp8_blockscale_gemm_swapab(input_fp8, weight, input_scale, None) + + +def test_fp8_blockscale_gemm_output_buffer(): + """Test providing pre-allocated output buffer.""" + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] < 9: + pytest.skip("FP8 block-scale GEMM requires SM90 (Hopper) or later") + + if not is_sm90a_supported(torch.device("cuda")): + pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") + + device = "cuda" + m, n, k = 16, 256, 256 + torch.manual_seed(42) + + input = torch.randn(m, k, device=device, dtype=torch.bfloat16) + weight = torch.randn(n, k, device=device, dtype=torch.bfloat16) + + # Pre-allocate output + output = torch.empty(m, n, device=device, dtype=torch.bfloat16) + + # Run GEMM with pre-allocated output + result = fp8_blockscale_gemm_swapab(input, weight, out=output) + + # Verify result is the same buffer + assert result is output + + # Verify correctness + reference = torch.matmul(input, weight.T) + cos_sim = F.cosine_similarity( + reference.flatten().float(), + output.flatten().float(), + dim=0 + ) + assert cos_sim > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + From 1eeba0fa739187be296fe609c983d7a27c7622d6 Mon Sep 17 00:00:00 2001 From: Vivian Chen Date: Fri, 14 Nov 2025 21:46:22 -0800 Subject: [PATCH 2/3] fix binding --- csrc/fp8_blockscale_gemm_sm90_binding.cu | 235 ++++++++--------------- 1 file changed, 83 insertions(+), 152 deletions(-) diff --git a/csrc/fp8_blockscale_gemm_sm90_binding.cu b/csrc/fp8_blockscale_gemm_sm90_binding.cu index 1a24ed624c..fe3be954f0 100755 --- a/csrc/fp8_blockscale_gemm_sm90_binding.cu +++ b/csrc/fp8_blockscale_gemm_sm90_binding.cu @@ -13,45 +13,40 @@ namespace kernels = tensorrt_llm::kernels::fp8_blockscale_gemm; using tvm::ffi::Function; using tvm::ffi::Optional; +using tvm::ffi::TensorView; + +#ifdef FLASHINFER_ENABLE_FP8_E4M3 +inline bool is_fp8_e4m3fn(DLDataType dtype) { + return encode_dlpack_dtype(dtype) == float8_e4m3fn_code; +} +#else +inline bool is_fp8_e4m3fn(DLDataType) { return false; } +#endif /** - * @brief TVM FFI wrapper for TensorRT-LLM's FP8 block-scale GEMM kernel. - * - * This class provides a Python-accessible interface to the CUTLASS-based FP8 GEMM - * implementation with block-wise quantization. It supports two execution modes: + * @brief FP8 Block-Scale GEMM binding for SM90 * - * 1. Pre-quantized FP8: Both inputs are FP8 with external scale factors - * 2. Internal quantization: BF16 inputs are quantized to FP8 internally + * Supports: + * - BF16 + BF16 → BF16 + * - BF16 + FP8 → BF16 * - * The kernel automatically selects between normal and swapAB execution based on - * the M dimension for optimal performance. - * - * @note Requires NVIDIA Hopper (SM90) architecture and CUDA 12.8+ + * @note Output is BF16 */ class Fp8BlockScaleGemmRunner : public tvm::ffi::ModuleObj { public: - /** - * @brief Constructor initializes the CUTLASS FP8 GEMM runner. - * - * Template parameters allow the kernel to: - * - Accept FP8 inputs with external scales (pre-quantized path) - * - Accept BF16 inputs for internal quantization to FP8 - * - Produce BF16 output (accumulation happens in FP32 internally) - */ Fp8BlockScaleGemmRunner() { - runner_ = std::make_unique>(); + + runner_bf16_fp8_ = std::make_unique>(); } ~Fp8BlockScaleGemmRunner() = default; - const char* type_key() const { - return "flashinfer.Fp8BlockScaleGemmRunner"; - } - - const char* kind() const final { - return "fp8_blockscale_gemm_runner"; - } + const char* type_key() const { return "flashinfer.Fp8BlockScaleGemmRunner"; } + const char* kind() const final { return "fp8_blockscale_gemm_runner"; } Optional GetFunction(const tvm::ffi::String& name) { if (name == "gemm") { @@ -66,8 +61,7 @@ class Fp8BlockScaleGemmRunner : public tvm::ffi::ModuleObj { return getWorkspaceSize(shape_m, shape_n, shape_k); }); } else if (name == "configure_workspace") { - return Function::FromTyped( - [this](TensorView workspace) { + return Function::FromTyped([this](TensorView workspace) { configureWorkspace(workspace); }); } @@ -75,152 +69,89 @@ class Fp8BlockScaleGemmRunner : public tvm::ffi::ModuleObj { } private: + /** + * @brief Runtime dtype dispatch + */ + kernels::CutlassFp8BlockScaleGemmRunnerInterface* selectRunner( + bool input_is_fp8, bool weight_is_fp8) { + + if (!input_is_fp8 && !weight_is_fp8) { + return runner_bf16_bf16_.get(); + } else if (!input_is_fp8 && weight_is_fp8) { + return runner_bf16_fp8_.get(); + } else { + return nullptr; + } + } + void runGemm(const TensorView& input, const TensorView& weight, const TensorView& output, const Optional& scales_a, const Optional& scales_b) { - // Get CUDA stream from TVM runtime auto stream = get_stream(input.device()); // Extract tensor info auto input_ptr = input.data_ptr(); auto weight_ptr = weight.data_ptr(); auto output_ptr = output.data_ptr(); - - // Validate tensor dimensions - TVM_FFI_ICHECK(input.ndim() == 2) << "Input must be 2D (M, K), got " << input.ndim() << "D"; - TVM_FFI_ICHECK(weight.ndim() == 2) << "Weight must be 2D (N, K), got " << weight.ndim() << "D"; - TVM_FFI_ICHECK(output.ndim() == 2) << "Output must be 2D (M, N), got " << output.ndim() << "D"; - + int shape_m = input.size(0); int shape_k = input.size(1); int shape_n = weight.size(0); - - TVM_FFI_ICHECK_EQ(weight.size(1), shape_k) - << "Weight K dimension must match input K. Expected " << shape_k - << ", got " << weight.size(1); - TVM_FFI_ICHECK_EQ(output.size(0), shape_m) - << "Output M dimension must match input M. Expected " << shape_m - << ", got " << output.size(0); - TVM_FFI_ICHECK_EQ(output.size(1), shape_n) - << "Output N dimension must match weight N. Expected " << shape_n - << ", got " << output.size(1); - - // Validate K is divisible by block size (128) - constexpr int BLOCK_SIZE = 128; - TVM_FFI_ICHECK_EQ(shape_k % BLOCK_SIZE, 0) - << "K dimension must be divisible by block size (" << BLOCK_SIZE - << "), got K=" << shape_k; - - // Validate workspace is configured - TVM_FFI_ICHECK(workspace_ != nullptr) - << "Workspace not configured. Call configure_workspace() before gemm()"; - - // Get scales if provided - float const* scales_a_ptr = nullptr; - float const* scales_b_ptr = nullptr; - if (scales_a.has_value()) { - const auto& scale_a_view = scales_a.value(); - TVM_FFI_ICHECK_EQ(scale_a_view.dtype().code, kDLFloat) - << "input_scale must be float32"; - TVM_FFI_ICHECK_EQ(scale_a_view.dtype().bits, 32) - << "input_scale must be float32"; - TVM_FFI_ICHECK_EQ(scale_a_view.ndim(), 2) - << "input_scale must be 2D (M, K//128)"; - TVM_FFI_ICHECK_EQ(scale_a_view.size(0), shape_m) - << "input_scale M dimension mismatch. Expected " << shape_m - << ", got " << scale_a_view.size(0); - TVM_FFI_ICHECK_EQ(scale_a_view.size(1), shape_k / BLOCK_SIZE) - << "input_scale K dimension mismatch. Expected " << (shape_k / BLOCK_SIZE) - << ", got " << scale_a_view.size(1); - scales_a_ptr = reinterpret_cast(scale_a_view.data_ptr()); - } + // Sanity checks (defense against Python bugs) + TVM_FFI_ICHECK(input_ptr != nullptr) << "input is null"; + TVM_FFI_ICHECK(weight_ptr != nullptr) << "weight is null"; + TVM_FFI_ICHECK(output_ptr != nullptr) << "output is null"; + TVM_FFI_ICHECK(shape_k == weight.size(1)) << "K dimension mismatch"; - if (scales_b.has_value()) { - const auto& scale_b_view = scales_b.value(); - TVM_FFI_ICHECK_EQ(scale_b_view.dtype().code, kDLFloat) - << "weight_scale must be float32"; - TVM_FFI_ICHECK_EQ(scale_b_view.dtype().bits, 32) - << "weight_scale must be float32"; - TVM_FFI_ICHECK_EQ(scale_b_view.ndim(), 2) - << "weight_scale must be 2D (N, K//128)"; - TVM_FFI_ICHECK_EQ(scale_b_view.size(0), shape_n) - << "weight_scale N dimension mismatch. Expected " << shape_n - << ", got " << scale_b_view.size(0); - TVM_FFI_ICHECK_EQ(scale_b_view.size(1), shape_k / BLOCK_SIZE) - << "weight_scale K dimension mismatch. Expected " << (shape_k / BLOCK_SIZE) - << ", got " << scale_b_view.size(1); - scales_b_ptr = reinterpret_cast(scale_b_view.data_ptr()); - } - - // Check input types - FP8 uses special dtype codes - bool input_is_fp8 = (input.dtype().code == kDLFloat8_e4m3fn || - input.dtype().code == kDLFloat8_e5m2); - bool weight_is_fp8 = (weight.dtype().code == kDLFloat8_e4m3fn || - weight.dtype().code == kDLFloat8_e5m2); + // Determine dtypes for runner selection + bool input_is_fp8 = is_fp8_e4m3fn(input.dtype()); + bool weight_is_fp8 = is_fp8_e4m3fn(weight.dtype()); - // Dispatch to appropriate kernel path - if (input_is_fp8 && weight_is_fp8) { - // Path 1: Both inputs are FP8 - use pre-quantized FP8 GEMM - TVM_FFI_ICHECK(scales_a_ptr != nullptr && scales_b_ptr != nullptr) - << "FP8 inputs require scale factors. Provide both input_scale and weight_scale."; - - // Validate output dtype is BF16 - TVM_FFI_ICHECK_EQ(output.dtype().code, kDLBfloat) - << "Output must be BF16 for FP8 inputs"; - TVM_FFI_ICHECK_EQ(output.dtype().bits, 16) - << "Output must be BF16 for FP8 inputs"; - - auto input_fp8 = reinterpret_cast<__nv_fp8_e4m3 const*>(input_ptr); - auto weight_fp8 = reinterpret_cast<__nv_fp8_e4m3 const*>(weight_ptr); - auto output_bf16 = reinterpret_cast<__nv_bfloat16*>(output_ptr); - - int ld_a = shape_k; // Leading dimension for row-major input - int ld_b = shape_k; // Leading dimension for row-major weight - int ld_d = shape_n; // Leading dimension for row-major output - - runner_->gemm(input_fp8, ld_a, weight_fp8, ld_b, output_bf16, ld_d, - shape_m, shape_n, shape_k, scales_a_ptr, scales_b_ptr, stream); - } else if (!input_is_fp8 && !weight_is_fp8) { - // Path 2: Both inputs are BF16 - use internal quantization - TVM_FFI_ICHECK(scales_a_ptr == nullptr && scales_b_ptr == nullptr) - << "BF16 inputs use internal quantization. Do not provide scales. " - << "For external scales, use FP8 inputs."; - - // Validate input/weight dtypes are BF16 - TVM_FFI_ICHECK_EQ(input.dtype().code, kDLBfloat) - << "Input must be BF16 for internal quantization path"; - TVM_FFI_ICHECK_EQ(input.dtype().bits, 16) - << "Input must be BF16 for internal quantization path"; - TVM_FFI_ICHECK_EQ(weight.dtype().code, kDLBfloat) - << "Weight must be BF16 for internal quantization path"; - TVM_FFI_ICHECK_EQ(weight.dtype().bits, 16) - << "Weight must be BF16 for internal quantization path"; - - // Call internal quantization path (note: different argument order!) - runner_->gemm(output_ptr, input_ptr, weight_ptr, shape_m, shape_n, shape_k, - stream, scales_a_ptr, scales_b_ptr); - } else { - // Path 3: Mixed dtypes - not supported - TVM_FFI_ICHECK(false) - << "Mixed FP8/BF16 inputs not supported. Both input and weight must be " - << "either FP8 (with scales) or BF16 (internal quantization)."; - } + // Extract scale pointers (nullptr if not provided) + float const* scales_a_ptr = scales_a.has_value() + ? reinterpret_cast(scales_a.value().data_ptr()) + : nullptr; + float const* scales_b_ptr = scales_b.has_value() + ? reinterpret_cast(scales_b.value().data_ptr()) + : nullptr; + + // Select appropriate runner + auto* runner = selectRunner(input_is_fp8, weight_is_fp8); + TVM_FFI_ICHECK(runner != nullptr) << "Unsupported dtype combination"; + + // Ensure workspace is configured (defensive check) + TVM_FFI_ICHECK(workspace_ != nullptr) << "Workspace not configured. Call configure_workspace first."; + + // Call kernel + runner->gemm(output_ptr, input_ptr, weight_ptr, shape_m, shape_n, shape_k, + stream, scales_a_ptr, scales_b_ptr); } int64_t getWorkspaceSize(int64_t shape_m, int64_t shape_n, int64_t shape_k) { - // Use getWorkspaceSizeBase to ensure internal state is properly initialized - // This is critical for BF16 internal quantization path - // num_problems=1 for single GEMM, top_k=1 for regular gemm (not MOE) - size_t workspace_size = runner_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, /*num_problems=*/1); - return workspace_size; + size_t max_size = 0; + + max_size = std::max(max_size, + runner_bf16_bf16_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); + max_size = std::max(max_size, + runner_bf16_fp8_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); + + return max_size; } void configureWorkspace(const TensorView& workspace) { auto workspace_ptr = reinterpret_cast(workspace.data_ptr()); - runner_->configureWorkspace(workspace_ptr); + workspace_ = workspace_ptr; + + runner_bf16_bf16_->configureWorkspace(workspace_ptr); + runner_bf16_fp8_->configureWorkspace(workspace_ptr); } - std::unique_ptr runner_; + std::unique_ptr> runner_bf16_bf16_; + std::unique_ptr> runner_bf16_fp8_; + + char* workspace_ = nullptr; }; tvm::ffi::Module init() { @@ -228,4 +159,4 @@ tvm::ffi::Module init() { return tvm::ffi::Module(ptr); } -TVM_FFI_DLL_EXPORT_TYPED_FUNC(init, init); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(init, init); \ No newline at end of file From 364ee70039482702957be29415b2e876deb440d8 Mon Sep 17 00:00:00 2001 From: Vivian Chen Date: Mon, 17 Nov 2025 15:27:14 -0800 Subject: [PATCH 3/3] rename function for SM90 --- csrc/fp8_blockscale_gemm_sm90_binding.cu | 7 +---- flashinfer/gemm/__init__.py | 4 +-- flashinfer/gemm/gemm_base.py | 34 +++--------------------- tests/gemm/test_fp8_blockscale_gemm.py | 14 +++++----- 4 files changed, 12 insertions(+), 47 deletions(-) diff --git a/csrc/fp8_blockscale_gemm_sm90_binding.cu b/csrc/fp8_blockscale_gemm_sm90_binding.cu index fe3be954f0..426100f7e4 100755 --- a/csrc/fp8_blockscale_gemm_sm90_binding.cu +++ b/csrc/fp8_blockscale_gemm_sm90_binding.cu @@ -88,7 +88,6 @@ class Fp8BlockScaleGemmRunner : public tvm::ffi::ModuleObj { const Optional& scales_a, const Optional& scales_b) { auto stream = get_stream(input.device()); - // Extract tensor info auto input_ptr = input.data_ptr(); auto weight_ptr = weight.data_ptr(); auto output_ptr = output.data_ptr(); @@ -97,7 +96,6 @@ class Fp8BlockScaleGemmRunner : public tvm::ffi::ModuleObj { int shape_k = input.size(1); int shape_n = weight.size(0); - // Sanity checks (defense against Python bugs) TVM_FFI_ICHECK(input_ptr != nullptr) << "input is null"; TVM_FFI_ICHECK(weight_ptr != nullptr) << "weight is null"; TVM_FFI_ICHECK(output_ptr != nullptr) << "output is null"; @@ -107,7 +105,7 @@ class Fp8BlockScaleGemmRunner : public tvm::ffi::ModuleObj { bool input_is_fp8 = is_fp8_e4m3fn(input.dtype()); bool weight_is_fp8 = is_fp8_e4m3fn(weight.dtype()); - // Extract scale pointers (nullptr if not provided) + // Extract scale pointers float const* scales_a_ptr = scales_a.has_value() ? reinterpret_cast(scales_a.value().data_ptr()) : nullptr; @@ -118,11 +116,8 @@ class Fp8BlockScaleGemmRunner : public tvm::ffi::ModuleObj { // Select appropriate runner auto* runner = selectRunner(input_is_fp8, weight_is_fp8); TVM_FFI_ICHECK(runner != nullptr) << "Unsupported dtype combination"; - - // Ensure workspace is configured (defensive check) TVM_FFI_ICHECK(workspace_ != nullptr) << "Workspace not configured. Call configure_workspace first."; - // Call kernel runner->gemm(output_ptr, input_ptr, weight_ptr, shape_m, shape_n, shape_k, stream, scales_a_ptr, scales_b_ptr); } diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index 545536424a..7389eb5b1b 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -14,7 +14,7 @@ from .gemm_base import gemm_fp8_nt_groupwise as gemm_fp8_nt_groupwise from .gemm_base import group_gemm_fp8_nt_groupwise as group_gemm_fp8_nt_groupwise from .gemm_base import ( - get_fp8_blockscale_gemm_runner as get_fp8_blockscale_gemm_runner, + get_fp8_blockscale_gemm_runner_sm90 as get_fp8_blockscale_gemm_runner_sm90, ) from .gemm_base import ( fp8_blockscale_gemm_swapab as fp8_blockscale_gemm_swapab, @@ -36,7 +36,7 @@ "gemm_fp8_nt_blockscaled", "gemm_fp8_nt_groupwise", "group_gemm_fp8_nt_groupwise", - "get_fp8_blockscale_gemm_runner", + "get_fp8_blockscale_gemm_runner_sm90", "fp8_blockscale_gemm_swapab", "mm_M1_16_K7168_N256", ] diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index b1a8b764c3..75ac7684b8 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -3114,10 +3114,9 @@ def batch_deepgemm_fp8_nt_groupwise( return out @functools.cache -def get_fp8_blockscale_gemm_runner(): +def get_fp8_blockscale_gemm_runner_sm90(): """Get the FP8 block scale GEMM runner module for SM90.""" - module = gen_fp8_blockscale_gemm_sm90_module().build_and_load() - return module.init() + return gen_fp8_blockscale_gemm_sm90_module().build_and_load().init() def fp8_blockscale_gemm_swapab( @@ -3133,18 +3132,12 @@ def fp8_blockscale_gemm_swapab( This function automatically selects between normal and swapAB kernel based on the M dimension. For small M (< 32), it uses the swapAB kernel for better performance. - The computation is: output = input @ weight.T with per-block FP8 quantization - and scaling. Supported Dtype Combinations ----------------------------- - **BF16 + BF16 → BF16**: Both inputs BF16, internal quantization (no scales needed) - **BF16 + FP8 → BF16**: BF16 input, FP8 weight - Note: - - FP16 is NOT supported. - - FP8 + BF16 is NOT supported (missing kernel implementation) - Parameters ---------- input : torch.Tensor @@ -3155,12 +3148,8 @@ def fp8_blockscale_gemm_swapab( - FP8 (torch.float8_e4m3fn) with weight_scale required - BF16 (torch.bfloat16) for internal quantization input_scale : torch.Tensor, optional - Not used. Input is always BF16 with internal quantization. weight_scale : torch.Tensor, optional Scaling factors for weight. Required if weight is FP8. - Supports TWO granularities: - - Per-token (1x128 blocks): shape (N, K // 128) - - Per-block (128x128 blocks): shape (N // 128, K // 128) out : torch.Tensor, optional Output tensor of shape (M, N). If None, will be allocated. out_dtype : torch.dtype, optional @@ -3223,7 +3212,6 @@ def fp8_blockscale_gemm_swapab( if weight.ndim != 2: raise ValueError(f"Weight must be 2D (N, K), got shape {weight.shape}") - # Get dimensions M, K = input.shape N, K_weight = weight.shape @@ -3239,13 +3227,6 @@ def fp8_blockscale_gemm_swapab( f"K dimension must be divisible by block size ({BLOCK_SIZE}), got K={K}" ) - # Validate device consistency - if input.device != weight.device: - raise ValueError( - f"Input and weight must be on the same device. " - f"Got input: {input.device}, weight: {weight.device}" - ) - # Validate dtype combinations input_is_fp8 = input.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] weight_is_fp8 = weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] @@ -3263,9 +3244,7 @@ def fp8_blockscale_gemm_swapab( if input_scale is None: raise ValueError( "input_scale is required when input is FP8. " - "For BF16 inputs, omit input_scale for internal quantization." ) - # Validate scale shape: (M, K // BLOCK_SIZE) expected_scale_shape = (M, K // BLOCK_SIZE) if input_scale.shape != expected_scale_shape: raise ValueError( @@ -3297,9 +3276,7 @@ def fp8_blockscale_gemm_swapab( if weight_scale is None: raise ValueError( "weight_scale is required when weight is FP8. " - "For BF16 weights, omit weight_scale for internal quantization." ) - # Validate scale shape: supports (N, K // BLOCK_SIZE) or (N // BLOCK_SIZE, K // BLOCK_SIZE) expected_per_token_shape = (N, K // BLOCK_SIZE) expected_per_block_shape = ((N + BLOCK_SIZE - 1) // BLOCK_SIZE, K // BLOCK_SIZE) is_per_token = weight_scale.shape == expected_per_token_shape @@ -3315,11 +3292,6 @@ def fp8_blockscale_gemm_swapab( raise ValueError( f"weight_scale must be float32, got {weight_scale.dtype}" ) - if weight_scale.device != weight.device: - raise ValueError( - f"weight_scale device mismatch. Expected {weight.device}, " - f"got {weight_scale.device}" - ) else: if not weight_is_bf16: raise ValueError( @@ -3357,7 +3329,7 @@ def fp8_blockscale_gemm_swapab( out = torch.empty(M, N, dtype=out_dtype, device=input.device) # Get the runner - runner = get_fp8_blockscale_gemm_runner() + runner = get_fp8_blockscale_gemm_runner_sm90() # Allocate workspace workspace_size = runner.get_workspace_size(M, N, K) diff --git a/tests/gemm/test_fp8_blockscale_gemm.py b/tests/gemm/test_fp8_blockscale_gemm.py index 49ae459fa2..e069b15c02 100755 --- a/tests/gemm/test_fp8_blockscale_gemm.py +++ b/tests/gemm/test_fp8_blockscale_gemm.py @@ -80,7 +80,7 @@ def test_fp8_blockscale_gemm_swapab(m, n, k, input_dtype, weight_dtype): assert output.shape == (m, n), f"Expected shape {(m, n)}, got {output.shape}" assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}" - # Check correctness using cosine similarity + # Check correctness cos_sim = F.cosine_similarity( reference.flatten().float(), output.flatten().float(), @@ -136,8 +136,6 @@ def test_fp8_blockscale_gemm_dtypes(m, n, k, input_dtype, weight_dtype): input_tensor, input_scale = input_bf16, None # Quantize weight - # Note: Use per_token_cast_to_fp8 for weights too (1x128 blocks) - # This gives the correct scale shape (N, K//128) if weight_dtype == torch.float8_e4m3fn: weight_tensor, weight_scale = per_token_cast_to_fp8(weight_bf16) else: @@ -155,7 +153,7 @@ def test_fp8_blockscale_gemm_dtypes(m, n, k, input_dtype, weight_dtype): assert output.shape == (m, n), f"Expected shape {(m, n)}, got {output.shape}" assert output.dtype == torch.bfloat16, f"Expected BF16 output, got {output.dtype}" - # Check correctness using cosine similarity + # Check correctness cos_sim = F.cosine_similarity( reference.flatten().float(), output.flatten().float(), @@ -188,14 +186,14 @@ def test_fp8_blockscale_gemm_per_block_weight_scales(): pytest.skip("FP8 block-scale GEMM requires SM90a (Hopper) support") device = "cuda" - m, n, k = 16, 512, 512 # N and K must be divisible by 128 for per-block + m, n, k = 16, 512, 512 torch.manual_seed(42) - # Create BF16 data + # Create inputs input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) weight_bf16 = torch.randn(n, k, device=device, dtype=torch.bfloat16) - # Quantize weight with per-block (128x128) blocks -> (N//128, K//128) + # Quantize weight with per-block (128x128) blocks weight_fp8, weight_scale = per_block_cast_to_fp8(weight_bf16) # Verify scale shape @@ -301,7 +299,7 @@ def test_fp8_blockscale_gemm_error_handling(): with pytest.raises(ValueError): fp8_blockscale_gemm_swapab(input_fp8, weight, input_scale=wrong_scale) - # Test: FP8 input + BF16 weight is NOT supported (missing kernel implementation) + # Test: FP8 input + BF16 weight is NOT supported input_bf16 = torch.randn(m, k, device=device, dtype=torch.bfloat16) input_fp8, input_scale = per_token_cast_to_fp8(input_bf16) weight = torch.randn(n, k, device=device, dtype=torch.bfloat16)