Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/models/test_deepseek_v4_mega_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
import torch

from vllm.model_executor.models.deepseek_v4 import (
from vllm.models.deepseek_v4.nvidia.deepseek_v4 import (
DeepseekV4MegaMoEExperts,
_stage_deepseek_v4_mega_moe_inputs,
make_deepseek_v4_expert_params_mapping,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
# lazy import to avoid triggering `torch.compile` too early
from vllm.config.quantization import _ONLINE_SHORTHANDS
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from vllm.model_executor.models.deepseek_v4 import DeepseekV4FP8Config
from vllm.models.deepseek_v4 import DeepseekV4FP8Config

from .auto_gptq import AutoGPTQConfig
from .awq import AWQConfig
Expand Down
31 changes: 26 additions & 5 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import importlib
import importlib.util
import json
import os
import pickle
Expand Down Expand Up @@ -97,7 +98,7 @@
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),
"DeepseekV4ForCausalLM": ("vllm.models.deepseek_v4", "DeepseekV4ForCausalLM"),
"Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
"Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
"Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
Expand Down Expand Up @@ -611,7 +612,7 @@
"Eagle3DeepseekV3ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"DeepSeekV4MTPModel": ("deepseek_v4_mtp", "DeepSeekV4MTP"),
"DeepSeekV4MTPModel": ("vllm.models.deepseek_v4", "DeepSeekV4MTP"),
"Gemma4MTPModel": ("gemma4_mtp", "Gemma4MTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
"ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
Expand Down Expand Up @@ -870,10 +871,21 @@ def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:

@logtime(logger=logger, msg="Registry inspect model class")
def inspect_model_cls(self) -> _ModelInfo:
model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
# Modules registered with a non-default location (e.g. the
# hardware-isolated ``vllm.models.<name>`` layout) live outside
# ``vllm/model_executor/models``. Resolve the module spec directly
# so the file-hash cache stays warm for them.
if self.module_name.startswith("vllm.model_executor.models."):
model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
else:
try:
spec = importlib.util.find_spec(self.module_name)
except (ImportError, ValueError):
spec = None
model_path = Path(spec.origin) if spec is not None and spec.origin else None
module_hash = None

if model_path.exists():
if model_path is not None and model_path.exists():
with open(model_path, "rb") as f:
module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()

Expand Down Expand Up @@ -1328,10 +1340,19 @@ def is_transcription_only_model(
return model_cls.supports_transcription_only


def _resolve_module_name(mod_relname: str) -> str:
# Allow registry entries to point at fully-qualified module paths (e.g.
# ``vllm.models.deepseek_v4``) for models that live outside the legacy
# ``vllm.model_executor.models`` flat layout.
if mod_relname.startswith("vllm."):
return mod_relname
return f"vllm.model_executor.models.{mod_relname}"


ModelRegistry = _ModelRegistry(
{
model_arch: _LazyRegisteredModel(
module_name=f"vllm.model_executor.models.{mod_relname}",
module_name=_resolve_module_name(mod_relname),
class_name=cls_name,
)
for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
Expand Down
2 changes: 2 additions & 0 deletions vllm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
30 changes: 30 additions & 0 deletions vllm/models/deepseek_v4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""DeepSeek V4 model — hardware-isolated entry point.

The actual implementation lives under ``nvidia/`` and ``amd/``; this module
picks the right one for the current platform and re-exports the public
classes used by the model registry and quantization config lookup.
"""

from typing import TYPE_CHECKING

from vllm.platforms import current_platform

from .quant_config import DeepseekV4FP8Config

# Pick the per-platform implementation. The NVIDIA branch is the static
# default that mypy sees; the ROCm branch overrides it at runtime and is
# kept type-compatible via ``# type: ignore[assignment]``.
if TYPE_CHECKING or not current_platform.is_rocm():
from .nvidia.deepseek_v4 import DeepseekV4ForCausalLM
from .nvidia.deepseek_v4_mtp import DeepSeekV4MTP
else:
from .amd.deepseek_v4 import DeepseekV4ForCausalLM # type: ignore[assignment]
from .amd.deepseek_v4_mtp import DeepSeekV4MTP # type: ignore[assignment]
Comment on lines +19 to +24
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can extract this away once we have more and more models written in this way right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Will do it later.


__all__ = [
"DeepSeekV4MTP",
"DeepseekV4FP8Config",
"DeepseekV4ForCausalLM",
]
2 changes: 2 additions & 0 deletions vllm/models/deepseek_v4/amd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
1 change: 1 addition & 0 deletions vllm/models/deepseek_v4/amd/deepseek_v4.py
1 change: 1 addition & 0 deletions vllm/models/deepseek_v4/amd/deepseek_v4_mtp.py
2 changes: 2 additions & 0 deletions vllm/models/deepseek_v4/nvidia/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn as nn

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config import VllmConfig
from vllm.distributed import (
get_ep_group,
get_pp_group,
Expand All @@ -24,7 +24,6 @@
DeepseekV4MultiHeadLatentAttentionWrapper,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)
Expand All @@ -44,29 +43,15 @@
MHCPostOp,
MHCPreOp,
)
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
QuantizationMethods,
)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4MoEMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op

from .utils import (
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
Expand All @@ -75,8 +60,11 @@
make_layers,
maybe_prefix,
)

_DEEPSEEK_V4_EXPERT_DTYPES = ("fp4", "fp8")
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op


class DeepseekV4MLP(nn.Module):
Expand Down Expand Up @@ -130,97 +118,6 @@ def forward(self, x):
return x


class DeepseekV4FP8Config(Fp8Config):
"""FP8 config for DeepSeek V4 with expert-dtype-aware MoE dispatch.

DeepSeek V4 checkpoints always use FP8 block quantization for
linear/attention layers. The MoE expert weights vary by checkpoint:
- ``expert_dtype="fp4"`` (e.g. DeepSeek-V4-Flash): MXFP4 experts
with ue8m0 (e8m0fnu) FP8 linear scales.
- ``expert_dtype="fp8"`` (e.g. DeepSeek-V4-Flash-Base): FP8 block
experts with float32 FP8 linear scales.

The dispatch and the linear scale dtype are both keyed off
``expert_dtype`` from the model's hf_config; missing values default
to ``"fp4"`` so existing FP4 checkpoints stay unchanged.

NOTE: ``expert_dtype`` is resolved lazily because this config is
constructed during VllmConfig setup, before ``set_current_vllm_config``
is active. Reading hf_config eagerly in ``__init__`` would always see
the default ``"fp4"`` and silently misroute Flash-Base checkpoints.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._resolved_expert_dtype: str | None = None
# ``is_scale_e8m0`` is a property that resolves on first read,
# by which time the current vllm_config has been set.

@property
def expert_dtype(self) -> str:
if self._resolved_expert_dtype is None:
try:
hf_config = get_current_vllm_config().model_config.hf_config
except Exception:
# vllm_config not yet set; defer the decision until a
# later call lands inside set_current_vllm_config.
return "fp4"
expert_dtype = getattr(hf_config, "expert_dtype", "fp4")
if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES:
raise ValueError(
f"Unsupported DeepSeek V4 expert_dtype={expert_dtype!r}; "
f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}."
)
self._resolved_expert_dtype = expert_dtype
from vllm.logger import init_logger

init_logger(__name__).info_once(
"DeepSeek V4 expert_dtype resolved to %r", expert_dtype
)
return self._resolved_expert_dtype

@property
def is_scale_e8m0(self) -> bool:
# FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert
# checkpoints (Flash-Base) store them as float32.
return self.expert_dtype == "fp4"

@classmethod
def get_name(cls) -> QuantizationMethods:
return "deepseek_v4_fp8"

@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
if not (
isinstance(hf_quant_cfg, dict)
and hf_quant_cfg.get("quant_method") in ("fp8", "deepseek_v4_fp8")
):
return None
model_type = getattr(hf_config, "model_type", None)
if model_type == "deepseek_v4" or user_quant == "deepseek_v4_fp8":
return "deepseek_v4_fp8"
return None

def get_quant_method(self, layer, prefix):
if isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
if self.expert_dtype == "fp4":
return Mxfp4MoEMethod(layer.moe_config)
# expert_dtype == "fp8": fall through to Fp8Config which
# returns Fp8MoEMethod with block-wise float32 scales.
return super().get_quant_method(layer, prefix)

def is_mxfp4_quant(self, prefix, layer):
return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4"


@triton.jit
def _deepseek_v4_stage_mega_moe_inputs_kernel(
hidden_states,
Expand Down Expand Up @@ -1535,7 +1432,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
):
loaded_weight = loaded_weight.view(torch.uint8)
for mapping in expert_mapping:
param_name, weight_name, expert_id, shard_id = mapping
param_name, weight_name, expert_id, expert_shard_id = mapping
if weight_name not in name:
continue
name_mapped = name.replace(weight_name, param_name)
Expand All @@ -1552,7 +1449,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
shard_id=expert_shard_id,
expert_id=expert_id,
return_success=True,
)
Expand Down Expand Up @@ -1673,7 +1570,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.make_empty_intermediate_tensors = ( # type: ignore[method-assign]
self.model.make_empty_intermediate_tensors
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.deepseek_mtp import SharedHead
from vllm.model_executor.models.deepseek_v2 import get_spec_layer_idx_from_weight_name
from vllm.model_executor.models.utils import maybe_prefix
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .deepseek_mtp import SharedHead
from .deepseek_v2 import get_spec_layer_idx_from_weight_name
from .deepseek_v4 import (
DeepseekV4DecoderLayer,
make_deepseek_v4_expert_params_mapping,
)
from .utils import maybe_prefix

logger = init_logger(__name__)

Expand All @@ -68,6 +68,7 @@ def __init__(
) -> None:
super().__init__()

assert vllm_config.speculative_config is not None
config = vllm_config.speculative_config.draft_model_config.hf_config
self.config = config
quant_config = vllm_config.quant_config
Expand Down Expand Up @@ -407,7 +408,7 @@ def _find_mtp_layer_idx(name: str) -> int:
):
loaded_weight = loaded_weight.view(torch.uint8)
for mapping in expert_mapping:
param_name, weight_name, expert_id, shard_id = mapping
param_name, weight_name, expert_id, expert_shard_id = mapping
if weight_name not in name:
continue
name_mapped = name.replace(weight_name, param_name)
Expand All @@ -422,7 +423,7 @@ def _find_mtp_layer_idx(name: str) -> int:
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
shard_id=expert_shard_id,
expert_id=expert_id,
return_success=True,
)
Expand Down
Loading
Loading