diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 99fa42f75e99..758a77795553 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -5,117 +5,11 @@ #include #include "cuda_compat.h" +#include "cuda_vec_utils.cuh" #include "dispatch_utils.h" namespace vllm { -struct alignas(32) u32x8_t { - uint32_t u0, u1, u2, u3, u4, u5, u6, u7; -}; - -__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \ - defined(CUDA_VERSION) && CUDA_VERSION >= 12090 - asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n" - : "=r"(val.u0), "=r"(val.u1), "=r"(val.u2), "=r"(val.u3), - "=r"(val.u4), "=r"(val.u5), "=r"(val.u6), "=r"(val.u7) - : "l"(ptr)); -#else - const uint4* uint_ptr = reinterpret_cast(ptr); - uint4 top_half = __ldg(&uint_ptr[0]); - uint4 bottom_half = __ldg(&uint_ptr[1]); - val.u0 = top_half.x; - val.u1 = top_half.y; - val.u2 = top_half.z; - val.u3 = top_half.w; - val.u4 = bottom_half.x; - val.u5 = bottom_half.y; - val.u6 = bottom_half.z; - val.u7 = bottom_half.w; -#endif -} - -__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \ - defined(CUDA_VERSION) && CUDA_VERSION >= 12090 - asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n" - : - : "l"(ptr), "r"(val.u0), "r"(val.u1), "r"(val.u2), "r"(val.u3), - "r"(val.u4), "r"(val.u5), "r"(val.u6), "r"(val.u7) - : "memory"); -#else - uint4* uint_ptr = reinterpret_cast(ptr); - uint_ptr[0] = make_uint4(val.u0, val.u1, val.u2, val.u3); - uint_ptr[1] = make_uint4(val.u4, val.u5, val.u6, val.u7); -#endif -} - -template -struct VecTraits; - -template <> -struct VecTraits { - static constexpr int ARCH_MAX_VEC_SIZE = 32; - using vec_t = u32x8_t; -}; - -template <> -struct VecTraits { - static constexpr int ARCH_MAX_VEC_SIZE = 16; - using vec_t = int4; -}; - -template -struct PackedTraits; - -template <> -struct PackedTraits { - using packed_t = __nv_bfloat162; -}; - -template <> -struct PackedTraits { - using packed_t = __half2; -}; - -template <> -struct PackedTraits { - using packed_t = float2; -}; - -template -__device__ __forceinline__ float2 cast_to_float2(const packed_t& val) { - if constexpr (std::is_same_v) { - return __bfloat1622float2(val); - } else if constexpr (std::is_same_v) { - return __half22float2(val); - } else if constexpr (std::is_same_v) { - return float2(val); - } -} - -template -__device__ __forceinline__ packed_t cast_to_packed(const float2& val) { - if constexpr (std::is_same_v) { - return __float22bfloat162_rn(val); - } else if constexpr (std::is_same_v) { - return __float22half2_rn(val); - } else if constexpr (std::is_same_v) { - return float2(val); - } -} - -template -__device__ __forceinline__ packed_t packed_mul(const packed_t& x, - const packed_t& y) { - if constexpr (std::is_same_v || - std::is_same_v) { - return __hmul2(x, y); - } else if constexpr (std::is_same_v) { - return make_float2(x.x * y.x, x.y * y.y); - } -} - template __device__ __forceinline__ scalar_t compute(const scalar_t& x, @@ -131,16 +25,6 @@ __device__ __forceinline__ packed_t packed_compute(const packed_t& x, : packed_mul(x, PACKED_ACT_FN(y)); } -// Check if all pointers are 16-byte aligned for int4 vectorized access -__host__ __device__ __forceinline__ bool is_16byte_aligned(const void* ptr) { - return (reinterpret_cast(ptr) & 15) == 0; -} - -// Check if all pointers are 16-byte aligned for longlong4_32a vectorized access -__host__ __device__ __forceinline__ bool is_32byte_aligned(const void* ptr) { - return (reinterpret_cast(ptr) & 31) == 0; -} - // Activation and gating kernel template. template ::vec_t; - constexpr int ARCH_MAX_VEC_SIZE = VecTraits::ARCH_MAX_VEC_SIZE; - constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(packed_t); + using cuda_t = typename CUDATypeConverter::Type; + using pvec_t = PackedVec; - const vec_t* x_vec = reinterpret_cast(x_ptr); - const vec_t* y_vec = reinterpret_cast(y_ptr); - vec_t* out_vec = reinterpret_cast(out_ptr); - const int num_vecs = d / 2 / VEC_SIZE; + const pvec_t* x_vec = reinterpret_cast(x_ptr); + const pvec_t* y_vec = reinterpret_cast(y_ptr); + pvec_t* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / 2 / pvec_t::NUM_ELTS; for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { - vec_t x, y; + pvec_t x, y; if constexpr (use_256b) { ld256(x, &x_vec[i]); ld256(y, &y_vec[i]); } else { - x = VLLM_LDG(&x_vec[i]); - y = VLLM_LDG(&y_vec[i]); + ld128(x, &x_vec[i]); + ld128(y, &y_vec[i]); } - auto* xp = reinterpret_cast(&x); - auto* yp = reinterpret_cast(&y); #pragma unroll - for (int j = 0; j < VEC_SIZE; j++) { - xp[j] = - packed_compute(xp[j], yp[j]); + for (int j = 0; j < pvec_t::NUM_ELTS; j++) { + x.elts[j] = packed_compute( + x.elts[j], y.elts[j]); } if constexpr (use_256b) { st256(x, &out_vec[i]); } else { - out_vec[i] = x; + st128(x, &out_vec[i]); } } } else { @@ -272,51 +152,54 @@ packed_gelu_tanh_kernel(const packed_t& val) { // Launch activation and gating kernel. // Use ACT_FIRST (bool) indicating whether to apply the activation function // first. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ - auto dtype = input.scalar_type(); \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - if (num_tokens == 0) { \ - return; \ - } \ - dim3 grid(num_tokens); \ - int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ - int vec_size = support_vec / at::elementSize(dtype); \ - const bool use_vec = (d % vec_size == 0); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - if (use_vec) { \ - dim3 block(std::min(d / vec_size, 1024)); \ - if (cc_major >= 10 && num_tokens > 128) { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ - KERNEL, \ - PACKED_KERNEL::packed_t>, \ - ACT_FIRST, true, true><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ - } else { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ - KERNEL, \ - PACKED_KERNEL::packed_t>, \ - ACT_FIRST, true, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ - } \ - } else { \ - dim3 block(std::min(d, 1024)); \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ - KERNEL, \ - PACKED_KERNEL::packed_t>, \ - ACT_FIRST, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ + int vec_size = support_vec / at::elementSize(dtype); \ + const bool use_vec = (d % vec_size == 0); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL::Type>, \ + ACT_FIRST, true, true><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } else { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL::Type>, \ + ACT_FIRST, true, false><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL::Type>, \ + ACT_FIRST, false><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ } void silu_and_mul(torch::Tensor& out, // [..., d] @@ -378,35 +261,31 @@ __global__ void act_and_mul_kernel_with_param( scalar_t* out_ptr = out + blockIdx.x * d; if constexpr (use_vec) { - // Fast path: 128-bit/256-bit vectorized loop - using vec_t = typename VecTraits::vec_t; - constexpr int ARCH_MAX_VEC_SIZE = VecTraits::ARCH_MAX_VEC_SIZE; - constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(packed_t); + using cuda_t = typename CUDATypeConverter::Type; + using pvec_t = PackedVec; - const vec_t* x_vec = reinterpret_cast(x_ptr); - const vec_t* y_vec = reinterpret_cast(y_ptr); - vec_t* out_vec = reinterpret_cast(out_ptr); - const int num_vecs = d / 2 / VEC_SIZE; + const pvec_t* x_vec = reinterpret_cast(x_ptr); + const pvec_t* y_vec = reinterpret_cast(y_ptr); + pvec_t* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / 2 / pvec_t::NUM_ELTS; for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { - vec_t x, y; + pvec_t x, y; if constexpr (use_256b) { ld256(x, &x_vec[i]); ld256(y, &y_vec[i]); } else { - x = VLLM_LDG(&x_vec[i]); - y = VLLM_LDG(&y_vec[i]); + ld128(x, &x_vec[i]); + ld128(y, &y_vec[i]); } - auto* xp = reinterpret_cast(&x); - auto* yp = reinterpret_cast(&y); #pragma unroll - for (int j = 0; j < VEC_SIZE; j++) { - xp[j] = packed_mul(PACKED_ACT_FN(xp[j], param), yp[j]); + for (int j = 0; j < pvec_t::NUM_ELTS; j++) { + x.elts[j] = packed_mul(PACKED_ACT_FN(x.elts[j], param), y.elts[j]); } if constexpr (use_256b) { st256(x, &out_vec[i]); } else { - out_vec[i] = x; + st128(x, &out_vec[i]); } } } else { @@ -499,21 +378,24 @@ __global__ void swigluoai_and_mul_kernel( } \ dim3 grid(num_tokens); \ int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ int vec_size = support_vec / at::elementSize(dtype); \ const bool use_vec = (d % vec_size == 0); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ if (use_vec) { \ dim3 block(std::min(d / vec_size, 1024)); \ - if (cc_major >= 10 && num_tokens > 128) { \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ VLLM_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ PACKED_KERNEL< \ - typename vllm::PackedTraits::packed_t>, \ + typename vllm::PackedTypeConverter::Type>, \ true, true><<>>( \ out.data_ptr(), input.data_ptr(), d, \ PARAM); \ @@ -522,10 +404,10 @@ __global__ void swigluoai_and_mul_kernel( VLLM_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ PACKED_KERNEL< \ - typename vllm::PackedTraits::packed_t>, \ + typename vllm::PackedTypeConverter::Type>, \ true, false><<>>( \ out.data_ptr(), input.data_ptr(), d, \ PARAM); \ @@ -535,9 +417,9 @@ __global__ void swigluoai_and_mul_kernel( dim3 block(std::min(d, 1024)); \ VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ - PACKED_KERNEL::packed_t>, \ + PACKED_KERNEL::Type>, \ false><<>>( \ out.data_ptr(), input.data_ptr(), d, PARAM); \ }); \ @@ -629,14 +511,17 @@ __global__ void activation_kernel( } \ dim3 grid(num_tokens); \ int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ int vec_size = support_vec / at::elementSize(dtype); \ const bool use_vec = (d % vec_size == 0); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ if (use_vec) { \ dim3 block(std::min(d / vec_size, 1024)); \ - if (cc_major >= 10 && num_tokens > 128) { \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ vllm::activation_kernel, true, true> \ <<>>(out.data_ptr(), \ diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh new file mode 100644 index 000000000000..82a19f10a70e --- /dev/null +++ b/csrc/cuda_vec_utils.cuh @@ -0,0 +1,334 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +#pragma once + +#include +#include +#include + +#ifdef USE_ROCM + #include +#else + #include + #include + #include +#endif + +// Device-side: SM100+ architecture with CUDA 12.9+ toolkit, which +// together enable 256-bit (v8.u32) PTX load/store instructions. +// Use for PTX instruction selection with architecture fallback paths. +#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \ + defined(CUDA_VERSION) && CUDA_VERSION >= 12090 + #define VLLM_256B_PTX_ENABLED 1 +#else + #define VLLM_256B_PTX_ENABLED 0 +#endif + +namespace vllm { + +// ============================================================ +// Types and traits +// ============================================================ + +// 256-bit (32-byte) aligned vector type: 8 x uint32_t +struct alignas(32) u32x8_t { + uint32_t d[8]; +}; + +// VecTraits — select between 128-bit (int4) and 256-bit +// (u32x8_t) vector types at compile time. +template +struct VecTraits; + +template <> +struct VecTraits { + static constexpr int ARCH_MAX_VEC_SIZE = 32; + using vec_t = u32x8_t; +}; + +template <> +struct VecTraits { + static constexpr int ARCH_MAX_VEC_SIZE = 16; + using vec_t = int4; +}; + +// PackedTypeConverter — map between CUDA scalar and packed types +// half <-> half2, __nv_bfloat16 <-> __nv_bfloat162, etc. +template +struct PackedTypeConverter { + static_assert(sizeof(T) == 0, + "PackedTypeConverter is not specialized for this type."); +}; + +template <> +struct PackedTypeConverter { + using Type = half; +}; + +template <> +struct PackedTypeConverter { + using Type = half2; +}; + +template <> +struct PackedTypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct PackedTypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +template <> +struct PackedTypeConverter { + using Type = float2; +}; + +template <> +struct PackedTypeConverter { + using Type = float; +}; + +template <> +struct PackedTypeConverter { + using Type = half2; +}; + +template <> +struct PackedTypeConverter { + using Type = __nv_bfloat162; +}; + +// CUDATypeConverter — map PyTorch scalar types to CUDA scalar +// c10::Half -> half, c10::BFloat16 -> __nv_bfloat16 +template +struct CUDATypeConverter { + using Type = T; +}; + +template <> +struct CUDATypeConverter { + using Type = half; +}; + +template <> +struct CUDATypeConverter { + using Type = __nv_bfloat16; +}; + +// PackedVec — typed vector container for packed element access. +// Derives alignment and element count from VecTraits. +// Type is the CUDA scalar type (e.g. half, __nv_bfloat16). +template +struct alignas(VecTraits::ARCH_MAX_VEC_SIZE) PackedVec { + static constexpr int NUM_ELTS = + VecTraits::ARCH_MAX_VEC_SIZE / + sizeof(typename PackedTypeConverter::Type); + typename PackedTypeConverter::Type elts[NUM_ELTS]; +}; + +// ============================================================ +// Load / store primitives +// ============================================================ + +// 256-bit load / store — SM100+ only (PTX v8 instructions). +__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { +#if VLLM_256B_PTX_ENABLED + asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "l"(ptr)); +#else + assert(false && "ld256 requires SM100+ with CUDA 12.9+"); +#endif +} + +__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { +#if VLLM_256B_PTX_ENABLED + asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n" + : + : "l"(ptr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), + "r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), + "r"(val.d[7]) + : "memory"); +#else + assert(false && "st256 requires SM100+ with CUDA 12.9+"); +#endif +} + +// Generic ld256 / st256 for any 32-byte aligned type (e.g. PackedVec). +// Non-template overloads above are preferred for u32x8_t. +template +__device__ __forceinline__ void ld256(T& val, const T* ptr) { + static_assert(sizeof(T) == 32, "ld256 requires a 32-byte type"); + ld256(reinterpret_cast(val), reinterpret_cast(ptr)); +} + +template +__device__ __forceinline__ void st256(T& val, T* ptr) { + static_assert(sizeof(T) == 32, "st256 requires a 32-byte type"); + st256(reinterpret_cast(val), reinterpret_cast(ptr)); +} + +// 128-bit load / store via __ldg (read-only cache hint). +template +__device__ __forceinline__ void ld128(T& val, const T* ptr) { + static_assert(sizeof(T) == 16, "ld128 requires a 16-byte type"); + *reinterpret_cast(&val) = __ldg(reinterpret_cast(ptr)); +} + +template +__device__ __forceinline__ void st128(T& val, T* ptr) { + static_assert(sizeof(T) == 16, "st128 requires a 16-byte type"); + *reinterpret_cast(ptr) = *reinterpret_cast(&val); +} + +// 256-bit cache-streaming (.cs) load / store — SM100+ only. +__forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) { +#if VLLM_256B_PTX_ENABLED + u32x8_t val; + asm volatile("ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "l"(addr)); + return val; +#else + assert(false && "ld256_cs requires SM100+ with CUDA 12.9+"); + return {}; +#endif +} + +__forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) { +#if VLLM_256B_PTX_ENABLED + asm volatile( + "st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" ::"l"(addr), + "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]), + "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7])); +#else + assert(false && "st256_cs requires SM100+ with CUDA 12.9+"); +#endif +} + +// 32-bit cache-streaming (.cs) load / store — SM100+ only. +__forceinline__ __device__ int ld32_cs(const int* addr) { +#if VLLM_256B_PTX_ENABLED + int val; + asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr)); + return val; +#else + assert(false && "ld32_cs requires SM100+ with CUDA 12.9+"); + return 0; +#endif +} + +__forceinline__ __device__ void st32_cs(int* addr, int val) { +#if VLLM_256B_PTX_ENABLED + asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val)); +#else + assert(false && "st32_cs requires SM100+ with CUDA 12.9+"); +#endif +} + +// Predicated 256-bit / 128-bit cache-global (.cg) loads. +// Returns zero if pred is false. SM100+ only. +__device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr, + bool pred) { +#if VLLM_256B_PTX_ENABLED + asm volatile( + "{\n" + " .reg .pred pr;\n" + " setp.ne.u32 pr, %8, 0;\n" + " mov.u32 %0, 0;\n" + " mov.u32 %1, 0;\n" + " mov.u32 %2, 0;\n" + " mov.u32 %3, 0;\n" + " mov.u32 %4, 0;\n" + " mov.u32 %5, 0;\n" + " mov.u32 %6, 0;\n" + " mov.u32 %7, 0;\n" + " @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n" + "}\n" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "r"((int)pred), "l"(ptr)); +#else + assert(false && "ld256_cg_or_zero requires SM100+ with CUDA 12.9+"); +#endif +} + +__device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr, + bool pred) { +#if VLLM_256B_PTX_ENABLED + uint32_t r0, r1, r2, r3; + + asm volatile( + "{\n" + " .reg .pred pr;\n" + " setp.ne.u32 pr, %4, 0;\n" + " mov.u32 %0, 0;\n" + " mov.u32 %1, 0;\n" + " mov.u32 %2, 0;\n" + " mov.u32 %3, 0;\n" + " @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n" + "}\n" + : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) + : "r"((int)pred), "l"(ptr)); + + val = uint4{r0, r1, r2, r3}; +#else + assert(false && "ld128_cg_or_zero requires SM100+ with CUDA 12.9+"); +#endif +} + +// ============================================================ +// Alignment helpers +// ============================================================ + +__host__ __device__ __forceinline__ bool is_16byte_aligned(const void* ptr) { + return (reinterpret_cast(ptr) & 15) == 0; +} + +__host__ __device__ __forceinline__ bool is_32byte_aligned(const void* ptr) { + return (reinterpret_cast(ptr) & 31) == 0; +} + +// ============================================================ +// Packed type conversion and arithmetic +// ============================================================ + +template +__device__ __forceinline__ float2 cast_to_float2(const packed_t& val) { + if constexpr (std::is_same_v) { + return __bfloat1622float2(val); + } else if constexpr (std::is_same_v) { + return __half22float2(val); + } else if constexpr (std::is_same_v) { + return float2(val); + } +} + +template +__device__ __forceinline__ packed_t cast_to_packed(const float2& val) { + if constexpr (std::is_same_v) { + return __float22bfloat162_rn(val); + } else if constexpr (std::is_same_v) { + return __float22half2_rn(val); + } else if constexpr (std::is_same_v) { + return float2(val); + } +} + +template +__device__ __forceinline__ packed_t packed_mul(const packed_t& x, + const packed_t& y) { + if constexpr (std::is_same_v || + std::is_same_v) { + return __hmul2(x, y); + } else if constexpr (std::is_same_v) { + return make_float2(x.x * y.x, x.y * y.y); + } +} + +} // namespace vllm diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index 8583b79fd58f..3539096c9feb 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -39,12 +39,12 @@ namespace vllm { template __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, - int32_t num_padded_cols, + int32_t num_packed_cols, Type const* __restrict__ in, float const* __restrict__ SFScale, uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { - using PackedVec = vllm::PackedVec; + using PackedVec = vllm::PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, @@ -63,7 +63,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) // Input tensor row/col loops. for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { - if (colIdx < num_padded_cols) { + if (colIdx < num_packed_cols) { PackedVec in_vec; PackedVec in_vec2; int64_t inOffset = @@ -73,19 +73,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) bool valid = (rowIdx < numRows) && (elem_idx < numCols); if constexpr (CVT_FP4_PACK16) { - ld256_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 8], - valid); - ld256_or_zero_cg_u32( - in_vec2, &reinterpret_cast(in)[inOffset2 * 8], - valid); + ld256_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 8], + valid); + ld256_cg_or_zero(reinterpret_cast(in_vec2), + &reinterpret_cast(in)[inOffset2 * 8], + valid); } else { - ld128_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 4], - valid); - ld128_or_zero_cg_u32( - in_vec2, &reinterpret_cast(in)[inOffset2 * 4], - valid); + ld128_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 4], + valid); + ld128_cg_or_zero(reinterpret_cast(in_vec2), + &reinterpret_cast(in)[inOffset2 * 4], + valid); } // Compute silu and mul @@ -142,9 +142,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] int const numBlocksPerSM = vllm_runtime_blocks_per_sm(static_cast(block.x)); - int sf_n_unpadded = int(n / CVT_FP4_ELTS_PER_THREAD); + int num_packed_cols = int(n / CVT_FP4_ELTS_PER_THREAD); - int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); + int grid_y = vllm::div_round_up(num_packed_cols, static_cast(block.x)); int grid_x = std::min( int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); dim3 grid(grid_x, grid_y); @@ -154,7 +154,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] using cuda_type = vllm::CUDATypeConverter::Type; auto input_ptr = static_cast(input.data_ptr()); vllm::silu_mul_cvt_fp16_to_fp4<<>>( - m, n, sf_n_unpadded, input_ptr, input_sf_ptr, + m, n, num_packed_cols, input_ptr, input_sf_ptr, reinterpret_cast(output_ptr), reinterpret_cast(sf_out)); }); diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index 32685c201102..3162b6cdb8a9 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -43,7 +43,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) uint32_t* input_offset_by_experts, uint32_t* output_scale_offset_by_experts, int n_experts, bool low_latency) { - using PackedVec = PackedVec; + using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, @@ -155,7 +155,7 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) float const* SFScale, uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, uint32_t* output_scale_offset_by_experts, int n_experts) { - using PackedVec = PackedVec; + using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index b521b4707a4d..773047c22500 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -42,7 +42,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) Type const* __restrict__ in, float const* __restrict__ SFScale, uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { - using PackedVec = vllm::PackedVec; + using PackedVec = vllm::PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -71,13 +71,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) // If we are outside valid rows OR outside valid columns -> Use Zeros bool valid = (rowIdx < numRows) && (elem_idx < numCols); if constexpr (CVT_FP4_PACK16) { - ld256_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 8], - valid); + ld256_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 8], + valid); } else { - ld128_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 4], - valid); + ld128_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 4], + valid); } auto sf_out = @@ -114,7 +114,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) float const* __restrict__ SFScale, uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { - using PackedVec = PackedVec; + using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -139,13 +139,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) // If we are outside valid rows OR outside valid columns -> Use Zeros bool valid = (rowIdx < numRows) && (elem_idx < numCols); if constexpr (CVT_FP4_PACK16) { - ld256_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 8], - valid); + ld256_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 8], + valid); } else { - ld128_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 4], - valid); + ld128_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 4], + valid); } auto sf_out = diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index 3e7adb9e2931..c1df1860c1a1 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -19,8 +19,10 @@ #include #include -#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \ - defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) +#include "../../cuda_vec_utils.cuh" + +#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \ + CUDA_VERSION >= 12090 #define ELTS_PER_THREAD 16 constexpr int CVT_FP4_ELTS_PER_THREAD = 16; constexpr bool CVT_FP4_PACK16 = true; @@ -34,68 +36,6 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16; namespace vllm { -// Convert PyTorch cpp type to CUDA type -template -struct CUDATypeConverter { - using Type = T; -}; - -template <> -struct CUDATypeConverter { - using Type = half; -}; - -template <> -struct CUDATypeConverter { - using Type = __nv_bfloat16; -}; - -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \ - defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) -// Define a 32 bytes packed data type. -template -struct alignas(32) PackedVec { - typename TypeConverter::Type elts[8]; -}; -#else -// Define a 16 bytes packed data type. -template -struct alignas(16) PackedVec { - typename TypeConverter::Type elts[4]; -}; -#endif - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - template __host__ __device__ inline Int round_up(Int x, Int y) { static_assert(std::is_integral_v, @@ -208,56 +148,6 @@ __device__ __forceinline__ float reciprocal_approximate_ftz(float a) { return b; } -template -__device__ __forceinline__ void ld128_or_zero_cg_u32(PackedVec& out, - const void* ptr, - bool pred) { - uint32_t r0, r1, r2, r3; - - asm volatile( - "{\n" - " .reg .pred pr;\n" - " setp.ne.u32 pr, %4, 0;\n" - " mov.u32 %0, 0;\n" - " mov.u32 %1, 0;\n" - " mov.u32 %2, 0;\n" - " mov.u32 %3, 0;\n" - " @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n" - "}\n" - : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) - : "r"((int)pred), "l"(ptr)); - - *reinterpret_cast(&out) = uint4{r0, r1, r2, r3}; -} - -template -__device__ __forceinline__ void ld256_or_zero_cg_u32(PackedVec& out, - const void* ptr, - bool pred) { - uint32_t r0, r1, r2, r3, r4, r5, r6, r7; - - asm volatile( - "{\n" - " .reg .pred pr;\n" - " setp.ne.u32 pr, %8, 0;\n" - " mov.u32 %0, 0;\n" - " mov.u32 %1, 0;\n" - " mov.u32 %2, 0;\n" - " mov.u32 %3, 0;\n" - " mov.u32 %4, 0;\n" - " mov.u32 %5, 0;\n" - " mov.u32 %6, 0;\n" - " mov.u32 %7, 0;\n" - " @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n" - "}\n" - : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4), "=r"(r5), "=r"(r6), - "=r"(r7) - : "r"((int)pred), "l"(ptr)); - - reinterpret_cast(&out)[0] = uint4{r0, r1, r2, r3}; - reinterpret_cast(&out)[1] = uint4{r4, r5, r6, r7}; -} - // Compute SF output offset for swizzled tensor core layout. // SF layout: [numMTiles, numKTiles, 32, 4, 4] // Caller must precompute: numKTiles = (numCols + 63) / 64 @@ -315,8 +205,8 @@ __device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack, // Quantizes the provided PackedVec into the uint32_t output template -__device__ __forceinline__ fp4_packed_t -cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { +__device__ __forceinline__ fp4_packed_t cvt_warp_fp16_to_fp4( + PackedVec& vec, float SFScaleVal, uint8_t* SFout) { // Get absolute maximum values among the local 8 values. auto localMax = __habs2(vec.elts[0]); @@ -372,11 +262,7 @@ cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { #pragma unroll for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(vec.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } + fp2Vals[i] = cast_to_float2(vec.elts[i]); fp2Vals[i].x *= outputScale; fp2Vals[i].y *= outputScale; } @@ -395,22 +281,19 @@ __device__ __forceinline__ float2 silu2(float2 x) { } template -__inline__ __device__ PackedVec compute_silu_mul( - const PackedVec& x_vec, const PackedVec& y_vec) { - PackedVec result; +__inline__ __device__ PackedVec compute_silu_mul( + const PackedVec& x_vec, + const PackedVec& y_vec) { + PackedVec result; #pragma unroll for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { // silu_mul in float32 - if constexpr (std::is_same_v) { - float2 silu_vec = silu2(__half22float2(x_vec.elts[i])); - result.elts[i] = __float22half2_rn( - __fmul2_rn(silu_vec, __half22float2(y_vec.elts[i]))); - } else { - float2 silu_vec = silu2(__bfloat1622float2(x_vec.elts[i])); - result.elts[i] = __float22bfloat162_rn( - __fmul2_rn(silu_vec, __bfloat1622float2(y_vec.elts[i]))); - } + using packed_t = typename PackedTypeConverter::Type; + float2 silu_vec = silu2(cast_to_float2(x_vec.elts[i])); + float2 y_f2 = cast_to_float2(y_vec.elts[i]); + result.elts[i] = cast_to_packed( + make_float2(silu_vec.x * y_f2.x, silu_vec.y * y_f2.y)); } return result; }