Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 93 additions & 19 deletions vllm/model_executor/models/deepseek_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn.functional as F

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import (
get_ep_group,
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -65,6 +65,8 @@
maybe_prefix,
)

_DEEPSEEK_V4_EXPERT_DTYPES = ("fp4", "fp8")


class DeepseekV4MLP(nn.Module):
def __init__(
Expand Down Expand Up @@ -118,16 +120,59 @@ def forward(self, x):


class DeepseekV4FP8Config(Fp8Config):
"""FP8 config that routes MoE layers to MXFP4 quantization.

DeepSeek V4 checkpoints use FP8 for linear/attention layers but
MXFP4 for MoE expert weights. This config inherits standard FP8
behavior and overrides only the MoE dispatch.
"""FP8 config for DeepSeek V4 with expert-dtype-aware MoE dispatch.

DeepSeek V4 checkpoints always use FP8 block quantization for
linear/attention layers. The MoE expert weights vary by checkpoint:
- ``expert_dtype="fp4"`` (e.g. DeepSeek-V4-Flash): MXFP4 experts
with ue8m0 (e8m0fnu) FP8 linear scales.
- ``expert_dtype="fp8"`` (e.g. DeepSeek-V4-Flash-Base): FP8 block
experts with float32 FP8 linear scales.

The dispatch and the linear scale dtype are both keyed off
``expert_dtype`` from the model's hf_config; missing values default
to ``"fp4"`` so existing FP4 checkpoints stay unchanged.

NOTE: ``expert_dtype`` is resolved lazily because this config is
constructed during VllmConfig setup, before ``set_current_vllm_config``
is active. Reading hf_config eagerly in ``__init__`` would always see
the default ``"fp4"`` and silently misroute Flash-Base checkpoints.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_scale_e8m0: bool = True
self._resolved_expert_dtype: str | None = None
# ``is_scale_e8m0`` is a property that resolves on first read,
# by which time the current vllm_config has been set.

@property
def expert_dtype(self) -> str:
if self._resolved_expert_dtype is None:
try:
hf_config = get_current_vllm_config().model_config.hf_config
except Exception:
# vllm_config not yet set; defer the decision until a
# later call lands inside set_current_vllm_config.
return "fp4"
expert_dtype = getattr(hf_config, "expert_dtype", "fp4")
if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES:
raise ValueError(
f"Unsupported DeepSeek V4 expert_dtype={expert_dtype!r}; "
f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}."
)
self._resolved_expert_dtype = expert_dtype
from vllm.logger import init_logger

init_logger(__name__).info_once(
"DeepSeek V4 expert_dtype resolved to %r", expert_dtype
)
return self._resolved_expert_dtype

@property
def is_scale_e8m0(self) -> bool:
# FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert
# checkpoints (Flash-Base) store them as float32.
return self.expert_dtype == "fp4"

@classmethod
def get_name(cls) -> QuantizationMethods:
Expand Down Expand Up @@ -155,11 +200,14 @@ def get_quant_method(self, layer, prefix):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
return Mxfp4MoEMethod(layer.moe_config)
if self.expert_dtype == "fp4":
return Mxfp4MoEMethod(layer.moe_config)
# expert_dtype == "fp8": fall through to Fp8Config which
# returns Fp8MoEMethod with block-wise float32 scales.
return super().get_quant_method(layer, prefix)

def is_mxfp4_quant(self, prefix, layer):
return isinstance(layer, FusedMoE)
return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4"


@triton.jit
Expand Down Expand Up @@ -689,6 +737,12 @@ def __init__(
raise NotImplementedError(
"DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only."
)
if self.use_mega_moe and getattr(config, "expert_dtype", "fp4") != "fp4":
raise NotImplementedError(
"DeepSeek V4 MegaMoE only supports fp4 experts; got expert_dtype="
f"{config.expert_dtype!r}. Drop --kernel-config moe_backend="
"deep_gemm_mega_moe for this checkpoint."
)

self.gate = GateLinear(
config.hidden_size,
Expand Down Expand Up @@ -1410,23 +1464,32 @@ def hc_head(
return y.to(dtype)


class DeepseekV4ForCausalLM(nn.Module):
model_cls = DeepseekV4Model

hf_to_vllm_mapper = WeightsMapper(
def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
if expert_dtype == "fp4":
# MXFP4 experts use Mxfp4MoEMethod, which registers scales as
# ``w{1,2,3}_weight_scale`` (no _inv suffix). FP8 linear and
# shared experts use Fp8LinearMethod's block scales, which
# register as ``weight_scale_inv``.
scale_regex = {
re.compile(r"(\.experts\.\d+\.w[123])\.scale$"): r"\1.weight_scale",
re.compile(r"\.scale$"): ".weight_scale_inv",
}
else:
# FP8 experts use Fp8MoEMethod (block_quant=True), which registers
# scales as ``w{13,2}_weight_scale_inv``. Map all ``.scale`` keys
# there.
scale_regex = {
re.compile(r"\.scale$"): ".weight_scale_inv",
}
return WeightsMapper(
orig_to_new_prefix={
"layers.": "model.layers.",
"embed.": "model.embed.",
"norm.": "model.norm.",
"hc_head": "model.hc_head",
"mtp.": "model.mtp.",
},
orig_to_new_regex={
# Routed MoE expert scales: experts.N.wX.scale -> .weight_scale
re.compile(r"(\.experts\.\d+\.w[123])\.scale$"): r"\1.weight_scale",
# Everything else (FP8 linear + shared experts): .scale -> .weight_scale_inv
re.compile(r"\.scale$"): ".weight_scale_inv",
},
orig_to_new_regex=scale_regex,
orig_to_new_suffix={
"head.weight": "lm_head.weight",
"embed.weight": "embed_tokens.weight",
Expand All @@ -1438,11 +1501,22 @@ class DeepseekV4ForCausalLM(nn.Module):
},
)


class DeepseekV4ForCausalLM(nn.Module):
model_cls = DeepseekV4Model

# Default mapper assumes the original FP4-expert checkpoint layout.
# Overridden per-instance in __init__ when expert_dtype != "fp4".
hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

config = vllm_config.model_config.hf_config
self.config = config
expert_dtype = getattr(config, "expert_dtype", "fp4")
if expert_dtype != "fp4":
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(expert_dtype)

self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
Expand Down
22 changes: 18 additions & 4 deletions vllm/model_executor/models/deepseek_v4_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,14 @@

logger = init_logger(__name__)

# MoE expert scales are fused into per-layer w13/w2 tensors; other FP8 linear
# scales use `.weight_scale_inv`. Mirrors the regex in
# DeepseekV4ForCausalLM.hf_to_vllm_mapper.
# MoE expert scales are fused into per-layer w13/w2 tensors. The exact
# parameter suffix depends on which FusedMoE method handles the experts:
# - fp4 experts (Mxfp4MoEMethod) register ``w{1,2,3}_weight_scale``;
# - fp8 experts (Fp8MoEMethod with block_quant=True) register
# ``w{1,2,3}_weight_scale_inv``.
# Other FP8 linear scales (including shared experts) always use
# ``.weight_scale_inv``. Mirrors the per-instance mapper built by
# ``_make_deepseek_v4_weights_mapper`` in deepseek_v4.py.
_EXPERT_SCALE_RE = re.compile(r"\.experts\.\d+\.w[123]\.scale$")


Expand Down Expand Up @@ -326,6 +331,15 @@ def _find_mtp_layer_idx(name: str) -> int:
num_experts=self.config.n_routed_experts,
)

# FP8 experts register ``..._weight_scale_inv`` (block_quant) while
# FP4/MXFP4 experts register ``..._weight_scale``. Choose the suffix
# for the rename below based on the model's expert dtype.
expert_scale_suffix = (
".weight_scale"
if getattr(self.config, "expert_dtype", "fp4") == "fp4"
else ".weight_scale_inv"
)

for name, loaded_weight in weights:
mtp_layer_idx = _find_mtp_layer_idx(name)
# V4 checkpoints store MTP weights as `mtp.{i}.*`; remap to
Expand All @@ -347,7 +361,7 @@ def _find_mtp_layer_idx(name: str) -> int:
continue
if name.endswith(".scale"):
suffix = (
".weight_scale"
expert_scale_suffix
if _EXPERT_SCALE_RE.search(name)
else ".weight_scale_inv"
)
Expand Down
Loading