From b148f75d0ff87b9165f2878b316cc2877faa34b6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 28 Mar 2025 15:23:09 -0700 Subject: [PATCH 01/16] Implement Eagle proposer Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 158 +++++++++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 vllm/v1/spec_decode/eagle.py diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py new file mode 100644 index 000000000000..389a7ae957f6 --- /dev/null +++ b/vllm/v1/spec_decode/eagle.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.sample.metadata import SamplingMetadata + + +class EagleProposer: + + def __init__( + self, + vllm_config: VllmConfig, + num_speculative_tokens: int, + ): + self.model = ... + self.vllm_config = vllm_config + self.num_speculative_tokens = num_speculative_tokens + self.block_size = vllm_config.cache_config.block_size + + def propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + # [batch_size + 1] + cu_num_tokens: torch.Tensor, + max_num_tokens: int, + # [batch_size, max_num_blocks_per_req] + block_table: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: # [batch_size, num_speculative_tokens] + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + + last_token_indices = cu_num_tokens[1:] - 1 + input_ids = target_token_ids + input_ids[last_token_indices] = next_token_ids + + seq_lens = target_positions[last_token_indices] + 1 + max_seq_len = seq_lens.max().item() # FIXME: Avoid synchronization. + slot_mapping = compute_slot_mapping( + positions=target_positions, + block_table=block_table, + block_size=self.block_size, + ) + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_tokens, + max_query_len=max_num_tokens, + query_start_loc=cu_num_tokens, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table, + slot_mapping=slot_mapping, + # TODO(woosuk): Add cascade attention. + use_cascade=False, + common_prefix_len=0, + cu_prefix_query_lens=None, + prefix_kv_lens=None, + suffix_kv_lens=None, + ) + + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + hidden_states=target_hidden_states, + positions=target_positions, + ) + sample_hidden_states = hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + + # Sample the next token. + draft_token_ids = sample_token_ids(logits, sampling_metadata) + + # Early exit if there is only one draft token to be generated. + if self.num_speculative_tokens == 1: + return draft_token_ids.view(-1, 1) + + # Generate the remaining draft tokens. + draft_token_ids_list: list[torch.Tensor] = [draft_token_ids] + positions = target_positions[last_token_indices] + hidden_states = sample_hidden_states + attn_metadata.num_actual_tokens = batch_size + attn_metadata.query_start_loc = torch.arange(batch_size, + device=positions.device) + for _ in range(self.num_speculative_tokens - 1): + # Update the inputs. + input_ids = draft_token_ids_list[-1] + positions += 1 + attn_metadata.max_query_len = 1 + attn_metadata.max_seq_len += 1 + attn_metadata.seq_lens += 1 + attn_metadata.slot_mapping = compute_slot_mapping( + positions=positions, + block_table=block_table, + block_size=self.block_size, + ) + + # Run the model. + hidden_states = self.model( + input_ids=input_ids, + hidden_states=hidden_states, + positions=positions, + ) + logits = self.model.compute_logits(hidden_states, None) + draft_token_ids = sample_token_ids(logits, sampling_metadata) + draft_token_ids_list.append(draft_token_ids) + + # [batch_size, num_speculative_tokens] + return torch.stack(draft_token_ids_list, dim=1) + + +def sample_token_ids( + logits: torch.Tensor, # [batch_size, vocab_size] + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: # [batch_size] + # NOTE(woosuk): We don't need to apply all the sampling parameters + # for generating the draft tokens. + if sampling_metadata.all_greedy: + # All greedy. + next_token_ids = logits.argmax(dim=-1) + else: + logits.div_(sampling_metadata.temperature) + probs = logits.softmax(dim=-1, dtype=torch.float32) + + # TODO(woosuk): Consider seeds? + q = torch.empty_like(logits) + q.exponential_() + next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) + + if not sampling_metadata.all_random: + greedy_token_ids = logits.argmax(dim=-1) + next_token_ids = torch.where( + sampling_metadata.temperature == -1, + greedy_token_ids, + next_token_ids, + ) + return next_token_ids + + +def compute_slot_mapping( + # [num_tokens] + positions: torch.Tensor, + # [batch_size, max_num_blocks_per_req] + block_table: torch.Tensor, + block_size: int, +) -> torch.Tensor: # [num_tokens] + # [num_tokens] + block_numbers = positions // block_size + block_ids = block_table.gather(dim=1, index=block_numbers) + slot_mapping = block_ids * block_size + positions % block_size + return slot_mapping From 657d311759645f544e15693f4d91cf00fdf8db9b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 28 Mar 2025 15:27:00 -0700 Subject: [PATCH 02/16] minor Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 389a7ae957f6..6bd83f5c9d9a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -83,7 +83,7 @@ def propose( return draft_token_ids.view(-1, 1) # Generate the remaining draft tokens. - draft_token_ids_list: list[torch.Tensor] = [draft_token_ids] + draft_token_ids_list = [draft_token_ids] positions = target_positions[last_token_indices] hidden_states = sample_hidden_states attn_metadata.num_actual_tokens = batch_size From 4e2a2d18db2e7506457fff4e090ef4e1279dd3f1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 28 Mar 2025 15:43:40 -0700 Subject: [PATCH 03/16] minor Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6bd83f5c9d9a..1591b2a10877 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -29,7 +29,7 @@ def propose( target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, - # [batch_size + 1] + # [batch_size + 1] starting with 0 cu_num_tokens: torch.Tensor, max_num_tokens: int, # [batch_size, max_num_blocks_per_req] @@ -38,9 +38,14 @@ def propose( ) -> torch.Tensor: # [batch_size, num_speculative_tokens] num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = cu_num_tokens[1:] - 1 + input_ids = target_token_ids + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + input_ids[:-1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] input_ids[last_token_indices] = next_token_ids seq_lens = target_positions[last_token_indices] + 1 From 1b340f23117f420001e8c9ac8ef895452e207549 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 28 Mar 2025 15:56:51 -0700 Subject: [PATCH 04/16] Minor Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1591b2a10877..b4be93ceff82 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -13,11 +13,16 @@ def __init__( self, vllm_config: VllmConfig, num_speculative_tokens: int, + device: torch.device, ): self.model = ... self.vllm_config = vllm_config self.num_speculative_tokens = num_speculative_tokens self.block_size = vllm_config.cache_config.block_size + self.arange = torch.arange( + vllm_config.scheduler_config.max_num_seqs, + device=device, + ) def propose( self, @@ -92,13 +97,12 @@ def propose( positions = target_positions[last_token_indices] hidden_states = sample_hidden_states attn_metadata.num_actual_tokens = batch_size - attn_metadata.query_start_loc = torch.arange(batch_size, - device=positions.device) + attn_metadata.max_query_len = 1 + attn_metadata.query_start_loc = self.arange[:batch_size] for _ in range(self.num_speculative_tokens - 1): # Update the inputs. input_ids = draft_token_ids_list[-1] positions += 1 - attn_metadata.max_query_len = 1 attn_metadata.max_seq_len += 1 attn_metadata.seq_lens += 1 attn_metadata.slot_mapping = compute_slot_mapping( @@ -150,10 +154,8 @@ def sample_token_ids( def compute_slot_mapping( - # [num_tokens] - positions: torch.Tensor, - # [batch_size, max_num_blocks_per_req] - block_table: torch.Tensor, + positions: torch.Tensor, # [num_tokens] + block_table: torch.Tensor, # [batch_size, max_num_blocks_per_req] block_size: int, ) -> torch.Tensor: # [num_tokens] # [num_tokens] From 382e6d0a501238a1b68252cf2bbfae8b54c35d42 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 28 Mar 2025 15:58:02 -0700 Subject: [PATCH 05/16] Minor Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index b4be93ceff82..7d42d4b2290b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -19,10 +19,8 @@ def __init__( self.vllm_config = vllm_config self.num_speculative_tokens = num_speculative_tokens self.block_size = vllm_config.cache_config.block_size - self.arange = torch.arange( - vllm_config.scheduler_config.max_num_seqs, - device=device, - ) + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs, + device=device) def propose( self, From a4f043815e0720007c5d8899b1ac3a87b6d4d35e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 28 Mar 2025 15:58:31 -0700 Subject: [PATCH 06/16] Fix Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 7d42d4b2290b..fb3611cf2b8a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -110,11 +110,12 @@ def propose( ) # Run the model. - hidden_states = self.model( - input_ids=input_ids, - hidden_states=hidden_states, - positions=positions, - ) + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + hidden_states=hidden_states, + positions=positions, + ) logits = self.model.compute_logits(hidden_states, None) draft_token_ids = sample_token_ids(logits, sampling_metadata) draft_token_ids_list.append(draft_token_ids) From d4b0cf47c518abdc649bfe142a8961f028584849 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 28 Mar 2025 21:50:36 -0700 Subject: [PATCH 07/16] minor Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index fb3611cf2b8a..4de022244da0 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -12,12 +12,12 @@ class EagleProposer: def __init__( self, vllm_config: VllmConfig, - num_speculative_tokens: int, device: torch.device, ): self.model = ... self.vllm_config = vllm_config - self.num_speculative_tokens = num_speculative_tokens + self.num_speculative_tokens = ( + vllm_config.speculative_config.num_speculative_tokens) self.block_size = vllm_config.cache_config.block_size self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs, device=device) From 07dfa92a665effe22f5c71ac0f1007ff187c584a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 29 Mar 2025 04:52:16 -0700 Subject: [PATCH 08/16] minor Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4de022244da0..9b22c41d7710 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -66,7 +66,7 @@ def propose( seq_lens=seq_lens, block_table=block_table, slot_mapping=slot_mapping, - # TODO(woosuk): Add cascade attention. + # TODO(woosuk): Support cascade attention. use_cascade=False, common_prefix_len=0, cu_prefix_query_lens=None, @@ -124,6 +124,8 @@ def propose( return torch.stack(draft_token_ids_list, dim=1) +# TODO(woosuk): The logic here is duplicated with the main sampling code. +# We should refactor this to reuse the same sampling implementation. def sample_token_ids( logits: torch.Tensor, # [batch_size, vocab_size] sampling_metadata: SamplingMetadata, @@ -157,7 +159,6 @@ def compute_slot_mapping( block_table: torch.Tensor, # [batch_size, max_num_blocks_per_req] block_size: int, ) -> torch.Tensor: # [num_tokens] - # [num_tokens] block_numbers = positions // block_size block_ids = block_table.gather(dim=1, index=block_numbers) slot_mapping = block_ids * block_size + positions % block_size From e5e559e6a7bff1bf7996a8b8ab1cf7dcd9951189 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 29 Mar 2025 09:45:47 -0700 Subject: [PATCH 09/16] max_num_tokens Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 7 +++--- vllm/v1/worker/gpu_model_runner.py | 35 +++++++++++++++++++----------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 9b22c41d7710..e108f1e5c424 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -34,7 +34,6 @@ def propose( next_token_ids: torch.Tensor, # [batch_size + 1] starting with 0 cu_num_tokens: torch.Tensor, - max_num_tokens: int, # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -43,7 +42,7 @@ def propose( batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 - input_ids = target_token_ids + input_ids = torch.empty_like(target_token_ids) # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] input_ids[:-1] = target_token_ids[1:] @@ -52,7 +51,9 @@ def propose( input_ids[last_token_indices] = next_token_ids seq_lens = target_positions[last_token_indices] + 1 - max_seq_len = seq_lens.max().item() # FIXME: Avoid synchronization. + # FIXME(woosuk): The below two ops cause synchronization. Optimize. + max_seq_len = seq_lens.max().item() + max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() slot_mapping = compute_slot_mapping( positions=target_positions, block_table=block_table, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4511a9aa85fd..50da85bd33f4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -36,6 +36,7 @@ from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -150,18 +151,14 @@ def __init__( self.use_spec_decode = False if self.speculative_config: self.use_spec_decode = True - assert self.speculative_config.method == "ngram", \ - "Currently, only ngram spec decode is supported in V1." if get_pp_group().is_last_rank: - self.drafter = NgramProposer() - # Trigger Numba JIT compilation for N-gram proposer. - # This usually takes less than 1 second. - self.drafter.propose( - np.zeros(1024, dtype=np.int32), - self.speculative_config.prompt_lookup_min, - self.speculative_config.prompt_lookup_max, - self.speculative_config.num_speculative_tokens, - ) + if self.speculative_config.method == "ngram": + self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.method == "eagle": + self.drafter = EagleProposer(self.vllm_config, self.device) + else: + raise ValueError("Unknown speculative decoding method: " + f"{self.speculative_config.method}") self.rejection_sampler = RejectionSampler() # Request states. @@ -1127,6 +1124,7 @@ def execute_model( sampled_token_ids, self.input_batch.vocab_size, ) + # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[i].clear() @@ -1134,8 +1132,19 @@ def execute_model( if not self.use_spec_decode: spec_token_ids = None else: - spec_token_ids = self.generate_draft_token_ids( - valid_sampled_token_ids, sampling_metadata) + token_indices = None + cu_num_tokens = None + next_token_ids = None # Sampled, Bonus, or Recovered + input_ids = self.input_ids[:num_scheduled_tokens] + spec_token_ids = self.drafter.propose( + target_token_ids=input_ids[token_indices], + target_positions=positions[token_indices], + target_hidden_states=hidden_states[token_indices], + next_token_ids=next_token_ids, + cu_num_tokens=cu_num_tokens, + block_table=attn_metadata.block_table, + sampling_metadata=sampling_metadata, + ) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, From 4a4bb60e0b8dee2acac5882f50f8992597ec8ad1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 29 Mar 2025 18:09:15 -0700 Subject: [PATCH 10/16] upstream Signed-off-by: Woosuk Kwon --- vllm/config.py | 6 +- vllm/engine/arg_utils.py | 4 +- vllm/v1/spec_decode/eagle.py | 102 +++++++++++++++++++++++------ vllm/v1/worker/gpu_input_batch.py | 11 +++- vllm/v1/worker/gpu_model_runner.py | 65 +++++++++++++++--- 5 files changed, 153 insertions(+), 35 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 6a15109c6744..122b05c9a45e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2133,9 +2133,9 @@ def __post_init__(self): # Replace hf_config for EAGLE draft_model if self.method == "eagle": - if self.enable_chunked_prefill: - raise ValueError( - "Chunked prefill and EAGLE are not compatible.") + # if self.enable_chunked_prefill: + # raise ValueError( + # "Chunked prefill and EAGLE are not compatible.") from vllm.transformers_utils.configs.eagle import ( EAGLEConfig) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ca511c7434f8..7e2cdf3b444f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1647,7 +1647,9 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: if (self.speculative_model is not None or self.num_speculative_tokens is not None): # This is supported but experimental (handled below). - if self.speculative_model in ("ngram", "[ngram]"): + supported_methods = ["ngram", "eagle"] + if any(method in self.speculative_model.lower() + for method in supported_methods): pass else: _raise_or_fallback(feature_name="Speculative Decoding", diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index e108f1e5c424..3fb749b728c9 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch +import triton +import triton.language as tl from vllm.config import VllmConfig from vllm.forward_context import set_forward_context @@ -14,7 +16,7 @@ def __init__( vllm_config: VllmConfig, device: torch.device, ): - self.model = ... + self.model = None self.vllm_config = vllm_config self.num_speculative_tokens = ( vllm_config.speculative_config.num_speculative_tokens) @@ -30,6 +32,8 @@ def propose( target_positions: torch.Tensor, # [num_tokens, hidden_size] target_hidden_states: torch.Tensor, + # [num_tokens] + target_slot_mapping: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, # [batch_size + 1] starting with 0 @@ -54,11 +58,6 @@ def propose( # FIXME(woosuk): The below two ops cause synchronization. Optimize. max_seq_len = seq_lens.max().item() max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() - slot_mapping = compute_slot_mapping( - positions=target_positions, - block_table=block_table, - block_size=self.block_size, - ) attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_tokens, max_query_len=max_num_tokens, @@ -66,7 +65,7 @@ def propose( max_seq_len=max_seq_len, seq_lens=seq_lens, block_table=block_table, - slot_mapping=slot_mapping, + slot_mapping=target_slot_mapping, # TODO(woosuk): Support cascade attention. use_cascade=False, common_prefix_len=0, @@ -104,11 +103,13 @@ def propose( positions += 1 attn_metadata.max_seq_len += 1 attn_metadata.seq_lens += 1 - attn_metadata.slot_mapping = compute_slot_mapping( - positions=positions, - block_table=block_table, - block_size=self.block_size, - ) + # Compute the slot mapping. + block_numbers = positions // self.block_size + block_ids = block_table.gather(dim=1, + index=block_numbers.view(-1, 1)) + block_ids = block_ids.view(-1) + attn_metadata.slot_mapping = (block_ids * self.block_size + + positions % self.block_size) # Run the model. with set_forward_context(attn_metadata, self.vllm_config): @@ -124,6 +125,49 @@ def propose( # [batch_size, num_speculative_tokens] return torch.stack(draft_token_ids_list, dim=1) + @staticmethod + def prepare_inputs( + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + + # [0, a, a + b, a + b + c] -> [a, b, c] + query_len_per_req = (cu_target_query_lens[1:] - + cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] + num_tokens_per_req = query_len_per_req - num_rejected_tokens + + cu_num_tokens = torch.empty_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + cu_num_tokens[0] = 0 + + # FIXME(woosuk): Avoid synchronization. + num_tokens = cu_num_tokens[-1].item() + token_indices = torch.empty( + num_tokens, + dtype=torch.int32, + device=cu_num_tokens.device, + ) + + batch_size = num_rejected_tokens.shape[0] + BLOCK_SIZE = 1024 + prepare_input_kernel[(batch_size, )]( + token_indices, + cu_target_query_lens, + cu_num_tokens, + BLOCK_SIZE=BLOCK_SIZE, + ) + return cu_num_tokens, token_indices + # TODO(woosuk): The logic here is duplicated with the main sampling code. # We should refactor this to reuse the same sampling implementation. @@ -155,12 +199,28 @@ def sample_token_ids( return next_token_ids -def compute_slot_mapping( - positions: torch.Tensor, # [num_tokens] - block_table: torch.Tensor, # [batch_size, max_num_blocks_per_req] - block_size: int, -) -> torch.Tensor: # [num_tokens] - block_numbers = positions // block_size - block_ids = block_table.gather(dim=1, index=block_numbers) - slot_mapping = block_ids * block_size + positions % block_size - return slot_mapping +@triton.jit +def prepare_input_kernel( + out_ptr, + cu_query_lens_ptr, + cu_num_tokens_ptr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + # [start_pos, end_pos) + start_pos = tl.load(cu_num_tokens_ptr + pid) + end_pos = tl.load(cu_num_tokens_ptr + pid + 1) + num_tokens = end_pos - start_pos + + index_start = tl.load(cu_query_lens_ptr + pid) + indices = index_start + tl.arange(0, BLOCK_SIZE) + + num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) + for i in tl.range(num_blocks): + offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store( + out_ptr + start_pos + offset, + indices, + mask=offset < num_tokens, + ) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 351b35815580..a64cb97e0123 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -39,9 +39,18 @@ class CachedRequestState: lora_request: Optional[LoRARequest] = None + def __post_init__(self): + self.num_prompt_tokens = len(self.prompt_token_ids) + @property def num_tokens(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) + return self.num_prompt_tokens + len(self.output_token_ids) + + def get_token_id(self, idx: int) -> int: + if idx < self.num_prompt_tokens: + return self.prompt_token_ids[idx] + else: + return self.output_token_ids[idx - self.num_prompt_tokens] class InputBatch: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 50da85bd33f4..035f93ef7d06 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -34,9 +34,9 @@ ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -1130,16 +1130,63 @@ def execute_model( valid_sampled_token_ids[i].clear() if not self.use_spec_decode: + # Speculative decoding is not enabled. spec_token_ids = None - else: - token_indices = None - cu_num_tokens = None - next_token_ids = None # Sampled, Bonus, or Recovered - input_ids = self.input_ids[:num_scheduled_tokens] + elif self.speculative_config.method == "ngram": + pass + elif self.speculative_config.method == "eagle": + # TODO(woosuk): Refactor the loop. + next_token_ids: list[int] = [] + for i, token_ids in enumerate(valid_sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = self.input_batch.req_ids[i] + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + + if spec_decode_metadata is None: + # input_ids can be None for multimodal models. + target_token_ids = self.input_ids[:num_scheduled_tokens] + target_positions = positions + target_hidden_states = hidden_states + target_slot_mapping = attn_metadata.slot_mapping + cu_num_tokens = attn_metadata.query_start_loc + else: + # TODO(woosuk): Refactor this. + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor( + num_rejected_tokens, + dtype=torch.int32, + device=self.device, + ) + cu_num_tokens, token_indices = self.drafter.prepare_inputs( + attn_metadata.query_start_loc, + num_rejected_tokens, + ) + target_token_ids = self.input_ids[token_indices] + target_positions = positions[token_indices] + target_hidden_states = hidden_states[token_indices] + target_slot_mapping = attn_metadata.slot_mapping[token_indices] + spec_token_ids = self.drafter.propose( - target_token_ids=input_ids[token_indices], - target_positions=positions[token_indices], - target_hidden_states=hidden_states[token_indices], + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, cu_num_tokens=cu_num_tokens, block_table=attn_metadata.block_table, From d8e901a57dda0a26474a5baf760c60881db50bee Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 29 Mar 2025 18:16:53 -0700 Subject: [PATCH 11/16] minor Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 035f93ef7d06..4ccc0ded7a9b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1124,7 +1124,6 @@ def execute_model( sampled_token_ids, self.input_batch.vocab_size, ) - # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[i].clear() @@ -1133,7 +1132,8 @@ def execute_model( # Speculative decoding is not enabled. spec_token_ids = None elif self.speculative_config.method == "ngram": - pass + spec_token_ids = self.generate_draft_token_ids( + valid_sampled_token_ids, sampling_metadata) elif self.speculative_config.method == "eagle": # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] From 83c8b59c747f64b7bfcf0022e0d5e5e55c59f017 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 30 Mar 2025 14:37:14 -0700 Subject: [PATCH 12/16] dummy model Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 24 +++++++++++++++++++++++- vllm/v1/worker/gpu_model_runner.py | 6 +++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 3fb749b728c9..a1132c09ddad 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import torch +import torch.nn as nn import triton import triton.language as tl @@ -16,7 +17,6 @@ def __init__( vllm_config: VllmConfig, device: torch.device, ): - self.model = None self.vllm_config = vllm_config self.num_speculative_tokens = ( vllm_config.speculative_config.num_speculative_tokens) @@ -168,6 +168,28 @@ def prepare_inputs( ) return cu_num_tokens, token_indices + def load_model(self, target_model: nn.Module) -> None: + self.model = DummyEagleModel() + self.model.get_input_embeddings = target_model.get_input_embeddings + self.model.compute_logits = target_model.compute_logits + + +# FIXME(woosuk): This is a dummy model for testing. +# Remove this once we have a real model. +class DummyEagleModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward( + self, + input_ids: torch.Tensor, + hidden_states: torch.Tensor, + positions: torch.Tensor, + ) -> torch.Tensor: + input_embeddings = self.get_input_embeddings(input_ids) + return hidden_states + input_embeddings # Dummy return. + # TODO(woosuk): The logic here is duplicated with the main sampling code. # We should refactor this to reuse the same sampling implementation. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2c28079f1f2c..a842a28a2d32 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1191,7 +1191,7 @@ def execute_model( cu_num_tokens=cu_num_tokens, block_table=attn_metadata.block_table, sampling_metadata=sampling_metadata, - ) + ).tolist() return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1249,6 +1249,10 @@ def load_model(self) -> None: self.scheduler_config, self.lora_config, self.device) + if (hasattr(self, "drafter") + and self.speculative_config.method != "ngram"): + logger.info("Loading drafter model...") + self.drafter.load_model(self.model) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory logger.info("Model loading took %.4f GB and %.6f seconds", From 64d2ed7479b1759648012dd8eaa78163df7ed5b5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 30 Mar 2025 18:12:41 -0700 Subject: [PATCH 13/16] fix Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a1132c09ddad..3117e2208666 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -203,7 +203,7 @@ def sample_token_ids( # All greedy. next_token_ids = logits.argmax(dim=-1) else: - logits.div_(sampling_metadata.temperature) + logits.div_(sampling_metadata.temperature.view(-1, 1)) probs = logits.softmax(dim=-1, dtype=torch.float32) # TODO(woosuk): Consider seeds? From a7f0600de6a24fa85ca55d46379949d75f1a9a10 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 30 Mar 2025 22:18:59 -0700 Subject: [PATCH 14/16] Return draft_probs Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 76 ++++++++++++++++++------------ vllm/v1/worker/gpu_model_runner.py | 8 +++- 2 files changed, 51 insertions(+), 33 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 3117e2208666..57c6b652593d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -41,7 +41,7 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: # [batch_size, num_speculative_tokens] + ) -> tuple[torch.Tensor, torch.Tensor]: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 @@ -82,16 +82,18 @@ def propose( ) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) - - # Sample the next token. - draft_token_ids = sample_token_ids(logits, sampling_metadata) + draft_token_ids, draft_probs = compute_probs_and_sample_next_token( + logits, sampling_metadata) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: - return draft_token_ids.view(-1, 1) + # [batch_size, 1] and [batch_size, 1, vocab_size] + return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] + draft_probs_list = [draft_probs] + positions = target_positions[last_token_indices] hidden_states = sample_hidden_states attn_metadata.num_actual_tokens = batch_size @@ -119,11 +121,16 @@ def propose( positions=positions, ) logits = self.model.compute_logits(hidden_states, None) - draft_token_ids = sample_token_ids(logits, sampling_metadata) + draft_token_ids, probs = compute_probs_and_sample_next_token( + logits, sampling_metadata) draft_token_ids_list.append(draft_token_ids) + draft_probs_list.append(probs) # [batch_size, num_speculative_tokens] - return torch.stack(draft_token_ids_list, dim=1) + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + # [batch_size, num_speculative_tokens, vocab_size] + draft_probs = torch.stack(draft_probs_list, dim=1) + return draft_token_ids, draft_probs @staticmethod def prepare_inputs( @@ -191,34 +198,41 @@ def forward( return hidden_states + input_embeddings # Dummy return. -# TODO(woosuk): The logic here is duplicated with the main sampling code. +# FIXME(woosuk): The logic here is duplicated with the main sampling code. # We should refactor this to reuse the same sampling implementation. -def sample_token_ids( - logits: torch.Tensor, # [batch_size, vocab_size] +def compute_probs_and_sample_next_token( + logits: torch.Tensor, sampling_metadata: SamplingMetadata, -) -> torch.Tensor: # [batch_size] - # NOTE(woosuk): We don't need to apply all the sampling parameters - # for generating the draft tokens. +) -> tuple[torch.Tensor, torch.Tensor]: if sampling_metadata.all_greedy: - # All greedy. + # For greedy requests, draft_probs is not used in rejection sampling. + # Therefore, we can just return the logits. + probs = logits next_token_ids = logits.argmax(dim=-1) - else: - logits.div_(sampling_metadata.temperature.view(-1, 1)) - probs = logits.softmax(dim=-1, dtype=torch.float32) - - # TODO(woosuk): Consider seeds? - q = torch.empty_like(logits) - q.exponential_() - next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) - - if not sampling_metadata.all_random: - greedy_token_ids = logits.argmax(dim=-1) - next_token_ids = torch.where( - sampling_metadata.temperature == -1, - greedy_token_ids, - next_token_ids, - ) - return next_token_ids + return next_token_ids, probs + + is_greedy = sampling_metadata.temperature == -1 + temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) + logits.div_(temperature.view(-1, 1)) + probs = logits.softmax(dim=-1, dtype=torch.float32) + + # NOTE(woosuk): Currently, we ignore most of the sampling parameters in + # generating the draft tokens. We only use the temperature. While this + # could degrade the acceptance rate, it does not affect the distribution + # of the generated tokens after rejection sampling. + + # TODO(woosuk): Consider seeds. + q = torch.empty_like(probs) + q.exponential_() + next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) + if not sampling_metadata.all_random: + greedy_token_ids = probs.argmax(dim=-1) + next_token_ids = torch.where( + is_greedy, + greedy_token_ids, + next_token_ids, + ) + return next_token_ids, probs @triton.jit diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a639fe4b11b5..ca1b4d7dc485 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1182,7 +1182,7 @@ def execute_model( target_hidden_states = hidden_states[token_indices] target_slot_mapping = attn_metadata.slot_mapping[token_indices] - spec_token_ids = self.drafter.propose( + draft_token_ids, draft_probs = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -1191,7 +1191,11 @@ def execute_model( cu_num_tokens=cu_num_tokens, block_table=attn_metadata.block_table, sampling_metadata=sampling_metadata, - ).tolist() + ) + spec_token_ids = draft_token_ids.tolist() + # TODO(woosuk): Cache draft_probs and use it for rejection sampling + # in the next step. + del draft_probs return ModelRunnerOutput( req_ids=self.input_batch.req_ids, From d5db76aafe7c2ce905280ecd50479023e82dcd85 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 31 Mar 2025 23:54:44 -0700 Subject: [PATCH 15/16] simplify Signed-off-by: Woosuk Kwon --- vllm/engine/arg_utils.py | 7 ++----- vllm/v1/spec_decode/ngram_proposer.py | 9 +++++++++ vllm/v1/worker/gpu_model_runner.py | 8 +++++--- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8097998fcafc..579b3b5a0917 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1479,11 +1479,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: is_eagle_enabled = True else: speculative_model = self.speculative_config.get("model") - if speculative_model: - if speculative_model in ("ngram", "[ngram]"): - is_ngram_enabled = True - elif "eagle" in speculative_model.lower(): - is_eagle_enabled = True + if speculative_model in ("ngram", "[ngram]"): + is_ngram_enabled = True if not (is_ngram_enabled or is_eagle_enabled): # Other speculative decoding methods are not supported yet. _raise_or_fallback(feature_name="Speculative Decoding", diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 0bef349e99e2..8f6d20d11ff3 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -4,9 +4,14 @@ import numpy as np from numba import jit +from vllm.config import VllmConfig + class NgramProposer: + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + def propose( self, context_token_ids: np.ndarray, @@ -50,6 +55,10 @@ def propose( return result return None + def load_model(self, *args, **kwargs): + # No model to load. + pass + @jit(nopython=True) def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 967754f05023..b216dbe960d6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -161,7 +161,8 @@ def __init__( if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.method == "eagle": - self.drafter = EagleProposer(self.vllm_config, self.device) + self.drafter = EagleProposer(self.vllm_config, + self.device) # type: ignore else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -1143,9 +1144,11 @@ def execute_model( # Speculative decoding is not enabled. spec_token_ids = None elif self.speculative_config.method == "ngram": + assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.generate_draft_token_ids( valid_sampled_token_ids, sampling_metadata) elif self.speculative_config.method == "eagle": + assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] for i, token_ids in enumerate(valid_sampled_token_ids): @@ -1264,8 +1267,7 @@ def load_model(self) -> None: self.scheduler_config, self.lora_config, self.device) - if (hasattr(self, "drafter") - and self.speculative_config.method != "ngram"): + if hasattr(self, "drafter"): logger.info("Loading drafter model...") self.drafter.load_model(self.model) time_after_load = time.perf_counter() From 7a1d5ffa83f6f815d2c77bf860ad71b682b2d0a6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 1 Apr 2025 12:25:04 -0700 Subject: [PATCH 16/16] fix Signed-off-by: Woosuk Kwon --- vllm/config.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index c1f933c3601d..6ec5d1bc28fa 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2152,9 +2152,10 @@ def __post_init__(self): # Replace hf_config for EAGLE draft_model if self.method == "eagle": - # if self.enable_chunked_prefill: - # raise ValueError( - # "Chunked prefill and EAGLE are not compatible.") + if self.enable_chunked_prefill and not envs.VLLM_USE_V1: + raise ValueError( + "Chunked prefill and EAGLE are not compatible " + "when using V0.") from vllm.transformers_utils.configs.eagle import ( EAGLEConfig)