diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 86881376a106..44797874a4c5 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -83,6 +83,7 @@ maybe_remap_kv_scale_name, ) from vllm.model_executor.models.utils import ( + AutoWeightsLoader, extract_layer_index, sequence_parallel_chunk, ) @@ -1254,6 +1255,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.aux_hidden_state_layers = tuple[int, ...]() + # Needed by load_weights + qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) + qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + self.use_mha = config.model_type == "deepseek" or all( + dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) + ) + self.num_redundant_experts = ( + vllm_config.parallel_config.eplb_config.num_redundant_experts + ) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -1315,174 +1326,6 @@ def forward( return hidden_states, aux_hidden_states return hidden_states - -class DeepseekV2MixtureOfExperts(MixtureOfExperts): - moe_mlp_layers: list[DeepseekV2MoE] - """ - List of MoE MLP layers in the model. - """ - - def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None): - if example_moe is None: - self.num_moe_layers = 0 - self.num_expert_groups = 0 - self.num_logical_experts = 0 - self.num_physical_experts = 0 - self.num_local_physical_experts = 0 - self.num_routed_experts = 0 - self.num_shared_experts = 0 - self.num_redundant_experts = 0 - logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.") - else: - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_shared_experts = example_moe.n_shared_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - def update_physical_experts_metadata( - self, - num_physical_experts: int, - num_local_physical_experts: int, - ) -> None: - assert self.num_local_physical_experts == num_local_physical_experts - self.num_physical_experts = num_physical_experts - self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = num_physical_experts - self.num_logical_experts - for moe in self.moe_mlp_layers: - moe.n_local_physical_experts = num_local_physical_experts - moe.n_physical_experts = num_physical_experts - moe.n_redundant_experts = self.num_redundant_experts - moe.experts.update_expert_map() - - -class DeepseekV2ForCausalLM( - nn.Module, - SupportsPP, - DeepseekV2MixtureOfExperts, - SupportsLoRA, - SupportsEagle, - SupportsEagle3, -): - packed_modules_mapping = { - "gate_up_proj": ["gate_proj", "up_proj"], - } - model_cls = DeepseekV2Model - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - - qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) - qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) - self.use_mha = config.model_type == "deepseek" or all( - dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) - ) - - if self.use_mha: - self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"] - - # `packed_modules_mapping` needs to be modified before - # initializing DeepseekV2Model, as it is passed inplace to - # quantization config init and may be used to select the - # quant_method for relevant layers during initialization. - self.fuse_qkv_a_proj = ( - hasattr(config, "q_lora_rank") and config.q_lora_rank is not None - ) - if self.fuse_qkv_a_proj: - self.packed_modules_mapping["fused_qkv_a_proj"] = [ - "q_a_proj", - "kv_a_proj_with_mqa", - ] - - self.model = self.model_cls( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - else: - self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) - # Set MoE hyperparameters - self.num_moe_layers = ( - self.config.num_hidden_layers - self.config.first_k_dense_replace - ) - self.set_moe_parameters() - - def set_moe_parameters(self): - self.expert_weights = [] - - self.num_expert_groups = getattr(self.config, "n_group", 1) - - self.moe_layers = [] - self.moe_mlp_layers = [] - example_moe = None - for layer in self.model.layers: - if isinstance(layer, PPMissingLayer): - continue - - assert isinstance(layer, DeepseekV2DecoderLayer) - if isinstance(layer.mlp, DeepseekV2MoE): - # Pick last one layer since the first ones may be dense layers. - example_moe = layer.mlp - self.moe_mlp_layers.append(layer.mlp) - self.moe_layers.append(layer.mlp.experts) - - self.extract_moe_parameters(example_moe) - - def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: - self.model.aux_hidden_state_layers = layers - - def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: - num_layers = len(self.model.layers) - return (2, num_layers // 2, num_layers - 3) - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor | None, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor | IntermediateTensors: - hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds - ) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - logits = self.logits_processor(self.lm_head, hidden_states) - return logits - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - return fused_moe_make_expert_params_mapping( - self, - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts, - num_redundant_experts=0, - ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: rocm_aiter_moe_shared_expert_enabled = ( rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() @@ -1703,6 +1546,178 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params +class DeepseekV2MixtureOfExperts(MixtureOfExperts): + moe_mlp_layers: list[DeepseekV2MoE] + """ + List of MoE MLP layers in the model. + """ + + def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None): + if example_moe is None: + self.num_moe_layers = 0 + self.num_expert_groups = 0 + self.num_logical_experts = 0 + self.num_physical_experts = 0 + self.num_local_physical_experts = 0 + self.num_routed_experts = 0 + self.num_shared_experts = 0 + self.num_redundant_experts = 0 + logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.") + else: + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for moe in self.moe_mlp_layers: + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + +class DeepseekV2ForCausalLM( + nn.Module, + SupportsPP, + DeepseekV2MixtureOfExperts, + SupportsLoRA, + SupportsEagle, + SupportsEagle3, +): + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + } + model_cls = DeepseekV2Model + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) + qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + self.use_mha = config.model_type == "deepseek" or all( + dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) + ) + + if self.use_mha: + self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"] + + # `packed_modules_mapping` needs to be modified before + # initializing DeepseekV2Model, as it is passed inplace to + # quantization config init and may be used to select the + # quant_method for relevant layers during initialization. + self.fuse_qkv_a_proj = ( + hasattr(config, "q_lora_rank") and config.q_lora_rank is not None + ) + if self.fuse_qkv_a_proj: + self.packed_modules_mapping["fused_qkv_a_proj"] = [ + "q_a_proj", + "kv_a_proj_with_mqa", + ] + + self.model = self.model_cls( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + # Set MoE hyperparameters + self.num_moe_layers = ( + self.config.num_hidden_layers - self.config.first_k_dense_replace + ) + self.set_moe_parameters() + + def set_moe_parameters(self): + self.expert_weights = [] + + self.num_expert_groups = getattr(self.config, "n_group", 1) + + self.moe_layers = [] + self.moe_mlp_layers = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, DeepseekV2DecoderLayer) + if isinstance(layer.mlp, DeepseekV2MoE): + # Pick last one layer since the first ones may be dense layers. + example_moe = layer.mlp + self.moe_mlp_layers.append(layer.mlp) + self.moe_layers.append(layer.mlp.experts) + + self.extract_moe_parameters(example_moe) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return fused_moe_make_expert_params_mapping( + self, + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + num_redundant_experts=0, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + class DeepseekForCausalLM(DeepseekV2ForCausalLM): pass @@ -1726,6 +1741,8 @@ def get_spec_layer_idx_from_weight_name( ): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): - if weight_name.startswith(f"model.layers.{layer_idx + i}."): + if weight_name.startswith( + f"model.layers.{layer_idx + i}." + ) or weight_name.startswith(f"layers.{layer_idx + i}."): return layer_idx + i return None