Skip to content
Draft
37 changes: 35 additions & 2 deletions csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
*/

#include <cmath>
#include <cuda_fp8.h>
#ifndef USE_ROCM
#include <cuda_fp8.h>
#else
#include <hip/hip_fp8.h>
#endif
#include <cuda_runtime.h>
#include <type_traits>

Expand All @@ -42,7 +46,23 @@
#include "type_convert.cuh"

#ifndef FINAL_MASK
#define FINAL_MASK 0xffffffffu
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
#else
#define FINAL_MASK 0xffffffffu
#endif
#endif

#ifdef USE_ROCM
// ROCm-compatible FP8 conversion helpers
__device__ __forceinline__ uint8_t rocm_cvt_float_to_fp8_e4m3(float val) {
#if defined(HIP_FP8_TYPE_OCP)
__hip_fp8_e4m3 fp8_val(val);
#else
__hip_fp8_e4m3_fnuz fp8_val(val);
#endif
return reinterpret_cast<uint8_t&>(fp8_val);
}
#endif

namespace vllm {
Expand Down Expand Up @@ -314,9 +334,13 @@ __global__ void fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel(
for (int i = 0; i < kElemsPerLane; i++) {
float scaled = elements[i] * inv_scale;
scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max);
#ifndef USE_ROCM
__nv_fp8_storage_t s =
__nv_cvt_float_to_fp8(scaled, __NV_SATFINITE, __NV_E4M3);
out_bytes[i] = static_cast<uint8_t>(s);
#else
out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled);
#endif
}
// One 16-byte STG per lane.
*reinterpret_cast<uint4*>(token_fp8_ptr + dim_base) =
Expand Down Expand Up @@ -384,6 +408,7 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
// PDL: enable programmatic stream serialization whenever the hardware
// supports it (SM90+). On pre-Hopper GPUs the attribute is unavailable,
// so leave numAttrs = 0 and launch as a regular kernel.
#ifndef USE_ROCM
static int const sm_version = getSMVersion();
// Host-side guard: the device kernel body is compiled as a no-op for
// bf16 on pre-Ampere (sm_70/sm_75) because _typeConvert<BFloat16> is
Expand All @@ -410,6 +435,14 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps,
num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size,
kv_block_stride);
#else
// ROCm: use standard kernel launch syntax (no PDL/stream serialization)
fusedDeepseekV4QNormRopeKVRopeQuantInsertKernel<scalar_t_in>
<<<grid, kBlockSize, 0, stream>>>(
q_inout, kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, eps,
num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size,
kv_block_stride);
#endif
}

} // namespace deepseek_v4_fused_ops
Expand Down
6 changes: 5 additions & 1 deletion csrc/moe/topk_softplus_sqrt_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ __device__ __forceinline__ float toFloat(T value) {
}
}

#define FINAL_MASK 0xffffffff
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
#else
#define FINAL_MASK 0xffffffff
#endif
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
Expand Down
6 changes: 5 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2451,7 +2451,11 @@ def moe_wna16_gemm(

def router_gemm_bf16_fp32(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""bf16 x bf16 -> fp32 GEMM via cuBLAS. weight shape: (N, K)."""
return torch.ops._moe_C.router_gemm_bf16_fp32(input, weight)
if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "router_gemm_bf16_fp32"):
return torch.ops._moe_C.router_gemm_bf16_fp32(input, weight)

# Native fallback for platforms/builds without the custom MoE GEMM op.
return torch.matmul(input.to(torch.float32), weight.to(torch.float32).t())


if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "router_gemm_bf16_fp32"):
Expand Down
2 changes: 2 additions & 0 deletions vllm/config/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def with_default(
"flashinfer_cutlass",
"flashinfer_cutedsl",
"marlin",
"triton_unfused",
"aiter",
"emulation",
]
Expand Down Expand Up @@ -145,6 +146,7 @@ class KernelConfig:
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)
- "marlin": Use Marlin kernels (weight-only quantization)
- "triton_unfused": Use Triton unfused MoE kernels
- "aiter": Use AMD AITer kernels (ROCm only)
- "emulation": use BF16/FP16 GEMM, dequantizing weights and
running QDQ on activations.
Expand Down
8 changes: 6 additions & 2 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def set_forward_context(
batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None,
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False,
):
"""A context manager that stores the current forward context,
Expand Down Expand Up @@ -296,7 +297,7 @@ def set_forward_context(
if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)

additional_kwargs = current_platform.set_additional_forward_context(
platform_additional_kwargs = current_platform.set_additional_forward_context(
attn_metadata=attn_metadata,
vllm_config=vllm_config,
dp_metadata=dp_metadata,
Expand All @@ -306,6 +307,9 @@ def set_forward_context(
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
)
merged_additional_kwargs = dict(platform_additional_kwargs)
if additional_kwargs:
merged_additional_kwargs.update(additional_kwargs)

forward_context = create_forward_context(
attn_metadata,
Expand All @@ -315,7 +319,7 @@ def set_forward_context(
batch_descriptor,
ubatch_slices,
slot_mapping,
additional_kwargs,
merged_additional_kwargs,
skip_compiled,
)

Expand Down
15 changes: 15 additions & 0 deletions vllm/model_executor/kernels/linear/scaled_mm/aiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,21 @@ def apply_block_scaled_mm(
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
if As.dtype != Bs.dtype:
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_upcast_e8m0_to_fp32,
)

if As.dtype == torch.float8_e8m0fnu:
As = _upcast_e8m0_to_fp32(As).contiguous()
else:
As = As.to(torch.float32)

if Bs.dtype == torch.float8_e8m0fnu:
Bs = _upcast_e8m0_to_fp32(Bs).contiguous()
else:
Bs = Bs.to(torch.float32)

out_dtype = self.config.out_dtype
if self.use_triton:
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ class SiluAndMulWithClamp(CustomOp):
def __init__(self, swiglu_limit: float, *, compile_native: bool = True):
super().__init__(compile_native=compile_native)
self.swiglu_limit = float(swiglu_limit)
if current_platform.is_cuda_alike() or current_platform.is_xpu():
if current_platform.is_rocm():
self._forward_method = self.forward_native
elif current_platform.is_cuda_alike() or current_platform.is_xpu():
self.op = torch.ops._C.silu_and_mul_with_clamp
elif current_platform.is_cpu():
self._forward_method = self.forward_native
Expand Down
Loading
Loading