Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ steps:
# Quantization
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
- pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py
- pytest -v -s tests/kernels/quantization/test_silu_mul_nvfp4_quant.py
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
Expand Down
120 changes: 27 additions & 93 deletions csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,109 +30,41 @@

namespace vllm {

template <class Type>
__inline__ __device__ PackedVec<Type> compute_silu(PackedVec<Type>& vec,
PackedVec<Type>& vec2) {
PackedVec<Type> result;
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
if constexpr (std::is_same_v<Type, half>) {
half2 val(0.5f, 0.5f);
half2 t0 = __hmul2(vec.elts[i], val);
half2 t1 = __hfma2(h2tanh(t0), val, val);
half2 t2 = __hmul2(vec.elts[i], t1);
result.elts[i] = __hmul2(t2, vec2.elts[i]);
} else {
__nv_bfloat162 val(0.5f, 0.5f);
__nv_bfloat162 t0 = __hmul2(vec.elts[i], val);
__nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val);
__nv_bfloat162 t2 = __hmul2(vec.elts[i], t1);
result.elts[i] = __hmul2(t2, vec2.elts[i]);
}
}
return result;
// silu in float32
__device__ __forceinline__ float silu(float x) {
return __fdividef(x, (1.f + __expf(-x)));
}

// Quantizes the provided PackedVec into the uint32_t output
template <class Type, bool UE8M0_SF = false>
__device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
PackedVec<Type>& vec2,
float SFScaleVal,
uint8_t* SFout) {
PackedVec<Type> out_silu = compute_silu(vec, vec2);
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(out_silu.elts[0]);

// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(out_silu.elts[i]));
}

// Get the absolute maximum among all 16 values (two threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));

// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
// Extract the 8 exponent bits from float32.
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
fp8SFVal = tmp & 0xff;
// Convert back to fp32.
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
// Convert back to fp32.
SFValue = float(tmp);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(
SFValue * reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;

if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
__device__ __forceinline__ float2 silu2(float2 x) {
return make_float2(silu(x.x), silu(x.y));
}

// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
template <class Type>
__inline__ __device__ PackedVec<Type> compute_silu_mul(PackedVec<Type>& vec,
PackedVec<Type>& vec2) {
PackedVec<Type> result;
using packed_type = typename TypeConverter<Type>::Type;

#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
// silu_mul in float32
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(out_silu.elts[i]);
float2 silu_vec = silu2(__half22float2(vec.elts[i]));
result.elts[i] =
__float22half2_rn(__fmul2_rn(silu_vec, __half22float2(vec2.elts[i])));
} else {
fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]);
float2 silu_vec = silu2(__bfloat1622float2(vec.elts[i]));
result.elts[i] = __float22bfloat162_rn(
__fmul2_rn(silu_vec, __bfloat1622float2(vec2.elts[i])));
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}

// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);

// Write the e2m1 values to global memory.
return e2m1Vec;
return result;
}

// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(1024, 4)
silu_and_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
float const* SFScale, uint32_t* out,
uint32_t* SFout) {
using PackedVec = PackedVec<Type>;
Expand Down Expand Up @@ -160,16 +92,18 @@ __global__ void __launch_bounds__(1024, 4)
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
;
auto& out_pos = out[outOffset];

// Compute silu and mul
PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2);

auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx, colIdx, numCols, SFout);

out_pos = silu_and_cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(
in_vec, in_vec2, SFScaleVal, sf_out);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(out_silu_mul, SFScaleVal,
sf_out);
}
}
}
Expand Down Expand Up @@ -204,7 +138,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::silu_and_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
m, n, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
Expand Down
13 changes: 7 additions & 6 deletions tests/compile/test_silu_mul_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@ def ops_in_model_after(self):
return [FUSED_OPS[kNvfp4Quant]]


@pytest.mark.parametrize("num_tokens", [64])
@pytest.mark.parametrize("hidden_size", [128])
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"model_class",
cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
Expand All @@ -110,13 +111,13 @@ def ops_in_model_after(self):
[True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
cuda_force_torch):
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
pytest.skip("Duplicate tests for NVFP4")

torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
torch.set_default_dtype(dtype)

x = torch.rand(num_tokens, hidden_size * 2)

Expand Down Expand Up @@ -145,8 +146,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1

torch.testing.assert_close(result[0].to(dtype=torch.float16),
result2[0].to(dtype=torch.float16),
torch.testing.assert_close(result[0].to(dtype=dtype),
result2[0].to(dtype=dtype),
atol=atol,
rtol=rtol)

Expand Down
75 changes: 75 additions & 0 deletions tests/kernels/quantization/test_silu_mul_nvfp4_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from vllm._custom_ops import scaled_fp4_quant
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms import current_platform

if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)

FP4_DTYPE = torch.uint8
FP8_DTYPE = current_platform.fp8_dtype()

DTYPES = [torch.float16, torch.bfloat16]
SHAPES = [(128, 256), (128, 128), (256, 256), (256, 128)]
BLOCK_SIZE = 16


@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
@torch.inference_mode()
def test_silu_mul_nvfp4_quant(
dtype: torch.dtype,
shape: tuple[int, int],
) -> None:
current_platform.seed_everything(42)
device = 'cuda:0'
torch.set_default_device(device)

x = torch.randn(shape, dtype=dtype)

# ref op
ref_output = SiluAndMul().forward_native(x)
ref_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.abs(ref_output).max().to(torch.float32))
ref_output_quant, ref_block_scale = scaled_fp4_quant(
ref_output, ref_global_scale)

# fused op
fused_output_quant = torch.empty_like(ref_output_quant)
fused_block_scale = torch.empty_like(ref_block_scale)
torch.ops._C.silu_and_mul_nvfp4_quant(fused_output_quant,
fused_block_scale, x,
ref_global_scale)

# check dtype
assert ref_output_quant.dtype == FP4_DTYPE
assert fused_output_quant.dtype == FP4_DTYPE
assert ref_output_quant.shape == fused_output_quant.shape

assert ref_block_scale.dtype == FP8_DTYPE
assert fused_block_scale.dtype == FP8_DTYPE
assert ref_block_scale.shape == fused_block_scale.shape

# check dequantized output
ref_output_dequant = dequantize_nvfp4_to_dtype(ref_output_quant,
ref_block_scale,
ref_global_scale, dtype,
device)
fused_output_dequant = dequantize_nvfp4_to_dtype(fused_output_quant,
fused_block_scale,
ref_global_scale, dtype,
device)

atol, rtol = 3e-1, 3e-1
torch.testing.assert_close(ref_output_dequant,
fused_output_dequant,
atol=atol,
rtol=rtol)
Loading