Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/models/test_deepseek_v4_mega_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
import torch

from vllm.model_executor.models.deepseek_v4 import (
from vllm.models.deepseek_v4.nvidia.deepseek_v4 import (
DeepseekV4MegaMoEExperts,
_stage_deepseek_v4_mega_moe_inputs,
make_deepseek_v4_expert_params_mapping,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
# lazy import to avoid triggering `torch.compile` too early
from vllm.config.quantization import _ONLINE_SHORTHANDS
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from vllm.model_executor.models.deepseek_v4 import DeepseekV4FP8Config
from vllm.models.deepseek_v4 import DeepseekV4FP8Config

from .auto_gptq import AutoGPTQConfig
from .awq import AWQConfig
Expand Down
31 changes: 26 additions & 5 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import importlib
import importlib.util
import json
import os
import pickle
Expand Down Expand Up @@ -97,7 +98,7 @@
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),
"DeepseekV4ForCausalLM": ("vllm.models.deepseek_v4", "DeepseekV4ForCausalLM"),
"Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
"Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
"Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
Expand Down Expand Up @@ -611,7 +612,7 @@
"Eagle3DeepseekV3ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"DeepSeekV4MTPModel": ("deepseek_v4_mtp", "DeepSeekV4MTP"),
"DeepSeekV4MTPModel": ("vllm.models.deepseek_v4", "DeepSeekV4MTP"),
"Gemma4MTPModel": ("gemma4_mtp", "Gemma4MTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
"ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
Expand Down Expand Up @@ -870,10 +871,21 @@ def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None:

@logtime(logger=logger, msg="Registry inspect model class")
def inspect_model_cls(self) -> _ModelInfo:
model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
# Modules registered with a non-default location (e.g. the
# hardware-isolated ``vllm.models.<name>`` layout) live outside
# ``vllm/model_executor/models``. Resolve the module spec directly
# so the file-hash cache stays warm for them.
if self.module_name.startswith("vllm.model_executor.models."):
model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py"
else:
try:
spec = importlib.util.find_spec(self.module_name)
except (ImportError, ValueError):
spec = None
model_path = Path(spec.origin) if spec is not None and spec.origin else None
module_hash = None

if model_path.exists():
if model_path is not None and model_path.exists():
with open(model_path, "rb") as f:
module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()

Expand Down Expand Up @@ -1328,10 +1340,19 @@ def is_transcription_only_model(
return model_cls.supports_transcription_only


def _resolve_module_name(mod_relname: str) -> str:
# Allow registry entries to point at fully-qualified module paths (e.g.
# ``vllm.models.deepseek_v4``) for models that live outside the legacy
# ``vllm.model_executor.models`` flat layout.
if mod_relname.startswith("vllm."):
return mod_relname
return f"vllm.model_executor.models.{mod_relname}"


ModelRegistry = _ModelRegistry(
{
model_arch: _LazyRegisteredModel(
module_name=f"vllm.model_executor.models.{mod_relname}",
module_name=_resolve_module_name(mod_relname),
class_name=cls_name,
)
for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items()
Expand Down
2 changes: 2 additions & 0 deletions vllm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
31 changes: 31 additions & 0 deletions vllm/models/deepseek_v4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""DeepSeek V4 model — hardware-isolated entry point.

The actual implementation lives under ``nvidia/`` and ``amd/``; this module
picks the right one for the current platform and re-exports the public
classes used by the model registry and quantization config lookup.
"""

from typing import TYPE_CHECKING

from vllm.platforms import current_platform

# Pick the per-platform implementation. The NVIDIA branch is the static
# default that mypy sees; the ROCm branch overrides it at runtime and is
# kept type-compatible via ``# type: ignore[assignment]``.
if TYPE_CHECKING or not current_platform.is_rocm():
from .nvidia.deepseek_v4 import DeepseekV4ForCausalLM, DeepseekV4FP8Config
from .nvidia.deepseek_v4_mtp import DeepSeekV4MTP
else:
from .amd.deepseek_v4 import ( # type: ignore[assignment]
DeepseekV4ForCausalLM,
DeepseekV4FP8Config,
)
from .amd.deepseek_v4_mtp import DeepSeekV4MTP # type: ignore[assignment]
Comment on lines +19 to +24

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can extract this away once we have more and more models written in this way right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Will do it later.


__all__ = [
"DeepSeekV4MTP",
"DeepseekV4FP8Config",
"DeepseekV4ForCausalLM",
]
2 changes: 2 additions & 0 deletions vllm/models/deepseek_v4/amd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
1 change: 1 addition & 0 deletions vllm/models/deepseek_v4/amd/deepseek_v4.py
1 change: 1 addition & 0 deletions vllm/models/deepseek_v4/amd/deepseek_v4_mtp.py
2 changes: 2 additions & 0 deletions vllm/models/deepseek_v4/nvidia/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,7 @@
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op

from .utils import (
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
Expand All @@ -75,6 +69,11 @@
make_layers,
maybe_prefix,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op

_DEEPSEEK_V4_EXPERT_DTYPES = ("fp4", "fp8")

Expand Down Expand Up @@ -1535,7 +1534,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
):
loaded_weight = loaded_weight.view(torch.uint8)
for mapping in expert_mapping:
param_name, weight_name, expert_id, shard_id = mapping
param_name, weight_name, expert_id, shard_id = mapping # type: ignore[assignment]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# type: ignore[assignment] is this flagged by pre-commit? why don't we have this before?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because vllm/model_executor/models is excluded from the mypy check historically. Now we apply mypy to vllm/models, which is a nice improvement

EXCLUDE = [
"vllm/model_executor/models",

if weight_name not in name:
continue
name_mapped = name.replace(weight_name, param_name)
Expand Down Expand Up @@ -1673,7 +1672,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.make_empty_intermediate_tensors = ( # type: ignore[method-assign]
self.model.make_empty_intermediate_tensors
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.deepseek_mtp import SharedHead
from vllm.model_executor.models.deepseek_v2 import get_spec_layer_idx_from_weight_name
from vllm.model_executor.models.utils import maybe_prefix
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .deepseek_mtp import SharedHead
from .deepseek_v2 import get_spec_layer_idx_from_weight_name
from .deepseek_v4 import (
DeepseekV4DecoderLayer,
make_deepseek_v4_expert_params_mapping,
)
from .utils import maybe_prefix

logger = init_logger(__name__)

Expand All @@ -68,7 +68,7 @@ def __init__(
) -> None:
super().__init__()

config = vllm_config.speculative_config.draft_model_config.hf_config
config = vllm_config.speculative_config.draft_model_config.hf_config # type: ignore[union-attr]
Comment thread
WoosukKwon marked this conversation as resolved.
Outdated
self.config = config
quant_config = vllm_config.quant_config
self.rms_norm_eps = config.rms_norm_eps
Expand Down Expand Up @@ -407,7 +407,7 @@ def _find_mtp_layer_idx(name: str) -> int:
):
loaded_weight = loaded_weight.view(torch.uint8)
for mapping in expert_mapping:
param_name, weight_name, expert_id, shard_id = mapping
param_name, weight_name, expert_id, shard_id = mapping # type: ignore[assignment]
if weight_name not in name:
continue
name_mapped = name.replace(weight_name, param_name)
Expand Down
Loading