Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@
#ifdef USE_ROCM
// ROCm-compatible FP8 conversion helpers
__device__ __forceinline__ uint8_t rocm_cvt_float_to_fp8_e4m3(float val) {
#if defined(HIP_FP8_TYPE_OCP)
// HIP defines HIP_FP8_TYPE_OCP based on HIP version, not GPU arch. On gfx942
// mfma only supports FNUZ fp8, and the rest of vLLM's gfx942 path (Triton
// indexer / current_platform.fp8_dtype()) uses FNUZ. Gate OCP on __gfx950__
// so the K cache encoding matches what the reader expects.
#if defined(HIP_FP8_TYPE_OCP) && defined(__gfx950__)
__hip_fp8_e4m3 fp8_val(val);
#else
__hip_fp8_e4m3_fnuz fp8_val(val);
Expand Down Expand Up @@ -85,7 +89,13 @@ constexpr int kQuantBlock = 64;
constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7
constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad)
constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576
// FNUZ on gfx942 / OCP on gfx950. FNUZ uses 224.0 (not the dtype's raw
// 240.0) to match the rest of vLLM's FNUZ pipeline; see fp8_utils.py:412-417.
#if defined(USE_ROCM) && (!defined(HIP_FP8_TYPE_OCP) || !defined(__gfx950__))
constexpr float kFp8Max = 224.0f;
#else
constexpr float kFp8Max = 448.0f;
#endif

#ifndef USE_ROCM
// When num_tokens is less than this threshold,
Expand Down
20 changes: 15 additions & 5 deletions tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@
dequantize_and_gather_k_cache,
quantize_and_insert_k_cache,
)
from vllm.platforms import current_platform
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is failing

-Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_matches_reference[16-2048] - AssertionError: RoPE portion not exact: 0.0009765625
==== short test summary info=
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_matches_reference[64-2048]
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_with_dp_padding[16-1-2048]
- AssertionError: RoPE portion not exact: 0.0009765625
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_with_dp_padding[16-5-2048]
AILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_with_dp_padding[64-1-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_with_dp_padding[64-5-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_combined_q_and_kv[16-8-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_combined_q_and_kv[16-64-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_combined_q_and_kv[64-8-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_combined_q_and_kv[64-64-2048] - AssertionError: Tensor-likes are not equal!
sys:1: DeprecationWarning: builtin type swigvarlink has no

10 failed, 38 passed, 16 warnings in 16.72s


# ── Constants matching the kernel ────────────────────────────────────────────
HEAD_DIM = 512
ROPE_DIM = 64
NOPE_DIM = HEAD_DIM - ROPE_DIM # 448
QUANT_BLOCK = 64
FP8_MAX = 448.0
# Match the C++ SWA-K encoder: FNUZ on gfx942 (FP8_MAX=224), OCP elsewhere (448).
USE_FNUZ = current_platform.is_fp8_fnuz()
FP8_MAX = 224.0 if USE_FNUZ else 448.0
HEAD_BYTES = NOPE_DIM + ROPE_DIM * 2 + 8 # 448 + 128 + 8 = 584


Expand Down Expand Up @@ -198,7 +201,7 @@ def test_kv_path_matches_reference(num_tokens: int, block_size: int):
num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device
)
quantize_and_insert_k_cache(
kv_ref, k_cache_ref, slot_mapping, block_size=block_size
kv_ref, k_cache_ref, slot_mapping, block_size=block_size, use_fnuz=USE_FNUZ
)

# ── Fused path (dummy q, single head) ──────────────────────────────────
Expand Down Expand Up @@ -229,7 +232,14 @@ def _dequant(k_cache_2d):
# gather_lens arg is None (use seq_lens)
k_cache_3d = k_cache_2d.view(num_blocks, block_size, HEAD_BYTES)
dequantize_and_gather_k_cache(
out, k_cache_3d, seq_lens, None, block_table, block_size, offset=0
out,
k_cache_3d,
seq_lens,
None,
block_table,
block_size,
offset=0,
use_fnuz=USE_FNUZ,
)
return out[0, :num_tokens]

Expand Down Expand Up @@ -292,7 +302,7 @@ def test_kv_path_with_dp_padding(num_tokens: int, pad: int, block_size: int):
num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device
)
quantize_and_insert_k_cache(
kv_ref, k_cache_ref, slot_mapping, block_size=block_size
kv_ref, k_cache_ref, slot_mapping, block_size=block_size, use_fnuz=USE_FNUZ
)

# Fused: pass full-sized q/kv/positions, shorter slot_mapping.
Expand Down Expand Up @@ -341,7 +351,7 @@ def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int):
num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device
)
quantize_and_insert_k_cache(
kv_ref, k_cache_ref, slot_mapping, block_size=block_size
kv_ref, k_cache_ref, slot_mapping, block_size=block_size, use_fnuz=USE_FNUZ
)

# Fused single call.
Expand Down
25 changes: 22 additions & 3 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,9 +1310,28 @@ def process_fp8_weight_block_strategy(
)

if current_platform.is_fp8_fnuz() and weight.dtype == torch.float8_e4m3fn:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale
)
if weight_scale.dtype == torch.float8_e8m0fnu:
# UE8M0 scales: e8m0 stores exponent-only values (2^(exp-127)),
# so doubling the dequant scale == incrementing the exponent byte
# by 1. Convert the OCP E4M3 weight bytes to FNUZ in place by
# reinterpreting and patching the NaN sentinel (-128 in int8),
# then double the UE8M0 exponent so the dequantized magnitudes
# match.
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
exp_bytes = weight_scale.view(torch.uint8)
weight_scale = (
(exp_bytes.to(torch.int16) + 1)
.clamp(max=254)
.to(torch.uint8)
.view(torch.float8_e8m0fnu)
)
else:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale
)

weight = _maybe_pad_fp8_weight(weight)
return weight, weight_scale
Expand Down
9 changes: 9 additions & 0 deletions vllm/models/deepseek_v4/amd/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DeepseekV4FlashMLASparseBackend,
DeepseekV4SparseMLAAttentionImpl,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
CommonAttentionMetadata,
Expand Down Expand Up @@ -789,6 +790,11 @@ def _forward_prefill(
kv = workspace_manager.get_simultaneous(
((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)[0]
# TODO: workspace is torch.empty() and only the compressed-K prefix +
# SWA window are written per chunk row; the indexer's topK can land in
# the unwritten holes for short sequences. Proper fix is to mask invalid
# rows in the indexer (score = -inf) or in rocm_sparse_attn_prefill.
kv.zero_()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rather weird, there is no need to do this on Mi355x. I would like to avoid doing this as it incurs overhead. Please try to look into the sparse indexer logic of gfx942. I believe fixing the logic there can avoid calling kv.zero_()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rather weird, there is no need to do this on Mi355x. I would like to avoid doing this as it incurs overhead. Please try to look into the sparse indexer logic of gfx942. I believe fixing the logic there can avoid calling kv.zero_()

for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * cls.PREFILL_CHUNK_SIZE
chunk_end = min(chunk_start + cls.PREFILL_CHUNK_SIZE, num_prefills)
Expand All @@ -797,6 +803,7 @@ def _forward_prefill(
assert attn_metadata is not None
assert compressed_k_cache is not None
block_table = attn_metadata.block_table[num_decodes:]
# compressed_k_cache is OCP on every platform (Triton encoder).
dequantize_and_gather_k_cache(
kv[:chunk_size],
compressed_k_cache,
Expand All @@ -805,6 +812,7 @@ def _forward_prefill(
block_table=block_table[chunk_start:chunk_end],
block_size=attn_metadata.block_size // layer.compress_ratio,
offset=0,
use_fnuz=False,
)

swa_block_table = swa_metadata.block_table[num_decodes:]
Expand All @@ -816,6 +824,7 @@ def _forward_prefill(
block_table=swa_block_table[chunk_start:chunk_end],
block_size=swa_metadata.block_size,
offset=N,
use_fnuz=current_platform.is_fp8_fnuz(),
)

query_start = (
Expand Down
45 changes: 39 additions & 6 deletions vllm/models/deepseek_v4/common/ops/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def quantize_and_insert_k_kernel(
block_stride: tl.constexpr, # total bytes per block (padded)
fp8_max: tl.constexpr,
n_quant_blocks: tl.constexpr, # 8 (7 real + 1 padding)
use_fnuz: tl.constexpr = False,
):
"""
Quantize K tensor and insert into paged K cache.
Expand All @@ -49,6 +50,9 @@ def quantize_and_insert_k_kernel(
- [64*576 + 64*8, block_stride): Padding

One program per token.

``use_fnuz=True`` selects FNUZ (``tl.float8e4b8``); default OCP
(``tl.float8e4nv``) matches every production caller.
"""
pid = tl.program_id(0)

Expand Down Expand Up @@ -112,8 +116,11 @@ def quantize_and_insert_k_kernel(
x_scaled = x / scale
x_clamped = tl.clamp(x_scaled, -fp8_max, fp8_max)

# Convert to fp8, then bitcast to uint8 for storage
x_fp8 = x_clamped.to(tl.float8e4nv)
# Convert to fp8 (FNUZ on gfx942, OCP elsewhere), then bitcast to uint8.
if use_fnuz:
x_fp8 = x_clamped.to(tl.float8e4b8)
else:
x_fp8 = x_clamped.to(tl.float8e4nv)
x_uint8 = x_fp8.to(tl.uint8, bitcast=True)

# Store as uint8 (1 byte each)
Expand Down Expand Up @@ -145,6 +152,7 @@ def quantize_and_insert_k_cache(
slot_mapping: torch.Tensor, # [num_tokens] int64
block_size: int = 64,
is_ue8m0: bool = True,
use_fnuz: bool = False,
):
"""
Quantize K tensor and insert into paged K cache.
Expand All @@ -155,6 +163,9 @@ def quantize_and_insert_k_cache(
- Next 64 * 8 = 512 bytes: Scales
- Each token: 8 bytes (uint8 scales, 7 real + 1 padding)
- Padded to multiple of 576

``use_fnuz=True`` switches to FNUZ E4M3 (FP8_MAX=224); default OCP
(FP8_MAX=448) matches every production caller.
"""
assert k.dim() == 2 and k.shape[1] == 512, (
f"K must be [num_tokens, 512], got {k.shape}"
Expand All @@ -171,7 +182,7 @@ def quantize_and_insert_k_cache(
TOKEN_BF16_DIM = 64
TOKEN_SCALE_DIM = 8
QUANT_BLOCK_SIZE = 64
FP8_MAX = 448.0
FP8_MAX = 224.0 if use_fnuz else 448.0 # FNUZ value matches fp8_utils.py
TOKEN_DATA_SIZE = TOKEN_FP8_DIM + TOKEN_BF16_DIM * 2

grid = (num_tokens,)
Expand All @@ -191,6 +202,7 @@ def quantize_and_insert_k_cache(
block_stride=block_stride,
fp8_max=FP8_MAX,
n_quant_blocks=8,
use_fnuz=use_fnuz,
)


Expand All @@ -216,6 +228,7 @@ def _dequantize_and_gather_k_kernel(
output_dim: tl.constexpr, # 512
fp8_max: tl.constexpr,
n_quant_blocks: tl.constexpr, # 7 real blocks
use_fnuz: tl.constexpr = False,
):
batch_idx = tl.program_id(0)
worker_id = tl.program_id(1)
Expand Down Expand Up @@ -273,8 +286,11 @@ def _dequantize_and_gather_k_kernel(
# Load quantized fp8 values (stored as uint8)
x_uint8 = tl.load(token_fp8_ptr + offsets, mask=mask, other=0)

# Bitcast uint8 back to fp8
x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True)
# Bitcast uint8 back to fp8 (FNUZ on gfx942, OCP elsewhere).
if use_fnuz:
x_fp8 = x_uint8.to(tl.float8e4b8, bitcast=True)
else:
x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True)

# Convert fp8 to float32 for computation
x_float = x_fp8.to(tl.float32)
Expand Down Expand Up @@ -317,6 +333,7 @@ def dequantize_and_gather_k_cache_triton(
block_table: torch.Tensor,
block_size: int,
offset: int,
use_fnuz: bool = False,
) -> None:
TOKEN_FP8_DIM = 448
TOKEN_BF16_DIM = 64
Expand Down Expand Up @@ -347,6 +364,7 @@ def dequantize_and_gather_k_cache_triton(
output_dim=512,
fp8_max=FP8_MAX,
n_quant_blocks=7,
use_fnuz=use_fnuz,
)


Expand All @@ -363,7 +381,15 @@ def dequantize_and_gather_k_cache(
block_table: torch.Tensor,
block_size: int,
offset: int,
use_fnuz: bool = False,
) -> None:
"""Dequantize and gather a paged DSv4 K cache.

``use_fnuz`` MUST match the encoder of the specific cache being read:
``False`` for ``compressed_k_cache`` (Triton encoder is OCP everywhere),
``current_platform.is_fp8_fnuz()`` for ``swa_k_cache`` (C++ encoder
writes FNUZ on gfx942 and OCP on gfx950).
"""
if has_cutedsl():
# lazily import, otherwise some tests fail due to CUDA driver init failure.
from vllm.models.deepseek_v4.nvidia.ops.dequant_gather_k_cutedsl import (
Expand All @@ -376,7 +402,14 @@ def dequantize_and_gather_k_cache(
return

dequantize_and_gather_k_cache_triton(
out, k_cache, seq_lens, gather_lens, block_table, block_size, offset
out,
k_cache,
seq_lens,
gather_lens,
block_table,
block_size,
offset,
use_fnuz=use_fnuz,
)


Expand Down
41 changes: 33 additions & 8 deletions vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,13 @@ def fp8_mqa_logits_torch(
)
mask = mask_lo & mask_hi

score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
# ``score`` is [H, M, N]; ``scale`` is the per-KV-token scale, which
# vLLM callers hand us as ``[N, 1]`` (a ``[N, 4]`` uint8 buffer cast
# to fp32). PyTorch right-aligns dimensions for broadcasting, so a
# naked ``score * scale`` would align ``scale``'s leading dim with
# ``score``'s M dim and raise a shape mismatch. Flatten to ``[N]`` so
# broadcasting lines up with the last dim of ``score``.
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale.reshape(-1)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))

Expand Down Expand Up @@ -557,13 +563,28 @@ def rocm_fp8_mqa_logits(
# path after aiter merge this kernel into main
from vllm._aiter_ops import rocm_aiter_ops

k_fp8, scale = kv

# gfx942: AITER's bundled fp8_mqa_logits launches with BLOCK_KV=128 +
# num_stages=2 (~96 KiB LDS), exceeding MI300X's 64 KiB LDS so it aborts
# with OutOfResources. Route gfx942 to a vendored copy that drops to
# BLOCK_KV=64 + num_stages=1 (~33 KiB) per ROCm/aiter#3257. Remove this
# branch once vLLM bumps AITER to a version that includes that PR.
if _ON_GFX942 and rocm_aiter_ops.is_enabled():
from vllm.v1.attention.ops.triton_fp8_mqa_logits import (
fp8_mqa_logits_gfx942,
)

return fp8_mqa_logits_gfx942(
q, k_fp8, scale, weights, cu_seqlen_ks, cu_seqlen_ke
)

aiter_mqa_logits_module = None
if rocm_aiter_ops.is_enabled():
aiter_mqa_logits_module = mqa_logits_module()

if aiter_mqa_logits_module is not None:
fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits
k_fp8, scale = kv
return fp8_mqa_logits(q, k_fp8, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
else:
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
Expand Down Expand Up @@ -1167,7 +1188,10 @@ def _sparse_attn_decode_ragged_kernel(
NOPE_DIM: tl.constexpr,
NOPE_BLOCK: tl.constexpr,
ROPE_DIM: tl.constexpr,
IS_FNUZ: tl.constexpr,
# SWA K-cache (main): C++ encoder writes FNUZ on gfx942, OCP on gfx950.
# Compressed K-cache (extra): Triton encoder writes OCP everywhere.
IS_FNUZ_MAIN: tl.constexpr,
IS_FNUZ_EXTRA: tl.constexpr,
BLOCK_H: tl.constexpr,
BLOCK_K: tl.constexpr,
):
Expand Down Expand Up @@ -1224,8 +1248,8 @@ def _sparse_attn_decode_ragged_kernel(
mask=valid[:, None] & nope_mask[None, :],
other=0,
)
if IS_FNUZ:
x_fp8 = x_uint8.to(tl.float8e4b15, bitcast=True)
if IS_FNUZ_MAIN:
x_fp8 = x_uint8.to(tl.float8e4b8, bitcast=True)
else:
x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True)
encoded_scales = tl.load(
Expand Down Expand Up @@ -1292,8 +1316,8 @@ def _sparse_attn_decode_ragged_kernel(
mask=valid[:, None] & nope_mask[None, :],
other=0,
)
if IS_FNUZ:
x_fp8 = x_uint8.to(tl.float8e4b15, bitcast=True)
if IS_FNUZ_EXTRA:
x_fp8 = x_uint8.to(tl.float8e4b8, bitcast=True)
else:
x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True)
encoded_scales = tl.load(
Expand Down Expand Up @@ -1570,7 +1594,8 @@ def _rocm_sparse_attn_decode_ragged_triton(
NOPE_DIM=nope_head_dim,
NOPE_BLOCK=triton.next_power_of_2(nope_head_dim),
ROPE_DIM=rope_head_dim,
IS_FNUZ=current_platform.is_fp8_fnuz(),
IS_FNUZ_MAIN=current_platform.is_fp8_fnuz(),
IS_FNUZ_EXTRA=False,
BLOCK_H=block_h,
BLOCK_K=block_k,
num_warps=8,
Expand Down
Loading
Loading