Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,19 +200,49 @@ def load_model(self, model: nn.Module) -> None:
)
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
# check if mtp model use main model's embedding and LMhead
if hasattr(model, "model") and hasattr(model.model, "embed_tokens") and \
torch.equal(self.model.model.embed_tokens.weight,
model.model.embed_tokens.weight):
logger.info(
"The EAGLE head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = target_embed_tokens
share_embeddings = False
if hasattr(self.model, "has_own_embed_tokens"):
# EAGLE model
if not self.model.has_own_embed_tokens:
share_embeddings = True
logger.info(
"Detected EAGLE model without its own embed_tokens in the"
" checkpoint. Sharing target model embedding weights with the"
" draft model."
)
elif (
isinstance(target_embed_tokens.weight, torch.Tensor)
and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
Comment thread
zhaomingyu13 marked this conversation as resolved.
# TODO: Offload to CPU for comparison to avoid extra NPU memory
# usage in CI testing environments with limited NPU memory
and torch.equal(
target_embed_tokens.weight.cpu(),
self.model.model.embed_tokens.weight.cpu(),
)
):
share_embeddings = True
logger.info(
"Detected EAGLE model with embed_tokens identical to the target"
" model. Sharing target model embedding weights with the draft"
" model."
)
else:
logger.info(
"Detected EAGLE model with distinct embed_tokens weights. "
"Keeping separate embedding weights from the target model."
)
else:
# MTP model
share_embeddings = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow the original logic.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, it's my mistake.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

logger.info(
" The EAGLE head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
"Detected MTP model. "
"Sharing target model embedding weights with the draft model."
)

if share_embeddings:
if hasattr(self.model.model, "embed_tokens"):
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
"Since PP > 1 or other reasons the model head loaded its own vocab embedding" \
Expand Down
Loading