Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/deepseek_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@ def forward(
TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]),
STATE_WIDTH=state_width,
COMPRESS_RATIO=self.compress_ratio,
launch_pdl=False,
)

# Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write.
Expand Down Expand Up @@ -378,7 +377,6 @@ def forward(
SCALE_DIM=self._scale_dim,
KV_BLOCK_STRIDE=kv_cache.stride(0),
num_warps=self._num_warps,
launch_pdl=False,
)


Expand Down
354 changes: 349 additions & 5 deletions vllm/model_executor/layers/deepseek_v4_attention.py

Large diffs are not rendered by default.

153 changes: 153 additions & 0 deletions vllm/model_executor/layers/deepseek_v4_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# SPDX-License-Identifier: Apache-2.0
"""Debug-only dump helpers for DeepSeek-V4 accuracy triage.

All functionality is disabled unless VLLM_DSV4_DUMP_ROOT is set.
"""

from __future__ import annotations

import json
import os
import re
import time
from pathlib import Path
from typing import Any

import torch

_COUNTS: dict[str, int] = {}
_LAYER_RE = re.compile(r"(?:^|\.)layers\.(\d+)(?:\.|$)")


def _truthy(value: str | None) -> bool:
return value not in (None, "", "0", "false", "False", "no", "No")


def enabled() -> bool:
return _truthy(os.environ.get("VLLM_DSV4_DUMP_ROOT"))


def side() -> str:
return os.environ.get("VLLM_DSV4_DUMP_SIDE", "fp8")


def _rank() -> str:
return os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0"


def layer_from_name(name: str) -> int | None:
m = _LAYER_RE.search(name)
return int(m.group(1)) if m else None


def _selected_layer(name: str, layer_idx: int | None) -> bool:
spec = os.environ.get("VLLM_DSV4_DUMP_LAYERS", "")
if not spec:
return True
if layer_idx is None:
return True
vals: set[int] = set()
for item in spec.split(","):
item = item.strip()
if not item:
continue
if "-" in item:
lo, hi = item.split("-", 1)
vals.update(range(int(lo), int(hi) + 1))
else:
vals.add(int(item))
return layer_idx in vals


def _safe_float(x: torch.Tensor, fn: str) -> float | None:
try:
if fn == "mean":
return float(x.mean().item())
if fn == "std":
return float(x.std(unbiased=False).item())
if fn == "amax":
return float(x.abs().amax().item())
if fn == "amin":
return float(x.amin().item())
if fn == "max":
return float(x.amax().item())
except Exception:
return None
return None


def _summary(tensor: torch.Tensor) -> dict[str, Any]:
t = tensor.detach()
out: dict[str, Any] = {
"shape": list(t.shape),
"dtype": str(t.dtype),
"device": str(t.device),
"numel": int(t.numel()),
}
if t.numel() == 0:
out.update({"finite": True, "nan_count": 0, "inf_count": 0})
return out
try:
tf = t.float()
except Exception:
tf = t.to(torch.float32)
try:
finite = torch.isfinite(tf)
out["finite"] = bool(finite.all().item())
out["nan_count"] = int(torch.isnan(tf).sum().item())
out["inf_count"] = int(torch.isinf(tf).sum().item())
except Exception:
out["finite"] = None
out["nan_count"] = None
out["inf_count"] = None
out["mean"] = _safe_float(tf, "mean")
out["std"] = _safe_float(tf, "std")
out["amax_abs"] = _safe_float(tf, "amax")
out["min"] = _safe_float(tf, "amin")
out["max"] = _safe_float(tf, "max")
if tf.ndim >= 1 and tf.shape[-1] > 0 and tf.numel() <= 2_000_000:
try:
row = tf.reshape(-1, tf.shape[-1])[-1]
k = min(8, row.numel())
vals, idx = torch.topk(row, k=k)
out["last_row_top_ids"] = [int(x) for x in idx.cpu().tolist()]
out["last_row_top_vals"] = [float(x) for x in vals.cpu().tolist()]
except Exception:
pass
return out


def dump_tensor(name: str, tensor: torch.Tensor | None, *, layer_idx: int | None = None,
note: str | None = None, max_writes: int | None = None) -> None:
if not enabled() or tensor is None:
return
if layer_idx is None:
layer_idx = layer_from_name(name)
if not _selected_layer(name, layer_idx):
return
key = f"{side()}:{_rank()}:{name}"
if max_writes is None:
max_writes = int(os.environ.get("VLLM_DSV4_DUMP_MAX_WRITES", "16"))
count = _COUNTS.get(key, 0)
if count >= max_writes:
return
_COUNTS[key] = count + 1
root = Path(os.environ["VLLM_DSV4_DUMP_ROOT"]) / side()
root.mkdir(parents=True, exist_ok=True)
rec: dict[str, Any] = {
"ts": time.time(),
"side": side(),
"rank": _rank(),
"name": name,
"layer_idx": layer_idx,
"write_idx": count,
"note": note,
}
rec.update(_summary(tensor))
with (root / f"rank_{_rank()}_summary.jsonl").open("a") as f:
f.write(json.dumps(rec, ensure_ascii=False, allow_nan=True) + "\n")
if _truthy(os.environ.get("VLLM_DSV4_DUMP_FULL_TENSOR")):
import numpy as np
safe = re.sub(r"[^A-Za-z0-9_.-]+", "_", name)
arr = tensor.detach().float().cpu().numpy()
np.savez_compressed(root / f"rank_{_rank()}_{count:04d}_{safe}.npz", tensor=arr)
90 changes: 84 additions & 6 deletions vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,8 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
_AVAILABLE_BACKENDS = [
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.DEEPGEMM_MXFP4,
# TRITON_UNFUSED has bug with MTP support
# TODO re-enable after kernel is fixed
# TRITON_UNFUSED
Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.TRITON_UNFUSED,
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
]
Expand Down Expand Up @@ -836,14 +835,24 @@ def _interleave_mxfp4_cutlass_sm90(w):
w2_weight.data = w2_weight.data.view(torch.float4_e2m1fn_x2)

# Shuffle weights and scales for AITER CK kernel layout
w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True)
w13_weight = torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True).view(
torch.float4_e2m1fn_x2
),
requires_grad=False,
)
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w13_weight_scale.view(-1, w13_weight_scale.shape[-1]),
num_experts,
True,
)

w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False)
w2_weight = torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False).view(
torch.float4_e2m1fn_x2
),
requires_grad=False,
)
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w2_weight_scale.view(-1, w2_weight_scale.shape[-1]),
num_experts,
Expand Down Expand Up @@ -1159,10 +1168,79 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
w13_bias,
w2_bias,
)
elif mxfp4_backend == Mxfp4MoeBackend.AITER:
from vllm._aiter_ops import rocm_aiter_ops

w13_weight = w13_weight.data
w2_weight = w2_weight.data
w13_weight_scale = w13_weight_scale.data
w2_weight_scale = w2_weight_scale.data
if w13_bias is not None:
w13_bias = w13_bias.data.to(torch.float32)
if w2_bias is not None:
w2_bias = w2_bias.data.to(torch.float32)

e, n, k = w13_weight.shape

# De-interleave w13 rows: gate/up pairs -> contiguous gate, up blocks.
w13_weight = (
w13_weight.view(torch.uint8)
.view(e, n // 2, 2, k)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, k)
)
w13_weight_scale = (
w13_weight_scale.view(e, n // 2, 2, -1)
.permute(0, 2, 1, 3)
.contiguous()
.view(e, n, -1)
)

# AITER CK kernels key off torch.float4_e2m1fn_x2, not raw uint8.
w13_weight = torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight_a16w4(
w13_weight.view(torch.float4_e2m1fn_x2), 16, True
).view(torch.float4_e2m1fn_x2),
requires_grad=False,
)
w2_weight = torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight_a16w4(
w2_weight.view(torch.float4_e2m1fn_x2), 16, False
).view(torch.float4_e2m1fn_x2),
requires_grad=False,
)
shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w13_weight_scale.view(-1, w13_weight_scale.shape[-1]),
num_experts,
True,
)
shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
w2_weight_scale.view(-1, w2_weight_scale.shape[-1]),
num_experts,
False,
)

if w13_bias is not None:
w13_bias = (
w13_bias.view(-1, n // 2, 2)
.permute(0, 2, 1)
.contiguous()
.view(-1, n)
)

return (
w13_weight,
w2_weight,
shuffled_w13_scale,
shuffled_w2_scale,
w13_bias,
w2_bias,
)
else:
raise ValueError(
f"Unsupported mxfp4_backend for Mxfp4MoEMethod: {mxfp4_backend}. "
f"Expected TRTLLM or Triton backend."
f"Expected TRTLLM, Triton, AITER, or Marlin backend."
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,56 @@ def vllm_topk_sigmoid(
return topk_weights, topk_indices


def _topk_softplus_sqrt_torch(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
routed_scaling_factor: float,
e_score_correction_bias: torch.Tensor | None,
input_tokens: torch.Tensor | None,
hash_indices_table: torch.Tensor | None,
) -> None:
x_f32 = gating_output.to(torch.float32)
weights_base = torch.sqrt(F.softplus(x_f32, beta=1.0, threshold=20.0))
topk = topk_weights.shape[-1]

if input_tokens is not None and hash_indices_table is not None:
selected_experts = hash_indices_table[input_tokens.to(torch.long)]
selected_weights = torch.gather(weights_base, -1, selected_experts.to(torch.long))
if renormalize:
denom = selected_weights.sum(dim=-1, keepdim=True)
denom = torch.where(denom > 0, denom, torch.ones_like(denom))
selected_weights = selected_weights / denom
selected_weights = selected_weights * routed_scaling_factor
topk_weights.copy_(selected_weights.to(topk_weights.dtype))
topk_indices.copy_(selected_experts.to(topk_indices.dtype))
return

ranking = weights_base
if e_score_correction_bias is not None:
ranking = ranking + e_score_correction_bias.to(torch.float32)
_, topk_ids = torch.topk(ranking, topk, dim=-1)
out_weights = torch.gather(weights_base, -1, topk_ids)
if renormalize:
denom = out_weights.sum(dim=-1, keepdim=True)
denom = torch.where(denom > 0, denom, torch.ones_like(denom))
out_weights = out_weights / denom
out_weights = out_weights * routed_scaling_factor
topk_weights.copy_(out_weights.to(topk_weights.dtype))
topk_indices.copy_(topk_ids.to(topk_indices.dtype))

arange_t = torch.arange(
gating_output.shape[0], device=gating_output.device,
dtype=token_expert_indices.dtype,
).unsqueeze(-1)
arange_k = torch.arange(
topk, device=gating_output.device, dtype=token_expert_indices.dtype,
).unsqueeze(0)
token_expert_indices.copy_(arange_k * gating_output.shape[0] + arange_t)


def vllm_topk_softplus_sqrt(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
Expand All @@ -68,17 +118,32 @@ def vllm_topk_softplus_sqrt(
hash_indices_table: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, ...]:
ops.topk_hash_softplus_sqrt(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
input_tokens,
hash_indices_table,
)
from vllm.platforms import current_platform

if current_platform.is_rocm():
_topk_softplus_sqrt_torch(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
input_tokens,
hash_indices_table,
)
else:
ops.topk_hash_softplus_sqrt(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
routed_scaling_factor,
e_score_correction_bias,
input_tokens,
hash_indices_table,
)

return topk_weights, topk_indices

Expand Down
Loading