diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 74e2917806b..8d46338ae45 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -111,7 +111,24 @@ def load_model(self, model: nn.Module) -> None: "The EAGLE head shares the same vocab embedding" \ " with the target model." ) - self.model.model.embed_tokens = model.model.embed_tokens + if hasattr(model.model, "embed_tokens"): + target_embed_tokens = model.model.embed_tokens + elif hasattr(model.model, "embedding"): + target_embed_tokens = model.model.embedding + else: + raise AttributeError( + "Target model does not have 'embed_tokens' or 'embedding' attribute" + ) + # share embed_tokens with the target model if needed + if not self.model.has_own_embed_tokens: + logger.info("Draft model embed_tokens are uninitialized. " + "Sharing vocab embedding with the target model.") + del self.model.model.embed_tokens + self.model.model.embed_tokens = target_embed_tokens + else: + logger.info( + "Draft model embed_tokens are already initialized. " + "Keeping separate vocab embedding from the target model.") else: logger.info( "Since PP > 1, the EAGLE head loaded its own vocab embedding" \ @@ -121,12 +138,24 @@ def load_model(self, model: nn.Module) -> None: # share lm_head with the target model if needed # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.name == SpecDcodeType.EAGLE and hasattr(model, "lm_head"): - logger.info("Loading EAGLE LM head weights from the target model.") - if supports_multimodal(model): - self.model.lm_head = model.get_language_model().lm_head - else: + if self.name == SpecDcodeType.EAGLE: + if hasattr(model, "lm_head"): + logger.info( + "Loading EAGLE LM head weights from the target model.") + if supports_multimodal(model): + self.model.lm_head = model.get_language_model().lm_head + else: + self.model.lm_head = model.lm_head + else: + if (hasattr(model, "lm_head") and hasattr(self.model, "lm_head") + and not self.model.has_own_lm_head): + logger.info("Draft model lm_head is uninitialized. " + "Sharing lm_head with the target model.") + del self.model.lm_head self.model.lm_head = model.lm_head + else: + logger.info("Draft model lm_head is already initialized. " + "Keeping separate lm_head from the target model.") @torch.inference_mode() def dummy_run(self,