diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index 885fffea868a..fb724fbe2f1e 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -30,7 +30,6 @@ MHCPreOp, ) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -50,6 +49,7 @@ DeepseekV4Indexer, DeepseekV4MLA, ) +from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import has_tilelang @@ -314,25 +314,12 @@ def __init__( self.rope_parameters = config.rope_scaling # Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it) - rope_parameters = config.rope_parameters - rope_parameters["rope_theta"] = ( - config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta - ) - if config.rope_parameters["rope_type"] != "default": - config.rope_parameters["rope_type"] = ( - "deepseek_yarn" - if config.rope_parameters.get("apply_yarn_scaling", True) - else "deepseek_llama_scaling" - ) - rope_parameters["mscale"] = 0 # Disable mscale - rope_parameters["mscale_all_dim"] = 0 # Disable mscale - rope_parameters["is_deepseek_v4"] = True - rope_parameters["rope_dim"] = self.rope_head_dim - self.rotary_emb = get_rope( - self.head_dim, - max_position=self.max_position_embeddings, - rope_parameters=rope_parameters, - is_neox_style=False, + self.rotary_emb = build_deepseek_v4_rope( + config, + head_dim=self.head_dim, + rope_head_dim=self.rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + compress_ratio=self.compress_ratio, ) self.indexer = None diff --git a/vllm/models/deepseek_v4/common/rope.py b/vllm/models/deepseek_v4/common/rope.py new file mode 100644 index 000000000000..44ae3286eb21 --- /dev/null +++ b/vllm/models/deepseek_v4/common/rope.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""DeepseekV4 rotary embedding initialization.""" + +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding.base import RotaryEmbedding + + +def build_deepseek_v4_rope( + config, + *, + head_dim: int, + rope_head_dim: int, + max_position_embeddings: int, + compress_ratio: int, +) -> RotaryEmbedding: + rope_parameters = config.rope_parameters + rope_parameters["rope_theta"] = ( + config.compress_rope_theta if compress_ratio > 1 else config.rope_theta + ) + if rope_parameters["rope_type"] != "default": + rope_parameters["rope_type"] = ( + "deepseek_yarn" + if rope_parameters.get("apply_yarn_scaling", True) + else "deepseek_llama_scaling" + ) + rope_parameters["mscale"] = 0 # Disable mscale + rope_parameters["mscale_all_dim"] = 0 # Disable mscale + rope_parameters["is_deepseek_v4"] = True + rope_parameters["rope_dim"] = rope_head_dim + return get_rope( + head_dim, + max_position=max_position_embeddings, + rope_parameters=rope_parameters, + is_neox_style=False, + ) diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index e26d9e593bec..13e58360c8b2 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -35,7 +35,6 @@ ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -56,6 +55,7 @@ DeepseekV4Indexer, DeepseekV4MLA, ) +from vllm.models.deepseek_v4.common.rope import build_deepseek_v4_rope from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs from vllm.sequence import IntermediateTensors @@ -697,25 +697,12 @@ def __init__( self.rope_parameters = config.rope_scaling # Initialize rotary embedding BEFORE DeepseekV4MLA (which needs it) - rope_parameters = config.rope_parameters - rope_parameters["rope_theta"] = ( - config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta - ) - if config.rope_parameters["rope_type"] != "default": - config.rope_parameters["rope_type"] = ( - "deepseek_yarn" - if config.rope_parameters.get("apply_yarn_scaling", True) - else "deepseek_llama_scaling" - ) - rope_parameters["mscale"] = 0 # Disable mscale - rope_parameters["mscale_all_dim"] = 0 # Disable mscale - rope_parameters["is_deepseek_v4"] = True - rope_parameters["rope_dim"] = self.rope_head_dim - self.rotary_emb = get_rope( - self.head_dim, - max_position=self.max_position_embeddings, - rope_parameters=rope_parameters, - is_neox_style=False, + self.rotary_emb = build_deepseek_v4_rope( + config, + head_dim=self.head_dim, + rope_head_dim=self.rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + compress_ratio=self.compress_ratio, ) self.indexer = None