diff --git a/CMakeLists.txt b/CMakeLists.txt index 0000b6d32be6..20ee377a761a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -657,7 +657,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_experts_quant.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu" - "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu" + "csrc/quantization/fp4/rmsnorm_nvfp4_quant_kernels.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${FP4_ARCHS}") @@ -683,7 +684,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_experts_quant.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" - "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu" + "csrc/quantization/fp4/rmsnorm_nvfp4_quant_kernels.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${FP4_ARCHS}") diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index fb3329975cee..91c3faed5926 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -13,10 +13,25 @@ from tqdm import tqdm import vllm._custom_ops as ops +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +# FP4 constants +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +# Check if NVFP4 fused ops are available +rms_norm_nvfp4_quant_supported = current_platform.is_cuda() and hasattr( + torch.ops._C, "rms_norm_nvfp4_quant" +) +fused_add_rms_norm_nvfp4_quant_supported = current_platform.is_cuda() and hasattr( + torch.ops._C, "fused_add_rms_norm_nvfp4_quant" +) @dataclass @@ -59,6 +74,7 @@ def unfused_int8_impl( residual: torch.Tensor | None, quant_dtype: torch.dtype, group_size: list[int], + **kwargs, ): # Norm torch_out = None @@ -77,6 +93,7 @@ def unfused_fp8_impl( residual: torch.Tensor | None, quant_dtype: torch.dtype, group_size: list[int], + **kwargs, ): # Norm torch_out = None @@ -95,6 +112,7 @@ def unfused_groupwise_fp8_impl( residual: torch.Tensor | None, quant_dtype: torch.dtype, group_size: list[int], + **kwargs, ): # Norm torch_out = None @@ -115,6 +133,7 @@ def fused_impl( residual: torch.Tensor | None, quant_dtype: torch.dtype, group_size: list[int], + **kwargs, ): out, _ = ops.rms_norm_dynamic_per_token_quant( x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual @@ -127,6 +146,7 @@ def fused_groupwise_impl( residual: torch.Tensor | None, quant_dtype: torch.dtype, group_size: list[int], + **kwargs, ): out, _ = ops.rms_norm_per_block_quant( x, @@ -139,6 +159,93 @@ def fused_groupwise_impl( ) +def get_fp4_output_tensors(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Allocate output tensors for FP4 quantization.""" + m, n = x.shape + block_size = 16 + # Two fp4 values will be packed into an uint8. + output = torch.empty((m, n // 2), device=x.device, dtype=torch.uint8) + # Swizzled scale layout for 128x4 tiles + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(m, 128) + scale_n = n // block_size + rounded_n = round_up(scale_n, 4) + output_scale = torch.empty( + (rounded_m, rounded_n // 4), device=x.device, dtype=torch.int32 + ) + return output, output_scale + + +def unfused_nvfp4_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], + global_scale: torch.Tensor | None = None, + **kwargs, +): + """Unfused RMSNorm + NVFP4 quantization implementation.""" + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Use pre-computed global_scale if provided (simulates real inference) + if global_scale is None: + global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs( + torch_out + ).max().to(torch.float32) + output_quant, output_scale = ops.scaled_fp4_quant(torch_out, global_scale) + + +def fused_nvfp4_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], + global_scale: torch.Tensor | None = None, + **kwargs, +): + """Fused RMSNorm + NVFP4 quantization implementation.""" + # Use pre-computed global_scale if provided (simulates real inference) + # In practice, global_scale is computed once during calibration + if global_scale is None: + torch_out_ref = rms_norm_layer.forward_cuda(x, None) + global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs( + torch_out_ref + ).max().to(torch.float32) + + # Allocate output tensors each time (matching unfused_nvfp4_impl behavior) + output_quant, output_scale = get_fp4_output_tensors(x) + + if residual is None: + # rms_norm_nvfp4_quant + torch.ops._C.rms_norm_nvfp4_quant( + output_quant, + output_scale, + x, + rms_norm_layer.weight, + global_scale, + 1e-6, + ) + else: + # fused_add_rms_norm_nvfp4_quant + # Note: residual is modified in-place, but for benchmark we accept this + torch.ops._C.fused_add_rms_norm_nvfp4_quant( + output_quant, + output_scale, + x, + residual, + rms_norm_layer.weight, + global_scale, + 1e-6, + ) + + # Bench functions def bench_fn( rms_norm_layer: RMSNorm, @@ -150,6 +257,7 @@ def bench_fn( sub_label: str, fn: Callable, description: str, + global_scale: torch.Tensor | None = None, ) -> TMeasurement: min_run_time = 1 @@ -159,10 +267,12 @@ def bench_fn( "residual": residual, "quant_dtype": quant_dtype, "group_size": group_size, + "global_scale": global_scale, "fn": fn, } return TBenchmark.Timer( - stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)", + stmt="fn(rms_norm_layer, x, residual, quant_dtype," + " group_size, global_scale=global_scale)", globals=globals, label=label, sub_label=sub_label, @@ -279,6 +389,53 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu ) ) + # NVFP4 benchmarks (only if supported, hidden_size is multiple of 16, + # and dtype is fp16/bf16 - NVFP4 does not support float32) + if params.hidden_size % 16 == 0 and params.dtype in (torch.float16, torch.bfloat16): + # Pre-compute global_scale ONCE before benchmark loop + # This simulates real inference where global_scale is calibrated offline + with torch.no_grad(): + torch_out_ref = layer.forward_cuda(x, None) + nvfp4_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs( + torch_out_ref + ).max().to(torch.float32) + + if rms_norm_nvfp4_quant_supported or fused_add_rms_norm_nvfp4_quant_supported: + # unfused nvfp4 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.uint8, # FP4 is packed as uint8 + params.group_size, + label, + sub_label, + unfused_nvfp4_impl, + "unfused_nvfp4_impl", + global_scale=nvfp4_global_scale, + ) + ) + + if (not params.add_residual and rms_norm_nvfp4_quant_supported) or ( + params.add_residual and fused_add_rms_norm_nvfp4_quant_supported + ): + # fused nvfp4 impl + timers.append( + bench_fn( + layer, + x, + residual, + torch.uint8, # FP4 is packed as uint8 + params.group_size, + label, + sub_label, + fused_nvfp4_impl, + "fused_nvfp4_impl", + global_scale=nvfp4_global_scale, + ) + ) + print_timers(timers) return timers @@ -296,8 +453,11 @@ def main(): bench_params = get_bench_params() timers = [] - for bp in tqdm(bench_params): - timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) + with set_current_vllm_config(VllmConfig()): + for bp in tqdm(bench_params): + timers.extend( + bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()) + ) print_timers(timers) # pickle all the results diff --git a/csrc/ops.h b/csrc/ops.h index 9ee6bda31f74..f61e5d8b928a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -139,6 +139,19 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, std::optional residual, int64_t group_size, bool is_scale_transposed); +#ifndef USE_ROCM +void rms_norm_nvfp4_quant(torch::Tensor& out, torch::Tensor& output_scale, + torch::Tensor& input, torch::Tensor& weight, + torch::Tensor& input_scale, double epsilon); + +void fused_add_rms_norm_nvfp4_quant(torch::Tensor& out, + torch::Tensor& output_scale, + torch::Tensor& input, + torch::Tensor& residual, + torch::Tensor& weight, + torch::Tensor& input_scale, double epsilon); +#endif + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); diff --git a/csrc/quantization/cuda_type_utils.cuh b/csrc/quantization/cuda_type_utils.cuh new file mode 100644 index 000000000000..15d26fde55df --- /dev/null +++ b/csrc/quantization/cuda_type_utils.cuh @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include +#include + +// Conditional compilation for FP4 element packing size +#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \ + defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) + #define ELTS_PER_THREAD 16 +constexpr int CVT_FP4_ELTS_PER_THREAD = 16; +constexpr bool CVT_FP4_PACK16 = true; +#else + #define ELTS_PER_THREAD 8 +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr bool CVT_FP4_PACK16 = false; +#endif + +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +namespace vllm { + +// Convert PyTorch cpp type to CUDA type +template +struct CUDATypeConverter { + using Type = T; +}; + +template <> +struct CUDATypeConverter { + using Type = half; +}; + +template <> +struct CUDATypeConverter { + using Type = __nv_bfloat16; +}; + +// Get type2 from type or vice versa (half <-> half2, bfloat16 <-> bfloat162) +template +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \ + defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) +// Define a 32 bytes packed data type. +template +struct alignas(32) PackedVec { + typename TypeConverter::Type elts[8]; +}; +#else +// Define a 16 bytes packed data type. +template +struct alignas(16) PackedVec { + typename TypeConverter::Type elts[4]; +}; +#endif + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +} // namespace vllm diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index 650b9da8a499..b42b5be559ae 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -51,6 +51,22 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( torch::Tensor const& output_scale_offset_by_experts); #endif +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void rms_norm_nvfp4_quant_sm1xxa(torch::Tensor& output, + torch::Tensor& output_scale, + torch::Tensor& input, torch::Tensor& weight, + torch::Tensor& input_scale, double epsilon); +#endif + +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void fused_add_rms_norm_nvfp4_quant_sm1xxa( + torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor& input, + torch::Tensor& residual, torch::Tensor& weight, torch::Tensor& input_scale, + double epsilon); +#endif + void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf, bool is_sf_swizzled_layout) { @@ -101,3 +117,28 @@ void silu_and_mul_scaled_fp4_experts_quant( TORCH_CHECK_NOT_IMPLEMENTED( false, "No compiled silu_and_mul nvfp4 experts quantization kernel"); } + +void rms_norm_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor& input, torch::Tensor& weight, + torch::Tensor& input_scale, double epsilon) { +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + return rms_norm_nvfp4_quant_sm1xxa(output, output_scale, input, weight, + input_scale, epsilon); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, + "No compiled rms_norm nvfp4 quantization kernel"); +} + +void fused_add_rms_norm_nvfp4_quant( + torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor& input, + torch::Tensor& residual, torch::Tensor& weight, torch::Tensor& input_scale, + double epsilon) { +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + return fused_add_rms_norm_nvfp4_quant_sm1xxa( + output, output_scale, input, residual, weight, input_scale, epsilon); +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, "No compiled fused_add_rms_norm nvfp4 quantization kernel"); +} diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index 3e7adb9e2931..59158c64918e 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -18,84 +18,10 @@ #include #include - -#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \ - defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) - #define ELTS_PER_THREAD 16 -constexpr int CVT_FP4_ELTS_PER_THREAD = 16; -constexpr bool CVT_FP4_PACK16 = true; -#else - #define ELTS_PER_THREAD 8 -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr bool CVT_FP4_PACK16 = false; -#endif - -constexpr int CVT_FP4_SF_VEC_SIZE = 16; +#include "../cuda_type_utils.cuh" namespace vllm { -// Convert PyTorch cpp type to CUDA type -template -struct CUDATypeConverter { - using Type = T; -}; - -template <> -struct CUDATypeConverter { - using Type = half; -}; - -template <> -struct CUDATypeConverter { - using Type = __nv_bfloat16; -}; - -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \ - defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) -// Define a 32 bytes packed data type. -template -struct alignas(32) PackedVec { - typename TypeConverter::Type elts[8]; -}; -#else -// Define a 16 bytes packed data type. -template -struct alignas(16) PackedVec { - typename TypeConverter::Type elts[4]; -}; -#endif - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - template __host__ __device__ inline Int round_up(Int x, Int y) { static_assert(std::is_integral_v, diff --git a/csrc/quantization/fp4/rmsnorm_nvfp4_quant_kernels.cu b/csrc/quantization/fp4/rmsnorm_nvfp4_quant_kernels.cu new file mode 100644 index 000000000000..999822d6e11c --- /dev/null +++ b/csrc/quantization/fp4/rmsnorm_nvfp4_quant_kernels.cu @@ -0,0 +1,394 @@ +#include + +#include +#include + +#include +#include + +#include +#include +#include "dispatch_utils.h" +#include "cub_helpers.h" + +#include "cuda_utils.h" +#include "launch_bounds_utils.h" + +// Define before including nvfp4_utils.cuh so the header +// can use this macro during compilation. +#define NVFP4_ENABLE_ELTS16 1 +#include "nvfp4_utils.cuh" +#include "../fused_kernels/layernorm_utils.cuh" + +namespace vllm { + +// Use UE4M3 by default. +template +__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) + rms_norm_cvt_fp16_to_fp4(const int num_tokens, const int hidden_size, + scalar_t const* __restrict__ input, + scalar_t const* __restrict__ weight, + float const* __restrict__ scale, + const float epsilon, uint32_t* __restrict__ output, + uint32_t* __restrict__ output_scale) { + using PackedVecT = vllm::PackedVec; + + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; + + __shared__ float s_rms_inv; + + // SF count along K, padded to a multiple of 4 (swizzle tile requirement). + const int sf_k_unpadded = hidden_size / CVT_FP4_SF_VEC_SIZE; + static constexpr int SF_KTILE_SIZE = 4; + const int sf_k_padded = + ((sf_k_unpadded + SF_KTILE_SIZE - 1) / SF_KTILE_SIZE) * SF_KTILE_SIZE; + const int32_t num_k_tiles = sf_k_padded / SF_KTILE_SIZE; + const float global_scale = (scale == nullptr) ? 1.0f : scale[0]; + const int vecs_per_row = hidden_size / CVT_FP4_ELTS_PER_THREAD; + + // SF layout requires rows padded to 128 and SF cols padded to 4. + const int sf_rows = (num_tokens + 127) / 128 * 128; + const int vecs_per_row_padded = + sf_k_padded * CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; + + for (int row_idx = blockIdx.x; row_idx < sf_rows; row_idx += gridDim.x) { + const bool valid_row = row_idx < num_tokens; + const scalar_t* row_input = input + row_idx * hidden_size; + + float variance = 0.0f; + for (int col_idx = threadIdx.x; col_idx < vecs_per_row_padded; + col_idx += blockDim.x) { + const int elem_idx = col_idx * CVT_FP4_ELTS_PER_THREAD; + PackedVecT vec{}; + + bool valid = valid_row && (elem_idx < hidden_size); + if constexpr (CVT_FP4_PACK16) { + ld256_or_zero_cg_u32( + vec, &reinterpret_cast(row_input)[col_idx * 8], + valid); + } else { + ld128_or_zero_cg_u32( + vec, &reinterpret_cast(row_input)[col_idx * 4], + valid); + } + + if (valid) { + variance += compute_packed_sum_squares(vec); + } + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduce_storage; + variance = + BlockReduce(reduce_storage).Reduce(variance, CubAddOp{}, blockDim.x); + + if (threadIdx.x == 0) { + s_rms_inv = rsqrtf(variance / static_cast(hidden_size) + epsilon); + } + __syncthreads(); + + const float rms_inv = s_rms_inv; + uint32_t* row_out = output + row_idx * vecs_per_row; + + for (int col_idx = threadIdx.x; col_idx < vecs_per_row_padded; + col_idx += blockDim.x) { + const int elem_idx = col_idx * CVT_FP4_ELTS_PER_THREAD; + const bool valid_col = elem_idx < hidden_size; + const bool valid = valid_row && valid_col; + + PackedVecT in_vec{}, w_vec{}; + + if constexpr (CVT_FP4_PACK16) { + ld256_or_zero_cg_u32( + in_vec, &reinterpret_cast(row_input)[col_idx * 8], + valid); + ld256_or_zero_cg_u32( + w_vec, &reinterpret_cast(weight)[col_idx * 8], + valid); + } else { + ld128_or_zero_cg_u32( + in_vec, &reinterpret_cast(row_input)[col_idx * 4], + valid); + ld128_or_zero_cg_u32( + w_vec, &reinterpret_cast(weight)[col_idx * 4], + valid); + } + + PackedVecT norm_vec = compute_rms_norm(in_vec, w_vec, rms_inv); + + uint8_t* sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + row_idx, col_idx, num_k_tiles, output_scale); + + auto fp4_packed = + cvt_warp_fp16_to_fp4( + norm_vec, global_scale, sf_out); + + if (valid) { + if constexpr (CVT_FP4_PACK16) { + int64_t out_offset = row_idx * (hidden_size / 8) + col_idx * 2; + uint64_t packed64 = + (uint64_t(fp4_packed.hi) << 32) | uint64_t(fp4_packed.lo); + reinterpret_cast(output)[out_offset >> 1] = packed64; + } else { + row_out[col_idx] = fp4_packed; + } + } + } + + __syncthreads(); + } +} + +// Fused Add + RMSNorm + FP4 quantization kernel +// Performs: residual = input + residual, then RMSNorm(residual) -> FP4 quant +// Addition is done in high precision (BF16/FP16) before normalization. +template +__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) + fused_add_rms_norm_cvt_fp16_to_fp4( + const int num_tokens, const int hidden_size, + scalar_t const* __restrict__ input, scalar_t* __restrict__ residual, + scalar_t const* __restrict__ weight, float const* __restrict__ scale, + const float epsilon, uint32_t* __restrict__ output, + uint32_t* __restrict__ output_scale) { + using PackedVecT = vllm::PackedVec; + + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; + + __shared__ float s_rms_inv; + + // SF count along K, padded to a multiple of 4 (swizzle tile requirement). + const int sf_k_unpadded = hidden_size / CVT_FP4_SF_VEC_SIZE; + static constexpr int SF_KTILE_SIZE = 4; + const int sf_k_padded = + ((sf_k_unpadded + SF_KTILE_SIZE - 1) / SF_KTILE_SIZE) * SF_KTILE_SIZE; + const int32_t num_k_tiles = sf_k_padded / SF_KTILE_SIZE; + const float global_scale = (scale == nullptr) ? 1.0f : scale[0]; + const int vecs_per_row = hidden_size / CVT_FP4_ELTS_PER_THREAD; + + // SF layout requires rows padded to 128 and SF cols padded to 4. + const int sf_rows = (num_tokens + 127) / 128 * 128; + const int vecs_per_row_padded = + sf_k_padded * CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; + + for (int row_idx = blockIdx.x; row_idx < sf_rows; row_idx += gridDim.x) { + const bool valid_row = row_idx < num_tokens; + const scalar_t* row_input = input + row_idx * hidden_size; + scalar_t* row_residual = residual + row_idx * hidden_size; + + // First pass: compute x = input + residual, update residual, compute + // variance + float variance = 0.0f; + for (int col_idx = threadIdx.x; col_idx < vecs_per_row_padded; + col_idx += blockDim.x) { + const int elem_idx = col_idx * CVT_FP4_ELTS_PER_THREAD; + PackedVecT in_vec{}, res_vec{}, added_vec{}; + + bool valid = valid_row && (elem_idx < hidden_size); + if constexpr (CVT_FP4_PACK16) { + ld256_or_zero_cg_u32( + in_vec, &reinterpret_cast(row_input)[col_idx * 8], + valid); + ld256_or_zero_cg_u32( + res_vec, + &reinterpret_cast(row_residual)[col_idx * 8], + valid); + } else { + ld128_or_zero_cg_u32( + in_vec, &reinterpret_cast(row_input)[col_idx * 4], + valid); + ld128_or_zero_cg_u32( + res_vec, + &reinterpret_cast(row_residual)[col_idx * 4], + valid); + } + + if (valid) { + // Compute fused add and sum squares in one pass + variance += + compute_packed_fused_add_sum_squares(in_vec, res_vec, added_vec); + + // Write back updated residual + if constexpr (CVT_FP4_PACK16) { + // 32 bytes = 2 x uint4 (128-bit) + *(reinterpret_cast(row_residual + + col_idx * CVT_FP4_ELTS_PER_THREAD)) = + *(reinterpret_cast(&added_vec)); + *(reinterpret_cast(row_residual + + col_idx * CVT_FP4_ELTS_PER_THREAD + 8)) = + *(reinterpret_cast(&added_vec) + 1); + } else { + // 16 bytes = 1 x uint4 (128-bit) + *(reinterpret_cast(row_residual + + col_idx * CVT_FP4_ELTS_PER_THREAD)) = + *(reinterpret_cast(&added_vec)); + } + } + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduce_storage; + variance = + BlockReduce(reduce_storage).Reduce(variance, CubAddOp{}, blockDim.x); + + if (threadIdx.x == 0) { + s_rms_inv = rsqrtf(variance / static_cast(hidden_size) + epsilon); + } + __syncthreads(); + + const float rms_inv = s_rms_inv; + uint32_t* row_out = output + row_idx * vecs_per_row; + + // Second pass: read updated residual, apply RMSNorm and quantize + for (int col_idx = threadIdx.x; col_idx < vecs_per_row_padded; + col_idx += blockDim.x) { + const int elem_idx = col_idx * CVT_FP4_ELTS_PER_THREAD; + const bool valid_col = elem_idx < hidden_size; + const bool valid = valid_row && valid_col; + + PackedVecT in_vec{}, w_vec{}; + + if constexpr (CVT_FP4_PACK16) { + // Read from updated residual + ld256_or_zero_cg_u32( + in_vec, + &reinterpret_cast(row_residual)[col_idx * 8], + valid); + ld256_or_zero_cg_u32( + w_vec, &reinterpret_cast(weight)[col_idx * 8], + valid); + } else { + ld128_or_zero_cg_u32( + in_vec, + &reinterpret_cast(row_residual)[col_idx * 4], + valid); + ld128_or_zero_cg_u32( + w_vec, &reinterpret_cast(weight)[col_idx * 4], + valid); + } + + PackedVecT norm_vec = compute_rms_norm(in_vec, w_vec, rms_inv); + + uint8_t* sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + row_idx, col_idx, num_k_tiles, output_scale); + + auto fp4_packed = + cvt_warp_fp16_to_fp4( + norm_vec, global_scale, sf_out); + + if (valid) { + if constexpr (CVT_FP4_PACK16) { + int64_t out_offset = row_idx * (hidden_size / 8) + col_idx * 2; + uint64_t packed64 = + (uint64_t(fp4_packed.hi) << 32) | uint64_t(fp4_packed.lo); + reinterpret_cast(output)[out_offset >> 1] = packed64; + } else { + row_out[col_idx] = fp4_packed; + } + } + } + + __syncthreads(); + } +} + +} // namespace vllm + +void rms_norm_nvfp4_quant_sm1xxa( + torch::Tensor& output, // [..., hidden_size/2] uint8 (packed FP4) + torch::Tensor& output_scale, // block scale, int32 (swizzled layout) + torch::Tensor& input, // [..., hidden_size] BF16/FP16 + torch::Tensor& weight, // [hidden_size] BF16/FP16 + torch::Tensor& scale, // [1] float32 (global scale) + double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + TORCH_CHECK(hidden_size % 16 == 0, "The hidden_size must be multiple of 16."); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for rms_norm_nvfp4_quant."); + + int multi_processor_count = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Use same block size convention as fused_fp8 kernels for consistency + // Kernel internally handles the case where blockDim.x > vecs_per_row + dim3 block(std::min(hidden_size, 1024)); + int const num_blocks_per_sm = + vllm_runtime_blocks_per_sm(static_cast(block.x)); + // SF layout pads rows to 128, so we need to process those padded rows too + int effective_rows = (num_tokens + 127) / 128 * 128; + dim3 grid( + std::min(effective_rows, multi_processor_count * num_blocks_per_sm)); + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "rms_norm_nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + vllm::rms_norm_cvt_fp16_to_fp4<<>>( + num_tokens, hidden_size, + reinterpret_cast(input.data_ptr()), + reinterpret_cast(weight.data_ptr()), + scale.data_ptr(), static_cast(epsilon), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(output_scale.data_ptr())); + }); +} + +void fused_add_rms_norm_nvfp4_quant_sm1xxa( + torch::Tensor& output, // [..., hidden_size/2] uint8 (packed FP4) + torch::Tensor& output_scale, // block scale, int32 (swizzled layout) + torch::Tensor& input, // [..., hidden_size] BF16/FP16 + torch::Tensor& residual, // [..., hidden_size] BF16/FP16 (in-place updated) + torch::Tensor& weight, // [hidden_size] BF16/FP16 + torch::Tensor& scale, // [1] float32 (global scale) + double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + TORCH_CHECK(hidden_size % 16 == 0, "The hidden_size must be multiple of 16."); + TORCH_CHECK( + input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for fused_add_rms_norm_nvfp4_quant."); + TORCH_CHECK(input.scalar_type() == residual.scalar_type(), + "Input and residual must have the same dtype."); + + int multi_processor_count = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Use same block size convention as fused_fp8 kernels for consistency + // Kernel internally handles the case where blockDim.x > vecs_per_row + dim3 block(std::min(hidden_size, 1024)); + int const num_blocks_per_sm = + vllm_runtime_blocks_per_sm(static_cast(block.x)); + // SF layout pads rows to 128, so we need to process those padded rows too + int effective_rows = (num_tokens + 127) / 128 * 128; + dim3 grid( + std::min(effective_rows, multi_processor_count * num_blocks_per_sm)); + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "fused_add_rms_norm_nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + vllm::fused_add_rms_norm_cvt_fp16_to_fp4 + <<>>( + num_tokens, hidden_size, + reinterpret_cast(input.data_ptr()), + reinterpret_cast(residual.data_ptr()), + reinterpret_cast(weight.data_ptr()), + scale.data_ptr(), static_cast(epsilon), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(output_scale.data_ptr())); + }); +} diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index cb7adc312573..4169fbd50b6e 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -6,6 +6,7 @@ #include "quantization/vectorization.cuh" #include "quantization/utils.cuh" +#include "quantization/cuda_type_utils.cuh" #include "quant_conversions.cuh" #include "../../cub_helpers.h" @@ -44,9 +45,9 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, *rms = s_rms; } -__device__ float warpReduceMaxSpecialized(volatile float* val, int64_t tid, - int64_t thread_in_warp, - int64_t reduced_elems) { +__device__ __forceinline__ float warpReduceMaxSpecialized( + volatile float* val, int64_t tid, int64_t thread_in_warp, + int64_t reduced_elems) { static_assert(WARP_SIZE == 32 || WARP_SIZE == 64); if constexpr (WARP_SIZE == 64) { if (thread_in_warp + 64 < reduced_elems) @@ -536,4 +537,95 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, } // namespace vectorized +// Compute sum of squares for a PackedVec (CVT_FP4_ELTS_PER_THREAD elements). +// Used in RMSNorm variance calculation. +template +__device__ __forceinline__ float compute_packed_sum_squares( + const PackedVec& vec) { + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { + float2 fp2; + if constexpr (std::is_same_v) { + fp2 = __half22float2(vec.elts[i]); + } else { + fp2 = __bfloat1622float2(vec.elts[i]); + } + sum += fp2.x * fp2.x + fp2.y * fp2.y; + } + return sum; +} + +// RMSNorm: output = input * rms_inv * weight +// rms_inv = rsqrt(mean(x^2) + epsilon) +// Match Python reference: cast to target dtype after rms_inv multiplication, +// then multiply with weight in target dtype precision. +template +__device__ __forceinline__ PackedVec compute_rms_norm( + const PackedVec& in_vec, const PackedVec& w_vec, + float rms_inv) { + PackedVec result{}; +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { + float2 in_fp2, w_fp2; + if constexpr (std::is_same_v) { + in_fp2 = __half22float2(in_vec.elts[i]); + w_fp2 = __half22float2(w_vec.elts[i]); + // Cast to half after rms_inv, then multiply with weight in half + half2 normalized = __float22half2_rn( + make_float2(in_fp2.x * rms_inv, in_fp2.y * rms_inv)); + result.elts[i] = __hmul2(normalized, w_vec.elts[i]); + } else { + in_fp2 = __bfloat1622float2(in_vec.elts[i]); + w_fp2 = __bfloat1622float2(w_vec.elts[i]); + // Cast to bfloat16 after rms_inv, then multiply with weight in bfloat16 + __nv_bfloat162 normalized = __float22bfloat162_rn( + make_float2(in_fp2.x * rms_inv, in_fp2.y * rms_inv)); + result.elts[i] = __hmul2(normalized, w_vec.elts[i]); + } + } + return result; +} + +// Compute fused add for packed vectors: in_vec + res_vec +// Both addition and result are in high precision (BF16/FP16). +// __hadd2 is overloaded for both half2 and __nv_bfloat162 (SM80+). +template +__device__ __forceinline__ PackedVec compute_packed_fused_add( + const PackedVec& in_vec, const PackedVec& res_vec) { + PackedVec result{}; +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { + result.elts[i] = __hadd2(in_vec.elts[i], res_vec.elts[i]); + } + return result; +} + +// Compute sum of squares for a PackedVec after fused add: (in_vec + res_vec) +// Returns sum of squares and updates result with the added values. +template +__device__ __forceinline__ float compute_packed_fused_add_sum_squares( + const PackedVec& in_vec, const PackedVec& res_vec, + PackedVec& result) { + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { + float2 in_fp2, res_fp2; + if constexpr (std::is_same_v) { + in_fp2 = __half22float2(in_vec.elts[i]); + res_fp2 = __half22float2(res_vec.elts[i]); + float2 added = make_float2(in_fp2.x + res_fp2.x, in_fp2.y + res_fp2.y); + result.elts[i] = __float22half2_rn(added); + sum += added.x * added.x + added.y * added.y; + } else { + in_fp2 = __bfloat1622float2(in_vec.elts[i]); + res_fp2 = __bfloat1622float2(res_vec.elts[i]); + float2 added = make_float2(in_fp2.x + res_fp2.x, in_fp2.y + res_fp2.y); + result.elts[i] = __float22bfloat162_rn(added); + sum += added.x * added.x + added.y * added.y; + } + } + return sum; +} + } // namespace vllm diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 97c0e80e7676..12e3c128cb24 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -223,6 +223,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "bool is_scale_transposed) -> ()"); ops.impl("rms_norm_per_block_quant", torch::kCUDA, &rms_norm_per_block_quant); +#ifndef USE_ROCM + // Fused RMSNorm + NVFP4 quantization + ops.def( + "rms_norm_nvfp4_quant(Tensor! result, Tensor! result_scale, " + "Tensor input, Tensor weight, Tensor input_scale, float epsilon) -> ()"); + ops.impl("rms_norm_nvfp4_quant", torch::kCUDA, &rms_norm_nvfp4_quant); + + // Fused Add + RMSNorm + NVFP4 quantization + ops.def( + "fused_add_rms_norm_nvfp4_quant(Tensor! result, Tensor! result_scale, " + "Tensor input, Tensor! residual, Tensor weight, Tensor input_scale, " + "float epsilon) -> ()"); + ops.impl("fused_add_rms_norm_nvfp4_quant", torch::kCUDA, + &fused_add_rms_norm_nvfp4_quant); +#endif + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( diff --git a/tests/compile/passes/test_fusion.py b/tests/compile/passes/test_fusion.py index a2128150f701..7b877b47fa3c 100644 --- a/tests/compile/passes/test_fusion.py +++ b/tests/compile/passes/test_fusion.py @@ -10,11 +10,14 @@ from tests.compile.backend import TestBackend from tests.utils import TestBlockFP8Layer, TestFP8Layer from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS from vllm.compilation.passes.fusion.rms_quant_fusion import ( FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass, + fused_add_rms_norm_nvfp4_quant_supported, + rms_norm_nvfp4_quant_supported, ) from vllm.compilation.passes.fx_utils import find_op_nodes from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass @@ -48,6 +51,7 @@ GroupShape, QuantKey, ScaleDesc, + kNvfp4Dynamic, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, @@ -58,6 +62,7 @@ ) FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 RMS_OP = torch.ops._C.rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default @@ -439,3 +444,251 @@ def test_aiter_fusion_rmsnorm_quant( _run_fusion_test( model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens ) + + +def is_nvfp4_supported(): + return current_platform.has_device_capability(100) + + +def quant_nvfp4_tensor(a: torch.Tensor): + from vllm.scalar_type import scalar_types + + FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() + FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + a_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to( + torch.float32 + ) + a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale) + return a_quant, a_block_scale, a_global_scale + + +class TestRMSNormNvfp4QuantModel(torch.nn.Module): + def __init__(self, hidden_size: int, eps: float, x: torch.Tensor): + super().__init__() + self.norm = RMSNorm(hidden_size, eps) + self.enable_rms_norm_custom_op = self.norm.enabled() + + w = torch.rand((hidden_size, hidden_size)) + self.w, self.w_block_scale, self.w_global_scale = quant_nvfp4_tensor(w) + + y = self.norm(x) + _, _, self.y_global_scale = quant_nvfp4_tensor(y) + self.alpha = 1.0 / (self.w_global_scale * self.y_global_scale) + + def forward(self, x): + y = self.norm(x) + y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale) + out = cutlass_scaled_fp4_mm( + a=y_quant, + b=self.w, + block_scale_a=y_block_scale, + block_scale_b=self.w_block_scale, + alpha=self.alpha, + out_dtype=y.dtype, + ) + return out + + def ops_in_model_before(self): + return [ + RMS_OP if self.enable_rms_norm_custom_op else torch.ops.aten.rsqrt, + QUANT_OPS[kNvfp4Dynamic], + ] + + def ops_in_model_after(self): + return [FUSED_OPS[FusedRMSQuantKey(kNvfp4Dynamic, False)]] + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("hidden_size", [128, 256]) +@pytest.mark.parametrize("num_tokens", [32, 64]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test on CUDA") +def test_fusion_rmsnorm_nvfp4_quant( + dtype: torch.dtype, + hidden_size: int, + num_tokens: int, + eps: float, + enable_rms_norm_custom_op: bool, +): + if not is_nvfp4_supported(): + pytest.skip("NVFP4 is not supported on this GPU.") + + if not rms_norm_nvfp4_quant_supported: + pytest.skip("rms_norm_nvfp4_quant op is not available.") + + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(42) + + x = torch.rand(num_tokens, hidden_size) + + custom_ops = ["none"] + if enable_rms_norm_custom_op: + custom_ops.append("+rms_norm") + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops, + backend="eager", + pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config): + fusion_pass = RMSNormQuantFusionPass(vllm_config) + noop_pass = NoOpEliminationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + + model = TestRMSNormNvfp4QuantModel(hidden_size, eps, x) + + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + atol, rtol = 1e-1, 1e-1 + torch.testing.assert_close(result, result2, atol=atol, rtol=rtol) + + assert fusion_pass.matched_count == 1 + backend.check_before_ops(model.ops_in_model_before()) + backend.check_after_ops(model.ops_in_model_after()) + + +class TestFusedAddRMSNormNvfp4QuantModel(torch.nn.Module): + """Test model for fused_add_rms_norm + nvfp4 quant fusion.""" + + def __init__(self, hidden_size: int, eps: float, x: torch.Tensor): + super().__init__() + self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] + self.enable_rms_norm_custom_op = self.norm[0].enabled() + + w = torch.rand((hidden_size, hidden_size)) + self.w, self.w_block_scale, self.w_global_scale = quant_nvfp4_tensor(w) + + # Compute global scale from a sample forward pass + resid = torch.zeros_like(x) + y, _ = self.norm[0](x, resid) + _, _, self.y_global_scale = quant_nvfp4_tensor(y) + self.alpha = 1.0 / (self.w_global_scale * self.y_global_scale) + + def forward(self, x): + # Avoid having graph input be an arg to a pattern directly + x = torch.relu(x) + resid = torch.tanh(x) + + # First: fused_add_rms_norm + nvfp4 quant + y, resid = self.norm[0](x, resid) + y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale) + out1 = cutlass_scaled_fp4_mm( + a=y_quant, + b=self.w, + block_scale_a=y_block_scale, + block_scale_b=self.w_block_scale, + alpha=self.alpha, + out_dtype=y.dtype, + ) + + # Second: fused_add_rms_norm + nvfp4 quant + y2, resid = self.norm[1](out1, resid) + y2_quant, y2_block_scale = scaled_fp4_quant(y2, self.y_global_scale) + out2 = cutlass_scaled_fp4_mm( + a=y2_quant, + b=self.w, + block_scale_a=y2_block_scale, + block_scale_b=self.w_block_scale, + alpha=self.alpha, + out_dtype=y2.dtype, + ) + + # Third: fused_add_rms_norm + nvfp4 quant + y3, resid = self.norm[2](out2, resid) + y3_quant, y3_block_scale = scaled_fp4_quant(y3, self.y_global_scale) + out3 = cutlass_scaled_fp4_mm( + a=y3_quant, + b=self.w, + block_scale_a=y3_block_scale, + block_scale_b=self.w_block_scale, + alpha=self.alpha, + out_dtype=y3.dtype, + ) + + return out3 + + def ops_in_model_before(self): + return [ + RMS_ADD_OP if self.enable_rms_norm_custom_op else torch.ops.aten.rsqrt, + QUANT_OPS[kNvfp4Dynamic], + ] + + def ops_in_model_after(self): + return [FUSED_OPS[FusedRMSQuantKey(kNvfp4Dynamic, True)]] + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("hidden_size", [128, 256]) +@pytest.mark.parametrize("num_tokens", [32, 64]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test on CUDA") +def test_fusion_fused_add_rmsnorm_nvfp4_quant( + dtype: torch.dtype, + hidden_size: int, + num_tokens: int, + eps: float, + enable_rms_norm_custom_op: bool, +): + if not is_nvfp4_supported(): + pytest.skip("NVFP4 is not supported on this GPU.") + + if not fused_add_rms_norm_nvfp4_quant_supported: + pytest.skip("fused_add_rms_norm_nvfp4_quant op is not available.") + + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(42) + + x = torch.rand(num_tokens, hidden_size) + + custom_ops = ["none"] + if enable_rms_norm_custom_op: + custom_ops.append("+rms_norm") + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops, + backend="eager", + pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config): + fusion_pass = RMSNormQuantFusionPass(vllm_config) + noop_pass = NoOpEliminationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + + model = TestFusedAddRMSNormNvfp4QuantModel(hidden_size, eps, x) + + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + atol, rtol = 1e-1, 1e-1 + torch.testing.assert_close(result, result2, atol=atol, rtol=rtol) + + assert fusion_pass.matched_count == 2 + backend.check_before_ops(model.ops_in_model_before()) + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/kernels/quantization/test_fused_add_rmsnorm_nvfp4_quant.py b/tests/kernels/quantization/test_fused_add_rmsnorm_nvfp4_quant.py new file mode 100644 index 000000000000..27214d384272 --- /dev/null +++ b/tests/kernels/quantization/test_fused_add_rmsnorm_nvfp4_quant.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for fused Add + RMSNorm + NVFP4 quantization kernel. +""" + +import pytest +import torch + +from tests.kernels.quantization.nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype, +) +from vllm._custom_ops import scaled_fp4_quant +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) + +FP4_DTYPE = torch.uint8 +FP8_DTYPE = current_platform.fp8_dtype() + +DTYPES = [torch.float16, torch.bfloat16] +SHAPES = [(128, 256), (128, 128), (256, 256), (256, 128)] +EPSILON = 1e-6 + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@torch.inference_mode() +def test_fused_add_rms_norm_nvfp4_quant( + default_vllm_config, + dtype: torch.dtype, + shape: tuple[int, int], +) -> None: + """Test fused Add + RMSNorm + NVFP4 quantization kernel.""" + set_random_seed(42) + device = "cuda:0" + torch.set_default_device(device) + + num_tokens, hidden_size = shape + x = torch.randn(shape, dtype=dtype) + residual = torch.randn(shape, dtype=dtype) + weight = torch.randn(hidden_size, dtype=dtype) + + # Reference: x + residual -> RMSNorm -> scaled_fp4_quant + ref_residual = x + residual + rms_norm = RMSNorm(hidden_size, EPSILON).to(dtype=dtype, device=device) + rms_norm.weight.data.copy_(weight) + ref_output = rms_norm.forward_native(ref_residual) + + ref_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs( + ref_output + ).max().to(torch.float32) + ref_output_quant, ref_block_scale = scaled_fp4_quant(ref_output, ref_global_scale) + + # Fused op: fused_add_rms_norm_nvfp4_quant + fused_output_quant = torch.empty_like(ref_output_quant) + fused_block_scale = torch.empty_like(ref_block_scale) + fused_residual = residual.clone() # Will be updated in-place + torch.ops._C.fused_add_rms_norm_nvfp4_quant( + fused_output_quant, + fused_block_scale, + x, + fused_residual, + weight, + ref_global_scale, + EPSILON, + ) + + # Check residual is updated in-place + torch.testing.assert_close(fused_residual, ref_residual, atol=1e-5, rtol=1e-5) + + # Check dtype + assert ref_output_quant.dtype == FP4_DTYPE + assert fused_output_quant.dtype == FP4_DTYPE + assert ref_output_quant.shape == fused_output_quant.shape + + assert ref_block_scale.dtype == FP8_DTYPE + assert fused_block_scale.dtype == FP8_DTYPE + assert ref_block_scale.shape == fused_block_scale.shape + + # Check dequantized output + ref_output_dequant = dequantize_nvfp4_to_dtype( + ref_output_quant, ref_block_scale, ref_global_scale, dtype, device + ) + fused_output_dequant = dequantize_nvfp4_to_dtype( + fused_output_quant, fused_block_scale, ref_global_scale, dtype, device + ) + + atol, rtol = 3e-1, 3e-1 + torch.testing.assert_close( + ref_output_dequant, fused_output_dequant, atol=atol, rtol=rtol + ) + + +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode() +def test_fused_add_rms_norm_nvfp4_quant_large( + default_vllm_config, + dtype: torch.dtype, +) -> None: + """Test fused Add + RMSNorm + NVFP4 quantization kernel with larger shapes.""" + set_random_seed(42) + device = "cuda:0" + torch.set_default_device(device) + + # Test with larger hidden sizes (typical model dimensions) + shape = (64, 4096) + num_tokens, hidden_size = shape + x = torch.randn(shape, dtype=dtype) + residual = torch.randn(shape, dtype=dtype) + weight = torch.randn(hidden_size, dtype=dtype) + + # Reference: x + residual -> RMSNorm -> scaled_fp4_quant + ref_residual = x + residual + rms_norm = RMSNorm(hidden_size, EPSILON).to(dtype=dtype, device=device) + rms_norm.weight.data.copy_(weight) + ref_output = rms_norm.forward_native(ref_residual) + + ref_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs( + ref_output + ).max().to(torch.float32) + ref_output_quant, ref_block_scale = scaled_fp4_quant(ref_output, ref_global_scale) + + # Fused op: fused_add_rms_norm_nvfp4_quant + fused_output_quant = torch.empty_like(ref_output_quant) + fused_block_scale = torch.empty_like(ref_block_scale) + fused_residual = residual.clone() + torch.ops._C.fused_add_rms_norm_nvfp4_quant( + fused_output_quant, + fused_block_scale, + x, + fused_residual, + weight, + ref_global_scale, + EPSILON, + ) + + # Check residual is updated in-place + torch.testing.assert_close(fused_residual, ref_residual, atol=1e-5, rtol=1e-5) + + # Check dequantized output + ref_output_dequant = dequantize_nvfp4_to_dtype( + ref_output_quant, ref_block_scale, ref_global_scale, dtype, device + ) + fused_output_dequant = dequantize_nvfp4_to_dtype( + fused_output_quant, fused_block_scale, ref_global_scale, dtype, device + ) + + atol, rtol = 3e-1, 3e-1 + torch.testing.assert_close( + ref_output_dequant, fused_output_dequant, atol=atol, rtol=rtol + ) diff --git a/tests/kernels/quantization/test_rmsnorm_nvfp4_quant.py b/tests/kernels/quantization/test_rmsnorm_nvfp4_quant.py new file mode 100644 index 000000000000..cd7c6ee6c288 --- /dev/null +++ b/tests/kernels/quantization/test_rmsnorm_nvfp4_quant.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for fused RMSNorm + NVFP4 quantization kernel. +""" + +import pytest +import torch + +from tests.kernels.quantization.nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype, +) +from vllm._custom_ops import scaled_fp4_quant +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed + +if not current_platform.has_device_capability(100): + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) + +FP4_DTYPE = torch.uint8 +FP8_DTYPE = current_platform.fp8_dtype() + +DTYPES = [torch.float16, torch.bfloat16] +SHAPES = [(128, 256), (128, 128), (256, 256), (256, 128)] +EPSILON = 1e-6 + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@torch.inference_mode() +def test_rms_norm_nvfp4_quant( + default_vllm_config, + dtype: torch.dtype, + shape: tuple[int, int], +) -> None: + """Test RMSNorm + NVFP4 quantization fusion.""" + set_random_seed(42) + device = "cuda:0" + torch.set_default_device(device) + + num_tokens, hidden_size = shape + x = torch.randn(shape, dtype=dtype) + weight = torch.randn(hidden_size, dtype=dtype) + + # Reference: RMSNorm -> scaled_fp4_quant + rms_norm = RMSNorm(hidden_size, EPSILON).to(dtype=dtype, device=device) + rms_norm.weight.data.copy_(weight) + ref_output = rms_norm.forward_native(x) + + ref_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs( + ref_output + ).max().to(torch.float32) + ref_output_quant, ref_block_scale = scaled_fp4_quant(ref_output, ref_global_scale) + + # Fused op: rms_norm_nvfp4_quant + fused_output_quant = torch.empty_like(ref_output_quant) + fused_block_scale = torch.empty_like(ref_block_scale) + torch.ops._C.rms_norm_nvfp4_quant( + fused_output_quant, fused_block_scale, x, weight, ref_global_scale, EPSILON + ) + + # Check dtype + assert ref_output_quant.dtype == FP4_DTYPE + assert fused_output_quant.dtype == FP4_DTYPE + assert ref_output_quant.shape == fused_output_quant.shape + + assert ref_block_scale.dtype == FP8_DTYPE + assert fused_block_scale.dtype == FP8_DTYPE + assert ref_block_scale.shape == fused_block_scale.shape + + # Check dequantized output + ref_output_dequant = dequantize_nvfp4_to_dtype( + ref_output_quant, ref_block_scale, ref_global_scale, dtype, device + ) + fused_output_dequant = dequantize_nvfp4_to_dtype( + fused_output_quant, fused_block_scale, ref_global_scale, dtype, device + ) + + atol, rtol = 3e-1, 3e-1 + torch.testing.assert_close( + ref_output_dequant, fused_output_dequant, atol=atol, rtol=rtol + ) diff --git a/vllm/compilation/passes/fusion/rms_quant_fusion.py b/vllm/compilation/passes/fusion/rms_quant_fusion.py index eac9fea286e0..2c5f5a78ae5b 100644 --- a/vllm/compilation/passes/fusion/rms_quant_fusion.py +++ b/vllm/compilation/passes/fusion/rms_quant_fusion.py @@ -113,6 +113,24 @@ def __str__(self) -> str: ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 } +# Check if NVFP4 fused ops are available +rms_norm_nvfp4_quant_supported = current_platform.is_cuda() and hasattr( + torch.ops._C, "rms_norm_nvfp4_quant" +) +if rms_norm_nvfp4_quant_supported: + FUSED_OPS[FusedRMSQuantKey(kNvfp4Dynamic, False)] = ( + torch.ops._C.rms_norm_nvfp4_quant.default + ) # noqa: E501 + +# Check if fused_add_rms_norm_nvfp4_quant is available +fused_add_rms_norm_nvfp4_quant_supported = current_platform.is_cuda() and hasattr( + torch.ops._C, "fused_add_rms_norm_nvfp4_quant" +) +if fused_add_rms_norm_nvfp4_quant_supported: + FUSED_OPS[FusedRMSQuantKey(kNvfp4Dynamic, True)] = ( + torch.ops._C.fused_add_rms_norm_nvfp4_quant.default + ) # noqa: E501 + class RMSNormQuantPattern: def __init__( @@ -494,6 +512,149 @@ def replacement( ) +class RMSNormNvfp4QuantPattern: + """ + Fusion pattern for RMSNorm + NVFP4 quantization. + """ + + def __init__(self, epsilon: float) -> None: + self.epsilon = epsilon + config = get_current_vllm_config() + self.model_dtype = config.model_config.dtype if config.model_config else None + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + self.QUANT_OP = QUANT_OPS[kNvfp4Dynamic] + self.FUSED_OP = FUSED_OPS[FusedRMSQuantKey(kNvfp4Dynamic, False)] + + def get_inputs(self) -> list[torch.Tensor]: + # Use rmsnorm_matcher.inputs() to respect model dtype (bf16 or fp16) + rms_inputs = self.rmsnorm_matcher.inputs() # [input, weight] + input_ = rms_inputs[0] # (5, 16) + weight = rms_inputs[1] # (16,) + hidden_size = input_.shape[1] + # FP4 packs 2 values per uint8, so output cols = hidden_size / 2 + result = torch.empty(5, hidden_size // 2, dtype=FP4_DTYPE, device="cuda") + output_scale = empty_i32(128, 1) + input_scale = empty_fp32(1, 1) + return [result, output_scale, input_, weight, input_scale] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + result: torch.Tensor, + output_scale: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + result_rms = self.rmsnorm_matcher(input, weight) + at = auto_functionalized( + self.QUANT_OP, + output=result, + input=result_rms, + output_scale=output_scale, + input_scale=input_scale, + is_sf_swizzled_layout=True, + ) + return at[1], at[2] + + def replacement( + result: torch.Tensor, + output_scale: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + input = input.to(dtype=self.model_dtype) + at = auto_functionalized( + self.FUSED_OP, + result=result, + result_scale=output_scale, + input=input, + weight=weight, + input_scale=input_scale, + epsilon=self.epsilon, + ) + return at[1], at[2] + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class FusedAddRMSNormNvfp4QuantPattern: + """ + Fusion pattern for FusedAddRMSNorm + NVFP4 quantization. + """ + + def __init__(self, epsilon: float) -> None: + self.epsilon = epsilon + config = get_current_vllm_config() + self.model_dtype = config.model_config.dtype if config.model_config else None + self.fused_add_rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + self.QUANT_OP = QUANT_OPS[kNvfp4Dynamic] + self.FUSED_OP = FUSED_OPS[FusedRMSQuantKey(kNvfp4Dynamic, True)] + + def get_inputs(self) -> list[torch.Tensor]: + # Use fused_add_rmsnorm_matcher.inputs() for [input, weight, residual] + rms_inputs = self.fused_add_rmsnorm_matcher.inputs() + input_ = rms_inputs[0] # (5, 16) + weight = rms_inputs[1] # (16,) + residual = rms_inputs[2] # (5, 16) + hidden_size = input_.shape[1] + # FP4 packs 2 values per uint8, so output cols = hidden_size / 2 + result = torch.empty(5, hidden_size // 2, dtype=FP4_DTYPE, device="cuda") + output_scale = empty_i32(128, 1) + input_scale = empty_fp32(1, 1) + return [result, output_scale, input_, weight, residual, input_scale] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + result: torch.Tensor, + output_scale: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + input_scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + result_rms, residual = self.fused_add_rmsnorm_matcher( + input, weight, residual + ) + at = auto_functionalized( + self.QUANT_OP, + output=result, + input=result_rms, + output_scale=output_scale, + input_scale=input_scale, + is_sf_swizzled_layout=True, + ) + return at[1], at[2], residual + + def replacement( + result: torch.Tensor, + output_scale: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + input_scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input = input.to(dtype=self.model_dtype) + at = auto_functionalized( + self.FUSED_OP, + result=result, + result_scale=output_scale, + input=input, + residual=residual, + weight=weight, + input_scale=input_scale, + epsilon=self.epsilon, + ) + # result, result_scale, residual + return at[1], at[2], at[3] + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + class RMSNormQuantFusionPass(VllmPatternMatcherPass): """ This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op. @@ -550,6 +711,17 @@ def __init__(self, config: VllmConfig) -> None: is_e8m0=is_e8m0, ).register(self.patterns) + # Register NVFP4 patterns if supported + # Make sure fused_add pattern is registered before simple rms_norm, + # as the latter is a subset of the former in torch ops + if fused_add_rms_norm_nvfp4_quant_supported: + # Fuse fused_add_rms_norm + nvfp4 quant + FusedAddRMSNormNvfp4QuantPattern(epsilon).register(self.patterns) + + if rms_norm_nvfp4_quant_supported: + # Fuse rms_norm + nvfp4 quant + RMSNormNvfp4QuantPattern(epsilon).register(self.patterns) + self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log @@ -567,4 +739,6 @@ def uuid(self) -> str: FusedAddRMSNormStaticQuantPattern, FusedAddRMSNormDynamicQuantPattern, FusedAddRMSNormGroupQuantPattern, + RMSNormNvfp4QuantPattern, + FusedAddRMSNormNvfp4QuantPattern, )