diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 2970f255537..e791031f4dd 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -181,6 +181,7 @@ def test_load_model_pp1(self, mock_pp_group, mock_get_model, weight = torch.zeros(0) mock_model = MagicMock() + mock_model.supports_multimodal = False mock_model.lm_head = MagicMock() mock_model.multimodal_cpu_fields = None mock_model.merge_by_field_config = None @@ -193,7 +194,7 @@ def test_load_model_pp1(self, mock_pp_group, mock_get_model, self.proposer.load_model(mock_model) mock_get_model.assert_called_once() - self.assertEqual(self.proposer.attn_layer_name, ["layer3"]) + self.assertEqual(self.proposer.attn_layer_names, ["layer3"]) self.assertIs(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens) @@ -224,7 +225,7 @@ def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model, self.assertIsNot(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens) - self.assertEqual(self.proposer.attn_layer_name, ["layer2"]) + self.assertEqual(self.proposer.attn_layer_names, ["layer2"]) @patch( "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @@ -254,7 +255,7 @@ def test_load_model_multimodal(self, mock_supports_multi, mock_pp_group, self.proposer.name = SpecDcodeType.EAGLE self.proposer.load_model(mock_model) - mock_model.get_language_model.assert_called_once() + self.assertEqual(mock_model.get_language_model.call_count, 2) self.assertIs(self.proposer.model.lm_head, mock_model.get_language_model.return_value.lm_head) @@ -435,4 +436,4 @@ def test_prepare_inputs(self): torch.tensor([1, 2, 4]))): return_attn, indices = self.proposer.prepare_inputs( mock_attn, num_rejected) - self.assertEqual(indices.tolist(), [1, 2, 4]) + self.assertEqual(indices.tolist(), [1, 2, 4]) \ No newline at end of file diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 37a86e15efc..1d4bb4eb70b 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -197,31 +197,81 @@ def load_model(self, model: nn.Module) -> None: draft_indexer_layer_names = indexer_layers - target_indexer_layer_names draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names assert len(draft_attn_layer_names) == 1 - self.attn_layer_name = list(sorted(draft_attn_layer_names)) - self.attn_layer_names = self.attn_layer_name + self.attn_layer_names = list(sorted(draft_attn_layer_names)) + + if supports_multimodal(model): + # handle multimodality + if self.get_model_name(model) in [ + "Qwen2_5_VLForConditionalGeneration", + "Qwen3VLForConditionalGeneration", + ]: + self.model.config.image_token_index = model.config.image_token_id + elif self.get_model_name( + model) == "PixtralForConditionalGeneration": + self.model.config.image_token_index = ( + model.config.vision_config.image_token_id) + else: + self.model.config.image_token_index = ( + model.config.image_token_index) + target_language_model = model.get_language_model() + else: + target_language_model = model # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: - # If pp>1, the weights of mtp and the main model's embedding are not on the same device. - # check if mtp model use main model's embedding and LMhead - if hasattr(model, "model") and hasattr(model.model, "embed_tokens") and \ - torch.equal(self.model.model.embed_tokens.weight, - model.model.embed_tokens.weight): - logger.info( - "The EAGLE head shares the same vocab embedding" \ - " with the target model." + if hasattr(target_language_model.model, "embed_tokens"): + target_embed_tokens = target_language_model.model.embed_tokens + elif hasattr(target_language_model.model, "embedding"): + target_embed_tokens = target_language_model.model.embedding + else: + raise AttributeError( + "Target model does not have 'embed_tokens' or 'embedding' attribute" ) - self.model.model.embed_tokens = model.model.embed_tokens + + share_embeddings = False + if hasattr(self.model, "has_own_embed_tokens"): + # EAGLE model + if not self.model.has_own_embed_tokens: + share_embeddings = True + logger.info( + "Detected EAGLE model without its own embed_tokens in the" + " checkpoint. Sharing target model embedding weights with the" + " draft model.") + elif (isinstance(target_embed_tokens.weight, torch.Tensor) + and isinstance(self.model.model.embed_tokens.weight, + torch.Tensor) + # TODO: Offload to CPU for comparison to avoid extra NPU memory + # usage in CI testing environments with limited NPU memory + and torch.equal( + target_embed_tokens.weight.cpu(), + self.model.model.embed_tokens.weight.cpu(), + )): + share_embeddings = True + logger.info( + "Detected EAGLE model with embed_tokens identical to the target" + " model. Sharing target model embedding weights with the draft" + " model.") + else: + logger.info( + "Detected EAGLE model with distinct embed_tokens weights. " + "Keeping separate embedding weights from the target model." + ) else: + # MTP model + share_embeddings = True logger.info( - " The EAGLE head loaded its own vocab embedding" \ - " weights instead of sharing them with the target model." + "Detected MTP model. " + "Sharing target model embedding weights with the draft model." ) + + if share_embeddings: + if hasattr(self.model.model, "embed_tokens"): + del self.model.model.embed_tokens + self.model.model.embed_tokens = target_embed_tokens else: logger.info( - "Since PP > 1 or other reasons the model head loaded its own vocab embedding" \ - " weights instead of sharing them with the target model." - ) + "The draft model's vocab embedding will be loaded separately" + " from the target model.") # share lm_head with the target model if needed # some model definition do not define lm_head explicitly diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 9a2476f3456..3f126372f63 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -90,7 +90,7 @@ def dummy_run(self, attn_metadata_mtp = builder.build_for_graph_capture( common_attn_metadata, attn_state) attn_metadata = {} - for layer_name in self.attn_layer_name: + for layer_name in self.attn_layer_names: attn_metadata[layer_name] = attn_metadata_mtp else: attn_metadata = None @@ -302,7 +302,7 @@ def _propose( attn_metadata_mtp = builder.build(0, common_attn_metadata, self.runner.get_model()) attn_metadata = {} - for layer_name in self.attn_layer_name: + for layer_name in self.attn_layer_names: attn_metadata[layer_name] = attn_metadata_mtp for step in range(self.num_speculative_tokens): @@ -331,7 +331,7 @@ def _propose( hidden_states = torch.ops.vllm.maybe_pad_and_reduce( hidden_states) - for layer_name in self.attn_layer_name: + for layer_name in self.attn_layer_names: decode_metadata = getattr(attn_metadata[layer_name], "decode", None) if self.use_async_scheduling and decode_metadata is not None: @@ -402,7 +402,7 @@ def _propose( if step == self.num_speculative_tokens - 1 or with_prefill: break - attn_metadata_i = attn_metadata[self.attn_layer_name[0]] + attn_metadata_i = attn_metadata[self.attn_layer_names[0]] if step == 0: positions = target_positions[last_token_indices]