Skip to content
69 changes: 69 additions & 0 deletions tests/kernels/attention/test_merge_attn_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,72 @@ def diff(a: torch.Tensor, b: torch.Tensor):
len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES)
):
generate_markdown_table()


@pytest.mark.parametrize("num_tokens", [32, 128, 512])
@pytest.mark.parametrize("num_query_heads", [8, 32])
@pytest.mark.parametrize("head_size", [64, 128])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
@torch.inference_mode()
def test_merge_attn_states_both_lse_neg_inf(
num_tokens: int,
num_query_heads: int,
head_size: int,
input_dtype: torch.dtype,
):
"""Regression test for NaN when both prefix_lse and suffix_lse are -inf.

This happens during chunked prefill when a request has zero context
tokens — both sides produce no valid attention scores, so both LSEs
are -inf. The kernel must produce finite (zero) output, not NaN.
"""
device = "cuda"

prefix_output = torch.randn(
(num_tokens, num_query_heads, head_size),
dtype=input_dtype, device=device,
)
suffix_output = torch.randn(
(num_tokens, num_query_heads, head_size),
dtype=input_dtype, device=device,
)
output = torch.zeros_like(prefix_output)

prefix_lse = torch.randn(
num_query_heads, num_tokens, dtype=torch.float32, device=device,
)
suffix_lse = torch.randn(
num_query_heads, num_tokens, dtype=torch.float32, device=device,
)

# --- Inject edge cases ---
# ~25% of (head, token) positions: both LSEs = -inf (the NaN trigger)
both_neg_inf_mask = torch.rand(num_query_heads, num_tokens) < 0.25
prefix_lse[both_neg_inf_mask] = float("-inf")
suffix_lse[both_neg_inf_mask] = float("-inf")

# ~10% of remaining positions: both LSEs = +inf (FA2 empty-sequence style)
fa2_mask = (torch.rand(num_query_heads, num_tokens) < 0.10) & ~both_neg_inf_mask
prefix_lse[fa2_mask] = float("inf")
suffix_lse[fa2_mask] = float("inf")

merge_attn_states_triton(
output, prefix_output, prefix_lse, suffix_output, suffix_lse,
)

# 1) No NaN anywhere in the output.
nan_count = torch.isnan(output).sum().item()
assert nan_count == 0, (
f"Found {nan_count} NaN elements in output"
)

# 2) Positions where both LSEs were -inf must produce zero output
# (no attention scores to merge → zero weight on both sides).
# both_neg_inf_mask is [num_heads, num_tokens], output is
# [num_tokens, num_heads, head_size].
both_inf_idx = both_neg_inf_mask.T.unsqueeze(-1).expand_as(output)
zero_region = output[both_inf_idx]
assert torch.all(zero_region == 0), (
f"Expected zero output for both-neg-inf positions, "
f"got max abs = {zero_region.abs().max().item()}"
)
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/deepseek_v4_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,7 +1012,8 @@ def _forward_prefill(
M,
N,
)
flash_mla_sparse_fwd(

output_chunk, _, _ = flash_mla_sparse_fwd(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to rethink this. This changes also reflects on CUDA code path

q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
Expand Down
42 changes: 30 additions & 12 deletions vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,19 +901,27 @@ def _interleave_mxfp4_cutlass_sm90(w):
.view(e, n, -1)
)

# View as native FP4 dtype for AITER shuffle
w13_weight.data = w13_weight.data.view(torch.float4_e2m1fn_x2)
w2_weight.data = w2_weight.data.view(torch.float4_e2m1fn_x2)

# Shuffle weights and scales for AITER CK kernel layout
w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True)
# AITER CK kernels key off torch.float4_e2m1fn_x2, not raw uint8.
# Return fresh Parameters instead of assigning .data so the dtype and
# Tensor metadata survive replace_parameter() unchanged.
w13_weight = torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight_a16w4(
w13_weight.data.view(torch.float4_e2m1fn_x2), 16, True
).view(torch.float4_e2m1fn_x2),
requires_grad=False,
)
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w13_weight_scale.view(-1, w13_weight_scale.shape[-1]),
num_experts,
True,
)

w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False)
w2_weight = torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight_a16w4(
w2_weight.data.view(torch.float4_e2m1fn_x2), 16, False
).view(torch.float4_e2m1fn_x2),
requires_grad=False,
)
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w2_weight_scale.view(-1, w2_weight_scale.shape[-1]),
num_experts,
Expand Down Expand Up @@ -1273,17 +1281,27 @@ def convert_weight_to_mxfp4_moe_kernel_format(
.view(e, n, -1)
)

w13_weight.data = w13_weight.data.view(torch.float4_e2m1fn_x2)
w2_weight.data = w2_weight.data.view(torch.float4_e2m1fn_x2)

w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True)
# AITER CK kernels key off torch.float4_e2m1fn_x2, not raw uint8.
# Return fresh Parameters instead of assigning .data so the dtype and
# Tensor metadata survive replace_parameter() unchanged.
w13_weight = torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight_a16w4(
w13_weight.data.view(torch.float4_e2m1fn_x2), 16, True
).view(torch.float4_e2m1fn_x2),
requires_grad=False,
)
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w13_weight_scale.view(-1, w13_weight_scale.shape[-1]),
num_experts,
True,
)

w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False)
w2_weight = torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight_a16w4(
w2_weight.data.view(torch.float4_e2m1fn_x2), 16, False
).view(torch.float4_e2m1fn_x2),
requires_grad=False,
)
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w2_weight_scale.view(-1, w2_weight_scale.shape[-1]),
num_experts,
Expand Down
31 changes: 30 additions & 1 deletion vllm/model_executor/layers/mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,10 @@ def mhc_post(
comb_res_mix.to(torch.float32),
residual.to(torch.float32),
)
post_term = post_layer_mix.to(torch.float32) * x.unsqueeze(-2).to(torch.float32)
post_mix = post_layer_mix.to(torch.float32)
if post_mix.ndim == x.ndim:
post_mix = post_mix.unsqueeze(-1)
post_term = post_mix * x.unsqueeze(-2).to(torch.float32)
return (mixed_residual + post_term).to(residual.dtype)
out = torch.empty_like(residual)
mhc_post_tilelang(
Expand Down Expand Up @@ -653,6 +656,32 @@ def mhc_fused_post_pre(
post_layer_mix_flat = post_layer_mix.view(num_tokens, hc_mult)
comb_res_mix_flat = comb_res_mix.view(num_tokens, hc_mult, hc_mult)

if current_platform.is_rocm():
residual_cur = mhc_post(
x_flat,
residual_flat,
post_layer_mix_flat,
comb_res_mix_flat,
)
post_mix_cur, comb_mix_cur, layer_input_cur = mhc_pre(
residual_cur,
fn,
hc_scale,
hc_base,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_mult_value,
sinkhorn_repeat,
n_splits,
)
return (
residual_cur.view(*outer_shape, hc_mult, hidden_size),
post_mix_cur.view(*outer_shape, hc_mult, 1),
comb_mix_cur.view(*outer_shape, hc_mult, hc_mult),
layer_input_cur.view(*outer_shape, hidden_size),
)

fma_token_threshold = 16
if num_tokens <= fma_token_threshold:
# TODO(gnovack): investigate autotuning these heuristics
Expand Down
16 changes: 9 additions & 7 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,15 +975,13 @@ def requant_weight_ue8m0_inplace(


def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor:
"""Upcast E8M0 (exponent-only) scale to float32.
"""Decode E8M0 (exponent-only) scale tensors to float32.

E8M0 stores only the 8-bit biased exponent (bias=127). To convert
to float32 we place those 8 bits into the exponent field of an
IEEE-754 float32 (bits 23-30) with sign=0 and mantissa=0.
E8M0 stores an unsigned exponent with IEEE-754 bias 127. Keep the
conversion in one helper so CUDA DeepGEMM and ROCm fallback paths use
identical scale semantics for checkpoints that store UE8M0 scales.
"""
exp_bits = scale.view(torch.uint8).to(torch.int32)
fp32_bits = exp_bits << 23
return fp32_bits.view(torch.float32)
return torch.exp2(scale.view(torch.uint8).to(torch.float32) - 127)


def deepgemm_post_process_fp8_weight_block(
Expand Down Expand Up @@ -1293,6 +1291,10 @@ def process_fp8_weight_block_strategy(
weight=weight, weight_scale=weight_scale
)

if weight_scale.dtype == torch.float8_e8m0fnu and not is_deep_gemm_e8m0_used():
# ROCm fallback kernels do not accept UE8M0 scale tensors directly.
weight_scale = _upcast_e8m0_to_fp32(weight_scale)

weight = _maybe_pad_fp8_weight(weight)
return weight, weight_scale

Expand Down
12 changes: 10 additions & 2 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,15 @@ def normalize_e4m3fn_to_e4m3fnuz(
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if weight_scale.dtype == torch.float8_e8m0fnu:
weight_scale = weight_scale.view(torch.uint8).to(torch.float32)
weight_scale = torch.exp2(weight_scale - 127) * 2.0
else:
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
if input_scale.dtype == torch.float8_e8m0fnu:
input_scale = input_scale.view(torch.uint8).to(torch.float32)
input_scale = torch.exp2(input_scale - 127) * 2.0
else:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale
140 changes: 140 additions & 0 deletions vllm/utils/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,149 @@ def fp8_gemm_nt(*args, **kwargs):
return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs)


def _decode_fp8_scale(scale: torch.Tensor) -> torch.Tensor:
if scale.dtype == torch.float8_e8m0fnu:
return torch.exp2(scale.view(torch.uint8).to(torch.float32) - 127.0)
return scale.to(torch.float32)


def _dequant_fp8_block(x: torch.Tensor, scale: torch.Tensor | None) -> torch.Tensor:
if scale is None:
return x.to(torch.float32)

scale_f = _decode_fp8_scale(scale)
flat_scale = scale_f.reshape(-1)
if flat_scale.numel() == 1:
return x.to(torch.float32) * flat_scale[0]

if x.numel() % flat_scale.numel() != 0:
# Keep the fallback best-effort for unusual scale layouts; the real
# DeepGEMM path remains authoritative on CUDA.
return x.to(torch.float32) * flat_scale.mean()

block_size = x.numel() // flat_scale.numel()
expanded_scale = flat_scale.repeat_interleave(block_size)
return (x.reshape(-1).to(torch.float32) * expanded_scale).reshape(x.shape)


def _dequant_fp8_2d_block(
x: torch.Tensor,
scale: torch.Tensor,
rows: int,
cols: int,
block_rows: int = 128,
block_cols: int = 128,
) -> torch.Tensor:
"""Dequantize a 2D FP8 matrix with row/column block scales.

DeepSeek-V4 wo_a stores weights as [G * R, D] and scales as
[(G * R) / 128, D / 128]. The generic fallback flattens scales and
repeats them globally, which mixes row and column blocks. Expand each
scale along its own matrix dimension instead.
"""
scale_f = _decode_fp8_scale(scale)
if scale_f.ndim < 2:
return _dequant_fp8_block(x, scale)

row_blocks, col_blocks = scale_f.shape[-2:]
expected_row_blocks = cdiv(rows, block_rows)
expected_col_blocks = cdiv(cols, block_cols)
if row_blocks != expected_row_blocks or col_blocks != expected_col_blocks:
return _dequant_fp8_block(x, scale)

expanded = scale_f.repeat_interleave(block_rows, dim=-2)[..., :rows, :]
expanded = expanded.repeat_interleave(block_cols, dim=-1)[..., :, :cols]
return x.reshape(rows, cols).to(torch.float32) * expanded


def _reshape_to_subscripts(
tensor: torch.Tensor,
subscripts: str,
dim_map: dict[str, int],
) -> torch.Tensor:
target_ndim = len(subscripts)
if tensor.ndim == target_ndim:
return tensor

known_shape = [dim_map.get(dim) for dim in subscripts]
known_product = 1
unknown_count = 0
for dim in known_shape:
if dim is None:
unknown_count += 1
else:
known_product *= dim

if unknown_count == 0:
return tensor.reshape(known_shape)
if unknown_count == 1 and tensor.numel() % known_product == 0:
unknown_dim = tensor.numel() // known_product
shape = [dim if dim is not None else unknown_dim for dim in known_shape]
return tensor.reshape(shape)

# Let torch.einsum raise the shape error with the original tensor.
return tensor


def _rocm_fp8_einsum_fallback(
equation: str,
a_tuple: tuple[torch.Tensor, torch.Tensor | None],
b_tuple: tuple[torch.Tensor, torch.Tensor | None],
out: torch.Tensor,
recipe: tuple[int, ...] | None = None,
) -> None:
del recipe
a, a_scale = a_tuple
b, b_scale = b_tuple

lhs, rhs = equation.split("->")
a_subs, b_subs = lhs.split(",")

a_f = _dequant_fp8_block(a, a_scale)
dim_map: dict[str, int] = {}
if a_f.ndim == len(a_subs):
dim_map.update(zip(a_subs, a_f.shape))
if out.ndim == len(rhs):
dim_map.update(zip(rhs, out.shape))

# DeepSeek-V4 MLA output projection: activation is [T, G, D] and
# wo_a is stored as [G * R, D] with 128x128 block scales. Preserve the
# Triton inverse-RoPE activation path, but dequantize wo_a with its true
# 2D scale layout before reshaping to the einsum subscripts.
if (
current_platform.is_rocm()
and equation == "bhr,hdr->bhd"
and b_scale is not None
and b.ndim == 2
and {"h", "d", "r"}.issubset(dim_map)
):
h = dim_map["h"]
d = dim_map["d"]
r = dim_map["r"]
if b.numel() == h * d * r:
b_f = _dequant_fp8_2d_block(b, b_scale, h * d, r).reshape(h, d, r)
else:
b_f = _dequant_fp8_block(b, b_scale)
else:
b_f = _dequant_fp8_block(b, b_scale)

if b_f.ndim == len(b_subs):
dim_map.update(zip(b_subs, b_f.shape))

a_f = _reshape_to_subscripts(a_f, a_subs, dim_map)
if a_f.ndim == len(a_subs):
dim_map.update(zip(a_subs, a_f.shape))
b_f = _reshape_to_subscripts(b_f, b_subs, dim_map)

result = torch.einsum(equation, a_f, b_f)
out.copy_(result.to(out.dtype))


def fp8_einsum(*args, **kwargs):
_lazy_init()
if _fp8_einsum_impl is None:
if current_platform.is_rocm():
return _rocm_fp8_einsum_fallback(*args, **kwargs)
return _missing(*args, **kwargs)
return _fp8_einsum_impl(*args, **kwargs)

Expand Down
Loading
Loading