From 70e688ca721dc3ab7427e27f655456dca740f452 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Fri, 11 Apr 2025 00:29:30 +0000 Subject: [PATCH 01/19] [V1] MTP Support Prototype Signed-off-by: Rui Qiao --- tests/conftest.py | 2 + tests/spec_decode/conftest.py | 2 +- vllm/config.py | 3 + vllm/engine/arg_utils.py | 4 + vllm/entrypoints/llm.py | 2 + vllm/v1/attention/backends/mla/common.py | 8 +- vllm/v1/spec_decode/mtp_proposer.py | 217 +++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 12 +- 8 files changed, 245 insertions(+), 5 deletions(-) create mode 100644 vllm/v1/spec_decode/mtp_proposer.py diff --git a/tests/conftest.py b/tests/conftest.py index f02b5a8c0520..cc528584d524 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -759,6 +759,8 @@ def __init__( enforce_eager: Optional[bool] = False, **kwargs, ) -> None: + from vllm import envs + logger.info(f"VLLM_USE_V1: {envs.VLLM_USE_V1}") self.model = LLM( model=model_name, task=task, diff --git a/tests/spec_decode/conftest.py b/tests/spec_decode/conftest.py index 1a20e2c135c2..eca20289aa46 100644 --- a/tests/spec_decode/conftest.py +++ b/tests/spec_decode/conftest.py @@ -8,4 +8,4 @@ def use_v0_only(monkeypatch): Since this module is V0 only, set VLLM_USE_V1=0 for all tests in the module. """ - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv('VLLM_USE_V1', '1') diff --git a/vllm/config.py b/vllm/config.py index e96d872d693e..b92e9b74be56 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2416,6 +2416,9 @@ def __post_init__(self): elif (self.draft_model_config.hf_config.model_type == "mlp_speculator"): self.method = "mlp_speculator" + elif (self.draft_model_config.hf_config.model_type == + "deepseek_mtp"): + self.method = "mtp" else: self.method = "draft_model" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c7a580cf1051..405d4af3fa7f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -957,6 +957,8 @@ def create_engine_config( else: envs.set_vllm_use_v1(use_v1) + logger.info("use_v1: %s", use_v1) + # Set default arguments for V0 or V1 Engine. if use_v1: self._set_default_args_v1(usage_context) @@ -1281,6 +1283,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: speculative_model = self.speculative_config.get("model") if speculative_model in ("ngram", "[ngram]"): is_ngram_enabled = True + logger.info("Forcing to use V1 for speculative decoding.") + return True if not (is_ngram_enabled or is_eagle_enabled): # Other speculative decoding methods are not supported yet. _raise_or_fallback(feature_name="Speculative Decoding", diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 79f1d80f402c..f4308b06c791 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -240,6 +240,8 @@ def __init__( **kwargs, ) + from vllm import envs + logger.info(f"VLLM_USE_V1: {envs.VLLM_USE_V1}") # Create the Engine (autoselects V0 vs V1) self.llm_engine = LLMEngine.from_engine_args( engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index fd3be901f4c3..952ff1d1d8e3 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -304,6 +304,7 @@ class MLACommonMetadata(Generic[D]): num_actual_tokens: int # Number of tokens excluding padding. query_start_loc: torch.Tensor + block_table: torch.Tensor slot_mapping: torch.Tensor # New for MLA (compared to FlashAttention) @@ -341,6 +342,7 @@ def __init__(self, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata + logger.info(f"self.metadata_cls: {self.metadata_cls}") self.runner = runner scheduler_config = runner.scheduler_config model_config = runner.model_config @@ -352,8 +354,9 @@ def __init__(self, self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) # Dont try to access the runner on AMD - if self.aot_schedule: - self.page_size = self.runner.block_size + #if self.aot_schedule: + # Need page_size to compute max_context_chunk + self.page_size = self.runner.block_size if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -557,6 +560,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, return self.metadata_cls( num_actual_tokens=num_actual_tokens, query_start_loc=query_start_loc, + block_table=block_table, slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), # MLACommonMetadata Chunk prefill specific diff --git a/vllm/v1/spec_decode/mtp_proposer.py b/vllm/v1/spec_decode/mtp_proposer.py new file mode 100644 index 000000000000..b4b96bfc1eae --- /dev/null +++ b/vllm/v1/spec_decode/mtp_proposer.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch +import torch.nn as nn +import triton +import triton.language as tl + +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.forward_context import set_forward_context +from vllm.model_executor.model_loader.loader import get_model_loader +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP +from vllm.v1.sample.metadata import SamplingMetadata + + +# FIXME(woosuk): The logic here is duplicated with the main sampling code. +# We should refactor this to reuse the same sampling implementation. +def compute_probs_and_sample_next_token( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> tuple[torch.Tensor, torch.Tensor]: + if sampling_metadata.all_greedy: + # For greedy requests, draft_probs is not used in rejection sampling. + # Therefore, we can just return the logits. + probs = logits + next_token_ids = logits.argmax(dim=-1) + return next_token_ids, probs + + is_greedy = sampling_metadata.temperature == -1 + temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) + logits.div_(temperature.view(-1, 1)) + probs = logits.softmax(dim=-1, dtype=torch.float32) + + # NOTE(woosuk): Currently, we ignore most of the sampling parameters in + # generating the draft tokens. We only use the temperature. While this + # could degrade the acceptance rate, it does not affect the distribution + # of the generated tokens after rejection sampling. + + # TODO(woosuk): Consider seeds. + q = torch.empty_like(probs) + q.exponential_() + next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) + if not sampling_metadata.all_random: + greedy_token_ids = probs.argmax(dim=-1) + next_token_ids = torch.where( + is_greedy, + greedy_token_ids, + next_token_ids, + ) + return next_token_ids, probs + + +class MtpProposer: + + def __init__( + self, + vllm_config: VllmConfig, + runner, + ): + self.vllm_config = vllm_config + self.num_speculative_tokens = ( + vllm_config.speculative_config.num_speculative_tokens) + self.block_size = vllm_config.cache_config.block_size + self.runner = runner + + @staticmethod + def prepare_inputs( + # [batch_size + 1] + cu_target_query_lens: torch.Tensor, + # [batch_size] + num_rejected_tokens: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # cu_target_query_lens: [0, a, a + b, a + b + c] + # num_rejected_tokens: [n1, n2, n3] + # num_tokens_per_req: [a - n1, b - n2, c - n3] + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] + # token_indices: [0, 1, ..., a - n1 - 1, + # a, a + 1, ..., a + b - n2 - 1, + # a + b, a + b + 1, ..., a + b + c - n3 - 1] + + # [0, a, a + b, a + b + c] -> [a, b, c] + query_len_per_req = (cu_target_query_lens[1:] - + cu_target_query_lens[:-1]) + # [a, b, c] -> [a - n1, b - n2, c - n3] + num_tokens_per_req = query_len_per_req - num_rejected_tokens + + cu_num_tokens = torch.empty_like(cu_target_query_lens) + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + cu_num_tokens[0] = 0 + + # FIXME(woosuk): Avoid synchronization. + num_tokens = cu_num_tokens[-1].item() + token_indices = torch.empty( + num_tokens, + dtype=torch.int32, + device=cu_num_tokens.device, + ) + + batch_size = num_rejected_tokens.shape[0] + BLOCK_SIZE = 1024 + prepare_input_kernel[(batch_size, )]( + token_indices, + cu_target_query_lens, + cu_num_tokens, + BLOCK_SIZE=BLOCK_SIZE, + ) + return cu_num_tokens, token_indices + + def propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [num_tokens] + target_slot_mapping: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + # [batch_size + 1] starting with 0 + cu_num_tokens: torch.Tensor, + # [batch_size, max_num_blocks_per_req] + block_table: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + last_token_indices = cu_num_tokens[1:] - 1 + + input_ids = torch.empty_like(target_token_ids) + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + input_ids[:-1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + input_ids[last_token_indices] = next_token_ids + + query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] + max_query_len = query_lens.max().item() + + # FIXME: reorder_batch() needs to be called before build() + # because fields of attn_metadata_builder needs to be updated. + # However, currently reorder_batch() takes input_batch and + # scheduler_output as arguments, we should probably refactor + # the method to use new data structures which are independent + # from input_batch and scheduler_output. + # self.runner.attn_metadata_builder.reorder_batch( + # input_batch=self.runner.input_batch, + # scheduler_output=self.runner.scheduler_output, + # ) + + attn_metadata = self.runner.attn_metadata_builder.build( + num_reqs=batch_size, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + common_prefix_len=0, + ) + + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + positions=target_positions, + previous_hidden_states=target_hidden_states, + ) + sample_hidden_states = hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + draft_token_ids = logits.argmax(dim=-1) + + assert self.num_speculative_tokens == 1 + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + def load_model(self, target_model: nn.Module) -> None: + loader = get_model_loader(self.vllm_config.load_config) + + draft_model_config = \ + self.vllm_config.speculative_config.draft_model_config + # FIXME(lily): This does not handle with distributed inference. + target_device = self.vllm_config.device_config.device + # We need to set the vllm_config here to register attention + # layers in the forward context. + with set_default_torch_dtype( + draft_model_config.dtype), set_current_vllm_config( + self.vllm_config): + self.model = DeepSeekMTP( + vllm_config=self.vllm_config).to(target_device) + + self.model.load_weights( + loader.get_all_weights( + self.vllm_config.speculative_config.draft_model_config, + self.model)) + + +@triton.jit +def prepare_input_kernel( + out_ptr, + cu_query_lens_ptr, + cu_num_tokens_ptr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + # [start_pos, end_pos) + start_pos = tl.load(cu_num_tokens_ptr + pid) + end_pos = tl.load(cu_num_tokens_ptr + pid + 1) + num_tokens = end_pos - start_pos + + index_start = tl.load(cu_query_lens_ptr + pid) + + num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) + for i in tl.range(num_blocks): + offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store( + out_ptr + start_pos + offset, + index_start + offset, + mask=offset < num_tokens, + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 97d8c91b4659..e533fa781850 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -41,6 +41,7 @@ from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.mtp_proposer import MtpProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache @@ -140,6 +141,7 @@ def __init__( self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) + print(f"self.attn_metadata_builder: {self.attn_metadata_builder}") self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support @@ -176,6 +178,9 @@ def __init__( self.device) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True + elif self.speculative_config.method == "mtp": + self.drafter = MtpProposer(self.vllm_config, + self) # type: ignore else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -1191,6 +1196,7 @@ def execute_model( sampled_token_ids = sampler_output.sampled_token_ids max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: + # GPU tensor to CPU list? sync point? # No spec decode tokens. valid_sampled_token_ids = sampled_token_ids.tolist() else: @@ -1210,8 +1216,10 @@ def execute_model( assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.generate_draft_token_ids( valid_sampled_token_ids, sampling_metadata) - elif self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + elif (self.speculative_config.use_eagle() or + self.speculative_config.draft_model_config.hf_config.model_type \ + == "deepseek_mtp"): + assert isinstance(self.drafter, (EagleProposer, MtpProposer)) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] for i, token_ids in enumerate(valid_sampled_token_ids): From 4cf308b20a98d26169a8b5356eacb352db49aad2 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Tue, 20 May 2025 19:59:54 +0000 Subject: [PATCH 02/19] fix MTP tp Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/v1/spec_decode/mtp_proposer.py | 28 +++++++++++++++++++++------- vllm/v1/worker/gpu_model_runner.py | 5 +++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/vllm/v1/spec_decode/mtp_proposer.py b/vllm/v1/spec_decode/mtp_proposer.py index b4b96bfc1eae..144780d3912c 100644 --- a/vllm/v1/spec_decode/mtp_proposer.py +++ b/vllm/v1/spec_decode/mtp_proposer.py @@ -6,7 +6,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context -from vllm.model_executor.model_loader.loader import get_model_loader +from vllm.model_executor.model_loader.loader import get_model_loader, _process_weights_after_loading from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP from vllm.v1.sample.metadata import SamplingMetadata @@ -182,13 +182,27 @@ def load_model(self, target_model: nn.Module) -> None: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - self.model = DeepSeekMTP( - vllm_config=self.vllm_config).to(target_device) - - self.model.load_weights( - loader.get_all_weights( + #self.model = DeepSeekMTP( + # vllm_config=self.vllm_config).to(target_device) + + with target_device: + self.model = DeepSeekMTP( + vllm_config=self.vllm_config) + + self.model.load_weights( + loader.get_all_weights( + self.vllm_config.speculative_config.draft_model_config, + self.model)) + + _process_weights_after_loading( + self.model, self.vllm_config.speculative_config.draft_model_config, - self.model)) + target_device) + + # self.model.load_weights( + # loader.get_all_weights( + # self.vllm_config.speculative_config.draft_model_config, + # self.model)) @triton.jit diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e533fa781850..01e6c4ec12f8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1287,6 +1287,11 @@ def execute_model( sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() + + import torch.distributed as dist + if dist.get_rank() == 0: + print(f"valid_sampled_token_ids: {valid_sampled_token_ids}") + print(f"spec_token_ids: {spec_token_ids}") # Clear KVConnector state after all KVs are generated. if has_kv_transfer_group(): From e2f065c23374d3543e73991c2c40a5f99aa3dfdb Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Tue, 20 May 2025 22:05:34 +0000 Subject: [PATCH 03/19] fix bugs Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/config.py | 4 +-- vllm/engine/arg_utils.py | 2 +- vllm/v1/attention/backends/mla/common.py | 3 -- vllm/v1/spec_decode/mtp_proposer.py | 38 ++++++++++++++++-------- vllm/v1/worker/gpu_model_runner.py | 30 +++++++++++-------- 5 files changed, 46 insertions(+), 31 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 5ed027d306ee..a4a4f82fdcd1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2517,7 +2517,7 @@ def __post_init__(self): self.method = "mlp_speculator" elif (self.draft_model_config.hf_config.model_type == "deepseek_mtp"): - self.method = "mtp" + self.method = "deepseek_mtp" else: self.method = "draft_model" @@ -2738,7 +2738,7 @@ def num_lookahead_slots(self) -> int: return self.num_speculative_tokens def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3") + return self.method in ("eagle", "eagle3", "deepseek_mtp") def __repr__(self) -> str: method = self.method diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ebec1aab9594..d79252477acd 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1339,7 +1339,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: is_ngram_enabled = True elif speculative_method == "medusa": is_medusa_enabled = True - elif speculative_method in ("eagle", "eagle3"): + elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"): is_eagle_enabled = True else: speculative_model = self.speculative_config.get("model") diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index dc927297ed8f..83e181116577 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -302,7 +302,6 @@ class MLACommonMetadata(Generic[D]): num_actual_tokens: int # Number of tokens excluding padding. query_start_loc: torch.Tensor - block_table: torch.Tensor slot_mapping: torch.Tensor # New for MLA (compared to FlashAttention) @@ -342,7 +341,6 @@ def __init__(self, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata - logger.info(f"self.metadata_cls: {self.metadata_cls}") self.runner = runner scheduler_config = runner.scheduler_config model_config = runner.model_config @@ -555,7 +553,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, return self.metadata_cls( num_actual_tokens=num_actual_tokens, query_start_loc=query_start_loc, - block_table=block_table, slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), # MLACommonMetadata Chunk prefill specific diff --git a/vllm/v1/spec_decode/mtp_proposer.py b/vllm/v1/spec_decode/mtp_proposer.py index 144780d3912c..be7cda07dea8 100644 --- a/vllm/v1/spec_decode/mtp_proposer.py +++ b/vllm/v1/spec_decode/mtp_proposer.py @@ -3,13 +3,17 @@ import torch.nn as nn import triton import triton.language as tl +from typing import Optional -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.attention.layer import Attention +from vllm.config import (VllmConfig, + get_layers_from_vllm_config, set_current_vllm_config) from vllm.forward_context import set_forward_context -from vllm.model_executor.model_loader.loader import get_model_loader, _process_weights_after_loading -from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.utils import set_default_torch_dtype, process_weights_after_loading from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.attention.backends.utils import CommonAttentionMetadata # FIXME(woosuk): The logic here is duplicated with the main sampling code. @@ -120,7 +124,7 @@ def propose( # [batch_size + 1] starting with 0 cu_num_tokens: torch.Tensor, # [batch_size, max_num_blocks_per_req] - block_table: torch.Tensor, + block_table: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> tuple[torch.Tensor, torch.Tensor]: num_tokens = target_token_ids.shape[0] @@ -137,6 +141,11 @@ def propose( query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() + + seq_lens = (target_positions[last_token_indices] + 1) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=cu_num_tokens, seq_lens=seq_lens) # FIXME: reorder_batch() needs to be called before build() # because fields of attn_metadata_builder needs to be updated. @@ -149,11 +158,13 @@ def propose( # scheduler_output=self.runner.scheduler_output, # ) - attn_metadata = self.runner.attn_metadata_builder.build( + # FIXME: need to consider multiple kv_cache_groups + attn_metadata = self.runner.attn_metadata_builders[0].build( num_reqs=batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, common_prefix_len=0, + common_attn_metadata=common_attn_metadata, ) with set_forward_context(attn_metadata, self.vllm_config): @@ -172,7 +183,9 @@ def propose( def load_model(self, target_model: nn.Module) -> None: loader = get_model_loader(self.vllm_config.load_config) - + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + draft_model_config = \ self.vllm_config.speculative_config.draft_model_config # FIXME(lily): This does not handle with distributed inference. @@ -182,27 +195,26 @@ def load_model(self, target_model: nn.Module) -> None: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - #self.model = DeepSeekMTP( - # vllm_config=self.vllm_config).to(target_device) with target_device: self.model = DeepSeekMTP( vllm_config=self.vllm_config) + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - + target_attn_layer_names) + assert len(draft_attn_layer_names) == 1 + self.attn_layer_name = next(iter(draft_attn_layer_names)) self.model.load_weights( loader.get_all_weights( self.vllm_config.speculative_config.draft_model_config, self.model)) - _process_weights_after_loading( + process_weights_after_loading( self.model, self.vllm_config.speculative_config.draft_model_config, target_device) - # self.model.load_weights( - # loader.get_all_weights( - # self.vllm_config.speculative_config.draft_model_config, - # self.model)) @triton.jit diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f89702ce9172..b1e628ee012c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -159,13 +159,14 @@ def __init__( if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, - self.device) # type: ignore - if self.speculative_config.method == "eagle3": - self.use_aux_hidden_state_outputs = True - elif self.speculative_config.method == "mtp": - self.drafter = MtpProposer(self.vllm_config, - self) # type: ignore + if self.speculative_config.method == "deepseek_mtp": + self.drafter = MtpProposer(self.vllm_config, + self) # type: ignore + else: + self.drafter = EagleProposer(self.vllm_config, + self.device) # type: ignore + if self.speculative_config.method == "eagle3": + self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "medusa": self.drafter = MedusaProposer( vllm_config=self.vllm_config, @@ -1337,9 +1338,7 @@ def execute_model( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) - elif (self.speculative_config.use_eagle() or - self.speculative_config.draft_model_config.hf_config.model_type \ - == "deepseek_mtp"): + elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, (EagleProposer, MtpProposer)) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] @@ -1361,6 +1360,12 @@ def execute_model( device=self.device) eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] + # NOTE: deepseek_mtp uses MLA which does not have `block_table` + if hasattr(eagle_attn_metadata, "block_table"): + block_table = eagle_attn_metadata.block_table + else: + block_table = None + if spec_decode_metadata is None: # input_ids can be None for multimodal models. target_token_ids = self.input_ids[:num_scheduled_tokens] @@ -1398,7 +1403,8 @@ def execute_model( target_hidden_states = hidden_states[token_indices] target_slot_mapping = eagle_attn_metadata.slot_mapping[ token_indices] - + + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -1406,7 +1412,7 @@ def execute_model( target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, cu_num_tokens=cu_num_tokens, - block_table=eagle_attn_metadata.block_table, + block_table=block_table, sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() From fce4d58c6634cebccecd39f4eb778d7281ccac84 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Tue, 20 May 2025 22:46:22 +0000 Subject: [PATCH 04/19] fix pp Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/v1/worker/gpu_model_runner.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b1e628ee012c..65aa47d55662 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1417,11 +1417,6 @@ def execute_model( ) spec_token_ids = draft_token_ids.tolist() - import torch.distributed as dist - if dist.get_rank() == 0: - print(f"valid_sampled_token_ids: {valid_sampled_token_ids}") - print(f"spec_token_ids: {spec_token_ids}") - # Clear KVConnector state after all KVs are generated. if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() From fade742de8cb6c64950555a488213798df4cc0fc Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Tue, 20 May 2025 22:58:17 +0000 Subject: [PATCH 05/19] fix format Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- requirements/test.txt | 22 +++++++- vllm/model_executor/models/deepseek_mtp.py | 1 - vllm/v1/spec_decode/eagle.py | 30 +---------- vllm/v1/spec_decode/mtp_proposer.py | 62 ++++++---------------- vllm/v1/spec_decode/utils.py | 27 ++++++++++ vllm/v1/worker/gpu_model_runner.py | 17 +++--- 6 files changed, 74 insertions(+), 85 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 89d477017342..df3770856022 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -27,6 +27,10 @@ argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 # via isoduration +async-timeout==5.0.1 + # via + # aiohttp + # redis attrs==24.2.0 # via # aiohttp @@ -129,6 +133,11 @@ eval-type-backport==0.2.2 # via mteb evaluate==0.4.3 # via lm-eval +exceptiongroup==1.3.0 + # via + # anyio + # hypothesis + # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -640,7 +649,6 @@ setuptools==77.0.3 # via # mamba-ssm # pytablewriter - # torch # triton shellingham==1.5.4 # via typer @@ -700,8 +708,13 @@ tokenizers==0.21.1 # via # -r requirements/test.in # transformers +toml==0.10.2 + # via datamodel-code-generator tomli==2.2.1 - # via schemathesis + # via + # black + # pytest + # schemathesis tomli-w==1.2.0 # via schemathesis torch==2.7.0+cu128 @@ -775,13 +788,18 @@ types-python-dateutil==2.9.0.20241206 # via arrow typing-extensions==4.12.2 # via + # anyio + # black + # exceptiongroup # huggingface-hub # librosa # mistral-common # mteb + # multidict # pqdm # pydantic # pydantic-core + # rich # torch # typer tzdata==2024.2 diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index f9a114ff7617..03ef7bed0edc 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -19,7 +19,6 @@ from .deepseek_v2 import (DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name) - from .interfaces import SupportsPP from .utils import maybe_prefix diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5b84bc1f5ec3..56ee6c67a53c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -12,9 +12,9 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata +from vllm.vllm.v1.spec_decode.utils import prepare_input_kernel logger = init_logger(__name__) @@ -388,30 +388,4 @@ def compute_probs_and_sample_next_token( greedy_token_ids, next_token_ids, ) - return next_token_ids, probs - - -@triton.jit -def prepare_input_kernel( - out_ptr, - cu_query_lens_ptr, - cu_num_tokens_ptr, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - - # [start_pos, end_pos) - start_pos = tl.load(cu_num_tokens_ptr + pid) - end_pos = tl.load(cu_num_tokens_ptr + pid + 1) - num_tokens = end_pos - start_pos - - index_start = tl.load(cu_query_lens_ptr + pid) - - num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) - for i in tl.range(num_blocks): - offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - tl.store( - out_ptr + start_pos + offset, - index_start + offset, - mask=offset < num_tokens, - ) + return next_token_ids, probs \ No newline at end of file diff --git a/vllm/v1/spec_decode/mtp_proposer.py b/vllm/v1/spec_decode/mtp_proposer.py index be7cda07dea8..65fa15a01119 100644 --- a/vllm/v1/spec_decode/mtp_proposer.py +++ b/vllm/v1/spec_decode/mtp_proposer.py @@ -1,19 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional + import torch import torch.nn as nn -import triton -import triton.language as tl -from typing import Optional from vllm.attention.layer import Attention -from vllm.config import (VllmConfig, - get_layers_from_vllm_config, set_current_vllm_config) +from vllm.config import (VllmConfig, get_layers_from_vllm_config, + set_current_vllm_config) from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.utils import set_default_torch_dtype, process_weights_after_loading +from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading, set_default_torch_dtype) from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP -from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.vllm.v1.spec_decode.utils import prepare_input_kernel # FIXME(woosuk): The logic here is duplicated with the main sampling code. @@ -141,9 +142,9 @@ def propose( query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() - + seq_lens = (target_positions[last_token_indices] + 1) - + common_attn_metadata = CommonAttentionMetadata( query_start_loc=cu_num_tokens, seq_lens=seq_lens) @@ -185,7 +186,7 @@ def load_model(self, target_model: nn.Module) -> None: loader = get_model_loader(self.vllm_config.load_config) target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) - + draft_model_config = \ self.vllm_config.speculative_config.draft_model_config # FIXME(lily): This does not handle with distributed inference. @@ -197,47 +198,18 @@ def load_model(self, target_model: nn.Module) -> None: self.vllm_config): with target_device: - self.model = DeepSeekMTP( - vllm_config=self.vllm_config) - - draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - - target_attn_layer_names) + self.model = DeepSeekMTP(vllm_config=self.vllm_config) + + draft_attn_layer_names = (get_layers_from_vllm_config( + self.vllm_config, Attention).keys() - target_attn_layer_names) assert len(draft_attn_layer_names) == 1 self.attn_layer_name = next(iter(draft_attn_layer_names)) self.model.load_weights( loader.get_all_weights( self.vllm_config.speculative_config.draft_model_config, self.model)) - + process_weights_after_loading( - self.model, + self.model, self.vllm_config.speculative_config.draft_model_config, target_device) - - - -@triton.jit -def prepare_input_kernel( - out_ptr, - cu_query_lens_ptr, - cu_num_tokens_ptr, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - - # [start_pos, end_pos) - start_pos = tl.load(cu_num_tokens_ptr + pid) - end_pos = tl.load(cu_num_tokens_ptr + pid + 1) - num_tokens = end_pos - start_pos - - index_start = tl.load(cu_query_lens_ptr + pid) - - num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) - for i in tl.range(num_blocks): - offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - tl.store( - out_ptr + start_pos + offset, - index_start + offset, - mask=offset < num_tokens, - ) diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index ce81a40ee3ae..c484f7b500be 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +from vllm.triton_utils import tl, triton from vllm.v1.worker.gpu_input_batch import InputBatch @@ -16,3 +17,29 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: return False return True + + +@triton.jit +def prepare_input_kernel( + out_ptr, + cu_query_lens_ptr, + cu_num_tokens_ptr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + # [start_pos, end_pos) + start_pos = tl.load(cu_num_tokens_ptr + pid) + end_pos = tl.load(cu_num_tokens_ptr + pid + 1) + num_tokens = end_pos - start_pos + + index_start = tl.load(cu_query_lens_ptr + pid) + + num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) + for i in tl.range(num_blocks): + offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + tl.store( + out_ptr + start_pos + offset, + index_start + offset, + mask=offset < num_tokens, + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 65aa47d55662..9bc1fbf02df6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -151,9 +151,9 @@ def __init__( self.use_aux_hidden_state_outputs = False if self.speculative_config: self.use_spec_decode = True - + # NOTE(Jiayi): currently we put the entire draft model on - # the last PP rank. This is not ideal if there are many + # the last PP rank. This is not ideal if there are many # layers in the draft model. if get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": @@ -161,10 +161,10 @@ def __init__( elif self.speculative_config.use_eagle(): if self.speculative_config.method == "deepseek_mtp": self.drafter = MtpProposer(self.vllm_config, - self) # type: ignore + self) # type: ignore else: - self.drafter = EagleProposer(self.vllm_config, - self.device) # type: ignore + self.drafter = EagleProposer( + self.vllm_config, self.device) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "medusa": @@ -1365,7 +1365,7 @@ def execute_model( block_table = eagle_attn_metadata.block_table else: block_table = None - + if spec_decode_metadata is None: # input_ids can be None for multimodal models. target_token_ids = self.input_ids[:num_scheduled_tokens] @@ -1403,8 +1403,7 @@ def execute_model( target_hidden_states = hidden_states[token_indices] target_slot_mapping = eagle_attn_metadata.slot_mapping[ token_indices] - - + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -1416,7 +1415,7 @@ def execute_model( sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() - + # Clear KVConnector state after all KVs are generated. if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() From 0e6ed110ecc932c2d0ca03cb2627c15edad1d86c Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Wed, 21 May 2025 01:47:32 +0000 Subject: [PATCH 06/19] fix format Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- requirements/test.txt | 20 -------------------- tests/conftest.py | 2 -- vllm/entrypoints/llm.py | 2 -- 3 files changed, 24 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index df3770856022..76e365a5c736 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -27,10 +27,6 @@ argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 # via isoduration -async-timeout==5.0.1 - # via - # aiohttp - # redis attrs==24.2.0 # via # aiohttp @@ -133,11 +129,6 @@ eval-type-backport==0.2.2 # via mteb evaluate==0.4.3 # via lm-eval -exceptiongroup==1.3.0 - # via - # anyio - # hypothesis - # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -708,13 +699,7 @@ tokenizers==0.21.1 # via # -r requirements/test.in # transformers -toml==0.10.2 - # via datamodel-code-generator tomli==2.2.1 - # via - # black - # pytest - # schemathesis tomli-w==1.2.0 # via schemathesis torch==2.7.0+cu128 @@ -788,18 +773,13 @@ types-python-dateutil==2.9.0.20241206 # via arrow typing-extensions==4.12.2 # via - # anyio - # black - # exceptiongroup # huggingface-hub # librosa # mistral-common # mteb - # multidict # pqdm # pydantic # pydantic-core - # rich # torch # typer tzdata==2024.2 diff --git a/tests/conftest.py b/tests/conftest.py index 64a760f3dd59..19c2c6247129 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -768,8 +768,6 @@ def __init__( enforce_eager: Optional[bool] = False, **kwargs, ) -> None: - from vllm import envs - logger.info(f"VLLM_USE_V1: {envs.VLLM_USE_V1}") self.model = LLM( model=model_name, task=task, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 30186ed0a0cc..053ee55bb6a8 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -246,8 +246,6 @@ def __init__( **kwargs, ) - from vllm import envs - logger.info(f"VLLM_USE_V1: {envs.VLLM_USE_V1}") # Create the Engine (autoselects V0 vs V1) self.llm_engine = LLMEngine.from_engine_args( engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) From a79b116e99c00e919594193194d94a475838e85e Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Wed, 21 May 2025 01:53:08 +0000 Subject: [PATCH 07/19] revert changes Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- requirements/test.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements/test.txt b/requirements/test.txt index 76e365a5c736..89d477017342 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -640,6 +640,7 @@ setuptools==77.0.3 # via # mamba-ssm # pytablewriter + # torch # triton shellingham==1.5.4 # via typer @@ -700,6 +701,7 @@ tokenizers==0.21.1 # -r requirements/test.in # transformers tomli==2.2.1 + # via schemathesis tomli-w==1.2.0 # via schemathesis torch==2.7.0+cu128 From ff6a1b302c6b3a3f0a22ab836cad60e96872f98f Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Wed, 21 May 2025 02:21:36 +0000 Subject: [PATCH 08/19] fix format Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index a4a4f82fdcd1..cf5153fa8f14 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2251,7 +2251,7 @@ def __post_init__(self): SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator", - "draft_model"] + "draft_model", "deepseek_mtp"] SpeculativeAcceptanceMethod = Literal["rejection_sampler", "typical_acceptance_sampler"] From ca0d63d2f44d744bda7b35acb696e153dceaf1ad Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Wed, 21 May 2025 03:55:15 +0000 Subject: [PATCH 09/19] fix unit test Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/v1/spec_decode/eagle.py | 2 +- vllm/v1/spec_decode/mtp_proposer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 56ee6c67a53c..e5024c34f0fa 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -14,7 +14,7 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata -from vllm.vllm.v1.spec_decode.utils import prepare_input_kernel +from vllm.v1.spec_decode.utils import prepare_input_kernel logger = init_logger(__name__) diff --git a/vllm/v1/spec_decode/mtp_proposer.py b/vllm/v1/spec_decode/mtp_proposer.py index 65fa15a01119..434a3d2bd6af 100644 --- a/vllm/v1/spec_decode/mtp_proposer.py +++ b/vllm/v1/spec_decode/mtp_proposer.py @@ -14,7 +14,7 @@ from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata -from vllm.vllm.v1.spec_decode.utils import prepare_input_kernel +from vllm.v1.spec_decode.utils import prepare_input_kernel # FIXME(woosuk): The logic here is duplicated with the main sampling code. From 3699a9874329bda21bad77332eabab52cd170e08 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Thu, 22 May 2025 20:03:51 +0000 Subject: [PATCH 10/19] address minor comments Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- tests/spec_decode/conftest.py | 2 +- vllm/engine/arg_utils.py | 2 -- vllm/v1/spec_decode/utils.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/spec_decode/conftest.py b/tests/spec_decode/conftest.py index eca20289aa46..1a20e2c135c2 100644 --- a/tests/spec_decode/conftest.py +++ b/tests/spec_decode/conftest.py @@ -8,4 +8,4 @@ def use_v0_only(monkeypatch): Since this module is V0 only, set VLLM_USE_V1=0 for all tests in the module. """ - monkeypatch.setenv('VLLM_USE_V1', '1') + monkeypatch.setenv('VLLM_USE_V1', '0') diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d79252477acd..947769f51807 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -997,8 +997,6 @@ def create_engine_config( else: envs.set_vllm_use_v1(use_v1) - logger.info("use_v1: %s", use_v1) - # Set default arguments for V0 or V1 Engine. if use_v1: self._set_default_args_v1(usage_context) diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index c484f7b500be..b3b74527734d 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -42,4 +42,4 @@ def prepare_input_kernel( out_ptr + start_pos + offset, index_start + offset, mask=offset < num_tokens, - ) + ) \ No newline at end of file From c505dd70265b7667dfb979bd0f8b7cd6386d8249 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Thu, 22 May 2025 21:12:28 +0000 Subject: [PATCH 11/19] add cudagraph compatibility Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/v1/spec_decode/mtp_proposer.py | 59 ++++++++++++++++++++++++++--- vllm/v1/worker/gpu_model_runner.py | 5 ++- 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/vllm/v1/spec_decode/mtp_proposer.py b/vllm/v1/spec_decode/mtp_proposer.py index 434a3d2bd6af..58d4fe277756 100644 --- a/vllm/v1/spec_decode/mtp_proposer.py +++ b/vllm/v1/spec_decode/mtp_proposer.py @@ -5,8 +5,8 @@ import torch.nn as nn from vllm.attention.layer import Attention -from vllm.config import (VllmConfig, get_layers_from_vllm_config, - set_current_vllm_config) +from vllm.config import (CompilationLevel, VllmConfig, + get_layers_from_vllm_config, set_current_vllm_config) from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import ( @@ -66,6 +66,30 @@ def __init__( vllm_config.speculative_config.num_speculative_tokens) self.block_size = vllm_config.cache_config.block_size self.runner = runner + + self.max_num_tokens = self.runner.max_num_tokens + self.device = self.runner.device + self.dtype = self.runner.dtype + self.hidden_size = self.runner.hidden_size + + self.use_cuda_graph = (self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE and + not self.vllm_config.model_config.enforce_eager) + self.cudagraph_batch_sizes = list( + reversed( + self.vllm_config.compilation_config.cudagraph_capture_sizes)) + + # persistent buffers for cuda graph + self.input_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + self.positions = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=self.device) + self.hidden_states = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) @staticmethod def prepare_inputs( @@ -167,17 +191,29 @@ def propose( common_prefix_len=0, common_attn_metadata=common_attn_metadata, ) + + if self.use_cuda_graph and \ + num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + num_input_tokens = num_tokens + # copy inputs to buffer for cudagraph + self.positions[:num_tokens] = target_positions + self.hidden_states[:num_tokens] = target_hidden_states with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( - input_ids=input_ids, - positions=target_positions, - previous_hidden_states=target_hidden_states, + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], + previous_hidden_states=self.hidden_states[:num_input_tokens], ) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids = logits.argmax(dim=-1) + # TODO: Currently, MTP module released by deepseek only has + # one layer. Adapt this code to support multiple layers once + # there's a multi-layer MTP module. assert self.num_speculative_tokens == 1 # [batch_size, 1] return draft_token_ids.view(-1, 1) @@ -213,3 +249,16 @@ def load_model(self, target_model: nn.Module) -> None: self.model, self.vllm_config.speculative_config.draft_model_config, target_device) + + @torch.inference_mode() + def dummy_run( + self, + num_tokens: int, + ) -> None: + with set_forward_context(None, self.vllm_config, + num_tokens=num_tokens): + self.model( + input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + previous_hidden_states=self.hidden_states[:num_tokens], + ) \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9bc1fbf02df6..c85d90bfd101 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1720,8 +1720,9 @@ def _dummy_run( hidden_states = outputs if self.use_spec_decode and \ - self.speculative_config.method in ('eagle', 'eagle3'): - assert isinstance(self.drafter, EagleProposer) + self.speculative_config.method in ( + 'eagle', 'eagle3', "deepseek_mtp"): + assert isinstance(self.drafter, (EagleProposer, MtpProposer)) self.drafter.dummy_run(num_tokens) logit_indices = np.cumsum(num_scheduled_tokens) - 1 From 8b8c8baa7e182210d8dd2ab39b50d2fe7e4df889 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Fri, 23 May 2025 03:15:38 +0000 Subject: [PATCH 12/19] unify eagle and mtp Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/v1/spec_decode/eagle.py | 81 +++++++++++++++++++++-------- vllm/v1/spec_decode/mtp_proposer.py | 12 ++--- vllm/v1/worker/gpu_model_runner.py | 12 ++--- 3 files changed, 68 insertions(+), 37 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index e5024c34f0fa..683386872471 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -12,7 +12,8 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.flash_attn import (FlashAttentionMetadata, + CommonAttentionMetadata) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import prepare_input_kernel @@ -69,7 +70,7 @@ def __init__( 1, device=device, dtype=torch.int32) - + def propose( self, # [num_tokens] @@ -108,24 +109,53 @@ def propose( # FA requires seq_len to have dtype int32. seq_lens = (target_positions[last_token_indices] + 1).int() - # FIXME(woosuk): The below two ops cause synchronization. Optimize. - max_seq_len = seq_lens.max().item() - max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() - attn_metadata = FlashAttentionMetadata( - num_actual_tokens=num_tokens, - max_query_len=max_num_tokens, - query_start_loc=cu_num_tokens, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table, - slot_mapping=target_slot_mapping, - # TODO(woosuk): Support cascade attention. - use_cascade=False, - common_prefix_len=0, - cu_prefix_query_lens=None, - prefix_kv_lens=None, - suffix_kv_lens=None, - ) + if self.method in []: + # FIXME(woosuk): The below two ops cause synchronization. Optimize. + max_seq_len = seq_lens.max().item() + max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_tokens, + max_query_len=max_num_tokens, + query_start_loc=cu_num_tokens, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table, + slot_mapping=target_slot_mapping, + # TODO(woosuk): Support cascade attention. + use_cascade=False, + common_prefix_len=0, + cu_prefix_query_lens=None, + prefix_kv_lens=None, + suffix_kv_lens=None, + ) + elif self.method == "eagle3": + query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] + max_query_len = query_lens.max().item() + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=cu_num_tokens, seq_lens=seq_lens) + # FIXME: reorder_batch() needs to be called before build() + # because fields of attn_metadata_builder needs to be updated. + # However, currently reorder_batch() takes input_batch and + # scheduler_output as arguments, we should probably refactor + # the method to use new data structures which are independent + # from input_batch and scheduler_output. + # self.runner.attn_metadata_builder.reorder_batch( + # input_batch=self.runner.input_batch, + # scheduler_output=self.runner.scheduler_output, + # ) + + # FIXME: need to consider multiple kv_cache_groups + attn_metadata = self.runner.attn_metadata_builders[0].build( + num_reqs=batch_size, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + raise ValueError(f"Unsupported method: {self.method}") + if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) @@ -152,6 +182,11 @@ def propose( # [batch_size, 1] return draft_token_ids.view(-1, 1) + # TODO: Currently, MTP module released by deepseek only has + # one layer. Adapt this code to support multiple layers once + # there's a multi-layer MTP module. + assert self.method != "deepseek_mtp" + # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -342,9 +377,9 @@ def dummy_run( with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): self.model( - input_ids=self.input_ids[:num_tokens], - positions=self.positions[:num_tokens], - hidden_states=self.hidden_states[:num_tokens], + self.input_ids[:num_tokens], + self.positions[:num_tokens], + self.hidden_states[:num_tokens], ) diff --git a/vllm/v1/spec_decode/mtp_proposer.py b/vllm/v1/spec_decode/mtp_proposer.py index 58d4fe277756..43e6649db09c 100644 --- a/vllm/v1/spec_decode/mtp_proposer.py +++ b/vllm/v1/spec_decode/mtp_proposer.py @@ -203,9 +203,9 @@ def propose( with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - previous_hidden_states=self.hidden_states[:num_input_tokens], + self.input_ids[:num_input_tokens], + self.positions[:num_input_tokens], + self.hidden_states[:num_input_tokens], ) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -258,7 +258,7 @@ def dummy_run( with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): self.model( - input_ids=self.input_ids[:num_tokens], - positions=self.positions[:num_tokens], - previous_hidden_states=self.hidden_states[:num_tokens], + self.input_ids[:num_tokens], + self.positions[:num_tokens], + self.hidden_states[:num_tokens], ) \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c85d90bfd101..7c2f4e75a902 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -159,14 +159,10 @@ def __init__( if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - if self.speculative_config.method == "deepseek_mtp": - self.drafter = MtpProposer(self.vllm_config, - self) # type: ignore - else: - self.drafter = EagleProposer( - self.vllm_config, self.device) # type: ignore - if self.speculative_config.method == "eagle3": - self.use_aux_hidden_state_outputs = True + self.drafter = EagleProposer( + self.vllm_config, self.device) # type: ignore + if self.speculative_config.method == "eagle3": + self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "medusa": self.drafter = MedusaProposer( vllm_config=self.vllm_config, From 5d0296541511fbee467d822616e50452d082d2e7 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Fri, 23 May 2025 05:09:29 +0000 Subject: [PATCH 13/19] fix minor bug and remove mtp_proposer Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/v1/spec_decode/eagle.py | 27 +-- vllm/v1/spec_decode/mtp_proposer.py | 264 ---------------------------- vllm/v1/worker/gpu_model_runner.py | 7 +- 3 files changed, 20 insertions(+), 278 deletions(-) delete mode 100644 vllm/v1/spec_decode/mtp_proposer.py diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 19719e3db32a..0ac119d37ce0 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -26,11 +26,14 @@ def __init__( self, vllm_config: VllmConfig, device: torch.device, + runner, ): self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method + + self.runner = runner self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len @@ -107,7 +110,7 @@ def propose( # FA requires seq_len to have dtype int32. seq_lens = (target_positions[last_token_indices] + 1).int() - if self.method in []: + if self.method in ["eagle", "eagle3"]: # FIXME(woosuk): The below two ops cause synchronization. Optimize. max_seq_len = seq_lens.max().item() max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() @@ -126,7 +129,7 @@ def propose( prefix_kv_lens=None, suffix_kv_lens=None, ) - elif self.method == "eagle3": + elif self.method == "deepseek_mtp": query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() @@ -144,7 +147,7 @@ def propose( # ) # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_metadata_builders[0].build( + attn_metadata = self.runner.attn_metadata_builder.build( num_reqs=batch_size, num_actual_tokens=num_tokens, max_query_len=max_query_len, @@ -166,11 +169,15 @@ def propose( with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens): - last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens], + ret_hidden_states = self.model( + self.input_ids[:num_input_tokens], + self.positions[:num_input_tokens], + self.hidden_states[:num_input_tokens], ) + if self.method == "deepseek_mtp": + last_hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids = logits.argmax(dim=-1) @@ -250,9 +257,9 @@ def propose( self.vllm_config, num_tokens=input_batch_size): last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:input_batch_size], - positions=self.positions[:input_batch_size], - hidden_states=self.hidden_states[:input_batch_size], + self.input_ids[:input_batch_size], + self.positions[:input_batch_size], + self.hidden_states[:input_batch_size], ) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], diff --git a/vllm/v1/spec_decode/mtp_proposer.py b/vllm/v1/spec_decode/mtp_proposer.py deleted file mode 100644 index 43e6649db09c..000000000000 --- a/vllm/v1/spec_decode/mtp_proposer.py +++ /dev/null @@ -1,264 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from typing import Optional - -import torch -import torch.nn as nn - -from vllm.attention.layer import Attention -from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config, set_current_vllm_config) -from vllm.forward_context import set_forward_context -from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.utils import ( - process_weights_after_loading, set_default_torch_dtype) -from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP -from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.utils import prepare_input_kernel - - -# FIXME(woosuk): The logic here is duplicated with the main sampling code. -# We should refactor this to reuse the same sampling implementation. -def compute_probs_and_sample_next_token( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> tuple[torch.Tensor, torch.Tensor]: - if sampling_metadata.all_greedy: - # For greedy requests, draft_probs is not used in rejection sampling. - # Therefore, we can just return the logits. - probs = logits - next_token_ids = logits.argmax(dim=-1) - return next_token_ids, probs - - is_greedy = sampling_metadata.temperature == -1 - temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) - logits.div_(temperature.view(-1, 1)) - probs = logits.softmax(dim=-1, dtype=torch.float32) - - # NOTE(woosuk): Currently, we ignore most of the sampling parameters in - # generating the draft tokens. We only use the temperature. While this - # could degrade the acceptance rate, it does not affect the distribution - # of the generated tokens after rejection sampling. - - # TODO(woosuk): Consider seeds. - q = torch.empty_like(probs) - q.exponential_() - next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) - if not sampling_metadata.all_random: - greedy_token_ids = probs.argmax(dim=-1) - next_token_ids = torch.where( - is_greedy, - greedy_token_ids, - next_token_ids, - ) - return next_token_ids, probs - - -class MtpProposer: - - def __init__( - self, - vllm_config: VllmConfig, - runner, - ): - self.vllm_config = vllm_config - self.num_speculative_tokens = ( - vllm_config.speculative_config.num_speculative_tokens) - self.block_size = vllm_config.cache_config.block_size - self.runner = runner - - self.max_num_tokens = self.runner.max_num_tokens - self.device = self.runner.device - self.dtype = self.runner.dtype - self.hidden_size = self.runner.hidden_size - - self.use_cuda_graph = (self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE and - not self.vllm_config.model_config.enforce_eager) - self.cudagraph_batch_sizes = list( - reversed( - self.vllm_config.compilation_config.cudagraph_capture_sizes)) - - # persistent buffers for cuda graph - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) - self.hidden_states = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) - - @staticmethod - def prepare_inputs( - # [batch_size + 1] - cu_target_query_lens: torch.Tensor, - # [batch_size] - num_rejected_tokens: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - # cu_target_query_lens: [0, a, a + b, a + b + c] - # num_rejected_tokens: [n1, n2, n3] - # num_tokens_per_req: [a - n1, b - n2, c - n3] - # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - # token_indices: [0, 1, ..., a - n1 - 1, - # a, a + 1, ..., a + b - n2 - 1, - # a + b, a + b + 1, ..., a + b + c - n3 - 1] - - # [0, a, a + b, a + b + c] -> [a, b, c] - query_len_per_req = (cu_target_query_lens[1:] - - cu_target_query_lens[:-1]) - # [a, b, c] -> [a - n1, b - n2, c - n3] - num_tokens_per_req = query_len_per_req - num_rejected_tokens - - cu_num_tokens = torch.empty_like(cu_target_query_lens) - torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - cu_num_tokens[0] = 0 - - # FIXME(woosuk): Avoid synchronization. - num_tokens = cu_num_tokens[-1].item() - token_indices = torch.empty( - num_tokens, - dtype=torch.int32, - device=cu_num_tokens.device, - ) - - batch_size = num_rejected_tokens.shape[0] - BLOCK_SIZE = 1024 - prepare_input_kernel[(batch_size, )]( - token_indices, - cu_target_query_lens, - cu_num_tokens, - BLOCK_SIZE=BLOCK_SIZE, - ) - return cu_num_tokens, token_indices - - def propose( - self, - # [num_tokens] - target_token_ids: torch.Tensor, - # [num_tokens] - target_positions: torch.Tensor, - # [num_tokens, hidden_size] - target_hidden_states: torch.Tensor, - # [num_tokens] - target_slot_mapping: torch.Tensor, - # [batch_size] - next_token_ids: torch.Tensor, - # [batch_size + 1] starting with 0 - cu_num_tokens: torch.Tensor, - # [batch_size, max_num_blocks_per_req] - block_table: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> tuple[torch.Tensor, torch.Tensor]: - num_tokens = target_token_ids.shape[0] - batch_size = next_token_ids.shape[0] - last_token_indices = cu_num_tokens[1:] - 1 - - input_ids = torch.empty_like(target_token_ids) - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - input_ids[:-1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - input_ids[last_token_indices] = next_token_ids - - query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] - max_query_len = query_lens.max().item() - - seq_lens = (target_positions[last_token_indices] + 1) - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=cu_num_tokens, seq_lens=seq_lens) - - # FIXME: reorder_batch() needs to be called before build() - # because fields of attn_metadata_builder needs to be updated. - # However, currently reorder_batch() takes input_batch and - # scheduler_output as arguments, we should probably refactor - # the method to use new data structures which are independent - # from input_batch and scheduler_output. - # self.runner.attn_metadata_builder.reorder_batch( - # input_batch=self.runner.input_batch, - # scheduler_output=self.runner.scheduler_output, - # ) - - # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_metadata_builders[0].build( - num_reqs=batch_size, - num_actual_tokens=num_tokens, - max_query_len=max_query_len, - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - ) - - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - else: - num_input_tokens = num_tokens - # copy inputs to buffer for cudagraph - self.positions[:num_tokens] = target_positions - self.hidden_states[:num_tokens] = target_hidden_states - - with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( - self.input_ids[:num_input_tokens], - self.positions[:num_input_tokens], - self.hidden_states[:num_input_tokens], - ) - sample_hidden_states = hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states, None) - draft_token_ids = logits.argmax(dim=-1) - - # TODO: Currently, MTP module released by deepseek only has - # one layer. Adapt this code to support multiple layers once - # there's a multi-layer MTP module. - assert self.num_speculative_tokens == 1 - # [batch_size, 1] - return draft_token_ids.view(-1, 1) - - def load_model(self, target_model: nn.Module) -> None: - loader = get_model_loader(self.vllm_config.load_config) - target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys()) - - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config - # FIXME(lily): This does not handle with distributed inference. - target_device = self.vllm_config.device_config.device - # We need to set the vllm_config here to register attention - # layers in the forward context. - with set_default_torch_dtype( - draft_model_config.dtype), set_current_vllm_config( - self.vllm_config): - - with target_device: - self.model = DeepSeekMTP(vllm_config=self.vllm_config) - - draft_attn_layer_names = (get_layers_from_vllm_config( - self.vllm_config, Attention).keys() - target_attn_layer_names) - assert len(draft_attn_layer_names) == 1 - self.attn_layer_name = next(iter(draft_attn_layer_names)) - self.model.load_weights( - loader.get_all_weights( - self.vllm_config.speculative_config.draft_model_config, - self.model)) - - process_weights_after_loading( - self.model, - self.vllm_config.speculative_config.draft_model_config, - target_device) - - @torch.inference_mode() - def dummy_run( - self, - num_tokens: int, - ) -> None: - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): - self.model( - self.input_ids[:num_tokens], - self.positions[:num_tokens], - self.hidden_states[:num_tokens], - ) \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 364125a625b2..518f8f74c125 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -48,7 +48,6 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm.v1.spec_decode.mtp_proposer import MtpProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache @@ -198,7 +197,7 @@ def __init__( self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.use_eagle(): self.drafter = EagleProposer( - self.vllm_config, self.device) # type: ignore + self.vllm_config, self.device, self) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "medusa": @@ -1347,7 +1346,7 @@ def execute_model( sampling_metadata=sampling_metadata, ) elif self.speculative_config.use_eagle(): - assert isinstance(self.drafter, (EagleProposer, MtpProposer)) + assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] for i, token_ids in enumerate(valid_sampled_token_ids): @@ -1733,7 +1732,7 @@ def _dummy_run( if self.use_spec_decode and \ self.speculative_config.method in ( 'eagle', 'eagle3', "deepseek_mtp"): - assert isinstance(self.drafter, (EagleProposer, MtpProposer)) + assert isinstance(self.drafter, EagleProposer) self.drafter.dummy_run(num_tokens) logit_indices = np.cumsum(num_scheduled_tokens) - 1 From 56e08d78b40f632999149e98e2af42814b402558 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Fri, 23 May 2025 05:12:18 +0000 Subject: [PATCH 14/19] add blank line Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/v1/spec_decode/eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 0ac119d37ce0..caae518d4ea8 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -408,4 +408,4 @@ def compute_probs_and_sample_next_token( greedy_token_ids, next_token_ids, ) - return next_token_ids, probs \ No newline at end of file + return next_token_ids, probs From a76facb1e25fbeb6475a33bb8068b4efd1dcb262 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Fri, 23 May 2025 05:25:29 +0000 Subject: [PATCH 15/19] fix format Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/v1/spec_decode/eagle.py | 17 +++++++++-------- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index caae518d4ea8..cbd45ace6c83 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -10,8 +10,8 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.v1.attention.backends.flash_attn import (FlashAttentionMetadata, - CommonAttentionMetadata) +from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata, + FlashAttentionMetadata) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import prepare_input_kernel @@ -32,7 +32,7 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method - + self.runner = runner self.dtype = vllm_config.model_config.dtype @@ -71,7 +71,7 @@ def __init__( 1, device=device, dtype=torch.int32) - + def propose( self, # [num_tokens] @@ -113,7 +113,8 @@ def propose( if self.method in ["eagle", "eagle3"]: # FIXME(woosuk): The below two ops cause synchronization. Optimize. max_seq_len = seq_lens.max().item() - max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() + max_num_tokens = (cu_num_tokens[1:] - + cu_num_tokens[:-1]).max().item() attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_tokens, max_query_len=max_num_tokens, @@ -134,7 +135,7 @@ def propose( max_query_len = query_lens.max().item() common_attn_metadata = CommonAttentionMetadata( - query_start_loc=cu_num_tokens, seq_lens=seq_lens) + query_start_loc=cu_num_tokens, seq_lens=seq_lens) # FIXME: reorder_batch() needs to be called before build() # because fields of attn_metadata_builder needs to be updated. # However, currently reorder_batch() takes input_batch and @@ -156,7 +157,7 @@ def propose( ) else: raise ValueError(f"Unsupported method: {self.method}") - + if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) @@ -191,7 +192,7 @@ def propose( # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. assert self.method != "deepseek_mtp" - + # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 518f8f74c125..eff6f5306d55 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -196,8 +196,8 @@ def __init__( if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer( - self.vllm_config, self.device, self) # type: ignore + self.drafter = EagleProposer(self.vllm_config, self.device, + self) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "medusa": From 1271a571c1fa48595cdae2765e378cd546f28921 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Fri, 23 May 2025 05:46:50 +0000 Subject: [PATCH 16/19] fix minor issues Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/v1/spec_decode/eagle.py | 24 +++++++++--------------- vllm/v1/spec_decode/utils.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 7 ++----- 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index cbd45ace6c83..8f8542ec0ab7 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -13,7 +13,7 @@ from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata, FlashAttentionMetadata) from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.utils import prepare_input_kernel +from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel logger = init_logger(__name__) @@ -113,7 +113,7 @@ def propose( if self.method in ["eagle", "eagle3"]: # FIXME(woosuk): The below two ops cause synchronization. Optimize. max_seq_len = seq_lens.max().item() - max_num_tokens = (cu_num_tokens[1:] - + max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_tokens, @@ -133,19 +133,9 @@ def propose( elif self.method == "deepseek_mtp": query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] max_query_len = query_lens.max().item() - + common_attn_metadata = CommonAttentionMetadata( query_start_loc=cu_num_tokens, seq_lens=seq_lens) - # FIXME: reorder_batch() needs to be called before build() - # because fields of attn_metadata_builder needs to be updated. - # However, currently reorder_batch() takes input_batch and - # scheduler_output as arguments, we should probably refactor - # the method to use new data structures which are independent - # from input_batch and scheduler_output. - # self.runner.attn_metadata_builder.reorder_batch( - # input_batch=self.runner.input_batch, - # scheduler_output=self.runner.scheduler_output, - # ) # FIXME: need to consider multiple kv_cache_groups attn_metadata = self.runner.attn_metadata_builder.build( @@ -191,7 +181,11 @@ def propose( # TODO: Currently, MTP module released by deepseek only has # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. - assert self.method != "deepseek_mtp" + if self.method == "deepseek_mtp": + logger.warning( + "All Deepseek MTP models only have one layer. " \ + "Might need to change code to support multiple layers." + ) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -311,7 +305,7 @@ def prepare_inputs( batch_size = num_rejected_tokens.shape[0] BLOCK_SIZE = 1024 - prepare_input_kernel[(batch_size, )]( + prepare_eagle_input_kernel[(batch_size, )]( token_indices, cu_target_query_lens, cu_num_tokens, diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index b3b74527734d..334258e7f87a 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -20,7 +20,7 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: @triton.jit -def prepare_input_kernel( +def prepare_eagle_input_kernel( out_ptr, cu_query_lens_ptr, cu_num_tokens_ptr, @@ -42,4 +42,4 @@ def prepare_input_kernel( out_ptr + start_pos + offset, index_start + offset, mask=offset < num_tokens, - ) \ No newline at end of file + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eff6f5306d55..6cdcc3152dc1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -196,7 +196,7 @@ def __init__( if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, + self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True @@ -1304,7 +1304,6 @@ def execute_model( sampled_token_ids = sampler_output.sampled_token_ids max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: - # GPU tensor to CPU list? sync point? # No spec decode tokens. valid_sampled_token_ids = sampled_token_ids.tolist() else: @@ -1729,9 +1728,7 @@ def _dummy_run( else: hidden_states = outputs - if self.use_spec_decode and \ - self.speculative_config.method in ( - 'eagle', 'eagle3', "deepseek_mtp"): + if self.use_spec_decode and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) self.drafter.dummy_run(num_tokens) From 485332a04f5e0a6f268162b758fd588bc56677c3 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Fri, 23 May 2025 06:24:44 +0000 Subject: [PATCH 17/19] fix warning Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/config.py | 6 ++++++ vllm/v1/spec_decode/eagle.py | 5 ----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 1c6a1eb2a9e3..5e65db9ef767 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2522,6 +2522,12 @@ def __post_init__(self): elif (self.draft_model_config.hf_config.model_type == "deepseek_mtp"): self.method = "deepseek_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Deepseek MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) else: self.method = "draft_model" diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 8f8542ec0ab7..7d91c6c35fb2 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -181,11 +181,6 @@ def propose( # TODO: Currently, MTP module released by deepseek only has # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. - if self.method == "deepseek_mtp": - logger.warning( - "All Deepseek MTP models only have one layer. " \ - "Might need to change code to support multiple layers." - ) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] From e02f26cc46cb4291274b48b9bd3b68204aca1bd1 Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Fri, 23 May 2025 12:16:04 +0000 Subject: [PATCH 18/19] fix unit test Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/v1/spec_decode/eagle.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 7d91c6c35fb2..8a545801711e 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -26,7 +26,7 @@ def __init__( self, vllm_config: VllmConfig, device: torch.device, - runner, + runner = None, ): self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config @@ -137,6 +137,8 @@ def propose( common_attn_metadata = CommonAttentionMetadata( query_start_loc=cu_num_tokens, seq_lens=seq_lens) + assert self.runner is not None + # FIXME: need to consider multiple kv_cache_groups attn_metadata = self.runner.attn_metadata_builder.build( num_reqs=batch_size, From a9c890ca2d37e595651fa94db9af731a0397389f Mon Sep 17 00:00:00 2001 From: YaoJiayi <120040070@link.cuhk.edu.cn> Date: Fri, 23 May 2025 12:23:34 +0000 Subject: [PATCH 19/19] format fix Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn> --- vllm/v1/spec_decode/eagle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 8a545801711e..3926a86ee591 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -26,7 +26,7 @@ def __init__( self, vllm_config: VllmConfig, device: torch.device, - runner = None, + runner=None, ): self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config @@ -138,7 +138,7 @@ def propose( query_start_loc=cu_num_tokens, seq_lens=seq_lens) assert self.runner is not None - + # FIXME: need to consider multiple kv_cache_groups attn_metadata = self.runner.attn_metadata_builder.build( num_reqs=batch_size,