From 644e25419b101e24b053566a100ce49a49cc3e2c Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Fri, 25 Apr 2025 23:59:20 +0000 Subject: [PATCH 1/6] apply torch.compile & cudagraph to EAGLE Signed-off-by: Bryan Lu --- vllm/compilation/backends.py | 15 +++- vllm/model_executor/models/llama_eagle.py | 25 ++++--- vllm/v1/spec_decode/eagle.py | 87 ++++++++++++++++++----- vllm/v1/worker/gpu_model_runner.py | 36 ++++++++-- 4 files changed, 128 insertions(+), 35 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index c493a764f56d..d8abeb74ee03 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -347,8 +347,12 @@ def configure_post_pass(self): PASS_KEY = "post_grad_custom_post_pass" if PASS_KEY in inductor_config: # Config should automatically wrap all inductor passes - assert isinstance(inductor_config[PASS_KEY], InductorPass) - self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) + if isinstance(inductor_config[PASS_KEY], PostGradPassManager): + assert (inductor_config[PASS_KEY].uuid() == + self.post_grad_pass_manager.uuid()) + else: + assert isinstance(inductor_config[PASS_KEY], InductorPass) + self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) inductor_config[PASS_KEY] = self.post_grad_pass_manager def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: @@ -404,8 +408,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: ) self.compilation_config.cache_dir = cache_dir - cache_dir = self.compilation_config.cache_dir + if compilation_counter.num_graphs_seen > 0: + cache_dir = self.compilation_config.cache_dir + \ + f'-{compilation_counter.num_graphs_seen}' + else: + cache_dir = self.compilation_config.cache_dir os.makedirs(cache_dir, exist_ok=True) + self.compilation_config.cache_dir = cache_dir rank = vllm_config.parallel_config.rank dp_rank = vllm_config.parallel_config.data_parallel_rank local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 28ad6128c4f1..e42791168530 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -6,7 +6,8 @@ import torch.nn as nn from transformers import LlamaConfig -from vllm.config import ModelConfig +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -37,17 +38,19 @@ def __init__( self.input_layernorm = nn.Identity() +@support_torch_compile class LlamaModel(nn.Module): def __init__( self, *, - model_config: ModelConfig, - start_layer_id: int = 0, + vllm_config: VllmConfig, prefix: str = "", + start_layer_id: int = 0, ) -> None: super().__init__() - self.config = model_config.hf_config + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, @@ -75,8 +78,7 @@ def forward( hidden_states = self.fc( torch.cat((input_embeds, hidden_states), dim=-1)) residual = None - for i in range(len(self.layers)): - layer = self.layers[i] + for layer in self.layers: hidden_states, residual = layer( positions, hidden_states, @@ -116,12 +118,13 @@ def load_weights(self, weights: Iterable[Tuple[str, class EagleLlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): + def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): nn.Module.__init__(self) - self.config = model_config.hf_config - self.model = LlamaModel(model_config=model_config, - start_layer_id=start_layer_id, - prefix="model") + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + self.model = LlamaModel(vllm_config=vllm_config, + prefix="model", + start_layer_id=start_layer_id) logit_scale = getattr(self.config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.config.vocab_size, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 95f0c067d406..17a623a56025 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -4,7 +4,7 @@ import triton import triton.language as tl -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype @@ -27,6 +27,32 @@ def __init__( vllm_config.speculative_config.num_speculative_tokens) self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size + + self.dtype = vllm_config.model_config.dtype + + self.max_num_tokens = vllm_config.scheduler_config \ + .max_num_batched_tokens + self.hidden_size = vllm_config.model_config.get_hidden_size() + + self.use_cuda_graph = (self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE and + not self.vllm_config.model_config.enforce_eager) + + self.cudagraph_batch_sizes = list( + reversed( + self.vllm_config.compilation_config.cudagraph_capture_sizes)) + + # persistent buffers for cuda graph + self.input_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=device) + self.positions = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=device) + self.hidden_states = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=device) # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + @@ -34,6 +60,16 @@ def __init__( device=device, dtype=torch.int32) + def copy_model_inputs_to_buffer( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> None: + self.input_ids[:input_ids.size(0)].copy_(input_ids) + self.positions[:positions.size(0)].copy_(positions) + self.hidden_states[:hidden_states.size(0)].copy_(hidden_states) + def propose( self, # [num_tokens] @@ -51,18 +87,18 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, + num_actual_draft_tokens: int, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 - input_ids = torch.empty_like(target_token_ids) # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - input_ids[:-1] = target_token_ids[1:] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - input_ids[last_token_indices] = next_token_ids + self.input_ids[last_token_indices] = next_token_ids # FA requires seq_len to have dtype int32. seq_lens = (target_positions[last_token_indices] + 1).int() @@ -71,7 +107,7 @@ def propose( max_seq_len = seq_lens.max().item() max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() attn_metadata = FlashAttentionMetadata( - num_actual_tokens=num_tokens, + num_actual_tokens=num_actual_draft_tokens, max_query_len=max_num_tokens, query_start_loc=cu_num_tokens, max_seq_len=max_seq_len, @@ -85,12 +121,17 @@ def propose( prefix_kv_lens=None, suffix_kv_lens=None, ) + attn_metadata.num_input_tokens = num_tokens + + # copy inputs to buffer for cudagraph + self.positions[:num_tokens] = target_positions + self.hidden_states[:num_tokens] = target_hidden_states with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( - input_ids=input_ids, - hidden_states=target_hidden_states, - positions=target_positions, + input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], ) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -106,12 +147,22 @@ def propose( positions = target_positions[last_token_indices] hidden_states = sample_hidden_states + + if self.use_cuda_graph and \ + batch_size <= self.cudagraph_batch_sizes[-1]: + padded_batch_size = self.vllm_config. \ + pad_for_cudagraph(batch_size) + else: + padded_batch_size = batch_size attn_metadata.num_actual_tokens = batch_size + attn_metadata.num_input_tokens = padded_batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] for _ in range(self.num_speculative_tokens - 1): # Update the inputs. - input_ids = draft_token_ids_list[-1] + # cast to int32 is crucial when eagle model is compiled. + # tensor.argmax() returns int64 by default. + input_ids = draft_token_ids_list[-1].int() positions += 1 # NOTE(woosuk): We should handle the case where the draft model @@ -149,13 +200,19 @@ def propose( attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + self.hidden_states[:batch_size] = hidden_states + # Run the model. with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( - input_ids=input_ids, - hidden_states=hidden_states, - positions=clamped_positions, + input_ids=self.input_ids[:padded_batch_size], + positions=self.positions[:padded_batch_size], + hidden_states=self.hidden_states[:padded_batch_size], ) + hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(hidden_states, None) draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) @@ -222,13 +279,11 @@ def load_model(self, target_model: nn.Module) -> None: draft_model_config.dtype), set_current_vllm_config( self.vllm_config): self.model = EagleLlamaForCausalLM( - model_config=draft_model_config, + vllm_config=self.vllm_config, start_layer_id=target_layer_num).to(target_device) self.model.load_weights( - loader.get_all_weights( - self.vllm_config.speculative_config.draft_model_config, - self.model)) + loader.get_all_weights(draft_model_config, self.model)) self.model.lm_head = target_model.lm_head diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 86f6a301fbb6..06abe9a8df66 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1089,7 +1089,6 @@ def execute_model( # For mid-pipeline stages, return the hidden states. return hidden_states - hidden_states = hidden_states[:num_scheduled_tokens] sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1155,7 +1154,7 @@ def execute_model( # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states, + hidden_states[:num_scheduled_tokens], scheduler_output, ) @@ -1208,9 +1207,10 @@ def execute_model( # We need to slice token_ids, positions, and hidden_states # because the eagle head does not use cuda graph and should # not include padding. - target_token_ids = self.input_ids[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - target_hidden_states = hidden_states[:num_scheduled_tokens] + num_actual_draft_tokens = num_scheduled_tokens + target_token_ids = self.input_ids[:num_input_tokens] + target_positions = positions[:num_input_tokens] + target_hidden_states = hidden_states[:num_input_tokens] target_slot_mapping = attn_metadata.slot_mapping cu_num_tokens = attn_metadata.query_start_loc else: @@ -1229,6 +1229,19 @@ def execute_model( attn_metadata.query_start_loc, num_rejected_tokens, ) + + num_actual_draft_tokens = len(token_indices) + if self.use_cuda_graph and \ + num_actual_draft_tokens <= self.cudagraph_batch_sizes[-1]: + num_padded_draft_tokens = self.vllm_config. \ + pad_for_cudagraph(num_actual_draft_tokens) + + if num_padded_draft_tokens > num_actual_draft_tokens: + token_indices = torch.cat(( + token_indices, + token_indices[-1].repeat(num_padded_draft_tokens - + num_actual_draft_tokens))) + target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] target_hidden_states = hidden_states[token_indices] @@ -1243,6 +1256,7 @@ def execute_model( cu_num_tokens=cu_num_tokens, block_table=attn_metadata.block_table, sampling_metadata=sampling_metadata, + num_actual_draft_tokens=num_actual_draft_tokens, ) spec_token_ids = draft_token_ids.tolist() @@ -1470,6 +1484,18 @@ def _dummy_run( inputs_embeds=inputs_embeds, ) + if self.use_spec_decode and \ + self.speculative_config.method == 'eagle': + assert isinstance(self.drafter, EagleProposer) + with set_forward_context(None, + self.drafter.vllm_config, + num_tokens=num_tokens): + self.drafter.model( + input_ids=self.drafter.input_ids[:num_tokens], + positions=self.drafter.positions[:num_tokens], + hidden_states=self.drafter.hidden_states[:num_tokens], + ) + logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states[logit_indices] From 262356d2620c09671875867e60d3ee8ecdb604bd Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Sat, 26 Apr 2025 00:19:00 +0000 Subject: [PATCH 2/6] remove redundant code Signed-off-by: Bryan Lu --- vllm/v1/spec_decode/eagle.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 17a623a56025..493edbd09d9c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -60,16 +60,6 @@ def __init__( device=device, dtype=torch.int32) - def copy_model_inputs_to_buffer( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> None: - self.input_ids[:input_ids.size(0)].copy_(input_ids) - self.positions[:positions.size(0)].copy_(positions) - self.hidden_states[:hidden_states.size(0)].copy_(hidden_states) - def propose( self, # [num_tokens] From b6a5c3d679464541291133765fb20c1e6b21d439 Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Sat, 26 Apr 2025 06:59:52 +0000 Subject: [PATCH 3/6] remove outdated comments Signed-off-by: Bryan Lu --- vllm/v1/worker/gpu_model_runner.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ca71bbaaeaed..aaa3f881745b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1215,9 +1215,6 @@ def execute_model( if spec_decode_metadata is None: # input_ids can be None for multimodal models. - # We need to slice token_ids, positions, and hidden_states - # because the eagle head does not use cuda graph and should - # not include padding. num_actual_draft_tokens = num_scheduled_tokens target_token_ids = self.input_ids[:num_input_tokens] target_positions = positions[:num_input_tokens] From 7bea2d11fdfe73dc7a777327c853b0038ea4a2a6 Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Mon, 28 Apr 2025 22:43:06 +0000 Subject: [PATCH 4/6] avoid pad token indices Signed-off-by: Bryan Lu --- vllm/v1/spec_decode/eagle.py | 33 ++++++++++++++++-------------- vllm/v1/worker/gpu_model_runner.py | 23 ++++----------------- 2 files changed, 22 insertions(+), 34 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index efd336c075ea..b4742e600743 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -86,7 +86,6 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, - num_actual_draft_tokens: int, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -106,7 +105,7 @@ def propose( max_seq_len = seq_lens.max().item() max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() attn_metadata = FlashAttentionMetadata( - num_actual_tokens=num_actual_draft_tokens, + num_actual_tokens=num_tokens, max_query_len=max_num_tokens, query_start_loc=cu_num_tokens, max_seq_len=max_seq_len, @@ -120,23 +119,28 @@ def propose( prefix_kv_lens=None, suffix_kv_lens=None, ) - attn_metadata.num_input_tokens = num_tokens + if self.use_cuda_graph and \ + num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + num_input_tokens = num_tokens + attn_metadata.num_input_tokens = num_input_tokens # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions if self.method == 'eagle': self.hidden_states[:num_tokens] = target_hidden_states - hidden_states = self.hidden_states[:num_tokens] + hidden_states = self.hidden_states else: # TODO: make eagle3 compatible with cuda graph - hidden_states = target_hidden_states[:num_tokens] + hidden_states = target_hidden_states with set_forward_context(attn_metadata, self.vllm_config): last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:num_tokens], - positions=self.positions[:num_tokens], - hidden_states=hidden_states, + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + hidden_states=hidden_states[:num_input_tokens], ) sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -155,13 +159,12 @@ def propose( if self.use_cuda_graph and \ batch_size <= self.cudagraph_batch_sizes[-1]: - padded_batch_size = self.vllm_config. \ - pad_for_cudagraph(batch_size) + input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: - padded_batch_size = batch_size + input_batch_size = batch_size attn_metadata.num_actual_tokens = batch_size - attn_metadata.num_input_tokens = padded_batch_size + attn_metadata.num_input_tokens = input_batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] for _ in range(self.num_speculative_tokens - 1): @@ -218,9 +221,9 @@ def propose( # Run the model. with set_forward_context(attn_metadata, self.vllm_config): last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:padded_batch_size], - positions=self.positions[:padded_batch_size], - hidden_states=hidden_states[:padded_batch_size], + input_ids=self.input_ids[:input_batch_size], + positions=self.positions[:input_batch_size], + hidden_states=hidden_states[:input_batch_size], ) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index aaa3f881745b..04556fdb5028 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1215,15 +1215,14 @@ def execute_model( if spec_decode_metadata is None: # input_ids can be None for multimodal models. - num_actual_draft_tokens = num_scheduled_tokens - target_token_ids = self.input_ids[:num_input_tokens] - target_positions = positions[:num_input_tokens] + target_token_ids = self.input_ids[:num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( - [h[:num_input_tokens] for h in aux_hidden_states], + [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1) else: - target_hidden_states = hidden_states[:num_input_tokens] + target_hidden_states = hidden_states[:num_scheduled_tokens] target_slot_mapping = attn_metadata.slot_mapping cu_num_tokens = attn_metadata.query_start_loc else: @@ -1242,19 +1241,6 @@ def execute_model( attn_metadata.query_start_loc, num_rejected_tokens, ) - - num_actual_draft_tokens = len(token_indices) - if self.use_cuda_graph and \ - num_actual_draft_tokens <= self.cudagraph_batch_sizes[-1]: - num_padded_draft_tokens = self.vllm_config. \ - pad_for_cudagraph(num_actual_draft_tokens) - - if num_padded_draft_tokens > num_actual_draft_tokens: - token_indices = torch.cat(( - token_indices, - token_indices[-1].repeat(num_padded_draft_tokens - - num_actual_draft_tokens))) - target_token_ids = self.input_ids[token_indices] target_positions = positions[token_indices] if self.use_aux_hidden_state_outputs: @@ -1273,7 +1259,6 @@ def execute_model( cu_num_tokens=cu_num_tokens, block_table=attn_metadata.block_table, sampling_metadata=sampling_metadata, - num_actual_draft_tokens=num_actual_draft_tokens, ) spec_token_ids = draft_token_ids.tolist() From a49c3e4af3d3079dc4ca670f7ed5a0ec8c618965 Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Mon, 28 Apr 2025 22:56:50 +0000 Subject: [PATCH 5/6] minor Signed-off-by: Bryan Lu --- vllm/v1/spec_decode/eagle.py | 4 ---- vllm/v1/worker/gpu_model_runner.py | 1 - 2 files changed, 5 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index b4742e600743..6c030c26bd38 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -125,7 +125,6 @@ def propose( else: num_input_tokens = num_tokens attn_metadata.num_input_tokens = num_input_tokens - # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions @@ -156,13 +155,11 @@ def propose( positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] - if self.use_cuda_graph and \ batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size - attn_metadata.num_actual_tokens = batch_size attn_metadata.num_input_tokens = input_batch_size attn_metadata.max_query_len = 1 @@ -304,7 +301,6 @@ def load_model(self, target_model: nn.Module) -> None: loaded_weights = self.model.load_weights( loader.get_all_weights(draft_model_config, self.model)) - if self.vllm_config.speculative_config.method == "eagle3": if "model.embed_tokens.weight" not in loaded_weights: logger.info( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 04556fdb5028..1f6db24fb2ed 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1488,7 +1488,6 @@ def _dummy_run( intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) - if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: From 784d1b623250d9650fe6d117664d4aaafc68ab32 Mon Sep 17 00:00:00 2001 From: Bryan Lu Date: Mon, 28 Apr 2025 23:59:38 +0000 Subject: [PATCH 6/6] dummy_run() method for eagle proposer Signed-off-by: Bryan Lu --- vllm/v1/spec_decode/eagle.py | 14 ++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 11 ++--------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6c030c26bd38..b60df6d3b54b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -310,6 +310,20 @@ def load_model(self, target_model: nn.Module) -> None: logger.info("Loading EAGLE LM head weights from the target model.") self.model.lm_head = target_model.lm_head + @torch.inference_mode() + def dummy_run( + self, + num_tokens: int, + ) -> None: + with set_forward_context(None, self.vllm_config, + num_tokens=num_tokens): + if self.method == 'eagle': + self.model( + input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + ) + # NOTE(woosuk): Currently, the below code is not used and we always use argmax # to sample the draft tokens. We will use this after we find a way to manage diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1f6db24fb2ed..c5d38c25cc07 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1494,16 +1494,9 @@ def _dummy_run( hidden_states = outputs if self.use_spec_decode and \ - self.speculative_config.method == 'eagle': + self.speculative_config.method in ('eagle', 'eagle3'): assert isinstance(self.drafter, EagleProposer) - with set_forward_context(None, - self.drafter.vllm_config, - num_tokens=num_tokens): - self.drafter.model( - input_ids=self.drafter.input_ids[:num_tokens], - positions=self.drafter.positions[:num_tokens], - hidden_states=self.drafter.hidden_states[:num_tokens], - ) + self.drafter.dummy_run(num_tokens) logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states[logit_indices]