diff --git a/CMakeLists.txt b/CMakeLists.txt index 3db7ff0bbda2..1695d5ab4955 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -340,7 +340,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC "csrc/quantization/awq/gemm_kernels.cu" - "csrc/cutlass_extensions/common.cpp") + "csrc/cutlass_extensions/common.cpp" + "csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" diff --git a/benchmarks/fused_kernels/silu_mul_block_quant_benchmark.py b/benchmarks/fused_kernels/silu_mul_block_quant_benchmark.py new file mode 100644 index 000000000000..4e8d787bf9c7 --- /dev/null +++ b/benchmarks/fused_kernels/silu_mul_block_quant_benchmark.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable, Iterable +from dataclasses import dataclass +from itertools import product + +import torch +import torch.nn.functional as F +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from tqdm import tqdm + +import vllm._custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) + + +@dataclass +class bench_params_t: + num_tokens: int + hidden_size: int + dtype: torch.dtype + group_size: int # Changed from list[int] to int + + def description(self): + return ( + f"N {self.num_tokens} " + f"x D {self.hidden_size} " + f"x DT {self.dtype} " + f"x GS {self.group_size}" + ) + + +def get_bench_params() -> list[bench_params_t]: + """Test configurations covering common model sizes.""" + NUM_TOKENS = [16, 128, 512, 2048] + HIDDEN_SIZES = [1024, 2048, 4096, 5120, 14336] # Common FFN sizes + DTYPES = [torch.float16, torch.bfloat16] + GROUP_SIZES = [64, 128] # Changed from [[1, 64], [1, 128]] + + combinations = product(NUM_TOKENS, HIDDEN_SIZES, DTYPES, GROUP_SIZES) + bench_params = list( + map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations) + ) + return bench_params + + +# Reference implementations +def unfused_fp8_impl( + x: torch.Tensor, + quant_dtype: torch.dtype, + group_size: int, # Changed from list[int] +): + """Unfused: SiLU+Mul then per-tensor quantize.""" + hidden = x.shape[-1] // 2 + gate, up = x.split(hidden, dim=-1) + + # SiLU(gate) * up + silu_out = F.silu(gate) * up + + # Per-tensor quantize (no group_size used here) + silu_out, _ = ops.scaled_fp8_quant(silu_out) + + +def unfused_groupwise_fp8_impl( + x: torch.Tensor, + quant_dtype: torch.dtype, + group_size: int, # Changed from list[int] +): + """Unfused: SiLU+Mul then group-wise quantize.""" + hidden = x.shape[-1] // 2 + gate, up = x.split(hidden, dim=-1) + + # SiLU(gate) * up + silu_out = F.silu(gate) * up + + # Group quantize - use group_size directly + silu_out, _ = per_token_group_quant_fp8( + silu_out, group_size=group_size, use_ue8m0=False + ) + + +def fused_impl( + x: torch.Tensor, + quant_dtype: torch.dtype, + group_size: int, +): + """Fused: SiLU+Mul+Block Quantization in single kernel.""" + out, _ = ops.silu_and_mul_per_block_quant( + x, + group_size=group_size, + quant_dtype=quant_dtype, + is_scale_transposed=False, + ) + + +# Bench functions +def bench_fn( + x: torch.Tensor, + quant_dtype: torch.dtype, + group_size: int, + label: str, + sub_label: str, + fn: Callable, + description: str, +) -> TMeasurement: + min_run_time = 1 + + globals = { + "x": x, + "quant_dtype": quant_dtype, + "group_size": group_size, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(x, quant_dtype, group_size)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]: + """Run benchmarks for all implementations.""" + # Make inputs: [num_tokens, hidden_size * 2] for [gate || up] + scale = 1 / params.hidden_size + x = ( + torch.randn( + params.num_tokens, + params.hidden_size * 2, + dtype=params.dtype, + device="cuda", + ) + * scale + ) + + timers = [] + + # Unfused per-tensor FP8 + timers.append( + bench_fn( + x, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + unfused_fp8_impl, + "unfused_fp8_impl", + ) + ) + + # Unfused group-wise FP8 + timers.append( + bench_fn( + x, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + unfused_groupwise_fp8_impl, + "unfused_groupwise_fp8_impl", + ) + ) + + # Fused group-wise FP8 + timers.append( + bench_fn( + x, + torch.float8_e4m3fn, + params.group_size, + label, + sub_label, + fused_impl, + "fused_groupwise_fp8_impl", + ) + ) + + return timers + + +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def main(): + torch.set_default_device("cuda") + bench_params = get_bench_params() + + print(f"Running {len(bench_params)} benchmark configurations...") + print( + f"This will take approximately {len(bench_params) * 3} seconds (1s per variant)" + ) + print() + + timers = [] + for bp in tqdm(bench_params): + result_timers = bench(bp, "silu-mul-block-quant", bp.description()) + timers.extend(result_timers) + + print("\n" + "=" * 80) + print("FINAL COMPARISON - ALL RESULTS") + print("=" * 80) + print_timers(timers) + + +if __name__ == "__main__": + main() diff --git a/csrc/ops.h b/csrc/ops.h index 1fdd77f73179..a3f4c8678671 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -143,6 +143,12 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, std::optional residual, int64_t group_size, bool is_scale_transposed); +void silu_and_mul_per_block_quant(torch::Tensor& out, + torch::Tensor const& input, + torch::Tensor& scales, int64_t group_size, + std::optional scale_ub, + bool is_scale_transposed); + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); diff --git a/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu b/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu new file mode 100644 index 000000000000..993ee641b5d6 --- /dev/null +++ b/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +#include +#include + +#include "../../dispatch_utils.h" +#include "quant_conversions.cuh" +#include "../w8a8/fp8/common.cuh" + +namespace vllm { + +// Logic: one thread block per (token, group) pair + +template +__global__ void silu_and_mul_per_block_quant_kernel( + scalar_out_t* __restrict__ out, // Output: [num_tokens, hidden_size] in + // FP8/INT8 + float* __restrict__ scales, // Output: [num_tokens, hidden_size / + // group_size] or [hidden_size / group_size, + // num_tokens] + scalar_t const* __restrict__ input, // Input: [num_tokens, hidden_size * 2] + float const* scale_ub, // Optional scale upper bound + int32_t const hidden_size // Output hidden size (input is 2x this) +) { + static_assert((group_size & (group_size - 1)) == 0, + "group_size must be a power of 2 for correct reduction"); + + // Grid: (num_tokens, num_groups) + int const token_idx = blockIdx.x; + int const group_idx = blockIdx.y; + int const tid = threadIdx.x; // tid in [0, group_size) + int const num_tokens = gridDim.x; + + // Input layout: [gate || up] concatenated along last dimension + int const input_stride = hidden_size * 2; + int const group_start = group_idx * group_size; + + // Pointers to this token's data + scalar_t const* token_input_gate = + input + token_idx * input_stride + group_start; + scalar_t const* token_input_up = token_input_gate + hidden_size; + scalar_out_t* token_output = out + token_idx * hidden_size + group_start; + + // Scale pointer for this group + int const num_groups = gridDim.y; + float* group_scale_ptr = is_scale_transposed + ? scales + group_idx * num_tokens + token_idx + : scales + token_idx * num_groups + group_idx; + + // Shared memory for reduction (compile-time sized) + __shared__ float shared_max[group_size]; + + // Step 1: Each thread loads one element, computes SiLU, stores in register + float gate = static_cast(token_input_gate[tid]); + float up = static_cast(token_input_up[tid]); + + // Compute SiLU(gate) * up + float sigmoid_gate = 1.0f / (1.0f + expf(-gate)); + float silu_gate = gate * sigmoid_gate; + float result = silu_gate * up; // Keep in register + + // Step 2: Reduce to find group max + shared_max[tid] = fabsf(result); + __syncthreads(); + +// Power-of-2 reduction (group_size guaranteed to be power of 2) +#pragma unroll + for (int stride = group_size / 2; stride > 0; stride >>= 1) { + if (tid < stride) { + shared_max[tid] = fmaxf(shared_max[tid], shared_max[tid + stride]); + } + __syncthreads(); + } + + // Step 3: Compute scale (thread 0), broadcast via shared memory + if (tid == 0) { + float group_max = shared_max[0]; + + float const quant_range = quant_type_max_v; + float group_scale = group_max / quant_range; + + // Apply scale upper bound if provided + if (scale_ub != nullptr) { + group_scale = fminf(group_scale, *scale_ub); + } + + // Use minimum safe scaling factor + group_scale = fmaxf(group_scale, min_scaling_factor::val()); + + // Store scale to global memory + *group_scale_ptr = group_scale; + + // Reuse shared_max[0] to broadcast scale + shared_max[0] = group_scale; + } + __syncthreads(); + + float group_scale = shared_max[0]; + + // Step 4: Quantize and write output + token_output[tid] = + vllm::ScaledQuant::quant_fn(result, group_scale); +} + +} // namespace vllm + +void silu_and_mul_per_block_quant(torch::Tensor& out, + torch::Tensor const& input, + torch::Tensor& scales, int64_t group_size, + std::optional scale_ub, + bool is_scale_transposed) { + static c10::ScalarType kFp8Type = is_fp8_ocp() + ? c10::ScalarType::Float8_e4m3fn + : c10::ScalarType::Float8_e4m3fnuz; + + TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); + TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + TORCH_CHECK( + input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16, + "Input must be FP16 or BF16"); + TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32"); + TORCH_CHECK(group_size == 128 || group_size == 64, + "Unsupported group size: ", group_size); + + if (scale_ub.has_value()) { + TORCH_CHECK(out.dtype() == kFp8Type); + } + + int32_t hidden_size = out.size(-1); + auto num_tokens = input.size(0); + int32_t num_groups = hidden_size / group_size; + + TORCH_CHECK(input.size(-1) == hidden_size * 2, + "input last dim must be 2x output hidden_size"); + TORCH_CHECK(hidden_size % group_size == 0, + "hidden_size must be divisible by group_size"); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(num_tokens, num_groups); + dim3 block(group_size); + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "silu_and_mul_per_block_quant", [&] { + using scalar_in_t = scalar_t; + + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "silu_and_mul_per_block_quant", [&] { + using scalar_out_t = scalar_t; + + VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] { + VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] { + vllm::silu_and_mul_per_block_quant_kernel< + scalar_in_t, scalar_out_t, transpose_scale, gs> + <<>>( + out.data_ptr(), + scales.data_ptr(), + input.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + hidden_size); + }); + }); + }); + }); +} \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 85605458f1ae..9475ee668232 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -2,7 +2,6 @@ #include "cuda_utils.h" #include "ops.h" #include "core/registration.h" - #include #include @@ -109,6 +108,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); + // Fused SiLU+Mul + per-block quantization + ops.def( + "silu_and_mul_per_block_quant(" + "Tensor! out, " + "Tensor input, " + "Tensor! scales, " + "int group_size, " + "Tensor? scale_ub=None, " + "bool is_scale_transposed=False) -> ()"); + ops.impl("silu_and_mul_per_block_quant", torch::kCUDA, + &silu_and_mul_per_block_quant); + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); diff --git a/docs/design/fusions.md b/docs/design/fusions.md index 28a29a7f3516..cdc825e8dee2 100644 --- a/docs/design/fusions.md +++ b/docs/design/fusions.md @@ -45,7 +45,7 @@ The table below lists the quantization schemes supported by each fusion on each | `enable_sp` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — | | `fuse_gemm_comms` | FP16/BF16, FP8 static† | FP16/BF16, FP8 static | FP16/BF16† | FP16/BF16† | — | | `fuse_norm_quant` | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | FP8 static, FP8 per-token, FP8 per-group | — | FP8 static, FP8 per-token, FP8 per-group | -| `fuse_act_quant` | FP8 static, NVFP4 | FP8 static | FP8 static | — | FP8 per-group | +| `fuse_act_quant` | FP8 static, NVFP4 | FP8 static, FP8 per-group (128/64) | FP8 static, FP8 per-group (128/64) | — | FP8 per-group | | `fuse_act_padding` | — | — | — | — | FP16/BF16 | \* `fuse_attn_quant` support depends on the attention backend in use; not all backends support @@ -305,6 +305,7 @@ Note that AITER fusions are in a separate pass in `vllm.compilation.passes.fusio Supported quantization scheme/hardware combinations: - FP8 static per-tensor: CUDA & HIP kernel +- FP8 dynamic per-group (128/64): CUDA kernel (sm89+, not active when DeepGemm is used on sm100+) - NVFP4 dynamic: CUDA sm100+ only with FlashInfer - FP8 per-token-group (128): ROCm AITER only @@ -313,6 +314,7 @@ Supported quantization scheme/hardware combinations: - Pass: [`vllm/compilation/passes/fusion/act_quant_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/act_quant_fusion.py) - ROCm AITER pass: [`vllm/compilation/passes/fusion/rocm_aiter_fusion.py`](https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/rocm_aiter_fusion.py) - CUDA/HIP kernels: [`csrc/quantization/`](https://github.com/vllm-project/vllm/blob/main/csrc/quantization/) +- Fused SiLU+Mul+BlockQuant kernel: [`csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu`](https://github.com/vllm-project/vllm/blob/main/csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu) ### RMSNorm + Padding (`fuse_act_padding`) diff --git a/tests/compile/fusions_e2e/models.py b/tests/compile/fusions_e2e/models.py index 1a5f18cc0d50..b174efd257d0 100644 --- a/tests/compile/fusions_e2e/models.py +++ b/tests/compile/fusions_e2e/models.py @@ -150,9 +150,8 @@ # - post_attn_layernorm + MLP # 2 per MoE layer (remaining) due to MoE wrapping rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers - # TODO silu+block quant - # act_quant_fusion=min(3, n_layers), # dense layers only - act_quant_fusion=0, + # silu+block quant + act_quant_fusion=min(3, n_layers), # dense layers only # MLA attn + quant not supported yet: # https://github.com/vllm-project/vllm/issues/35792 attn_quant_fusion=0, diff --git a/tests/compile/passes/test_silu_mul_quant_fusion.py b/tests/compile/passes/test_silu_mul_quant_fusion.py index a77b4e6de7bd..383d59d03a7d 100644 --- a/tests/compile/passes/test_silu_mul_quant_fusion.py +++ b/tests/compile/passes/test_silu_mul_quant_fusion.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools +from functools import partial import pytest import torch @@ -34,13 +35,16 @@ ROCmFP8ScaledMMLinearKernel, ) from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, + kFp8Dynamic128Sym, kFp8StaticTensorSym, kNvfp4Dynamic, ) from vllm.platforms import current_platform +from vllm.utils.deep_gemm import is_deep_gemm_supported FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -165,6 +169,48 @@ def ops_in_model_after(self): return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant] +class TestSiluMulBlockQuantModel(torch.nn.Module): + quant_key = kFp8Dynamic128Sym + + def __init__(self, hidden_size: int, is_scale_transposed: bool = False, **kwargs): + super().__init__() + self.silu_and_mul = SiluAndMul() + self.is_scale_transposed = is_scale_transposed + self.quant_fp8 = QuantFP8( + static=False, + group_shape=GroupShape(1, 128), + column_major_scales=is_scale_transposed, + compile_native=False, + ) + + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() + self.enable_quant_fp8_custom_op = self.quant_fp8.enabled() + + def forward(self, x): + y = self.silu_and_mul(x) + out, scale = self.quant_fp8(y) + group_size = self.quant_key.scale.group_shape[1] + scale_expanded = scale.repeat_interleave(group_size, dim=1) + dequant = out.to(dtype=torch.float32) * scale_expanded + return (dequant,) + + def ops_in_model_before(self): + ops = [] + if self.enable_silu_mul_custom_op: + ops.append(SILU_MUL_OP) + # When silu custom op is disabled, aten.mul.Tensor also appears + # in dequant code, so we skip checking it to avoid false positives. + ops.append( + QUANT_OPS[self.quant_key] + if self.enable_quant_fp8_custom_op + else torch.ops.aten.reciprocal.default + ) + return ops + + def ops_in_model_after(self): + return [FUSED_OPS[self.quant_key]] + + ROCM_KERNELS = [ROCmFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel] CUDA_KERNELS = [ FlashInferFP8ScaledMMLinearKernel, @@ -200,6 +246,19 @@ def ops_in_model_after(self): not current_platform.is_rocm(), reason="ROCm only" ), ), + # Block quant fusion for per-group FP8 (CUDA only). + *[ + pytest.param( + partial(TestSiluMulBlockQuantModel, is_scale_transposed=transposed), + True, + None, + marks=pytest.mark.skipif( + not current_platform.is_cuda(), reason="CUDA only" + ), + id=f"TestSiluMulBlockQuant-transposed={transposed}", + ) + for transposed in [False, True] + ], ], ) @pytest.mark.skipif( @@ -213,6 +272,7 @@ def test_fusion_silu_and_mul_quant( TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel | TestSiluMulGroupFp8QuantModel + | TestSiluMulBlockQuantModel ], enable_silu_mul_custom_op: bool, enable_quant_fp8_custom_op: bool, @@ -223,6 +283,12 @@ def test_fusion_silu_and_mul_quant( pytest.skip("NVFP4 is not supported on this GPU.") if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND: pytest.skip("AITER is not supported on this GPU.") + if ( + isinstance(model_class, partial) + and model_class.func is TestSiluMulBlockQuantModel + and is_deep_gemm_supported() + ): + pytest.skip("SiluMul+BlockQuant fusion not applicable with DeepGemm") torch.set_default_device("cuda") torch.set_default_dtype(dtype) @@ -269,11 +335,13 @@ def test_fusion_silu_and_mul_quant( result2 = model2(x) # Check that it gives the same answer - if model_class == TestSiluMulFp8QuantModel: + if isinstance(model, TestSiluMulFp8QuantModel): atol, rtol = 1e-3, 1e-3 - elif model_class == TestSiluMulNvfp4QuantModel: + elif isinstance(model, TestSiluMulNvfp4QuantModel): atol, rtol = 1e-1, 1e-1 - elif model_class == TestSiluMulGroupFp8QuantModel: + elif isinstance( + model, (TestSiluMulGroupFp8QuantModel, TestSiluMulBlockQuantModel) + ): atol, rtol = 5e-2, 5e-2 torch.testing.assert_close( diff --git a/tests/kernels/core/test_fused_silu_mul_block_quant.py b/tests/kernels/core/test_fused_silu_mul_block_quant.py new file mode 100644 index 000000000000..1878390ac2f2 --- /dev/null +++ b/tests/kernels/core/test_fused_silu_mul_block_quant.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import torch.nn.functional as F + +import vllm._custom_ops as ops +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, +) +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + per_token_group_quant_int8, +) +from vllm.platforms import current_platform + +DTYPES = [torch.float16, torch.bfloat16] +QUANT_DTYPES = [torch.float8_e4m3fn, torch.int8] +VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029] +NUM_TOKENS_HIDDEN_SIZES = [ + *[(1, i) for i in [64, *VEC_HIDDEN_SIZES, 2048, 5120]], + *[(16, i) for i in [64, *VEC_HIDDEN_SIZES, 5120]], + *[(128, i) for i in [64, *VEC_HIDDEN_SIZES]], + *[(512, i) for i in [64, 5120]], +] +SCALE_UBS = [False] +GROUP_SIZES = [64, 128] +IS_SCALE_TRANSPOSED = [False, True] +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2) +] + + +def ref_silu_and_mul_per_block_quant( + x: torch.Tensor, + quant_dtype: torch.dtype, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Reference implementation: unfused SiLU+Mul then group quantization.""" + hidden = x.shape[-1] // 2 + gate, up = x.split(hidden, dim=-1) + silu_out = F.silu(gate) * up + + if quant_dtype == current_platform.fp8_dtype(): + return per_token_group_quant_fp8( + silu_out, group_size=group_size, use_ue8m0=False + ) + elif quant_dtype == torch.int8: + return per_token_group_quant_int8(silu_out, group_size=group_size) + else: + raise ValueError(f"Unsupported quant_dtype: {quant_dtype}") + + +@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES) +@pytest.mark.parametrize("has_scale_ub", SCALE_UBS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES) +@pytest.mark.parametrize("group_size", GROUP_SIZES) +@pytest.mark.parametrize("is_scale_transposed", IS_SCALE_TRANSPOSED) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_silu_and_mul_per_block_quant( + default_vllm_config, + num_tokens: int, + hidden_size: int, + has_scale_ub: bool, + dtype: torch.dtype, + quant_dtype: torch.dtype, + group_size: int, + is_scale_transposed: bool, + seed: int, + device: str, +) -> None: + """Test SiLU+Mul+Block Quantization kernel correctness.""" + torch.random.manual_seed(seed) + torch.set_default_device(device) + + if hidden_size % group_size != 0: + return + + if has_scale_ub: + pytest.skip("Scale upper bound not yet supported") + + scale = 1 / hidden_size + x = torch.randn(num_tokens, hidden_size * 2, dtype=dtype, device=device) * scale + + # Reference implementation + ref_out, ref_scales = ref_silu_and_mul_per_block_quant(x, quant_dtype, group_size) + + # Fused kernel implementation + ops_out, ops_scales = ops.silu_and_mul_per_block_quant( + x, group_size, quant_dtype, None, is_scale_transposed + ) + + # Check for NaN/Inf + assert not torch.isnan(ops_out.float()).any(), "Kernel output contains NaN" + assert not torch.isinf(ops_out.float()).any(), "Kernel output contains Inf" + assert not torch.isnan(ops_scales).any(), "Kernel scales contain NaN" + assert not torch.isinf(ops_scales).any(), "Kernel scales contain Inf" + + # Check dtypes + assert ref_out.dtype == quant_dtype + assert ops_out.dtype == quant_dtype + + # Check scales match + torch.testing.assert_close(ref_scales, ops_scales, rtol=1e-5, atol=1e-5) + + # Check output correctness via dequantized values + ref_scales_expanded = ref_scales.repeat_interleave(group_size, dim=1) + ops_scales_expanded = ops_scales.repeat_interleave(group_size, dim=1) + ref_deq = ref_out.to(dtype=torch.float32) * ref_scales_expanded + ops_deq = ops_out.to(dtype=torch.float32) * ops_scales_expanded + torch.testing.assert_close(ref_deq, ops_deq, atol=5e-2, rtol=5e-2) + + # opcheck + output = torch.empty(num_tokens, hidden_size, device=device, dtype=quant_dtype) + num_groups = hidden_size // group_size + if is_scale_transposed: + scales = torch.empty(num_groups, num_tokens, device=device, dtype=torch.float32) + else: + scales = torch.empty(num_tokens, num_groups, device=device, dtype=torch.float32) + opcheck( + torch.ops._C.silu_and_mul_per_block_quant, + (output, x, scales, group_size, None, is_scale_transposed), + ) + + +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("num_tokens", [128]) +@pytest.mark.parametrize("group_size", [128]) +def test_silu_block_quant_shapes( + default_vllm_config, + dtype: torch.dtype, + hidden_size: int, + num_tokens: int, + group_size: int, +): + """Test that output shapes are correct.""" + torch.set_default_device("cuda") + x = torch.randn(num_tokens, hidden_size * 2, dtype=dtype, device="cuda") + + # Row-major scales + out, scales = ops.silu_and_mul_per_block_quant( + x, + group_size=group_size, + quant_dtype=torch.float8_e4m3fn, + is_scale_transposed=False, + ) + assert out.shape == (num_tokens, hidden_size) + assert scales.shape == (num_tokens, hidden_size // group_size) + + # Column-major scales (logical shape same after .t() in _custom_ops) + out, scales = ops.silu_and_mul_per_block_quant( + x, + group_size=group_size, + quant_dtype=torch.float8_e4m3fn, + is_scale_transposed=True, + ) + assert out.shape == (num_tokens, hidden_size) + assert scales.shape == (num_tokens, hidden_size // group_size) + + +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("batch_size", [1, 16, 256]) +@pytest.mark.parametrize("hidden_size", [1024, 5120, 14336]) +def test_silu_block_quant_edge_cases( + default_vllm_config, dtype: torch.dtype, batch_size: int, hidden_size: int +): + """Test edge cases: single token, large batch, large hidden size.""" + torch.set_default_device("cuda") + x = torch.randn(batch_size, hidden_size * 2, dtype=dtype, device="cuda") + + out, scales = ops.silu_and_mul_per_block_quant( + x, + group_size=128, + quant_dtype=torch.float8_e4m3fn, + is_scale_transposed=False, + ) + + assert out.shape == (batch_size, hidden_size) + assert out.dtype == torch.float8_e4m3fn + assert scales.dtype == torch.float32 + assert not torch.isnan(out.float()).any() + assert not torch.isnan(scales).any() + assert not torch.isinf(scales).any() diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ea54aaa95a8b..de8e29756fc0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -572,6 +572,56 @@ def rms_norm_per_block_quant( return output, scales +# fused silu_and_mul + block quant +def silu_and_mul_per_block_quant( + input: torch.Tensor, + group_size: int, # Changed from list[int] + quant_dtype: torch.dtype, + scale_ub: torch.Tensor | None = None, + is_scale_transposed: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + assert input.ndim == 2, f"input must be 2D [batch, hidden*2], got {input.shape}" + assert input.shape[-1] % 2 == 0, ( + f"input last dim must be even (gate||up layout), got {input.shape[-1]}" + ) + + # Output is half the width of input (after silu_and_mul) + num_tokens = input.shape[0] + hidden_size = input.shape[-1] // 2 # Divide by 2 because input is [gate || up] + + # Allocate output tensor (FP8 or INT8) + output = torch.empty( + (num_tokens, hidden_size), device=input.device, dtype=quant_dtype + ) + + # Allocate scales tensor + num_groups = hidden_size // group_size # Directly use group_size + if is_scale_transposed: + scales = torch.empty( + (num_groups, num_tokens), + device=input.device, + dtype=torch.float32, + ).t() + else: + scales = torch.empty( + (num_tokens, num_groups), + device=input.device, + dtype=torch.float32, + ) + + # Call the C++ kernel + torch.ops._C.silu_and_mul_per_block_quant( + output, + input, + scales, + group_size, # Pass directly as int + scale_ub, + is_scale_transposed, + ) + + return output, scales + + # quantization ops # awq def awq_dequantize( diff --git a/vllm/compilation/passes/fusion/act_quant_fusion.py b/vllm/compilation/passes/fusion/act_quant_fusion.py index 911775f69967..2a1d37a1dae7 100644 --- a/vllm/compilation/passes/fusion/act_quant_fusion.py +++ b/vllm/compilation/passes/fusion/act_quant_fusion.py @@ -17,6 +17,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, + kFp8Dynamic64Sym, + kFp8Dynamic128Sym, kFp8StaticTensorSym, kNvfp4Dynamic, ) @@ -43,6 +45,10 @@ if silu_and_mul_nvfp4_quant_supported: FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 +if current_platform.is_cuda(): + FUSED_OPS[kFp8Dynamic128Sym] = torch.ops._C.silu_and_mul_per_block_quant.default + FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default + class ActivationQuantPattern(ABC): """ @@ -174,6 +180,102 @@ def replacement( register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass) +class SiluMulBlockQuantPattern(ActivationQuantPattern): + """ + Fusion for SiluMul+BlockQuant (FP8 dynamic per-group) Pattern. + Supports group_size 128 and 64 via QuantKey. + Parameterized on is_scale_transposed for different scale layouts. + """ + + def __init__( + self, + quant_key: QuantKey, + is_scale_transposed: bool = False, + is_e8m0: bool = False, + is_tma_aligned: bool = False, + ) -> None: + super().__init__(quant_key) + self.quant_matcher = MatcherQuantFP8( + quant_key, + has_col_major_scales=is_scale_transposed, + is_e8m0=is_e8m0, + is_tma_aligned=is_tma_aligned, + ) + self.group_size = quant_key.scale.group_shape[1] + self.is_scale_transposed = is_scale_transposed + self.is_e8m0 = is_e8m0 + self.is_tma_aligned = is_tma_aligned + + def get_inputs(self) -> list[torch.Tensor]: + scale = self.quant_matcher.empty_f32(1, 1) + return self.silu_and_mul_matcher.inputs() + [scale] + + def register(self, pm_pass: PatternMatcherPass) -> None: + is_scale_transposed = self.is_scale_transposed + + def pattern( + input: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + silu_out = self.silu_and_mul_matcher(input) + result = torch.empty( + silu_out.shape, + device=silu_out.device, + dtype=self.quant_dtype, + ) + assert scale is not None + finfo = torch.finfo(self.quant_dtype) + _, result, scale = auto_functionalized( + self.quant_matcher.QUANT_OP, + input=silu_out, + output_q=result, + output_s=scale, + group_size=self.group_size, + eps=1e-10, + fp8_min=finfo.min, + fp8_max=finfo.max, + scale_ue8m0=self.is_e8m0, + dummy_is_scale_transposed=is_scale_transposed, + dummy_is_tma_aligned=self.is_tma_aligned, + ) + return result, scale + + def replacement( + input: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + d = input.shape[-1] // 2 + output_shape = input.shape[:-1] + (d,) + result = torch.empty( + output_shape, device=input.device, dtype=self.quant_dtype + ) + if is_scale_transposed: + scale = torch.empty( + (d // self.group_size, input.shape[0]), + device=input.device, + dtype=torch.float32, + ).permute(-1, -2) + else: + scale = torch.empty( + (input.shape[0], d // self.group_size), + device=input.device, + dtype=torch.float32, + ) + at = auto_functionalized( + self.FUSED_OP, + out=result, + input=input, + scales=scale, + group_size=self.group_size, + scale_ub=None, + is_scale_transposed=is_scale_transposed, + ) + return at[1], at[2] + + inps = self.get_inputs() + register_replacement(pattern, replacement, inps, fwd_only, pm_pass) + + class ActivationQuantFusionPass(VllmPatternMatcherPass): """ This pass fuses a pre-defined set of custom ops into fused ops. @@ -199,6 +301,18 @@ def __init__(self, config: VllmConfig) -> None: pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() pattern_silu_mul_nvfp4.register(self.patterns) + if current_platform.is_cuda(): + for quant_key in [kFp8Dynamic128Sym, kFp8Dynamic64Sym]: + for is_scale_transposed in [False, True]: + for is_e8m0 in [True, False]: + for is_tma_aligned in [False, True]: + SiluMulBlockQuantPattern( + quant_key, + is_scale_transposed=is_scale_transposed, + is_e8m0=is_e8m0, + is_tma_aligned=is_tma_aligned, + ).register(self.patterns) + self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log @@ -212,4 +326,5 @@ def uuid(self) -> str: ActivationQuantPattern, SiluMulFp8StaticQuantPattern, SiluMulNvfp4QuantPattern, + SiluMulBlockQuantPattern, )