Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,24 @@ def load_model(self, model: nn.Module) -> None:
"The EAGLE head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = model.model.embed_tokens
if hasattr(model.model, "embed_tokens"):
target_embed_tokens = model.model.embed_tokens
elif hasattr(model.model, "embedding"):
target_embed_tokens = model.model.embedding
else:
raise AttributeError(
"Target model does not have 'embed_tokens' or 'embedding' attribute"
)
# share embed_tokens with the target model if needed
if not self.model.has_own_embed_tokens:
logger.info("Draft model embed_tokens are uninitialized. "
"Sharing vocab embedding with the target model.")
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
"Draft model embed_tokens are already initialized. "
"Keeping separate vocab embedding from the target model.")
else:
logger.info(
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
Expand All @@ -121,12 +138,24 @@ def load_model(self, model: nn.Module) -> None:
# share lm_head with the target model if needed
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.name == SpecDcodeType.EAGLE and hasattr(model, "lm_head"):
logger.info("Loading EAGLE LM head weights from the target model.")
if supports_multimodal(model):
self.model.lm_head = model.get_language_model().lm_head
else:
if self.name == SpecDcodeType.EAGLE:
if hasattr(model, "lm_head"):
logger.info(
"Loading EAGLE LM head weights from the target model.")
if supports_multimodal(model):
self.model.lm_head = model.get_language_model().lm_head
else:
self.model.lm_head = model.lm_head
else:
if (hasattr(model, "lm_head") and hasattr(self.model, "lm_head")
and not self.model.has_own_lm_head):
logger.info("Draft model lm_head is uninitialized. "
"Sharing lm_head with the target model.")
del self.model.lm_head
self.model.lm_head = model.lm_head
else:
logger.info("Draft model lm_head is already initialized. "
"Keeping separate lm_head from the target model.")

@torch.inference_mode()
def dummy_run(self,
Expand Down
Loading