Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
"csrc/quantization/fp4/rmsnorm_nvfp4_quant_kernels.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
Expand All @@ -683,7 +684,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
"csrc/quantization/fp4/rmsnorm_nvfp4_quant_kernels.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
Expand Down
166 changes: 163 additions & 3 deletions benchmarks/fused_kernels/layernorm_rms_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,25 @@
from tqdm import tqdm

import vllm._custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types

# FP4 constants
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max

# Check if NVFP4 fused ops are available
rms_norm_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
torch.ops._C, "rms_norm_nvfp4_quant"
)
fused_add_rms_norm_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
torch.ops._C, "fused_add_rms_norm_nvfp4_quant"
)


@dataclass
Expand Down Expand Up @@ -59,6 +74,7 @@ def unfused_int8_impl(
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
**kwargs,
):
# Norm
torch_out = None
Expand All @@ -77,6 +93,7 @@ def unfused_fp8_impl(
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
**kwargs,
):
# Norm
torch_out = None
Expand All @@ -95,6 +112,7 @@ def unfused_groupwise_fp8_impl(
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
**kwargs,
):
# Norm
torch_out = None
Expand All @@ -115,6 +133,7 @@ def fused_impl(
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
**kwargs,
):
out, _ = ops.rms_norm_dynamic_per_token_quant(
x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
Expand All @@ -127,6 +146,7 @@ def fused_groupwise_impl(
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
**kwargs,
):
out, _ = ops.rms_norm_per_block_quant(
x,
Expand All @@ -139,6 +159,93 @@ def fused_groupwise_impl(
)


def get_fp4_output_tensors(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Allocate output tensors for FP4 quantization."""
m, n = x.shape
block_size = 16
# Two fp4 values will be packed into an uint8.
output = torch.empty((m, n // 2), device=x.device, dtype=torch.uint8)
# Swizzled scale layout for 128x4 tiles
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=x.device, dtype=torch.int32
)
return output, output_scale


def unfused_nvfp4_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
global_scale: torch.Tensor | None = None,
**kwargs,
):
"""Unfused RMSNorm + NVFP4 quantization implementation."""
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

# Use pre-computed global_scale if provided (simulates real inference)
if global_scale is None:
global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(
torch_out
).max().to(torch.float32)
output_quant, output_scale = ops.scaled_fp4_quant(torch_out, global_scale)


def fused_nvfp4_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
global_scale: torch.Tensor | None = None,
**kwargs,
):
"""Fused RMSNorm + NVFP4 quantization implementation."""
# Use pre-computed global_scale if provided (simulates real inference)
# In practice, global_scale is computed once during calibration
if global_scale is None:
torch_out_ref = rms_norm_layer.forward_cuda(x, None)
global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(
torch_out_ref
).max().to(torch.float32)

# Allocate output tensors each time (matching unfused_nvfp4_impl behavior)
output_quant, output_scale = get_fp4_output_tensors(x)

if residual is None:
# rms_norm_nvfp4_quant
torch.ops._C.rms_norm_nvfp4_quant(
output_quant,
output_scale,
x,
rms_norm_layer.weight,
global_scale,
1e-6,
)
else:
# fused_add_rms_norm_nvfp4_quant
# Note: residual is modified in-place, but for benchmark we accept this
torch.ops._C.fused_add_rms_norm_nvfp4_quant(
output_quant,
output_scale,
x,
residual,
rms_norm_layer.weight,
global_scale,
1e-6,
)


# Bench functions
def bench_fn(
rms_norm_layer: RMSNorm,
Expand All @@ -150,6 +257,7 @@ def bench_fn(
sub_label: str,
fn: Callable,
description: str,
global_scale: torch.Tensor | None = None,
) -> TMeasurement:
min_run_time = 1

Expand All @@ -159,10 +267,12 @@ def bench_fn(
"residual": residual,
"quant_dtype": quant_dtype,
"group_size": group_size,
"global_scale": global_scale,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)",
stmt="fn(rms_norm_layer, x, residual, quant_dtype,"
" group_size, global_scale=global_scale)",
globals=globals,
label=label,
sub_label=sub_label,
Expand Down Expand Up @@ -279,6 +389,53 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
)
)

# NVFP4 benchmarks (only if supported, hidden_size is multiple of 16,
# and dtype is fp16/bf16 - NVFP4 does not support float32)
if params.hidden_size % 16 == 0 and params.dtype in (torch.float16, torch.bfloat16):
# Pre-compute global_scale ONCE before benchmark loop
# This simulates real inference where global_scale is calibrated offline
with torch.no_grad():
torch_out_ref = layer.forward_cuda(x, None)
nvfp4_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(
torch_out_ref
).max().to(torch.float32)

if rms_norm_nvfp4_quant_supported or fused_add_rms_norm_nvfp4_quant_supported:
# unfused nvfp4 impl.
timers.append(
bench_fn(
layer,
x,
residual,
torch.uint8, # FP4 is packed as uint8
params.group_size,
label,
sub_label,
unfused_nvfp4_impl,
"unfused_nvfp4_impl",
global_scale=nvfp4_global_scale,
)
)

if (not params.add_residual and rms_norm_nvfp4_quant_supported) or (
params.add_residual and fused_add_rms_norm_nvfp4_quant_supported
):
# fused nvfp4 impl
timers.append(
bench_fn(
layer,
x,
residual,
torch.uint8, # FP4 is packed as uint8
params.group_size,
label,
sub_label,
fused_nvfp4_impl,
"fused_nvfp4_impl",
global_scale=nvfp4_global_scale,
)
)

print_timers(timers)

return timers
Expand All @@ -296,8 +453,11 @@ def main():
bench_params = get_bench_params()

timers = []
for bp in tqdm(bench_params):
timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
with set_current_vllm_config(VllmConfig()):
for bp in tqdm(bench_params):
timers.extend(
bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())
)
print_timers(timers)

# pickle all the results
Expand Down
13 changes: 13 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,19 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
std::optional<torch::Tensor> residual,
int64_t group_size, bool is_scale_transposed);

#ifndef USE_ROCM
void rms_norm_nvfp4_quant(torch::Tensor& out, torch::Tensor& output_scale,
torch::Tensor& input, torch::Tensor& weight,
torch::Tensor& input_scale, double epsilon);

void fused_add_rms_norm_nvfp4_quant(torch::Tensor& out,
torch::Tensor& output_scale,
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
torch::Tensor& input_scale, double epsilon);
#endif

void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
Expand Down
86 changes: 86 additions & 0 deletions csrc/quantization/cuda_type_utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#pragma once

#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <ATen/cuda/CUDAContext.h>

// Conditional compilation for FP4 element packing size
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
#define ELTS_PER_THREAD 16
constexpr int CVT_FP4_ELTS_PER_THREAD = 16;
constexpr bool CVT_FP4_PACK16 = true;
#else
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr bool CVT_FP4_PACK16 = false;
#endif

constexpr int CVT_FP4_SF_VEC_SIZE = 16;

namespace vllm {

// Convert PyTorch cpp type to CUDA type
template <typename T>
struct CUDATypeConverter {
using Type = T;
};

template <>
struct CUDATypeConverter<at::Half> {
using Type = half;
};

template <>
struct CUDATypeConverter<at::BFloat16> {
using Type = __nv_bfloat16;
};

// Get type2 from type or vice versa (half <-> half2, bfloat16 <-> bfloat162)
template <typename T>
struct TypeConverter {
using Type = half2;
};

template <>
struct TypeConverter<half2> {
using Type = half;
};

template <>
struct TypeConverter<half> {
using Type = half2;
};

template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};

template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};

#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
// Define a 32 bytes packed data type.
template <class Type>
struct alignas(32) PackedVec {
typename TypeConverter<Type>::Type elts[8];
};
#else
// Define a 16 bytes packed data type.
template <class Type>
struct alignas(16) PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
#endif

template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};

} // namespace vllm
Loading
Loading