Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions tests/ut/spec_decode/test_eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,19 @@ def test_load_model_pp1(self, mock_pp_group, mock_get_model,
"layer3": mock_draft_layer3
}]

weight = torch.zeros(0)

mock_model = MagicMock()
mock_model.supports_multimodal = False
mock_model.model.embed_tokens = MagicMock()
mock_model.lm_head = MagicMock()
mock_model.multimodal_cpu_fields = None
mock_model.merge_by_field_config = None
mock_get_model.return_value = MagicMock()
mock_model.model.embed_tokens = MagicMock()
mock_model.model.embed_tokens.weight = weight

self.proposer.name = SpecDcodeType.EAGLE
mock_get_model.return_value = MagicMock()
mock_get_model.return_value.model.embed_tokens.weight = weight

self.proposer.load_model(mock_model)
mock_get_model.assert_called_once()
Expand Down
23 changes: 8 additions & 15 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,28 +177,21 @@ def load_model(self, model: nn.Module) -> None:
raise AttributeError(
"Target model does not have 'embed_tokens' or 'embedding' attribute"
)
if self.method == "mtp":
if self.vllm_config.model_config.is_deepseek_mla and \
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
# check if mtp model use main model's embedding and LMhead
if hasattr(model, "model") and hasattr(model.model, "embed_tokens") and \
torch.equal(self.model.model.embed_tokens.weight,
model.model.embed_tokens.weight):
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
# check if mtp model use main model's embedding and LMhead
logger.info(
"The MTP head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
" The MTP head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
)
else:
logger.info(
"The EAGLE head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
" The EAGLE head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
)
else:
logger.info(
"Since PP > 1 or other reasons the model head loaded its own vocab embedding" \
Expand Down
Loading