Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/w8a8/int8/scaled_quant.cu"
"csrc/quantization/w8a8/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
Expand Down Expand Up @@ -340,8 +341,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")

list(APPEND VLLM_EXT_SRC
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu")
"csrc/cutlass_extensions/common.cpp")

set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
Expand Down
2 changes: 0 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,11 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
std::optional<torch::Tensor> residual,
int64_t group_size, bool is_scale_transposed);

#ifndef USE_ROCM
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
bool is_scale_transposed);
#endif

void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/fused_kernels/quant_conversions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include "libtorch_stable/quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead
#include "quantization/w8a8/fp8/common.cuh"
#include "../w8a8/fp8/common.cuh"

namespace vllm {

Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/w8a8/fp8/common.cuh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "libtorch_stable/quantization/vectorization.cuh"
#include "quantization/utils.cuh"
#include "../../utils.cuh"

#include <cmath>

Expand Down
23 changes: 12 additions & 11 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);

// Fused SiLU+Mul + per-block quantization
ops.def(
"silu_and_mul_per_block_quant("
"Tensor! out, "
"Tensor input, "
"Tensor! scales, "
"int group_size, "
"Tensor? scale_ub=None, "
"bool is_scale_transposed=False) -> ()");
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
&silu_and_mul_per_block_quant);

ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);

Expand Down Expand Up @@ -232,17 +244,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// Quantization ops
#ifndef USE_ROCM
// Fused SiLU+Mul + per-block quantization
ops.def(
"silu_and_mul_per_block_quant("
"Tensor! out, "
"Tensor input, "
"Tensor! scales, "
"int group_size, "
"Tensor? scale_ub=None, "
"bool is_scale_transposed=False) -> ()");
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
&silu_and_mul_per_block_quant);
// DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens).
ops.def(
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
Expand Down
20 changes: 10 additions & 10 deletions tests/kernels/core/test_fused_silu_mul_block_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.platforms import current_platform

DTYPES = [torch.float16, torch.bfloat16]
QUANT_DTYPES = [torch.float8_e4m3fn, torch.int8]
QUANT_DTYPES = [current_platform.fp8_dtype(), torch.int8]
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
NUM_TOKENS_HIDDEN_SIZES = [
*[(1, i) for i in [64, *VEC_HIDDEN_SIZES, 2048, 5120]],
Expand All @@ -28,9 +28,7 @@
GROUP_SIZES = [64, 128]
IS_SCALE_TRANSPOSED = [False, True]
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
CUDA_DEVICES = [i for i in range(1 if torch.accelerator.device_count() == 1 else 2)]


def ref_silu_and_mul_per_block_quant(
Expand Down Expand Up @@ -60,7 +58,7 @@ def ref_silu_and_mul_per_block_quant(
@pytest.mark.parametrize("group_size", GROUP_SIZES)
@pytest.mark.parametrize("is_scale_transposed", IS_SCALE_TRANSPOSED)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device_idx", CUDA_DEVICES)
@torch.inference_mode()
def test_silu_and_mul_per_block_quant(
default_vllm_config,
Expand All @@ -72,9 +70,11 @@ def test_silu_and_mul_per_block_quant(
group_size: int,
is_scale_transposed: bool,
seed: int,
device: str,
device_idx: str,
) -> None:
"""Test SiLU+Mul+Block Quantization kernel correctness."""
torch.accelerator.set_device_index(device_idx)
device = f"cuda:{device_idx}"
torch.random.manual_seed(seed)
torch.set_default_device(device)

Expand Down Expand Up @@ -147,7 +147,7 @@ def test_silu_block_quant_shapes(
out, scales = ops.silu_and_mul_per_block_quant(
x,
group_size=group_size,
quant_dtype=torch.float8_e4m3fn,
quant_dtype=current_platform.fp8_dtype(),
is_scale_transposed=False,
)
assert out.shape == (num_tokens, hidden_size)
Expand All @@ -157,7 +157,7 @@ def test_silu_block_quant_shapes(
out, scales = ops.silu_and_mul_per_block_quant(
x,
group_size=group_size,
quant_dtype=torch.float8_e4m3fn,
quant_dtype=current_platform.fp8_dtype(),
is_scale_transposed=True,
)
assert out.shape == (num_tokens, hidden_size)
Expand All @@ -177,12 +177,12 @@ def test_silu_block_quant_edge_cases(
out, scales = ops.silu_and_mul_per_block_quant(
x,
group_size=128,
quant_dtype=torch.float8_e4m3fn,
quant_dtype=current_platform.fp8_dtype(),
is_scale_transposed=False,
)

assert out.shape == (batch_size, hidden_size)
assert out.dtype == torch.float8_e4m3fn
assert out.dtype == current_platform.fp8_dtype()
assert scales.dtype == torch.float32
assert not torch.isnan(out.float()).any()
assert not torch.isnan(scales).any()
Expand Down
4 changes: 2 additions & 2 deletions vllm/compilation/passes/fusion/act_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
if silu_and_mul_nvfp4_quant_supported:
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501

if current_platform.is_cuda():
if current_platform.is_cuda_alike():
FUSED_OPS[kFp8Dynamic128Sym] = torch.ops._C.silu_and_mul_per_block_quant.default
FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default

Expand Down Expand Up @@ -301,7 +301,7 @@ def __init__(self, config: VllmConfig) -> None:
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
pattern_silu_mul_nvfp4.register(self.patterns)

if current_platform.is_cuda():
if current_platform.is_cuda_alike():
for quant_key in [kFp8Dynamic128Sym, kFp8Dynamic64Sym]:
for is_scale_transposed in [False, True]:
for is_e8m0 in [True, False]:
Expand Down
Loading