diff --git a/csrc/flashinfer_norm_binding.cu b/csrc/flashinfer_norm_binding.cu index ddb59f3dc9..816eb2754b 100644 --- a/csrc/flashinfer_norm_binding.cu +++ b/csrc/flashinfer_norm_binding.cu @@ -17,14 +17,14 @@ void rmsnorm(TensorView out, TensorView input, TensorView weight, double eps, bool enable_pdl); -void rmsnorm_quant(TensorView out, TensorView input, TensorView weight, double scale, double eps, - bool enable_pdl); +void rmsnorm_quant(TensorView out, TensorView input, TensorView weight, TensorView scale, + double eps, bool enable_pdl); void fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight, double eps, bool enable_pdl); void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView residual, - TensorView weight, double scale, double eps, bool enable_pdl); + TensorView weight, TensorView scale, double eps, bool enable_pdl); void gemma_rmsnorm(TensorView out, TensorView input, TensorView weight, double eps, bool enable_pdl); diff --git a/csrc/norm.cu b/csrc/norm.cu index dbbe6d80dd..b3460a87d9 100644 --- a/csrc/norm.cu +++ b/csrc/norm.cu @@ -77,13 +77,15 @@ void rmsnorm(TensorView output, TensorView input, TensorView weight, double eps, } } -void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, double scale, double eps, - bool enable_pdl) { +void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, TensorView scale, + double eps, bool enable_pdl) { CHECK_LAST_DIM_CONTIGUOUS_INPUT(input); CHECK_LAST_DIM_CONTIGUOUS_INPUT(output); CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight); CHECK_DEVICE(input, weight); + CHECK_DEVICE(input, scale); CHECK_DIM(1, weight); // weight: (hidden_size) + TVM_FFI_ICHECK_EQ(scale.numel(), 1); auto input_ndim = input.ndim(); if (input_ndim == 2) { @@ -103,7 +105,7 @@ void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, doubl cudaError_t status = norm::RMSNormQuant( static_cast(input.data_ptr()), static_cast(weight.data_ptr()), static_cast(output.data_ptr()), batch_size, hidden_size, input.stride(0), - output.stride(0), static_cast(scale), eps, enable_pdl, stream); + output.stride(0), static_cast(scale.data_ptr()), eps, enable_pdl, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "RMSNormQuant failed with error code " << cudaGetErrorString(status); return true; @@ -145,7 +147,7 @@ void fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight, } void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView residual, - TensorView weight, double scale, double eps, bool enable_pdl) { + TensorView weight, TensorView scale, double eps, bool enable_pdl) { CHECK_LAST_DIM_CONTIGUOUS_INPUT(input); CHECK_LAST_DIM_CONTIGUOUS_INPUT(residual); CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight); @@ -153,6 +155,7 @@ void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView res CHECK_DEVICE(input, residual); CHECK_DEVICE(input, weight); CHECK_DEVICE(input, output); + CHECK_DEVICE(input, scale); CHECK_DIM(2, input); // input: (batch_size, hidden_size) CHECK_DIM(2, residual); // residual: (batch_size, hidden_size) CHECK_DIM(1, weight); // weight: (hidden_size) @@ -162,6 +165,7 @@ void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView res TVM_FFI_ICHECK_EQ(residual.size(0), batch_size); TVM_FFI_ICHECK_EQ(residual.size(1), hidden_size); TVM_FFI_ICHECK_EQ(weight.size(0), hidden_size); + TVM_FFI_ICHECK_EQ(scale.numel(), 1); ffi::CUDADeviceGuard device_guard(input.device().device_id); const cudaStream_t stream = get_stream(input.device()); @@ -170,8 +174,8 @@ void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView res cudaError_t status = norm::FusedAddRMSNormQuant( static_cast(input.data_ptr()), static_cast(residual.data_ptr()), static_cast(weight.data_ptr()), static_cast(output.data_ptr()), - batch_size, hidden_size, input.stride(0), residual.stride(0), output.stride(0), scale, - eps, enable_pdl, stream); + batch_size, hidden_size, input.stride(0), residual.stride(0), output.stride(0), + static_cast(scale.data_ptr()), eps, enable_pdl, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "FusedAddRMSNormQuant failed with error code " << cudaGetErrorString(status); diff --git a/docs/api/norm.rst b/docs/api/norm.rst index 98c0d4b5fa..058940995d 100644 --- a/docs/api/norm.rst +++ b/docs/api/norm.rst @@ -11,7 +11,9 @@ Kernels for normalization layers. :toctree: ../generated rmsnorm + rmsnorm_quant fused_add_rmsnorm + fused_add_rmsnorm_quant gemma_rmsnorm gemma_fused_add_rmsnorm layernorm diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 8fa98adb62..62cabd0a6e 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -112,8 +112,11 @@ from .norm import gemma_rmsnorm as gemma_rmsnorm from .norm import rmsnorm as rmsnorm -from .norm import rmsnorm_fp4quant as rmsnorm_fp4quant -from .norm import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant +try: + from .norm import rmsnorm_fp4quant as rmsnorm_fp4quant + from .norm import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant +except (ImportError, AttributeError): + pass # nvidia-cutlass-dsl not installed from .page import append_paged_kv_cache as append_paged_kv_cache from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache from .page import get_batch_indices_positions as get_batch_indices_positions diff --git a/flashinfer/cute_dsl/__init__.py b/flashinfer/cute_dsl/__init__.py index 940031453d..0d271df2b5 100644 --- a/flashinfer/cute_dsl/__init__.py +++ b/flashinfer/cute_dsl/__init__.py @@ -54,6 +54,24 @@ AddRMSNormFP4QuantKernel, ) + # Backwards-compatible re-exports from flashinfer.norm.kernels submodule + from ..norm.kernels import ( + # Kernel classes + RMSNormKernel, + QKRMSNormKernel, + RMSNormQuantKernel, + FusedAddRMSNormKernel, + FusedAddRMSNormQuantKernel, + LayerNormKernel, + # Python API functions + rmsnorm_cute, + qk_rmsnorm_cute, + rmsnorm_quant_cute, + fused_add_rmsnorm_cute, + fused_add_rmsnorm_quant_cute, + layernorm_cute, + ) + __all__ = [ # Utils (always available) "is_cute_dsl_available", @@ -79,4 +97,17 @@ # Add + RMSNorm + FP4 Quantization "add_rmsnorm_fp4quant", "AddRMSNormFP4QuantKernel", + # Norm kernels (CuTe DSL) - backwards-compatible re-exports + "RMSNormKernel", + "QKRMSNormKernel", + "RMSNormQuantKernel", + "FusedAddRMSNormKernel", + "FusedAddRMSNormQuantKernel", + "LayerNormKernel", + "rmsnorm_cute", + "qk_rmsnorm_cute", + "rmsnorm_quant_cute", + "fused_add_rmsnorm_cute", + "fused_add_rmsnorm_quant_cute", + "layernorm_cute", ] diff --git a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py index 9a6bfe55af..f2cdafc9c6 100644 --- a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py +++ b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py @@ -1012,8 +1012,8 @@ def tensor_api( s_tensor, s_unswizzled.contiguous(), global_scale, - Int32(M), - Float32(eps), + M, + eps, ) return tensor_api diff --git a/flashinfer/cute_dsl/rmsnorm_fp4quant.py b/flashinfer/cute_dsl/rmsnorm_fp4quant.py index cd98a8bcea..f5df9a625f 100644 --- a/flashinfer/cute_dsl/rmsnorm_fp4quant.py +++ b/flashinfer/cute_dsl/rmsnorm_fp4quant.py @@ -750,8 +750,8 @@ def tensor_api( y_uint8, s_tensor, global_scale, - Int32(M), - Float32(eps), + M, + eps, ) return tensor_api diff --git a/flashinfer/norm.py b/flashinfer/norm/__init__.py similarity index 64% rename from flashinfer/norm.py rename to flashinfer/norm/__init__.py index de27b12d7a..af16a5d7be 100644 --- a/flashinfer/norm.py +++ b/flashinfer/norm/__init__.py @@ -12,21 +12,74 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +FlashInfer Normalization Kernels +================================ + +This package provides high-performance normalization kernels: + +- RMSNorm: Root Mean Square Normalization +- LayerNorm: Layer Normalization +- Fused Add + RMSNorm: Combined residual add and RMSNorm +- Quantized variants with FP8/FP4 output """ import functools +import os from typing import Optional import torch -from .api_logging import flashinfer_api -from .jit.norm import gen_norm_module -from .utils import device_support_pdl, register_custom_op, register_fake_op +from ..api_logging import flashinfer_api +from ..utils import device_support_pdl, register_custom_op, register_fake_op + +# Always import gen_norm_module for JIT warmup and CUDA fallback +from ..jit.norm import gen_norm_module +# Use CUDA JIT implementation instead of CuTe DSL (for debugging/fallback) +# Also fallback to CUDA JIT if nvidia-cutlass-dsl is not installed +_USE_CUDA_NORM = os.environ.get("FLASHINFER_USE_CUDA_NORM", "0") == "1" -@functools.cache -def get_norm_module(): - return gen_norm_module().build_and_load() +if not _USE_CUDA_NORM: + try: + from .kernels import ( + rmsnorm_cute, + qk_rmsnorm_cute, + rmsnorm_quant_cute, + fused_add_rmsnorm_cute, + fused_add_rmsnorm_quant_cute, + layernorm_cute, + ) + except (ImportError, AttributeError): + # nvidia-cutlass-dsl not installed or incompatible version + _USE_CUDA_NORM = True + +if _USE_CUDA_NORM: + + @functools.cache + def get_norm_module(): + return gen_norm_module().build_and_load() + + +def _normalize_scale_tensor( + scale: torch.Tensor, ref_tensor: torch.Tensor +) -> torch.Tensor: + """Normalize quantization scale tensor to 1D shape (1,) on target device.""" + if not isinstance(scale, torch.Tensor): + raise TypeError(f"scale must be torch.Tensor, got {type(scale)}") + if scale.device != ref_tensor.device: + scale = scale.to(ref_tensor.device) + if scale.dtype != torch.float32: + scale = scale.to(torch.float32) + if scale.ndim == 0: + scale = scale.view(1) + elif scale.ndim == 1 and scale.numel() == 1: + pass + else: + raise ValueError( + f"scale must be a scalar tensor or shape (1,), got shape {tuple(scale.shape)}" + ) + return scale.contiguous() @flashinfer_api @@ -60,16 +113,14 @@ def rmsnorm( output: torch.Tensor Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size). """ - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) if out is None: out = torch.empty_like(input) - _rmsnorm(out, input, weight, eps, enable_pdl) + _rmsnorm_impl(out, input, weight, eps, enable_pdl) return out @register_custom_op("flashinfer::rmsnorm", mutates_args=("out",)) -def _rmsnorm( +def _rmsnorm_impl( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -78,11 +129,21 @@ def _rmsnorm( ) -> None: if enable_pdl is None: enable_pdl = device_support_pdl(input.device) - get_norm_module().rmsnorm(out, input, weight, eps, enable_pdl) + if _USE_CUDA_NORM: + get_norm_module().rmsnorm(out, input, weight, eps, enable_pdl) + else: + if input.dim() == 3: + qk_rmsnorm_cute( + input, weight, out, eps, weight_bias=0.0, enable_pdl=enable_pdl + ) + else: + rmsnorm_cute( + input, weight, out, eps, weight_bias=0.0, enable_pdl=enable_pdl + ) @register_fake_op("flashinfer::rmsnorm") -def _rmsnorm_fake( +def _rmsnorm_impl_fake( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -98,13 +159,13 @@ def rmsnorm_quant( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - scale: float, + scale: torch.Tensor, eps: float = 1e-6, enable_pdl: Optional[bool] = None, -) -> torch.Tensor: - r"""Root mean square normalization. +) -> None: + r"""Root mean square normalization + fp8 quantization. - ``out[i] = (input[i] / RMS(input)) * weight[i]`` + ``out[i] = ((input[i] / RMS(input)) * weight[i]).to(fp8)`` Parameters ---------- @@ -114,22 +175,24 @@ def rmsnorm_quant( Input tensor, 2D shape (batch_size, hidden_size). weight: torch.Tensor Weight tensor, shape (hidden_size,). - scale: float - Scale factor for quantization. + scale: torch.Tensor + Scale factor for quantization, shape (1,). eps: float Epsilon for numerical stability. enable_pdl: bool Whether to enable `programmatic dependent launch `_ - Returns - ------- - output: torch.Tensor - Normalized tensor, 2D shape (batch_size, hidden_size). """ + scale = _normalize_scale_tensor(scale, input) if enable_pdl is None: enable_pdl = device_support_pdl(input.device) - get_norm_module().rmsnorm_quant(out, input, weight, scale, eps, enable_pdl) + if _USE_CUDA_NORM: + get_norm_module().rmsnorm_quant(out, input, weight, scale, eps, enable_pdl) + else: + rmsnorm_quant_cute( + out, input, weight, scale, eps, weight_bias=0.0, enable_pdl=enable_pdl + ) @register_fake_op("flashinfer::rmsnorm_quant") @@ -137,7 +200,7 @@ def _rmsnorm_quant_fake( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - scale: float, + scale: torch.Tensor, eps: float, enable_pdl: Optional[bool], ) -> None: @@ -177,7 +240,12 @@ def fused_add_rmsnorm( """ if enable_pdl is None: enable_pdl = device_support_pdl(input.device) - get_norm_module().fused_add_rmsnorm(input, residual, weight, eps, enable_pdl) + if _USE_CUDA_NORM: + get_norm_module().fused_add_rmsnorm(input, residual, weight, eps, enable_pdl) + else: + fused_add_rmsnorm_cute( + input, residual, weight, eps, weight_bias=0.0, enable_pdl=enable_pdl + ) @register_fake_op("flashinfer::fused_add_rmsnorm") @@ -200,17 +268,17 @@ def fused_add_rmsnorm_quant( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - scale: float, + scale: torch.Tensor, eps: float = 1e-6, enable_pdl: Optional[bool] = None, ) -> None: - r"""Fused add root mean square normalization. + r"""Fused add root mean square normalization + fp8 quantization. Step 1: ``residual[i] += input[i]`` Step 2: - ``input[i] = (residual[i] / RMS(residual)) * weight[i]`` + ``input[i] = ((residual[i] / RMS(residual)) * weight[i]).to(fp8)`` Parameters ---------- @@ -222,19 +290,32 @@ def fused_add_rmsnorm_quant( Residual tensor, shape (batch_size, hidden_size). weight: torch.Tensor Weight tensor, shape (hidden_size,). - scale: float - Scale factor for quantization. + scale: torch.Tensor + Scale factor for quantization, shape (1,). eps: float Epsilon for numerical stability. enable_pdl: bool Whether to enable `programmatic dependent launch `_ """ + scale = _normalize_scale_tensor(scale, input) if enable_pdl is None: enable_pdl = device_support_pdl(input.device) - get_norm_module().fused_add_rmsnorm_quant( - out, input, residual, weight, scale, eps, enable_pdl - ) + if _USE_CUDA_NORM: + get_norm_module().fused_add_rmsnorm_quant( + out, input, residual, weight, scale, eps, enable_pdl + ) + else: + fused_add_rmsnorm_quant_cute( + out, + input, + residual, + weight, + scale, + eps, + weight_bias=0.0, + enable_pdl=enable_pdl, + ) @register_fake_op("flashinfer::fused_add_rmsnorm_quant") @@ -243,7 +324,7 @@ def _fused_add_rmsnorm_quant_fake( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - scale: float, + scale: torch.Tensor, eps: float = 1e-6, enable_pdl: Optional[bool] = None, ) -> None: @@ -281,16 +362,14 @@ def gemma_rmsnorm( output: torch.Tensor Gemma Normalized tensor, shape (batch_size, hidden_size). """ - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) if out is None: out = torch.empty_like(input) - _gemma_rmsnorm(out, input, weight, eps, enable_pdl) + _gemma_rmsnorm_impl(out, input, weight, eps, enable_pdl) return out @register_custom_op("flashinfer::gemma_rmsnorm", mutates_args=("out",)) -def _gemma_rmsnorm( +def _gemma_rmsnorm_impl( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -299,11 +378,21 @@ def _gemma_rmsnorm( ) -> None: if enable_pdl is None: enable_pdl = device_support_pdl(input.device) - get_norm_module().gemma_rmsnorm(out, input, weight, eps, enable_pdl) + if _USE_CUDA_NORM: + get_norm_module().gemma_rmsnorm(out, input, weight, eps, enable_pdl) + else: + if input.dim() == 3: + qk_rmsnorm_cute( + input, weight, out, eps, weight_bias=1.0, enable_pdl=enable_pdl + ) + else: + rmsnorm_cute( + input, weight, out, eps, weight_bias=1.0, enable_pdl=enable_pdl + ) @register_fake_op("flashinfer::gemma_rmsnorm") -def _gemma_rmsnorm_fake( +def _gemma_rmsnorm_impl_fake( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -348,7 +437,14 @@ def gemma_fused_add_rmsnorm( """ if enable_pdl is None: enable_pdl = device_support_pdl(input.device) - get_norm_module().gemma_fused_add_rmsnorm(input, residual, weight, eps, enable_pdl) + if _USE_CUDA_NORM: + get_norm_module().gemma_fused_add_rmsnorm( + input, residual, weight, eps, enable_pdl + ) + else: + fused_add_rmsnorm_cute( + input, residual, weight, eps, weight_bias=1.0, enable_pdl=enable_pdl + ) @register_fake_op("flashinfer::gemma_fused_add_rmsnorm") @@ -388,7 +484,10 @@ def layernorm( Layer Normalized tensor, shape (batch_size, hidden_size). Same dtype as input. """ out = torch.empty_like(input) - get_norm_module().layernorm(out, input, gemma, beta, eps) + if _USE_CUDA_NORM: + get_norm_module().layernorm(out, input, gemma, beta, eps) + else: + layernorm_cute(out, input, gemma, beta, eps) return out @@ -404,10 +503,25 @@ def _layernorm_fake( # CuTe-DSL fused RMSNorm + FP4 Quantization kernels -# These require CuTe-DSL to be available and SM100+ (Blackwell) GPUs +# These require SM100+ (Blackwell) GPUs and nvidia-cutlass-dsl try: - from .cute_dsl import rmsnorm_fp4quant, add_rmsnorm_fp4quant + from ..cute_dsl import rmsnorm_fp4quant as rmsnorm_fp4quant + from ..cute_dsl import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant except ImportError: - # CuTe-DSL not available - rmsnorm_fp4quant = None # type: ignore[misc,assignment] - add_rmsnorm_fp4quant = None # type: ignore[misc,assignment] + # nvidia-cutlass-dsl not installed, these functions will not be available + pass + + +# Public API exports +__all__ = [ + # JIT module generator (always available) + "gen_norm_module", + # Public APIs + "rmsnorm", + "rmsnorm_quant", + "fused_add_rmsnorm", + "fused_add_rmsnorm_quant", + "gemma_rmsnorm", + "gemma_fused_add_rmsnorm", + "layernorm", +] diff --git a/flashinfer/norm/kernels/__init__.py b/flashinfer/norm/kernels/__init__.py new file mode 100644 index 0000000000..20a8b1b7f3 --- /dev/null +++ b/flashinfer/norm/kernels/__init__.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +CuTe DSL Norm Kernels +===================== + +Internal kernel implementations using NVIDIA CuTe-DSL. +""" + +from .rmsnorm import ( + RMSNormKernel, + QKRMSNormKernel, + RMSNormQuantKernel, + rmsnorm_cute, + qk_rmsnorm_cute, + rmsnorm_quant_cute, +) +from .fused_add_rmsnorm import ( + FusedAddRMSNormKernel, + FusedAddRMSNormQuantKernel, + fused_add_rmsnorm_cute, + fused_add_rmsnorm_quant_cute, +) +from .layernorm import ( + LayerNormKernel, + layernorm_cute, +) + +__all__ = [ + # RMSNorm + "RMSNormKernel", + "QKRMSNormKernel", + "RMSNormQuantKernel", + "rmsnorm_cute", + "qk_rmsnorm_cute", + "rmsnorm_quant_cute", + # Fused Add + RMSNorm + "FusedAddRMSNormKernel", + "FusedAddRMSNormQuantKernel", + "fused_add_rmsnorm_cute", + "fused_add_rmsnorm_quant_cute", + # LayerNorm + "LayerNormKernel", + "layernorm_cute", +] diff --git a/flashinfer/norm/kernels/fused_add_rmsnorm.py b/flashinfer/norm/kernels/fused_add_rmsnorm.py new file mode 100644 index 0000000000..406b604ac5 --- /dev/null +++ b/flashinfer/norm/kernels/fused_add_rmsnorm.py @@ -0,0 +1,604 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Fused Add + RMSNorm CuTe DSL Kernels +==================================== + +Includes: +- FusedAddRMSNormKernel: Fused residual add + RMSNorm +- FusedAddRMSNormQuantKernel: Fused residual add + RMSNorm + FP8 quantization +""" + +import functools + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32 + +from ..utils import ( + FLOAT8_E4M3_MAX, + COPY_BITS, + rcp_approx_ftz, + cvt_and_store_f32_to_e4m3_hw, + cvt_and_store_f32_to_e4m3_sw, + has_hw_fp8_cvt, + get_ptr_as_int64, + row_reduce_sum, + predicate_k, + compute_optimal_vec_size, + compute_threads_per_row, + make_tv_layout, + _torch_dtype_to_str, + get_cutlass_dtype, +) + + +# ============================================================================= +# FusedAddRMSNormKernel +# ============================================================================= + + +class FusedAddRMSNormKernel: + """ + Fused Residual Add + RMSNorm Kernel using CuTe-DSL. + + Computes: + 1. residual = input + residual (in-place update) + 2. input = residual / sqrt(mean(residual^2) + eps) * (weight + weight_bias) + """ + + def __init__( + self, + dtype: cutlass.Numeric, + H: int, + weight_bias: float = 0.0, + ): + self.dtype = dtype + self.H = H + self.weight_bias = weight_bias + + # Vectorization parameters: use optimal vec_size for warp utilization + elem_bits = dtype.width + max_vec_size = COPY_BITS // elem_bits + self.vec_size = compute_optimal_vec_size(H, max_vec_size) + self.copy_bits = self.vec_size * elem_bits + + self.threads_per_row = compute_threads_per_row(H, self.vec_size) + self.num_threads = self.threads_per_row + self.num_warps = max(self.threads_per_row // 32, 1) + + self.num_vec_blocks = max( + 1, (H // self.vec_size + self.threads_per_row - 1) // self.threads_per_row + ) + self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_row + + def _smem_size_in_bytes(self) -> int: + # Only reduction buffer needed (register-based approach) + return self.num_warps * 4 + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mR: cute.Tensor, + mW: cute.Tensor, + M: Int32, + eps: Float32, + enable_pdl: cutlass.Constexpr[bool], + stream, + ): + tv_shape, tv_stride = make_tv_layout( + self.threads_per_row, + self.vec_size, + self.num_vec_blocks, + ) + tv_layout = cute.make_layout(tv_shape, stride=tv_stride) + tiler_mn = (1, self.cols_per_tile) + + self.kernel(mX, mR, mW, M, eps, enable_pdl, tv_layout, tiler_mn).launch( + grid=[M, 1, 1], + block=[self.num_threads, 1, 1], + smem=self._smem_size_in_bytes(), + stream=stream, + use_pdl=enable_pdl, + ) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mR: cute.Tensor, + mW: cute.Tensor, + M: Int32, + eps: Float32, + enable_pdl: cutlass.Constexpr[bool], + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + # PDL: Wait for previous kernel (SM90+ only) + if enable_pdl: + cute.arch.griddepcontrol_wait() + + H = self.H + weight_bias = self.weight_bias + threads_per_row = tv_layout.shape[0][0] + num_warps = self.num_warps + copy_bits = self.copy_bits + + smem = cutlass.utils.SmemAllocator() + reduction_buffer = smem.allocate_tensor( + Float32, + cute.make_layout((num_warps,)), + byte_alignment=4, + ) + + idX = cute.make_identity_tensor(mX.shape) + + gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) + gR = cute.local_tile(mR, tiler_mn, (bidx, 0)) + cX = cute.local_tile(idX, tiler_mn, (bidx, 0)) + + mW_2d = cute.prepend_ones(mW, up_to_rank=2) + + copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mX.element_type, + num_bits_per_copy=copy_bits, + ) + + tiled_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn) + thr_copy = tiled_copy.get_slice(tidx) + + tXgX = thr_copy.partition_S(gX) + tXgR = thr_copy.partition_S(gR) + tXgW = thr_copy.partition_S(mW_2d) + tXcX = thr_copy.partition_S(cX) + tYgX = thr_copy.partition_D(gX) + tYgR = thr_copy.partition_D(gR) + + # Register fragments - initialize to zero for proper handling of out-of-bounds threads + tXrX = cute.make_rmem_tensor(tXgX.shape, mX.element_type) + tXrR = cute.make_rmem_tensor(tXgR.shape, mR.element_type) + tXrW = cute.make_rmem_tensor(tXgW.shape, mW.element_type) + tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) + tXrR.store(cute.zeros_like(tXrR, dtype=mR.element_type)) + tXrW.store(cute.zeros_like(tXrW, dtype=mW.element_type)) + + tXpX = predicate_k(tXcX, limit=H) + + # Phase 1: Load input and residual from global to register + cute.copy(copy_atom, tXgX, tXrX, pred=tXpX) + cute.copy(copy_atom, tXgR, tXrR, pred=tXpX) + + x_in = tXrX.load().to(Float32) + r_in = tXrR.load().to(Float32) + x = x_in + r_in + + # Phase 2: Store x to residual (global) + tXrR_out = x.to(mR.element_type) + tXrR_store = cute.make_rmem_tensor(tYgR.shape, mR.element_type) + tXrR_store.store(tXrR_out) + + cute.copy(copy_atom, tXrR_store, tYgR, pred=tXpX) + + # Phase 3: Compute sum of squares (x is kept in registers) + x_sq = x * x + sum_sq = row_reduce_sum(x_sq, threads_per_row, reduction_buffer) + + mean_sq = sum_sq / Float32(H) + rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) + + # Phase 4: Load weight from global to register + cute.copy(copy_atom, tXgW, tXrW, pred=tXpX) + + w = tXrW.load().to(Float32) + + # output = x * rstd * (weight + weight_bias) + # x is still in registers from Phase 1 + y = x * rstd * (w + Float32(weight_bias)) + + tYrY = y.to(mX.element_type) + tXrY = cute.make_rmem_tensor(tYgX.shape, mX.element_type) + tXrY.store(tYrY) + + cute.copy(copy_atom, tXrY, tYgX, pred=tXpX) + + # PDL: Signal dependent kernels (SM90+ only) + if enable_pdl: + cute.arch.griddepcontrol_launch_dependents() + + +# ============================================================================= +# FusedAddRMSNormQuantKernel +# ============================================================================= + + +class FusedAddRMSNormQuantKernel: + """ + Fused Residual Add + RMSNorm + FP8 Quantization Kernel. + + Computes: + 1. residual = input + residual (in-place update) + 2. output = clamp(residual / sqrt(mean(residual^2) + eps) * weight / scale, -448, 448) + """ + + def __init__( + self, + dtype: cutlass.Numeric, + H: int, + weight_bias: float = 0.0, + use_hw_fp8: bool = True, + ): + self.dtype = dtype + self.H = H + self.weight_bias = weight_bias + self.use_hw_fp8 = use_hw_fp8 + + # Vectorization parameters: use optimal vec_size for warp utilization + elem_bits = dtype.width + max_vec_size = COPY_BITS // elem_bits + self.vec_size = compute_optimal_vec_size(H, max_vec_size) + self.copy_bits = self.vec_size * elem_bits + + self.threads_per_row = compute_threads_per_row(H, self.vec_size) + self.num_threads = self.threads_per_row + self.num_warps = max(self.threads_per_row // 32, 1) + + self.num_vec_blocks = max( + 1, (H // self.vec_size + self.threads_per_row - 1) // self.threads_per_row + ) + self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_row + + def _smem_size_in_bytes(self) -> int: + # Only reduction buffer needed (register-based approach) + return self.num_warps * 4 + + @cute.jit + def __call__( + self, + mY: cute.Tensor, + mX: cute.Tensor, + mR: cute.Tensor, + mW: cute.Tensor, + M: Int32, + mS: cute.Tensor, + eps: Float32, + enable_pdl: cutlass.Constexpr[bool], + stream, + ): + tv_shape, tv_stride = make_tv_layout( + self.threads_per_row, + self.vec_size, + self.num_vec_blocks, + ) + tv_layout = cute.make_layout(tv_shape, stride=tv_stride) + tiler_mn = (1, self.cols_per_tile) + + self.kernel(mY, mX, mR, mW, M, mS, eps, enable_pdl, tv_layout, tiler_mn).launch( + grid=[M, 1, 1], + block=[self.num_threads, 1, 1], + smem=self._smem_size_in_bytes(), + stream=stream, + use_pdl=enable_pdl, + ) + + @cute.kernel + def kernel( + self, + mY: cute.Tensor, + mX: cute.Tensor, + mR: cute.Tensor, + mW: cute.Tensor, + M: Int32, + mS: cute.Tensor, + eps: Float32, + enable_pdl: cutlass.Constexpr[bool], + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + # PDL: Wait for previous kernel (SM90+ only) + if enable_pdl: + cute.arch.griddepcontrol_wait() + + H = self.H + weight_bias = self.weight_bias + threads_per_row = tv_layout.shape[0][0] + num_warps = self.num_warps + copy_bits = self.copy_bits + vec_size = self.vec_size + num_vec_blocks = self.num_vec_blocks + + inv_scale = rcp_approx_ftz(mS[0]) + + smem = cutlass.utils.SmemAllocator() + reduction_buffer = smem.allocate_tensor( + Float32, + cute.make_layout((num_warps,)), + byte_alignment=4, + ) + + idX = cute.make_identity_tensor(mX.shape) + + gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) + gR = cute.local_tile(mR, tiler_mn, (bidx, 0)) + cX = cute.local_tile(idX, tiler_mn, (bidx, 0)) + + mW_2d = cute.prepend_ones(mW, up_to_rank=2) + + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mX.element_type, + num_bits_per_copy=copy_bits, + ) + + tiled_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn) + thr_copy_load = tiled_copy_load.get_slice(tidx) + + tXgX = thr_copy_load.partition_S(gX) + tXgR = thr_copy_load.partition_S(gR) + tXgW = thr_copy_load.partition_S(mW_2d) + tXcX = thr_copy_load.partition_S(cX) + tYgR = thr_copy_load.partition_D(gR) + + # Register fragments - initialize to zero for proper handling of out-of-bounds threads + tXrX = cute.make_rmem_tensor(tXgX.shape, mX.element_type) + tXrR = cute.make_rmem_tensor(tXgR.shape, mR.element_type) + tXrW = cute.make_rmem_tensor(tXgW.shape, mW.element_type) + tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) + tXrR.store(cute.zeros_like(tXrR, dtype=mR.element_type)) + tXrW.store(cute.zeros_like(tXrW, dtype=mW.element_type)) + + tXpX = predicate_k(tXcX, limit=H) + + # Phase 1: Load input and residual from global to register + cute.copy(copy_atom_load, tXgX, tXrX, pred=tXpX) + cute.copy(copy_atom_load, tXgR, tXrR, pred=tXpX) + + x_in = tXrX.load().to(Float32) + r_in = tXrR.load().to(Float32) + x = x_in + r_in + + # Store x to residual (global) + tXrR_out = x.to(mR.element_type) + tXrR_store = cute.make_rmem_tensor(tYgR.shape, mR.element_type) + tXrR_store.store(tXrR_out) + cute.copy(copy_atom_load, tXrR_store, tYgR, pred=tXpX) + + # Phase 2: Compute sum of squares (x is kept in registers) + x_sq = x * x + sum_sq = row_reduce_sum(x_sq, threads_per_row, reduction_buffer) + + mean_sq = sum_sq / Float32(H) + rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) + + # Phase 3: Load weight from global to register + cute.copy(copy_atom_load, tXgW, tXrW, pred=tXpX) + w = tXrW.load().to(Float32) + + # output = x * rstd * (weight + weight_bias) * inv_scale + # x is still in registers from Phase 1 + y = x * rstd * (w + Float32(weight_bias)) * inv_scale + + # Phase 4: Clamp and store to FP8 output using PTX scalar stores + # (CuTe FP8 conversion requires vectorized ops, so we use PTX for scalar stores) + # Store y to register tensor for element-wise access + tYrY_f32 = cute.make_rmem_tensor(tXgX.shape, Float32) + tYrY_f32.store(y) + + col_offset = tidx * vec_size + for v in cutlass.range_constexpr(num_vec_blocks): + for e in cutlass.range_constexpr(vec_size): + idx = col_offset + v * threads_per_row * vec_size + e + if idx < H: + # Clamp and convert - use flat index for register tensor + flat_idx = v * vec_size + e + clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) + clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) + # Use PTX to convert and store FP8 byte + out_offset = bidx * H + idx + out_ptr = get_ptr_as_int64(mY, Int32(out_offset)) + if self.use_hw_fp8: + cvt_and_store_f32_to_e4m3_hw(clamped, out_ptr) + else: + cvt_and_store_f32_to_e4m3_sw(clamped, out_ptr) + + # PDL: Signal dependent kernels (SM90+ only) + if enable_pdl: + cute.arch.griddepcontrol_launch_dependents() + + +# ============================================================================= +# Compiled Kernel Getters +# ============================================================================= + + +@functools.cache +def _get_compiled_fused_add_rmsnorm_kernel( + dtype_str: str, H: int, weight_bias: float, enable_pdl: bool +): + """Get a compiled Fused Add + RMSNorm kernel using TVM-FFI.""" + dtype = get_cutlass_dtype(dtype_str) + kernel_obj = FusedAddRMSNormKernel(dtype, H, weight_bias) + + sym_m = cute.sym_int() + sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_r = cute.sym_int(divisibility=kernel_obj.vec_size) + + x_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 + ) + r_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_r, 1), assumed_align=16 + ) + w_fake = cute.runtime.make_fake_compact_tensor(dtype, (H,), assumed_align=16) + + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_kernel = cute.compile( + kernel_obj, + x_fake, + r_fake, + w_fake, + Int32(1), + Float32(1e-6), + enable_pdl, + stream_fake, + options="--enable-tvm-ffi", + ) + + return compiled_kernel + + +@functools.cache +def _get_compiled_fused_add_rmsnorm_quant_kernel( + dtype_str: str, + out_dtype_str: str, + H: int, + weight_bias: float, + enable_pdl: bool, + use_hw_fp8: bool = True, +): + """Get a compiled Fused Add + RMSNorm + Quant kernel using TVM-FFI.""" + dtype = get_cutlass_dtype(dtype_str) + out_dtype = get_cutlass_dtype(out_dtype_str) + kernel_obj = FusedAddRMSNormQuantKernel( + dtype, H, weight_bias, use_hw_fp8=use_hw_fp8 + ) + + sym_m = cute.sym_int() + sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_r = cute.sym_int(divisibility=kernel_obj.vec_size) + + y_fake = cute.runtime.make_fake_tensor( + out_dtype, (sym_m, H), (sym_row_stride_y, 1), assumed_align=16 + ) + x_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 + ) + r_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_r, 1), assumed_align=16 + ) + w_fake = cute.runtime.make_fake_compact_tensor(dtype, (H,), assumed_align=16) + s_fake = cute.runtime.make_fake_compact_tensor(Float32, (1,), assumed_align=4) + + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_kernel = cute.compile( + kernel_obj, + y_fake, + x_fake, + r_fake, + w_fake, + Int32(1), + s_fake, + Float32(1e-6), + enable_pdl, + stream_fake, + options="--enable-tvm-ffi", + ) + + return compiled_kernel + + +# ============================================================================= +# CuTe DSL API Functions +# ============================================================================= + + +def fused_add_rmsnorm_cute( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + weight_bias: float = 0.0, + enable_pdl: bool = False, +) -> None: + """CuTe DSL Fused Add + RMSNorm implementation. + + Supports arbitrary stride - no need to call contiguous(). + Last dimension must be contiguous (stride[-1] == 1). + """ + + shape = input.shape + H = shape[-1] + M = shape[0] + + dtype_str = _torch_dtype_to_str(input.dtype) + kernel = _get_compiled_fused_add_rmsnorm_kernel( + dtype_str, H, weight_bias, enable_pdl + ) + kernel(input, residual, weight, M, eps) + + +def fused_add_rmsnorm_quant_cute( + out: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + eps: float = 1e-6, + weight_bias: float = 0.0, + enable_pdl: bool = False, +) -> None: + """CuTe DSL Fused Add + RMSNorm + FP8 quantization implementation. + + Supports arbitrary stride - no need to call contiguous(). + Last dimension must be contiguous (stride[-1] == 1). + """ + + shape = input.shape + H = shape[-1] + M = shape[0] + + dtype_str = _torch_dtype_to_str(input.dtype) + out_dtype_str = _torch_dtype_to_str(out.dtype) + kernel = _get_compiled_fused_add_rmsnorm_quant_kernel( + dtype_str, + out_dtype_str, + H, + weight_bias, + enable_pdl, + use_hw_fp8=has_hw_fp8_cvt(input.device), + ) + kernel( + out, + input, + residual, + weight, + M, + scale, + eps, + ) + + +__all__ = [ + # Kernel classes + "FusedAddRMSNormKernel", + "FusedAddRMSNormQuantKernel", + # Compiled kernel getters + "_get_compiled_fused_add_rmsnorm_kernel", + "_get_compiled_fused_add_rmsnorm_quant_kernel", + # CuTe DSL APIs + "fused_add_rmsnorm_cute", + "fused_add_rmsnorm_quant_cute", +] diff --git a/flashinfer/norm/kernels/layernorm.py b/flashinfer/norm/kernels/layernorm.py new file mode 100644 index 0000000000..953aed6c95 --- /dev/null +++ b/flashinfer/norm/kernels/layernorm.py @@ -0,0 +1,344 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +LayerNorm CuTe DSL Kernel +========================= + +Traditional LayerNorm with mean and variance normalization. +""" + +import functools + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32 + +from ..utils import ( + COPY_BITS, + row_reduce_sum, + predicate_k, + compute_optimal_vec_size, + compute_threads_per_row, + make_tv_layout, + _torch_dtype_to_str, + get_cutlass_dtype, +) + + +# ============================================================================= +# LayerNormKernel +# ============================================================================= + + +class LayerNormKernel: + """ + Layer Normalization Kernel using CuTe-DSL. + + Computes: output = (input - mean) / sqrt(variance + eps) * gamma + beta + """ + + def __init__( + self, + dtype: cutlass.Numeric, + H: int, + ): + self.dtype = dtype + self.H = H + + # Vectorization parameters: use optimal vec_size for warp utilization + elem_bits = dtype.width + max_vec_size = COPY_BITS // elem_bits + self.vec_size = compute_optimal_vec_size(H, max_vec_size) + self.copy_bits = self.vec_size * elem_bits + + self.threads_per_row = compute_threads_per_row(H, self.vec_size) + self.num_threads = self.threads_per_row + self.num_warps = max(self.threads_per_row // 32, 1) + + self.num_vec_blocks = max( + 1, (H // self.vec_size + self.threads_per_row - 1) // self.threads_per_row + ) + self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_row + + def _smem_size_in_bytes(self) -> int: + # Two reduction buffers (sum and variance), one float32 slot per warp each + return 2 * self.num_warps * 4 + + @cute.jit + def __call__( + self, + mY: cute.Tensor, + mX: cute.Tensor, + mGamma: cute.Tensor, + mBeta: cute.Tensor, + M: Int32, + eps: Float32, + enable_pdl: cutlass.Constexpr[bool], + stream, + ): + # Layout for input (float16/bfloat16) + tv_shape, tv_stride = make_tv_layout( + self.threads_per_row, + self.vec_size, + self.num_vec_blocks, + ) + tv_layout = cute.make_layout(tv_shape, stride=tv_stride) + tiler_mn = (1, self.cols_per_tile) + + self.kernel( + mY, + mX, + mGamma, + mBeta, + M, + eps, + enable_pdl, + tv_layout, + tiler_mn, + ).launch( + grid=[M, 1, 1], + block=[self.num_threads, 1, 1], + smem=self._smem_size_in_bytes(), + stream=stream, + use_pdl=enable_pdl, + ) + + @cute.kernel + def kernel( + self, + mY: cute.Tensor, + mX: cute.Tensor, + mGamma: cute.Tensor, + mBeta: cute.Tensor, + M: Int32, + eps: Float32, + enable_pdl: cutlass.Constexpr[bool], + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + # PDL: Wait for previous kernel (SM90+ only) + if enable_pdl: + cute.arch.griddepcontrol_wait() + + H = self.H + threads_per_row = tv_layout.shape[0][0] + num_warps = self.num_warps + vec_size = self.vec_size + num_vec_blocks = self.num_vec_blocks + copy_bits = self.copy_bits + + smem = cutlass.utils.SmemAllocator() + + # Two reduction buffers: one for sum, one for variance + reduction_buffer_sum = smem.allocate_tensor( + Float32, + cute.make_layout((num_warps,)), + byte_alignment=4, + ) + + reduction_buffer_var = smem.allocate_tensor( + Float32, + cute.make_layout((num_warps,)), + byte_alignment=4, + ) + + idX = cute.make_identity_tensor(mX.shape) + + gY = cute.local_tile(mY, tiler_mn, (bidx, 0)) + gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) + cX = cute.local_tile(idX, tiler_mn, (bidx, 0)) + + # Copy atom for input (input dtype) - sync load + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mX.element_type, + num_bits_per_copy=copy_bits, + ) + + tiled_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn) + + thr_copy_load = tiled_copy_load.get_slice(tidx) + + # Partitions for input + tXgX = thr_copy_load.partition_S(gX) + tXgY = thr_copy_load.partition_D(gY) + tXcX = thr_copy_load.partition_S(cX) + + # Register fragment - initialize to zero for proper handling of out-of-bounds threads + tXrX = cute.make_rmem_tensor(tXgX.shape, mX.element_type) + tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) + + tXpX = predicate_k(tXcX, limit=H) + + # Phase 1: Load input from global to register + cute.copy(copy_atom_load, tXgX, tXrX, pred=tXpX) + + x = tXrX.load().to(Float32) + sum_x = row_reduce_sum(x, threads_per_row, reduction_buffer_sum) + + mean = sum_x / Float32(H) + + # Phase 2: Compute variance = E[(x - mean)^2] + # For invalid threads (col >= H), x=0 so diff = -mean, which would incorrectly + # contribute mean^2 to variance. We zero out these positions before reduction. + diff = x - mean + diff_sq = diff * diff + + num_elems = vec_size * num_vec_blocks + diff_sq_reg = cute.make_rmem_tensor(diff_sq.shape, Float32) + diff_sq_reg.store(diff_sq) + + # Zero out invalid positions so they don't contribute to variance + for i in cutlass.range_constexpr(num_elems): + vec_idx = i % vec_size + block_idx = i // vec_size + col = tidx * vec_size + vec_idx + block_idx * vec_size * threads_per_row + if col >= H: + diff_sq_reg[i] = Float32(0.0) + + diff_sq_masked = diff_sq_reg.load() + sum_diff_sq = row_reduce_sum( + diff_sq_masked, threads_per_row, reduction_buffer_var + ) + + variance = sum_diff_sq / Float32(H) + rstd = cute.math.rsqrt(variance + eps, fastmath=True) + + cute.arch.barrier() + + # Phase 3: Load gamma/beta directly from global memory into registers. + # Each thread owns a disjoint range of columns so there is no sharing + # between threads — staging through shared memory is unnecessary. + gamma_reg = cute.make_rmem_tensor(x.shape, Float32) + beta_reg = cute.make_rmem_tensor(x.shape, Float32) + gamma_reg.store(cute.zeros_like(gamma_reg, dtype=Float32)) + beta_reg.store(cute.zeros_like(beta_reg, dtype=Float32)) + + col_offset = tidx * vec_size + for v in cutlass.range_constexpr(num_vec_blocks): + for e in cutlass.range_constexpr(vec_size): + idx = col_offset + v * threads_per_row * vec_size + e + reg_idx = v * vec_size + e + if idx < H: + gamma_reg[reg_idx] = mGamma[idx] + beta_reg[reg_idx] = mBeta[idx] + + gamma = gamma_reg.load() + beta = beta_reg.load() + + # output = (x - mean) * rstd * gamma + beta + y = (x - mean) * rstd * gamma + beta + + tYrY = y.to(mY.element_type) + tXrY = cute.make_rmem_tensor(tXgY.shape, mY.element_type) + tXrY.store(tYrY) + + cute.copy(copy_atom_load, tXrY, tXgY, pred=tXpX) + + # PDL: Signal dependent kernels (SM90+ only) + if enable_pdl: + cute.arch.griddepcontrol_launch_dependents() + + +# ============================================================================= +# Compiled Kernel Getter +# ============================================================================= + + +@functools.cache +def _get_compiled_layernorm_kernel( + dtype_str: str, gamma_dtype_str: str, H: int, enable_pdl: bool +): + """Get a compiled LayerNorm kernel using TVM-FFI.""" + dtype = get_cutlass_dtype(dtype_str) + gamma_dtype = get_cutlass_dtype(gamma_dtype_str) + kernel_obj = LayerNormKernel(dtype, H) + + sym_m = cute.sym_int() + sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) + + y_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_y, 1), assumed_align=16 + ) + x_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 + ) + gamma_fake = cute.runtime.make_fake_compact_tensor( + gamma_dtype, (H,), assumed_align=16 + ) + beta_fake = cute.runtime.make_fake_compact_tensor( + gamma_dtype, (H,), assumed_align=16 + ) + + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_kernel = cute.compile( + kernel_obj, + y_fake, + x_fake, + gamma_fake, + beta_fake, + Int32(1), + Float32(1e-6), + enable_pdl, + stream_fake, + options="--enable-tvm-ffi", + ) + + return compiled_kernel + + +# ============================================================================= +# CuTe DSL API Function +# ============================================================================= + + +def layernorm_cute( + out: torch.Tensor, + input: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + eps: float = 1e-6, + enable_pdl: bool = False, +) -> None: + """CuTe DSL LayerNorm implementation. + + Supports arbitrary stride - no need to call contiguous(). + Last dimension must be contiguous (stride[-1] == 1). + """ + + shape = input.shape + H = shape[-1] + M = shape[0] + + dtype_str = _torch_dtype_to_str(input.dtype) + gamma_dtype_str = _torch_dtype_to_str(gamma.dtype) + kernel = _get_compiled_layernorm_kernel(dtype_str, gamma_dtype_str, H, enable_pdl) + kernel(out, input, gamma, beta, M, eps) + + +__all__ = [ + # Kernel class + "LayerNormKernel", + # Compiled kernel getter + "_get_compiled_layernorm_kernel", + # CuTe DSL API + "layernorm_cute", +] diff --git a/flashinfer/norm/kernels/rmsnorm.py b/flashinfer/norm/kernels/rmsnorm.py new file mode 100644 index 0000000000..7d7388fc27 --- /dev/null +++ b/flashinfer/norm/kernels/rmsnorm.py @@ -0,0 +1,910 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +RMSNorm CuTe DSL Kernels +======================== + +Includes: +- RMSNormKernel: Basic RMSNorm (also handles Gemma variant with weight_bias=1.0) +- QKRMSNormKernel: RMSNorm for 3D tensors [batch, heads, head_dim] +- RMSNormQuantKernel: RMSNorm + FP8 quantization +""" + +import functools +import operator + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32 + +from ..utils import ( + FLOAT8_E4M3_MAX, + COPY_BITS, + rcp_approx_ftz, + cvt_and_store_f32_to_e4m3_hw, + cvt_and_store_f32_to_e4m3_sw, + has_hw_fp8_cvt, + get_ptr_as_int64, + warp_reduce, + row_reduce_sum, + predicate_k, + compute_optimal_vec_size, + compute_threads_per_row, + make_tv_layout, + _torch_dtype_to_str, + get_cutlass_dtype, + get_num_sm, +) + + +# ============================================================================= +# RMSNormKernel +# ============================================================================= + + +class RMSNormKernel: + """ + RMSNorm Kernel using CuTe-DSL. + + Computes: output = input / sqrt(mean(input^2) + eps) * (weight + weight_bias) + + Key optimizations: + 1. 128-bit vectorized loads for input and weight + 2. Two-stage reduction: warp shuffle + cross-warp shared memory + 3. All computations in FP32 for numerical stability + """ + + def __init__( + self, + dtype: cutlass.Numeric, + H: int, + weight_bias: float = 0.0, + ): + self.dtype = dtype + self.H = H + self.weight_bias = weight_bias + + # Vectorization parameters: use optimal vec_size for warp utilization + elem_bits = dtype.width + max_vec_size = COPY_BITS // elem_bits # 8 for float16/bfloat16, 4 for float32 + self.vec_size = compute_optimal_vec_size(H, max_vec_size) + self.copy_bits = self.vec_size * elem_bits # Actual bits per copy + + # Thread configuration + self.threads_per_row = compute_threads_per_row(H, self.vec_size) + self.num_threads = self.threads_per_row # One row per block + self.num_warps = max(self.threads_per_row // 32, 1) + + # Vectorization blocks + self.num_vec_blocks = max( + 1, (H // self.vec_size + self.threads_per_row - 1) // self.threads_per_row + ) + self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_row + + def _smem_size_in_bytes(self) -> int: + """Calculate shared memory requirement.""" + # Only reduction buffer needed (no shared memory for input/weight) + return self.num_warps * 4 + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mY: cute.Tensor, + M: Int32, + eps: Float32, + enable_pdl: cutlass.Constexpr[bool], + stream, + ): + """Launch the RMSNorm kernel.""" + tv_shape, tv_stride = make_tv_layout( + self.threads_per_row, + self.vec_size, + self.num_vec_blocks, + ) + tv_layout = cute.make_layout(tv_shape, stride=tv_stride) + tiler_mn = (1, self.cols_per_tile) + + self.kernel(mX, mW, mY, M, eps, enable_pdl, tv_layout, tiler_mn).launch( + grid=[M, 1, 1], + block=[self.num_threads, 1, 1], + smem=self._smem_size_in_bytes(), + stream=stream, + use_pdl=enable_pdl, + ) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mY: cute.Tensor, + M: Int32, + eps: Float32, + enable_pdl: cutlass.Constexpr[bool], + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ): + """Device kernel for RMSNorm.""" + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + # PDL: Wait for previous kernel (SM90+ only) + if enable_pdl: + cute.arch.griddepcontrol_wait() + + H = self.H + weight_bias = self.weight_bias + threads_per_row = tv_layout.shape[0][0] + num_warps = self.num_warps + copy_bits = self.copy_bits + + # Allocate shared memory (only reduction buffer needed) + smem = cutlass.utils.SmemAllocator() + reduction_buffer = smem.allocate_tensor( + Float32, + cute.make_layout((num_warps,)), + byte_alignment=4, + ) + + # Create identity tensor for coordinate tracking + idX = cute.make_identity_tensor(mX.shape) + + # Slice for this row + gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) + gY = cute.local_tile(mY, tiler_mn, (bidx, 0)) + cX = cute.local_tile(idX, tiler_mn, (bidx, 0)) + + # Expand weight to 2D for consistent tiling + mW_2d = cute.prepend_ones(mW, up_to_rank=2) + + # Create TiledCopy for load and store (both use CopyUniversalOp for sync operations) + copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mX.element_type, + num_bits_per_copy=copy_bits, + ) + + tiled_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn) + thr_copy = tiled_copy.get_slice(tidx) + + # Partition tensors + tXgX = thr_copy.partition_S(gX) + tXgW = thr_copy.partition_S(mW_2d) + tXgY = thr_copy.partition_D(gY) + tXcX = thr_copy.partition_S(cX) + + # Register fragments - initialize to zero for proper handling of out-of-bounds threads + tXrX = cute.make_rmem_tensor(tXgX.shape, mX.element_type) + tXrW = cute.make_rmem_tensor(tXgW.shape, mW.element_type) + tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) + tXrW.store(cute.zeros_like(tXrW, dtype=mW.element_type)) + + # Bounds checking (column boundary only, row is always valid since grid=[M,1,1]) + tXpX = predicate_k(tXcX, limit=H) + + # =================================================================== + # Phase 1: Load input from global to register + # =================================================================== + cute.copy(copy_atom, tXgX, tXrX, pred=tXpX) + + x = tXrX.load().to(Float32) + x_sq = x * x + sum_sq = row_reduce_sum(x_sq, threads_per_row, reduction_buffer) + + # Compute rstd = 1 / sqrt(mean(x^2) + eps) + mean_sq = sum_sq / Float32(H) + rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) + + # =================================================================== + # Phase 2: Load weight from global to register + # =================================================================== + cute.copy(copy_atom, tXgW, tXrW, pred=tXpX) + + w = tXrW.load().to(Float32) + + # output = input * rstd * (weight + weight_bias) + y = x * rstd * (w + Float32(weight_bias)) + + # Store output using cute.copy with predicate + tYrY = y.to(mY.element_type) + tXrY = cute.make_rmem_tensor(tXgY.shape, mY.element_type) + tXrY.store(tYrY) + + cute.copy(copy_atom, tXrY, tXgY, pred=tXpX) + + # PDL: Signal dependent kernels (SM90+ only) + if enable_pdl: + cute.arch.griddepcontrol_launch_dependents() + + +# ============================================================================= +# QKRMSNormKernel +# ============================================================================= + + +class QKRMSNormKernel: + """ + QK RMSNorm Kernel using CuTe-DSL for 3D tensors [batch, heads, head_dim]. + + Supports arbitrary stride - no need for contiguous tensors. + Each warp processes one (batch, head) pair independently. + Uses warp-only reduction (no cross-warp shared memory sync needed). + + Computes: output[b,h,:] = input[b,h,:] / sqrt(mean(input[b,h,:]^2) + eps) * (weight + weight_bias) + """ + + def __init__( + self, + dtype: cutlass.Numeric, + head_dim: int, + weight_bias: float = 0.0, + num_warps: int = 4, + ): + self.dtype = dtype + self.head_dim = head_dim + self.weight_bias = weight_bias + self.num_warps = num_warps + + # Vectorization: each warp (32 threads) processes head_dim elements + elem_bits = dtype.width + max_vec_size = COPY_BITS // elem_bits # 8 for float16/bfloat16 + self.vec_size = compute_optimal_vec_size(head_dim, max_vec_size) + self.copy_bits = self.vec_size * elem_bits + + # Threads per warp is always 32 + self.threads_per_warp = 32 + self.num_threads = self.threads_per_warp * num_warps + + # Number of vectorized blocks per warp + self.num_vec_blocks = max( + 1, + (head_dim // self.vec_size + self.threads_per_warp - 1) + // self.threads_per_warp, + ) + self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_warp + + def _smem_size_in_bytes(self) -> int: + # No shared memory needed - warp-only reduction + return 0 + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mY: cute.Tensor, + B: Int32, + N: Int32, + eps: Float32, + enable_pdl: cutlass.Constexpr[bool], + num_blocks: Int32, + stream, + ): + """Launch the QKRMSNorm kernel. + + Args: + mX: Input tensor of shape [B, N, H] with arbitrary stride. + mW: Weight tensor of shape [H]. + mY: Output tensor of shape [B, N, H] with arbitrary stride. + B: Batch size. + N: Number of heads. + eps: Epsilon for numerical stability. + enable_pdl: Enable PDL for SM90+. + num_blocks: Number of blocks to launch. + stream: CUDA stream. + """ + # Use 32 threads per warp for warp-level layout + tv_shape, tv_stride = make_tv_layout(32, self.vec_size, self.num_vec_blocks) + tv_layout = cute.make_layout(tv_shape, stride=tv_stride) + + self.kernel(mX, mW, mY, B, N, eps, enable_pdl, tv_layout).launch( + grid=[num_blocks, 1, 1], + block=[self.num_threads, 1, 1], + smem=self._smem_size_in_bytes(), + stream=stream, + use_pdl=enable_pdl, + ) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mY: cute.Tensor, + B: Int32, + N: Int32, + eps: Float32, + enable_pdl: cutlass.Constexpr[bool], + tv_layout: cute.Layout, + ): + """Device kernel for QKRMSNorm with 3D tensor support and arbitrary stride.""" + bidx, _, _ = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + + # PDL: Wait for previous kernel (SM90+ only) + if enable_pdl: + cute.arch.griddepcontrol_wait() + + head_dim = self.head_dim + weight_bias = self.weight_bias + num_warps = self.num_warps + copy_bits = self.copy_bits + + # Thread indexing within block + lane_idx = tidx % 32 + warp_idx = tidx // 32 + + # Total workers and jobs + grid_dim_x, _, _ = cute.arch.grid_dim() + num_workers = grid_dim_x * num_warps + worker_idx = bidx * num_warps + warp_idx + + # Total number of rows + M = B * N + + # Create copy atom for vectorized loads/stores + copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mX.element_type, + num_bits_per_copy=copy_bits, + ) + + # Expand weight to 2D for consistent tiling: [1, H] + mW_2d = cute.prepend_ones(mW, up_to_rank=2) + + # Create tiled copy for warp-level access (32 threads) + tiler_2d = (1, self.cols_per_tile) + tiled_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_2d) + thr_copy = tiled_copy.get_slice(lane_idx) + + # Create identity tensor matching tile shape for bounds checking + id2d = cute.make_identity_tensor(tiler_2d) + + # Weight and predicate are the same for all rows - compute once + tXgW = thr_copy.partition_S(mW_2d) + tXcX = thr_copy.partition_S(id2d) + tXpX = predicate_k(tXcX, limit=head_dim) + + # Load weight once (same for all rows) + tXrW = cute.make_rmem_tensor(tXgW.shape, mW.element_type) + tXrW.store(cute.zeros_like(tXrW, dtype=mW.element_type)) + cute.copy(copy_atom, tXgW, tXrW, pred=tXpX) + w = tXrW.load().to(Float32) + + # Each warp processes multiple rows with grid-stride loop + row_idx = worker_idx + while row_idx < M: + batch_idx = row_idx // N + head_idx = row_idx % N + + # Get 3D tile and collapse first two dims (both size 1) to 2D for tiled_copy + gX = cute.group_modes( + cute.local_tile( + mX, (1, 1, self.cols_per_tile), (batch_idx, head_idx, 0) + ), + 0, + 2, + ) + gY = cute.group_modes( + cute.local_tile( + mY, (1, 1, self.cols_per_tile), (batch_idx, head_idx, 0) + ), + 0, + 2, + ) + + # Partition tensors for this thread + tXgX = thr_copy.partition_S(gX) + tXgY = thr_copy.partition_D(gY) + + # Register fragment for input - initialize to zero + tXrX = cute.make_rmem_tensor(tXgX.shape, mX.element_type) + tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) + + # Phase 1: Load input and compute sum of squares + cute.copy(copy_atom, tXgX, tXrX, pred=tXpX) + + x = tXrX.load().to(Float32) + x_sq = x * x + + # Reduce within register tensor first + local_sum = x_sq.reduce( + cute.ReductionOp.ADD, init_val=Float32(0.0), reduction_profile=0 + ) + + # Warp reduction for sum_sq + sum_sq = warp_reduce(local_sum, operator.add, width=32) + + # Compute rstd + mean_sq = sum_sq / Float32(head_dim) + rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) + + # output = input * rstd * (weight + weight_bias) + # w is already loaded outside the loop + y = x * rstd * (w + Float32(weight_bias)) + + # Store output + tYrY = y.to(mY.element_type) + tXrY = cute.make_rmem_tensor(tXgY.shape, mY.element_type) + tXrY.store(tYrY) + + cute.copy(copy_atom, tXrY, tXgY, pred=tXpX) + + # Next row for this warp + row_idx = row_idx + num_workers + + # PDL: Signal dependent kernels (SM90+ only) + if enable_pdl: + cute.arch.griddepcontrol_launch_dependents() + + +# ============================================================================= +# RMSNormQuantKernel +# ============================================================================= + + +class RMSNormQuantKernel: + """ + RMSNorm + FP8 Quantization Kernel using CuTe-DSL. + + Computes: output = clamp(input / sqrt(mean(input^2) + eps) * weight / scale, -448, 448) + Then quantizes to FP8 E4M3. + """ + + def __init__( + self, + dtype: cutlass.Numeric, + H: int, + weight_bias: float = 0.0, + use_hw_fp8: bool = True, + ): + self.dtype = dtype + self.H = H + self.weight_bias = weight_bias + self.use_hw_fp8 = use_hw_fp8 + + # Vectorization parameters: use optimal vec_size for warp utilization + elem_bits = dtype.width + max_vec_size_in = COPY_BITS // elem_bits # 8 for fp16/bf16 + self.vec_size = compute_optimal_vec_size(H, max_vec_size_in) + self.copy_bits = self.vec_size * elem_bits + + # For FP8 output: minimum 16 bits = 2 FP8 elements + # Use same vec_size to keep layouts aligned, but ensure copy_bits_out >= 16 + self.vec_size_out = self.vec_size + self.copy_bits_out = max(16, self.vec_size * 8) + + self.threads_per_row = compute_threads_per_row(H, self.vec_size) + self.num_threads = self.threads_per_row + self.num_warps = max(self.threads_per_row // 32, 1) + + self.num_vec_blocks = max( + 1, (H // self.vec_size + self.threads_per_row - 1) // self.threads_per_row + ) + self.cols_per_tile = self.vec_size * self.num_vec_blocks * self.threads_per_row + + def _smem_size_in_bytes(self) -> int: + # Only reduction buffer needed + return self.num_warps * 4 + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mY: cute.Tensor, + M: Int32, + mS: cute.Tensor, + eps: Float32, + enable_pdl: cutlass.Constexpr[bool], + stream, + ): + tv_shape, tv_stride = make_tv_layout( + self.threads_per_row, self.vec_size, self.num_vec_blocks + ) + tv_layout = cute.make_layout(tv_shape, stride=tv_stride) + tiler_mn = (1, self.cols_per_tile) + + self.kernel(mX, mW, mY, M, mS, eps, enable_pdl, tv_layout, tiler_mn).launch( + grid=[M, 1, 1], + block=[self.num_threads, 1, 1], + smem=self._smem_size_in_bytes(), + stream=stream, + use_pdl=enable_pdl, + ) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mY: cute.Tensor, + M: Int32, + mS: cute.Tensor, + eps: Float32, + enable_pdl: cutlass.Constexpr[bool], + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + # PDL: Wait for previous kernel (SM90+ only) + if enable_pdl: + cute.arch.griddepcontrol_wait() + + H = self.H + weight_bias = self.weight_bias + threads_per_row = tv_layout.shape[0][0] + num_warps = self.num_warps + copy_bits = self.copy_bits + vec_size = self.vec_size + num_vec_blocks = self.num_vec_blocks + + inv_scale = rcp_approx_ftz(mS[0]) + + smem = cutlass.utils.SmemAllocator() + reduction_buffer = smem.allocate_tensor( + Float32, cute.make_layout((num_warps,)), byte_alignment=4 + ) + + idX = cute.make_identity_tensor(mX.shape) + gX = cute.local_tile(mX, tiler_mn, (bidx, 0)) + cX = cute.local_tile(idX, tiler_mn, (bidx, 0)) + + mW_2d = cute.prepend_ones(mW, up_to_rank=2) + + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=copy_bits + ) + + tiled_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn) + thr_copy_load = tiled_copy_load.get_slice(tidx) + + tXgX = thr_copy_load.partition_S(gX) + tXgW = thr_copy_load.partition_S(mW_2d) + tXcX = thr_copy_load.partition_S(cX) + + # Register fragments - initialize to zero for proper handling of out-of-bounds threads + tXrX = cute.make_rmem_tensor(tXgX.shape, mX.element_type) + tXrW = cute.make_rmem_tensor(tXgW.shape, mW.element_type) + tXrX.store(cute.zeros_like(tXrX, dtype=mX.element_type)) + tXrW.store(cute.zeros_like(tXrW, dtype=mW.element_type)) + + tXpX = predicate_k(tXcX, limit=H) + + # Phase 1: Load input from global to register + cute.copy(copy_atom_load, tXgX, tXrX, pred=tXpX) + + x = tXrX.load().to(Float32) + x_sq = x * x + sum_sq = row_reduce_sum(x_sq, threads_per_row, reduction_buffer) + + mean_sq = sum_sq / Float32(H) + rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) + + # Phase 2: Load weight from global to register + cute.copy(copy_atom_load, tXgW, tXrW, pred=tXpX) + + w = tXrW.load().to(Float32) + y = x * rstd * (w + Float32(weight_bias)) * inv_scale + + # Phase 3: Clamp and store to FP8 output using PTX scalar stores + # (CuTe FP8 conversion requires vectorized ops, so we use PTX for scalar stores) + # Store y to register tensor for element-wise access + tYrY_f32 = cute.make_rmem_tensor(tXgX.shape, Float32) + tYrY_f32.store(y) + + col_offset = tidx * vec_size + for v in cutlass.range_constexpr(num_vec_blocks): + for e in cutlass.range_constexpr(vec_size): + idx = col_offset + v * threads_per_row * vec_size + e + if idx < H: + # Clamp and convert - use flat index for register tensor + flat_idx = v * vec_size + e + clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX)) + clamped = min(clamped, Float32(FLOAT8_E4M3_MAX)) + # Use PTX to convert and store FP8 byte + out_offset = bidx * H + idx + out_ptr = get_ptr_as_int64(mY, Int32(out_offset)) + if self.use_hw_fp8: + cvt_and_store_f32_to_e4m3_hw(clamped, out_ptr) + else: + cvt_and_store_f32_to_e4m3_sw(clamped, out_ptr) + + # PDL: Signal dependent kernels (SM90+ only) + if enable_pdl: + cute.arch.griddepcontrol_launch_dependents() + + +# ============================================================================= +# Compiled Kernel Getters +# ============================================================================= + + +@functools.cache +def _get_compiled_rmsnorm_kernel( + dtype_str: str, H: int, weight_bias: float, enable_pdl: bool +): + """Get a compiled RMSNorm kernel using TVM-FFI.""" + dtype = get_cutlass_dtype(dtype_str) + kernel_obj = RMSNormKernel(dtype, H, weight_bias) + + # Use symbolic size for dynamic M dimension + sym_m = cute.sym_int() + # Use symbolic stride for arbitrary row stride (last dim must be contiguous) + sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) + + # Create fake tensors with symbolic stride for arbitrary stride support + x_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 + ) + w_fake = cute.runtime.make_fake_compact_tensor(dtype, (H,), assumed_align=16) + y_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_y, 1), assumed_align=16 + ) + + # Create fake stream that uses environment stream at runtime + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + # Compile with TVM-FFI enabled + compiled_kernel = cute.compile( + kernel_obj, + x_fake, + w_fake, + y_fake, + Int32(1), # Dummy M + Float32(1e-6), # Dummy eps + enable_pdl, + stream_fake, + options="--enable-tvm-ffi", + ) + + return compiled_kernel + + +@functools.cache +def _get_compiled_qk_rmsnorm_kernel( + dtype_str: str, head_dim: int, weight_bias: float, num_warps: int, enable_pdl: bool +): + """Get a compiled QKRMSNorm kernel for 3D tensors with arbitrary stride.""" + dtype = get_cutlass_dtype(dtype_str) + kernel_obj = QKRMSNormKernel(dtype, head_dim, weight_bias, num_warps) + + # Use symbolic sizes for B, N dimensions + sym_b = cute.sym_int() + sym_n = cute.sym_int() + + # Use symbolic strides for arbitrary stride support + # stride[-1] must be 1 (contiguous in head_dim), but batch/head strides can be anything + sym_batch_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_head_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_batch_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_head_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) + + # Create 3D fake tensors with arbitrary stride + x_fake = cute.runtime.make_fake_tensor( + dtype, + (sym_b, sym_n, head_dim), + (sym_batch_stride_x, sym_head_stride_x, 1), + assumed_align=16, + ) + y_fake = cute.runtime.make_fake_tensor( + dtype, + (sym_b, sym_n, head_dim), + (sym_batch_stride_y, sym_head_stride_y, 1), + assumed_align=16, + ) + w_fake = cute.runtime.make_fake_compact_tensor(dtype, (head_dim,), assumed_align=16) + + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + # Compile with TVM-FFI enabled + compiled_kernel = cute.compile( + kernel_obj, + x_fake, + w_fake, + y_fake, + Int32(1), # Dummy B + Int32(1), # Dummy N + Float32(1e-6), # Dummy eps + enable_pdl, + Int32(1), # Dummy num_blocks + stream_fake, + options="--enable-tvm-ffi", + ) + + return compiled_kernel + + +@functools.cache +def _get_compiled_rmsnorm_quant_kernel( + dtype_str: str, + out_dtype_str: str, + H: int, + weight_bias: float, + enable_pdl: bool, + use_hw_fp8: bool = True, +): + """Get a compiled RMSNorm + Quant kernel using TVM-FFI.""" + dtype = get_cutlass_dtype(dtype_str) + out_dtype = get_cutlass_dtype(out_dtype_str) + kernel_obj = RMSNormQuantKernel(dtype, H, weight_bias, use_hw_fp8=use_hw_fp8) + + sym_m = cute.sym_int() + sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size_out) + + x_fake = cute.runtime.make_fake_tensor( + dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 + ) + w_fake = cute.runtime.make_fake_compact_tensor(dtype, (H,), assumed_align=16) + y_fake = cute.runtime.make_fake_tensor( + out_dtype, (sym_m, H), (sym_row_stride_y, 1), assumed_align=16 + ) + s_fake = cute.runtime.make_fake_compact_tensor(Float32, (1,), assumed_align=4) + + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_kernel = cute.compile( + kernel_obj, + x_fake, + w_fake, + y_fake, + Int32(1), + s_fake, + Float32(1e-6), # eps + enable_pdl, + stream_fake, + options="--enable-tvm-ffi", + ) + + return compiled_kernel + + +# ============================================================================= +# CuTe DSL API Functions +# ============================================================================= + + +def rmsnorm_cute( + input: torch.Tensor, + weight: torch.Tensor, + out: torch.Tensor, + eps: float = 1e-6, + weight_bias: float = 0.0, + enable_pdl: bool = False, +) -> None: + """CuTe DSL RMSNorm implementation. + + Supports arbitrary stride - no need to call contiguous(). + Last dimension must be contiguous (stride[-1] == 1). + """ + + shape = input.shape + H = shape[-1] + + if len(shape) == 3: + M = shape[0] * shape[1] + input_2d = input.view(M, H) + out_2d = out.view(M, H) + else: + M = shape[0] + input_2d = input + out_2d = out + + kernel = _get_compiled_rmsnorm_kernel( + _torch_dtype_to_str(input.dtype), H, weight_bias, enable_pdl + ) + kernel(input_2d, weight, out_2d, M, eps) + + +def qk_rmsnorm_cute( + input: torch.Tensor, + weight: torch.Tensor, + output: torch.Tensor, + eps: float = 1e-6, + weight_bias: float = 0.0, + enable_pdl: bool = False, +) -> None: + """CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim]. + + Supports arbitrary stride - no need to call contiguous(). + Each warp processes one (batch, head) pair independently using warp-only reduction. + + Args: + input: Input tensor of shape [batch_size, num_heads, head_dim]. + Last dimension must be contiguous (stride[-1] == 1). + weight: Weight tensor of shape [head_dim]. + output: Output tensor (same shape as input). + eps: Small constant for numerical stability. + weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma). + enable_pdl: Enable Programmatic Dependent Launch for SM90+ GPUs. + """ + shape = input.shape + assert len(shape) == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]" + + batch_size, num_heads, head_dim = shape + M = batch_size * num_heads + + # Kernel configuration + num_warps = 4 + + # Calculate grid size based on SM count and estimated occupancy + num_sms = get_num_sm(input.device) + blocks_per_sm = 16 # Theoretical max for 128-thread blocks + max_blocks = num_sms * blocks_per_sm + needed_blocks = (M + num_warps - 1) // num_warps + num_blocks = min(max_blocks, needed_blocks) + + dtype_str = _torch_dtype_to_str(input.dtype) + kernel = _get_compiled_qk_rmsnorm_kernel( + dtype_str, head_dim, weight_bias, num_warps, enable_pdl + ) + + # Pass 3D tensors directly - kernel handles arbitrary stride + kernel(input, weight, output, batch_size, num_heads, eps, num_blocks) + + +def rmsnorm_quant_cute( + out: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + eps: float = 1e-6, + weight_bias: float = 0.0, + enable_pdl: bool = False, +) -> None: + """CuTe DSL RMSNorm + FP8 quantization implementation. + + Supports arbitrary stride - no need to call contiguous(). + Last dimension must be contiguous (stride[-1] == 1). + """ + + shape = input.shape + H = shape[-1] + M = shape[0] + + dtype_str = _torch_dtype_to_str(input.dtype) + out_dtype_str = _torch_dtype_to_str(out.dtype) + kernel = _get_compiled_rmsnorm_quant_kernel( + dtype_str, + out_dtype_str, + H, + weight_bias, + enable_pdl, + use_hw_fp8=has_hw_fp8_cvt(input.device), + ) + kernel(input, weight, out, M, scale, eps) + + +__all__ = [ + # Kernel classes + "RMSNormKernel", + "QKRMSNormKernel", + "RMSNormQuantKernel", + # Compiled kernel getters + "_get_compiled_rmsnorm_kernel", + "_get_compiled_qk_rmsnorm_kernel", + "_get_compiled_rmsnorm_quant_kernel", + # CuTe DSL APIs + "rmsnorm_cute", + "qk_rmsnorm_cute", + "rmsnorm_quant_cute", +] diff --git a/flashinfer/norm/utils.py b/flashinfer/norm/utils.py new file mode 100644 index 0000000000..058b7c2276 --- /dev/null +++ b/flashinfer/norm/utils.py @@ -0,0 +1,435 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Shared CuTe DSL Utilities for Norm Kernels +========================================== + +Common utilities used by all norm kernels: +- Constants for vectorization and FP8 quantization +- PTX intrinsics for fast reciprocal and FP8 conversion +- Warp and block reduction utilities +- Predicate helpers for bounds checking +- Layout configuration helpers +- Type conversion utilities +""" + +import math +import operator +from typing import Callable + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32, Int64 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm + +from ..cute_dsl.utils import get_cutlass_dtype, get_num_sm + + +# ============================================================================= +# Constants +# ============================================================================= + +FLOAT8_E4M3_MAX = 448.0 # Maximum value representable in FP8 E4M3 +COPY_BITS = 128 # 128-bit vectorized loads + + +# ============================================================================= +# PTX Intrinsics +# ============================================================================= + + +@dsl_user_op +def rcp_approx_ftz(a: Float32, *, loc=None, ip=None) -> Float32: + """Fast reciprocal using PTX rcp.approx.ftz.f32.""" + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "rcp.approx.ftz.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def cvt_and_store_f32_to_e4m3_hw(val: Float32, addr: Int64, *, loc=None, ip=None): + """Convert float32 to E4M3 and store single byte — hardware path (sm_89+). + + Uses the cvt.rn.satfinite.e4m3x2.f32 PTX instruction for maximum performance. + """ + llvm.inline_asm( + None, + [Float32(val).ir_value(loc=loc, ip=ip), Int64(addr).ir_value(loc=loc, ip=ip)], + """ + { + .reg .b16 fp8_pair; + .reg .f32 zero; + mov.f32 zero, 0f00000000; + cvt.rn.satfinite.e4m3x2.f32 fp8_pair, zero, $0; + st.global.b8 [$1], fp8_pair; + } + """, + "f,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cvt_and_store_f32_to_e4m3_sw(val: Float32, addr: Int64, *, loc=None, ip=None): + """Convert float32 to E4M3 and store single byte — software path (all architectures). + + Uses integer bit manipulation mirroring NVIDIA's __nv_cvt_float_to_fp8 from cuda_fp8.hpp. + The caller must clamp the value to [-448, 448] before calling this function. + + E4M3 format: 1 sign bit, 4 exponent bits (bias=7), 3 mantissa bits. + Conversion strategy (operates on f32 bit representation): + - Normal range (f32 biased exp >= 121): direct exponent/mantissa extraction with RNE + - Denormal range (f32 biased exp in [118..120]): shift mantissa with implicit bit, RNE + - Underflow (abs <= 2^-10): flush to zero (2^-10 is the RNE midpoint to min denorm) + """ + llvm.inline_asm( + None, + [Float32(val).ir_value(loc=loc, ip=ip), Int64(addr).ir_value(loc=loc, ip=ip)], + """ + { + .reg .b32 fbits, sign8, abs_bits, f32_exp, f32_mant; + .reg .b32 e4m3_exp, e4m3_mant, norm_raw; + .reg .b32 rbit, sticky, odd_bit, radj; + .reg .b32 shift, dmant4, denorm_raw; + .reg .b32 dr_bit, dsticky, dadj; + .reg .b32 e4m3_raw, tmp, tmp2; + .reg .pred p_zero, p_denorm; + + // Bitcast float to int and extract sign/exponent/mantissa + mov.b32 fbits, $0; + shr.u32 sign8, fbits, 24; + and.b32 sign8, sign8, 128; + and.b32 abs_bits, fbits, 0x7FFFFFFF; + shr.u32 f32_exp, abs_bits, 23; + and.b32 f32_mant, abs_bits, 0x007FFFFF; + + // === Normal path (f32 biased exp >= 121, i.e. e4m3 exp >= 1) === + sub.u32 e4m3_exp, f32_exp, 120; + shr.u32 e4m3_mant, f32_mant, 20; + shl.b32 norm_raw, e4m3_exp, 3; + or.b32 norm_raw, norm_raw, e4m3_mant; + + // Round-to-nearest-even: round up if round_bit AND (sticky OR odd) + shr.u32 rbit, f32_mant, 19; + and.b32 rbit, rbit, 1; + and.b32 sticky, f32_mant, 0x0007FFFF; + and.b32 odd_bit, e4m3_mant, 1; + or.b32 tmp, sticky, odd_bit; + min.u32 tmp, tmp, 1; + and.b32 radj, tmp, rbit; + add.u32 norm_raw, norm_raw, radj; + min.u32 norm_raw, norm_raw, 126; + + // === Denormal path (f32 biased exp in {118,119,120}) === + sub.u32 shift, 121, f32_exp; + shr.u32 tmp, f32_mant, 20; + or.b32 dmant4, tmp, 8; + shr.u32 denorm_raw, dmant4, shift; + + // RNE rounding for denormals + sub.u32 tmp, shift, 1; + shr.u32 dr_bit, dmant4, tmp; + and.b32 dr_bit, dr_bit, 1; + shl.b32 tmp2, 1, tmp; + sub.u32 tmp2, tmp2, 1; + and.b32 dsticky, dmant4, tmp2; + and.b32 tmp, f32_mant, 0x000FFFFF; + or.b32 dsticky, dsticky, tmp; + and.b32 odd_bit, denorm_raw, 1; + or.b32 tmp, dsticky, odd_bit; + min.u32 tmp, tmp, 1; + and.b32 dadj, tmp, dr_bit; + add.u32 denorm_raw, denorm_raw, dadj; + + // Select between normal/denormal, then apply zero flush + setp.le.u32 p_denorm, f32_exp, 120; + selp.u32 e4m3_raw, denorm_raw, norm_raw, p_denorm; + setp.le.u32 p_zero, abs_bits, 0x3A800000; + selp.u32 e4m3_raw, 0, e4m3_raw, p_zero; + + // Apply sign and store single byte + or.b32 e4m3_raw, e4m3_raw, sign8; + st.global.b8 [$1], e4m3_raw; + } + """, + "f,l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def has_hw_fp8_cvt(device: torch.device = None) -> bool: + """Check if the device supports hardware FP8 conversion (sm_89+).""" + if device is None: + device = torch.device("cuda", torch.cuda.current_device()) + major, minor = torch.cuda.get_device_capability(device) + return major > 8 or (major == 8 and minor >= 9) + + +@dsl_user_op +def get_ptr_as_int64(tensor: cute.Tensor, offset: Int32, *, loc=None, ip=None) -> Int64: + """Get the memory address of tensor[offset] as Int64.""" + elem_ptr = tensor.iterator + Int32(offset) + ptr_int = llvm.ptrtoint(T.i64(), elem_ptr.llvm_ptr, loc=loc, ip=ip) + return Int64(ptr_int) + + +# ============================================================================= +# Warp and Block Reduction Utilities +# ============================================================================= + + +@cute.jit +def warp_reduce(val, op, width: cutlass.Constexpr[int] = 32): + """Reduce across threads in a warp using butterfly shuffle.""" + if cutlass.const_expr(isinstance(val, cute.TensorSSA)): + res = cute.make_rmem_tensor(val.shape, val.dtype) + res.store(val) + for i in cutlass.range_constexpr(cute.size(val.shape)): + res[i] = warp_reduce(res[i], op, width) + return res.load() + else: + for i in cutlass.range_constexpr(int(math.log2(width))): + val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) + return val + + +@cute.jit +def block_reduce( + val: Float32, + op: Callable, + reduction_buffer: cute.Tensor, + init_val: Float32, +) -> Float32: + """Block reduction across multiple warps using shared memory.""" + lane_idx = cute.arch.lane_idx() + warp_idx = cute.arch.warp_idx() + num_warps = cute.size(reduction_buffer.shape) + + if lane_idx == 0: + reduction_buffer[warp_idx] = val + cute.arch.barrier() + + block_reduce_val = init_val + if lane_idx < num_warps: + block_reduce_val = reduction_buffer[lane_idx] + return warp_reduce(block_reduce_val, op) + + +@cute.jit +def row_reduce_sum( + x: cute.TensorSSA, + threads_per_row: cutlass.Constexpr[int], + reduction_buffer: cute.Tensor, +) -> Float32: + """Row reduction for sum operation.""" + local_val = x.reduce( + cute.ReductionOp.ADD, init_val=Float32(0.0), reduction_profile=0 + ) + + warp_width = min(threads_per_row, 32) + warp_val = warp_reduce(local_val, operator.add, width=warp_width) + + warps_per_row = max(threads_per_row // 32, 1) + + if cutlass.const_expr(warps_per_row > 1): + return block_reduce(warp_val, operator.add, reduction_buffer, Float32(0.0)) + else: + return warp_val + + +# ============================================================================= +# Predicate Utility +# ============================================================================= + + +@cute.jit +def predicate_k(tXcX: cute.Tensor, limit: int) -> cute.Tensor: + """Create predicate tensor for bounds checking (2D tensors).""" + tXpX = cute.make_rmem_tensor( + cute.make_layout( + ( + cute.size(tXcX, mode=[0, 1]), + cute.size(tXcX, mode=[1]), + cute.size(tXcX, mode=[2]), + ), + stride=(cute.size(tXcX, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tXpX.shape[0]): + for rest_k in cutlass.range_constexpr(tXpX.shape[2]): + tXpX[rest_v, 0, rest_k] = cute.elem_less( + tXcX[(0, rest_v), 0, rest_k][1], limit + ) + return tXpX + + +@cute.jit +def predicate_k_3d(tXcX: cute.Tensor, limit: int) -> cute.Tensor: + """Create predicate tensor for bounds checking (3D tensors). + + For 3D tensors after local_tile, the last coordinate [2] is the head_dim dimension. + """ + tXpX = cute.make_rmem_tensor( + cute.make_layout( + ( + cute.size(tXcX, mode=[0, 1]), + cute.size(tXcX, mode=[1]), + cute.size(tXcX, mode=[2]), + ), + stride=(cute.size(tXcX, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tXpX.shape[0]): + for rest_k in cutlass.range_constexpr(tXpX.shape[2]): + # For 3D tensor, coordinate[2] is the head_dim index + tXpX[rest_v, 0, rest_k] = cute.elem_less( + tXcX[(0, rest_v), 0, rest_k][2], limit + ) + return tXpX + + +# ============================================================================= +# Helper Functions for Kernel Configuration +# ============================================================================= + + +def compute_optimal_vec_size(H: int, max_vec_size: int) -> int: + """Compute vec_size that maximizes warp utilization. + + For small hidden sizes, using max vec_size may result in fewer than 32 threads, + wasting warp resources. This function finds the largest vec_size that: + 1. Divides H evenly + 2. Results in at least 32 threads (one full warp) + + Examples: + - H=128, max=8: vec_size=8 gives 16 threads, vec_size=4 gives 32 threads -> return 4 + - H=4096, max=8: vec_size=8 gives 512 threads -> return 8 + - H=111, max=8: no vec_size divides evenly with >=32 threads, use gcd -> return 1 + """ + # Try vec_sizes from largest to smallest + for vec_size in [ + max_vec_size, + max_vec_size // 2, + max_vec_size // 4, + max_vec_size // 8, + ]: + if vec_size < 1: + continue + if H % vec_size != 0: + continue + threads_needed = H // vec_size + if threads_needed >= 32: + return vec_size + # Fallback: use gcd for correctness (handles odd sizes like 111) + return math.gcd(max_vec_size, H) + + +def compute_threads_per_row(H: int, vec_size: int) -> int: + """Compute optimal threads per row based on hidden size.""" + threads_needed = (H + vec_size - 1) // vec_size + # Round up to power of 2, capped at 1024 + threads = 32 + while threads < threads_needed and threads < 1024: + threads *= 2 + return min(threads, 1024) + + +def make_tv_layout(threads_per_row: int, vec_size: int, num_vec_blocks: int) -> tuple: + """Create Thread-Value layout for coalesced vectorized memory access. + + This layout distributes work across threads where each thread handles + vec_size consecutive elements, and threads are arranged for coalesced access. + + Args: + threads_per_row: Number of threads processing one row + vec_size: Number of elements each thread processes per vector load + num_vec_blocks: Number of vector blocks per row + + Returns: + Tuple of (shape, stride) for creating cute.Layout + """ + shape = ( + (threads_per_row, 1), + (vec_size, num_vec_blocks), + ) + stride = ( + (vec_size, 1), + (1, vec_size * threads_per_row), + ) + return shape, stride + + +# ============================================================================= +# Type Conversion Utilities +# ============================================================================= + + +# Module-level dict to avoid recreation on each call +_TORCH_DTYPE_TO_STR_MAP = { + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.float32: "float32", + torch.float8_e4m3fn: "float8_e4m3fn", +} + + +def _torch_dtype_to_str(dtype: torch.dtype) -> str: + return _TORCH_DTYPE_TO_STR_MAP[dtype] + + +# Re-export utilities from cute_dsl.utils for convenience +__all__ = [ + # Constants + "FLOAT8_E4M3_MAX", + "COPY_BITS", + # PTX intrinsics + "rcp_approx_ftz", + "cvt_and_store_f32_to_e4m3_hw", + "cvt_and_store_f32_to_e4m3_sw", + "has_hw_fp8_cvt", + "get_ptr_as_int64", + # Reduction utilities + "warp_reduce", + "block_reduce", + "row_reduce_sum", + # Predicate utilities + "predicate_k", + "predicate_k_3d", + # Configuration helpers + "compute_optimal_vec_size", + "compute_threads_per_row", + "make_tv_layout", + # Type utilities + "_torch_dtype_to_str", + # Re-exports from cute_dsl.utils + "get_cutlass_dtype", + "get_num_sm", +] diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 6814e892d1..2fefc117f4 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -149,7 +149,7 @@ template __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight, O* __restrict__ output, const uint32_t d, const uint32_t stride_input, const uint32_t stride_output, - float weight_bias, float scale, float eps) { + float weight_bias, float* scale, float eps) { const uint32_t bx = blockIdx.x; const uint32_t tx = threadIdx.x, ty = threadIdx.y; constexpr uint32_t warp_size = 32; @@ -158,7 +158,7 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight const uint32_t thread_id = tx + ty * warp_size; const uint32_t num_threads = num_warps * warp_size; const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); - const float scale_inv = 1.0f / scale; + const float scale_inv = 1.0f / scale[0]; extern __shared__ float smem[]; float sum_sq = 0.f; @@ -228,7 +228,7 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight template cudaError_t RMSNormQuant(T* input, T* weight, O* output, uint32_t batch_size, uint32_t d, - uint32_t stride_input, uint32_t stride_output, float scale, + uint32_t stride_input, uint32_t stride_output, float* scale, float eps = 1e-5, bool enable_pdl = false, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -519,7 +519,7 @@ __global__ void FusedAddRMSNormQuantKernel(T* __restrict__ input, T* __restrict_ const uint32_t d, const uint32_t stride_input, const uint32_t stride_residual, const uint32_t stride_output, float weight_bias, - float scale, float eps) { + float* scale, float eps) { const uint32_t bx = blockIdx.x; const uint32_t tx = threadIdx.x, ty = threadIdx.y; constexpr uint32_t warp_size = 32; @@ -527,7 +527,7 @@ __global__ void FusedAddRMSNormQuantKernel(T* __restrict__ input, T* __restrict_ const uint32_t thread_id = tx + ty * warp_size; const uint32_t num_threads = num_warps * warp_size; const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); - const float scale_inv = 1.0f / scale; + const float scale_inv = 1.0f / scale[0]; extern __shared__ float smem[]; float* smem_x = smem + ceil_div(num_warps, 4) * 4; @@ -613,7 +613,7 @@ __global__ void FusedAddRMSNormQuantKernel(T* __restrict__ input, T* __restrict_ template cudaError_t FusedAddRMSNormQuant(T* input, T* residual, T* weight, O* output, uint32_t batch_size, uint32_t d, uint32_t stride_input, uint32_t stride_residual, - uint32_t stride_output, float scale, float eps = 1e-5, + uint32_t stride_output, float* scale, float eps = 1e-5, bool enable_pdl = false, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); diff --git a/tests/utils/test_norm.py b/tests/utils/test_norm.py index 5e8b067cbe..1bda5005a2 100644 --- a/tests/utils/test_norm.py +++ b/tests/utils/test_norm.py @@ -149,7 +149,9 @@ def test_norm_quant( y_ref = llama_rms_norm_quant(x, w, quant_scale) y = torch.empty_like(x, dtype=torch.float8_e4m3fn, device="cuda") - flashinfer.norm.rmsnorm_quant(y, x, w, quant_scale, enable_pdl=enable_pdl) + flashinfer.norm.rmsnorm_quant( + y, x, w, torch.tensor(quant_scale, device="cuda"), enable_pdl=enable_pdl + ) torch.testing.assert_close(y_ref.float(), y.float(), rtol=1, atol=1) @@ -250,7 +252,13 @@ def test_fused_add_rmsnorm_quant( residual_fused = residual.clone() y = torch.empty_like(x, dtype=torch.float8_e4m3fn, device="cuda") flashinfer.norm.fused_add_rmsnorm_quant( - y, x_fused, residual_fused, weight, quant_scale, eps, enable_pdl=enable_pdl + y, + x_fused, + residual_fused, + weight, + torch.tensor(quant_scale, device="cuda"), + eps, + enable_pdl=enable_pdl, ) torch.testing.assert_close(y.float(), x_native.float(), rtol=1, atol=1)