diff --git a/CMakeLists.txt b/CMakeLists.txt index dd6ebce34be0..f06adc76c24c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -299,6 +299,7 @@ set(VLLM_EXT_SRC "csrc/quantization/w8a8/int8/scaled_quant.cu" "csrc/quantization/w8a8/fp8/common.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" + "csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" @@ -340,8 +341,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC "csrc/quantization/awq/gemm_kernels.cu" - "csrc/cutlass_extensions/common.cpp" - "csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu") + "csrc/cutlass_extensions/common.cpp") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" diff --git a/csrc/ops.h b/csrc/ops.h index 20351a3e4dc0..283e8a885197 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -142,13 +142,11 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, std::optional residual, int64_t group_size, bool is_scale_transposed); -#ifndef USE_ROCM void silu_and_mul_per_block_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, int64_t group_size, std::optional scale_ub, bool is_scale_transposed); -#endif void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh index 3711c47edc8c..fc60643777e0 100644 --- a/csrc/quantization/fused_kernels/quant_conversions.cuh +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -6,7 +6,7 @@ #include "libtorch_stable/quantization/vectorization.cuh" // TODO(luka/varun):refactor common.cuh to use this file instead -#include "quantization/w8a8/fp8/common.cuh" +#include "../w8a8/fp8/common.cuh" namespace vllm { diff --git a/csrc/quantization/w8a8/fp8/common.cuh b/csrc/quantization/w8a8/fp8/common.cuh index 7a385f5163ae..7576f7179501 100644 --- a/csrc/quantization/w8a8/fp8/common.cuh +++ b/csrc/quantization/w8a8/fp8/common.cuh @@ -1,7 +1,7 @@ #pragma once #include "libtorch_stable/quantization/vectorization.cuh" -#include "quantization/utils.cuh" +#include "../../utils.cuh" #include diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 0354df666c3a..80d83d4c375d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -109,6 +109,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); + // Fused SiLU+Mul + per-block quantization + ops.def( + "silu_and_mul_per_block_quant(" + "Tensor! out, " + "Tensor input, " + "Tensor! scales, " + "int group_size, " + "Tensor? scale_ub=None, " + "bool is_scale_transposed=False) -> ()"); + ops.impl("silu_and_mul_per_block_quant", torch::kCUDA, + &silu_and_mul_per_block_quant); + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); @@ -232,17 +244,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantization ops #ifndef USE_ROCM - // Fused SiLU+Mul + per-block quantization - ops.def( - "silu_and_mul_per_block_quant(" - "Tensor! out, " - "Tensor input, " - "Tensor! scales, " - "int group_size, " - "Tensor? scale_ub=None, " - "bool is_scale_transposed=False) -> ()"); - ops.impl("silu_and_mul_per_block_quant", torch::kCUDA, - &silu_and_mul_per_block_quant); // DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens). ops.def( "dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); diff --git a/tests/kernels/core/test_fused_silu_mul_block_quant.py b/tests/kernels/core/test_fused_silu_mul_block_quant.py index 1878390ac2f2..37b76056cc21 100644 --- a/tests/kernels/core/test_fused_silu_mul_block_quant.py +++ b/tests/kernels/core/test_fused_silu_mul_block_quant.py @@ -16,7 +16,7 @@ from vllm.platforms import current_platform DTYPES = [torch.float16, torch.bfloat16] -QUANT_DTYPES = [torch.float8_e4m3fn, torch.int8] +QUANT_DTYPES = [current_platform.fp8_dtype(), torch.int8] VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029] NUM_TOKENS_HIDDEN_SIZES = [ *[(1, i) for i in [64, *VEC_HIDDEN_SIZES, 2048, 5120]], @@ -28,9 +28,7 @@ GROUP_SIZES = [64, 128] IS_SCALE_TRANSPOSED = [False, True] SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2) -] +CUDA_DEVICES = [i for i in range(1 if torch.accelerator.device_count() == 1 else 2)] def ref_silu_and_mul_per_block_quant( @@ -60,7 +58,7 @@ def ref_silu_and_mul_per_block_quant( @pytest.mark.parametrize("group_size", GROUP_SIZES) @pytest.mark.parametrize("is_scale_transposed", IS_SCALE_TRANSPOSED) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device_idx", CUDA_DEVICES) @torch.inference_mode() def test_silu_and_mul_per_block_quant( default_vllm_config, @@ -72,9 +70,11 @@ def test_silu_and_mul_per_block_quant( group_size: int, is_scale_transposed: bool, seed: int, - device: str, + device_idx: str, ) -> None: """Test SiLU+Mul+Block Quantization kernel correctness.""" + torch.accelerator.set_device_index(device_idx) + device = f"cuda:{device_idx}" torch.random.manual_seed(seed) torch.set_default_device(device) @@ -147,7 +147,7 @@ def test_silu_block_quant_shapes( out, scales = ops.silu_and_mul_per_block_quant( x, group_size=group_size, - quant_dtype=torch.float8_e4m3fn, + quant_dtype=current_platform.fp8_dtype(), is_scale_transposed=False, ) assert out.shape == (num_tokens, hidden_size) @@ -157,7 +157,7 @@ def test_silu_block_quant_shapes( out, scales = ops.silu_and_mul_per_block_quant( x, group_size=group_size, - quant_dtype=torch.float8_e4m3fn, + quant_dtype=current_platform.fp8_dtype(), is_scale_transposed=True, ) assert out.shape == (num_tokens, hidden_size) @@ -177,12 +177,12 @@ def test_silu_block_quant_edge_cases( out, scales = ops.silu_and_mul_per_block_quant( x, group_size=128, - quant_dtype=torch.float8_e4m3fn, + quant_dtype=current_platform.fp8_dtype(), is_scale_transposed=False, ) assert out.shape == (batch_size, hidden_size) - assert out.dtype == torch.float8_e4m3fn + assert out.dtype == current_platform.fp8_dtype() assert scales.dtype == torch.float32 assert not torch.isnan(out.float()).any() assert not torch.isnan(scales).any() diff --git a/vllm/compilation/passes/fusion/act_quant_fusion.py b/vllm/compilation/passes/fusion/act_quant_fusion.py index 2a1d37a1dae7..a712c013ce99 100644 --- a/vllm/compilation/passes/fusion/act_quant_fusion.py +++ b/vllm/compilation/passes/fusion/act_quant_fusion.py @@ -45,7 +45,7 @@ if silu_and_mul_nvfp4_quant_supported: FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 -if current_platform.is_cuda(): +if current_platform.is_cuda_alike(): FUSED_OPS[kFp8Dynamic128Sym] = torch.ops._C.silu_and_mul_per_block_quant.default FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default @@ -301,7 +301,7 @@ def __init__(self, config: VllmConfig) -> None: pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() pattern_silu_mul_nvfp4.register(self.patterns) - if current_platform.is_cuda(): + if current_platform.is_cuda_alike(): for quant_key in [kFp8Dynamic128Sym, kFp8Dynamic64Sym]: for is_scale_transposed in [False, True]: for is_e8m0 in [True, False]: