From d8b0b12a41243e525d81aa237026776f161f2842 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Mon, 23 Feb 2026 14:47:20 +0000 Subject: [PATCH 1/6] add helper vec instructions Signed-off-by: LopezCastroRoberto --- csrc/activation_kernels.cu | 220 ++--- csrc/cuda_vec_utils.cuh | 316 +++++++ .../activation_nvfp4_quant_fusion_kernels.cu | 287 +++--- csrc/quantization/fp4/nvfp4_experts_quant.cu | 857 +++++++++--------- csrc/quantization/fp4/nvfp4_quant_kernels.cu | 451 ++++----- csrc/quantization/fp4/nvfp4_utils.cuh | 693 ++++++-------- 6 files changed, 1460 insertions(+), 1364 deletions(-) create mode 100644 csrc/cuda_vec_utils.cuh diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index f1d4c137ccd1..97894e80f4c4 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -5,115 +5,11 @@ #include #include "cuda_compat.h" +#include "cuda_vec_utils.cuh" #include "dispatch_utils.h" namespace vllm { -struct alignas(32) u32x8_t { - uint32_t u0, u1, u2, u3, u4, u5, u6, u7; -}; - -__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 - asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n" - : "=r"(val.u0), "=r"(val.u1), "=r"(val.u2), "=r"(val.u3), - "=r"(val.u4), "=r"(val.u5), "=r"(val.u6), "=r"(val.u7) - : "l"(ptr)); -#else - const uint4* uint_ptr = reinterpret_cast(ptr); - uint4 top_half = __ldg(&uint_ptr[0]); - uint4 bottom_half = __ldg(&uint_ptr[1]); - val.u0 = top_half.x; - val.u1 = top_half.y; - val.u2 = top_half.z; - val.u3 = top_half.w; - val.u4 = bottom_half.x; - val.u5 = bottom_half.y; - val.u6 = bottom_half.z; - val.u7 = bottom_half.w; -#endif -} - -__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 - asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n" - : - : "l"(ptr), "r"(val.u0), "r"(val.u1), "r"(val.u2), "r"(val.u3), - "r"(val.u4), "r"(val.u5), "r"(val.u6), "r"(val.u7) - : "memory"); -#else - uint4* uint_ptr = reinterpret_cast(ptr); - uint_ptr[0] = make_uint4(val.u0, val.u1, val.u2, val.u3); - uint_ptr[1] = make_uint4(val.u4, val.u5, val.u6, val.u7); -#endif -} - -template -struct VecTraits; - -template <> -struct VecTraits { - static constexpr int ARCH_MAX_VEC_SIZE = 32; - using vec_t = u32x8_t; -}; - -template <> -struct VecTraits { - static constexpr int ARCH_MAX_VEC_SIZE = 16; - using vec_t = int4; -}; - -template -struct PackedTraits; - -template <> -struct PackedTraits { - using packed_t = __nv_bfloat162; -}; - -template <> -struct PackedTraits { - using packed_t = __half2; -}; - -template <> -struct PackedTraits { - using packed_t = float2; -}; - -template -__device__ __forceinline__ float2 cast_to_float2(const packed_t& val) { - if constexpr (std::is_same_v) { - return __bfloat1622float2(val); - } else if constexpr (std::is_same_v) { - return __half22float2(val); - } else if constexpr (std::is_same_v) { - return float2(val); - } -} - -template -__device__ __forceinline__ packed_t cast_to_packed(const float2& val) { - if constexpr (std::is_same_v) { - return __float22bfloat162_rn(val); - } else if constexpr (std::is_same_v) { - return __float22half2_rn(val); - } else if constexpr (std::is_same_v) { - return float2(val); - } -} - -template -__device__ __forceinline__ packed_t packed_mul(const packed_t& x, - const packed_t& y) { - if constexpr (std::is_same_v || - std::is_same_v) { - return __hmul2(x, y); - } else if constexpr (std::is_same_v) { - return make_float2(x.x * y.x, x.y * y.y); - } -} - template __device__ __forceinline__ scalar_t compute(const scalar_t& x, @@ -129,16 +25,6 @@ __device__ __forceinline__ packed_t packed_compute(const packed_t& x, : packed_mul(x, PACKED_ACT_FN(y)); } -// Check if all pointers are 16-byte aligned for int4 vectorized access -__host__ __device__ __forceinline__ bool is_16byte_aligned(const void* ptr) { - return (reinterpret_cast(ptr) & 15) == 0; -} - -// Check if all pointers are 16-byte aligned for longlong4_32a vectorized access -__host__ __device__ __forceinline__ bool is_32byte_aligned(const void* ptr) { - return (reinterpret_cast(ptr) & 31) == 0; -} - // Activation and gating kernel template. template ::vec_t; - constexpr int ARCH_MAX_VEC_SIZE = VecTraits::ARCH_MAX_VEC_SIZE; - constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(packed_t); + using cuda_t = typename CUDATypeConverter::Type; + using pvec_t = PackedVec; - const vec_t* x_vec = reinterpret_cast(x_ptr); - const vec_t* y_vec = reinterpret_cast(y_ptr); - vec_t* out_vec = reinterpret_cast(out_ptr); - const int num_vecs = d / 2 / VEC_SIZE; + const pvec_t* x_vec = reinterpret_cast(x_ptr); + const pvec_t* y_vec = reinterpret_cast(y_ptr); + pvec_t* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / 2 / pvec_t::NUM_ELTS; for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { - vec_t x, y; + pvec_t x, y; if constexpr (use_256b) { ld256(x, &x_vec[i]); ld256(y, &y_vec[i]); } else { - x = VLLM_LDG(&x_vec[i]); - y = VLLM_LDG(&y_vec[i]); + ld128(x, &x_vec[i]); + ld128(y, &y_vec[i]); } - auto* xp = reinterpret_cast(&x); - auto* yp = reinterpret_cast(&y); #pragma unroll - for (int j = 0; j < VEC_SIZE; j++) { - xp[j] = - packed_compute(xp[j], yp[j]); + for (int j = 0; j < pvec_t::NUM_ELTS; j++) { + x.elts[j] = + packed_compute(x.elts[j], + y.elts[j]); } if constexpr (use_256b) { st256(x, &out_vec[i]); } else { - out_vec[i] = x; + st128(x, &out_vec[i]); } } } else { @@ -279,28 +162,31 @@ packed_gelu_tanh_kernel(const packed_t& val) { } \ dim3 grid(num_tokens); \ int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ int vec_size = support_vec / at::elementSize(dtype); \ const bool use_vec = (d % vec_size == 0); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ if (use_vec) { \ dim3 block(std::min(d / vec_size, 1024)); \ - if (cc_major >= 10 && num_tokens > 128) { \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ + scalar_t, typename vllm::TypeConverter::Type, \ KERNEL, \ - PACKED_KERNEL::packed_t>, \ + PACKED_KERNEL::Type>, \ ACT_FIRST, true, true><<>>( \ out.data_ptr(), input.data_ptr(), d); \ }); \ } else { \ VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ + scalar_t, typename vllm::TypeConverter::Type, \ KERNEL, \ - PACKED_KERNEL::packed_t>, \ + PACKED_KERNEL::Type>, \ ACT_FIRST, true, false><<>>( \ out.data_ptr(), input.data_ptr(), d); \ }); \ @@ -309,9 +195,9 @@ packed_gelu_tanh_kernel(const packed_t& val) { dim3 block(std::min(d, 1024)); \ VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ + scalar_t, typename vllm::TypeConverter::Type, \ KERNEL, \ - PACKED_KERNEL::packed_t>, \ + PACKED_KERNEL::Type>, \ ACT_FIRST, false><<>>( \ out.data_ptr(), input.data_ptr(), d); \ }); \ @@ -376,35 +262,31 @@ __global__ void act_and_mul_kernel_with_param( scalar_t* out_ptr = out + blockIdx.x * d; if constexpr (use_vec) { - // Fast path: 128-bit/256-bit vectorized loop - using vec_t = typename VecTraits::vec_t; - constexpr int ARCH_MAX_VEC_SIZE = VecTraits::ARCH_MAX_VEC_SIZE; - constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(packed_t); + using cuda_t = typename CUDATypeConverter::Type; + using pvec_t = PackedVec; - const vec_t* x_vec = reinterpret_cast(x_ptr); - const vec_t* y_vec = reinterpret_cast(y_ptr); - vec_t* out_vec = reinterpret_cast(out_ptr); - const int num_vecs = d / 2 / VEC_SIZE; + const pvec_t* x_vec = reinterpret_cast(x_ptr); + const pvec_t* y_vec = reinterpret_cast(y_ptr); + pvec_t* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / 2 / pvec_t::NUM_ELTS; for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { - vec_t x, y; + pvec_t x, y; if constexpr (use_256b) { ld256(x, &x_vec[i]); ld256(y, &y_vec[i]); } else { - x = VLLM_LDG(&x_vec[i]); - y = VLLM_LDG(&y_vec[i]); + ld128(x, &x_vec[i]); + ld128(y, &y_vec[i]); } - auto* xp = reinterpret_cast(&x); - auto* yp = reinterpret_cast(&y); #pragma unroll - for (int j = 0; j < VEC_SIZE; j++) { - xp[j] = packed_mul(PACKED_ACT_FN(xp[j], param), yp[j]); + for (int j = 0; j < pvec_t::NUM_ELTS; j++) { + x.elts[j] = packed_mul(PACKED_ACT_FN(x.elts[j], param), y.elts[j]); } if constexpr (use_256b) { st256(x, &out_vec[i]); } else { - out_vec[i] = x; + st128(x, &out_vec[i]); } } } else { @@ -497,21 +379,24 @@ __global__ void swigluoai_and_mul_kernel( } \ dim3 grid(num_tokens); \ int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ int vec_size = support_vec / at::elementSize(dtype); \ const bool use_vec = (d % vec_size == 0); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ if (use_vec) { \ dim3 block(std::min(d / vec_size, 1024)); \ - if (cc_major >= 10 && num_tokens > 128) { \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ VLLM_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ + scalar_t, typename vllm::TypeConverter::Type, \ KERNEL, \ PACKED_KERNEL< \ - typename vllm::PackedTraits::packed_t>, \ + typename vllm::TypeConverter::Type>, \ true, true><<>>( \ out.data_ptr(), input.data_ptr(), d, \ PARAM); \ @@ -520,10 +405,10 @@ __global__ void swigluoai_and_mul_kernel( VLLM_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ + scalar_t, typename vllm::TypeConverter::Type, \ KERNEL, \ PACKED_KERNEL< \ - typename vllm::PackedTraits::packed_t>, \ + typename vllm::TypeConverter::Type>, \ true, false><<>>( \ out.data_ptr(), input.data_ptr(), d, \ PARAM); \ @@ -533,9 +418,9 @@ __global__ void swigluoai_and_mul_kernel( dim3 block(std::min(d, 1024)); \ VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ + scalar_t, typename vllm::TypeConverter::Type, \ KERNEL, \ - PACKED_KERNEL::packed_t>, \ + PACKED_KERNEL::Type>, \ false><<>>( \ out.data_ptr(), input.data_ptr(), d, PARAM); \ }); \ @@ -627,14 +512,17 @@ __global__ void activation_kernel( } \ dim3 grid(num_tokens); \ int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ int vec_size = support_vec / at::elementSize(dtype); \ const bool use_vec = (d % vec_size == 0); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ if (use_vec) { \ dim3 block(std::min(d / vec_size, 1024)); \ - if (cc_major >= 10 && num_tokens > 128) { \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ vllm::activation_kernel, true, true> \ <<>>(out.data_ptr(), \ diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh new file mode 100644 index 000000000000..d3d14dbf88ee --- /dev/null +++ b/csrc/cuda_vec_utils.cuh @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +#pragma once + +#include +#include +#include +#include +#include + +// Device-side: SM100+ architecture with CUDA 12.9+ toolkit, which +// together enable 256-bit (v8.u32) PTX load/store instructions. +// Use for PTX instruction selection with architecture fallback paths. +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \ + defined(CUDA_VERSION) && CUDA_VERSION >= 12090 + #define VLLM_256B_PTX_ENABLED 1 +#else + #define VLLM_256B_PTX_ENABLED 0 +#endif + +#ifndef USE_ROCM + +namespace vllm { + +// ============================================================ +// Types and traits +// ============================================================ + +// 256-bit (32-byte) aligned vector type: 8 x uint32_t +struct alignas(32) u32x8_t { + uint32_t d[8]; +}; + +// VecTraits — select between 128-bit (int4) and 256-bit +// (u32x8_t) vector types at compile time. +template +struct VecTraits; + +template <> +struct VecTraits { + static constexpr int ARCH_MAX_VEC_SIZE = 32; + using vec_t = u32x8_t; +}; + +template <> +struct VecTraits { + static constexpr int ARCH_MAX_VEC_SIZE = 16; + using vec_t = int4; +}; + +// TypeConverter — map between CUDA scalar and packed types +// half <-> half2, __nv_bfloat16 <-> __nv_bfloat162, etc. +template +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +template <> +struct TypeConverter { + using Type = float2; +}; + +template <> +struct TypeConverter { + using Type = float; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter { + using Type = __nv_bfloat162; +}; + +// CUDATypeConverter — map PyTorch scalar types to CUDA scalar +// c10::Half -> half, c10::BFloat16 -> __nv_bfloat16 +template +struct CUDATypeConverter { + using Type = T; +}; + +template <> +struct CUDATypeConverter { + using Type = half; +}; + +template <> +struct CUDATypeConverter { + using Type = __nv_bfloat16; +}; + +// PackedVec — typed vector container for packed element access. +// Derives alignment and element count from VecTraits. +// Type is the CUDA scalar type (e.g. half, __nv_bfloat16). +template +struct alignas(VecTraits::ARCH_MAX_VEC_SIZE) PackedVec { + static constexpr int NUM_ELTS = + VecTraits::ARCH_MAX_VEC_SIZE / + sizeof(typename TypeConverter::Type); + typename TypeConverter::Type elts[NUM_ELTS]; +}; + +// ============================================================ +// Load / store primitives +// ============================================================ + +// 256-bit load / store with architecture fallback. +// SM100+ : PTX v8 instructions (.nc / default hint) +// Older : two uint4 loads via __ldg +__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { +#if VLLM_256B_PTX_ENABLED + asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "l"(ptr)); +#else + const uint4* src = reinterpret_cast(ptr); + uint4* dst = reinterpret_cast(val.d); + dst[0] = __ldg(&src[0]); + dst[1] = __ldg(&src[1]); +#endif +} + +__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { +#if VLLM_256B_PTX_ENABLED + asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n" + : + : "l"(ptr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), + "r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), + "r"(val.d[7]) + : "memory"); +#else + uint4* dst = reinterpret_cast(ptr); + const uint4* src = reinterpret_cast(val.d); + dst[0] = src[0]; + dst[1] = src[1]; +#endif +} + +// Generic ld256 / st256 for any 32-byte aligned type (e.g. PackedVec). +// Non-template overloads above are preferred for u32x8_t. +template +__device__ __forceinline__ void ld256(T& val, const T* ptr) { + static_assert(sizeof(T) == 32, "ld256 requires a 32-byte type"); + ld256(reinterpret_cast(val), reinterpret_cast(ptr)); +} + +template +__device__ __forceinline__ void st256(T& val, T* ptr) { + static_assert(sizeof(T) == 32, "st256 requires a 32-byte type"); + st256(reinterpret_cast(val), reinterpret_cast(ptr)); +} + +// 128-bit load / store via __ldg (read-only cache hint). +template +__device__ __forceinline__ void ld128(T& val, const T* ptr) { + static_assert(sizeof(T) == 16, "ld128 requires a 16-byte type"); + *reinterpret_cast(&val) = + __ldg(reinterpret_cast(ptr)); +} + +template +__device__ __forceinline__ void st128(T& val, T* ptr) { + static_assert(sizeof(T) == 16, "st128 requires a 16-byte type"); + *reinterpret_cast(ptr) = *reinterpret_cast(&val); +} + +// 256-bit cache-streaming (.cs) load / store — SM100+ only. +__forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) { + u32x8_t val; + asm volatile( + "ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "l"(addr)); + return val; +} + +__forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) { + asm volatile( + "st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" + ::"l"(addr), + "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), + "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7])); +} + +// 32-bit cache-streaming (.cs) load / store — SM100+ only. +__forceinline__ __device__ int ld32_cs(const int* addr) { + int val; + asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr)); + return val; +} + +__forceinline__ __device__ void st32_cs(int* addr, int val) { + asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val)); +} + +// Predicated 256-bit / 128-bit cache-global (.cg) loads. +// Returns zero if pred is false. SM100+ only. +__device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr, + bool pred) { + asm volatile( + "{\n" + " .reg .pred pr;\n" + " setp.ne.u32 pr, %8, 0;\n" + " mov.u32 %0, 0;\n" + " mov.u32 %1, 0;\n" + " mov.u32 %2, 0;\n" + " mov.u32 %3, 0;\n" + " mov.u32 %4, 0;\n" + " mov.u32 %5, 0;\n" + " mov.u32 %6, 0;\n" + " mov.u32 %7, 0;\n" + " @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n" + "}\n" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "r"((int)pred), "l"(ptr)); +} + +__device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr, + bool pred) { + uint32_t r0, r1, r2, r3; + + asm volatile( + "{\n" + " .reg .pred pr;\n" + " setp.ne.u32 pr, %4, 0;\n" + " mov.u32 %0, 0;\n" + " mov.u32 %1, 0;\n" + " mov.u32 %2, 0;\n" + " mov.u32 %3, 0;\n" + " @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n" + "}\n" + : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) + : "r"((int)pred), "l"(ptr)); + + val = uint4{r0, r1, r2, r3}; +} + +// ============================================================ +// Alignment helpers +// ============================================================ + +__host__ __device__ __forceinline__ bool is_16byte_aligned(const void* ptr) { + return (reinterpret_cast(ptr) & 15) == 0; +} + +__host__ __device__ __forceinline__ bool is_32byte_aligned(const void* ptr) { + return (reinterpret_cast(ptr) & 31) == 0; +} + +// ============================================================ +// Packed type conversion and arithmetic +// ============================================================ + +template +__device__ __forceinline__ float2 cast_to_float2(const packed_t& val) { + if constexpr (std::is_same_v) { + return __bfloat1622float2(val); + } else if constexpr (std::is_same_v) { + return __half22float2(val); + } else if constexpr (std::is_same_v) { + return float2(val); + } +} + +template +__device__ __forceinline__ packed_t cast_to_packed(const float2& val) { + if constexpr (std::is_same_v) { + return __float22bfloat162_rn(val); + } else if constexpr (std::is_same_v) { + return __float22half2_rn(val); + } else if constexpr (std::is_same_v) { + return float2(val); + } +} + +template +__device__ __forceinline__ packed_t packed_mul(const packed_t& x, + const packed_t& y) { + if constexpr (std::is_same_v || + std::is_same_v) { + return __hmul2(x, y); + } else if constexpr (std::is_same_v) { + return make_float2(x.x * y.x, x.y * y.y); + } +} + +} // namespace vllm + +#endif // !USE_ROCM diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index d0264c4d154c..3c3d2ec12223 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -14,146 +14,147 @@ * limitations under the License. */ -#include - -#include -#include - -#include -#include - -#include -#include "dispatch_utils.h" - -#include "cuda_utils.h" -#include "launch_bounds_utils.h" - -// Define before including nvfp4_utils.cuh so the header -// can use this macro during compilation. -#define NVFP4_ENABLE_ELTS16 1 -#include "nvfp4_utils.cuh" - -namespace vllm { - -// Use UE4M3 by default. -template -__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) - silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, - int32_t num_padded_cols, - Type const* __restrict__ in, - float const* __restrict__ SFScale, - uint32_t* __restrict__ out, - uint32_t* __restrict__ SFout) { - using PackedVec = vllm::PackedVec; - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = - (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, - "Vec size is not matched."); - - // Precompute SF layout parameter (constant for entire kernel). - int32_t const numKTiles = (numCols + 63) / 64; - - // Get the global scaling factor, which will be applied to the SF. - // Note SFScale is the same as next GEMM's alpha, which is - // (448.f / (Alpha_A / 6.f)). - float const SFScaleVal = (SFScale == nullptr) ? 1.0f : SFScale[0]; - - int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; - int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; - - // Input tensor row/col loops. - for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { - if (colIdx < num_padded_cols) { - PackedVec in_vec; - PackedVec in_vec2; - int64_t inOffset = - rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx; - int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + - numCols / CVT_FP4_ELTS_PER_THREAD + colIdx; - - bool valid = (rowIdx < numRows) && (elem_idx < numCols); - if constexpr (CVT_FP4_PACK16) { - ld256_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 8], - valid); - ld256_or_zero_cg_u32( - in_vec2, &reinterpret_cast(in)[inOffset2 * 8], - valid); - } else { - ld128_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 4], - valid); - ld128_or_zero_cg_u32( - in_vec2, &reinterpret_cast(in)[inOffset2 * 4], - valid); - } - - // Compute silu and mul - PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2); - - auto sf_out = - cvt_quant_to_fp4_get_sf_out_offset( - rowIdx, colIdx, numKTiles, SFout); - - auto out_val = - cvt_warp_fp16_to_fp4( - out_silu_mul, SFScaleVal, sf_out); - - if (valid) { - if constexpr (CVT_FP4_PACK16) { - int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; - uint64_t packed64 = - (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); - reinterpret_cast(out)[outOffset >> 1] = packed64; - } else { - out[inOffset] = out_val; - } - } - } - } -} - -} // namespace vllm - -void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] - torch::Tensor& output_sf, - torch::Tensor& input, // [..., 2 * d] - torch::Tensor& input_sf) { - int32_t m = input.size(0); - int32_t n = input.size(1) / 2; - - TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); - TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || - input.scalar_type() == at::ScalarType::BFloat16, - "Unsupported input data type for quantize_to_fp4."); - - int multiProcessorCount = - get_device_attribute(cudaDevAttrMultiProcessorCount, -1); - - auto input_sf_ptr = static_cast(input_sf.data_ptr()); - auto sf_out = static_cast(output_sf.data_ptr()); - auto output_ptr = static_cast(output.data_ptr()); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); - dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); - int const numBlocksPerSM = - vllm_runtime_blocks_per_sm(static_cast(block.x)); - - int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); - - int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); - int grid_x = std::min( - int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); - dim3 grid(grid_x, grid_y); - - VLLM_DISPATCH_HALF_TYPES( - input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { - using cuda_type = vllm::CUDATypeConverter::Type; - auto input_ptr = static_cast(input.data_ptr()); - vllm::silu_mul_cvt_fp16_to_fp4<<>>( - m, n, sf_n_unpadded, input_ptr, input_sf_ptr, - reinterpret_cast(output_ptr), - reinterpret_cast(sf_out)); - }); -} + #include + + #include + #include + + #include + #include + + #include + #include "dispatch_utils.h" + + #include "cuda_utils.h" + #include "launch_bounds_utils.h" + + // Define before including nvfp4_utils.cuh so the header + // can use this macro during compilation. + #define NVFP4_ENABLE_ELTS16 1 + #include "nvfp4_utils.cuh" + + namespace vllm { + + // Use UE4M3 by default. + template + __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) + silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, + int32_t num_padded_cols, + Type const* __restrict__ in, + float const* __restrict__ SFScale, + uint32_t* __restrict__ out, + uint32_t* __restrict__ SFout) { + using PackedVec = vllm::PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Precompute SF layout parameter (constant for entire kernel). + int32_t const numKTiles = (numCols + 63) / 64; + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = (SFScale == nullptr) ? 1.0f : SFScale[0]; + + int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; + int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + if (colIdx < num_padded_cols) { + PackedVec in_vec; + PackedVec in_vec2; + int64_t inOffset = + rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx; + int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + + numCols / CVT_FP4_ELTS_PER_THREAD + colIdx; + + bool valid = (rowIdx < numRows) && (elem_idx < numCols); + if constexpr (CVT_FP4_PACK16) { + ld256_cg_or_zero( + reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 8], valid); + ld256_cg_or_zero( + reinterpret_cast(in_vec2), + &reinterpret_cast(in)[inOffset2 * 8], valid); + } else { + ld128_cg_or_zero( + reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 4], valid); + ld128_cg_or_zero( + reinterpret_cast(in_vec2), + &reinterpret_cast(in)[inOffset2 * 4], valid); + } + + // Compute silu and mul + PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2); + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx, colIdx, numKTiles, SFout); + + auto out_val = + cvt_warp_fp16_to_fp4( + out_silu_mul, SFScaleVal, sf_out); + + if (valid) { + if constexpr (CVT_FP4_PACK16) { + int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; + uint64_t packed64 = + (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); + reinterpret_cast(out)[outOffset >> 1] = packed64; + } else { + out[inOffset] = out_val; + } + } + } + } + } + + } // namespace vllm + + void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] + torch::Tensor& output_sf, + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& input_sf) { + int32_t m = input.size(0); + int32_t n = input.size(1) / 2; + + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); + + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); + int const numBlocksPerSM = + vllm_runtime_blocks_per_sm(static_cast(block.x)); + + int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); + + int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); + int grid_x = std::min( + int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); + dim3 grid(grid_x, grid_y); + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + vllm::silu_mul_cvt_fp16_to_fp4<<>>( + m, n, sf_n_unpadded, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); + } + \ No newline at end of file diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index 32685c201102..de93adc5aa91 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -14,431 +14,432 @@ * limitations under the License. */ -#include - -#include -#include - -#include -#include - -#include -#include "dispatch_utils.h" - -#include "cuda_utils.h" -#include "nvfp4_utils.cuh" -#include "launch_bounds_utils.h" - -namespace vllm { - -// NVFP4 quantization kernel for experts (low-latency path). -// When FUSE_SILU_MUL=true, expects input with gate||up layout and fuses -// SiLU(gate)*up before quantization. -// Use UE4M3 by default. -template -__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) - cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, - float const* SFScale, uint32_t* out, uint32_t* SFout, - uint32_t* input_offset_by_experts, - uint32_t* output_scale_offset_by_experts, int n_experts, - bool low_latency) { - using PackedVec = PackedVec; - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = - (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, - "Vec size is not matched."); - - // Precompute SF layout parameter (constant for entire kernel). - int32_t const numKTiles = (numCols + 63) / 64; - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; - // When fusing SiLU+Mul, input has gate || up layout (doubled width) - int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow; - - // Each global thread processes one element - for (int globalIdx = tid; globalIdx < numRows * colsPerRow; - globalIdx += gridDim.x * blockDim.x) { - // Calculate which row and column this global thread should process - int rowIdx = globalIdx / colsPerRow; - int colIdx = globalIdx % colsPerRow; - - // Find index within the experts using different strategies based on expert - // count - int rowIdx_in_expert = 0; - int expert_idx = 0; - - if constexpr (SMALL_NUM_EXPERTS) { - for (int i = 0; i < n_experts; i++) { - uint32_t current_offset = __ldca(&input_offset_by_experts[i]); - uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]); - if (rowIdx >= current_offset && rowIdx < next_offset) { - rowIdx_in_expert = rowIdx - current_offset; - expert_idx = i; - break; - } - } - } else { - // Load input offsets into registers first, then do the computation. - // Local array size set to 17 because of register limit. - uint32_t local_offsets[17]; - for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) { - *reinterpret_cast(local_offsets) = - __ldca(reinterpret_cast( - &input_offset_by_experts[chunk_start])); - *reinterpret_cast(local_offsets + 4) = - __ldca(reinterpret_cast( - &input_offset_by_experts[chunk_start + 4])); - *reinterpret_cast(local_offsets + 8) = - __ldca(reinterpret_cast( - &input_offset_by_experts[chunk_start + 8])); - *reinterpret_cast(local_offsets + 12) = - __ldca(reinterpret_cast( - &input_offset_by_experts[chunk_start + 12])); - local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]); - -// Check against the 16 loaded offsets -#pragma unroll - for (int i = 0; i < 16; i++) { - if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) { - rowIdx_in_expert = rowIdx - local_offsets[i]; - expert_idx = chunk_start + i; - break; - } - } - } - } - - // Load input and optionally apply fused SiLU+Mul - int64_t inOffset = rowIdx * inColsPerRow + colIdx; - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - PackedVec quant_input; - if constexpr (FUSE_SILU_MUL) { - PackedVec in_vec_up = - reinterpret_cast(in)[inOffset + colsPerRow]; - quant_input = compute_silu_mul(in_vec, in_vec_up); - } else { - quant_input = in_vec; - } - - // Get the output tensor offset. - // Same as inOffset because 8 elements are packed into one uint32_t. - int64_t outOffset = rowIdx * colsPerRow + colIdx; - auto& out_pos = out[outOffset]; - - // Get the global scaling factor, which will be applied to the SF. - // Note SFScale is the same as next GEMM's alpha, which is - // (448.f / (Alpha_A / 6.f)). - float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; - - uint32_t* SFout_in_expert = - SFout + output_scale_offset_by_experts[expert_idx] * numKTiles; - - auto sf_out = - cvt_quant_to_fp4_get_sf_out_offset( - rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); - - out_pos = cvt_warp_fp16_to_fp4( - quant_input, SFScaleVal, sf_out); - } -} - -// NVFP4 quantization kernel for LARGE_M_TOPK = true (large m_topk optimized -// version). When FUSE_SILU_MUL=true, expects input with gate||up layout and -// fuses SiLU(gate)*up before quantization. -template -__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) - cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, - float const* SFScale, uint32_t* out, uint32_t* SFout, - uint32_t* input_offset_by_experts, - uint32_t* output_scale_offset_by_experts, int n_experts) { - using PackedVec = PackedVec; - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = - (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, - "Vec size is not matched."); - - // Precompute SF layout parameter (constant for entire kernel). - int32_t const numKTiles = (numCols + 63) / 64; - - extern __shared__ uint32_t shared_input_offsets[]; - - // Load input offsets into shared memory. - // If n_experts is larger than 4, use vectorized int4 to save instructions. - // If n_experts is smaller than 4, read directly. - if constexpr (SMALL_NUM_EXPERTS) { - for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) { - shared_input_offsets[i] = input_offset_by_experts[i]; - } - } else { - for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) { - *reinterpret_cast(&shared_input_offsets[i]) = - *reinterpret_cast(&input_offset_by_experts[i]); - } - if (threadIdx.x == 0) { - shared_input_offsets[n_experts] = input_offset_by_experts[n_experts]; - } - } - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; - // When fusing SiLU+Mul, input has gate || up layout (doubled width) - int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow; - - // Each global thread processes one element - for (int globalIdx = tid; globalIdx < numRows * colsPerRow; - globalIdx += gridDim.x * blockDim.x) { - // Calculate which row and column this global thread should process - int rowIdx = globalIdx / colsPerRow; - int colIdx = globalIdx % colsPerRow; - - // Find expert using binary search for better performance with large m_topk - int rowIdx_in_expert = 0; - int expert_idx = 0; - - // Binary search through experts using shared memory - int left = 0, right = n_experts - 1; - while (left <= right) { - int mid = (left + right) / 2; - // Get offsets: shared_input_offsets[i] corresponds to - // input_offset_by_experts[i] - uint32_t mid_offset = shared_input_offsets[mid]; - uint32_t next_offset = shared_input_offsets[mid + 1]; - - if (rowIdx >= mid_offset && rowIdx < next_offset) { - rowIdx_in_expert = rowIdx - mid_offset; - expert_idx = mid; - break; - } else if (rowIdx < mid_offset) { - right = mid - 1; - } else { - left = mid + 1; - } - } - - // Load input and optionally apply fused SiLU+Mul - int64_t inOffset = rowIdx * inColsPerRow + colIdx; - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - PackedVec quant_input; - if constexpr (FUSE_SILU_MUL) { - PackedVec in_vec_up = - reinterpret_cast(in)[inOffset + colsPerRow]; - quant_input = compute_silu_mul(in_vec, in_vec_up); - } else { - quant_input = in_vec; - } - - int64_t outOffset = rowIdx * colsPerRow + colIdx; - auto& out_pos = out[outOffset]; - - float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; - - uint32_t* SFout_in_expert = - SFout + output_scale_offset_by_experts[expert_idx] * numKTiles; - - auto sf_out = - cvt_quant_to_fp4_get_sf_out_offset( - rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); - - out_pos = cvt_warp_fp16_to_fp4( - quant_input, SFScaleVal, sf_out); - } -} - -template -void quant_impl(void* output, void* output_scale, void* input, - void* input_global_scale, void* input_offset_by_experts, - void* output_scale_offset_by_experts, int m_topk, int k, - int n_experts, cudaStream_t stream) { - int multiProcessorCount = - get_device_attribute(cudaDevAttrMultiProcessorCount, -1); - - // Grid, Block size. - // Each thread converts 8 values. - int const workSizePerRow = k / ELTS_PER_THREAD; - int const totalWorkSize = m_topk * workSizePerRow; - dim3 block(std::min(workSizePerRow, 512)); - // Get number of blocks per SM - int const numBlocksPerSM = - vllm_runtime_blocks_per_sm(static_cast(block.x)); - dim3 grid(std::min(static_cast((totalWorkSize + block.x - 1) / block.x), - multiProcessorCount * numBlocksPerSM)); - while (grid.x <= multiProcessorCount && block.x > 64) { - grid.x *= 2; - block.x = (block.x + 1) / 2; - } - - int const blockRepeat = - (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x); - if (blockRepeat > 1) { - size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t); - if (n_experts >= 4) { - cvt_fp16_to_fp4 - <<>>( - m_topk, k, reinterpret_cast(input), - reinterpret_cast(input_global_scale), - reinterpret_cast(output), - reinterpret_cast(output_scale), - reinterpret_cast(input_offset_by_experts), - reinterpret_cast(output_scale_offset_by_experts), - n_experts); - } else { - cvt_fp16_to_fp4 - <<>>( - m_topk, k, reinterpret_cast(input), - reinterpret_cast(input_global_scale), - reinterpret_cast(output), - reinterpret_cast(output_scale), - reinterpret_cast(input_offset_by_experts), - reinterpret_cast(output_scale_offset_by_experts), - n_experts); - } - } else { - if (n_experts >= 16) { - cvt_fp16_to_fp4 - <<>>( - m_topk, k, reinterpret_cast(input), - reinterpret_cast(input_global_scale), - reinterpret_cast(output), - reinterpret_cast(output_scale), - reinterpret_cast(input_offset_by_experts), - reinterpret_cast(output_scale_offset_by_experts), - n_experts, /* bool low_latency */ true); - } else { - cvt_fp16_to_fp4 - <<>>( - m_topk, k, reinterpret_cast(input), - reinterpret_cast(input_global_scale), - reinterpret_cast(output), - reinterpret_cast(output_scale), - reinterpret_cast(input_offset_by_experts), - reinterpret_cast(output_scale_offset_by_experts), - n_experts, /* bool low_latency */ true); - } - } -} - -} // namespace vllm - -/*Quantization entry for fp4 experts quantization*/ -#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x, m) \ - TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") -#define CHECK_INPUT(x, m) \ - CHECK_TH_CUDA(x, m); \ - CHECK_CONTIGUOUS(x, m); - -constexpr auto HALF = at::ScalarType::Half; -constexpr auto BF16 = at::ScalarType::BFloat16; -constexpr auto FLOAT = at::ScalarType::Float; -constexpr auto INT = at::ScalarType::Int; -constexpr auto UINT8 = at::ScalarType::Byte; - -// Common validation for fp4 experts quantization entry points. -static void validate_fp4_experts_quant_inputs( - torch::Tensor const& output, torch::Tensor const& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk, - int64_t k) { - CHECK_INPUT(output, "output"); - CHECK_INPUT(output_scale, "output_scale"); - CHECK_INPUT(input, "input"); - CHECK_INPUT(input_global_scale, "input_global_scale"); - CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts"); - CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts"); - - TORCH_CHECK(output.dim() == 2); - TORCH_CHECK(output_scale.dim() == 2); - TORCH_CHECK(input.dim() == 2); - TORCH_CHECK(input_global_scale.dim() == 1); - TORCH_CHECK(input_offset_by_experts.dim() == 1); - TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); - - TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); - TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); - TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); - TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); - // output is uint8 (two nvfp4 values are packed into one uint8) - // output_scale is int32 (four fp8 values are packed into one int32) - TORCH_CHECK(output.scalar_type() == UINT8); - TORCH_CHECK(output_scale.scalar_type() == INT); - - const int BLOCK_SIZE = 16; - TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); - auto n_experts = input_global_scale.size(0); - TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); - TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); - TORCH_CHECK(output.size(0) == m_topk); - TORCH_CHECK(output.size(1) == k / 2); - int scales_k = k / BLOCK_SIZE; - // 4 means the swizzle requirement by nvidia nvfp4. - int padded_k = (scales_k + (4 - 1)) / 4 * 4; - // 4 means 4 fp8 values are packed into one int32 - TORCH_CHECK(output_scale.size(1) * 4 == padded_k); -} - -void scaled_fp4_experts_quant_sm1xxa( - torch::Tensor& output, torch::Tensor& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts) { - auto m_topk = input.size(0); - auto k = input.size(1); - - validate_fp4_experts_quant_inputs(output, output_scale, input, - input_global_scale, input_offset_by_experts, - output_scale_offset_by_experts, m_topk, k); - - auto n_experts = input_global_scale.size(0); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = - at::cuda::getCurrentCUDAStream(input.get_device()); - - VLLM_DISPATCH_HALF_TYPES( - input.scalar_type(), "nvfp4_experts_quant_kernel", [&] { - using cuda_type = vllm::CUDATypeConverter::Type; - vllm::quant_impl( - output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), - input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, - stream); - }); -} - -void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( - torch::Tensor& output, torch::Tensor& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts) { - auto m_topk = input.size(0); - // Input has gate || up layout, so k = input.size(1) / 2 - auto k_times_2 = input.size(1); - TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)"); - auto k = k_times_2 / 2; - - validate_fp4_experts_quant_inputs(output, output_scale, input, - input_global_scale, input_offset_by_experts, - output_scale_offset_by_experts, m_topk, k); - - auto n_experts = input_global_scale.size(0); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = - at::cuda::getCurrentCUDAStream(input.get_device()); - - VLLM_DISPATCH_HALF_TYPES( - input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] { - using cuda_type = vllm::CUDATypeConverter::Type; - vllm::quant_impl( - output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), - input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, - stream); - }); -} + #include + + #include + #include + + #include + #include + + #include + #include "dispatch_utils.h" + + #include "cuda_utils.h" + #include "nvfp4_utils.cuh" + #include "launch_bounds_utils.h" + + namespace vllm { + + // NVFP4 quantization kernel for experts (low-latency path). + // When FUSE_SILU_MUL=true, expects input with gate||up layout and fuses + // SiLU(gate)*up before quantization. + // Use UE4M3 by default. + template + __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts, + bool low_latency) { + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Precompute SF layout parameter (constant for entire kernel). + int32_t const numKTiles = (numCols + 63) / 64; + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; + // When fusing SiLU+Mul, input has gate || up layout (doubled width) + int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow; + + // Each global thread processes one element + for (int globalIdx = tid; globalIdx < numRows * colsPerRow; + globalIdx += gridDim.x * blockDim.x) { + // Calculate which row and column this global thread should process + int rowIdx = globalIdx / colsPerRow; + int colIdx = globalIdx % colsPerRow; + + // Find index within the experts using different strategies based on expert + // count + int rowIdx_in_expert = 0; + int expert_idx = 0; + + if constexpr (SMALL_NUM_EXPERTS) { + for (int i = 0; i < n_experts; i++) { + uint32_t current_offset = __ldca(&input_offset_by_experts[i]); + uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]); + if (rowIdx >= current_offset && rowIdx < next_offset) { + rowIdx_in_expert = rowIdx - current_offset; + expert_idx = i; + break; + } + } + } else { + // Load input offsets into registers first, then do the computation. + // Local array size set to 17 because of register limit. + uint32_t local_offsets[17]; + for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) { + *reinterpret_cast(local_offsets) = + __ldca(reinterpret_cast( + &input_offset_by_experts[chunk_start])); + *reinterpret_cast(local_offsets + 4) = + __ldca(reinterpret_cast( + &input_offset_by_experts[chunk_start + 4])); + *reinterpret_cast(local_offsets + 8) = + __ldca(reinterpret_cast( + &input_offset_by_experts[chunk_start + 8])); + *reinterpret_cast(local_offsets + 12) = + __ldca(reinterpret_cast( + &input_offset_by_experts[chunk_start + 12])); + local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]); + + // Check against the 16 loaded offsets + #pragma unroll + for (int i = 0; i < 16; i++) { + if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) { + rowIdx_in_expert = rowIdx - local_offsets[i]; + expert_idx = chunk_start + i; + break; + } + } + } + } + + // Load input and optionally apply fused SiLU+Mul + int64_t inOffset = rowIdx * inColsPerRow + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + PackedVec quant_input; + if constexpr (FUSE_SILU_MUL) { + PackedVec in_vec_up = + reinterpret_cast(in)[inOffset + colsPerRow]; + quant_input = compute_silu_mul(in_vec, in_vec_up); + } else { + quant_input = in_vec; + } + + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = rowIdx * colsPerRow + colIdx; + auto& out_pos = out[outOffset]; + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; + + uint32_t* SFout_in_expert = + SFout + output_scale_offset_by_experts[expert_idx] * numKTiles; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); + + out_pos = cvt_warp_fp16_to_fp4( + quant_input, SFScaleVal, sf_out); + } + } + + // NVFP4 quantization kernel for LARGE_M_TOPK = true (large m_topk optimized + // version). When FUSE_SILU_MUL=true, expects input with gate||up layout and + // fuses SiLU(gate)*up before quantization. + template + __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts) { + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Precompute SF layout parameter (constant for entire kernel). + int32_t const numKTiles = (numCols + 63) / 64; + + extern __shared__ uint32_t shared_input_offsets[]; + + // Load input offsets into shared memory. + // If n_experts is larger than 4, use vectorized int4 to save instructions. + // If n_experts is smaller than 4, read directly. + if constexpr (SMALL_NUM_EXPERTS) { + for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) { + shared_input_offsets[i] = input_offset_by_experts[i]; + } + } else { + for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) { + *reinterpret_cast(&shared_input_offsets[i]) = + *reinterpret_cast(&input_offset_by_experts[i]); + } + if (threadIdx.x == 0) { + shared_input_offsets[n_experts] = input_offset_by_experts[n_experts]; + } + } + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; + // When fusing SiLU+Mul, input has gate || up layout (doubled width) + int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow; + + // Each global thread processes one element + for (int globalIdx = tid; globalIdx < numRows * colsPerRow; + globalIdx += gridDim.x * blockDim.x) { + // Calculate which row and column this global thread should process + int rowIdx = globalIdx / colsPerRow; + int colIdx = globalIdx % colsPerRow; + + // Find expert using binary search for better performance with large m_topk + int rowIdx_in_expert = 0; + int expert_idx = 0; + + // Binary search through experts using shared memory + int left = 0, right = n_experts - 1; + while (left <= right) { + int mid = (left + right) / 2; + // Get offsets: shared_input_offsets[i] corresponds to + // input_offset_by_experts[i] + uint32_t mid_offset = shared_input_offsets[mid]; + uint32_t next_offset = shared_input_offsets[mid + 1]; + + if (rowIdx >= mid_offset && rowIdx < next_offset) { + rowIdx_in_expert = rowIdx - mid_offset; + expert_idx = mid; + break; + } else if (rowIdx < mid_offset) { + right = mid - 1; + } else { + left = mid + 1; + } + } + + // Load input and optionally apply fused SiLU+Mul + int64_t inOffset = rowIdx * inColsPerRow + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + PackedVec quant_input; + if constexpr (FUSE_SILU_MUL) { + PackedVec in_vec_up = + reinterpret_cast(in)[inOffset + colsPerRow]; + quant_input = compute_silu_mul(in_vec, in_vec_up); + } else { + quant_input = in_vec; + } + + int64_t outOffset = rowIdx * colsPerRow + colIdx; + auto& out_pos = out[outOffset]; + + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; + + uint32_t* SFout_in_expert = + SFout + output_scale_offset_by_experts[expert_idx] * numKTiles; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); + + out_pos = cvt_warp_fp16_to_fp4( + quant_input, SFScaleVal, sf_out); + } + } + + template + void quant_impl(void* output, void* output_scale, void* input, + void* input_global_scale, void* input_offset_by_experts, + void* output_scale_offset_by_experts, int m_topk, int k, + int n_experts, cudaStream_t stream) { + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + + // Grid, Block size. + // Each thread converts 8 values. + int const workSizePerRow = k / ELTS_PER_THREAD; + int const totalWorkSize = m_topk * workSizePerRow; + dim3 block(std::min(workSizePerRow, 512)); + // Get number of blocks per SM + int const numBlocksPerSM = + vllm_runtime_blocks_per_sm(static_cast(block.x)); + dim3 grid(std::min(static_cast((totalWorkSize + block.x - 1) / block.x), + multiProcessorCount * numBlocksPerSM)); + while (grid.x <= multiProcessorCount && block.x > 64) { + grid.x *= 2; + block.x = (block.x + 1) / 2; + } + + int const blockRepeat = + (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x); + if (blockRepeat > 1) { + size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t); + if (n_experts >= 4) { + cvt_fp16_to_fp4 + <<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + n_experts); + } else { + cvt_fp16_to_fp4 + <<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + n_experts); + } + } else { + if (n_experts >= 16) { + cvt_fp16_to_fp4 + <<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + n_experts, /* bool low_latency */ true); + } else { + cvt_fp16_to_fp4 + <<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + n_experts, /* bool low_latency */ true); + } + } + } + + } // namespace vllm + + /*Quantization entry for fp4 experts quantization*/ + #define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") + #define CHECK_CONTIGUOUS(x, m) \ + TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") + #define CHECK_INPUT(x, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); + + constexpr auto HALF = at::ScalarType::Half; + constexpr auto BF16 = at::ScalarType::BFloat16; + constexpr auto FLOAT = at::ScalarType::Float; + constexpr auto INT = at::ScalarType::Int; + constexpr auto UINT8 = at::ScalarType::Byte; + + // Common validation for fp4 experts quantization entry points. + static void validate_fp4_experts_quant_inputs( + torch::Tensor const& output, torch::Tensor const& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk, + int64_t k) { + CHECK_INPUT(output, "output"); + CHECK_INPUT(output_scale, "output_scale"); + CHECK_INPUT(input, "input"); + CHECK_INPUT(input_global_scale, "input_global_scale"); + CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts"); + CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts"); + + TORCH_CHECK(output.dim() == 2); + TORCH_CHECK(output_scale.dim() == 2); + TORCH_CHECK(input.dim() == 2); + TORCH_CHECK(input_global_scale.dim() == 1); + TORCH_CHECK(input_offset_by_experts.dim() == 1); + TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); + + TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); + TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); + TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); + TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); + // output is uint8 (two nvfp4 values are packed into one uint8) + // output_scale is int32 (four fp8 values are packed into one int32) + TORCH_CHECK(output.scalar_type() == UINT8); + TORCH_CHECK(output_scale.scalar_type() == INT); + + const int BLOCK_SIZE = 16; + TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); + auto n_experts = input_global_scale.size(0); + TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output.size(0) == m_topk); + TORCH_CHECK(output.size(1) == k / 2); + int scales_k = k / BLOCK_SIZE; + // 4 means the swizzle requirement by nvidia nvfp4. + int padded_k = (scales_k + (4 - 1)) / 4 * 4; + // 4 means 4 fp8 values are packed into one int32 + TORCH_CHECK(output_scale.size(1) * 4 == padded_k); + } + + void scaled_fp4_experts_quant_sm1xxa( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { + auto m_topk = input.size(0); + auto k = input.size(1); + + validate_fp4_experts_quant_inputs(output, output_scale, input, + input_global_scale, input_offset_by_experts, + output_scale_offset_by_experts, m_topk, k); + + auto n_experts = input_global_scale.size(0); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(input.get_device()); + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "nvfp4_experts_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + vllm::quant_impl( + output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), + input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, + stream); + }); + } + + void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { + auto m_topk = input.size(0); + // Input has gate || up layout, so k = input.size(1) / 2 + auto k_times_2 = input.size(1); + TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)"); + auto k = k_times_2 / 2; + + validate_fp4_experts_quant_inputs(output, output_scale, input, + input_global_scale, input_offset_by_experts, + output_scale_offset_by_experts, m_topk, k); + + auto n_experts = input_global_scale.size(0); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(input.get_device()); + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + vllm::quant_impl( + output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), + input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, + stream); + }); + } + \ No newline at end of file diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index c27fb69d44be..7bd25a309096 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -14,228 +14,229 @@ * limitations under the License. */ -#include - -#include -#include - -#include -#include - -#include -#include "dispatch_utils.h" - -#include "cuda_utils.h" -#include "launch_bounds_utils.h" - -// Define before including nvfp4_utils.cuh so the header -// can use this macro during compilation. -#define NVFP4_ENABLE_ELTS16 1 -#include "nvfp4_utils.cuh" - -namespace vllm { - -// Use UE4M3 by default. -template -__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) - cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, int32_t num_padded_cols, - Type const* __restrict__ in, - float const* __restrict__ SFScale, - uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { - using PackedVec = vllm::PackedVec; - - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = - (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, - "Vec size is not matched."); - - // Precompute SF layout parameter (constant for entire kernel). - int32_t const numKTiles = (numCols + 63) / 64; - - int sf_m = round_up(numRows, 128); - int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; - int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; - - // Get the global scaling factor, which will be applied to the SF. - // Note SFScale is the same as next GEMM's alpha, which is - // (448.f / (Alpha_A / 6.f)). - float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0]; - - // Iterate over all rows and cols including padded ones - - // ensures we visit every single scale factor address to initialize it. - for (int rowIdx = blockIdx.x; rowIdx < sf_m; rowIdx += gridDim.x) { - if (colIdx < num_padded_cols) { - PackedVec in_vec; - int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; - - // If we are outside valid rows OR outside valid columns -> Use Zeros - bool valid = (rowIdx < numRows) && (elem_idx < numCols); - if constexpr (CVT_FP4_PACK16) { - ld256_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 8], - valid); - } else { - ld128_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 4], - valid); - } - - auto sf_out = - cvt_quant_to_fp4_get_sf_out_offset( - rowIdx, colIdx, numKTiles, SFout); - - auto out_val = - cvt_warp_fp16_to_fp4( - in_vec, global_scale, sf_out); - - // We do NOT write output for padding because the 'out' tensor is not - // padded. - if (valid) { - if constexpr (CVT_FP4_PACK16) { - int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; - uint64_t packed64 = - (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); - reinterpret_cast(out)[outOffset >> 1] = packed64; - } else { - out[inOffset] = out_val; - } - } - } - } -} - -// Use UE4M3 by default. -template -__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) - cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols, - int32_t sf_n_unpadded, Type const* __restrict__ in, - float const* __restrict__ SFScale, - uint32_t* __restrict__ out, - uint32_t* __restrict__ SFout) { - using PackedVec = PackedVec; - - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = - (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, - "Vec size is not matched."); - - int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; - int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; - - // Get the global scaling factor, which will be applied to the SF. - // Note SFScale is the same as next GEMM's alpha, which is - // (448.f / (Alpha_A / 6.f)). - float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0]; - - // Iterate over all rows and cols including padded ones - - // ensures we visit every single scale factor address to initialize it. - for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { - if (colIdx < sf_n_unpadded) { - PackedVec in_vec; - int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; - - // If we are outside valid rows OR outside valid columns -> Use Zeros - bool valid = (rowIdx < numRows) && (elem_idx < numCols); - if constexpr (CVT_FP4_PACK16) { - ld256_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 8], - valid); - } else { - ld128_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 4], - valid); - } - - auto sf_out = - sf_out_rowmajor_u8(rowIdx, colIdx, sf_n_unpadded, SFout); - - auto out_val = - cvt_warp_fp16_to_fp4( - in_vec, global_scale, sf_out); - - // We do NOT write output for padding because the 'out' tensor is not - // padded. - if (valid) { - if constexpr (CVT_FP4_PACK16) { - int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; - uint64_t packed64 = - (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); - reinterpret_cast(out)[outOffset >> 1] = packed64; - } else { - out[inOffset] = out_val; - } - } - } - } -} - -} // namespace vllm - -void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, - torch::Tensor const& input, - torch::Tensor const& output_sf, - torch::Tensor const& input_sf, - bool is_sf_swizzled_layout) { - int32_t m = input.size(0); - int32_t n = input.size(1); - - TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); - TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || - input.scalar_type() == at::ScalarType::BFloat16, - "Unsupported input data type for quantize_to_fp4."); - - int multiProcessorCount = - get_device_attribute(cudaDevAttrMultiProcessorCount, -1); - - auto input_sf_ptr = static_cast(input_sf.data_ptr()); - auto sf_out = static_cast(output_sf.data_ptr()); - auto output_ptr = static_cast(output.data_ptr()); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); - - int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); - - // Grid, Block size. Each thread converts 8 values. - dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); - int const numBlocksPerSM = - vllm_runtime_blocks_per_sm(static_cast(block.x)); - - if (is_sf_swizzled_layout) { - int sf_n_int = int(vllm::round_up(sf_n_unpadded, 4) / 4); - int32_t num_padded_cols = - sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; - - int grid_y = vllm::div_round_up(num_padded_cols, static_cast(block.x)); - int grid_x = - std::min(vllm::computeEffectiveRows(m), - std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); - dim3 grid(grid_x, grid_y); - - VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { - using cuda_type = vllm::CUDATypeConverter::Type; - auto input_ptr = static_cast(input.data_ptr()); - // NOTE: We don't support e8m0 scales at this moment. - vllm::cvt_fp16_to_fp4<<>>( - m, n, num_padded_cols, input_ptr, input_sf_ptr, - reinterpret_cast(output_ptr), - reinterpret_cast(sf_out)); - }); - } else { - int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); - int grid_x = std::min( - m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); - dim3 grid(grid_x, grid_y); - - VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { - using cuda_type = vllm::CUDATypeConverter::Type; - auto input_ptr = static_cast(input.data_ptr()); - // NOTE: We don't support e8m0 scales at this moment. - vllm::cvt_fp16_to_fp4_sf_major - <<>>(m, n, sf_n_unpadded, input_ptr, - input_sf_ptr, - reinterpret_cast(output_ptr), - reinterpret_cast(sf_out)); - }); - } -} + #include + + #include + #include + + #include + #include + + #include + #include "dispatch_utils.h" + + #include "cuda_utils.h" + #include "launch_bounds_utils.h" + + // Define before including nvfp4_utils.cuh so the header + // can use this macro during compilation. + #define NVFP4_ENABLE_ELTS16 1 + #include "nvfp4_utils.cuh" + + namespace vllm { + + // Use UE4M3 by default. + template + __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, int32_t num_padded_cols, + Type const* __restrict__ in, + float const* __restrict__ SFScale, + uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { + using PackedVec = vllm::PackedVec; + + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Precompute SF layout parameter (constant for entire kernel). + int32_t const numKTiles = (numCols + 63) / 64; + + int sf_m = round_up(numRows, 128); + int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; + int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0]; + + // Iterate over all rows and cols including padded ones - + // ensures we visit every single scale factor address to initialize it. + for (int rowIdx = blockIdx.x; rowIdx < sf_m; rowIdx += gridDim.x) { + if (colIdx < num_padded_cols) { + PackedVec in_vec; + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + + // If we are outside valid rows OR outside valid columns -> Use Zeros + bool valid = (rowIdx < numRows) && (elem_idx < numCols); + if constexpr (CVT_FP4_PACK16) { + ld256_cg_or_zero( + reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 8], valid); + } else { + ld128_cg_or_zero( + reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 4], valid); + } + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx, colIdx, numKTiles, SFout); + + auto out_val = + cvt_warp_fp16_to_fp4( + in_vec, global_scale, sf_out); + + // We do NOT write output for padding because the 'out' tensor is not + // padded. + if (valid) { + if constexpr (CVT_FP4_PACK16) { + int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; + uint64_t packed64 = + (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); + reinterpret_cast(out)[outOffset >> 1] = packed64; + } else { + out[inOffset] = out_val; + } + } + } + } + } + + // Use UE4M3 by default. + template + __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) + cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols, + int32_t sf_n_unpadded, Type const* __restrict__ in, + float const* __restrict__ SFScale, + uint32_t* __restrict__ out, + uint32_t* __restrict__ SFout) { + using PackedVec = PackedVec; + + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; + int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0]; + + // Iterate over all rows and cols including padded ones - + // ensures we visit every single scale factor address to initialize it. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + if (colIdx < sf_n_unpadded) { + PackedVec in_vec; + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + + // If we are outside valid rows OR outside valid columns -> Use Zeros + bool valid = (rowIdx < numRows) && (elem_idx < numCols); + if constexpr (CVT_FP4_PACK16) { + ld256_cg_or_zero( + reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 8], valid); + } else { + ld128_cg_or_zero( + reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 4], valid); + } + + auto sf_out = + sf_out_rowmajor_u8(rowIdx, colIdx, sf_n_unpadded, SFout); + + auto out_val = + cvt_warp_fp16_to_fp4( + in_vec, global_scale, sf_out); + + // We do NOT write output for padding because the 'out' tensor is not + // padded. + if (valid) { + if constexpr (CVT_FP4_PACK16) { + int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; + uint64_t packed64 = + (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); + reinterpret_cast(out)[outOffset >> 1] = packed64; + } else { + out[inOffset] = out_val; + } + } + } + } + } + + } // namespace vllm + + void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, + torch::Tensor const& input, + torch::Tensor const& output_sf, + torch::Tensor const& input_sf, + bool is_sf_swizzled_layout) { + int32_t m = input.size(0); + int32_t n = input.size(1); + + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); + + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); + + // Grid, Block size. Each thread converts 8 values. + dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); + int const numBlocksPerSM = + vllm_runtime_blocks_per_sm(static_cast(block.x)); + + if (is_sf_swizzled_layout) { + int sf_n_int = int(vllm::round_up(sf_n_unpadded, 4) / 4); + int32_t num_padded_cols = + sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; + + int grid_y = vllm::div_round_up(num_padded_cols, static_cast(block.x)); + int grid_x = + std::min(vllm::computeEffectiveRows(m), + std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); + dim3 grid(grid_x, grid_y); + + VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + // NOTE: We don't support e8m0 scales at this moment. + vllm::cvt_fp16_to_fp4<<>>( + m, n, num_padded_cols, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); + } else { + int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); + int grid_x = std::min( + m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); + dim3 grid(grid_x, grid_y); + + VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + // NOTE: We don't support e8m0 scales at this moment. + vllm::cvt_fp16_to_fp4_sf_major + <<>>(m, n, sf_n_unpadded, input_ptr, + input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); + } + } + \ No newline at end of file diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index 3e7adb9e2931..502b86b31270 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -14,405 +14,294 @@ * limitations under the License. */ -#pragma once - -#include -#include - -#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \ - defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) - #define ELTS_PER_THREAD 16 -constexpr int CVT_FP4_ELTS_PER_THREAD = 16; -constexpr bool CVT_FP4_PACK16 = true; -#else - #define ELTS_PER_THREAD 8 -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr bool CVT_FP4_PACK16 = false; -#endif - -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -namespace vllm { - -// Convert PyTorch cpp type to CUDA type -template -struct CUDATypeConverter { - using Type = T; -}; - -template <> -struct CUDATypeConverter { - using Type = half; -}; - -template <> -struct CUDATypeConverter { - using Type = __nv_bfloat16; -}; - -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \ - defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) -// Define a 32 bytes packed data type. -template -struct alignas(32) PackedVec { - typename TypeConverter::Type elts[8]; -}; -#else -// Define a 16 bytes packed data type. -template -struct alignas(16) PackedVec { - typename TypeConverter::Type elts[4]; -}; -#endif - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - -template -__host__ __device__ inline Int round_up(Int x, Int y) { - static_assert(std::is_integral_v, - "round_up argument must be integral type"); - return ((x + y - 1) / y) * y; -} - -template -__host__ __device__ __forceinline__ Int div_round_up(Int x, Int y) { - return (x + y - 1) / y; -} - -// Compute effective rows for grid configuration with swizzled SF layouts. -inline int computeEffectiveRows(int m) { - constexpr int ROW_TILE = 128; - return round_up(m, ROW_TILE); -} - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) { - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), - "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); - return val; -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -__device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) { - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}\n" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; -} - -struct u32x2 { - uint32_t lo, hi; -}; - -using fp4_packed_t = std::conditional_t; - -__device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) { - u32x2 out; - asm volatile( - "{\n" - ".reg .b8 b0;\n" - ".reg .b8 b1;\n" - ".reg .b8 b2;\n" - ".reg .b8 b3;\n" - ".reg .b8 b4;\n" - ".reg .b8 b5;\n" - ".reg .b8 b6;\n" - ".reg .b8 b7;\n" - "cvt.rn.satfinite.e2m1x2.f32 b0, %3, %2;\n" - "cvt.rn.satfinite.e2m1x2.f32 b1, %5, %4;\n" - "cvt.rn.satfinite.e2m1x2.f32 b2, %7, %6;\n" - "cvt.rn.satfinite.e2m1x2.f32 b3, %9, %8;\n" - "cvt.rn.satfinite.e2m1x2.f32 b4, %11, %10;\n" - "cvt.rn.satfinite.e2m1x2.f32 b5, %13, %12;\n" - "cvt.rn.satfinite.e2m1x2.f32 b6, %15, %14;\n" - "cvt.rn.satfinite.e2m1x2.f32 b7, %17, %16;\n" - "mov.b32 %0, {b0, b1, b2, b3};\n" - "mov.b32 %1, {b4, b5, b6, b7};\n" - "}\n" - : "=r"(out.lo), "=r"(out.hi) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y), - "f"(array[4].x), "f"(array[4].y), "f"(array[5].x), "f"(array[5].y), - "f"(array[6].x), "f"(array[6].y), "f"(array[7].x), "f"(array[7].y)); - return out; -} - -__device__ __forceinline__ uint32_t pack_fp4(float2 (&v)[4]) { - return fp32_vec8_to_e2m1(v); -} - -__device__ __forceinline__ u32x2 pack_fp4(float2 (&v)[8]) { - return fp32_vec16_to_e2m1(v); -} - -// Fast reciprocal. -__device__ __forceinline__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ __forceinline__ void ld128_or_zero_cg_u32(PackedVec& out, - const void* ptr, - bool pred) { - uint32_t r0, r1, r2, r3; - - asm volatile( - "{\n" - " .reg .pred pr;\n" - " setp.ne.u32 pr, %4, 0;\n" - " mov.u32 %0, 0;\n" - " mov.u32 %1, 0;\n" - " mov.u32 %2, 0;\n" - " mov.u32 %3, 0;\n" - " @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n" - "}\n" - : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) - : "r"((int)pred), "l"(ptr)); - - *reinterpret_cast(&out) = uint4{r0, r1, r2, r3}; -} - -template -__device__ __forceinline__ void ld256_or_zero_cg_u32(PackedVec& out, - const void* ptr, - bool pred) { - uint32_t r0, r1, r2, r3, r4, r5, r6, r7; - - asm volatile( - "{\n" - " .reg .pred pr;\n" - " setp.ne.u32 pr, %8, 0;\n" - " mov.u32 %0, 0;\n" - " mov.u32 %1, 0;\n" - " mov.u32 %2, 0;\n" - " mov.u32 %3, 0;\n" - " mov.u32 %4, 0;\n" - " mov.u32 %5, 0;\n" - " mov.u32 %6, 0;\n" - " mov.u32 %7, 0;\n" - " @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n" - "}\n" - : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4), "=r"(r5), "=r"(r6), - "=r"(r7) - : "r"((int)pred), "l"(ptr)); - - reinterpret_cast(&out)[0] = uint4{r0, r1, r2, r3}; - reinterpret_cast(&out)[1] = uint4{r4, r5, r6, r7}; -} - -// Compute SF output offset for swizzled tensor core layout. -// SF layout: [numMTiles, numKTiles, 32, 4, 4] -// Caller must precompute: numKTiles = (numCols + 63) / 64 -template -__device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset( - int rowIdx, int colIdx, int32_t numKTiles, SFType* SFout) { - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || - CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF != 0) { - return nullptr; - } - - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // Decompose indices using bitwise ops (all divisors are powers of 2). - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - int32_t mTileIdx = mIdx >> 7; // mIdx / 128 - int32_t outerMIdx = mIdx & 31; // mIdx % 32 - int32_t innerMIdx = (mIdx >> 5) & 3; // (mIdx / 32) % 4 - int32_t kTileIdx = kIdx >> 2; // kIdx / 4 - int32_t innerKIdx = kIdx & 3; // kIdx % 4 - - // Compute global SF offset: mTileIdx * (numKTiles * 512) + kTileIdx * 512 + - // outerMIdx * 16 + innerMIdx * 4 + innerKIdx - // Use bitwise OR for non-overlapping lower bits. - int64_t SFOffset = (static_cast(mTileIdx) * numKTiles + kTileIdx) - << 9 | - (outerMIdx << 4) | (innerMIdx << 2) | innerKIdx; - - return reinterpret_cast(SFout) + SFOffset; -} - -template -__device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack, - int packs_per_row_sf, - SFType* SFout) { - constexpr int PACK = CVT_FP4_ELTS_PER_THREAD; - constexpr int THREADS_PER_SF = - CVT_FP4_SF_VEC_SIZE / PACK; // 1 if PACK=16, 2 else PACK=8 - - if (threadIdx.x % THREADS_PER_SF != 0) return nullptr; - - int sf_col = - pack / THREADS_PER_SF; // PACK=16 => sf_col=pack; PACK=8 => sf_col=pack/2 - int64_t off = (int64_t)row * packs_per_row_sf + sf_col; - - return (uint8_t*)SFout + off; -} - -// Quantizes the provided PackedVec into the uint32_t output -template -__device__ __forceinline__ fp4_packed_t -cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { - // Get absolute maximum values among the local 8 values. - auto localMax = __habs2(vec.elts[0]); - - // Local maximum value. -#pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - localMax = __hmax2(localMax, __habs2(vec.elts[i])); - } - - // Get the absolute maximum among all 16 values (two threads). - - if constexpr (CVT_FP4_NUM_THREADS_PER_SF == 2) { - localMax = __hmax2(__shfl_xor_sync(0xffffffffu, localMax, 1), localMax); - } - // Get the final absolute maximum values. - float vecMax = float(__hmax(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - // Extract the 8 exponent bits from float32. - // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. - uint32_t tmp = reinterpret_cast(SFValue) >> 23; - fp8SFVal = tmp & 0xff; - // Convert back to fp32. - reinterpret_cast(SFValue) = tmp << 23; - } else { - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; - // Convert back to fp32. - SFValue = float(tmp); - } - - // Write the SF to global memory (STG.8). - if (SFout) *SFout = fp8SFVal; - - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * - // reciprocal(SFScaleVal)) - float outputScale = - SFValue != 0.0f ? reciprocal_approximate_ftz( - SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - - // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - -#pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(vec.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - return pack_fp4(fp2Vals); -} - -// silu in float32 -__device__ __forceinline__ float silu(float x) { - return __fdividef(x, (1.f + __expf(-x))); -} - -__device__ __forceinline__ float2 silu2(float2 x) { - return make_float2(silu(x.x), silu(x.y)); -} - -template -__inline__ __device__ PackedVec compute_silu_mul( - const PackedVec& x_vec, const PackedVec& y_vec) { - PackedVec result; - -#pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { - // silu_mul in float32 - if constexpr (std::is_same_v) { - float2 silu_vec = silu2(__half22float2(x_vec.elts[i])); - result.elts[i] = __float22half2_rn( - __fmul2_rn(silu_vec, __half22float2(y_vec.elts[i]))); - } else { - float2 silu_vec = silu2(__bfloat1622float2(x_vec.elts[i])); - result.elts[i] = __float22bfloat162_rn( - __fmul2_rn(silu_vec, __bfloat1622float2(y_vec.elts[i]))); - } - } - return result; -} - -} // namespace vllm + #pragma once + + #include + #include + + #include "../../cuda_vec_utils.cuh" + + #if defined(NVFP4_ENABLE_ELTS16) && \ + defined(CUDA_VERSION) && CUDA_VERSION >= 12090 && \ + defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 + #define ELTS_PER_THREAD 16 + constexpr int CVT_FP4_ELTS_PER_THREAD = 16; + constexpr bool CVT_FP4_PACK16 = true; + #else + #define ELTS_PER_THREAD 8 + constexpr int CVT_FP4_ELTS_PER_THREAD = 8; + constexpr bool CVT_FP4_PACK16 = false; + #endif + + constexpr int CVT_FP4_SF_VEC_SIZE = 16; + + namespace vllm { + + template + struct PackedVec<__nv_fp8_e4m3, use_256b> { + __nv_fp8x2_e4m3 elts[8]; + }; + + template + __host__ __device__ inline Int round_up(Int x, Int y) { + static_assert(std::is_integral_v, + "round_up argument must be integral type"); + return ((x + y - 1) / y) * y; + } + + template + __host__ __device__ __forceinline__ Int div_round_up(Int x, Int y) { + return (x + y - 1) / y; + } + + // Compute effective rows for grid configuration with swizzled SF layouts. + inline int computeEffectiveRows(int m) { + constexpr int ROW_TILE = 128; + return round_up(m, ROW_TILE); + } + + // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). + inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) { + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; + } + + // Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). + __device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) { + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}\n" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; + } + + struct u32x2 { + uint32_t lo, hi; + }; + + using fp4_packed_t = std::conditional_t; + + __device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) { + u32x2 out; + asm volatile( + "{\n" + ".reg .b8 b0;\n" + ".reg .b8 b1;\n" + ".reg .b8 b2;\n" + ".reg .b8 b3;\n" + ".reg .b8 b4;\n" + ".reg .b8 b5;\n" + ".reg .b8 b6;\n" + ".reg .b8 b7;\n" + "cvt.rn.satfinite.e2m1x2.f32 b0, %3, %2;\n" + "cvt.rn.satfinite.e2m1x2.f32 b1, %5, %4;\n" + "cvt.rn.satfinite.e2m1x2.f32 b2, %7, %6;\n" + "cvt.rn.satfinite.e2m1x2.f32 b3, %9, %8;\n" + "cvt.rn.satfinite.e2m1x2.f32 b4, %11, %10;\n" + "cvt.rn.satfinite.e2m1x2.f32 b5, %13, %12;\n" + "cvt.rn.satfinite.e2m1x2.f32 b6, %15, %14;\n" + "cvt.rn.satfinite.e2m1x2.f32 b7, %17, %16;\n" + "mov.b32 %0, {b0, b1, b2, b3};\n" + "mov.b32 %1, {b4, b5, b6, b7};\n" + "}\n" + : "=r"(out.lo), "=r"(out.hi) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y), + "f"(array[4].x), "f"(array[4].y), "f"(array[5].x), "f"(array[5].y), + "f"(array[6].x), "f"(array[6].y), "f"(array[7].x), "f"(array[7].y)); + return out; + } + + __device__ __forceinline__ uint32_t pack_fp4(float2 (&v)[4]) { + return fp32_vec8_to_e2m1(v); + } + + __device__ __forceinline__ u32x2 pack_fp4(float2 (&v)[8]) { + return fp32_vec16_to_e2m1(v); + } + + // Fast reciprocal. + __device__ __forceinline__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(b) : "f"(a)); + return b; + } + + // Compute SF output offset for swizzled tensor core layout. + // SF layout: [numMTiles, numKTiles, 32, 4, 4] + // Caller must precompute: numKTiles = (numCols + 63) / 64 + template + __device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset( + int rowIdx, int colIdx, int32_t numKTiles, SFType* SFout) { + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF != 0) { + return nullptr; + } + + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // Decompose indices using bitwise ops (all divisors are powers of 2). + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + int32_t mTileIdx = mIdx >> 7; // mIdx / 128 + int32_t outerMIdx = mIdx & 31; // mIdx % 32 + int32_t innerMIdx = (mIdx >> 5) & 3; // (mIdx / 32) % 4 + int32_t kTileIdx = kIdx >> 2; // kIdx / 4 + int32_t innerKIdx = kIdx & 3; // kIdx % 4 + + // Compute global SF offset: mTileIdx * (numKTiles * 512) + kTileIdx * 512 + + // outerMIdx * 16 + innerMIdx * 4 + innerKIdx + // Use bitwise OR for non-overlapping lower bits. + int64_t SFOffset = (static_cast(mTileIdx) * numKTiles + kTileIdx) + << 9 | + (outerMIdx << 4) | (innerMIdx << 2) | innerKIdx; + + return reinterpret_cast(SFout) + SFOffset; + } + + template + __device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack, + int packs_per_row_sf, + SFType* SFout) { + constexpr int PACK = CVT_FP4_ELTS_PER_THREAD; + constexpr int THREADS_PER_SF = + CVT_FP4_SF_VEC_SIZE / PACK; // 1 if PACK=16, 2 else PACK=8 + + if (threadIdx.x % THREADS_PER_SF != 0) return nullptr; + + int sf_col = + pack / THREADS_PER_SF; // PACK=16 => sf_col=pack; PACK=8 => sf_col=pack/2 + int64_t off = (int64_t)row * packs_per_row_sf + sf_col; + + return (uint8_t*)SFout + off; + } + + // Quantizes the provided PackedVec into the uint32_t output + template + __device__ __forceinline__ fp4_packed_t + cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + + // Local maximum value. + #pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + + if constexpr (CVT_FP4_NUM_THREADS_PER_SF == 2) { + localMax = __hmax2(__shfl_xor_sync(0xffffffffu, localMax, 1), localMax); + } + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + + // Write the SF to global memory (STG.8). + if (SFout) *SFout = fp8SFVal; + + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0.0f ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + fp2Vals[i] = cast_to_float2(vec.elts[i]); + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + return pack_fp4(fp2Vals); + } + + // silu in float32 + __device__ __forceinline__ float silu(float x) { + return __fdividef(x, (1.f + __expf(-x))); + } + + __device__ __forceinline__ float2 silu2(float2 x) { + return make_float2(silu(x.x), silu(x.y)); + } + + template + __inline__ __device__ PackedVec compute_silu_mul( + const PackedVec& x_vec, const PackedVec& y_vec) { + PackedVec result; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { + // silu_mul in float32 + using packed_t = typename TypeConverter::Type; + float2 silu_vec = silu2(cast_to_float2(x_vec.elts[i])); + float2 y_f2 = cast_to_float2(y_vec.elts[i]); + result.elts[i] = cast_to_packed( + make_float2(silu_vec.x * y_f2.x, silu_vec.y * y_f2.y)); + } + return result; + } + + } // namespace vllm + \ No newline at end of file From 426ca895eb9346c280fee8bf923807931f8b68d0 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Mon, 23 Feb 2026 15:03:59 +0000 Subject: [PATCH 2/6] pre-commit failures Signed-off-by: LopezCastroRoberto --- csrc/activation_kernels.cu | 101 +-- csrc/cuda_vec_utils.cuh | 40 +- .../activation_nvfp4_quant_fusion_kernels.cu | 287 +++--- csrc/quantization/fp4/nvfp4_experts_quant.cu | 857 +++++++++--------- csrc/quantization/fp4/nvfp4_quant_kernels.cu | 451 +++++---- csrc/quantization/fp4/nvfp4_utils.cuh | 581 ++++++------ 6 files changed, 1153 insertions(+), 1164 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 97894e80f4c4..24b4bd93eb45 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -58,9 +58,8 @@ __global__ void act_and_mul_kernel( } #pragma unroll for (int j = 0; j < pvec_t::NUM_ELTS; j++) { - x.elts[j] = - packed_compute(x.elts[j], - y.elts[j]); + x.elts[j] = packed_compute( + x.elts[j], y.elts[j]); } if constexpr (use_256b) { st256(x, &out_vec[i]); @@ -153,54 +152,54 @@ packed_gelu_tanh_kernel(const packed_t& val) { // Launch activation and gating kernel. // Use ACT_FIRST (bool) indicating whether to apply the activation function // first. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ - auto dtype = input.scalar_type(); \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - if (num_tokens == 0) { \ - return; \ - } \ - dim3 grid(num_tokens); \ - int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = \ - (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ - ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ - : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ - int vec_size = support_vec / at::elementSize(dtype); \ - const bool use_vec = (d % vec_size == 0); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - if (use_vec) { \ - dim3 block(std::min(d / vec_size, 1024)); \ - if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ + int vec_size = support_vec / at::elementSize(dtype); \ + const bool use_vec = (d % vec_size == 0); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ scalar_t, typename vllm::TypeConverter::Type, \ - KERNEL, \ + KERNEL, \ PACKED_KERNEL::Type>, \ - ACT_FIRST, true, true><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ - } else { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ + ACT_FIRST, true, true><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } else { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ scalar_t, typename vllm::TypeConverter::Type, \ - KERNEL, \ + KERNEL, \ PACKED_KERNEL::Type>, \ - ACT_FIRST, true, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ - } \ - } else { \ - dim3 block(std::min(d, 1024)); \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ + ACT_FIRST, true, false><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ scalar_t, typename vllm::TypeConverter::Type, \ - KERNEL, \ + KERNEL, \ PACKED_KERNEL::Type>, \ - ACT_FIRST, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ + ACT_FIRST, false><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ } void silu_and_mul(torch::Tensor& out, // [..., d] @@ -395,8 +394,7 @@ __global__ void swigluoai_and_mul_kernel( vllm::act_and_mul_kernel_with_param< \ scalar_t, typename vllm::TypeConverter::Type, \ KERNEL, \ - PACKED_KERNEL< \ - typename vllm::TypeConverter::Type>, \ + PACKED_KERNEL::Type>, \ true, true><<>>( \ out.data_ptr(), input.data_ptr(), d, \ PARAM); \ @@ -407,8 +405,7 @@ __global__ void swigluoai_and_mul_kernel( vllm::act_and_mul_kernel_with_param< \ scalar_t, typename vllm::TypeConverter::Type, \ KERNEL, \ - PACKED_KERNEL< \ - typename vllm::TypeConverter::Type>, \ + PACKED_KERNEL::Type>, \ true, false><<>>( \ out.data_ptr(), input.data_ptr(), d, \ PARAM); \ @@ -420,9 +417,9 @@ __global__ void swigluoai_and_mul_kernel( vllm::act_and_mul_kernel_with_param< \ scalar_t, typename vllm::TypeConverter::Type, \ KERNEL, \ - PACKED_KERNEL::Type>, \ - false><<>>( \ - out.data_ptr(), input.data_ptr(), d, PARAM); \ + PACKED_KERNEL::Type>, false> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d, PARAM); \ }); \ } diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh index d3d14dbf88ee..2b3b23470561 100644 --- a/csrc/cuda_vec_utils.cuh +++ b/csrc/cuda_vec_utils.cuh @@ -118,9 +118,8 @@ struct CUDATypeConverter { // Type is the CUDA scalar type (e.g. half, __nv_bfloat16). template struct alignas(VecTraits::ARCH_MAX_VEC_SIZE) PackedVec { - static constexpr int NUM_ELTS = - VecTraits::ARCH_MAX_VEC_SIZE / - sizeof(typename TypeConverter::Type); + static constexpr int NUM_ELTS = VecTraits::ARCH_MAX_VEC_SIZE / + sizeof(typename TypeConverter::Type); typename TypeConverter::Type elts[NUM_ELTS]; }; @@ -132,33 +131,33 @@ struct alignas(VecTraits::ARCH_MAX_VEC_SIZE) PackedVec { // SM100+ : PTX v8 instructions (.nc / default hint) // Older : two uint4 loads via __ldg __device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { -#if VLLM_256B_PTX_ENABLED + #if VLLM_256B_PTX_ENABLED asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n" : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) : "l"(ptr)); -#else + #else const uint4* src = reinterpret_cast(ptr); uint4* dst = reinterpret_cast(val.d); dst[0] = __ldg(&src[0]); dst[1] = __ldg(&src[1]); -#endif + #endif } __device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { -#if VLLM_256B_PTX_ENABLED + #if VLLM_256B_PTX_ENABLED asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n" : : "l"(ptr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7]) : "memory"); -#else + #else uint4* dst = reinterpret_cast(ptr); const uint4* src = reinterpret_cast(val.d); dst[0] = src[0]; dst[1] = src[1]; -#endif + #endif } // Generic ld256 / st256 for any 32-byte aligned type (e.g. PackedVec). @@ -179,8 +178,7 @@ __device__ __forceinline__ void st256(T& val, T* ptr) { template __device__ __forceinline__ void ld128(T& val, const T* ptr) { static_assert(sizeof(T) == 16, "ld128 requires a 16-byte type"); - *reinterpret_cast(&val) = - __ldg(reinterpret_cast(ptr)); + *reinterpret_cast(&val) = __ldg(reinterpret_cast(ptr)); } template @@ -192,20 +190,18 @@ __device__ __forceinline__ void st128(T& val, T* ptr) { // 256-bit cache-streaming (.cs) load / store — SM100+ only. __forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) { u32x8_t val; - asm volatile( - "ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];" - : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), - "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) - : "l"(addr)); + asm volatile("ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "l"(addr)); return val; } __forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) { asm volatile( - "st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" - ::"l"(addr), - "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), - "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7])); + "st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" ::"l"(addr), + "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]), + "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7])); } // 32-bit cache-streaming (.cs) load / store — SM100+ only. @@ -222,7 +218,7 @@ __forceinline__ __device__ void st32_cs(int* addr, int val) { // Predicated 256-bit / 128-bit cache-global (.cg) loads. // Returns zero if pred is false. SM100+ only. __device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr, - bool pred) { + bool pred) { asm volatile( "{\n" " .reg .pred pr;\n" @@ -243,7 +239,7 @@ __device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr, } __device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr, - bool pred) { + bool pred) { uint32_t r0, r1, r2, r3; asm volatile( diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index 3c3d2ec12223..8c0e82b7651f 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -14,147 +14,146 @@ * limitations under the License. */ - #include - - #include - #include - - #include - #include - - #include - #include "dispatch_utils.h" - - #include "cuda_utils.h" - #include "launch_bounds_utils.h" - - // Define before including nvfp4_utils.cuh so the header - // can use this macro during compilation. - #define NVFP4_ENABLE_ELTS16 1 - #include "nvfp4_utils.cuh" - - namespace vllm { - - // Use UE4M3 by default. - template - __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) - silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, - int32_t num_padded_cols, - Type const* __restrict__ in, - float const* __restrict__ SFScale, - uint32_t* __restrict__ out, - uint32_t* __restrict__ SFout) { - using PackedVec = vllm::PackedVec; - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = - (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, - "Vec size is not matched."); - - // Precompute SF layout parameter (constant for entire kernel). - int32_t const numKTiles = (numCols + 63) / 64; - - // Get the global scaling factor, which will be applied to the SF. - // Note SFScale is the same as next GEMM's alpha, which is - // (448.f / (Alpha_A / 6.f)). - float const SFScaleVal = (SFScale == nullptr) ? 1.0f : SFScale[0]; - - int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; - int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; - - // Input tensor row/col loops. - for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { - if (colIdx < num_padded_cols) { - PackedVec in_vec; - PackedVec in_vec2; - int64_t inOffset = - rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx; - int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + - numCols / CVT_FP4_ELTS_PER_THREAD + colIdx; - - bool valid = (rowIdx < numRows) && (elem_idx < numCols); - if constexpr (CVT_FP4_PACK16) { - ld256_cg_or_zero( - reinterpret_cast(in_vec), - &reinterpret_cast(in)[inOffset * 8], valid); - ld256_cg_or_zero( - reinterpret_cast(in_vec2), - &reinterpret_cast(in)[inOffset2 * 8], valid); - } else { - ld128_cg_or_zero( - reinterpret_cast(in_vec), - &reinterpret_cast(in)[inOffset * 4], valid); - ld128_cg_or_zero( - reinterpret_cast(in_vec2), - &reinterpret_cast(in)[inOffset2 * 4], valid); - } - - // Compute silu and mul - PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2); - - auto sf_out = - cvt_quant_to_fp4_get_sf_out_offset( - rowIdx, colIdx, numKTiles, SFout); - - auto out_val = - cvt_warp_fp16_to_fp4( - out_silu_mul, SFScaleVal, sf_out); - - if (valid) { - if constexpr (CVT_FP4_PACK16) { - int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; - uint64_t packed64 = - (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); - reinterpret_cast(out)[outOffset >> 1] = packed64; - } else { - out[inOffset] = out_val; - } - } - } - } - } - - } // namespace vllm - - void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] - torch::Tensor& output_sf, - torch::Tensor& input, // [..., 2 * d] - torch::Tensor& input_sf) { - int32_t m = input.size(0); - int32_t n = input.size(1) / 2; - - TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); - TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || - input.scalar_type() == at::ScalarType::BFloat16, - "Unsupported input data type for quantize_to_fp4."); - - int multiProcessorCount = - get_device_attribute(cudaDevAttrMultiProcessorCount, -1); - - auto input_sf_ptr = static_cast(input_sf.data_ptr()); - auto sf_out = static_cast(output_sf.data_ptr()); - auto output_ptr = static_cast(output.data_ptr()); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); - dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); - int const numBlocksPerSM = - vllm_runtime_blocks_per_sm(static_cast(block.x)); - - int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); - - int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); - int grid_x = std::min( - int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); - dim3 grid(grid_x, grid_y); - - VLLM_DISPATCH_HALF_TYPES( - input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { - using cuda_type = vllm::CUDATypeConverter::Type; - auto input_ptr = static_cast(input.data_ptr()); - vllm::silu_mul_cvt_fp16_to_fp4<<>>( - m, n, sf_n_unpadded, input_ptr, input_sf_ptr, - reinterpret_cast(output_ptr), - reinterpret_cast(sf_out)); - }); - } - \ No newline at end of file +#include + +#include +#include + +#include +#include + +#include +#include "dispatch_utils.h" + +#include "cuda_utils.h" +#include "launch_bounds_utils.h" + +// Define before including nvfp4_utils.cuh so the header +// can use this macro during compilation. +#define NVFP4_ENABLE_ELTS16 1 +#include "nvfp4_utils.cuh" + +namespace vllm { + +// Use UE4M3 by default. +template +__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) + silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, + int32_t num_padded_cols, + Type const* __restrict__ in, + float const* __restrict__ SFScale, + uint32_t* __restrict__ out, + uint32_t* __restrict__ SFout) { + using PackedVec = vllm::PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Precompute SF layout parameter (constant for entire kernel). + int32_t const numKTiles = (numCols + 63) / 64; + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = (SFScale == nullptr) ? 1.0f : SFScale[0]; + + int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; + int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + if (colIdx < num_padded_cols) { + PackedVec in_vec; + PackedVec in_vec2; + int64_t inOffset = + rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx; + int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + + numCols / CVT_FP4_ELTS_PER_THREAD + colIdx; + + bool valid = (rowIdx < numRows) && (elem_idx < numCols); + if constexpr (CVT_FP4_PACK16) { + ld256_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 8], + valid); + ld256_cg_or_zero(reinterpret_cast(in_vec2), + &reinterpret_cast(in)[inOffset2 * 8], + valid); + } else { + ld128_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 4], + valid); + ld128_cg_or_zero(reinterpret_cast(in_vec2), + &reinterpret_cast(in)[inOffset2 * 4], + valid); + } + + // Compute silu and mul + PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2); + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx, colIdx, numKTiles, SFout); + + auto out_val = + cvt_warp_fp16_to_fp4( + out_silu_mul, SFScaleVal, sf_out); + + if (valid) { + if constexpr (CVT_FP4_PACK16) { + int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; + uint64_t packed64 = + (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); + reinterpret_cast(out)[outOffset >> 1] = packed64; + } else { + out[inOffset] = out_val; + } + } + } + } +} + +} // namespace vllm + +void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] + torch::Tensor& output_sf, + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& input_sf) { + int32_t m = input.size(0); + int32_t n = input.size(1) / 2; + + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); + + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); + int const numBlocksPerSM = + vllm_runtime_blocks_per_sm(static_cast(block.x)); + + int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); + + int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); + int grid_x = std::min( + int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); + dim3 grid(grid_x, grid_y); + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + vllm::silu_mul_cvt_fp16_to_fp4<<>>( + m, n, sf_n_unpadded, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); +} diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index de93adc5aa91..3162b6cdb8a9 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -14,432 +14,431 @@ * limitations under the License. */ - #include - - #include - #include - - #include - #include - - #include - #include "dispatch_utils.h" - - #include "cuda_utils.h" - #include "nvfp4_utils.cuh" - #include "launch_bounds_utils.h" - - namespace vllm { - - // NVFP4 quantization kernel for experts (low-latency path). - // When FUSE_SILU_MUL=true, expects input with gate||up layout and fuses - // SiLU(gate)*up before quantization. - // Use UE4M3 by default. - template - __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) - cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, - float const* SFScale, uint32_t* out, uint32_t* SFout, - uint32_t* input_offset_by_experts, - uint32_t* output_scale_offset_by_experts, int n_experts, - bool low_latency) { - using PackedVec = PackedVec; - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = - (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, - "Vec size is not matched."); - - // Precompute SF layout parameter (constant for entire kernel). - int32_t const numKTiles = (numCols + 63) / 64; - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; - // When fusing SiLU+Mul, input has gate || up layout (doubled width) - int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow; - - // Each global thread processes one element - for (int globalIdx = tid; globalIdx < numRows * colsPerRow; - globalIdx += gridDim.x * blockDim.x) { - // Calculate which row and column this global thread should process - int rowIdx = globalIdx / colsPerRow; - int colIdx = globalIdx % colsPerRow; - - // Find index within the experts using different strategies based on expert - // count - int rowIdx_in_expert = 0; - int expert_idx = 0; - - if constexpr (SMALL_NUM_EXPERTS) { - for (int i = 0; i < n_experts; i++) { - uint32_t current_offset = __ldca(&input_offset_by_experts[i]); - uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]); - if (rowIdx >= current_offset && rowIdx < next_offset) { - rowIdx_in_expert = rowIdx - current_offset; - expert_idx = i; - break; - } - } - } else { - // Load input offsets into registers first, then do the computation. - // Local array size set to 17 because of register limit. - uint32_t local_offsets[17]; - for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) { - *reinterpret_cast(local_offsets) = - __ldca(reinterpret_cast( - &input_offset_by_experts[chunk_start])); - *reinterpret_cast(local_offsets + 4) = - __ldca(reinterpret_cast( - &input_offset_by_experts[chunk_start + 4])); - *reinterpret_cast(local_offsets + 8) = - __ldca(reinterpret_cast( - &input_offset_by_experts[chunk_start + 8])); - *reinterpret_cast(local_offsets + 12) = - __ldca(reinterpret_cast( - &input_offset_by_experts[chunk_start + 12])); - local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]); - - // Check against the 16 loaded offsets - #pragma unroll - for (int i = 0; i < 16; i++) { - if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) { - rowIdx_in_expert = rowIdx - local_offsets[i]; - expert_idx = chunk_start + i; - break; - } - } - } - } - - // Load input and optionally apply fused SiLU+Mul - int64_t inOffset = rowIdx * inColsPerRow + colIdx; - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - PackedVec quant_input; - if constexpr (FUSE_SILU_MUL) { - PackedVec in_vec_up = - reinterpret_cast(in)[inOffset + colsPerRow]; - quant_input = compute_silu_mul(in_vec, in_vec_up); - } else { - quant_input = in_vec; - } - - // Get the output tensor offset. - // Same as inOffset because 8 elements are packed into one uint32_t. - int64_t outOffset = rowIdx * colsPerRow + colIdx; - auto& out_pos = out[outOffset]; - - // Get the global scaling factor, which will be applied to the SF. - // Note SFScale is the same as next GEMM's alpha, which is - // (448.f / (Alpha_A / 6.f)). - float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; - - uint32_t* SFout_in_expert = - SFout + output_scale_offset_by_experts[expert_idx] * numKTiles; - - auto sf_out = - cvt_quant_to_fp4_get_sf_out_offset( - rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); - - out_pos = cvt_warp_fp16_to_fp4( - quant_input, SFScaleVal, sf_out); - } - } - - // NVFP4 quantization kernel for LARGE_M_TOPK = true (large m_topk optimized - // version). When FUSE_SILU_MUL=true, expects input with gate||up layout and - // fuses SiLU(gate)*up before quantization. - template - __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) - cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, - float const* SFScale, uint32_t* out, uint32_t* SFout, - uint32_t* input_offset_by_experts, - uint32_t* output_scale_offset_by_experts, int n_experts) { - using PackedVec = PackedVec; - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = - (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, - "Vec size is not matched."); - - // Precompute SF layout parameter (constant for entire kernel). - int32_t const numKTiles = (numCols + 63) / 64; - - extern __shared__ uint32_t shared_input_offsets[]; - - // Load input offsets into shared memory. - // If n_experts is larger than 4, use vectorized int4 to save instructions. - // If n_experts is smaller than 4, read directly. - if constexpr (SMALL_NUM_EXPERTS) { - for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) { - shared_input_offsets[i] = input_offset_by_experts[i]; - } - } else { - for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) { - *reinterpret_cast(&shared_input_offsets[i]) = - *reinterpret_cast(&input_offset_by_experts[i]); - } - if (threadIdx.x == 0) { - shared_input_offsets[n_experts] = input_offset_by_experts[n_experts]; - } - } - - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; - // When fusing SiLU+Mul, input has gate || up layout (doubled width) - int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow; - - // Each global thread processes one element - for (int globalIdx = tid; globalIdx < numRows * colsPerRow; - globalIdx += gridDim.x * blockDim.x) { - // Calculate which row and column this global thread should process - int rowIdx = globalIdx / colsPerRow; - int colIdx = globalIdx % colsPerRow; - - // Find expert using binary search for better performance with large m_topk - int rowIdx_in_expert = 0; - int expert_idx = 0; - - // Binary search through experts using shared memory - int left = 0, right = n_experts - 1; - while (left <= right) { - int mid = (left + right) / 2; - // Get offsets: shared_input_offsets[i] corresponds to - // input_offset_by_experts[i] - uint32_t mid_offset = shared_input_offsets[mid]; - uint32_t next_offset = shared_input_offsets[mid + 1]; - - if (rowIdx >= mid_offset && rowIdx < next_offset) { - rowIdx_in_expert = rowIdx - mid_offset; - expert_idx = mid; - break; - } else if (rowIdx < mid_offset) { - right = mid - 1; - } else { - left = mid + 1; - } - } - - // Load input and optionally apply fused SiLU+Mul - int64_t inOffset = rowIdx * inColsPerRow + colIdx; - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - PackedVec quant_input; - if constexpr (FUSE_SILU_MUL) { - PackedVec in_vec_up = - reinterpret_cast(in)[inOffset + colsPerRow]; - quant_input = compute_silu_mul(in_vec, in_vec_up); - } else { - quant_input = in_vec; - } - - int64_t outOffset = rowIdx * colsPerRow + colIdx; - auto& out_pos = out[outOffset]; - - float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; - - uint32_t* SFout_in_expert = - SFout + output_scale_offset_by_experts[expert_idx] * numKTiles; - - auto sf_out = - cvt_quant_to_fp4_get_sf_out_offset( - rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); - - out_pos = cvt_warp_fp16_to_fp4( - quant_input, SFScaleVal, sf_out); - } - } - - template - void quant_impl(void* output, void* output_scale, void* input, - void* input_global_scale, void* input_offset_by_experts, - void* output_scale_offset_by_experts, int m_topk, int k, - int n_experts, cudaStream_t stream) { - int multiProcessorCount = - get_device_attribute(cudaDevAttrMultiProcessorCount, -1); - - // Grid, Block size. - // Each thread converts 8 values. - int const workSizePerRow = k / ELTS_PER_THREAD; - int const totalWorkSize = m_topk * workSizePerRow; - dim3 block(std::min(workSizePerRow, 512)); - // Get number of blocks per SM - int const numBlocksPerSM = - vllm_runtime_blocks_per_sm(static_cast(block.x)); - dim3 grid(std::min(static_cast((totalWorkSize + block.x - 1) / block.x), - multiProcessorCount * numBlocksPerSM)); - while (grid.x <= multiProcessorCount && block.x > 64) { - grid.x *= 2; - block.x = (block.x + 1) / 2; - } - - int const blockRepeat = - (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x); - if (blockRepeat > 1) { - size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t); - if (n_experts >= 4) { - cvt_fp16_to_fp4 - <<>>( - m_topk, k, reinterpret_cast(input), - reinterpret_cast(input_global_scale), - reinterpret_cast(output), - reinterpret_cast(output_scale), - reinterpret_cast(input_offset_by_experts), - reinterpret_cast(output_scale_offset_by_experts), - n_experts); - } else { - cvt_fp16_to_fp4 - <<>>( - m_topk, k, reinterpret_cast(input), - reinterpret_cast(input_global_scale), - reinterpret_cast(output), - reinterpret_cast(output_scale), - reinterpret_cast(input_offset_by_experts), - reinterpret_cast(output_scale_offset_by_experts), - n_experts); - } - } else { - if (n_experts >= 16) { - cvt_fp16_to_fp4 - <<>>( - m_topk, k, reinterpret_cast(input), - reinterpret_cast(input_global_scale), - reinterpret_cast(output), - reinterpret_cast(output_scale), - reinterpret_cast(input_offset_by_experts), - reinterpret_cast(output_scale_offset_by_experts), - n_experts, /* bool low_latency */ true); - } else { - cvt_fp16_to_fp4 - <<>>( - m_topk, k, reinterpret_cast(input), - reinterpret_cast(input_global_scale), - reinterpret_cast(output), - reinterpret_cast(output_scale), - reinterpret_cast(input_offset_by_experts), - reinterpret_cast(output_scale_offset_by_experts), - n_experts, /* bool low_latency */ true); - } - } - } - - } // namespace vllm - - /*Quantization entry for fp4 experts quantization*/ - #define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") - #define CHECK_CONTIGUOUS(x, m) \ - TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") - #define CHECK_INPUT(x, m) \ - CHECK_TH_CUDA(x, m); \ - CHECK_CONTIGUOUS(x, m); - - constexpr auto HALF = at::ScalarType::Half; - constexpr auto BF16 = at::ScalarType::BFloat16; - constexpr auto FLOAT = at::ScalarType::Float; - constexpr auto INT = at::ScalarType::Int; - constexpr auto UINT8 = at::ScalarType::Byte; - - // Common validation for fp4 experts quantization entry points. - static void validate_fp4_experts_quant_inputs( - torch::Tensor const& output, torch::Tensor const& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk, - int64_t k) { - CHECK_INPUT(output, "output"); - CHECK_INPUT(output_scale, "output_scale"); - CHECK_INPUT(input, "input"); - CHECK_INPUT(input_global_scale, "input_global_scale"); - CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts"); - CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts"); - - TORCH_CHECK(output.dim() == 2); - TORCH_CHECK(output_scale.dim() == 2); - TORCH_CHECK(input.dim() == 2); - TORCH_CHECK(input_global_scale.dim() == 1); - TORCH_CHECK(input_offset_by_experts.dim() == 1); - TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); - - TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); - TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); - TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); - TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); - // output is uint8 (two nvfp4 values are packed into one uint8) - // output_scale is int32 (four fp8 values are packed into one int32) - TORCH_CHECK(output.scalar_type() == UINT8); - TORCH_CHECK(output_scale.scalar_type() == INT); - - const int BLOCK_SIZE = 16; - TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); - auto n_experts = input_global_scale.size(0); - TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); - TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); - TORCH_CHECK(output.size(0) == m_topk); - TORCH_CHECK(output.size(1) == k / 2); - int scales_k = k / BLOCK_SIZE; - // 4 means the swizzle requirement by nvidia nvfp4. - int padded_k = (scales_k + (4 - 1)) / 4 * 4; - // 4 means 4 fp8 values are packed into one int32 - TORCH_CHECK(output_scale.size(1) * 4 == padded_k); - } - - void scaled_fp4_experts_quant_sm1xxa( - torch::Tensor& output, torch::Tensor& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts) { - auto m_topk = input.size(0); - auto k = input.size(1); - - validate_fp4_experts_quant_inputs(output, output_scale, input, - input_global_scale, input_offset_by_experts, - output_scale_offset_by_experts, m_topk, k); - - auto n_experts = input_global_scale.size(0); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = - at::cuda::getCurrentCUDAStream(input.get_device()); - - VLLM_DISPATCH_HALF_TYPES( - input.scalar_type(), "nvfp4_experts_quant_kernel", [&] { - using cuda_type = vllm::CUDATypeConverter::Type; - vllm::quant_impl( - output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), - input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, - stream); - }); - } - - void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( - torch::Tensor& output, torch::Tensor& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts) { - auto m_topk = input.size(0); - // Input has gate || up layout, so k = input.size(1) / 2 - auto k_times_2 = input.size(1); - TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)"); - auto k = k_times_2 / 2; - - validate_fp4_experts_quant_inputs(output, output_scale, input, - input_global_scale, input_offset_by_experts, - output_scale_offset_by_experts, m_topk, k); - - auto n_experts = input_global_scale.size(0); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = - at::cuda::getCurrentCUDAStream(input.get_device()); - - VLLM_DISPATCH_HALF_TYPES( - input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] { - using cuda_type = vllm::CUDATypeConverter::Type; - vllm::quant_impl( - output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), - input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, - stream); - }); - } - \ No newline at end of file +#include + +#include +#include + +#include +#include + +#include +#include "dispatch_utils.h" + +#include "cuda_utils.h" +#include "nvfp4_utils.cuh" +#include "launch_bounds_utils.h" + +namespace vllm { + +// NVFP4 quantization kernel for experts (low-latency path). +// When FUSE_SILU_MUL=true, expects input with gate||up layout and fuses +// SiLU(gate)*up before quantization. +// Use UE4M3 by default. +template +__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts, + bool low_latency) { + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Precompute SF layout parameter (constant for entire kernel). + int32_t const numKTiles = (numCols + 63) / 64; + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; + // When fusing SiLU+Mul, input has gate || up layout (doubled width) + int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow; + + // Each global thread processes one element + for (int globalIdx = tid; globalIdx < numRows * colsPerRow; + globalIdx += gridDim.x * blockDim.x) { + // Calculate which row and column this global thread should process + int rowIdx = globalIdx / colsPerRow; + int colIdx = globalIdx % colsPerRow; + + // Find index within the experts using different strategies based on expert + // count + int rowIdx_in_expert = 0; + int expert_idx = 0; + + if constexpr (SMALL_NUM_EXPERTS) { + for (int i = 0; i < n_experts; i++) { + uint32_t current_offset = __ldca(&input_offset_by_experts[i]); + uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]); + if (rowIdx >= current_offset && rowIdx < next_offset) { + rowIdx_in_expert = rowIdx - current_offset; + expert_idx = i; + break; + } + } + } else { + // Load input offsets into registers first, then do the computation. + // Local array size set to 17 because of register limit. + uint32_t local_offsets[17]; + for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) { + *reinterpret_cast(local_offsets) = + __ldca(reinterpret_cast( + &input_offset_by_experts[chunk_start])); + *reinterpret_cast(local_offsets + 4) = + __ldca(reinterpret_cast( + &input_offset_by_experts[chunk_start + 4])); + *reinterpret_cast(local_offsets + 8) = + __ldca(reinterpret_cast( + &input_offset_by_experts[chunk_start + 8])); + *reinterpret_cast(local_offsets + 12) = + __ldca(reinterpret_cast( + &input_offset_by_experts[chunk_start + 12])); + local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]); + +// Check against the 16 loaded offsets +#pragma unroll + for (int i = 0; i < 16; i++) { + if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) { + rowIdx_in_expert = rowIdx - local_offsets[i]; + expert_idx = chunk_start + i; + break; + } + } + } + } + + // Load input and optionally apply fused SiLU+Mul + int64_t inOffset = rowIdx * inColsPerRow + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + PackedVec quant_input; + if constexpr (FUSE_SILU_MUL) { + PackedVec in_vec_up = + reinterpret_cast(in)[inOffset + colsPerRow]; + quant_input = compute_silu_mul(in_vec, in_vec_up); + } else { + quant_input = in_vec; + } + + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = rowIdx * colsPerRow + colIdx; + auto& out_pos = out[outOffset]; + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; + + uint32_t* SFout_in_expert = + SFout + output_scale_offset_by_experts[expert_idx] * numKTiles; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); + + out_pos = cvt_warp_fp16_to_fp4( + quant_input, SFScaleVal, sf_out); + } +} + +// NVFP4 quantization kernel for LARGE_M_TOPK = true (large m_topk optimized +// version). When FUSE_SILU_MUL=true, expects input with gate||up layout and +// fuses SiLU(gate)*up before quantization. +template +__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, + float const* SFScale, uint32_t* out, uint32_t* SFout, + uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts) { + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Precompute SF layout parameter (constant for entire kernel). + int32_t const numKTiles = (numCols + 63) / 64; + + extern __shared__ uint32_t shared_input_offsets[]; + + // Load input offsets into shared memory. + // If n_experts is larger than 4, use vectorized int4 to save instructions. + // If n_experts is smaller than 4, read directly. + if constexpr (SMALL_NUM_EXPERTS) { + for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) { + shared_input_offsets[i] = input_offset_by_experts[i]; + } + } else { + for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) { + *reinterpret_cast(&shared_input_offsets[i]) = + *reinterpret_cast(&input_offset_by_experts[i]); + } + if (threadIdx.x == 0) { + shared_input_offsets[n_experts] = input_offset_by_experts[n_experts]; + } + } + + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; + // When fusing SiLU+Mul, input has gate || up layout (doubled width) + int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow; + + // Each global thread processes one element + for (int globalIdx = tid; globalIdx < numRows * colsPerRow; + globalIdx += gridDim.x * blockDim.x) { + // Calculate which row and column this global thread should process + int rowIdx = globalIdx / colsPerRow; + int colIdx = globalIdx % colsPerRow; + + // Find expert using binary search for better performance with large m_topk + int rowIdx_in_expert = 0; + int expert_idx = 0; + + // Binary search through experts using shared memory + int left = 0, right = n_experts - 1; + while (left <= right) { + int mid = (left + right) / 2; + // Get offsets: shared_input_offsets[i] corresponds to + // input_offset_by_experts[i] + uint32_t mid_offset = shared_input_offsets[mid]; + uint32_t next_offset = shared_input_offsets[mid + 1]; + + if (rowIdx >= mid_offset && rowIdx < next_offset) { + rowIdx_in_expert = rowIdx - mid_offset; + expert_idx = mid; + break; + } else if (rowIdx < mid_offset) { + right = mid - 1; + } else { + left = mid + 1; + } + } + + // Load input and optionally apply fused SiLU+Mul + int64_t inOffset = rowIdx * inColsPerRow + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + PackedVec quant_input; + if constexpr (FUSE_SILU_MUL) { + PackedVec in_vec_up = + reinterpret_cast(in)[inOffset + colsPerRow]; + quant_input = compute_silu_mul(in_vec, in_vec_up); + } else { + quant_input = in_vec; + } + + int64_t outOffset = rowIdx * colsPerRow + colIdx; + auto& out_pos = out[outOffset]; + + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; + + uint32_t* SFout_in_expert = + SFout + output_scale_offset_by_experts[expert_idx] * numKTiles; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); + + out_pos = cvt_warp_fp16_to_fp4( + quant_input, SFScaleVal, sf_out); + } +} + +template +void quant_impl(void* output, void* output_scale, void* input, + void* input_global_scale, void* input_offset_by_experts, + void* output_scale_offset_by_experts, int m_topk, int k, + int n_experts, cudaStream_t stream) { + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + + // Grid, Block size. + // Each thread converts 8 values. + int const workSizePerRow = k / ELTS_PER_THREAD; + int const totalWorkSize = m_topk * workSizePerRow; + dim3 block(std::min(workSizePerRow, 512)); + // Get number of blocks per SM + int const numBlocksPerSM = + vllm_runtime_blocks_per_sm(static_cast(block.x)); + dim3 grid(std::min(static_cast((totalWorkSize + block.x - 1) / block.x), + multiProcessorCount * numBlocksPerSM)); + while (grid.x <= multiProcessorCount && block.x > 64) { + grid.x *= 2; + block.x = (block.x + 1) / 2; + } + + int const blockRepeat = + (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x); + if (blockRepeat > 1) { + size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t); + if (n_experts >= 4) { + cvt_fp16_to_fp4 + <<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + n_experts); + } else { + cvt_fp16_to_fp4 + <<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + n_experts); + } + } else { + if (n_experts >= 16) { + cvt_fp16_to_fp4 + <<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + n_experts, /* bool low_latency */ true); + } else { + cvt_fp16_to_fp4 + <<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), + n_experts, /* bool low_latency */ true); + } + } +} + +} // namespace vllm + +/*Quantization entry for fp4 experts quantization*/ +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) \ + TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") +#define CHECK_INPUT(x, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); + +constexpr auto HALF = at::ScalarType::Half; +constexpr auto BF16 = at::ScalarType::BFloat16; +constexpr auto FLOAT = at::ScalarType::Float; +constexpr auto INT = at::ScalarType::Int; +constexpr auto UINT8 = at::ScalarType::Byte; + +// Common validation for fp4 experts quantization entry points. +static void validate_fp4_experts_quant_inputs( + torch::Tensor const& output, torch::Tensor const& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk, + int64_t k) { + CHECK_INPUT(output, "output"); + CHECK_INPUT(output_scale, "output_scale"); + CHECK_INPUT(input, "input"); + CHECK_INPUT(input_global_scale, "input_global_scale"); + CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts"); + CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts"); + + TORCH_CHECK(output.dim() == 2); + TORCH_CHECK(output_scale.dim() == 2); + TORCH_CHECK(input.dim() == 2); + TORCH_CHECK(input_global_scale.dim() == 1); + TORCH_CHECK(input_offset_by_experts.dim() == 1); + TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); + + TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); + TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); + TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); + TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); + // output is uint8 (two nvfp4 values are packed into one uint8) + // output_scale is int32 (four fp8 values are packed into one int32) + TORCH_CHECK(output.scalar_type() == UINT8); + TORCH_CHECK(output_scale.scalar_type() == INT); + + const int BLOCK_SIZE = 16; + TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); + auto n_experts = input_global_scale.size(0); + TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output.size(0) == m_topk); + TORCH_CHECK(output.size(1) == k / 2); + int scales_k = k / BLOCK_SIZE; + // 4 means the swizzle requirement by nvidia nvfp4. + int padded_k = (scales_k + (4 - 1)) / 4 * 4; + // 4 means 4 fp8 values are packed into one int32 + TORCH_CHECK(output_scale.size(1) * 4 == padded_k); +} + +void scaled_fp4_experts_quant_sm1xxa( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { + auto m_topk = input.size(0); + auto k = input.size(1); + + validate_fp4_experts_quant_inputs(output, output_scale, input, + input_global_scale, input_offset_by_experts, + output_scale_offset_by_experts, m_topk, k); + + auto n_experts = input_global_scale.size(0); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(input.get_device()); + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "nvfp4_experts_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + vllm::quant_impl( + output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), + input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, + stream); + }); +} + +void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { + auto m_topk = input.size(0); + // Input has gate || up layout, so k = input.size(1) / 2 + auto k_times_2 = input.size(1); + TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)"); + auto k = k_times_2 / 2; + + validate_fp4_experts_quant_inputs(output, output_scale, input, + input_global_scale, input_offset_by_experts, + output_scale_offset_by_experts, m_topk, k); + + auto n_experts = input_global_scale.size(0); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(input.get_device()); + + VLLM_DISPATCH_HALF_TYPES( + input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + vllm::quant_impl( + output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), + input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, + stream); + }); +} diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index 7bd25a309096..585cbd4e073a 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -14,229 +14,228 @@ * limitations under the License. */ - #include - - #include - #include - - #include - #include - - #include - #include "dispatch_utils.h" - - #include "cuda_utils.h" - #include "launch_bounds_utils.h" - - // Define before including nvfp4_utils.cuh so the header - // can use this macro during compilation. - #define NVFP4_ENABLE_ELTS16 1 - #include "nvfp4_utils.cuh" - - namespace vllm { - - // Use UE4M3 by default. - template - __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) - cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, int32_t num_padded_cols, - Type const* __restrict__ in, - float const* __restrict__ SFScale, - uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { - using PackedVec = vllm::PackedVec; - - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = - (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, - "Vec size is not matched."); - - // Precompute SF layout parameter (constant for entire kernel). - int32_t const numKTiles = (numCols + 63) / 64; - - int sf_m = round_up(numRows, 128); - int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; - int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; - - // Get the global scaling factor, which will be applied to the SF. - // Note SFScale is the same as next GEMM's alpha, which is - // (448.f / (Alpha_A / 6.f)). - float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0]; - - // Iterate over all rows and cols including padded ones - - // ensures we visit every single scale factor address to initialize it. - for (int rowIdx = blockIdx.x; rowIdx < sf_m; rowIdx += gridDim.x) { - if (colIdx < num_padded_cols) { - PackedVec in_vec; - int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; - - // If we are outside valid rows OR outside valid columns -> Use Zeros - bool valid = (rowIdx < numRows) && (elem_idx < numCols); - if constexpr (CVT_FP4_PACK16) { - ld256_cg_or_zero( - reinterpret_cast(in_vec), - &reinterpret_cast(in)[inOffset * 8], valid); - } else { - ld128_cg_or_zero( - reinterpret_cast(in_vec), - &reinterpret_cast(in)[inOffset * 4], valid); - } - - auto sf_out = - cvt_quant_to_fp4_get_sf_out_offset( - rowIdx, colIdx, numKTiles, SFout); - - auto out_val = - cvt_warp_fp16_to_fp4( - in_vec, global_scale, sf_out); - - // We do NOT write output for padding because the 'out' tensor is not - // padded. - if (valid) { - if constexpr (CVT_FP4_PACK16) { - int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; - uint64_t packed64 = - (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); - reinterpret_cast(out)[outOffset >> 1] = packed64; - } else { - out[inOffset] = out_val; - } - } - } - } - } - - // Use UE4M3 by default. - template - __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) - cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols, - int32_t sf_n_unpadded, Type const* __restrict__ in, - float const* __restrict__ SFScale, - uint32_t* __restrict__ out, - uint32_t* __restrict__ SFout) { - using PackedVec = PackedVec; - - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = - (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, - "Vec size is not matched."); - - int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; - int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; - - // Get the global scaling factor, which will be applied to the SF. - // Note SFScale is the same as next GEMM's alpha, which is - // (448.f / (Alpha_A / 6.f)). - float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0]; - - // Iterate over all rows and cols including padded ones - - // ensures we visit every single scale factor address to initialize it. - for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { - if (colIdx < sf_n_unpadded) { - PackedVec in_vec; - int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; - - // If we are outside valid rows OR outside valid columns -> Use Zeros - bool valid = (rowIdx < numRows) && (elem_idx < numCols); - if constexpr (CVT_FP4_PACK16) { - ld256_cg_or_zero( - reinterpret_cast(in_vec), - &reinterpret_cast(in)[inOffset * 8], valid); - } else { - ld128_cg_or_zero( - reinterpret_cast(in_vec), - &reinterpret_cast(in)[inOffset * 4], valid); - } - - auto sf_out = - sf_out_rowmajor_u8(rowIdx, colIdx, sf_n_unpadded, SFout); - - auto out_val = - cvt_warp_fp16_to_fp4( - in_vec, global_scale, sf_out); - - // We do NOT write output for padding because the 'out' tensor is not - // padded. - if (valid) { - if constexpr (CVT_FP4_PACK16) { - int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; - uint64_t packed64 = - (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); - reinterpret_cast(out)[outOffset >> 1] = packed64; - } else { - out[inOffset] = out_val; - } - } - } - } - } - - } // namespace vllm - - void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, - torch::Tensor const& input, - torch::Tensor const& output_sf, - torch::Tensor const& input_sf, - bool is_sf_swizzled_layout) { - int32_t m = input.size(0); - int32_t n = input.size(1); - - TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); - TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || - input.scalar_type() == at::ScalarType::BFloat16, - "Unsupported input data type for quantize_to_fp4."); - - int multiProcessorCount = - get_device_attribute(cudaDevAttrMultiProcessorCount, -1); - - auto input_sf_ptr = static_cast(input_sf.data_ptr()); - auto sf_out = static_cast(output_sf.data_ptr()); - auto output_ptr = static_cast(output.data_ptr()); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); - - int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); - - // Grid, Block size. Each thread converts 8 values. - dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); - int const numBlocksPerSM = - vllm_runtime_blocks_per_sm(static_cast(block.x)); - - if (is_sf_swizzled_layout) { - int sf_n_int = int(vllm::round_up(sf_n_unpadded, 4) / 4); - int32_t num_padded_cols = - sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; - - int grid_y = vllm::div_round_up(num_padded_cols, static_cast(block.x)); - int grid_x = - std::min(vllm::computeEffectiveRows(m), - std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); - dim3 grid(grid_x, grid_y); - - VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { - using cuda_type = vllm::CUDATypeConverter::Type; - auto input_ptr = static_cast(input.data_ptr()); - // NOTE: We don't support e8m0 scales at this moment. - vllm::cvt_fp16_to_fp4<<>>( - m, n, num_padded_cols, input_ptr, input_sf_ptr, - reinterpret_cast(output_ptr), - reinterpret_cast(sf_out)); - }); - } else { - int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); - int grid_x = std::min( - m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); - dim3 grid(grid_x, grid_y); - - VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { - using cuda_type = vllm::CUDATypeConverter::Type; - auto input_ptr = static_cast(input.data_ptr()); - // NOTE: We don't support e8m0 scales at this moment. - vllm::cvt_fp16_to_fp4_sf_major - <<>>(m, n, sf_n_unpadded, input_ptr, - input_sf_ptr, - reinterpret_cast(output_ptr), - reinterpret_cast(sf_out)); - }); - } - } - \ No newline at end of file +#include + +#include +#include + +#include +#include + +#include +#include "dispatch_utils.h" + +#include "cuda_utils.h" +#include "launch_bounds_utils.h" + +// Define before including nvfp4_utils.cuh so the header +// can use this macro during compilation. +#define NVFP4_ENABLE_ELTS16 1 +#include "nvfp4_utils.cuh" + +namespace vllm { + +// Use UE4M3 by default. +template +__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) + cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, int32_t num_padded_cols, + Type const* __restrict__ in, + float const* __restrict__ SFScale, + uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { + using PackedVec = vllm::PackedVec; + + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Precompute SF layout parameter (constant for entire kernel). + int32_t const numKTiles = (numCols + 63) / 64; + + int sf_m = round_up(numRows, 128); + int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; + int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0]; + + // Iterate over all rows and cols including padded ones - + // ensures we visit every single scale factor address to initialize it. + for (int rowIdx = blockIdx.x; rowIdx < sf_m; rowIdx += gridDim.x) { + if (colIdx < num_padded_cols) { + PackedVec in_vec; + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + + // If we are outside valid rows OR outside valid columns -> Use Zeros + bool valid = (rowIdx < numRows) && (elem_idx < numCols); + if constexpr (CVT_FP4_PACK16) { + ld256_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 8], + valid); + } else { + ld128_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 4], + valid); + } + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx, colIdx, numKTiles, SFout); + + auto out_val = + cvt_warp_fp16_to_fp4( + in_vec, global_scale, sf_out); + + // We do NOT write output for padding because the 'out' tensor is not + // padded. + if (valid) { + if constexpr (CVT_FP4_PACK16) { + int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; + uint64_t packed64 = + (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); + reinterpret_cast(out)[outOffset >> 1] = packed64; + } else { + out[inOffset] = out_val; + } + } + } + } +} + +// Use UE4M3 by default. +template +__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) + cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols, + int32_t sf_n_unpadded, Type const* __restrict__ in, + float const* __restrict__ SFScale, + uint32_t* __restrict__ out, + uint32_t* __restrict__ SFout) { + using PackedVec = PackedVec; + + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x; + int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD; + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0]; + + // Iterate over all rows and cols including padded ones - + // ensures we visit every single scale factor address to initialize it. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + if (colIdx < sf_n_unpadded) { + PackedVec in_vec; + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + + // If we are outside valid rows OR outside valid columns -> Use Zeros + bool valid = (rowIdx < numRows) && (elem_idx < numCols); + if constexpr (CVT_FP4_PACK16) { + ld256_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 8], + valid); + } else { + ld128_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 4], + valid); + } + + auto sf_out = + sf_out_rowmajor_u8(rowIdx, colIdx, sf_n_unpadded, SFout); + + auto out_val = + cvt_warp_fp16_to_fp4( + in_vec, global_scale, sf_out); + + // We do NOT write output for padding because the 'out' tensor is not + // padded. + if (valid) { + if constexpr (CVT_FP4_PACK16) { + int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2; + uint64_t packed64 = + (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); + reinterpret_cast(out)[outOffset >> 1] = packed64; + } else { + out[inOffset] = out_val; + } + } + } + } +} + +} // namespace vllm + +void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, + torch::Tensor const& input, + torch::Tensor const& output_sf, + torch::Tensor const& input_sf, + bool is_sf_swizzled_layout) { + int32_t m = input.size(0); + int32_t n = input.size(1); + + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || + input.scalar_type() == at::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); + + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); + + // Grid, Block size. Each thread converts 8 values. + dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); + int const numBlocksPerSM = + vllm_runtime_blocks_per_sm(static_cast(block.x)); + + if (is_sf_swizzled_layout) { + int sf_n_int = int(vllm::round_up(sf_n_unpadded, 4) / 4); + int32_t num_padded_cols = + sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; + + int grid_y = vllm::div_round_up(num_padded_cols, static_cast(block.x)); + int grid_x = + std::min(vllm::computeEffectiveRows(m), + std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); + dim3 grid(grid_x, grid_y); + + VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + // NOTE: We don't support e8m0 scales at this moment. + vllm::cvt_fp16_to_fp4<<>>( + m, n, num_padded_cols, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); + } else { + int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); + int grid_x = std::min( + m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); + dim3 grid(grid_x, grid_y); + + VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + // NOTE: We don't support e8m0 scales at this moment. + vllm::cvt_fp16_to_fp4_sf_major + <<>>(m, n, sf_n_unpadded, input_ptr, + input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); + } +} diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index 502b86b31270..7b7a8488e048 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -14,294 +14,293 @@ * limitations under the License. */ - #pragma once - - #include - #include - - #include "../../cuda_vec_utils.cuh" - - #if defined(NVFP4_ENABLE_ELTS16) && \ - defined(CUDA_VERSION) && CUDA_VERSION >= 12090 && \ - defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 - #define ELTS_PER_THREAD 16 - constexpr int CVT_FP4_ELTS_PER_THREAD = 16; - constexpr bool CVT_FP4_PACK16 = true; - #else - #define ELTS_PER_THREAD 8 - constexpr int CVT_FP4_ELTS_PER_THREAD = 8; - constexpr bool CVT_FP4_PACK16 = false; - #endif - - constexpr int CVT_FP4_SF_VEC_SIZE = 16; - - namespace vllm { - - template - struct PackedVec<__nv_fp8_e4m3, use_256b> { - __nv_fp8x2_e4m3 elts[8]; - }; - - template - __host__ __device__ inline Int round_up(Int x, Int y) { - static_assert(std::is_integral_v, - "round_up argument must be integral type"); - return ((x + y - 1) / y) * y; - } - - template - __host__ __device__ __forceinline__ Int div_round_up(Int x, Int y) { - return (x + y - 1) / y; - } - - // Compute effective rows for grid configuration with swizzled SF layouts. - inline int computeEffectiveRows(int m) { - constexpr int ROW_TILE = 128; - return round_up(m, ROW_TILE); - } - - // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). - inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) { - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), - "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); - return val; - } - - // Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). - __device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) { - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}\n" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; - } - - struct u32x2 { - uint32_t lo, hi; - }; - - using fp4_packed_t = std::conditional_t; - - __device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) { - u32x2 out; - asm volatile( - "{\n" - ".reg .b8 b0;\n" - ".reg .b8 b1;\n" - ".reg .b8 b2;\n" - ".reg .b8 b3;\n" - ".reg .b8 b4;\n" - ".reg .b8 b5;\n" - ".reg .b8 b6;\n" - ".reg .b8 b7;\n" - "cvt.rn.satfinite.e2m1x2.f32 b0, %3, %2;\n" - "cvt.rn.satfinite.e2m1x2.f32 b1, %5, %4;\n" - "cvt.rn.satfinite.e2m1x2.f32 b2, %7, %6;\n" - "cvt.rn.satfinite.e2m1x2.f32 b3, %9, %8;\n" - "cvt.rn.satfinite.e2m1x2.f32 b4, %11, %10;\n" - "cvt.rn.satfinite.e2m1x2.f32 b5, %13, %12;\n" - "cvt.rn.satfinite.e2m1x2.f32 b6, %15, %14;\n" - "cvt.rn.satfinite.e2m1x2.f32 b7, %17, %16;\n" - "mov.b32 %0, {b0, b1, b2, b3};\n" - "mov.b32 %1, {b4, b5, b6, b7};\n" - "}\n" - : "=r"(out.lo), "=r"(out.hi) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y), - "f"(array[4].x), "f"(array[4].y), "f"(array[5].x), "f"(array[5].y), - "f"(array[6].x), "f"(array[6].y), "f"(array[7].x), "f"(array[7].y)); - return out; - } - - __device__ __forceinline__ uint32_t pack_fp4(float2 (&v)[4]) { - return fp32_vec8_to_e2m1(v); - } - - __device__ __forceinline__ u32x2 pack_fp4(float2 (&v)[8]) { - return fp32_vec16_to_e2m1(v); - } - - // Fast reciprocal. - __device__ __forceinline__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(b) : "f"(a)); - return b; - } - - // Compute SF output offset for swizzled tensor core layout. - // SF layout: [numMTiles, numKTiles, 32, 4, 4] - // Caller must precompute: numKTiles = (numCols + 63) / 64 - template - __device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset( - int rowIdx, int colIdx, int32_t numKTiles, SFType* SFout) { - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || - CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF != 0) { - return nullptr; - } - - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // Decompose indices using bitwise ops (all divisors are powers of 2). - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - int32_t mTileIdx = mIdx >> 7; // mIdx / 128 - int32_t outerMIdx = mIdx & 31; // mIdx % 32 - int32_t innerMIdx = (mIdx >> 5) & 3; // (mIdx / 32) % 4 - int32_t kTileIdx = kIdx >> 2; // kIdx / 4 - int32_t innerKIdx = kIdx & 3; // kIdx % 4 - - // Compute global SF offset: mTileIdx * (numKTiles * 512) + kTileIdx * 512 + - // outerMIdx * 16 + innerMIdx * 4 + innerKIdx - // Use bitwise OR for non-overlapping lower bits. - int64_t SFOffset = (static_cast(mTileIdx) * numKTiles + kTileIdx) - << 9 | - (outerMIdx << 4) | (innerMIdx << 2) | innerKIdx; - - return reinterpret_cast(SFout) + SFOffset; - } - - template - __device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack, - int packs_per_row_sf, - SFType* SFout) { - constexpr int PACK = CVT_FP4_ELTS_PER_THREAD; - constexpr int THREADS_PER_SF = - CVT_FP4_SF_VEC_SIZE / PACK; // 1 if PACK=16, 2 else PACK=8 - - if (threadIdx.x % THREADS_PER_SF != 0) return nullptr; - - int sf_col = - pack / THREADS_PER_SF; // PACK=16 => sf_col=pack; PACK=8 => sf_col=pack/2 - int64_t off = (int64_t)row * packs_per_row_sf + sf_col; - - return (uint8_t*)SFout + off; - } - - // Quantizes the provided PackedVec into the uint32_t output - template - __device__ __forceinline__ fp4_packed_t - cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { - // Get absolute maximum values among the local 8 values. - auto localMax = __habs2(vec.elts[0]); - - // Local maximum value. - #pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - localMax = __hmax2(localMax, __habs2(vec.elts[i])); - } - - // Get the absolute maximum among all 16 values (two threads). - - if constexpr (CVT_FP4_NUM_THREADS_PER_SF == 2) { - localMax = __hmax2(__shfl_xor_sync(0xffffffffu, localMax, 1), localMax); - } - // Get the final absolute maximum values. - float vecMax = float(__hmax(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - // Extract the 8 exponent bits from float32. - // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. - uint32_t tmp = reinterpret_cast(SFValue) >> 23; - fp8SFVal = tmp & 0xff; - // Convert back to fp32. - reinterpret_cast(SFValue) = tmp << 23; - } else { - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; - // Convert back to fp32. - SFValue = float(tmp); - } - - // Write the SF to global memory (STG.8). - if (SFout) *SFout = fp8SFVal; - - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * - // reciprocal(SFScaleVal)) - float outputScale = - SFValue != 0.0f ? reciprocal_approximate_ftz( - SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - - // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - - #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - fp2Vals[i] = cast_to_float2(vec.elts[i]); - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - return pack_fp4(fp2Vals); - } - - // silu in float32 - __device__ __forceinline__ float silu(float x) { - return __fdividef(x, (1.f + __expf(-x))); - } - - __device__ __forceinline__ float2 silu2(float2 x) { - return make_float2(silu(x.x), silu(x.y)); - } - - template - __inline__ __device__ PackedVec compute_silu_mul( - const PackedVec& x_vec, const PackedVec& y_vec) { - PackedVec result; - - #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { - // silu_mul in float32 - using packed_t = typename TypeConverter::Type; - float2 silu_vec = silu2(cast_to_float2(x_vec.elts[i])); - float2 y_f2 = cast_to_float2(y_vec.elts[i]); - result.elts[i] = cast_to_packed( - make_float2(silu_vec.x * y_f2.x, silu_vec.y * y_f2.y)); - } - return result; - } - - } // namespace vllm - \ No newline at end of file +#pragma once + +#include +#include + +#include "../../cuda_vec_utils.cuh" + +#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \ + CUDA_VERSION >= 12090 && defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 + #define ELTS_PER_THREAD 16 +constexpr int CVT_FP4_ELTS_PER_THREAD = 16; +constexpr bool CVT_FP4_PACK16 = true; +#else + #define ELTS_PER_THREAD 8 +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr bool CVT_FP4_PACK16 = false; +#endif + +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +namespace vllm { + +template +struct PackedVec<__nv_fp8_e4m3, use_256b> { + __nv_fp8x2_e4m3 elts[8]; +}; + +template +__host__ __device__ inline Int round_up(Int x, Int y) { + static_assert(std::is_integral_v, + "round_up argument must be integral type"); + return ((x + y - 1) / y) * y; +} + +template +__host__ __device__ __forceinline__ Int div_round_up(Int x, Int y) { + return (x + y - 1) / y; +} + +// Compute effective rows for grid configuration with swizzled SF layouts. +inline int computeEffectiveRows(int m) { + constexpr int ROW_TILE = 128; + return round_up(m, ROW_TILE); +} + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) { + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +__device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) { + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}\n" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +} + +struct u32x2 { + uint32_t lo, hi; +}; + +using fp4_packed_t = std::conditional_t; + +__device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) { + u32x2 out; + asm volatile( + "{\n" + ".reg .b8 b0;\n" + ".reg .b8 b1;\n" + ".reg .b8 b2;\n" + ".reg .b8 b3;\n" + ".reg .b8 b4;\n" + ".reg .b8 b5;\n" + ".reg .b8 b6;\n" + ".reg .b8 b7;\n" + "cvt.rn.satfinite.e2m1x2.f32 b0, %3, %2;\n" + "cvt.rn.satfinite.e2m1x2.f32 b1, %5, %4;\n" + "cvt.rn.satfinite.e2m1x2.f32 b2, %7, %6;\n" + "cvt.rn.satfinite.e2m1x2.f32 b3, %9, %8;\n" + "cvt.rn.satfinite.e2m1x2.f32 b4, %11, %10;\n" + "cvt.rn.satfinite.e2m1x2.f32 b5, %13, %12;\n" + "cvt.rn.satfinite.e2m1x2.f32 b6, %15, %14;\n" + "cvt.rn.satfinite.e2m1x2.f32 b7, %17, %16;\n" + "mov.b32 %0, {b0, b1, b2, b3};\n" + "mov.b32 %1, {b4, b5, b6, b7};\n" + "}\n" + : "=r"(out.lo), "=r"(out.hi) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y), + "f"(array[4].x), "f"(array[4].y), "f"(array[5].x), "f"(array[5].y), + "f"(array[6].x), "f"(array[6].y), "f"(array[7].x), "f"(array[7].y)); + return out; +} + +__device__ __forceinline__ uint32_t pack_fp4(float2 (&v)[4]) { + return fp32_vec8_to_e2m1(v); +} + +__device__ __forceinline__ u32x2 pack_fp4(float2 (&v)[8]) { + return fp32_vec16_to_e2m1(v); +} + +// Fast reciprocal. +__device__ __forceinline__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(b) : "f"(a)); + return b; +} + +// Compute SF output offset for swizzled tensor core layout. +// SF layout: [numMTiles, numKTiles, 32, 4, 4] +// Caller must precompute: numKTiles = (numCols + 63) / 64 +template +__device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset( + int rowIdx, int colIdx, int32_t numKTiles, SFType* SFout) { + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF != 0) { + return nullptr; + } + + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // Decompose indices using bitwise ops (all divisors are powers of 2). + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + int32_t mTileIdx = mIdx >> 7; // mIdx / 128 + int32_t outerMIdx = mIdx & 31; // mIdx % 32 + int32_t innerMIdx = (mIdx >> 5) & 3; // (mIdx / 32) % 4 + int32_t kTileIdx = kIdx >> 2; // kIdx / 4 + int32_t innerKIdx = kIdx & 3; // kIdx % 4 + + // Compute global SF offset: mTileIdx * (numKTiles * 512) + kTileIdx * 512 + + // outerMIdx * 16 + innerMIdx * 4 + innerKIdx + // Use bitwise OR for non-overlapping lower bits. + int64_t SFOffset = (static_cast(mTileIdx) * numKTiles + kTileIdx) + << 9 | + (outerMIdx << 4) | (innerMIdx << 2) | innerKIdx; + + return reinterpret_cast(SFout) + SFOffset; +} + +template +__device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack, + int packs_per_row_sf, + SFType* SFout) { + constexpr int PACK = CVT_FP4_ELTS_PER_THREAD; + constexpr int THREADS_PER_SF = + CVT_FP4_SF_VEC_SIZE / PACK; // 1 if PACK=16, 2 else PACK=8 + + if (threadIdx.x % THREADS_PER_SF != 0) return nullptr; + + int sf_col = + pack / THREADS_PER_SF; // PACK=16 => sf_col=pack; PACK=8 => sf_col=pack/2 + int64_t off = (int64_t)row * packs_per_row_sf + sf_col; + + return (uint8_t*)SFout + off; +} + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ __forceinline__ fp4_packed_t cvt_warp_fp16_to_fp4( + PackedVec& vec, float SFScaleVal, uint8_t* SFout) { + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + + // Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + + if constexpr (CVT_FP4_NUM_THREADS_PER_SF == 2) { + localMax = __hmax2(__shfl_xor_sync(0xffffffffu, localMax, 1), localMax); + } + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + + // Write the SF to global memory (STG.8). + if (SFout) *SFout = fp8SFVal; + + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0.0f ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + fp2Vals[i] = cast_to_float2(vec.elts[i]); + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + return pack_fp4(fp2Vals); +} + +// silu in float32 +__device__ __forceinline__ float silu(float x) { + return __fdividef(x, (1.f + __expf(-x))); +} + +__device__ __forceinline__ float2 silu2(float2 x) { + return make_float2(silu(x.x), silu(x.y)); +} + +template +__inline__ __device__ PackedVec compute_silu_mul( + const PackedVec& x_vec, + const PackedVec& y_vec) { + PackedVec result; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { + // silu_mul in float32 + using packed_t = typename TypeConverter::Type; + float2 silu_vec = silu2(cast_to_float2(x_vec.elts[i])); + float2 y_f2 = cast_to_float2(y_vec.elts[i]); + result.elts[i] = cast_to_packed( + make_float2(silu_vec.x * y_f2.x, silu_vec.y * y_f2.y)); + } + return result; +} + +} // namespace vllm From 7d55bb87d81588e9275a5c34ee112aba2c270069 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Tue, 24 Feb 2026 21:14:08 +0100 Subject: [PATCH 3/6] refactor and bugfix Signed-off-by: LopezCastroRoberto --- csrc/cuda_vec_utils.cuh | 10 ---------- .../fp4/activation_nvfp4_quant_fusion_kernels.cu | 6 ++++-- csrc/quantization/fp4/nvfp4_quant_kernels.cu | 12 +++++++----- csrc/quantization/fp4/nvfp4_utils.cuh | 2 +- 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh index 2b3b23470561..9e1fa846fd34 100644 --- a/csrc/cuda_vec_utils.cuh +++ b/csrc/cuda_vec_utils.cuh @@ -136,11 +136,6 @@ __device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) : "l"(ptr)); - #else - const uint4* src = reinterpret_cast(ptr); - uint4* dst = reinterpret_cast(val.d); - dst[0] = __ldg(&src[0]); - dst[1] = __ldg(&src[1]); #endif } @@ -152,11 +147,6 @@ __device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { "r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7]) : "memory"); - #else - uint4* dst = reinterpret_cast(ptr); - const uint4* src = reinterpret_cast(val.d); - dst[0] = src[0]; - dst[1] = src[1]; #endif } diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index 8c0e82b7651f..948cb7e3ad21 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -107,7 +107,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); reinterpret_cast(out)[outOffset >> 1] = packed64; } else { - out[inOffset] = out_val; + int64_t outOffset = + rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + out[outOffset] = out_val; } } } @@ -140,7 +142,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] int const numBlocksPerSM = vllm_runtime_blocks_per_sm(static_cast(block.x)); - int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); + int sf_n_unpadded = int(n / CVT_FP4_ELTS_PER_THREAD); int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); int grid_x = std::min( diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index 585cbd4e073a..773047c22500 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -109,7 +109,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) template __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols, - int32_t sf_n_unpadded, Type const* __restrict__ in, + int32_t sf_n_unpadded, int32_t num_packed_cols, + Type const* __restrict__ in, float const* __restrict__ SFScale, uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { @@ -131,7 +132,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) // Iterate over all rows and cols including padded ones - // ensures we visit every single scale factor address to initialize it. for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { - if (colIdx < sf_n_unpadded) { + if (colIdx < num_packed_cols) { PackedVec in_vec; int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; @@ -222,7 +223,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, reinterpret_cast(sf_out)); }); } else { - int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); + int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD; + int grid_y = vllm::div_round_up(num_packed_cols, static_cast(block.x)); int grid_x = std::min( m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); dim3 grid(grid_x, grid_y); @@ -232,8 +234,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, auto input_ptr = static_cast(input.data_ptr()); // NOTE: We don't support e8m0 scales at this moment. vllm::cvt_fp16_to_fp4_sf_major - <<>>(m, n, sf_n_unpadded, input_ptr, - input_sf_ptr, + <<>>(m, n, sf_n_unpadded, num_packed_cols, + input_ptr, input_sf_ptr, reinterpret_cast(output_ptr), reinterpret_cast(sf_out)); }); diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index 7b7a8488e048..bbd4e2362d59 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -22,7 +22,7 @@ #include "../../cuda_vec_utils.cuh" #if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \ - CUDA_VERSION >= 12090 && defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 + CUDA_VERSION >= 12090 #define ELTS_PER_THREAD 16 constexpr int CVT_FP4_ELTS_PER_THREAD = 16; constexpr bool CVT_FP4_PACK16 = true; From 94a036e6a48160f1bb1b6e1befb02f6f8864f0b1 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Thu, 26 Feb 2026 10:40:59 +0000 Subject: [PATCH 4/6] rename variable Signed-off-by: LopezCastroRoberto --- .../fp4/activation_nvfp4_quant_fusion_kernels.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index 948cb7e3ad21..3539096c9feb 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -39,7 +39,7 @@ namespace vllm { template __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, - int32_t num_padded_cols, + int32_t num_packed_cols, Type const* __restrict__ in, float const* __restrict__ SFScale, uint32_t* __restrict__ out, @@ -63,7 +63,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) // Input tensor row/col loops. for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { - if (colIdx < num_padded_cols) { + if (colIdx < num_packed_cols) { PackedVec in_vec; PackedVec in_vec2; int64_t inOffset = @@ -142,9 +142,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] int const numBlocksPerSM = vllm_runtime_blocks_per_sm(static_cast(block.x)); - int sf_n_unpadded = int(n / CVT_FP4_ELTS_PER_THREAD); + int num_packed_cols = int(n / CVT_FP4_ELTS_PER_THREAD); - int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); + int grid_y = vllm::div_round_up(num_packed_cols, static_cast(block.x)); int grid_x = std::min( int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); dim3 grid(grid_x, grid_y); @@ -154,7 +154,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] using cuda_type = vllm::CUDATypeConverter::Type; auto input_ptr = static_cast(input.data_ptr()); vllm::silu_mul_cvt_fp16_to_fp4<<>>( - m, n, sf_n_unpadded, input_ptr, input_sf_ptr, + m, n, num_packed_cols, input_ptr, input_sf_ptr, reinterpret_cast(output_ptr), reinterpret_cast(sf_out)); }); From ac4780570f2759e32f67f86452fb71edbcca1306 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Fri, 27 Feb 2026 12:24:08 +0000 Subject: [PATCH 5/6] review comments Signed-off-by: LopezCastroRoberto --- csrc/activation_kernels.cu | 114 +++++++++++++------------- csrc/cuda_vec_utils.cuh | 90 +++++++++++++------- csrc/quantization/fp4/nvfp4_utils.cuh | 2 +- 3 files changed, 120 insertions(+), 86 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 24b4bd93eb45..758a77795553 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -152,54 +152,54 @@ packed_gelu_tanh_kernel(const packed_t& val) { // Launch activation and gating kernel. // Use ACT_FIRST (bool) indicating whether to apply the activation function // first. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ - auto dtype = input.scalar_type(); \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - if (num_tokens == 0) { \ - return; \ - } \ - dim3 grid(num_tokens); \ - int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = \ - (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ - ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ - : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ - int vec_size = support_vec / at::elementSize(dtype); \ - const bool use_vec = (d % vec_size == 0); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - if (use_vec) { \ - dim3 block(std::min(d / vec_size, 1024)); \ - if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::TypeConverter::Type, \ - KERNEL, \ - PACKED_KERNEL::Type>, \ - ACT_FIRST, true, true><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ - } else { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::TypeConverter::Type, \ - KERNEL, \ - PACKED_KERNEL::Type>, \ - ACT_FIRST, true, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ - } \ - } else { \ - dim3 block(std::min(d, 1024)); \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::TypeConverter::Type, \ - KERNEL, \ - PACKED_KERNEL::Type>, \ - ACT_FIRST, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ + int vec_size = support_vec / at::elementSize(dtype); \ + const bool use_vec = (d % vec_size == 0); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL::Type>, \ + ACT_FIRST, true, true><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } else { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL::Type>, \ + ACT_FIRST, true, false><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL::Type>, \ + ACT_FIRST, false><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ } void silu_and_mul(torch::Tensor& out, // [..., d] @@ -392,9 +392,10 @@ __global__ void swigluoai_and_mul_kernel( VLLM_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::TypeConverter::Type, \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ - PACKED_KERNEL::Type>, \ + PACKED_KERNEL< \ + typename vllm::PackedTypeConverter::Type>, \ true, true><<>>( \ out.data_ptr(), input.data_ptr(), d, \ PARAM); \ @@ -403,9 +404,10 @@ __global__ void swigluoai_and_mul_kernel( VLLM_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::TypeConverter::Type, \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ - PACKED_KERNEL::Type>, \ + PACKED_KERNEL< \ + typename vllm::PackedTypeConverter::Type>, \ true, false><<>>( \ out.data_ptr(), input.data_ptr(), d, \ PARAM); \ @@ -415,11 +417,11 @@ __global__ void swigluoai_and_mul_kernel( dim3 block(std::min(d, 1024)); \ VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::TypeConverter::Type, \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ - PACKED_KERNEL::Type>, false> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d, PARAM); \ + PACKED_KERNEL::Type>, \ + false><<>>( \ + out.data_ptr(), input.data_ptr(), d, PARAM); \ }); \ } diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh index 9e1fa846fd34..82a19f10a70e 100644 --- a/csrc/cuda_vec_utils.cuh +++ b/csrc/cuda_vec_utils.cuh @@ -5,22 +5,26 @@ #include #include -#include -#include -#include +#include + +#ifdef USE_ROCM + #include +#else + #include + #include + #include +#endif // Device-side: SM100+ architecture with CUDA 12.9+ toolkit, which // together enable 256-bit (v8.u32) PTX load/store instructions. // Use for PTX instruction selection with architecture fallback paths. -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \ +#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \ defined(CUDA_VERSION) && CUDA_VERSION >= 12090 #define VLLM_256B_PTX_ENABLED 1 #else #define VLLM_256B_PTX_ENABLED 0 #endif -#ifndef USE_ROCM - namespace vllm { // ============================================================ @@ -49,50 +53,51 @@ struct VecTraits { using vec_t = int4; }; -// TypeConverter — map between CUDA scalar and packed types +// PackedTypeConverter — map between CUDA scalar and packed types // half <-> half2, __nv_bfloat16 <-> __nv_bfloat162, etc. template -struct TypeConverter { - using Type = half2; +struct PackedTypeConverter { + static_assert(sizeof(T) == 0, + "PackedTypeConverter is not specialized for this type."); }; template <> -struct TypeConverter { +struct PackedTypeConverter { using Type = half; }; template <> -struct TypeConverter { +struct PackedTypeConverter { using Type = half2; }; template <> -struct TypeConverter<__nv_bfloat162> { +struct PackedTypeConverter<__nv_bfloat162> { using Type = __nv_bfloat16; }; template <> -struct TypeConverter<__nv_bfloat16> { +struct PackedTypeConverter<__nv_bfloat16> { using Type = __nv_bfloat162; }; template <> -struct TypeConverter { +struct PackedTypeConverter { using Type = float2; }; template <> -struct TypeConverter { +struct PackedTypeConverter { using Type = float; }; template <> -struct TypeConverter { +struct PackedTypeConverter { using Type = half2; }; template <> -struct TypeConverter { +struct PackedTypeConverter { using Type = __nv_bfloat162; }; @@ -118,36 +123,39 @@ struct CUDATypeConverter { // Type is the CUDA scalar type (e.g. half, __nv_bfloat16). template struct alignas(VecTraits::ARCH_MAX_VEC_SIZE) PackedVec { - static constexpr int NUM_ELTS = VecTraits::ARCH_MAX_VEC_SIZE / - sizeof(typename TypeConverter::Type); - typename TypeConverter::Type elts[NUM_ELTS]; + static constexpr int NUM_ELTS = + VecTraits::ARCH_MAX_VEC_SIZE / + sizeof(typename PackedTypeConverter::Type); + typename PackedTypeConverter::Type elts[NUM_ELTS]; }; // ============================================================ // Load / store primitives // ============================================================ -// 256-bit load / store with architecture fallback. -// SM100+ : PTX v8 instructions (.nc / default hint) -// Older : two uint4 loads via __ldg +// 256-bit load / store — SM100+ only (PTX v8 instructions). __device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { - #if VLLM_256B_PTX_ENABLED +#if VLLM_256B_PTX_ENABLED asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n" : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) : "l"(ptr)); - #endif +#else + assert(false && "ld256 requires SM100+ with CUDA 12.9+"); +#endif } __device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { - #if VLLM_256B_PTX_ENABLED +#if VLLM_256B_PTX_ENABLED asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n" : : "l"(ptr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7]) : "memory"); - #endif +#else + assert(false && "st256 requires SM100+ with CUDA 12.9+"); +#endif } // Generic ld256 / st256 for any 32-byte aligned type (e.g. PackedVec). @@ -179,36 +187,55 @@ __device__ __forceinline__ void st128(T& val, T* ptr) { // 256-bit cache-streaming (.cs) load / store — SM100+ only. __forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) { +#if VLLM_256B_PTX_ENABLED u32x8_t val; asm volatile("ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];" : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) : "l"(addr)); return val; +#else + assert(false && "ld256_cs requires SM100+ with CUDA 12.9+"); + return {}; +#endif } __forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) { +#if VLLM_256B_PTX_ENABLED asm volatile( "st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" ::"l"(addr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7])); +#else + assert(false && "st256_cs requires SM100+ with CUDA 12.9+"); +#endif } // 32-bit cache-streaming (.cs) load / store — SM100+ only. __forceinline__ __device__ int ld32_cs(const int* addr) { +#if VLLM_256B_PTX_ENABLED int val; asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr)); return val; +#else + assert(false && "ld32_cs requires SM100+ with CUDA 12.9+"); + return 0; +#endif } __forceinline__ __device__ void st32_cs(int* addr, int val) { +#if VLLM_256B_PTX_ENABLED asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val)); +#else + assert(false && "st32_cs requires SM100+ with CUDA 12.9+"); +#endif } // Predicated 256-bit / 128-bit cache-global (.cg) loads. // Returns zero if pred is false. SM100+ only. __device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr, bool pred) { +#if VLLM_256B_PTX_ENABLED asm volatile( "{\n" " .reg .pred pr;\n" @@ -226,10 +253,14 @@ __device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr, : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) : "r"((int)pred), "l"(ptr)); +#else + assert(false && "ld256_cg_or_zero requires SM100+ with CUDA 12.9+"); +#endif } __device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr, bool pred) { +#if VLLM_256B_PTX_ENABLED uint32_t r0, r1, r2, r3; asm volatile( @@ -246,6 +277,9 @@ __device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr, : "r"((int)pred), "l"(ptr)); val = uint4{r0, r1, r2, r3}; +#else + assert(false && "ld128_cg_or_zero requires SM100+ with CUDA 12.9+"); +#endif } // ============================================================ @@ -298,5 +332,3 @@ __device__ __forceinline__ packed_t packed_mul(const packed_t& x, } } // namespace vllm - -#endif // !USE_ROCM diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index bbd4e2362d59..c7ab9cbb5a85 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -294,7 +294,7 @@ __inline__ __device__ PackedVec compute_silu_mul( #pragma unroll for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { // silu_mul in float32 - using packed_t = typename TypeConverter::Type; + using packed_t = typename PackedTypeConverter::Type; float2 silu_vec = silu2(cast_to_float2(x_vec.elts[i])); float2 y_f2 = cast_to_float2(y_vec.elts[i]); result.elts[i] = cast_to_packed( From 090159280c022fe3cd25120bd59ff5b2221afde5 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Fri, 27 Feb 2026 16:39:16 +0000 Subject: [PATCH 6/6] delete dead code Signed-off-by: LopezCastroRoberto --- csrc/quantization/fp4/nvfp4_utils.cuh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index c7ab9cbb5a85..c1df1860c1a1 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -36,11 +36,6 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16; namespace vllm { -template -struct PackedVec<__nv_fp8_e4m3, use_256b> { - __nv_fp8x2_e4m3 elts[8]; -}; - template __host__ __device__ inline Int round_up(Int x, Int y) { static_assert(std::is_integral_v,