Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/kernels/attention/test_merge_attn_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,72 @@ def diff(a: torch.Tensor, b: torch.Tensor):
len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES)
):
generate_markdown_table()


@pytest.mark.parametrize("num_tokens", [32, 128, 512])
@pytest.mark.parametrize("num_query_heads", [8, 32])
@pytest.mark.parametrize("head_size", [64, 128])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
@torch.inference_mode()
def test_merge_attn_states_both_lse_neg_inf(
num_tokens: int,
num_query_heads: int,
head_size: int,
input_dtype: torch.dtype,
):
"""Regression test for NaN when both prefix_lse and suffix_lse are -inf.

This happens during chunked prefill when a request has zero context
tokens — both sides produce no valid attention scores, so both LSEs
are -inf. The kernel must produce finite (zero) output, not NaN.
"""
device = "cuda"

prefix_output = torch.randn(
(num_tokens, num_query_heads, head_size),
dtype=input_dtype, device=device,
)
suffix_output = torch.randn(
(num_tokens, num_query_heads, head_size),
dtype=input_dtype, device=device,
)
output = torch.zeros_like(prefix_output)

prefix_lse = torch.randn(
num_query_heads, num_tokens, dtype=torch.float32, device=device,
)
suffix_lse = torch.randn(
num_query_heads, num_tokens, dtype=torch.float32, device=device,
)

# --- Inject edge cases ---
# ~25% of (head, token) positions: both LSEs = -inf (the NaN trigger)
both_neg_inf_mask = torch.rand(num_query_heads, num_tokens) < 0.25
prefix_lse[both_neg_inf_mask] = float("-inf")
suffix_lse[both_neg_inf_mask] = float("-inf")

# ~10% of remaining positions: both LSEs = +inf (FA2 empty-sequence style)
fa2_mask = (torch.rand(num_query_heads, num_tokens) < 0.10) & ~both_neg_inf_mask
prefix_lse[fa2_mask] = float("inf")
suffix_lse[fa2_mask] = float("inf")

merge_attn_states_triton(
output, prefix_output, prefix_lse, suffix_output, suffix_lse,
)

# 1) No NaN anywhere in the output.
nan_count = torch.isnan(output).sum().item()
assert nan_count == 0, (
f"Found {nan_count} NaN elements in output"
)

# 2) Positions where both LSEs were -inf must produce zero output
# (no attention scores to merge → zero weight on both sides).
# both_neg_inf_mask is [num_heads, num_tokens], output is
# [num_tokens, num_heads, head_size].
both_inf_idx = both_neg_inf_mask.T.unsqueeze(-1).expand_as(output)
zero_region = output[both_inf_idx]
assert torch.all(zero_region == 0), (
f"Expected zero output for both-neg-inf positions, "
f"got max abs = {zero_region.abs().max().item()}"
)
88 changes: 37 additions & 51 deletions vllm/model_executor/layers/deepseek_v4_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@
fused_inv_rope_fp8_quant,
fused_q_kv_rmsnorm,
)
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_forward_decode_fallback,
rocm_inv_rope_einsum,
rocm_sparse_attn_prefill,
)
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum

from aiter.ops.triton.rope.inv_rope_fp8_quant import inv_rope_fp8_quant

if TYPE_CHECKING:
from vllm.v1.attention.backends.mla.sparse_swa import (
Expand Down Expand Up @@ -310,14 +308,33 @@ def forward(

# Keep ROCm on the BF16 reference wo_a path util kernel ready.
if current_platform.is_rocm():
z = rocm_inv_rope_einsum(
self.rotary_emb,
o_fp8, o_scale = inv_rope_fp8_quant(
o,
positions,
self.rope_head_dim,
self.n_local_groups,
self.o_lora_rank,
self.wo_a,
self.rotary_emb.cos_sin_cache,
n_groups=self.n_local_groups,
heads_per_group=self.n_local_heads // self.n_local_groups,
rope_head_dim=self.rope_head_dim,
)
o_fp8 = o_fp8.transpose(0, 1).contiguous()
o_scale = o_scale.transpose(0, 1).contiguous()

wo_a_fp8 = self.wo_a.weight
wo_a_scale = self.wo_a.weight_scale_inv

z = torch.empty(
(num_tokens, self.n_local_groups, self.o_lora_rank),
device=o.device,
dtype=torch.bfloat16,
)
torch.ops.vllm.deepseek_v4_fp8_einsum(
o_fp8,
o_scale,
wo_a_fp8,
wo_a_scale,
z,
"bhr,hdr->bhd",
list(self._einsum_recipe),
)
return self.wo_b(z.flatten(1))

Expand Down Expand Up @@ -839,25 +856,6 @@ def _forward_decode(
swa_indices = swa_metadata.decode_swa_indices
swa_lens = swa_metadata.decode_swa_lens

if current_platform.is_rocm():
rocm_forward_decode_fallback(
q=q,
kv_cache=kv_cache,
swa_k_cache=self.swa_cache_layer.kv_cache,
swa_only=swa_only,
topk_indices=topk_indices,
topk_lens=topk_lens,
swa_indices=swa_indices,
swa_lens=swa_lens,
attn_sink=self.attn_sink,
scale=self.scale,
head_dim=self.head_dim,
nope_head_dim=self.nope_head_dim,
rope_head_dim=self.rope_head_dim,
output=output,
)
return

# We treat queries in the same seq as different queries
# and later we only attend by generated indices.
# q arrives pre-padded to self.padded_heads by the outer wrapper.
Expand Down Expand Up @@ -1022,27 +1020,15 @@ def _forward_prefill(
N,
)

if current_platform.is_rocm():
rocm_sparse_attn_prefill(
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
topk_length=combined_lens,
scale=self.scale,
head_dim=self.head_dim,
attn_sink=self.attn_sink,
output=output[query_start:query_end],
)
else:
output_chunk, _, _ = flash_mla_sparse_fwd(
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
sm_scale=self.scale,
attn_sink=self.attn_sink,
topk_length=combined_lens,
out=output[query_start:query_end],
)
output_chunk, _, _ = flash_mla_sparse_fwd(
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
sm_scale=self.scale,
attn_sink=self.attn_sink,
topk_length=combined_lens,
out=output[query_start:query_end],
)


class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):
Expand Down
42 changes: 30 additions & 12 deletions vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,19 +923,27 @@ def _interleave_mxfp4_cutlass_sm90(w):
.view(e, n, -1)
)

# View as native FP4 dtype for AITER shuffle
w13_weight.data = w13_weight.data.view(torch.float4_e2m1fn_x2)
w2_weight.data = w2_weight.data.view(torch.float4_e2m1fn_x2)

# Shuffle weights and scales for AITER CK kernel layout
w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True)
# AITER CK kernels key off torch.float4_e2m1fn_x2, not raw uint8.
# Return fresh Parameters instead of assigning .data so the dtype and
# Tensor metadata survive replace_parameter() unchanged.
w13_weight = torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight_a16w4(
w13_weight.data.view(torch.float4_e2m1fn_x2), 16, True
).view(torch.float4_e2m1fn_x2),
requires_grad=False,
)
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w13_weight_scale.view(-1, w13_weight_scale.shape[-1]),
num_experts,
True,
)

w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False)
w2_weight = torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight_a16w4(
w2_weight.data.view(torch.float4_e2m1fn_x2), 16, False
).view(torch.float4_e2m1fn_x2),
requires_grad=False,
)
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w2_weight_scale.view(-1, w2_weight_scale.shape[-1]),
num_experts,
Expand Down Expand Up @@ -1295,17 +1303,27 @@ def convert_weight_to_mxfp4_moe_kernel_format(
.view(e, n, -1)
)

w13_weight.data = w13_weight.data.view(torch.float4_e2m1fn_x2)
w2_weight.data = w2_weight.data.view(torch.float4_e2m1fn_x2)

w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True)
# AITER CK kernels key off torch.float4_e2m1fn_x2, not raw uint8.
# Return fresh Parameters instead of assigning .data so the dtype and
# Tensor metadata survive replace_parameter() unchanged.
w13_weight = torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight_a16w4(
w13_weight.data.view(torch.float4_e2m1fn_x2), 16, True
).view(torch.float4_e2m1fn_x2),
requires_grad=False,
)
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w13_weight_scale.view(-1, w13_weight_scale.shape[-1]),
num_experts,
True,
)

w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False)
w2_weight = torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight_a16w4(
w2_weight.data.view(torch.float4_e2m1fn_x2), 16, False
).view(torch.float4_e2m1fn_x2),
requires_grad=False,
)
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w2_weight_scale.view(-1, w2_weight_scale.shape[-1]),
num_experts,
Expand Down
16 changes: 9 additions & 7 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,15 +975,13 @@ def requant_weight_ue8m0_inplace(


def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor:
"""Upcast E8M0 (exponent-only) scale to float32.
"""Decode E8M0 (exponent-only) scale tensors to float32.

E8M0 stores only the 8-bit biased exponent (bias=127). To convert
to float32 we place those 8 bits into the exponent field of an
IEEE-754 float32 (bits 23-30) with sign=0 and mantissa=0.
E8M0 stores an unsigned exponent with IEEE-754 bias 127. Keep the
conversion in one helper so CUDA DeepGEMM and ROCm fallback paths use
identical scale semantics for checkpoints that store UE8M0 scales.
"""
exp_bits = scale.view(torch.uint8).to(torch.int32)
fp32_bits = exp_bits << 23
return fp32_bits.view(torch.float32)
return torch.exp2(scale.view(torch.uint8).to(torch.float32) - 127)


def deepgemm_post_process_fp8_weight_block(
Expand Down Expand Up @@ -1293,6 +1291,10 @@ def process_fp8_weight_block_strategy(
weight=weight, weight_scale=weight_scale
)

if weight_scale.dtype == torch.float8_e8m0fnu and not is_deep_gemm_e8m0_used():
# ROCm fallback kernels do not accept UE8M0 scale tensors directly.
weight_scale = _upcast_e8m0_to_fp32(weight_scale)

weight = _maybe_pad_fp8_weight(weight)
return weight, weight_scale

Expand Down
12 changes: 10 additions & 2 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,15 @@ def normalize_e4m3fn_to_e4m3fnuz(
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if weight_scale.dtype == torch.float8_e8m0fnu:
weight_scale = weight_scale.view(torch.uint8).to(torch.float32)
weight_scale = torch.exp2(weight_scale - 127) * 2.0
else:
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
if input_scale.dtype == torch.float8_e8m0fnu:
input_scale = input_scale.view(torch.uint8).to(torch.float32)
input_scale = torch.exp2(input_scale - 127) * 2.0
else:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale
Loading