From c531b0e527258a2798712891b6c57a048b9f51be Mon Sep 17 00:00:00 2001 From: morgendave Date: Fri, 9 May 2025 16:44:17 -0700 Subject: [PATCH] eagle mm support, primarily llama4 Signed-off-by: morgendave --- examples/offline_inference/spec_decode.py | 66 +++++++++++++++++++--- tests/v1/e2e/test_spec_decode.py | 61 ++++++++++++++------ vllm/model_executor/models/llama4.py | 1 + vllm/model_executor/models/llama4_eagle.py | 35 ++++++++++-- vllm/model_executor/models/llama_eagle.py | 6 ++ vllm/model_executor/models/llama_eagle3.py | 5 ++ vllm/v1/spec_decode/eagle.py | 59 ++++++++++++++++--- vllm/v1/worker/gpu_model_runner.py | 10 +++- 8 files changed, 206 insertions(+), 37 deletions(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index ce735f3b27df..184c30891eca 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -13,6 +13,38 @@ from argparse import ArgumentParser as FlexibleArgumentParser +QUESTION = "What is the content of each image?" +IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg", + "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg", + "https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG", + "https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg", + "https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg", + "https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg", + "https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg", + "https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg", +] + + +def get_custom_mm_prompts(num_prompts): + prompts = [] + for url in IMAGE_URLS: + prompts.append( + [ + {"type": "image_url", "image_url": {"url": url}}, + {"type": "text", "text": QUESTION}, + ] + ) + if num_prompts > len(IMAGE_URLS): + prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1) + + return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]] + + def parse_args(): parser = FlexibleArgumentParser() add_dataset_parser(parser) @@ -35,6 +67,7 @@ def parse_args(): parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) + parser.add_argument("--custom-mm-prompts", action="store_true") return parser.parse_args() @@ -44,14 +77,26 @@ def main(): model_dir = args.model_dir if args.model_dir is None: + if args.custom_mm_prompts: + raise ValueError( + "custom_mm_prompts requires mm based models" + "default llama3.1-8b-instruct is not mm based" + "please specify model_dir to give a mm based model" + ) model_dir = "meta-llama/Llama-3.1-8B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_dir) - - prompts = get_samples(args, tokenizer) - # add_special_tokens is False to avoid adding bos twice when using chat templates - prompt_ids = [ - tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts - ] + args.custom_skip_chat_template = True + + if not args.custom_mm_prompts: + prompts = get_samples(args, tokenizer) + # add_special_tokens is False to avoid adding bos twice + # when using chat templates + prompt_ids = [ + tokenizer.encode(prompt.prompt, add_special_tokens=False) + for prompt in prompts + ] + else: + prompts = get_custom_mm_prompts(args.num_prompts) if args.method == "eagle" or args.method == "eagle3": eagle_dir = args.eagle_dir @@ -85,10 +130,17 @@ def main(): speculative_config=speculative_config, disable_log_stats=False, max_model_len=16384, + limit_mm_per_prompt={"image": 5}, + disable_chunked_mm_input=True, ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) - outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + if not args.custom_mm_prompts: + outputs = llm.generate( + prompt_token_ids=prompt_ids, sampling_params=sampling_params + ) + else: + outputs = llm.chat(prompts, sampling_params=sampling_params) # print the generated text if args.print_output: diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 2423f966acfa..31f25e94c5b4 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -3,29 +3,34 @@ from __future__ import annotations import random -from typing import Any +from typing import Any, Union import pytest import torch from vllm import LLM, SamplingParams +from vllm.assets.base import VLLM_S3_BUCKET_URL +from vllm.assets.image import VLM_IMAGES_DIR from vllm.distributed import cleanup_dist_env_and_memory -@pytest.fixture -def test_prompts(): +def get_test_prompts(mm_enabled: bool): prompt_types = ["repeat", "sentence"] + if mm_enabled: + prompt_types.append("mm") num_prompts = 100 prompts = [] random.seed(0) random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) + print(f"Prompt types: {random_prompt_type_choices}") # Generate a mixed batch of prompts, some of which can be easily # predicted by n-gram matching and some which likely cannot. for kind in random_prompt_type_choices: word_choices = ["test", "temp", "hello", "where"] word = random.choice(word_choices) + prompt: Union[str, list[dict[str, Any]]] = "" if kind == "repeat": prompt = f""" please repeat the word '{word}' 10 times. @@ -38,6 +43,21 @@ def test_prompts(): uses the word {word} at least once. give no other output than that simple sentence without quotes. """ + elif kind == "mm": + placeholders = [{ + "type": "image_url", + "image_url": { + "url": + f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg" + }, + }] + prompt = [ + *placeholders, + { + "type": "text", + "text": "The meaning of the image is" + }, + ] else: raise ValueError(f"Unknown prompt type: {kind}") prompts.append([{"role": "user", "content": prompt}]) @@ -57,7 +77,6 @@ def model_name(): def test_ngram_correctness( monkeypatch: pytest.MonkeyPatch, - test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, model_name: str, ): @@ -67,6 +86,7 @@ def test_ngram_correctness( ''' with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + test_prompts = get_test_prompts(mm_enabled=False) ref_llm = LLM(model=model_name, max_model_len=1024) ref_outputs = ref_llm.chat(test_prompts, sampling_config) @@ -103,23 +123,32 @@ def test_ngram_correctness( cleanup_dist_env_and_memory() -@pytest.mark.parametrize("model_setup", [ - ("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), - ("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), -], - ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"]) +@pytest.mark.parametrize( + ["model_setup", "mm_enabled"], [ + (("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), + (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + False, + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + True, + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + ], + ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"]) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, - test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, model_setup: tuple[str, str, str, int], + mm_enabled: bool, ): + # Generate test prompts inside the function instead of using fixture + test_prompts = get_test_prompts(mm_enabled) ''' Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index fab1c163ac28..7f9a8fdabdf3 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -256,6 +256,7 @@ def __init__( super().__init__() self.layer_idx = extract_layer_index(prefix) + self.global_layer = config.no_rope_layers[self.layer_idx] == 0 self.hidden_size = config.hidden_size rope_theta = config.rope_theta rope_scaling = config.rope_scaling diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 222ab5dfaee4..ece490ff2f2a 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -37,8 +37,9 @@ from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, Llama4ForCausalLM) from vllm.model_executor.models.utils import extract_layer_index +from vllm.multimodal.inputs import NestedTensors -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings logger = init_logger(__name__) @@ -78,15 +79,23 @@ def __init__( self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - input_embeds = self.embed_tokens(input_ids) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) hidden_states = self.fc( - torch.cat((input_embeds, hidden_states), dim=-1)) + torch.cat((inputs_embeds, hidden_states), dim=-1)) residual = None for layer in self.layers: hidden_states, residual = layer( @@ -190,8 +199,9 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - return self.model(input_ids, positions, hidden_states) + return self.model(input_ids, positions, hidden_states, inputs_embeds) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: @@ -212,3 +222,20 @@ def load_weights(self, weights: Iterable[tuple[str, model_weights[name] = loaded_weight loader.load_weights(model_weights.items()) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.config.image_token_index, + ) + + return inputs_embeds diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index c7690604c1d0..a4933b77e3a5 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from typing import Optional import torch import torch.nn as nn @@ -148,7 +149,12 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + raise NotImplementedError( + f"{type(self).__name__} does not support multimodal inputs yet." + ) return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 7fc9fe2ebb6f..71275f0d5857 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -202,7 +202,12 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + raise NotImplementedError( + f"{type(self).__name__} does not support multimodal inputs yet." + ) return self.model(input_ids, positions, hidden_states) def compute_logits( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 967847c02ff2..48eea9a70d37 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + import numpy as np import torch import torch.nn as nn @@ -51,6 +53,9 @@ def __init__( # hidden size (e.g., Llama 3.3 70B). self.hidden_size = self.draft_model_config.get_hidden_size() + self.is_multimodal_model = vllm_config.model_config \ + .is_multimodal_model + self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.vllm_config.model_config.enforce_eager) @@ -76,6 +81,11 @@ def __init__( device=device, dtype=torch.int32) + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=device) + def propose( self, # [num_tokens] @@ -88,6 +98,7 @@ def propose( next_token_ids: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, + mm_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -128,14 +139,27 @@ def propose( # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states + if self.is_multimodal_model: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = self.model.get_input_embeddings( + input_ids, + multimodal_embeddings=mm_embeds or None, + ) + self.inputs_embeds[:num_tokens] = inputs_embeds + inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None + else: + inputs_embeds = None + input_ids = self.input_ids[:num_input_tokens] with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): ret_hidden_states = self.model( - self.input_ids[:num_input_tokens], - self.positions[:num_input_tokens], - self.hidden_states[:num_input_tokens], + input_ids=input_ids, + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], + inputs_embeds=inputs_embeds, ) if self.method == "deepseek_mtp": last_hidden_states = ret_hidden_states @@ -218,15 +242,24 @@ def propose( self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states + if self.is_multimodal_model: + inputs_embeds = self.model.get_input_embeddings(input_ids) + self.inputs_embeds[:batch_size] = inputs_embeds + inputs_embeds = self.inputs_embeds[:input_batch_size] + input_ids = None + else: + inputs_embeds = None + input_ids = self.input_ids[:input_batch_size] # Run the model. with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): last_hidden_states, hidden_states = self.model( - self.input_ids[:input_batch_size], - self.positions[:input_batch_size], - self.hidden_states[:input_batch_size], + input_ids=input_ids, + positions=self.positions[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], + inputs_embeds=inputs_embeds, ) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], @@ -390,10 +423,18 @@ def dummy_run( ) -> None: with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + self.model( - self.input_ids[:num_tokens], - self.positions[:num_tokens], - self.hidden_states[:num_tokens], + input_ids=input_ids, + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + inputs_embeds=inputs_embeds, ) def validate_same_kv_cache_group(self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a5bf197ba161..623493e93f52 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1108,13 +1108,15 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", + shift_computed_tokens: int = 0, ) -> list[torch.Tensor]: mm_embeds: list[torch.Tensor] = [] for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] - num_computed_tokens = req_state.num_computed_tokens + num_computed_tokens = \ + req_state.num_computed_tokens + shift_computed_tokens mm_positions = req_state.mm_positions for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset @@ -1734,6 +1736,11 @@ def propose_draft_token_ids( [h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] + mm_embeds = None + if self.is_multimodal_model: + mm_embeds = self._gather_mm_embeddings(scheduler_output, + shift_computed_tokens=1) + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -1741,6 +1748,7 @@ def propose_draft_token_ids( next_token_ids=next_token_ids, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, + mm_embeds=mm_embeds, ) spec_token_ids = draft_token_ids.tolist() return spec_token_ids