Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,12 @@ set(VLLM_EXT_SRC
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/custom_all_reduce.cu"
"csrc/torch_bindings.cpp")
"csrc/torch_bindings.cpp"
"csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")

if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/minimax_reduce_rms_kernel.cu"
"csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu")
"csrc/minimax_reduce_rms_kernel.cu")

SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")

Expand Down Expand Up @@ -1047,13 +1047,13 @@ endif()
set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu")
"csrc/moe/topk_softmax_kernels.cu"
"csrc/moe/topk_softplus_sqrt_kernels.cu")

if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/topk_softplus_sqrt_kernels.cu")
"csrc/moe/grouped_topk_kernels.cu")
endif()

if(VLLM_GPU_LANG STREQUAL "CUDA")
Expand Down
48 changes: 46 additions & 2 deletions csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
*/

#include <cmath>
#include <cuda_fp8.h>
#ifndef USE_ROCM
#include <cuda_fp8.h>
#else
#include <hip/hip_fp8.h>
#endif
#include <cuda_runtime.h>
#include <type_traits>

Expand All @@ -42,7 +46,27 @@
#include "type_convert.cuh"

#ifndef FINAL_MASK
#define FINAL_MASK 0xffffffffu
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
#else
#define FINAL_MASK 0xffffffffu
#endif
#endif

#ifdef USE_ROCM
// ROCm-compatible FP8 conversion helpers
__device__ __forceinline__ uint8_t rocm_cvt_float_to_fp8_e4m3(float val) {
// HIP defines HIP_FP8_TYPE_OCP based on HIP version, not GPU arch. On gfx942
// mfma only supports FNUZ fp8, and the rest of vLLM's gfx942 path (Triton
// indexer / current_platform.fp8_dtype()) uses FNUZ. Gate OCP on __gfx950__
// so the K cache encoding matches what the reader expects.
#if defined(HIP_FP8_TYPE_OCP) && defined(__gfx950__)
__hip_fp8_e4m3 fp8_val(val);
#else
__hip_fp8_e4m3_fnuz fp8_val(val);
#endif
return reinterpret_cast<uint8_t&>(fp8_val);
}
#endif

namespace vllm {
Expand All @@ -65,7 +89,13 @@ constexpr int kQuantBlock = 64;
constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7
constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad)
constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576
// Match the encoding chosen in rocm_cvt_float_to_fp8_e4m3: FNUZ on gfx942
// (max 240), OCP on gfx950 (max 448).
#if defined(USE_ROCM) && (!defined(HIP_FP8_TYPE_OCP) || !defined(__gfx950__))
constexpr float kFp8Max = 240.0f;
#else
constexpr float kFp8Max = 448.0f;
#endif

// Per-warp layout: 32 lanes × 16 elems/lane = 512 elems = HEAD_DIM.
constexpr int kNumLanes = 32;
Expand Down Expand Up @@ -314,9 +344,13 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel(
for (int i = 0; i < kElemsPerLane; i++) {
float scaled = elements[i] * inv_scale;
scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max);
#ifndef USE_ROCM
__nv_fp8_storage_t s =
__nv_cvt_float_to_fp8(scaled, __NV_SATFINITE, __NV_E4M3);
out_bytes[i] = static_cast<uint8_t>(s);
#else
out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled);
#endif
}
// One 16-byte STG per lane.
*reinterpret_cast<uint4*>(token_fp8_ptr + dim_base) =
Expand Down Expand Up @@ -384,6 +418,7 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
// PDL: enable programmatic stream serialization whenever the hardware
// supports it (SM90+). On pre-Hopper GPUs the attribute is unavailable,
// so leave numAttrs = 0 and launch as a regular kernel.
#ifndef USE_ROCM
static int const sm_version = getSMVersion();
// Host-side guard: the device kernel body is compiled as a no-op for
// bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert<BFloat16> is
Expand All @@ -410,6 +445,15 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps,
num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size,
kv_block_stride);
#else
// ROCm: use standard kernel launch syntax (no PDL/stream serialization)
// clang-format off
fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel<scalar_t_in>
<<<grid, kBlockSize, 0, stream>>>(
q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache,
eps, num_tokens_full, num_tokens_insert, num_heads_q,
cache_block_size, kv_block_stride);
#endif
}

} // namespace deepseek_v4_fused_ops
Expand Down
53 changes: 32 additions & 21 deletions csrc/moe/topk_softplus_sqrt_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,6 @@ __device__ __forceinline__ float toFloat(T value) {
}
}

#define FINAL_MASK 0xffffffff
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}

// ====================== TopK softplus_sqrt things
// ===============================

Expand Down Expand Up @@ -272,8 +263,14 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
}
}
// Compute per-thread scale (using warp reduction when renormalizing).
// THREADS_PER_ROW-parameterized butterfly works for both warp sizes (32
// on CUDA, 64 on ROCm CDNA) and any THREADS_PER_ROW the dispatch picks.
if (renormalize) {
selected_sum = warpReduceSum(selected_sum);
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
selected_sum +=
VLLM_SHFL_XOR_SYNC_WIDTH(selected_sum, mask, THREADS_PER_ROW);
}
}
float scale = static_cast<float>(routed_scaling_factor);
if (renormalize) {
Expand Down Expand Up @@ -544,14 +541,26 @@ void topkGatingSoftplusSqrtKernelLauncher(
const IndType* tid2eid, cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
#ifndef USE_ROCM
// for bfloat16 dtype, we need 4 bytes loading to make sure num_experts
// elements can be loaded by a warp
static constexpr int BYTES_PER_LDG_MULTIPLE_64 =
(std::is_same_v<InputType, __nv_bfloat16> ||
std::is_same_v<InputType, __half>)
? 4
: 8;
// Narrower LDG (ELTS_PER_LDG=1) used by 192/320/448/576 on ROCm WARP_SIZE=64
// where ELTS_PER_LDG=2 fails the EXPERTS%(ELTS_PER_LDG*WARP_SIZE)==0 check.
// On CUDA WARP_SIZE=32 the wider LDG already aligns, so the alias collapses
// back to BYTES_PER_LDG_MULTIPLE_64 — no behavioral change for CUDA.
#ifdef USE_ROCM
static constexpr int BYTES_PER_LDG_MULTIPLE_64_NARROW =
(std::is_same_v<InputType, __nv_bfloat16> ||
std::is_same_v<InputType, __half>)
? 2
: 4;
#else
static constexpr int BYTES_PER_LDG_MULTIPLE_64_NARROW =
BYTES_PER_LDG_MULTIPLE_64;
#endif
switch (num_experts) {
case 1:
Expand Down Expand Up @@ -584,27 +593,29 @@ void topkGatingSoftplusSqrtKernelLauncher(
case 512:
LAUNCH_SOFTPLUS_SQRT(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
// (CUDA only) support multiples of 64 when num_experts is not power of 2.
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of
// num_experts, alternatively we can test 4 bytes loading and enable it in
// future.
#ifndef USE_ROCM
// Multiples of 64 that are not powers of 2. The kernel requires
// EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0. With ELTS_PER_LDG=2
// (BYTES_PER_LDG_MULTIPLE_64), this holds for all five values on CUDA
// WARP_SIZE=32 but only for 384 on ROCm WARP_SIZE=64. The other four
// use BYTES_PER_LDG_MULTIPLE_64_NARROW (ELTS_PER_LDG=1), which
// satisfies the assertion for any multiple of 64 on either backend;
// on CUDA the narrow alias collapses back to the wider load, so CUDA
// behavior is unchanged.
case 192:
LAUNCH_SOFTPLUS_SQRT(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
LAUNCH_SOFTPLUS_SQRT(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64_NARROW);
break;
case 320:
LAUNCH_SOFTPLUS_SQRT(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
LAUNCH_SOFTPLUS_SQRT(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64_NARROW);
break;
case 384:
LAUNCH_SOFTPLUS_SQRT(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
case 448:
LAUNCH_SOFTPLUS_SQRT(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
LAUNCH_SOFTPLUS_SQRT(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64_NARROW);
break;
case 576:
LAUNCH_SOFTPLUS_SQRT(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
LAUNCH_SOFTPLUS_SQRT(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64_NARROW);
break;
#endif
default: {
TORCH_CHECK(false, "Unsupported expert number: ", num_experts);
}
Expand Down
3 changes: 1 addition & 2 deletions csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"bias) -> ()");
m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid);

#ifndef USE_ROCM
m.def(
"topk_softplus_sqrt(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize, float "
"routed_scaling_factor, Tensor? "
"bias, Tensor? input_ids, Tensor? tid2eid) -> ()");
m.impl("topk_softplus_sqrt", torch::kCUDA, &topk_softplus_sqrt);
#endif

// Calculate the result of moe by summing up the partial results
// from all selected experts.
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
Expand Down
2 changes: 0 additions & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int forced_token_heads_per_warp=-1) -> ()");
ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);

#ifndef USE_ROCM
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
// kernel launch.
Expand All @@ -194,7 +193,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float eps, int cache_block_size) -> ()");
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA,
&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert);
#endif

// Apply repetition penalties to logits in-place
ops.def(
Expand Down
2 changes: 1 addition & 1 deletion docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ configuration.
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 512, 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | 1, 64 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
3 changes: 3 additions & 0 deletions requirements/rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ timm>=1.0.17
# amd-quark: required for Quark quantization on ROCm
# To be consistent with test_quark.py
amd-quark>=0.8.99
# tilelang has to be installed for mhc module to be
# imported correctly.
tilelang==0.1.9
6 changes: 4 additions & 2 deletions tests/kernels/moe/test_topk_softplus_sqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def test_sqrtsoftplus_bias_uses_deepseek_v4_routing_method():


@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
not current_platform.is_cuda_alike(),
reason="This test is skipped on non-CUDA platform.",
)
@pytest.mark.parametrize("num_tokens", [1, 33, 128])
@pytest.mark.parametrize("hidden_size", [1024, 2048])
Expand Down Expand Up @@ -125,7 +126,8 @@ def test_fused_topk_softplus_sqrt(


@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
not current_platform.is_cuda_alike(),
reason="This test is skipped on non-CUDA platform.",
)
@pytest.mark.parametrize("num_tokens", [1, 33, 128])
@pytest.mark.parametrize("hidden_size", [1024, 2048])
Expand Down
2 changes: 2 additions & 0 deletions vllm/config/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def with_default(
"flashinfer_cutlass",
"flashinfer_cutedsl",
"marlin",
"triton_unfused",
"aiter",
"emulation",
]
Expand Down Expand Up @@ -145,6 +146,7 @@ class KernelConfig:
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)
- "marlin": Use Marlin kernels (weight-only quantization)
- "triton_unfused": Use Triton unfused MoE kernels
- "aiter": Use AMD AITer kernels (ROCm only)
- "emulation": use BF16/FP16 GEMM, dequantizing weights and
running QDQ on activations.
Expand Down
15 changes: 15 additions & 0 deletions vllm/model_executor/kernels/linear/scaled_mm/aiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,21 @@ def apply_block_scaled_mm(
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
if As.dtype != Bs.dtype:
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_upcast_e8m0_to_fp32,
)

if As.dtype == torch.float8_e8m0fnu:
As = _upcast_e8m0_to_fp32(As).contiguous()
else:
As = As.to(torch.float32)

if Bs.dtype == torch.float8_e8m0fnu:
Bs = _upcast_e8m0_to_fp32(Bs).contiguous()
else:
Bs = Bs.to(torch.float32)

out_dtype = self.config.out_dtype
if self.use_triton:
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ class SiluAndMulWithClamp(CustomOp):
def __init__(self, swiglu_limit: float, *, compile_native: bool = True):
super().__init__(compile_native=compile_native)
self.swiglu_limit = float(swiglu_limit)
if current_platform.is_cuda_alike() or current_platform.is_xpu():
if current_platform.is_rocm():
self._forward_method = self.forward_native
elif current_platform.is_cuda_alike() or current_platform.is_xpu():
self.op = torch.ops._C.silu_and_mul_with_clamp
elif current_platform.is_cpu():
self._forward_method = self.forward_native
Expand Down
Loading
Loading