From 8748f11629e16ad6045edf8c6329045f59d6498b Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 30 Oct 2025 03:27:45 -0400 Subject: [PATCH 01/27] Blockwise quant RMS norm Signed-off-by: ElizaWszola --- csrc/ops.h | 7 + ...fused_layernorm_dynamic_per_token_quant.cu | 186 +++++++++++++++++- .../fused_kernels/layernorm_utils.cuh | 111 +++++++---- csrc/torch_bindings.cpp | 7 + .../core/test_fused_quant_layernorm.py | 81 ++++++-- vllm/_custom_ops.py | 24 +++ 6 files changed, 359 insertions(+), 57 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index eb3d60b77e60..f87482a1e998 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -123,6 +123,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out, std::optional scale_ub, std::optional residual); +void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& weight, + torch::Tensor& scales, double const epsilon, + std::optional scale_ub, + std::optional residual, + int64_t group_size); + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 92d6c2f402a2..33b05fcf3144 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -76,13 +76,70 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( // RMS Norm + Quant if constexpr (std::is_same_v) { vllm::norm_and_quant( - out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, true, residual); } else { // FP8 - Do not invert s_token_scale for exact match with FBGemm vllm::norm_and_quant( - out, input, weight, rms, token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, false, residual); } } + +// RMS norm + quant kernel +template +__global__ void rms_norm_per_block_quant_kernel_1( + float* __restrict__ rms, + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens, hidden_size / group_size] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { + // Compute RMS + vllm::compute_rms(rms + blockIdx.x, input, + hidden_size, var_epsilon, residual); +} + +// RMS norm + quant kernel +template +__global__ void rms_norm_per_block_quant_kernel_2( + float* rms, float* token_scale, + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens, hidden_size / group_size] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { + // Compute Scale + vllm::compute_dynamic_per_token_scales( + token_scale + blockIdx.x, scales, input, weight, + rms[blockIdx.x / (hidden_size / group_size)], scale_ub, hidden_size, + residual, group_size); +} + +// RMS norm + quant kernel +template +__global__ void rms_norm_per_block_quant_kernel_3( + float* rms, float* token_scale, + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens, hidden_size / group_size] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { + // RMS Norm + Quant + int token_idx = blockIdx.x * hidden_size / group_size; + if constexpr (std::is_same_v) { + vllm::norm_and_quant( + out, input, weight, rms[blockIdx.x], token_scale + token_idx, + hidden_size, true, residual, group_size); + } else { + // FP8 - Do not invert s_token_scale for exact match with FBGemm + vllm::norm_and_quant( + out, input, weight, rms[blockIdx.x], token_scale + token_idx, + hidden_size, false, residual, group_size); + } +} + } // namespace vllm // Residual add + RMS norm + dynamic per token @@ -157,3 +214,128 @@ void rms_norm_dynamic_per_token_quant( out, input, weight, scales, var_epsilon, scale_ub, residual); }); } + +// Residual add + RMS norm + dynamic per token +// TODO think up better names than kernel_1, kernel_2, kernel_3, cleanup args +// TODO vectorized kernels +template +void rms_norm_per_block_quant_dispatch( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& weight, // [hidden_size] + torch::Tensor& scales, // [num_tokens, hidden_size / group_size] + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional const& scale_ub, + std::optional& residual, int64_t group_size) { + int32_t hidden_size = input.size(-1); + auto num_tokens = input.numel() / hidden_size; + + dim3 grid13(num_tokens); + dim3 block13(std::min(hidden_size, 1024)); + dim3 grid2(num_tokens * hidden_size / group_size); + dim3 block2(std::min(group_size, 1024l)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + auto const fp_options = + torch::TensorOptions().dtype(torch::kFloat32).device(input.device()); + torch::Tensor rms = torch::zeros({num_tokens}, fp_options); + torch::Tensor token_scale = + torch::zeros({num_tokens * hidden_size / group_size}, fp_options); + + if (residual.has_value()) { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel_1", [&] { + vllm::rms_norm_per_block_quant_kernel_1 + <<>>( + rms.data_ptr(), out.data_ptr(), + scales.data_ptr(), input.data_ptr(), + weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, hidden_size, residual->data_ptr(), + group_size); + }); + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel_2", [&] { + vllm::rms_norm_per_block_quant_kernel_2 + <<>>( + rms.data_ptr(), token_scale.data_ptr(), + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, hidden_size, residual->data_ptr(), + group_size); + }); + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel_3", [&] { + vllm::rms_norm_per_block_quant_kernel_3 + <<>>( + rms.data_ptr(), token_scale.data_ptr(), + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, hidden_size, residual->data_ptr(), + group_size); + }); + } else { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel_1", [&] { + vllm::rms_norm_per_block_quant_kernel_1 + <<>>( + rms.data_ptr(), out.data_ptr(), + scales.data_ptr(), input.data_ptr(), + weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, hidden_size, nullptr, group_size); + }); + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel_2", [&] { + vllm::rms_norm_per_block_quant_kernel_2 + <<>>( + rms.data_ptr(), token_scale.data_ptr(), + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, hidden_size, nullptr, group_size); + }); + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel_3", [&] { + vllm::rms_norm_per_block_quant_kernel_3 + <<>>( + rms.data_ptr(), token_scale.data_ptr(), + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, hidden_size, nullptr, group_size); + }); + } +} + +void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& weight, + torch::Tensor& scales, double const var_epsilon, + std::optional scale_ub, + std::optional residual, + int64_t group_size) { + static c10::ScalarType kFp8Type = is_fp8_ocp() + ? c10::ScalarType::Float8_e4m3fn + : c10::ScalarType::Float8_e4m3fnuz; + TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); + TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + + if (scale_ub.has_value()) { + TORCH_CHECK(out.dtype() == kFp8Type); + } + TORCH_CHECK(weight.dtype() == input.dtype()); + TORCH_CHECK(scales.dtype() == torch::kFloat32); + if (residual) { + TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + } + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_per_block_quant_dispatch", [&] { + rms_norm_per_block_quant_dispatch(out, input, weight, scales, + var_epsilon, scale_ub, + residual, group_size); + }); +} \ No newline at end of file diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 2d2fd771205c..402c16534c78 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -48,45 +48,83 @@ __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, - int32_t const hidden_size, - scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; + int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, + int32_t const group_size = 0) { + float block_absmax_val_maybe = 0.0f; constexpr scalar_out_t qmax{quant_type_max_v}; + if (group_size > 0) { + int64_t const token_block_offset = + blockIdx.x * static_cast(group_size); + int64_t const hidden_element_offset = token_block_offset % hidden_size; + for (auto i = threadIdx.x; i < group_size; i += blockDim.x) { + float x = static_cast(input[token_block_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_block_offset + i]); + } + x = static_cast(static_cast(x * rms) * + weight[hidden_element_offset + i]); + block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); + } - float block_absmax_val_maybe = 0.0f; - for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { - float x = static_cast(input[token_offset + i]); - if constexpr (has_residual) { - x += static_cast(residual[token_offset + i]); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + s_token_scale = scale; // Shared memory store + all_token_scales[blockIdx.x] = scale; // Global output store } + __syncthreads(); - x = static_cast(static_cast(x * rms) * weight[i]); - block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); - } + *token_scale = s_token_scale; + } else { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - block_absmax_val_maybe = - BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); + for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } - __shared__ float s_token_scale; - if (threadIdx.x == 0) { - float scale = 0.0f; - if (scale_ub) { - scale = min(block_absmax_val_maybe, *scale_ub); - } else { - scale = block_absmax_val_maybe; + x = static_cast(static_cast(x * rms) * weight[i]); + block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); } - // token scale computation - scale = max(scale / qmax, min_scaling_factor::val()); - s_token_scale = scale; // Shared memory store - all_token_scales[blockIdx.x] = scale; // Global output store - } - __syncthreads(); - *token_scale = s_token_scale; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + s_token_scale = scale; // Shared memory store + all_token_scales[blockIdx.x] = scale; // Global output store + } + __syncthreads(); + + *token_scale = s_token_scale; + } } template (hidden_size); ; @@ -109,8 +148,10 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, // Norm x = static_cast(static_cast(x * rms) * weight[i]); // Quant + auto scale_idx = group_size > 0 ? i / group_size : 0; + auto scale_val = invert_scale ? 1.0f / scale[scale_idx] : scale[scale_idx]; output[token_offset + i] = - ScaledQuant::quant_fn(x, scale); + ScaledQuant::quant_fn(x, scale_val); } } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7e8660349dad..990a83aed03c 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -220,6 +220,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, &rms_norm_dynamic_per_token_quant); + // Fused Layernorm + Block quant kernels + ops.def( + "rms_norm_per_block_quant(Tensor! result, Tensor input, " + "Tensor weight, Tensor! scale, float epsilon, " + "Tensor? scale_ub, Tensor!? residual, int group_size) -> ()"); + ops.impl("rms_norm_per_block_quant", torch::kCUDA, &rms_norm_per_block_quant); + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 63b5a37d3c77..b8127ad5ac16 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -8,6 +8,12 @@ import vllm._custom_ops as ops from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + per_token_group_quant_int8, +) DTYPES = [torch.bfloat16, torch.float] QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] @@ -21,6 +27,7 @@ ADD_RESIDUAL = [False, True] SCALE_UBS = [True, False] +GROUP_SIZES = [None, [1, 128]] SEEDS = [0] CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @@ -45,12 +52,13 @@ def ref_rms_norm( return out, residual -def ref_dynamic_per_token_quant( +def ref_dynamic_per_token_or_block_quant( rms_norm_layer: RMSNorm, x: torch.Tensor, quant_dtype: torch.dtype, residual: torch.Tensor | None, scale_ub: torch.Tensor | None, + group_size: list[int] | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: if scale_ub is not None: assert quant_dtype == torch.float8_e4m3fn @@ -59,13 +67,24 @@ def ref_dynamic_per_token_quant( torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual) # Quant - if quant_dtype == torch.float8_e4m3fn: - torch_out, scales = ops.scaled_fp8_quant( - torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True - ) + if group_size is not None: + if quant_dtype == torch.float8_e4m3fn: + torch_out, scales = per_token_group_quant_fp8( + torch_out, group_size=group_size[1], use_ue8m0=False + ) + else: + assert quant_dtype == torch.int8 + torch_out, scales = per_token_group_quant_int8( + torch_out, group_size=group_size[1] + ) else: - assert quant_dtype == torch.int8 - torch_out, scales = ops.scaled_int8_quant(torch_out) + if quant_dtype == torch.float8_e4m3fn: + torch_out, scales = ops.scaled_fp8_quant( + torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True + ) + else: + assert quant_dtype == torch.int8 + torch_out, scales, _ = ops.scaled_int8_quant(torch_out) return torch_out, scales, residual @@ -76,24 +95,31 @@ def ref_impl( quant_dtype: torch.dtype, residual: torch.Tensor | None, scale_ub: torch.Tensor | None, + group_size: list[int] | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - return ref_dynamic_per_token_quant( - rms_norm_layer, x, quant_dtype, residual, scale_ub + return ref_dynamic_per_token_or_block_quant( + rms_norm_layer, x, quant_dtype, residual, scale_ub, group_size ) -def ops_dynamic_per_token_quant( +def ops_dynamic_per_token_or_block_quant( weight: torch.Tensor, x: torch.Tensor, quant_dtype: torch.dtype, residual: torch.Tensor | None, scale_ub: torch.Tensor | None, + group_size: list[int] | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: if residual is not None: residual = residual.clone() - out, scales = ops.rms_norm_dynamic_per_token_quant( - x, weight, EPS, quant_dtype, scale_ub, residual - ) + if group_size is not None: + out, scales = ops.rms_norm_per_block_quant( + x, weight, EPS, quant_dtype, group_size, scale_ub, residual + ) + else: + out, scales = ops.rms_norm_dynamic_per_token_quant( + x, weight, EPS, quant_dtype, scale_ub, residual + ) return out, scales, residual @@ -103,15 +129,19 @@ def ops_impl( quant_dtype: torch.dtype, residual: torch.Tensor | None, scale_ub: torch.Tensor | None, + group_size: list[int] | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub) + return ops_dynamic_per_token_or_block_quant( + weight, x, quant_dtype, residual, scale_ub, group_size + ) @pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES) @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) -@pytest.mark.parametrize("scale_ub", SCALE_UBS) +@pytest.mark.parametrize("has_scale_ub", SCALE_UBS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_dtype", QUANT_DTYPES) +@pytest.mark.parametrize("group_size", GROUP_SIZES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() @@ -119,9 +149,10 @@ def test_rms_norm( num_tokens: int, hidden_size: int, add_residual: bool, - scale_ub: bool, + has_scale_ub: bool, dtype: torch.dtype, quant_dtype: torch.dtype, + group_size: list[int] | None, seed: int, device: str, ) -> None: @@ -130,7 +161,15 @@ def test_rms_norm( torch.cuda.manual_seed(seed) torch.set_default_device(device) - if scale_ub is not None and quant_dtype != torch.float8_e4m3fn: + if group_size is not None and hidden_size % group_size[1] != 0: + # skip + return + + # blockwise baseline doesn't support scale_ub + if group_size is not None and has_scale_ub: + return + + if has_scale_ub and quant_dtype != torch.float8_e4m3fn: # skip return @@ -143,15 +182,17 @@ def test_rms_norm( scale = 1 / (hidden_size) x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale residual = torch.randn_like(x) * scale if add_residual else None - if scale_ub is not None: + if has_scale_ub: rms_x, _ = ref_rms_norm(layer, x, residual) scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda") + else: + scale_ub = None ref_out, ref_scales, ref_residual = ref_impl( - layer, x, quant_dtype, residual, scale_ub + layer, x, quant_dtype, residual, scale_ub, group_size ) ops_out, ops_scales, ops_residual = ops_impl( - layer.weight, x, quant_dtype, residual, scale_ub + layer.weight, x, quant_dtype, residual, scale_ub, group_size ) assert ref_out.dtype == quant_dtype diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index eccb9a1ef26f..e264b0961f42 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -410,6 +410,30 @@ def rms_norm_dynamic_per_token_quant( return output, scales +# fused quant layer norm ops bloked +def rms_norm_per_block_quant( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, + group_size: list[int], + scale_ub: torch.Tensor | None = None, + residual: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert len(group_size) == 2 + output = torch.empty_like(input, dtype=quant_dtype) + scales = torch.zeros( + (input.numel() // input.shape[-1], input.shape[-1] // group_size[1]), + device=input.device, + dtype=torch.float32, + ) + + torch.ops._C.rms_norm_per_block_quant( + output, input, weight, scales, epsilon, scale_ub, residual, group_size[1] + ) + return output, scales + + # quantization ops # awq def awq_dequantize( From ea9f4dbaf6df6ded29f8bf380657ba36cd754e0a Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 31 Oct 2025 03:09:46 -0400 Subject: [PATCH 02/27] Cleanup Signed-off-by: ElizaWszola --- .../fused_layernorm_dynamic_per_token_quant.cu | 13 +++++++++---- csrc/quantization/fused_kernels/layernorm_utils.cuh | 5 ++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 33b05fcf3144..1a1d81aea22f 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -75,12 +75,13 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( // RMS Norm + Quant if constexpr (std::is_same_v) { + token_scale = 1.0f / token_scale; vllm::norm_and_quant( - out, input, weight, rms, &token_scale, hidden_size, true, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } else { // FP8 - Do not invert s_token_scale for exact match with FBGemm vllm::norm_and_quant( - out, input, weight, rms, &token_scale, hidden_size, false, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } } @@ -129,14 +130,18 @@ __global__ void rms_norm_per_block_quant_kernel_3( // RMS Norm + Quant int token_idx = blockIdx.x * hidden_size / group_size; if constexpr (std::is_same_v) { + for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto token_group_idx = token_idx + i / group_size; + token_scale[token_group_idx] = 1.0f / token_scale[token_group_idx]; + } vllm::norm_and_quant( out, input, weight, rms[blockIdx.x], token_scale + token_idx, - hidden_size, true, residual, group_size); + hidden_size, residual, group_size); } else { // FP8 - Do not invert s_token_scale for exact match with FBGemm vllm::norm_and_quant( out, input, weight, rms[blockIdx.x], token_scale + token_idx, - hidden_size, false, residual, group_size); + hidden_size, residual, group_size); } } diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 402c16534c78..a52f9a8464c9 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -133,7 +133,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float* const scale, - int32_t const hidden_size, bool invert_scale, + int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); @@ -148,8 +148,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, // Norm x = static_cast(static_cast(x * rms) * weight[i]); // Quant - auto scale_idx = group_size > 0 ? i / group_size : 0; - auto scale_val = invert_scale ? 1.0f / scale[scale_idx] : scale[scale_idx]; + auto scale_val = (group_size > 0 ? scale[i / group_size] : *scale); output[token_offset + i] = ScaledQuant::quant_fn(x, scale_val); } From b3a55fdd9c2a7a1f357aba95e8a571f2654ffa61 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 4 Nov 2025 02:16:20 -0500 Subject: [PATCH 03/27] Apply quant layer norm fixes from #27865, inv scale fix for int8 Signed-off-by: ElizaWszola --- ...fused_layernorm_dynamic_per_token_quant.cu | 9 +++--- .../fused_kernels/layernorm_utils.cuh | 6 +++- csrc/quantization/w8a8/int8/scaled_quant.cu | 3 ++ .../core/test_fused_quant_layernorm.py | 29 +++++++++++++++---- 4 files changed, 35 insertions(+), 12 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 1a1d81aea22f..265c6d9ca51a 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -130,15 +130,14 @@ __global__ void rms_norm_per_block_quant_kernel_3( // RMS Norm + Quant int token_idx = blockIdx.x * hidden_size / group_size; if constexpr (std::is_same_v) { - for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto token_group_idx = token_idx + i / group_size; - token_scale[token_group_idx] = 1.0f / token_scale[token_group_idx]; - } + // Don't invert token_scale here: do it inside the norm_and_quant kernel + // We do it because particular elements of token_scale can be shared + // between multiple threads, so this way, we avoid extra synchronization + // overhead. vllm::norm_and_quant( out, input, weight, rms[blockIdx.x], token_scale + token_idx, hidden_size, residual, group_size); } else { - // FP8 - Do not invert s_token_scale for exact match with FBGemm vllm::norm_and_quant( out, input, weight, rms[blockIdx.x], token_scale + token_idx, hidden_size, residual, group_size); diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index a52f9a8464c9..a47b2f940ef9 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -148,7 +148,11 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, // Norm x = static_cast(static_cast(x * rms) * weight[i]); // Quant - auto scale_val = (group_size > 0 ? scale[i / group_size] : *scale); + // If groupwise is_scale_inverted is true, so we invert the scale here. + auto scale_val = + (group_size > 0 ? (is_scale_inverted ? 1.0f / scale[i / group_size] + : scale[i / group_size]) + : *scale); output[token_offset + i] = ScaledQuant::quant_fn(x, scale_val); } diff --git a/csrc/quantization/w8a8/int8/scaled_quant.cu b/csrc/quantization/w8a8/int8/scaled_quant.cu index 7fe9e96bfb01..be8ecfeacf8c 100644 --- a/csrc/quantization/w8a8/int8/scaled_quant.cu +++ b/csrc/quantization/w8a8/int8/scaled_quant.cu @@ -1,5 +1,6 @@ #include #include +#include #include @@ -275,6 +276,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 256)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { @@ -306,6 +308,7 @@ void dynamic_scaled_int8_quant( int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); dim3 const block(std::min(hidden_size, 256)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index b8127ad5ac16..441271ed6ea3 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -17,7 +17,7 @@ DTYPES = [torch.bfloat16, torch.float] QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] -VEC_HIDDEN_SIZES = range(1024, 1030) +VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029] # Avoid combinatorial explosion with full Cartesian product NUM_TOKENS_HIDDEN_SIZES = [ *[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]], @@ -165,8 +165,8 @@ def test_rms_norm( # skip return - # blockwise baseline doesn't support scale_ub if group_size is not None and has_scale_ub: + # blockwise baseline doesn't support scale_ub return if has_scale_ub and quant_dtype != torch.float8_e4m3fn: @@ -197,14 +197,31 @@ def test_rms_norm( assert ref_out.dtype == quant_dtype assert ops_out.dtype == quant_dtype - assert torch.allclose(ref_scales, ops_scales) if quant_dtype == torch.int8: + assert torch.allclose(ref_scales, ops_scales, atol=1e-6) # big atol to account for round-off errors. assert torch.allclose(ref_out, ops_out, atol=1) else: - assert torch.allclose( - ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) - ) + assert torch.allclose(ref_scales, ops_scales) + a = ref_out.to(dtype=torch.float32) + b = ops_out.to(dtype=torch.float32) + ok = torch.allclose(a, b, atol=1e-6) + if not ok: + # fallback: compare dequantized values with relaxed tolerance + if group_size is None: + a_deq = a * ref_scales.view(-1, 1) + b_deq = b * ops_scales.view(-1, 1) + else: + a_deq = a * ref_scales.repeat_interleave(group_size[1], dim=1) + b_deq = b * ops_scales.repeat_interleave(group_size[1], dim=1) + # NOTE: It is possible that some future test cases trigger this + # max diff due to precision issues. If such an error is + # encountered, it's recommended to inspect the differences between + # all corresponding elements from each tensor (e.g. by looping over + # them) and checking how many the max diff error shows up on (just + # a few bad elements should still be considered acceptable). + ok = torch.allclose(a_deq, b_deq, rtol=5e-2, atol=5e-2) + assert ok if add_residual: assert torch.allclose(ref_residual, ops_residual) From 1e912ee98ee20750e35e4dfdf6f89851d257d8c9 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 4 Nov 2025 08:12:36 -0500 Subject: [PATCH 04/27] Cleanup Signed-off-by: ElizaWszola --- .../fused_kernels/layernorm_utils.cuh | 68 ++++++------------- 1 file changed, 21 insertions(+), 47 deletions(-) diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index a47b2f940ef9..b790257d3c9d 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -65,32 +65,8 @@ __device__ void compute_dynamic_per_token_scales( weight[hidden_element_offset + i]); block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); } - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - block_absmax_val_maybe = - BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); - - __shared__ float s_token_scale; - if (threadIdx.x == 0) { - float scale = 0.0f; - if (scale_ub) { - scale = min(block_absmax_val_maybe, *scale_ub); - } else { - scale = block_absmax_val_maybe; - } - // token scale computation - scale = max(scale / qmax, min_scaling_factor::val()); - s_token_scale = scale; // Shared memory store - all_token_scales[blockIdx.x] = scale; // Global output store - } - __syncthreads(); - - *token_scale = s_token_scale; } else { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { float x = static_cast(input[token_offset + i]); @@ -101,30 +77,30 @@ __device__ void compute_dynamic_per_token_scales( x = static_cast(static_cast(x * rms) * weight[i]); block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); } + } - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - block_absmax_val_maybe = - BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); - - __shared__ float s_token_scale; - if (threadIdx.x == 0) { - float scale = 0.0f; - if (scale_ub) { - scale = min(block_absmax_val_maybe, *scale_ub); - } else { - scale = block_absmax_val_maybe; - } - // token scale computation - scale = max(scale / qmax, min_scaling_factor::val()); - s_token_scale = scale; // Shared memory store - all_token_scales[blockIdx.x] = scale; // Global output store - } - __syncthreads(); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); - *token_scale = s_token_scale; + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + s_token_scale = scale; // Shared memory store + all_token_scales[blockIdx.x] = scale; // Global output store } + __syncthreads(); + + *token_scale = s_token_scale; } template (hidden_size); - ; for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { float x = static_cast(input[token_offset + i]); @@ -230,7 +205,6 @@ __device__ void compute_dynamic_per_token_scales( int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - ; // Vectorized input/weight/residual to better utilize memory bandwidth. vec4_t const* vec_input = From 051b451fcd0eaff6f6c0ac628a407a3510b3ba64 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 4 Nov 2025 09:51:27 -0500 Subject: [PATCH 05/27] Vectorize Signed-off-by: ElizaWszola --- ...fused_layernorm_dynamic_per_token_quant.cu | 37 +++++------ .../fused_kernels/layernorm_utils.cuh | 61 +++++++++++++------ 2 files changed, 60 insertions(+), 38 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 265c6d9ca51a..0313032fbb69 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -31,14 +31,15 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( // RMS Norm + Quant if constexpr (std::is_same_v) { + token_scale = 1.0f / token_scale; vllm::vectorized::norm_and_quant( - out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } else { // FP8 - Do not invert token_scale for exact match with FBGemm vllm::vectorized::norm_and_quant( - out, input, weight, rms, token_scale, hidden_size, residual); + out, input, weight, rms, &token_scale, hidden_size, residual); } } @@ -96,8 +97,9 @@ __global__ void rms_norm_per_block_quant_kernel_1( float const* scale_ub, float const var_epsilon, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { // Compute RMS - vllm::compute_rms(rms + blockIdx.x, input, - hidden_size, var_epsilon, residual); + // Always able to vectorize due to constraints on hidden_size + vllm::vectorized::compute_rms( + rms + blockIdx.x, input, hidden_size, var_epsilon, residual); } // RMS norm + quant kernel @@ -111,7 +113,9 @@ __global__ void rms_norm_per_block_quant_kernel_2( float const* scale_ub, float const var_epsilon, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { // Compute Scale - vllm::compute_dynamic_per_token_scales( + // Always able to vectorize due to constraints on hidden_size + vllm::vectorized::compute_dynamic_per_token_scales( token_scale + blockIdx.x, scales, input, weight, rms[blockIdx.x / (hidden_size / group_size)], scale_ub, hidden_size, residual, group_size); @@ -128,20 +132,17 @@ __global__ void rms_norm_per_block_quant_kernel_3( float const* scale_ub, float const var_epsilon, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { // RMS Norm + Quant + // Always able to vectorize due to constraints on hidden_size int token_idx = blockIdx.x * hidden_size / group_size; - if constexpr (std::is_same_v) { - // Don't invert token_scale here: do it inside the norm_and_quant kernel - // We do it because particular elements of token_scale can be shared - // between multiple threads, so this way, we avoid extra synchronization - // overhead. - vllm::norm_and_quant( - out, input, weight, rms[blockIdx.x], token_scale + token_idx, - hidden_size, residual, group_size); - } else { - vllm::norm_and_quant( - out, input, weight, rms[blockIdx.x], token_scale + token_idx, - hidden_size, residual, group_size); - } + // For int8, don't invert token_scale here: do it inside the norm_and_quant + // kernel. We do it because particular elements of token_scale can be shared + // between multiple threads, so this way, we avoid extra synchronization + // overhead. + vllm::vectorized::norm_and_quant, + has_residual>( + out, input, weight, rms[blockIdx.x], token_scale + token_idx, hidden_size, + residual, group_size); } } // namespace vllm diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index b790257d3c9d..d316c00f7cc1 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -202,27 +202,42 @@ __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, - int32_t const hidden_size, - scalar_t const* __restrict__ residual = nullptr) { - int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - - // Vectorized input/weight/residual to better utilize memory bandwidth. - vec4_t const* vec_input = - reinterpret_cast const*>(&input[token_offset]); - vec4_t const* vec_weight = - reinterpret_cast const*>(weight); - vec4_t const* vec_residual = nullptr; - if constexpr (has_residual) { - vec_residual = - reinterpret_cast const*>(&residual[token_offset]); - } - + int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, + int32_t const group_size = 0) { constexpr scalar_out_t qmax{quant_type_max_v}; const int VEC_SIZE = 4; - int32_t const num_vec_elems = hidden_size >> 2; + int32_t const num_vec_elems = + (group_size > 0 ? group_size : hidden_size) >> 2; float block_absmax_val_maybe = 0.0f; + // Vectorized input/weight/residual to better utilize memory bandwidth. + vec4_t const* vec_input = nullptr; + vec4_t const* vec_weight = nullptr; + vec4_t const* vec_residual = nullptr; + + if (group_size > 0) { + int64_t const token_block_offset = + blockIdx.x * static_cast(group_size); + int64_t const hidden_element_offset = token_block_offset % hidden_size; + vec_input = + reinterpret_cast const*>(&input[token_block_offset]); + vec_weight = reinterpret_cast const*>( + &weight[hidden_element_offset]); + if constexpr (has_residual) { + vec_residual = reinterpret_cast const*>( + &residual[token_block_offset]); + } + } else { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + vec_input = reinterpret_cast const*>(&input[token_offset]); + vec_weight = reinterpret_cast const*>(weight); + if constexpr (has_residual) { + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + } + #pragma unroll 4 for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { vec4_t in = vec_input[i]; @@ -280,11 +295,11 @@ template (hidden_size); - ; // Vectorized input/output/weight/residual to better utilize memory bandwidth. vec4_t const* vec_input = @@ -329,10 +344,16 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, } q8x4_t out; + + auto scale_val = + (group_size > 0 + ? (is_scale_inverted ? 1.0f / scale[i * VEC_SIZE / group_size] + : scale[i * VEC_SIZE / group_size]) + : *scale); #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { out.val[j] = ScaledQuant::quant_fn( - static_cast(x.val[j] * rms) * w.val[j], scale); + static_cast(x.val[j] * rms) * w.val[j], scale_val); } vec_output[i] = out; } From 2584f2f05392619ae6a4cddf4eea7007a52acca7 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 5 Nov 2025 18:09:50 -0500 Subject: [PATCH 06/27] Unify kernel shapes to fuse Signed-off-by: ElizaWszola --- ...fused_layernorm_dynamic_per_token_quant.cu | 9 +- .../fused_kernels/layernorm_utils.cuh | 150 ++++++++++++++---- 2 files changed, 125 insertions(+), 34 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 0313032fbb69..bd324a2b66b8 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -114,9 +114,8 @@ __global__ void rms_norm_per_block_quant_kernel_2( scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { // Compute Scale // Always able to vectorize due to constraints on hidden_size - vllm::vectorized::compute_dynamic_per_token_scales( - token_scale + blockIdx.x, scales, input, weight, + vllm::compute_dynamic_per_token_scales( + token_scale, scales, input, weight, rms[blockIdx.x / (hidden_size / group_size)], scale_ub, hidden_size, residual, group_size); } @@ -263,7 +262,7 @@ void rms_norm_per_block_quant_dispatch( VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_per_block_quant_kernel_2", [&] { vllm::rms_norm_per_block_quant_kernel_2 - <<>>( + <<>>( rms.data_ptr(), token_scale.data_ptr(), out.data_ptr(), scales.data_ptr(), input.data_ptr(), weight.data_ptr(), @@ -296,7 +295,7 @@ void rms_norm_per_block_quant_dispatch( VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_per_block_quant_kernel_2", [&] { vllm::rms_norm_per_block_quant_kernel_2 - <<>>( + <<>>( rms.data_ptr(), token_scale.data_ptr(), out.data_ptr(), scales.data_ptr(), input.data_ptr(), weight.data_ptr(), diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index d316c00f7cc1..ff1770064867 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -43,6 +43,22 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, *rms = s_rms; } +// TODO replace 32 with WARP_SIZE +__device__ float warpReduceMax(volatile float* val, int tid) { + val[tid] = fmaxf(val[tid], val[tid + 32]); + // printf("s_max vals red 32: %f (%d)\n", val[tid], tid); + val[tid] = fmaxf(val[tid], val[tid + 16]); + // printf("s_max vals red 16: %f (%d)\n", val[tid], tid); + val[tid] = fmaxf(val[tid], val[tid + 8]); + // printf("s_max vals red 8: %f (%d)\n", val[tid], tid); + val[tid] = fmaxf(val[tid], val[tid + 4]); + // printf("s_max vals red 4: %f (%d)\n", val[tid], tid); + val[tid] = fmaxf(val[tid], val[tid + 2]); + // printf("s_max vals red 2: %f (%d)\n", val[tid], tid); + val[tid] = fmaxf(val[tid], val[tid + 1]); + return val[tid]; +} + template __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, @@ -53,18 +69,95 @@ __device__ void compute_dynamic_per_token_scales( float block_absmax_val_maybe = 0.0f; constexpr scalar_out_t qmax{quant_type_max_v}; if (group_size > 0) { - int64_t const token_block_offset = - blockIdx.x * static_cast(group_size); - int64_t const hidden_element_offset = token_block_offset % hidden_size; - for (auto i = threadIdx.x; i < group_size; i += blockDim.x) { - float x = static_cast(input[token_block_offset + i]); + // if (threadIdx.x == 0) { + // printf("block size: %d\n", blockDim.x); + // } + + __shared__ float s_max_vals[1024]; + int32_t const token_offset = + blockIdx.x * static_cast(hidden_size); + int32_t num_groups = hidden_size / group_size; + int32_t const threads_per_group = blockDim.x / num_groups; + int32_t const thread_in_group = threadIdx.x % threads_per_group; + int32_t const thread_offset = threadIdx.x / threads_per_group * group_size + + thread_in_group; + // printf("%d %d %d %d\n", threadIdx.x, threads_per_group, thread_in_group, thread_offset); + // int64_t const hidden_element_offset = token_block_offset % hidden_size; + for (auto i = thread_offset; i < thread_offset + group_size; i += threads_per_group) { + float x = static_cast(input[token_offset + i]); if constexpr (has_residual) { - x += static_cast(residual[token_block_offset + i]); + x += static_cast(residual[token_offset + i]); } x = static_cast(static_cast(x * rms) * - weight[hidden_element_offset + i]); + weight[i]); block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); } + s_max_vals[threadIdx.x] = block_absmax_val_maybe; + // printf("s_max_vals 0: %f (%d)\n", block_absmax_val_maybe, threadIdx.x); + __syncthreads(); + + int step_size = threads_per_group; + int ctr = 1; + while (step_size > 32 * 2) { + step_size /= 2; + if (thread_in_group < step_size) { + s_max_vals[threadIdx.x] = fmaxf(s_max_vals[threadIdx.x], s_max_vals[threadIdx.x + step_size]); + // printf("s_max_vals %d: %f (%d)\n", ctr, s_max_vals[threadIdx.x], threadIdx.x); + ++ctr; + } + __syncthreads(); + } + float reduced_local = 0.0f; + if (thread_in_group < 32) { + reduced_local = warpReduceMax(s_max_vals, threadIdx.x); + } + if (thread_in_group == 0) { + block_absmax_val_maybe = reduced_local; + // printf("s_max_vals end: %f (%d)\n", block_absmax_val_maybe, threadIdx.x); + } + __syncthreads(); + + if (thread_in_group == 0) { + // printf("block_absmax_val_maybe: %f (%d)\n", block_absmax_val_maybe, threadIdx.x); + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + all_token_scales[blockIdx.x * num_groups + threadIdx.x / threads_per_group] = scale; // Global output store + token_scale[blockIdx.x * num_groups + threadIdx.x / threads_per_group] = scale; + __syncthreads(); + } + + // using BlockReduce = cub::BlockReduce; + // __shared__ typename BlockReduce::TempStorage reduceStore; + // block_absmax_val_maybe = + // BlockReduce(reduceStore) + // .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); + + // __shared__ float s_token_scale; + // if (threadIdx.x == 0) { + // float scale = 0.0f; + // if (scale_ub) { + // scale = min(block_absmax_val_maybe, *scale_ub); + // } else { + // scale = block_absmax_val_maybe; + // } + // // token scale computation + // scale = max(scale / qmax, min_scaling_factor::val()); + // s_token_scale = scale; // Shared memory store + // all_token_scales[blockIdx.x] = scale; // Global output store + // } + // __syncthreads(); + + // *token_scale = s_token_scale; + + // for each first warp of the group, do the for-loop reduction + // then, call warpReduceMax and sync + } else { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); @@ -77,30 +170,29 @@ __device__ void compute_dynamic_per_token_scales( x = static_cast(static_cast(x * rms) * weight[i]); block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); } - } - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - block_absmax_val_maybe = - BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); - - __shared__ float s_token_scale; - if (threadIdx.x == 0) { - float scale = 0.0f; - if (scale_ub) { - scale = min(block_absmax_val_maybe, *scale_ub); - } else { - scale = block_absmax_val_maybe; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + s_token_scale = scale; // Shared memory store + all_token_scales[blockIdx.x] = scale; // Global output store } - // token scale computation - scale = max(scale / qmax, min_scaling_factor::val()); - s_token_scale = scale; // Shared memory store - all_token_scales[blockIdx.x] = scale; // Global output store + __syncthreads(); + + *token_scale = s_token_scale; } - __syncthreads(); - - *token_scale = s_token_scale; } template Date: Thu, 6 Nov 2025 04:19:41 -0500 Subject: [PATCH 07/27] Fix Signed-off-by: ElizaWszola --- .../fused_kernels/layernorm_utils.cuh | 188 +++++++++++++----- 1 file changed, 139 insertions(+), 49 deletions(-) diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index ff1770064867..6612d650723a 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -44,18 +44,45 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, } // TODO replace 32 with WARP_SIZE -__device__ float warpReduceMax(volatile float* val, int tid) { - val[tid] = fmaxf(val[tid], val[tid + 32]); +__device__ float warpReduceMaxSpecialized(volatile float* val, int64_t tid, + int64_t thread_in_warp, + int64_t reduced_elems, + int64_t warp_id) { + if (thread_in_warp + 32 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 32]); + if (warp_id == 0 && thread_in_warp < reduced_elems) { + printf("reduce 32: %f (%ld, %ld)\n", val[tid], thread_in_warp, tid); + } // printf("s_max vals red 32: %f (%d)\n", val[tid], tid); - val[tid] = fmaxf(val[tid], val[tid + 16]); + if (thread_in_warp + 16 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 16]); + if (warp_id == 0 && thread_in_warp < reduced_elems) { + printf("reduce 16: %f (%ld, %ld)\n", val[tid], thread_in_warp, tid); + } // printf("s_max vals red 16: %f (%d)\n", val[tid], tid); - val[tid] = fmaxf(val[tid], val[tid + 8]); + if (thread_in_warp + 8 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 8]); + if (warp_id == 0 && thread_in_warp < reduced_elems) { + printf("reduce 8: %f (%ld, %ld)\n", val[tid], thread_in_warp, tid); + } // printf("s_max vals red 8: %f (%d)\n", val[tid], tid); - val[tid] = fmaxf(val[tid], val[tid + 4]); + if (thread_in_warp + 4 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 4]); + if (warp_id == 0 && thread_in_warp < reduced_elems) { + printf("reduce 4: %f (%ld, %ld)\n", val[tid], thread_in_warp, tid); + } // printf("s_max vals red 4: %f (%d)\n", val[tid], tid); - val[tid] = fmaxf(val[tid], val[tid + 2]); + if (thread_in_warp + 2 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 2]); + if (warp_id == 0 && thread_in_warp < reduced_elems) { + printf("reduce 2: %f (%ld, %ld)\n", val[tid], thread_in_warp, tid); + } // printf("s_max vals red 2: %f (%d)\n", val[tid], tid); - val[tid] = fmaxf(val[tid], val[tid + 1]); + if (thread_in_warp + 1 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 1]); + if (warp_id == 0 && thread_in_warp < reduced_elems) { + printf("reduce 1: %f (%ld, %ld)\n", val[tid], thread_in_warp, tid); + } return val[tid]; } @@ -68,57 +95,117 @@ __device__ void compute_dynamic_per_token_scales( int32_t const group_size = 0) { float block_absmax_val_maybe = 0.0f; constexpr scalar_out_t qmax{quant_type_max_v}; + __syncthreads(); if (group_size > 0) { // if (threadIdx.x == 0) { // printf("block size: %d\n", blockDim.x); // } - __shared__ float s_max_vals[1024]; - int32_t const token_offset = - blockIdx.x * static_cast(hidden_size); - int32_t num_groups = hidden_size / group_size; - int32_t const threads_per_group = blockDim.x / num_groups; - int32_t const thread_in_group = threadIdx.x % threads_per_group; - int32_t const thread_offset = threadIdx.x / threads_per_group * group_size + - thread_in_group; - // printf("%d %d %d %d\n", threadIdx.x, threads_per_group, thread_in_group, thread_offset); - // int64_t const hidden_element_offset = token_block_offset % hidden_size; - for (auto i = thread_offset; i < thread_offset + group_size; i += threads_per_group) { + // if (threadIdx.x == 0 && blockIdx.x == 0) { + // for (auto i = 0; i < blockDim.x; ++i) { + // float x = static_cast(input[i]); + // if constexpr (has_residual) { + // x += static_cast(residual[i]); + // } + // x = static_cast(static_cast(x * rms) * weight[i]); + // printf("%f ", x); + // } + // printf("\n"); + // } + // __syncthreads(); + + __shared__ float s_max_vals[1024]; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + int64_t num_groups = hidden_size / group_size; // 40 + int64_t const threads_per_group = blockDim.x / num_groups; // 25 + int64_t const thread_in_group = threadIdx.x % threads_per_group; + int64_t const group_offset = threadIdx.x / threads_per_group * group_size; + int64_t const thread_offset = group_offset + thread_in_group; + int64_t const thread_end = + min(group_offset + group_size, static_cast(hidden_size)); + // printf("%d %d %d %d\n", threadIdx.x, threads_per_group, thread_in_group, + // thread_offset); int64_t const hidden_element_offset = token_block_offset + // % hidden_size; + for (auto i = thread_offset; i < thread_end; i += threads_per_group) { float x = static_cast(input[token_offset + i]); if constexpr (has_residual) { x += static_cast(residual[token_offset + i]); } - x = static_cast(static_cast(x * rms) * - weight[i]); + x = static_cast(static_cast(x * rms) * weight[i]); block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); } s_max_vals[threadIdx.x] = block_absmax_val_maybe; - // printf("s_max_vals 0: %f (%d)\n", block_absmax_val_maybe, threadIdx.x); + // printf("s_max_xvals 0: %f (%d)\n", block_absmax_val_maybe, threadIdx.x); __syncthreads(); - int step_size = threads_per_group; - int ctr = 1; - while (step_size > 32 * 2) { - step_size /= 2; - if (thread_in_group < step_size) { - s_max_vals[threadIdx.x] = fmaxf(s_max_vals[threadIdx.x], s_max_vals[threadIdx.x + step_size]); - // printf("s_max_vals %d: %f (%d)\n", ctr, s_max_vals[threadIdx.x], threadIdx.x); - ++ctr; + // int step_size = threads_per_group; + // int ctr = 1; + // while (step_size > 32 * 2) { + // step_size /= 2; + // if (thread_in_group < step_size) { + // s_max_vals[threadIdx.x] = + // fmaxf(s_max_vals[threadIdx.x], s_max_vals[threadIdx.x + + // step_size]); + // // printf("s_max_vals %d: %f (%d)\n", ctr, s_max_vals[threadIdx.x], + // // threadIdx.x); + // ++ctr; + // } + // __syncthreads(); + // } + + int64_t const warp_size = 32; + int64_t const num_warps = blockDim.x / warp_size; + int64_t const warp_id = threadIdx.x / warp_size; + int64_t const thread_in_warp = threadIdx.x % warp_size; + int64_t const groups_per_warp = + (num_groups + num_warps - 1) / num_warps; // 2 + int64_t const absmax_per_warp = groups_per_warp * threads_per_group; + // int64_t const start = warp_id * absmax_per_warp + thread_in_warp; + // int64_t const end = (warp_id + 1) * absmax_per_warp; + for (auto i = 0; i < groups_per_warp; ++i) { + int64_t const group_id = i * num_warps + warp_id; + if (group_id < num_groups) { + int64_t warp_start = group_id * threads_per_group; + int64_t const start = warp_start + thread_in_warp; + int64_t const warp_end = min(warp_start + threads_per_group, + static_cast(hidden_size)); + for (auto j = start; j + warp_size < warp_end; j += warp_size) { + s_max_vals[start] = + fmaxf(s_max_vals[start], s_max_vals[j + warp_size]); + } + // int64_t const idx = start + i * warp_size; + // int64_t const next_idx = idx + warp_size; + // if (next_idx < blockDim.x && next_idx < end) { + // s_max_vals[idx] = fmaxf(s_max_vals[idx], s_max_vals[next_idx]); + // } + // if (thread_in_warp == 0) { + // printf("do warp reduce for warp %ld, group_id %ld, start %ld, end + // %ld, threads_per_group %ld\n", + // warp_id, group_id, start, warp_end, min(warp_end - warp_start, + // warp_size)); + // } + warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp, + min(warp_end - warp_start, warp_size), + 1 /*warp_start!=75*/); } - __syncthreads(); - } - float reduced_local = 0.0f; - if (thread_in_group < 32) { - reduced_local = warpReduceMax(s_max_vals, threadIdx.x); - } - if (thread_in_group == 0) { - block_absmax_val_maybe = reduced_local; - // printf("s_max_vals end: %f (%d)\n", block_absmax_val_maybe, threadIdx.x); } __syncthreads(); - if (thread_in_group == 0) { - // printf("block_absmax_val_maybe: %f (%d)\n", block_absmax_val_maybe, threadIdx.x); + // float reduced_local = 0.0f; + // if (thread_in_group < warp_size) { + // reduced_local = warpReduceMax(s_max_vals, threadIdx.x); + // } + // if (thread_in_group == 0) { + // block_absmax_val_maybe = reduced_local; + // // printf("s_max_vals end: %f (%d)\n", block_absmax_val_maybe, + // // threadIdx.x); + // } + // __syncthreads(); + + if (thread_in_group == 0 && thread_offset < thread_end) { + block_absmax_val_maybe = s_max_vals[threadIdx.x]; + // printf("block_absmax_val_maybe, %f, %d, %ld\n", block_absmax_val_maybe, + // threadIdx.x, threadIdx.x / threads_per_group); float scale = 0.0f; if (scale_ub) { scale = min(block_absmax_val_maybe, *scale_ub); @@ -127,17 +214,20 @@ __device__ void compute_dynamic_per_token_scales( } // token scale computation scale = max(scale / qmax, min_scaling_factor::val()); - all_token_scales[blockIdx.x * num_groups + threadIdx.x / threads_per_group] = scale; // Global output store - token_scale[blockIdx.x * num_groups + threadIdx.x / threads_per_group] = scale; - __syncthreads(); + all_token_scales[blockIdx.x * num_groups + + threadIdx.x / threads_per_group] = + scale; // Global output store + token_scale[blockIdx.x * num_groups + threadIdx.x / threads_per_group] = + scale; } + __syncthreads(); // using BlockReduce = cub::BlockReduce; // __shared__ typename BlockReduce::TempStorage reduceStore; // block_absmax_val_maybe = // BlockReduce(reduceStore) // .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); - + // __shared__ float s_token_scale; // if (threadIdx.x == 0) { // float scale = 0.0f; @@ -152,11 +242,11 @@ __device__ void compute_dynamic_per_token_scales( // all_token_scales[blockIdx.x] = scale; // Global output store // } // __syncthreads(); - + // *token_scale = s_token_scale; - // for each first warp of the group, do the for-loop reduction - // then, call warpReduceMax and sync + // for each first warp of the group, do the for-loop reduction + // then, call warpReduceMax and sync } else { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); @@ -175,7 +265,7 @@ __device__ void compute_dynamic_per_token_scales( block_absmax_val_maybe = BlockReduce(reduceStore) .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); - + __shared__ float s_token_scale; if (threadIdx.x == 0) { float scale = 0.0f; @@ -190,7 +280,7 @@ __device__ void compute_dynamic_per_token_scales( all_token_scales[blockIdx.x] = scale; // Global output store } __syncthreads(); - + *token_scale = s_token_scale; } } From 0fce11108d460daa7fe416e325a2191cd1dcfd3b Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 6 Nov 2025 07:30:24 -0500 Subject: [PATCH 08/27] Cleanup Signed-off-by: ElizaWszola --- .../fused_kernels/layernorm_utils.cuh | 126 +----------------- 1 file changed, 6 insertions(+), 120 deletions(-) diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 6612d650723a..7508c7efe3b2 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -46,43 +46,19 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, // TODO replace 32 with WARP_SIZE __device__ float warpReduceMaxSpecialized(volatile float* val, int64_t tid, int64_t thread_in_warp, - int64_t reduced_elems, - int64_t warp_id) { + int64_t reduced_elems) { if (thread_in_warp + 32 < reduced_elems) val[tid] = fmaxf(val[tid], val[tid + 32]); - if (warp_id == 0 && thread_in_warp < reduced_elems) { - printf("reduce 32: %f (%ld, %ld)\n", val[tid], thread_in_warp, tid); - } - // printf("s_max vals red 32: %f (%d)\n", val[tid], tid); if (thread_in_warp + 16 < reduced_elems) val[tid] = fmaxf(val[tid], val[tid + 16]); - if (warp_id == 0 && thread_in_warp < reduced_elems) { - printf("reduce 16: %f (%ld, %ld)\n", val[tid], thread_in_warp, tid); - } - // printf("s_max vals red 16: %f (%d)\n", val[tid], tid); if (thread_in_warp + 8 < reduced_elems) val[tid] = fmaxf(val[tid], val[tid + 8]); - if (warp_id == 0 && thread_in_warp < reduced_elems) { - printf("reduce 8: %f (%ld, %ld)\n", val[tid], thread_in_warp, tid); - } - // printf("s_max vals red 8: %f (%d)\n", val[tid], tid); if (thread_in_warp + 4 < reduced_elems) val[tid] = fmaxf(val[tid], val[tid + 4]); - if (warp_id == 0 && thread_in_warp < reduced_elems) { - printf("reduce 4: %f (%ld, %ld)\n", val[tid], thread_in_warp, tid); - } - // printf("s_max vals red 4: %f (%d)\n", val[tid], tid); if (thread_in_warp + 2 < reduced_elems) val[tid] = fmaxf(val[tid], val[tid + 2]); - if (warp_id == 0 && thread_in_warp < reduced_elems) { - printf("reduce 2: %f (%ld, %ld)\n", val[tid], thread_in_warp, tid); - } - // printf("s_max vals red 2: %f (%d)\n", val[tid], tid); if (thread_in_warp + 1 < reduced_elems) val[tid] = fmaxf(val[tid], val[tid + 1]); - if (warp_id == 0 && thread_in_warp < reduced_elems) { - printf("reduce 1: %f (%ld, %ld)\n", val[tid], thread_in_warp, tid); - } return val[tid]; } @@ -97,35 +73,15 @@ __device__ void compute_dynamic_per_token_scales( constexpr scalar_out_t qmax{quant_type_max_v}; __syncthreads(); if (group_size > 0) { - // if (threadIdx.x == 0) { - // printf("block size: %d\n", blockDim.x); - // } - - // if (threadIdx.x == 0 && blockIdx.x == 0) { - // for (auto i = 0; i < blockDim.x; ++i) { - // float x = static_cast(input[i]); - // if constexpr (has_residual) { - // x += static_cast(residual[i]); - // } - // x = static_cast(static_cast(x * rms) * weight[i]); - // printf("%f ", x); - // } - // printf("\n"); - // } - // __syncthreads(); - __shared__ float s_max_vals[1024]; int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - int64_t num_groups = hidden_size / group_size; // 40 - int64_t const threads_per_group = blockDim.x / num_groups; // 25 + int64_t num_groups = hidden_size / group_size; + int64_t const threads_per_group = blockDim.x / num_groups; int64_t const thread_in_group = threadIdx.x % threads_per_group; int64_t const group_offset = threadIdx.x / threads_per_group * group_size; int64_t const thread_offset = group_offset + thread_in_group; int64_t const thread_end = min(group_offset + group_size, static_cast(hidden_size)); - // printf("%d %d %d %d\n", threadIdx.x, threads_per_group, thread_in_group, - // thread_offset); int64_t const hidden_element_offset = token_block_offset - // % hidden_size; for (auto i = thread_offset; i < thread_end; i += threads_per_group) { float x = static_cast(input[token_offset + i]); if constexpr (has_residual) { @@ -135,33 +91,14 @@ __device__ void compute_dynamic_per_token_scales( block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); } s_max_vals[threadIdx.x] = block_absmax_val_maybe; - // printf("s_max_xvals 0: %f (%d)\n", block_absmax_val_maybe, threadIdx.x); __syncthreads(); - // int step_size = threads_per_group; - // int ctr = 1; - // while (step_size > 32 * 2) { - // step_size /= 2; - // if (thread_in_group < step_size) { - // s_max_vals[threadIdx.x] = - // fmaxf(s_max_vals[threadIdx.x], s_max_vals[threadIdx.x + - // step_size]); - // // printf("s_max_vals %d: %f (%d)\n", ctr, s_max_vals[threadIdx.x], - // // threadIdx.x); - // ++ctr; - // } - // __syncthreads(); - // } - int64_t const warp_size = 32; int64_t const num_warps = blockDim.x / warp_size; int64_t const warp_id = threadIdx.x / warp_size; int64_t const thread_in_warp = threadIdx.x % warp_size; - int64_t const groups_per_warp = - (num_groups + num_warps - 1) / num_warps; // 2 + int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps; int64_t const absmax_per_warp = groups_per_warp * threads_per_group; - // int64_t const start = warp_id * absmax_per_warp + thread_in_warp; - // int64_t const end = (warp_id + 1) * absmax_per_warp; for (auto i = 0; i < groups_per_warp; ++i) { int64_t const group_id = i * num_warps + warp_id; if (group_id < num_groups) { @@ -173,39 +110,15 @@ __device__ void compute_dynamic_per_token_scales( s_max_vals[start] = fmaxf(s_max_vals[start], s_max_vals[j + warp_size]); } - // int64_t const idx = start + i * warp_size; - // int64_t const next_idx = idx + warp_size; - // if (next_idx < blockDim.x && next_idx < end) { - // s_max_vals[idx] = fmaxf(s_max_vals[idx], s_max_vals[next_idx]); - // } - // if (thread_in_warp == 0) { - // printf("do warp reduce for warp %ld, group_id %ld, start %ld, end - // %ld, threads_per_group %ld\n", - // warp_id, group_id, start, warp_end, min(warp_end - warp_start, - // warp_size)); - // } warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp, - min(warp_end - warp_start, warp_size), - 1 /*warp_start!=75*/); + min(warp_end - warp_start, warp_size)); } } __syncthreads(); - // float reduced_local = 0.0f; - // if (thread_in_group < warp_size) { - // reduced_local = warpReduceMax(s_max_vals, threadIdx.x); - // } - // if (thread_in_group == 0) { - // block_absmax_val_maybe = reduced_local; - // // printf("s_max_vals end: %f (%d)\n", block_absmax_val_maybe, - // // threadIdx.x); - // } - // __syncthreads(); - if (thread_in_group == 0 && thread_offset < thread_end) { block_absmax_val_maybe = s_max_vals[threadIdx.x]; - // printf("block_absmax_val_maybe, %f, %d, %ld\n", block_absmax_val_maybe, - // threadIdx.x, threadIdx.x / threads_per_group); + float to_log = block_absmax_val_maybe; float scale = 0.0f; if (scale_ub) { scale = min(block_absmax_val_maybe, *scale_ub); @@ -221,33 +134,6 @@ __device__ void compute_dynamic_per_token_scales( scale; } __syncthreads(); - - // using BlockReduce = cub::BlockReduce; - // __shared__ typename BlockReduce::TempStorage reduceStore; - // block_absmax_val_maybe = - // BlockReduce(reduceStore) - // .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); - - // __shared__ float s_token_scale; - // if (threadIdx.x == 0) { - // float scale = 0.0f; - // if (scale_ub) { - // scale = min(block_absmax_val_maybe, *scale_ub); - // } else { - // scale = block_absmax_val_maybe; - // } - // // token scale computation - // scale = max(scale / qmax, min_scaling_factor::val()); - // s_token_scale = scale; // Shared memory store - // all_token_scales[blockIdx.x] = scale; // Global output store - // } - // __syncthreads(); - - // *token_scale = s_token_scale; - - // for each first warp of the group, do the for-loop reduction - // then, call warpReduceMax and sync - } else { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); From 0fac68c6501037762f483af8d7f2d2f7d3953399 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 6 Nov 2025 07:57:02 -0500 Subject: [PATCH 09/27] Scalar scale computation is working again Signed-off-by: ElizaWszola --- ...fused_layernorm_dynamic_per_token_quant.cu | 90 +++---------------- .../fused_kernels/layernorm_utils.cuh | 2 - 2 files changed, 12 insertions(+), 80 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index bd324a2b66b8..e0b2e991fad6 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -88,48 +88,26 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( // RMS norm + quant kernel template -__global__ void rms_norm_per_block_quant_kernel_1( +__global__ void rms_norm_per_block_quant_kernel( float* __restrict__ rms, scalar_out_t* __restrict__ out, // [..., hidden_size] float* __restrict__ scales, // [num_tokens, hidden_size / group_size] scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ weight, // [hidden_size] - float const* scale_ub, float const var_epsilon, int32_t const hidden_size, + float* __restrict__ token_scale, float const* scale_ub, + float const var_epsilon, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { // Compute RMS // Always able to vectorize due to constraints on hidden_size vllm::vectorized::compute_rms( rms + blockIdx.x, input, hidden_size, var_epsilon, residual); -} -// RMS norm + quant kernel -template -__global__ void rms_norm_per_block_quant_kernel_2( - float* rms, float* token_scale, - scalar_out_t* __restrict__ out, // [..., hidden_size] - float* __restrict__ scales, // [num_tokens, hidden_size / group_size] - scalar_t const* __restrict__ input, // [..., hidden_size] - scalar_t const* __restrict__ weight, // [hidden_size] - float const* scale_ub, float const var_epsilon, int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { // Compute Scale - // Always able to vectorize due to constraints on hidden_size + // TODO: Vectorize this vllm::compute_dynamic_per_token_scales( - token_scale, scales, input, weight, - rms[blockIdx.x / (hidden_size / group_size)], scale_ub, hidden_size, - residual, group_size); -} + token_scale, scales, input, weight, rms[blockIdx.x], scale_ub, + hidden_size, residual, group_size); -// RMS norm + quant kernel -template -__global__ void rms_norm_per_block_quant_kernel_3( - float* rms, float* token_scale, - scalar_out_t* __restrict__ out, // [..., hidden_size] - float* __restrict__ scales, // [num_tokens, hidden_size / group_size] - scalar_t const* __restrict__ input, // [..., hidden_size] - scalar_t const* __restrict__ weight, // [hidden_size] - float const* scale_ub, float const var_epsilon, int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { // RMS Norm + Quant // Always able to vectorize due to constraints on hidden_size int token_idx = blockIdx.x * hidden_size / group_size; @@ -236,8 +214,6 @@ void rms_norm_per_block_quant_dispatch( dim3 grid13(num_tokens); dim3 block13(std::min(hidden_size, 1024)); - dim3 grid2(num_tokens * hidden_size / group_size); - dim3 block2(std::min(group_size, 1024l)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -249,66 +225,24 @@ void rms_norm_per_block_quant_dispatch( if (residual.has_value()) { VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel_1", [&] { - vllm::rms_norm_per_block_quant_kernel_1 + out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { + vllm::rms_norm_per_block_quant_kernel <<>>( rms.data_ptr(), out.data_ptr(), scales.data_ptr(), input.data_ptr(), - weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, residual->data_ptr(), - group_size); - }); - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel_2", [&] { - vllm::rms_norm_per_block_quant_kernel_2 - <<>>( - rms.data_ptr(), token_scale.data_ptr(), - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, residual->data_ptr(), - group_size); - }); - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel_3", [&] { - vllm::rms_norm_per_block_quant_kernel_3 - <<>>( - rms.data_ptr(), token_scale.data_ptr(), - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), weight.data_ptr(), + weight.data_ptr(), token_scale.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, var_epsilon, hidden_size, residual->data_ptr(), group_size); }); } else { VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel_1", [&] { - vllm::rms_norm_per_block_quant_kernel_1 + out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { + vllm::rms_norm_per_block_quant_kernel <<>>( rms.data_ptr(), out.data_ptr(), scales.data_ptr(), input.data_ptr(), - weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, nullptr, group_size); - }); - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel_2", [&] { - vllm::rms_norm_per_block_quant_kernel_2 - <<>>( - rms.data_ptr(), token_scale.data_ptr(), - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, nullptr, group_size); - }); - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel_3", [&] { - vllm::rms_norm_per_block_quant_kernel_3 - <<>>( - rms.data_ptr(), token_scale.data_ptr(), - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), weight.data_ptr(), + weight.data_ptr(), token_scale.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, var_epsilon, hidden_size, nullptr, group_size); }); diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 7508c7efe3b2..ca88faa7cd13 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -98,7 +98,6 @@ __device__ void compute_dynamic_per_token_scales( int64_t const warp_id = threadIdx.x / warp_size; int64_t const thread_in_warp = threadIdx.x % warp_size; int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps; - int64_t const absmax_per_warp = groups_per_warp * threads_per_group; for (auto i = 0; i < groups_per_warp; ++i) { int64_t const group_id = i * num_warps + warp_id; if (group_id < num_groups) { @@ -118,7 +117,6 @@ __device__ void compute_dynamic_per_token_scales( if (thread_in_group == 0 && thread_offset < thread_end) { block_absmax_val_maybe = s_max_vals[threadIdx.x]; - float to_log = block_absmax_val_maybe; float scale = 0.0f; if (scale_ub) { scale = min(block_absmax_val_maybe, *scale_ub); From 294e884f14c299f9235de65e9473a3f0bf6293bb Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 6 Nov 2025 09:13:08 -0500 Subject: [PATCH 10/27] Vectorized Signed-off-by: ElizaWszola --- ...fused_layernorm_dynamic_per_token_quant.cu | 5 +- .../fused_kernels/layernorm_utils.cuh | 174 +++++++++++++----- 2 files changed, 129 insertions(+), 50 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index e0b2e991fad6..0abee0fac69f 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -103,8 +103,9 @@ __global__ void rms_norm_per_block_quant_kernel( rms + blockIdx.x, input, hidden_size, var_epsilon, residual); // Compute Scale - // TODO: Vectorize this - vllm::compute_dynamic_per_token_scales( + // Always able to vectorize due to constraints on hidden_size and group_size + vllm::vectorized::compute_dynamic_per_token_scales( token_scale, scales, input, weight, rms[blockIdx.x], scale_ub, hidden_size, residual, group_size); diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index ca88faa7cd13..64df224bc1b5 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -273,8 +273,6 @@ __device__ void compute_dynamic_per_token_scales( constexpr scalar_out_t qmax{quant_type_max_v}; const int VEC_SIZE = 4; - int32_t const num_vec_elems = - (group_size > 0 ? group_size : hidden_size) >> 2; float block_absmax_val_maybe = 0.0f; // Vectorized input/weight/residual to better utilize memory bandwidth. @@ -283,76 +281,156 @@ __device__ void compute_dynamic_per_token_scales( vec4_t const* vec_residual = nullptr; if (group_size > 0) { - int64_t const token_block_offset = - blockIdx.x * static_cast(group_size); - int64_t const hidden_element_offset = token_block_offset % hidden_size; - vec_input = - reinterpret_cast const*>(&input[token_block_offset]); - vec_weight = reinterpret_cast const*>( - &weight[hidden_element_offset]); - if constexpr (has_residual) { - vec_residual = reinterpret_cast const*>( - &residual[token_block_offset]); - } - } else { + __shared__ float s_max_vals[1024]; + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + int64_t num_groups = hidden_size / group_size; + int64_t const threads_per_group = blockDim.x / num_groups; + int64_t const thread_in_group = threadIdx.x % threads_per_group; + int64_t const group_offset = + threadIdx.x / threads_per_group * (group_size >> 2); + int64_t const thread_offset = group_offset + thread_in_group; + int64_t const thread_end = min(group_offset + (group_size >> 2), + static_cast(hidden_size >> 2)); vec_input = reinterpret_cast const*>(&input[token_offset]); vec_weight = reinterpret_cast const*>(weight); if constexpr (has_residual) { vec_residual = reinterpret_cast const*>(&residual[token_offset]); } - } + int32_t const num_vec_elems = thread_end; #pragma unroll 4 - for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { - vec4_t in = vec_input[i]; - vec4_t const w = vec_weight[i]; + for (auto i = thread_offset; i < num_vec_elems; i += threads_per_group) { + vec4_t in = vec_input[i]; + vec4_t const w = vec_weight[i]; - vec4_t x; + vec4_t x; #pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - x.val[j] = static_cast(in.val[j]); + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] = static_cast(in.val[j]); + } + + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] += static_cast(r.val[j]); + } + } + +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + block_absmax_val_maybe = + fmaxf(block_absmax_val_maybe, + fabs(static_cast(x.val[j] * rms) * w.val[j])); + } + } + + s_max_vals[threadIdx.x] = block_absmax_val_maybe; + __syncthreads(); + + int64_t const warp_size = 32; + int64_t const num_warps = blockDim.x / warp_size; + int64_t const warp_id = threadIdx.x / warp_size; + int64_t const thread_in_warp = threadIdx.x % warp_size; + int64_t const groups_per_warp = (num_groups + num_warps - 1) / num_warps; + for (auto i = 0; i < groups_per_warp; ++i) { + int64_t const group_id = i * num_warps + warp_id; + if (group_id < num_groups) { + int64_t warp_start = group_id * threads_per_group; + int64_t const start = warp_start + thread_in_warp; + int64_t const warp_end = min(warp_start + threads_per_group, + static_cast(hidden_size)); + for (auto j = start; j + warp_size < warp_end; j += warp_size) { + s_max_vals[start] = + fmaxf(s_max_vals[start], s_max_vals[j + warp_size]); + } + warpReduceMaxSpecialized(s_max_vals, start, thread_in_warp, + min(warp_end - warp_start, warp_size)); + } + } + __syncthreads(); + + if (thread_in_group == 0 && thread_offset < thread_end) { + block_absmax_val_maybe = s_max_vals[threadIdx.x]; + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + all_token_scales[blockIdx.x * num_groups + + threadIdx.x / threads_per_group] = + scale; // Global output store + token_scale[blockIdx.x * num_groups + threadIdx.x / threads_per_group] = + scale; } + __syncthreads(); + } else { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + vec_input = reinterpret_cast const*>(&input[token_offset]); + vec_weight = reinterpret_cast const*>(weight); if constexpr (has_residual) { - vec4_t r = vec_residual[i]; + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + + int32_t const num_vec_elems = (hidden_size >> 2); + +#pragma unroll 4 + for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { - x.val[j] += static_cast(r.val[j]); + x.val[j] = static_cast(in.val[j]); } - } + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; #pragma unroll - for (int j = 0; j < VEC_SIZE; ++j) { - block_absmax_val_maybe = - fmaxf(block_absmax_val_maybe, - fabs(static_cast(x.val[j] * rms) * w.val[j])); + for (int j = 0; j < VEC_SIZE; ++j) { + x.val[j] += static_cast(r.val[j]); + } + } + +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + block_absmax_val_maybe = + fmaxf(block_absmax_val_maybe, + fabs(static_cast(x.val[j] * rms) * w.val[j])); + } } - } - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - block_absmax_val_maybe = - BlockReduce(reduceStore) - .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x); - __shared__ float s_token_scale; - if (threadIdx.x == 0) { - float scale = 0.0f; - if (scale_ub) { - scale = min(block_absmax_val_maybe, *scale_ub); - } else { - scale = block_absmax_val_maybe; + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor::val()); + s_token_scale = scale; // shared memory store + all_token_scales[blockIdx.x] = scale; // global output store } - // token scale computation - scale = max(scale / qmax, min_scaling_factor::val()); - s_token_scale = scale; // shared memory store - all_token_scales[blockIdx.x] = scale; // global output store - } - __syncthreads(); + __syncthreads(); - *token_scale = s_token_scale; + *token_scale = s_token_scale; + } } // hidden_size must be a multiple of 4 From 54ab82fb825cea3ac503cbdeb13b2c89bc5b306e Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 7 Nov 2025 03:11:01 -0500 Subject: [PATCH 11/27] Test group_size=64, add benchmarks Signed-off-by: ElizaWszola --- .../fused_kernels/layernorm_rms_benchmarks.py | 83 ++++++++++++++++++- .../core/test_fused_quant_layernorm.py | 2 +- 2 files changed, 81 insertions(+), 4 deletions(-) diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index d809bf1db8cb..8591c53d2691 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -14,6 +14,9 @@ import vllm._custom_ops as ops from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) @dataclass @@ -22,6 +25,7 @@ class bench_params_t: hidden_size: int add_residual: bool dtype: torch.dtype + group_size: list[int] def description(self): return ( @@ -29,6 +33,7 @@ def description(self): f"x D {self.hidden_size} " f"x R {self.add_residual} " f"x DT {self.dtype}" + f"x GS {self.group_size}" ) @@ -38,10 +43,11 @@ def get_bench_params() -> list[bench_params_t]: HIDDEN_SIZES = list(range(1024, 8129, 1024)) ADD_RESIDUAL = [True, False] DTYPES = [torch.bfloat16, torch.float] + GROUP_SIZES = [[1, 64], [1, 128]] - combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) + combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES, GROUP_SIZES) bench_params = list( - map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations) + map(lambda x: bench_params_t(x[0], x[1], x[2], x[3], x[4]), combinations) ) return bench_params @@ -52,6 +58,7 @@ def unfused_int8_impl( x: torch.Tensor, residual: torch.Tensor | None, quant_dtype: torch.dtype, + group_size: list[int], ): # Norm torch_out = None @@ -69,6 +76,7 @@ def unfused_fp8_impl( x: torch.Tensor, residual: torch.Tensor | None, quant_dtype: torch.dtype, + group_size: list[int], ): # Norm torch_out = None @@ -81,23 +89,57 @@ def unfused_fp8_impl( torch_out, _ = ops.scaled_fp8_quant(torch_out) +def unfused_groupwise_fp8_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], +): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _ = per_token_group_quant_fp8( + torch_out, group_size=group_size[1], use_ue8m0=False + ) + + def fused_impl( rms_norm_layer: RMSNorm, # this stores the weights x: torch.Tensor, residual: torch.Tensor | None, quant_dtype: torch.dtype, + group_size: list[int], ): out, _ = ops.rms_norm_dynamic_per_token_quant( x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual ) +def fused_groupwise_impl( + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: torch.Tensor | None, + quant_dtype: torch.dtype, + group_size: list[int], +): + out, _ = ops.rms_norm_per_block_quant( + x, rms_norm_layer.weight, 1e-6, quant_dtype, group_size, residual=residual + ) + + # Bench functions def bench_fn( rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, quant_dtype: torch.dtype, + group_size: list[int], label: str, sub_label: str, fn: Callable, @@ -110,10 +152,11 @@ def bench_fn( "x": x, "residual": residual, "quant_dtype": quant_dtype, + "group_size": group_size, "fn": fn, } return TBenchmark.Timer( - stmt="fn(rms_norm_layer, x, residual, quant_dtype)", + stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)", globals=globals, label=label, sub_label=sub_label, @@ -147,6 +190,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.int8, + params.group_size, label, sub_label, unfused_int8_impl, @@ -161,6 +205,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.float8_e4m3fn, + params.group_size, label, sub_label, unfused_fp8_impl, @@ -175,6 +220,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.int8, + params.group_size, label, sub_label, fused_impl, @@ -189,6 +235,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu x, residual, torch.float8_e4m3fn, + params.group_size, label, sub_label, fused_impl, @@ -196,6 +243,36 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu ) ) + # unfused groupwise fp8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + unfused_groupwise_fp8_impl, + "unfused_groupwise_fp8_impl", + ) + ) + + # fused groupwise fp8 impl. + timers.append( + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + fused_groupwise_impl, + "fused_groupwise_fp8_impl", + ) + ) + print_timers(timers) return timers diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 441271ed6ea3..e715cffd16ec 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -27,7 +27,7 @@ ADD_RESIDUAL = [False, True] SCALE_UBS = [True, False] -GROUP_SIZES = [None, [1, 128]] +GROUP_SIZES = [None, [1, 64], [1, 128]] SEEDS = [0] CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] From e00a6d7138c63218c3ad960c286c80420ffdd8fa Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 7 Nov 2025 14:46:33 -0800 Subject: [PATCH 12/27] optimize Signed-off-by: yewentao256 --- .../fused_layernorm_dynamic_per_token_quant.cu | 6 +++--- vllm/_custom_ops.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 0abee0fac69f..cf86468789bb 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -214,15 +214,15 @@ void rms_norm_per_block_quant_dispatch( auto num_tokens = input.numel() / hidden_size; dim3 grid13(num_tokens); - dim3 block13(std::min(hidden_size, 1024)); + dim3 block13(std::min(hidden_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); auto const fp_options = torch::TensorOptions().dtype(torch::kFloat32).device(input.device()); - torch::Tensor rms = torch::zeros({num_tokens}, fp_options); + torch::Tensor rms = torch::empty({num_tokens}, fp_options); torch::Tensor token_scale = - torch::zeros({num_tokens * hidden_size / group_size}, fp_options); + torch::empty({num_tokens * hidden_size / group_size}, fp_options); if (residual.has_value()) { VLLM_DISPATCH_QUANT_TYPES( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0354aa4bbdb6..9079ce53b160 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -422,7 +422,7 @@ def rms_norm_per_block_quant( ) -> tuple[torch.Tensor, torch.Tensor]: assert len(group_size) == 2 output = torch.empty_like(input, dtype=quant_dtype) - scales = torch.zeros( + scales = torch.empty( (input.numel() // input.shape[-1], input.shape[-1] // group_size[1]), device=input.device, dtype=torch.float32, From 77e00789dcc129e6a147ab03be62bda9ed0ddd3d Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 14 Nov 2025 00:30:53 -0500 Subject: [PATCH 13/27] Add fusion patterns Signed-off-by: ElizaWszola --- tests/compile/test_fusion.py | 59 ++++++-- vllm/compilation/fusion.py | 129 ++++++++++++++++++ vllm/compilation/matcher_utils.py | 26 ++++ .../layers/quantization/utils/fp8_utils.py | 66 ++++----- .../layers/quantization/utils/quant_utils.py | 3 + 5 files changed, 228 insertions(+), 55 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 286f2276367a..6897ee66daa8 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -18,6 +18,9 @@ VllmConfig, ) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + W8A8BlockFp8LinearOp, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, @@ -25,6 +28,7 @@ ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, + cutlass_block_fp8_supported, cutlass_fp8_supported, maybe_create_device_identity, ) @@ -44,7 +48,7 @@ def __init__( self, hidden_size: int, eps: float, - static: bool, + group_shape: GroupShape, cuda_force_torch: bool, *args, **kwargs, @@ -52,8 +56,16 @@ def __init__( super().__init__(*args, **kwargs) self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] - self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] - group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN + if group_shape == GroupShape(1, 128): + self.wscale = [ + torch.rand( + (hidden_size // 128, hidden_size // 128), dtype=torch.float32 + ) + for _ in range(3) + ] + else: + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + static = group_shape == GroupShape.PER_TENSOR quant_scale = ScaleDesc(torch.float32, static, group_shape) self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: @@ -61,18 +73,32 @@ def __init__( else: self.scale = [None for _ in range(3)] self.w = [ - torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - for _ in range(3) + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3) ] + if group_shape != GroupShape(1, 128): + self.w = [self.w[0].t() for _ in range(3)] - with override_cutlass_fp8_supported(not cuda_force_torch): - self.fp8_linear = Fp8LinearOp( - act_quant_static=static, + if group_shape == GroupShape(1, 128): + self.fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(128, 128), act_quant_group_shape=group_shape, + cutlass_block_fp8_supported=cutlass_block_fp8_supported(), + use_aiter_and_is_supported=False, ) + else: + with override_cutlass_fp8_supported(not cuda_force_torch): + self.fp8_linear = Fp8LinearOp( + act_quant_static=static, + act_quant_group_shape=group_shape, + ) self.enable_rms_norm_custom_op = self.norm[0].enabled() - self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() + self.enable_quant_fp8_custom_op = ( + self.fp8_linear.quant_fp8.enabled() + if group_shape != GroupShape(1, 128) + else True + ) + self.group_shape = group_shape def forward(self, x): # avoid having graph input be an arg to a pattern directly @@ -119,11 +145,14 @@ def ops_in_model_before_partial(self): ) +GROUP_SHAPES = [GroupShape.PER_TOKEN, GroupShape.PER_TENSOR, GroupShape(1, 128)] + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) -@pytest.mark.parametrize("static", [True, False]) +@pytest.mark.parametrize("group_shape", GROUP_SHAPES) @pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) # cuda_force_torch used to test torch code path on platforms that @@ -139,7 +168,7 @@ def test_fusion_rmsnorm_quant( hidden_size, num_tokens, eps, - static, + group_shape, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, cuda_force_torch, @@ -149,6 +178,9 @@ def test_fusion_rmsnorm_quant( torch.manual_seed(1) maybe_create_device_identity() # needed for certain non-cutlass fp8 paths + if not enable_quant_fp8_custom_op and group_shape == GroupShape(1, 128): + pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization") + custom_ops = [] if enable_rms_norm_custom_op: custom_ops.append("+rms_norm") @@ -170,8 +202,7 @@ def test_fusion_rmsnorm_quant( backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) backend2 = TestBackend(noop_pass, cleanup_pass) - model = TestModel(hidden_size, eps, static, cuda_force_torch) - + model = TestModel(hidden_size, eps, group_shape, cuda_force_torch) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) torch._dynamo.mark_dynamic(x, 0) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 8f0ad2d69fbe..c50563547b64 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -15,6 +15,7 @@ GroupShape, QuantKey, ScaleDesc, + kFp8Dynamic128Sym, kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, @@ -54,6 +55,8 @@ def empty_i32(*args, **kwargs): } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default +if current_platform.is_cuda(): + QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 class FusedRMSQuantKey(NamedTuple): @@ -86,6 +89,12 @@ def __str__(self): FusedRMSQuantKey( kFp8DynamicTokenSym, True ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic128Sym, False + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic128Sym, True + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 } @@ -214,6 +223,114 @@ def replacement( ) +class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + self.group_shape = group_shape + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass): + def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + result, scale = self.quant_matcher(result_rms) + return result, residual, scale + + def replacement( + input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor + ): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) + + result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale(input) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual, + group_size=self.group_shape[1], + ) + + # result, residual, scale + return at[1], at[3], at[2] + + pm.register_replacement( + pattern, + replacement, + self.rmsnorm_matcher.inputs(), + pm.fwd_only, + pm_pass, + ) + + +class RMSNormGroupQuantPattern(RMSNormQuantPattern): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape, + symmetric=True, + ): + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + self.group_shape = group_shape + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass): + def pattern(input: torch.Tensor, weight: torch.Tensor): + result_rms = self.rmsnorm_matcher(input, weight) + result, scale = self.quant_matcher(result_rms) + return result, scale + + def replacement(input: torch.Tensor, weight: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) + + result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale(input) + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None, + group_size=self.group_shape[1], + ) + + # result, scale + return at[1], at[2] + + pm.register_replacement( + pattern, + replacement, + self.rmsnorm_matcher.inputs(), + pm.fwd_only, + pm_pass, + ) + + class RMSNormDynamicQuantPattern(RMSNormQuantPattern): def __init__( self, @@ -336,6 +453,16 @@ def __init__(self, config: VllmConfig): # Make sure fused add patterns are before simple rms norm, # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: + # Try it before the dynamic one, maybe refator later to run both + # blockwise and per-token from the same pattern? + FusedAddRMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) + ).register(self.patterns) + + RMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) + ).register(self.patterns) + # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( self.patterns @@ -362,9 +489,11 @@ def __call__(self, graph: fx.Graph): def uuid(self) -> Any: return self.hash_source( self, + RMSNormGroupQuantPattern, RMSNormQuantPattern, RMSNormStaticQuantPattern, RMSNormDynamicQuantPattern, FusedAddRMSNormStaticQuantPattern, FusedAddRMSNormDynamicQuantPattern, + FusedAddRMSNormGroupQuantPattern, ) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 383fe6033a6d..14121afea73f 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -11,8 +11,10 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, QuantKey, _normalize_quant_group_shape, + kFp8Dynamic128Sym, kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8StaticTensorSym, @@ -32,6 +34,9 @@ if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 +if current_platform.is_cuda(): + QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 + SILU_MUL_OP = torch.ops._C.silu_and_mul.default @@ -171,6 +176,27 @@ def forward_custom( input.shape, device=input.device, dtype=self.quant_key.dtype ) + if self.quant_key.scale.group_shape == GroupShape(1, 128): + assert scale is None + scale = self.make_scale(input) + + finfo = torch.finfo(self.quant_key.dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + _, result, scale = auto_functionalized( + self.QUANT_OP, + input=input, + output_q=result, + output_s=scale, + group_size=self.quant_key.scale.group_shape[1], + eps=1e-10, + fp8_min=fp8_min, + fp8_max=fp8_max, + scale_ue8m0=False, + ) + return result, scale + if self.quant_key.scale.static: assert scale is not None _, result = auto_functionalized( diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index f25148abb619..9ffbbb2d3af4 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -372,6 +372,7 @@ def _run_triton( ) -> torch.Tensor: assert self.input_quant_op is not None q_input, input_scale = self.input_quant_op(input_2d) + assert input_scale is not None return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( q_input, weight, @@ -604,15 +605,11 @@ def per_token_group_quant_fp8( assert out_q is None or out_q.shape == x.shape x_q = out_q if x_q is None: - x_q = torch.empty_like(x, device=x.device, dtype=dtype) + x_q = torch.empty(x.shape, device=x.device, dtype=dtype) - # Allocate the scale tensor in either row- or column-major format. - if column_major_scales: - shape = (x.shape[-1] // group_size,) + x.shape[:-1] - x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) - else: - shape = x.shape[:-1] + (x.shape[-1] // group_size,) - x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + # Allocate the scale tensor in row-major format. + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) # prefer CUDA kernel if available # TODO(bnell): this causes some fp8 moe test to fail. @@ -629,41 +626,28 @@ def per_token_group_quant_fp8( # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + use_ue8m0=use_ue8m0, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + # Permute the scales only at the end of this function to enable quant + # RMS norm fusion if column_major_scales: - _per_token_group_quant_fp8_colmajor[(M,)]( - x, - x_q, - x_s, - group_size, - x.shape[1], - x.stride(0), - x_s.stride(1), - eps, - fp8_min=fp8_min, - fp8_max=fp8_max, - use_ue8m0=use_ue8m0, - BLOCK=BLOCK, - num_warps=num_warps, - num_stages=num_stages, - ) + return x_q, x_s.permute(-1, -2) else: - _per_token_group_quant_fp8[(M,)]( - x, - x_q, - x_s, - group_size, - x.shape[1], - x.stride(0), - eps, - fp8_min=fp8_min, - fp8_max=fp8_max, - use_ue8m0=use_ue8m0, - BLOCK=BLOCK, - num_warps=num_warps, - num_stages=num_stages, - ) - - return x_q, x_s + return x_q, x_s @triton.jit diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index d056d3404385..0e4998f4fa4f 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -115,6 +115,9 @@ def __str__(self): kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale) +kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128)) +kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True) + # Normalize the group_shape to the full extent for any dims that are -1 def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): From c63bb1b2717c6a25c373804a9462deaf61a815fa Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 18 Nov 2025 02:03:45 -0500 Subject: [PATCH 14/27] Account for transposed scales Signed-off-by: ElizaWszola --- vllm/_custom_ops.py | 2 +- vllm/compilation/fusion.py | 18 ++++- vllm/compilation/matcher_utils.py | 32 ++++++--- .../layers/quantization/utils/fp8_utils.py | 65 ++++++++++++------- vllm/utils/deep_gemm.py | 17 +++++ 5 files changed, 99 insertions(+), 35 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d3568e6e7fe0..3a7667a79801 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -439,7 +439,7 @@ def rms_norm_dynamic_per_token_quant( return output, scales -# fused quant layer norm ops bloked +# fused quant layer norm ops blocked def rms_norm_per_block_quant( input: torch.Tensor, weight: torch.Tensor, diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 8fb58f185b02..e480443af8a4 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -22,7 +22,11 @@ kNvfp4Quant, kStaticTensorScale, ) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + cutlass_block_fp8_supported, +) from vllm.platforms import current_platform +from vllm.utils.deep_gemm import should_use_deepgemm_for_fp8_linear_for_nk from .inductor_pass import enable_fake_mode from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm @@ -109,6 +113,16 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): config = get_current_vllm_config() self.model_dtype = config.model_config.dtype if config.model_config else None + # groupwise FP8 linear uses col major scales if deepgemm and cutlass + use_col_major_scales = ( + should_use_deepgemm_for_fp8_linear_for_nk( + self.model_dtype, + config.model_config.hf_config.intermediate_size, + config.model_config.hf_config.hidden_size, + ) + or cutlass_block_fp8_supported() + ) + assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] @@ -117,7 +131,9 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): if not key.fused_add else MatcherFusedAddRMSNorm(epsilon) ) - self.quant_matcher = MatcherQuantFP8(key.quant) + self.quant_matcher = MatcherQuantFP8( + key.quant, use_col_major_scales=use_col_major_scales + ) class RMSNormStaticQuantPattern(RMSNormQuantPattern): diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 909ae1ba5c20..04c83fb2d7db 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -231,12 +231,18 @@ def forward_native( class MatcherQuantFP8(MatcherCustomOp): - def __init__(self, quant_key: QuantKey, enabled: bool | None = None): + def __init__( + self, + quant_key: QuantKey, + enabled: bool | None = None, + use_col_major_scales: bool = False, + ): if enabled is None: enabled = QuantFP8.enabled() super().__init__(enabled) self.quant_key = quant_key + self.use_col_major_scales = use_col_major_scales assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] @@ -257,7 +263,7 @@ def forward_custom( if self.quant_key.scale.group_shape == GroupShape(1, 128): assert scale is None - scale = self.make_scale(input) + scale = self.make_scale(input, transposed=self.use_col_major_scales) finfo = torch.finfo(self.quant_key.dtype) fp8_min = finfo.min @@ -297,16 +303,24 @@ def forward_native( ) -> tuple[torch.Tensor, torch.Tensor]: return self.quant_fp8(input, scale) - def make_scale(self, input: torch.Tensor): + def make_scale(self, input: torch.Tensor, transposed: bool = False): normalized_group_shape = _normalize_quant_group_shape( input, self.quant_key.scale.group_shape ) - scale_shape = ( - input.shape[0] // normalized_group_shape[0], - input.shape[1] // normalized_group_shape[1], - ) - - return torch.empty(scale_shape, device=input.device, dtype=torch.float32) + if transposed: + scale_shape = ( + input.shape[1] // normalized_group_shape[1], + input.shape[0] // normalized_group_shape[0], + ) + return torch.empty( + scale_shape, device=input.device, dtype=torch.float32 + ).permute(-1, -2) + else: + scale_shape = ( + input.shape[0] // normalized_group_shape[0], + input.shape[1] // normalized_group_shape[1], + ) + return torch.empty(scale_shape, device=input.device, dtype=torch.float32) def inputs(self) -> list[torch.Tensor]: input = self.empty(5, 16) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index b68b5cfa0b10..eaf56f0ef417 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -363,7 +363,6 @@ def _run_triton( assert input_scale is None assert self.input_quant_op is not None q_input, input_scale = self.input_quant_op(input_2d) - assert input_scale is not None return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( q_input, weight, @@ -599,9 +598,13 @@ def per_token_group_quant_fp8( if x_q is None: x_q = torch.empty(x.shape, device=x.device, dtype=dtype) - # Allocate the scale tensor in row-major format. - shape = x.shape[:-1] + (x.shape[-1] // group_size,) - x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + # Allocate the scale tensor in either row- or column-major format. + if column_major_scales: + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size,) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) # prefer CUDA kernel if available # TODO(bnell): this causes some fp8 moe test to fail. @@ -618,28 +621,42 @@ def per_token_group_quant_fp8( # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 - _per_token_group_quant_fp8[(M,)]( - x, - x_q, - x_s, - group_size, - x.shape[1], - x.stride(0), - eps, - fp8_min=fp8_min, - fp8_max=fp8_max, - use_ue8m0=use_ue8m0, - BLOCK=BLOCK, - num_warps=num_warps, - num_stages=num_stages, - ) - - # Permute the scales only at the end of this function to enable quant - # RMS norm fusion if column_major_scales: - return x_q, x_s.permute(-1, -2) + x_s = x_s.permute(-1, -2) + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + use_ue8m0=use_ue8m0, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) else: - return x_q, x_s + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + use_ue8m0=use_ue8m0, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s @triton.jit diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index b5ab37534dd7..f2e6aba4be21 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -373,6 +373,22 @@ def should_use_deepgemm_for_fp8_linear( ) +def should_use_deepgemm_for_fp8_linear_for_nk( + output_dtype: torch.dtype, + shape0: int, + shape1: int, + supports_deep_gemm: bool | None = None, +): + if supports_deep_gemm is None: + supports_deep_gemm = is_deep_gemm_supported() + return ( + supports_deep_gemm + and output_dtype == torch.bfloat16 + and shape0 % 128 == 0 + and shape1 % 128 == 0 + ) + + __all__ = [ "calc_diff", "fp8_gemm_nt", @@ -386,6 +402,7 @@ def should_use_deepgemm_for_fp8_linear( "is_deep_gemm_supported", "get_num_sms", "should_use_deepgemm_for_fp8_linear", + "should_use_deepgemm_for_fp8_linear_for_nk", "get_col_major_tma_aligned_tensor", "get_mk_alignment_for_contiguous_layout", ] From e8c55638dff79da69cd3da5895226d1949fe3f61 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 18 Nov 2025 02:06:37 -0500 Subject: [PATCH 15/27] Cleanup fallback code Signed-off-by: ElizaWszola --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index eaf56f0ef417..5dc9c1d47b4a 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -622,7 +622,6 @@ def per_token_group_quant_fp8( num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 if column_major_scales: - x_s = x_s.permute(-1, -2) _per_token_group_quant_fp8_colmajor[(M,)]( x, x_q, From c745e911d9ef63d9687f1f7edb9a5bb47a1ab382 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 19 Nov 2025 08:10:18 -0500 Subject: [PATCH 16/27] Cleanup comments, var names Signed-off-by: ElizaWszola --- .../fused_layernorm_dynamic_per_token_quant.cu | 8 ++++---- vllm/compilation/fusion.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index cf86468789bb..72b8f84bcd86 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -213,8 +213,8 @@ void rms_norm_per_block_quant_dispatch( int32_t hidden_size = input.size(-1); auto num_tokens = input.numel() / hidden_size; - dim3 grid13(num_tokens); - dim3 block13(std::min(hidden_size, 512)); + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -228,7 +228,7 @@ void rms_norm_per_block_quant_dispatch( VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { vllm::rms_norm_per_block_quant_kernel - <<>>( + <<>>( rms.data_ptr(), out.data_ptr(), scales.data_ptr(), input.data_ptr(), weight.data_ptr(), token_scale.data_ptr(), @@ -240,7 +240,7 @@ void rms_norm_per_block_quant_dispatch( VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { vllm::rms_norm_per_block_quant_kernel - <<>>( + <<>>( rms.data_ptr(), out.data_ptr(), scales.data_ptr(), input.data_ptr(), weight.data_ptr(), token_scale.data_ptr(), diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e480443af8a4..e25246b6ceba 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -473,12 +473,12 @@ def __init__(self, config: VllmConfig): # Make sure fused add patterns are before simple rms norm, # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: - # Try it before the dynamic one, maybe refator later to run both - # blockwise and per-token from the same pattern? + # Fuse fused_add_rms_norm + fp8 group quant FusedAddRMSNormGroupQuantPattern( epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) ).register(self.patterns) + # Fuse rms_norm + fp8 group quant RMSNormGroupQuantPattern( epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) ).register(self.patterns) From 949db4d71b0c65285e290937b8b0396bfedfd8cb Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 21 Nov 2025 02:27:15 -0500 Subject: [PATCH 17/27] Transpose scales if needed Signed-off-by: ElizaWszola --- csrc/ops.h | 2 +- ...fused_layernorm_dynamic_per_token_quant.cu | 93 +++++++++++++------ .../fused_kernels/layernorm_utils.cuh | 28 ++++-- csrc/torch_bindings.cpp | 3 +- .../core/test_fused_quant_layernorm.py | 2 +- vllm/_custom_ops.py | 11 ++- vllm/compilation/fusion.py | 10 +- 7 files changed, 105 insertions(+), 44 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 93b1a5acf7f5..e92ee122035d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -134,7 +134,7 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, double const epsilon, std::optional scale_ub, std::optional residual, - int64_t group_size); + int64_t group_size, bool is_scale_transposed); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 72b8f84bcd86..bf5b21e51a23 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -87,7 +87,8 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( } // RMS norm + quant kernel -template +template __global__ void rms_norm_per_block_quant_kernel( float* __restrict__ rms, scalar_out_t* __restrict__ out, // [..., hidden_size] @@ -104,8 +105,8 @@ __global__ void rms_norm_per_block_quant_kernel( // Compute Scale // Always able to vectorize due to constraints on hidden_size and group_size - vllm::vectorized::compute_dynamic_per_token_scales( + vllm::vectorized::compute_dynamic_per_token_scales< + scalar_t, scalar_out_t, has_residual, is_scale_transposed>( token_scale, scales, input, weight, rms[blockIdx.x], scale_ub, hidden_size, residual, group_size); @@ -206,10 +207,12 @@ void rms_norm_per_block_quant_dispatch( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] torch::Tensor const& weight, // [hidden_size] - torch::Tensor& scales, // [num_tokens, hidden_size / group_size] + torch::Tensor& scales, // [num_tokens, hidden_size / group_size] or + // [hidden_size / group_size, num_tokens] double const var_epsilon, // Variance epsilon used in norm calculation std::optional const& scale_ub, - std::optional& residual, int64_t group_size) { + std::optional& residual, int64_t group_size, + bool is_scale_transposed) { int32_t hidden_size = input.size(-1); auto num_tokens = input.numel() / hidden_size; @@ -225,28 +228,58 @@ void rms_norm_per_block_quant_dispatch( torch::empty({num_tokens * hidden_size / group_size}, fp_options); if (residual.has_value()) { - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { - vllm::rms_norm_per_block_quant_kernel - <<>>( - rms.data_ptr(), out.data_ptr(), - scales.data_ptr(), input.data_ptr(), - weight.data_ptr(), token_scale.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, residual->data_ptr(), - group_size); - }); + if (is_scale_transposed) { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { + vllm::rms_norm_per_block_quant_kernel< + scalar_in_t, scalar_t, true, true><<>>( + rms.data_ptr(), out.data_ptr(), + scales.data_ptr(), input.data_ptr(), + weight.data_ptr(), token_scale.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, hidden_size, residual->data_ptr(), + group_size); + }); + } else { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { + vllm::rms_norm_per_block_quant_kernel< + scalar_in_t, scalar_t, true, false><<>>( + rms.data_ptr(), out.data_ptr(), + scales.data_ptr(), input.data_ptr(), + weight.data_ptr(), token_scale.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, hidden_size, residual->data_ptr(), + group_size); + }); + } } else { - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { - vllm::rms_norm_per_block_quant_kernel - <<>>( - rms.data_ptr(), out.data_ptr(), - scales.data_ptr(), input.data_ptr(), - weight.data_ptr(), token_scale.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, nullptr, group_size); - }); + if (is_scale_transposed) { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { + vllm::rms_norm_per_block_quant_kernel< + scalar_in_t, scalar_t, false, true><<>>( + rms.data_ptr(), out.data_ptr(), + scales.data_ptr(), input.data_ptr(), + weight.data_ptr(), token_scale.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, hidden_size, nullptr, group_size); + }); + } else { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { + vllm::rms_norm_per_block_quant_kernel + <<>>( + rms.data_ptr(), out.data_ptr(), + scales.data_ptr(), input.data_ptr(), + weight.data_ptr(), + token_scale.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + var_epsilon, hidden_size, nullptr, group_size); + }); + } } } @@ -255,7 +288,7 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, double const var_epsilon, std::optional scale_ub, std::optional residual, - int64_t group_size) { + int64_t group_size, bool is_scale_transposed) { static c10::ScalarType kFp8Type = is_fp8_ocp() ? c10::ScalarType::Float8_e4m3fn : c10::ScalarType::Float8_e4m3fnuz; @@ -273,8 +306,8 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_per_block_quant_dispatch", [&] { - rms_norm_per_block_quant_dispatch(out, input, weight, scales, - var_epsilon, scale_ub, - residual, group_size); + rms_norm_per_block_quant_dispatch( + out, input, weight, scales, var_epsilon, scale_ub, residual, + group_size, is_scale_transposed); }); } \ No newline at end of file diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 64df224bc1b5..e92504b8e19b 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -62,7 +62,8 @@ __device__ float warpReduceMaxSpecialized(volatile float* val, int64_t tid, return val[tid]; } -template +template __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, @@ -125,9 +126,14 @@ __device__ void compute_dynamic_per_token_scales( } // token scale computation scale = max(scale / qmax, min_scaling_factor::val()); - all_token_scales[blockIdx.x * num_groups + - threadIdx.x / threads_per_group] = - scale; // Global output store + // Global output store + if constexpr (is_scale_transposed) { + all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x + + blockIdx.x] = scale; + } else { + all_token_scales[blockIdx.x * num_groups + + threadIdx.x / threads_per_group] = scale; + } token_scale[blockIdx.x * num_groups + threadIdx.x / threads_per_group] = scale; } @@ -263,7 +269,8 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, // Vectorized version of vllm::compute_dynamic_per_token_scales // hidden_size must be a multiple of 4 -template +template __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, @@ -362,9 +369,14 @@ __device__ void compute_dynamic_per_token_scales( } // token scale computation scale = max(scale / qmax, min_scaling_factor::val()); - all_token_scales[blockIdx.x * num_groups + - threadIdx.x / threads_per_group] = - scale; // Global output store + // Global output store + if constexpr (is_scale_transposed) { + all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x + + blockIdx.x] = scale; + } else { + all_token_scales[blockIdx.x * num_groups + + threadIdx.x / threads_per_group] = scale; + } token_scale[blockIdx.x * num_groups + threadIdx.x / threads_per_group] = scale; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 8396fe0c444e..1e2f693b14bd 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -232,7 +232,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "rms_norm_per_block_quant(Tensor! result, Tensor input, " "Tensor weight, Tensor! scale, float epsilon, " - "Tensor? scale_ub, Tensor!? residual, int group_size) -> ()"); + "Tensor? scale_ub, Tensor!? residual, int group_size, " + "bool is_scale_transposed) -> ()"); ops.impl("rms_norm_per_block_quant", torch::kCUDA, &rms_norm_per_block_quant); // Rotary embedding diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index e715cffd16ec..1dad7692c4d1 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -114,7 +114,7 @@ def ops_dynamic_per_token_or_block_quant( residual = residual.clone() if group_size is not None: out, scales = ops.rms_norm_per_block_quant( - x, weight, EPS, quant_dtype, group_size, scale_ub, residual + x, weight, EPS, quant_dtype, group_size, scale_ub, residual, False ) else: out, scales = ops.rms_norm_dynamic_per_token_quant( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3a7667a79801..dbecbfff2234 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -448,6 +448,7 @@ def rms_norm_per_block_quant( group_size: list[int], scale_ub: torch.Tensor | None = None, residual: torch.Tensor | None = None, + is_scale_transposed: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: assert len(group_size) == 2 output = torch.empty_like(input, dtype=quant_dtype) @@ -458,7 +459,15 @@ def rms_norm_per_block_quant( ) torch.ops._C.rms_norm_per_block_quant( - output, input, weight, scales, epsilon, scale_ub, residual, group_size[1] + output, + input, + weight, + scales, + epsilon, + scale_ub, + residual, + group_size[1], + is_scale_transposed, ) return output, scales diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e25246b6ceba..c9e1db508f94 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -273,7 +273,9 @@ def replacement( input = input.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) - scale = self.quant_matcher.make_scale(input) + scale = self.quant_matcher.make_scale( + input, transposed=self.quant_matcher.use_col_major_scales + ) at = auto_functionalized( self.FUSED_OP, result=result, @@ -284,6 +286,7 @@ def replacement( scale_ub=None, residual=residual, group_size=self.group_shape[1], + is_scale_transposed=self.quant_matcher.use_col_major_scales, ) # result, residual, scale @@ -326,7 +329,9 @@ def replacement(input: torch.Tensor, weight: torch.Tensor): input = input.to(dtype=self.model_dtype) result = torch.empty_like(input, dtype=self.quant_dtype) - scale = self.quant_matcher.make_scale(input) + scale = self.quant_matcher.make_scale( + input, transposed=self.quant_matcher.use_col_major_scales + ) at = auto_functionalized( self.FUSED_OP, result=result, @@ -337,6 +342,7 @@ def replacement(input: torch.Tensor, weight: torch.Tensor): scale_ub=None, residual=None, group_size=self.group_shape[1], + is_scale_transposed=self.quant_matcher.use_col_major_scales, ) # result, scale From e151ea74091478313cf72459c4390fefb21a9778 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 26 Nov 2025 10:48:56 -0500 Subject: [PATCH 18/27] Fix redundant write to scales, write to transposed scales too Signed-off-by: ElizaWszola --- ...fused_layernorm_dynamic_per_token_quant.cu | 28 ++++++++-------- .../fused_kernels/layernorm_utils.cuh | 33 +++++++++++++------ 2 files changed, 36 insertions(+), 25 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index bf5b21e51a23..143d061ae87d 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -93,10 +93,12 @@ __global__ void rms_norm_per_block_quant_kernel( float* __restrict__ rms, scalar_out_t* __restrict__ out, // [..., hidden_size] float* __restrict__ scales, // [num_tokens, hidden_size / group_size] + // or + // [hidden_size / group_size, num_tokens] scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ weight, // [hidden_size] - float* __restrict__ token_scale, float const* scale_ub, - float const var_epsilon, int32_t const hidden_size, + float* __restrict__ token_scale, // unused + float const* scale_ub, float const var_epsilon, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { // Compute RMS // Always able to vectorize due to constraints on hidden_size @@ -107,21 +109,20 @@ __global__ void rms_norm_per_block_quant_kernel( // Always able to vectorize due to constraints on hidden_size and group_size vllm::vectorized::compute_dynamic_per_token_scales< scalar_t, scalar_out_t, has_residual, is_scale_transposed>( - token_scale, scales, input, weight, rms[blockIdx.x], scale_ub, - hidden_size, residual, group_size); + nullptr, scales, input, weight, rms[blockIdx.x], scale_ub, hidden_size, + residual, group_size); // RMS Norm + Quant // Always able to vectorize due to constraints on hidden_size - int token_idx = blockIdx.x * hidden_size / group_size; // For int8, don't invert token_scale here: do it inside the norm_and_quant // kernel. We do it because particular elements of token_scale can be shared // between multiple threads, so this way, we avoid extra synchronization // overhead. vllm::vectorized::norm_and_quant, - has_residual>( - out, input, weight, rms[blockIdx.x], token_scale + token_idx, hidden_size, - residual, group_size); + has_residual, is_scale_transposed>( + out, input, weight, rms[blockIdx.x], scales, hidden_size, residual, + group_size); } } // namespace vllm @@ -224,8 +225,6 @@ void rms_norm_per_block_quant_dispatch( auto const fp_options = torch::TensorOptions().dtype(torch::kFloat32).device(input.device()); torch::Tensor rms = torch::empty({num_tokens}, fp_options); - torch::Tensor token_scale = - torch::empty({num_tokens * hidden_size / group_size}, fp_options); if (residual.has_value()) { if (is_scale_transposed) { @@ -235,7 +234,7 @@ void rms_norm_per_block_quant_dispatch( scalar_in_t, scalar_t, true, true><<>>( rms.data_ptr(), out.data_ptr(), scales.data_ptr(), input.data_ptr(), - weight.data_ptr(), token_scale.data_ptr(), + weight.data_ptr(), nullptr, scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, var_epsilon, hidden_size, residual->data_ptr(), group_size); @@ -247,7 +246,7 @@ void rms_norm_per_block_quant_dispatch( scalar_in_t, scalar_t, true, false><<>>( rms.data_ptr(), out.data_ptr(), scales.data_ptr(), input.data_ptr(), - weight.data_ptr(), token_scale.data_ptr(), + weight.data_ptr(), nullptr, scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, var_epsilon, hidden_size, residual->data_ptr(), group_size); @@ -261,7 +260,7 @@ void rms_norm_per_block_quant_dispatch( scalar_in_t, scalar_t, false, true><<>>( rms.data_ptr(), out.data_ptr(), scales.data_ptr(), input.data_ptr(), - weight.data_ptr(), token_scale.data_ptr(), + weight.data_ptr(), nullptr, scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, var_epsilon, hidden_size, nullptr, group_size); }); @@ -273,8 +272,7 @@ void rms_norm_per_block_quant_dispatch( <<>>( rms.data_ptr(), out.data_ptr(), scales.data_ptr(), input.data_ptr(), - weight.data_ptr(), - token_scale.data_ptr(), + weight.data_ptr(), nullptr, scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, var_epsilon, hidden_size, nullptr, group_size); diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index e92504b8e19b..68787d9c731a 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -134,8 +134,6 @@ __device__ void compute_dynamic_per_token_scales( all_token_scales[blockIdx.x * num_groups + threadIdx.x / threads_per_group] = scale; } - token_scale[blockIdx.x * num_groups + threadIdx.x / threads_per_group] = - scale; } __syncthreads(); } else { @@ -196,10 +194,18 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, x = static_cast(static_cast(x * rms) * weight[i]); // Quant // If groupwise is_scale_inverted is true, so we invert the scale here. + int64_t scale_idx = 0; + if (group_size > 0) { + if constexpr (is_scale_transposed) { + scale_idx = (i / group_size) * gridDim.x + blockIdx.x; + } else { + scale_idx = blockIdx.x * num_groups + i / group_size; + } + } auto scale_val = - (group_size > 0 ? (is_scale_inverted ? 1.0f / scale[i / group_size] - : scale[i / group_size]) - : *scale); + (group_size > 0 + ? (is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx]) + : *scale); output[token_offset + i] = ScaledQuant::quant_fn(x, scale_val); } @@ -377,8 +383,6 @@ __device__ void compute_dynamic_per_token_scales( all_token_scales[blockIdx.x * num_groups + threadIdx.x / threads_per_group] = scale; } - token_scale[blockIdx.x * num_groups + threadIdx.x / threads_per_group] = - scale; } __syncthreads(); @@ -447,7 +451,7 @@ __device__ void compute_dynamic_per_token_scales( // hidden_size must be a multiple of 4 template + bool has_residual = false, bool is_scale_transposed = false> __device__ void norm_and_quant(scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, @@ -501,10 +505,19 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, q8x4_t out; + int64_t num_groups = hidden_size / group_size; + int64_t scale_idx = 0; + if (group_size > 0) { + if constexpr (is_scale_transposed) { + scale_idx = (i * VEC_SIZE / group_size) * gridDim.x + blockIdx.x; + } else { + scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size; + } + } + auto scale_val = (group_size > 0 - ? (is_scale_inverted ? 1.0f / scale[i * VEC_SIZE / group_size] - : scale[i * VEC_SIZE / group_size]) + ? (is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx]) : *scale); #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { From 83648871b4c94ded0f99045e662398e9f17292eb Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 26 Nov 2025 12:08:41 -0500 Subject: [PATCH 19/27] Fix build Signed-off-by: ElizaWszola --- csrc/quantization/fused_kernels/layernorm_utils.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 68787d9c731a..c38968696e46 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -174,7 +174,7 @@ __device__ void compute_dynamic_per_token_scales( } template + bool has_residual = false, bool is_scale_transposed = false> __device__ void norm_and_quant(scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, @@ -199,7 +199,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, if constexpr (is_scale_transposed) { scale_idx = (i / group_size) * gridDim.x + blockIdx.x; } else { - scale_idx = blockIdx.x * num_groups + i / group_size; + scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size; } } auto scale_val = From ee2a354cf64cd1e2393a24b4b227c99e9b8c5f97 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 1 Dec 2025 07:09:51 +0000 Subject: [PATCH 20/27] Keep rms in shared memory in kernels Signed-off-by: ElizaWszola --- ...fused_layernorm_dynamic_per_token_quant.cu | 39 ++++++---------- .../fused_kernels/layernorm_utils.cuh | 44 +++++++++++++------ 2 files changed, 44 insertions(+), 39 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 143d061ae87d..8a4190b7383d 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -90,27 +90,26 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( template __global__ void rms_norm_per_block_quant_kernel( - float* __restrict__ rms, scalar_out_t* __restrict__ out, // [..., hidden_size] float* __restrict__ scales, // [num_tokens, hidden_size / group_size] // or // [hidden_size / group_size, num_tokens] scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ weight, // [hidden_size] - float* __restrict__ token_scale, // unused float const* scale_ub, float const var_epsilon, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { + __shared__ float s_rms; // Compute RMS // Always able to vectorize due to constraints on hidden_size vllm::vectorized::compute_rms( - rms + blockIdx.x, input, hidden_size, var_epsilon, residual); + &s_rms, input, hidden_size, var_epsilon, residual); // Compute Scale // Always able to vectorize due to constraints on hidden_size and group_size vllm::vectorized::compute_dynamic_per_token_scales< scalar_t, scalar_out_t, has_residual, is_scale_transposed>( - nullptr, scales, input, weight, rms[blockIdx.x], scale_ub, hidden_size, - residual, group_size); + nullptr, scales, input, weight, s_rms, scale_ub, hidden_size, residual, + group_size); // RMS Norm + Quant // Always able to vectorize due to constraints on hidden_size @@ -121,8 +120,7 @@ __global__ void rms_norm_per_block_quant_kernel( vllm::vectorized::norm_and_quant, has_residual, is_scale_transposed>( - out, input, weight, rms[blockIdx.x], scales, hidden_size, residual, - group_size); + out, input, weight, s_rms, scales, hidden_size, residual, group_size); } } // namespace vllm @@ -201,8 +199,6 @@ void rms_norm_dynamic_per_token_quant( } // Residual add + RMS norm + dynamic per token -// TODO think up better names than kernel_1, kernel_2, kernel_3, cleanup args -// TODO vectorized kernels template void rms_norm_per_block_quant_dispatch( torch::Tensor& out, // [..., hidden_size] @@ -222,19 +218,14 @@ void rms_norm_per_block_quant_dispatch( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - auto const fp_options = - torch::TensorOptions().dtype(torch::kFloat32).device(input.device()); - torch::Tensor rms = torch::empty({num_tokens}, fp_options); - if (residual.has_value()) { if (is_scale_transposed) { VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { vllm::rms_norm_per_block_quant_kernel< scalar_in_t, scalar_t, true, true><<>>( - rms.data_ptr(), out.data_ptr(), - scales.data_ptr(), input.data_ptr(), - weight.data_ptr(), nullptr, + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, var_epsilon, hidden_size, residual->data_ptr(), group_size); @@ -244,9 +235,8 @@ void rms_norm_per_block_quant_dispatch( out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { vllm::rms_norm_per_block_quant_kernel< scalar_in_t, scalar_t, true, false><<>>( - rms.data_ptr(), out.data_ptr(), - scales.data_ptr(), input.data_ptr(), - weight.data_ptr(), nullptr, + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, var_epsilon, hidden_size, residual->data_ptr(), group_size); @@ -258,9 +248,8 @@ void rms_norm_per_block_quant_dispatch( out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { vllm::rms_norm_per_block_quant_kernel< scalar_in_t, scalar_t, false, true><<>>( - rms.data_ptr(), out.data_ptr(), - scales.data_ptr(), input.data_ptr(), - weight.data_ptr(), nullptr, + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, var_epsilon, hidden_size, nullptr, group_size); }); @@ -270,9 +259,9 @@ void rms_norm_per_block_quant_dispatch( vllm::rms_norm_per_block_quant_kernel <<>>( - rms.data_ptr(), out.data_ptr(), - scales.data_ptr(), input.data_ptr(), - weight.data_ptr(), nullptr, + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, var_epsilon, hidden_size, nullptr, group_size); diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index c38968696e46..c8d194149ff6 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -16,7 +16,8 @@ namespace vllm { template __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, int32_t const hidden_size, float const epsilon, - scalar_t const* __restrict__ residual = nullptr) { + scalar_t const* __restrict__ residual = nullptr, + int32_t const group_size = 0) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // sum of squares float ss = 0.0f; @@ -34,13 +35,20 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, __shared__ typename BlockReduce::TempStorage reduceStore; ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x); - __shared__ float s_rms; - if (threadIdx.x == 0) { - s_rms = rsqrtf(ss / hidden_size + epsilon); - } - __syncthreads(); + if (group_size > 0) { + if (threadIdx.x == 0) { + *rms = rsqrtf(ss / hidden_size + epsilon); + } + __syncthreads(); + } else { + __shared__ float s_rms; + if (threadIdx.x == 0) { + s_rms = rsqrtf(ss / hidden_size + epsilon); + } + __syncthreads(); - *rms = s_rms; + *rms = s_rms; + } } // TODO replace 32 with WARP_SIZE @@ -218,7 +226,8 @@ namespace vectorized { template __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, int32_t const hidden_size, float const epsilon, - scalar_t const* __restrict__ residual = nullptr) { + scalar_t const* __restrict__ residual = nullptr, + int32_t const group_size = 0) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // Vectorized input/output to better utilize memory bandwidth. @@ -264,13 +273,20 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, __shared__ typename BlockReduce::TempStorage reduceStore; ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x); - __shared__ float s_rms; - if (threadIdx.x == 0) { - s_rms = rsqrtf(ss / hidden_size + epsilon); - } - __syncthreads(); + if (group_size > 0) { + if (threadIdx.x == 0) { + *rms = rsqrtf(ss / hidden_size + epsilon); + } + __syncthreads(); + } else { + __shared__ float s_rms; + if (threadIdx.x == 0) { + s_rms = rsqrtf(ss / hidden_size + epsilon); + } + __syncthreads(); - *rms = s_rms; + *rms = s_rms; + } } // Vectorized version of vllm::compute_dynamic_per_token_scales From e2b82b1acfba7419d313e8c354f5cc00a248b669 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 1 Dec 2025 08:19:45 +0000 Subject: [PATCH 21/27] Constexpr group size Signed-off-by: ElizaWszola --- ...fused_layernorm_dynamic_per_token_quant.cu | 100 +++++++++++------- .../fused_kernels/layernorm_utils.cuh | 39 +++---- 2 files changed, 80 insertions(+), 59 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 8a4190b7383d..cdb5d81dfbac 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -88,7 +88,7 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( // RMS norm + quant kernel template + bool is_scale_transposed = false, int32_t group_size = 0> __global__ void rms_norm_per_block_quant_kernel( scalar_out_t* __restrict__ out, // [..., hidden_size] float* __restrict__ scales, // [num_tokens, hidden_size / group_size] @@ -97,19 +97,18 @@ __global__ void rms_norm_per_block_quant_kernel( scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ weight, // [hidden_size] float const* scale_ub, float const var_epsilon, int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) { + scalar_t* __restrict__ residual = nullptr) { __shared__ float s_rms; // Compute RMS // Always able to vectorize due to constraints on hidden_size - vllm::vectorized::compute_rms( + vllm::vectorized::compute_rms( &s_rms, input, hidden_size, var_epsilon, residual); // Compute Scale // Always able to vectorize due to constraints on hidden_size and group_size vllm::vectorized::compute_dynamic_per_token_scales< - scalar_t, scalar_out_t, has_residual, is_scale_transposed>( - nullptr, scales, input, weight, s_rms, scale_ub, hidden_size, residual, - group_size); + scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>( + nullptr, scales, input, weight, s_rms, scale_ub, hidden_size, residual); // RMS Norm + Quant // Always able to vectorize due to constraints on hidden_size @@ -117,10 +116,10 @@ __global__ void rms_norm_per_block_quant_kernel( // kernel. We do it because particular elements of token_scale can be shared // between multiple threads, so this way, we avoid extra synchronization // overhead. - vllm::vectorized::norm_and_quant, - has_residual, is_scale_transposed>( - out, input, weight, s_rms, scales, hidden_size, residual, group_size); + vllm::vectorized::norm_and_quant< + scalar_t, scalar_out_t, std::is_same_v, + has_residual, is_scale_transposed, group_size>( + out, input, weight, s_rms, scales, hidden_size, residual); } } // namespace vllm @@ -199,7 +198,7 @@ void rms_norm_dynamic_per_token_quant( } // Residual add + RMS norm + dynamic per token -template +template void rms_norm_per_block_quant_dispatch( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] @@ -208,8 +207,7 @@ void rms_norm_per_block_quant_dispatch( // [hidden_size / group_size, num_tokens] double const var_epsilon, // Variance epsilon used in norm calculation std::optional const& scale_ub, - std::optional& residual, int64_t group_size, - bool is_scale_transposed) { + std::optional& residual, bool is_scale_transposed) { int32_t hidden_size = input.size(-1); auto num_tokens = input.numel() / hidden_size; @@ -222,49 +220,58 @@ void rms_norm_per_block_quant_dispatch( if (is_scale_transposed) { VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { - vllm::rms_norm_per_block_quant_kernel< - scalar_in_t, scalar_t, true, true><<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, residual->data_ptr(), - group_size); + vllm::rms_norm_per_block_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + var_epsilon, hidden_size, + residual->data_ptr()); }); } else { VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { - vllm::rms_norm_per_block_quant_kernel< - scalar_in_t, scalar_t, true, false><<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, residual->data_ptr(), - group_size); + vllm::rms_norm_per_block_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + var_epsilon, hidden_size, + residual->data_ptr()); }); } } else { if (is_scale_transposed) { VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { - vllm::rms_norm_per_block_quant_kernel< - scalar_in_t, scalar_t, false, true><<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, nullptr, group_size); + vllm::rms_norm_per_block_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + var_epsilon, hidden_size, nullptr); }); } else { VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { vllm::rms_norm_per_block_quant_kernel + false, group_size> <<>>( out.data_ptr(), scales.data_ptr(), input.data_ptr(), weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, nullptr, group_size); + var_epsilon, hidden_size, nullptr); }); } } @@ -291,10 +298,21 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, TORCH_CHECK(residual->scalar_type() == input.scalar_type()); } - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "rms_norm_per_block_quant_dispatch", [&] { - rms_norm_per_block_quant_dispatch( - out, input, weight, scales, var_epsilon, scale_ub, residual, - group_size, is_scale_transposed); - }); + if (group_size == 128) { + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_per_block_quant_dispatch", [&] { + rms_norm_per_block_quant_dispatch( + out, input, weight, scales, var_epsilon, scale_ub, residual, + is_scale_transposed); + }); + } else if (group_size == 64) { + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_per_block_quant_dispatch", [&] { + rms_norm_per_block_quant_dispatch( + out, input, weight, scales, var_epsilon, scale_ub, residual, + is_scale_transposed); + }); + } else { + TORCH_CHECK(false, "Unsupported group size: ", group_size); + } } \ No newline at end of file diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index c8d194149ff6..8136676eb528 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -223,11 +223,10 @@ namespace vectorized { // Compute 1.0/rms(input) // hidden_size must be a multiple of 4 -template +template __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, int32_t const hidden_size, float const epsilon, - scalar_t const* __restrict__ residual = nullptr, - int32_t const group_size = 0) { + scalar_t const* __restrict__ residual = nullptr) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // Vectorized input/output to better utilize memory bandwidth. @@ -273,7 +272,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, __shared__ typename BlockReduce::TempStorage reduceStore; ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x); - if (group_size > 0) { + if constexpr (group_size > 0) { if (threadIdx.x == 0) { *rms = rsqrtf(ss / hidden_size + epsilon); } @@ -292,13 +291,13 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, // Vectorized version of vllm::compute_dynamic_per_token_scales // hidden_size must be a multiple of 4 template + bool is_scale_transposed = false, int32_t group_size = 0> __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, - int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, - int32_t const group_size = 0) { + int32_t const hidden_size, + scalar_t const* __restrict__ residual = nullptr) { constexpr scalar_out_t qmax{quant_type_max_v}; const int VEC_SIZE = 4; @@ -309,11 +308,11 @@ __device__ void compute_dynamic_per_token_scales( vec4_t const* vec_weight = nullptr; vec4_t const* vec_residual = nullptr; - if (group_size > 0) { + if constexpr (group_size > 0) { __shared__ float s_max_vals[1024]; int64_t const token_offset = blockIdx.x * static_cast(hidden_size); - int64_t num_groups = hidden_size / group_size; + int64_t const num_groups = hidden_size / group_size; int64_t const threads_per_group = blockDim.x / num_groups; int64_t const thread_in_group = threadIdx.x % threads_per_group; int64_t const group_offset = @@ -467,14 +466,14 @@ __device__ void compute_dynamic_per_token_scales( // hidden_size must be a multiple of 4 template + bool has_residual = false, bool is_scale_transposed = false, + int32_t group_size = 0> __device__ void norm_and_quant(scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float* const scale, int32_t const hidden_size, - scalar_t* __restrict__ residual = nullptr, - int32_t const group_size = 0) { + scalar_t* __restrict__ residual = nullptr) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // Vectorized input/output/weight/residual to better utilize memory bandwidth. @@ -521,9 +520,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, q8x4_t out; - int64_t num_groups = hidden_size / group_size; + int64_t const num_groups = hidden_size / group_size; int64_t scale_idx = 0; - if (group_size > 0) { + if constexpr (group_size > 0) { if constexpr (is_scale_transposed) { scale_idx = (i * VEC_SIZE / group_size) * gridDim.x + blockIdx.x; } else { @@ -531,10 +530,14 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, } } - auto scale_val = - (group_size > 0 - ? (is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx]) - : *scale); + float scale_val; + + if constexpr (group_size > 0) { + scale_val = + is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx]; + } else { + scale_val = *scale; + } #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { out.val[j] = ScaledQuant::quant_fn( From bf0d3b5780dba6881a6451ebf128c1c0ca317fcb Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 1 Dec 2025 15:14:28 +0000 Subject: [PATCH 22/27] Fix zero division when group_size is 0 Signed-off-by: ElizaWszola --- csrc/quantization/fused_kernels/layernorm_utils.cuh | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 8136676eb528..f697873f9d7b 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -520,19 +520,16 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, q8x4_t out; - int64_t const num_groups = hidden_size / group_size; - int64_t scale_idx = 0; + float scale_val; + if constexpr (group_size > 0) { + int64_t const num_groups = hidden_size / group_size; + int64_t scale_idx = 0; if constexpr (is_scale_transposed) { scale_idx = (i * VEC_SIZE / group_size) * gridDim.x + blockIdx.x; } else { scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size; } - } - - float scale_val; - - if constexpr (group_size > 0) { scale_val = is_scale_inverted ? 1.0f / scale[scale_idx] : scale[scale_idx]; } else { From 5caf3a708f47421511bf7266824b2ba0a280671b Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 3 Dec 2025 14:26:02 +0000 Subject: [PATCH 23/27] A few more improvements Signed-off-by: ElizaWszola --- .../fused_kernels/layernorm_rms_benchmarks.py | 8 +++- ...fused_layernorm_dynamic_per_token_quant.cu | 11 ++--- .../fused_kernels/layernorm_utils.cuh | 41 ++++++------------- .../core/test_fused_quant_layernorm.py | 3 +- vllm/_custom_ops.py | 17 +++++--- 5 files changed, 40 insertions(+), 40 deletions(-) diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index 8591c53d2691..fb3329975cee 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -129,7 +129,13 @@ def fused_groupwise_impl( group_size: list[int], ): out, _ = ops.rms_norm_per_block_quant( - x, rms_norm_layer.weight, 1e-6, quant_dtype, group_size, residual=residual + x, + rms_norm_layer.weight, + 1e-6, + quant_dtype, + group_size, + residual=residual, + is_scale_transposed=True, ) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index cdb5d81dfbac..fa22efc248b3 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -98,17 +98,17 @@ __global__ void rms_norm_per_block_quant_kernel( scalar_t const* __restrict__ weight, // [hidden_size] float const* scale_ub, float const var_epsilon, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr) { - __shared__ float s_rms; + float rms; // Compute RMS // Always able to vectorize due to constraints on hidden_size vllm::vectorized::compute_rms( - &s_rms, input, hidden_size, var_epsilon, residual); + &rms, input, hidden_size, var_epsilon, residual); // Compute Scale // Always able to vectorize due to constraints on hidden_size and group_size vllm::vectorized::compute_dynamic_per_token_scales< scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>( - nullptr, scales, input, weight, s_rms, scale_ub, hidden_size, residual); + nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual); // RMS Norm + Quant // Always able to vectorize due to constraints on hidden_size @@ -119,7 +119,7 @@ __global__ void rms_norm_per_block_quant_kernel( vllm::vectorized::norm_and_quant< scalar_t, scalar_out_t, std::is_same_v, has_residual, is_scale_transposed, group_size>( - out, input, weight, s_rms, scales, hidden_size, residual); + out, input, weight, rms, scales, hidden_size, residual); } } // namespace vllm @@ -212,7 +212,8 @@ void rms_norm_per_block_quant_dispatch( auto num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 512)); + const int max_block_size = (num_tokens <= 256) ? 512 : 256; + dim3 block(std::min(hidden_size, max_block_size)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index f697873f9d7b..f890fe3e2395 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -16,8 +16,7 @@ namespace vllm { template __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, int32_t const hidden_size, float const epsilon, - scalar_t const* __restrict__ residual = nullptr, - int32_t const group_size = 0) { + scalar_t const* __restrict__ residual = nullptr) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); // sum of squares float ss = 0.0f; @@ -35,20 +34,13 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, __shared__ typename BlockReduce::TempStorage reduceStore; ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x); - if (group_size > 0) { - if (threadIdx.x == 0) { - *rms = rsqrtf(ss / hidden_size + epsilon); - } - __syncthreads(); - } else { - __shared__ float s_rms; - if (threadIdx.x == 0) { - s_rms = rsqrtf(ss / hidden_size + epsilon); - } - __syncthreads(); - - *rms = s_rms; + __shared__ float s_rms; + if (threadIdx.x == 0) { + s_rms = rsqrtf(ss / hidden_size + epsilon); } + __syncthreads(); + + *rms = s_rms; } // TODO replace 32 with WARP_SIZE @@ -272,20 +264,13 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, __shared__ typename BlockReduce::TempStorage reduceStore; ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x); - if constexpr (group_size > 0) { - if (threadIdx.x == 0) { - *rms = rsqrtf(ss / hidden_size + epsilon); - } - __syncthreads(); - } else { - __shared__ float s_rms; - if (threadIdx.x == 0) { - s_rms = rsqrtf(ss / hidden_size + epsilon); - } - __syncthreads(); - - *rms = s_rms; + __shared__ float s_rms; + if (threadIdx.x == 0) { + s_rms = rsqrtf(ss / hidden_size + epsilon); } + __syncthreads(); + + *rms = s_rms; } // Vectorized version of vllm::compute_dynamic_per_token_scales diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 1dad7692c4d1..094073f5d3f9 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -114,8 +114,9 @@ def ops_dynamic_per_token_or_block_quant( residual = residual.clone() if group_size is not None: out, scales = ops.rms_norm_per_block_quant( - x, weight, EPS, quant_dtype, group_size, scale_ub, residual, False + x, weight, EPS, quant_dtype, group_size, scale_ub, residual, True ) + scales = scales.contiguous() else: out, scales = ops.rms_norm_dynamic_per_token_quant( x, weight, EPS, quant_dtype, scale_ub, residual diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index fe6cad38439f..b9786c95bd24 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -449,11 +449,18 @@ def rms_norm_per_block_quant( ) -> tuple[torch.Tensor, torch.Tensor]: assert len(group_size) == 2 output = torch.empty_like(input, dtype=quant_dtype) - scales = torch.empty( - (input.numel() // input.shape[-1], input.shape[-1] // group_size[1]), - device=input.device, - dtype=torch.float32, - ) + if is_scale_transposed: + scales = torch.empty( + (input.shape[-1] // group_size[1], input.numel() // input.shape[-1]), + device=input.device, + dtype=torch.float32, + ).transpose(0, 1) + else: + scales = torch.empty( + (input.numel() // input.shape[-1], input.shape[-1] // group_size[1]), + device=input.device, + dtype=torch.float32, + ) torch.ops._C.rms_norm_per_block_quant( output, From 92fd8c9a979abf41cf0ebb8a984c2978b8f06150 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 3 Dec 2025 16:18:42 +0000 Subject: [PATCH 24/27] Cleanup unused template Signed-off-by: ElizaWszola --- .../fused_kernels/fused_layernorm_dynamic_per_token_quant.cu | 2 +- csrc/quantization/fused_kernels/layernorm_utils.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index fa22efc248b3..89dc364666e4 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -101,7 +101,7 @@ __global__ void rms_norm_per_block_quant_kernel( float rms; // Compute RMS // Always able to vectorize due to constraints on hidden_size - vllm::vectorized::compute_rms( + vllm::vectorized::compute_rms( &rms, input, hidden_size, var_epsilon, residual); // Compute Scale diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index f890fe3e2395..5a22a919a648 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -215,7 +215,7 @@ namespace vectorized { // Compute 1.0/rms(input) // hidden_size must be a multiple of 4 -template +template __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, int32_t const hidden_size, float const epsilon, scalar_t const* __restrict__ residual = nullptr) { From 06e3645828dc8ed27b7cf1fa1ad3610b7d293f6d Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 5 Dec 2025 08:55:39 +0000 Subject: [PATCH 25/27] Feedback, add gs==64 to fusion tests Signed-off-by: ElizaWszola --- csrc/dispatch_utils.h | 18 +++ ...fused_layernorm_dynamic_per_token_quant.cu | 131 +++++------------- .../fused_kernels/layernorm_utils.cuh | 11 +- tests/compile/test_fusion.py | 34 +++-- vllm/compilation/fusion.py | 17 +++ vllm/compilation/matcher_utils.py | 22 ++- .../layers/quantization/utils/quant_utils.py | 3 + 7 files changed, 115 insertions(+), 121 deletions(-) diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index e1d131e4a785..de0c505b7a62 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -118,6 +118,24 @@ } \ } +#define VLLM_DISPATCH_BOOL(expr, const_expr, ...) \ + if (expr) { \ + constexpr bool const_expr = true; \ + __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + __VA_ARGS__(); \ + } + +#define VLLM_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \ + if (group_size == 128) { \ + constexpr int const_group_size = 128; \ + __VA_ARGS__(); \ + } else if (group_size == 64) { \ + constexpr int const_group_size = 64; \ + __VA_ARGS__(); \ + } + #define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \ switch (NUM_DIMS) { \ case 2: { \ diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 89dc364666e4..98135c3a1d46 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -142,30 +142,19 @@ void rms_norm_dynamic_per_token_quant_dispatch( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (residual.has_value()) { + VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { vllm::rms_norm_dynamic_per_token_quant_kernel + has_residual> <<>>( out.data_ptr(), scales.data_ptr(), input.data_ptr(), weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, residual->data_ptr()); + var_epsilon, hidden_size, + has_residual ? residual->data_ptr() : nullptr); }); - - } else { - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { - vllm::rms_norm_dynamic_per_token_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, hidden_size, nullptr); - }); - } + }); } void rms_norm_dynamic_per_token_quant( @@ -198,14 +187,15 @@ void rms_norm_dynamic_per_token_quant( } // Residual add + RMS norm + dynamic per token -template +template void rms_norm_per_block_quant_dispatch( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] torch::Tensor const& weight, // [hidden_size] torch::Tensor& scales, // [num_tokens, hidden_size / group_size] or // [hidden_size / group_size, num_tokens] - double const var_epsilon, // Variance epsilon used in norm calculation + int32_t group_size, + double const var_epsilon, // Variance epsilon used in norm calculation std::optional const& scale_ub, std::optional& residual, bool is_scale_transposed) { int32_t hidden_size = input.size(-1); @@ -217,65 +207,26 @@ void rms_norm_per_block_quant_dispatch( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (residual.has_value()) { - if (is_scale_transposed) { - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { - vllm::rms_norm_per_block_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() - : nullptr, - var_epsilon, hidden_size, - residual->data_ptr()); - }); - } else { - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { - vllm::rms_norm_per_block_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() - : nullptr, - var_epsilon, hidden_size, - residual->data_ptr()); - }); - } - } else { - if (is_scale_transposed) { - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { - vllm::rms_norm_per_block_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() - : nullptr, - var_epsilon, hidden_size, nullptr); - }); - } else { - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { - vllm::rms_norm_per_block_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() - : nullptr, - var_epsilon, hidden_size, nullptr); - }); - } - } + VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] { + VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { + VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { + vllm::rms_norm_per_block_quant_kernel< + scalar_in_t, scalar_t, has_residual, transpose_scale, gs> + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + var_epsilon, hidden_size, + has_residual ? residual->data_ptr() + : nullptr); + }); + }); + }); + }); } void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, @@ -299,21 +250,13 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, TORCH_CHECK(residual->scalar_type() == input.scalar_type()); } - if (group_size == 128) { - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "rms_norm_per_block_quant_dispatch", [&] { - rms_norm_per_block_quant_dispatch( - out, input, weight, scales, var_epsilon, scale_ub, residual, - is_scale_transposed); - }); - } else if (group_size == 64) { - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "rms_norm_per_block_quant_dispatch", [&] { - rms_norm_per_block_quant_dispatch( - out, input, weight, scales, var_epsilon, scale_ub, residual, - is_scale_transposed); - }); - } else { - TORCH_CHECK(false, "Unsupported group size: ", group_size); - } + TORCH_CHECK(group_size == 128 || group_size == 64, + "Unsupported group size: ", group_size); + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_per_block_quant_dispatch", [&] { + rms_norm_per_block_quant_dispatch( + out, input, weight, scales, group_size, var_epsilon, scale_ub, + residual, is_scale_transposed); + }); } \ No newline at end of file diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index 5a22a919a648..cb7adc312573 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -9,6 +9,7 @@ #include "quant_conversions.cuh" #include "../../cub_helpers.h" +#include "../../cuda_compat.h" namespace vllm { @@ -43,10 +44,14 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, *rms = s_rms; } -// TODO replace 32 with WARP_SIZE __device__ float warpReduceMaxSpecialized(volatile float* val, int64_t tid, int64_t thread_in_warp, int64_t reduced_elems) { + static_assert(WARP_SIZE == 32 || WARP_SIZE == 64); + if constexpr (WARP_SIZE == 64) { + if (thread_in_warp + 64 < reduced_elems) + val[tid] = fmaxf(val[tid], val[tid + 64]); + } if (thread_in_warp + 32 < reduced_elems) val[tid] = fmaxf(val[tid], val[tid + 32]); if (thread_in_warp + 16 < reduced_elems) @@ -94,7 +99,7 @@ __device__ void compute_dynamic_per_token_scales( s_max_vals[threadIdx.x] = block_absmax_val_maybe; __syncthreads(); - int64_t const warp_size = 32; + int64_t const warp_size = WARP_SIZE; int64_t const num_warps = blockDim.x / warp_size; int64_t const warp_id = threadIdx.x / warp_size; int64_t const thread_in_warp = threadIdx.x % warp_size; @@ -343,7 +348,7 @@ __device__ void compute_dynamic_per_token_scales( s_max_vals[threadIdx.x] = block_absmax_val_maybe; __syncthreads(); - int64_t const warp_size = 32; + int64_t const warp_size = WARP_SIZE; int64_t const num_warps = blockDim.x / warp_size; int64_t const warp_id = threadIdx.x / warp_size; int64_t const thread_in_warp = threadIdx.x % warp_size; diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index a1859656fd05..2ad34a79859a 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -33,6 +33,7 @@ maybe_create_device_identity, ) from vllm.platforms import current_platform +from vllm.utils.deep_gemm import is_deep_gemm_supported from ..utils import override_cutlass_fp8_supported from .backend import TestBackend @@ -56,10 +57,11 @@ def __init__( super().__init__(*args, **kwargs) self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] - if group_shape == GroupShape(1, 128): + if group_shape.is_per_group(): self.wscale = [ torch.rand( - (hidden_size // 128, hidden_size // 128), dtype=torch.float32 + (hidden_size // group_shape[1], hidden_size // group_shape[1]), + dtype=torch.float32, ) for _ in range(3) ] @@ -75,29 +77,26 @@ def __init__( self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3) ] - if group_shape != GroupShape(1, 128): + if not group_shape.is_per_group(): self.w = [self.w[0].t() for _ in range(3)] - if group_shape == GroupShape(1, 128): + if group_shape.is_per_group(): self.fp8_linear = W8A8BlockFp8LinearOp( - weight_group_shape=GroupShape(128, 128), + weight_group_shape=GroupShape(group_shape[1], group_shape[1]), act_quant_group_shape=group_shape, cutlass_block_fp8_supported=cutlass_block_fp8_supported(), use_aiter_and_is_supported=False, ) + self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled() else: with override_cutlass_fp8_supported(not cuda_force_torch): self.fp8_linear = Fp8LinearOp( act_quant_static=static, act_quant_group_shape=group_shape, ) + self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() self.enable_rms_norm_custom_op = self.norm[0].enabled() - self.enable_quant_fp8_custom_op = ( - self.fp8_linear.quant_fp8.enabled() - if group_shape != GroupShape(1, 128) - else True - ) self.group_shape = group_shape def forward(self, x): @@ -145,7 +144,12 @@ def ops_in_model_before_partial(self): ) -GROUP_SHAPES = [GroupShape.PER_TOKEN, GroupShape.PER_TENSOR, GroupShape(1, 128)] +GROUP_SHAPES = [ + GroupShape.PER_TOKEN, + GroupShape.PER_TENSOR, + GroupShape(1, 128), + GroupShape(1, 64), +] @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -178,9 +182,15 @@ def test_fusion_rmsnorm_quant( torch.manual_seed(1) maybe_create_device_identity() # needed for certain non-cutlass fp8 paths - if not enable_quant_fp8_custom_op and group_shape == GroupShape(1, 128): + if not enable_quant_fp8_custom_op and group_shape.is_per_group(): pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization") + # Skip test for 64-bit group shape when running with cutlass or deepgemm + if group_shape == GroupShape(1, 64) and ( + cutlass_block_fp8_supported() or is_deep_gemm_supported() + ): + pytest.skip("Unsupported group shape 64 for CUTLASS/DeepGemm") + custom_ops = [] if enable_rms_norm_custom_op: custom_ops.append("+rms_norm") diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index c9e1db508f94..74592f1e3c45 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -15,6 +15,7 @@ GroupShape, QuantKey, ScaleDesc, + kFp8Dynamic64Sym, kFp8Dynamic128Sym, kFp8DynamicTensorSym, kFp8DynamicTokenSym, @@ -65,6 +66,7 @@ def empty_i64(*args, **kwargs): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default if current_platform.is_cuda(): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 + QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 class FusedRMSQuantKey(NamedTuple): @@ -103,6 +105,12 @@ def __str__(self): FusedRMSQuantKey( kFp8Dynamic128Sym, True ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic64Sym, False + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8Dynamic64Sym, True + ): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501 } @@ -489,6 +497,15 @@ def __init__(self, config: VllmConfig): epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128) ).register(self.patterns) + FusedAddRMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) + ).register(self.patterns) + + # Fuse rms_norm + fp8 group quant + RMSNormGroupQuantPattern( + epsilon, FP8_DTYPE, group_shape=GroupShape(1, 64) + ).register(self.patterns) + # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( self.patterns diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index c9e6bd76dfda..ddfc44ddc05d 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -11,9 +11,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, QuantKey, _normalize_quant_group_shape, + kFp8Dynamic64Sym, kFp8Dynamic128Sym, kFp8DynamicTensorSym, kFp8DynamicTokenSym, @@ -39,6 +39,7 @@ if current_platform.is_cuda(): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 + QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 SILU_MUL_OP = torch.ops._C.silu_and_mul.default @@ -259,7 +260,7 @@ def forward_custom( input.shape, device=input.device, dtype=self.quant_key.dtype ) - if self.quant_key.scale.group_shape == GroupShape(1, 128): + if self.quant_key.scale.group_shape.is_per_group(): assert scale is None scale = self.make_scale(input, transposed=self.use_col_major_scales) @@ -305,20 +306,17 @@ def make_scale(self, input: torch.Tensor, transposed: bool = False): normalized_group_shape = _normalize_quant_group_shape( input, self.quant_key.scale.group_shape ) + scale_shape = ( + input.shape[0] // normalized_group_shape[0], + input.shape[1] // normalized_group_shape[1], + ) if transposed: - scale_shape = ( - input.shape[1] // normalized_group_shape[1], - input.shape[0] // normalized_group_shape[0], - ) + scale_shape = tuple(reversed(scale_shape)) return torch.empty( scale_shape, device=input.device, dtype=torch.float32 ).permute(-1, -2) - else: - scale_shape = ( - input.shape[0] // normalized_group_shape[0], - input.shape[1] // normalized_group_shape[1], - ) - return torch.empty(scale_shape, device=input.device, dtype=torch.float32) + + return torch.empty(scale_shape, device=input.device, dtype=torch.float32) def inputs(self) -> list[torch.Tensor]: input = self.empty(5, 16) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 0e4998f4fa4f..92ee8c498e01 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -118,6 +118,9 @@ def __str__(self): kDynamic128Scale = ScaleDesc(torch.float32, False, GroupShape(1, 128)) kFp8Dynamic128Sym = QuantKey(FP8_DTYPE, kDynamic128Scale, symmetric=True) +kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64)) +kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True) + # Normalize the group_shape to the full extent for any dims that are -1 def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): From 416f173133e81cf6897172ac3c4c415d8ab0b040 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 5 Dec 2025 09:26:32 +0000 Subject: [PATCH 26/27] Move type dispatch to dispatch function Signed-off-by: ElizaWszola --- ...fused_layernorm_dynamic_per_token_quant.cu | 51 ++++++++++--------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 98135c3a1d46..2080ef3cd39b 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -187,7 +187,6 @@ void rms_norm_dynamic_per_token_quant( } // Residual add + RMS norm + dynamic per token -template void rms_norm_per_block_quant_dispatch( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] @@ -207,26 +206,31 @@ void rms_norm_per_block_quant_dispatch( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] { - VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { - VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] { - VLLM_DISPATCH_QUANT_TYPES( - out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { - vllm::rms_norm_per_block_quant_kernel< - scalar_in_t, scalar_t, has_residual, transpose_scale, gs> - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() - : nullptr, - var_epsilon, hidden_size, - has_residual ? residual->data_ptr() - : nullptr); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_per_block_quant_fp_dispatch", [&] { + using scalar_in_t = scalar_t; + VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] { + VLLM_DISPATCH_BOOL(residual.has_value(), has_residual, [&] { + VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] { + vllm::rms_norm_per_block_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + var_epsilon, hidden_size, + has_residual ? residual->data_ptr() + : nullptr); + }); }); + }); + }); }); - }); - }); } void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, @@ -253,10 +257,7 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, TORCH_CHECK(group_size == 128 || group_size == 64, "Unsupported group size: ", group_size); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "rms_norm_per_block_quant_dispatch", [&] { - rms_norm_per_block_quant_dispatch( - out, input, weight, scales, group_size, var_epsilon, scale_ub, - residual, is_scale_transposed); - }); + rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size, + var_epsilon, scale_ub, residual, + is_scale_transposed); } \ No newline at end of file From b2e22516377ad3ee87b19699768c65101ca65ad0 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 5 Dec 2025 16:45:54 +0000 Subject: [PATCH 27/27] Fix deepgemm when we use e8m0 Signed-off-by: ElizaWszola --- vllm/compilation/fusion.py | 20 +++++++++++--------- vllm/compilation/matcher_utils.py | 4 +++- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 74592f1e3c45..de083a2e5e3c 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -27,7 +27,10 @@ cutlass_block_fp8_supported, ) from vllm.platforms import current_platform -from vllm.utils.deep_gemm import should_use_deepgemm_for_fp8_linear_for_nk +from vllm.utils.deep_gemm import ( + is_deep_gemm_e8m0_used, + should_use_deepgemm_for_fp8_linear_for_nk, +) from .inductor_pass import enable_fake_mode from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm @@ -122,14 +125,13 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.model_dtype = config.model_config.dtype if config.model_config else None # groupwise FP8 linear uses col major scales if deepgemm and cutlass - use_col_major_scales = ( - should_use_deepgemm_for_fp8_linear_for_nk( - self.model_dtype, - config.model_config.hf_config.intermediate_size, - config.model_config.hf_config.hidden_size, - ) - or cutlass_block_fp8_supported() + using_deepgemm = should_use_deepgemm_for_fp8_linear_for_nk( + self.model_dtype, + config.model_config.hf_config.intermediate_size, + config.model_config.hf_config.hidden_size, ) + use_col_major_scales = using_deepgemm or cutlass_block_fp8_supported() + use_e8m0 = is_deep_gemm_e8m0_used() if using_deepgemm else False assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] @@ -140,7 +142,7 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey): else MatcherFusedAddRMSNorm(epsilon) ) self.quant_matcher = MatcherQuantFP8( - key.quant, use_col_major_scales=use_col_major_scales + key.quant, use_col_major_scales=use_col_major_scales, use_e8m0=use_e8m0 ) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index ddfc44ddc05d..0c0bece9b3fd 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -235,6 +235,7 @@ def __init__( quant_key: QuantKey, enabled: bool | None = None, use_col_major_scales: bool = False, + use_e8m0: bool = False, ): if enabled is None: enabled = QuantFP8.enabled() @@ -242,6 +243,7 @@ def __init__( super().__init__(enabled) self.quant_key = quant_key self.use_col_major_scales = use_col_major_scales + self.use_e8m0 = use_e8m0 assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" self.QUANT_OP = QUANT_OPS[quant_key] @@ -277,7 +279,7 @@ def forward_custom( eps=1e-10, fp8_min=fp8_min, fp8_max=fp8_max, - scale_ue8m0=False, + scale_ue8m0=self.use_e8m0, ) return result, scale