Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 7 additions & 20 deletions vllm/models/deepseek_v4/amd/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
MHCPreOp,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand All @@ -50,6 +49,7 @@
DeepseekV4Indexer,
DeepseekV4MLA,
)
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import has_tilelang
Expand Down Expand Up @@ -314,25 +314,12 @@ def __init__(
self.rope_parameters = config.rope_scaling

# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
rope_parameters = config.rope_parameters
rope_parameters["rope_theta"] = (
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
)
if config.rope_parameters["rope_type"] != "default":
config.rope_parameters["rope_type"] = (
"deepseek_yarn"
if config.rope_parameters.get("apply_yarn_scaling", True)
else "deepseek_llama_scaling"
)
rope_parameters["mscale"] = 0 # Disable mscale
rope_parameters["mscale_all_dim"] = 0 # Disable mscale
rope_parameters["is_deepseek_v4"] = True
rope_parameters["rope_dim"] = self.rope_head_dim
self.rotary_emb = get_rope(
self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=False,
self.rotary_emb = build_deepseek_v4_rope(
config,
head_dim=self.head_dim,
rope_head_dim=self.rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
compress_ratio=self.compress_ratio,
)

self.indexer = None
Expand Down
36 changes: 36 additions & 0 deletions vllm/models/deepseek_v4/common/rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""DeepseekV4 rotary embedding initialization."""

from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbedding


def build_deepseek_v4_rope(
config,
*,
head_dim: int,
rope_head_dim: int,
max_position_embeddings: int,
compress_ratio: int,
) -> RotaryEmbedding:
rope_parameters = config.rope_parameters
rope_parameters["rope_theta"] = (
config.compress_rope_theta if compress_ratio > 1 else config.rope_theta
)
if rope_parameters["rope_type"] != "default":
rope_parameters["rope_type"] = (
"deepseek_yarn"
if rope_parameters.get("apply_yarn_scaling", True)
else "deepseek_llama_scaling"
)
rope_parameters["mscale"] = 0 # Disable mscale
rope_parameters["mscale_all_dim"] = 0 # Disable mscale
rope_parameters["is_deepseek_v4"] = True
rope_parameters["rope_dim"] = rope_head_dim
return get_rope(
head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=False,
)
27 changes: 7 additions & 20 deletions vllm/models/deepseek_v4/nvidia/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand All @@ -56,6 +55,7 @@
DeepseekV4Indexer,
DeepseekV4MLA,
)
from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope
from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs
from vllm.sequence import IntermediateTensors

Expand Down Expand Up @@ -697,25 +697,12 @@ def __init__(
self.rope_parameters = config.rope_scaling

# Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it)
rope_parameters = config.rope_parameters
rope_parameters["rope_theta"] = (
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
)
if config.rope_parameters["rope_type"] != "default":
config.rope_parameters["rope_type"] = (
"deepseek_yarn"
if config.rope_parameters.get("apply_yarn_scaling", True)
else "deepseek_llama_scaling"
)
rope_parameters["mscale"] = 0 # Disable mscale
rope_parameters["mscale_all_dim"] = 0 # Disable mscale
rope_parameters["is_deepseek_v4"] = True
rope_parameters["rope_dim"] = self.rope_head_dim
self.rotary_emb = get_rope(
self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=False,
self.rotary_emb = build_deepseek_v4_rope(
config,
head_dim=self.head_dim,
rope_head_dim=self.rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
compress_ratio=self.compress_ratio,
)

self.indexer = None
Expand Down
Loading