diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 020521611f33..615f67e9f8d8 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -105,6 +105,13 @@ def main(): outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + # print the generated text + for output in outputs: + print("-" * 50) + print(f"prompt: {output.prompt}") + print(f"generated text: {output.outputs[0].text}") + print("-" * 50) + if not hasattr(outputs, "metrics") or outputs.metrics is None: return diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 76655bd71b15..4e51daa220e4 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -8,6 +8,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -52,11 +53,15 @@ def __init__( self.config = vllm_config. \ speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.config.vocab_size, - self.config.hidden_size, - prefix=maybe_prefix(prefix, "embed_tokens"), - ) + + # if PP disabled then draft will share embed with target + if get_pp_group().world_size > 1: + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + self.layers = nn.ModuleList([ LlamaDecoderLayer( self.config, @@ -109,6 +114,12 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight, shard_id) break else: + + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -142,8 +153,7 @@ def forward( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=None, ) model_weights = {} @@ -151,5 +161,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if "lm_head" not in name: name = "model." + name model_weights[name] = loaded_weight - - loader.load_weights(model_weights.items()) + return loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 904ff3210943..9761c8389db2 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -8,6 +8,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear @@ -91,11 +92,15 @@ def __init__( self.config = vllm_config. \ speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.config.vocab_size, - self.config.hidden_size, - prefix=maybe_prefix(prefix, "embed_tokens"), - ) + + # if PP disabled then draft will share embed with target + if get_pp_group().world_size > 1: + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + self.layers = nn.ModuleList([ LlamaDecoderLayer( self.config, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 13cfcc4bbb6e..9fbdc8f848e7 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -5,6 +5,7 @@ from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config, set_current_vllm_config) +from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model_loader @@ -302,12 +303,30 @@ def load_model(self, target_model: nn.Module) -> None: self.attn_layer_name = next(iter(draft_attn_layer_names)) loaded_weights = self.model.load_weights( loader.get_all_weights(draft_model_config, self.model)) - if self.vllm_config.speculative_config.method == "eagle3": - if "model.embed_tokens.weight" not in loaded_weights: - logger.info( - "Loading EAGLE embedding weights from the target model.") - self.model.model.embed_tokens = target_model.model.embed_tokens + + # share embed_tokens with the target model if needed + if get_pp_group().world_size == 1: + assert "model.embed_tokens.weight" not in loaded_weights, \ + "For PP = 1, Eagle draft should share embed with target model" + logger.info( + "The EAGLE head shares the same vocab embedding" \ + " with the target model." + ) + self.model.model.embed_tokens = target_model.model.embed_tokens else: + assert "model.embed_tokens.weight" in loaded_weights, \ + "For PP > 1, Eagle draft checkpoint should its own copy of " + " the model.embed_tokens.weight" + logger.info( + "Since PP > 1, the EAGLE head loaded its own vocab embedding" \ + " weights instead of sharing them with the target model." + ) + + # share lm_head with the target model if needed + # some model definition do not define lm_head explicitly + # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM + if self.vllm_config.speculative_config.method != "eagle3" and \ + hasattr(target_model, "lm_head"): logger.info("Loading EAGLE LM head weights from the target model.") self.model.lm_head = target_model.lm_head