From 75f90e7e78cf47aaca5cdba67b7af644acba146d Mon Sep 17 00:00:00 2001 From: skylee-01 <497627264@qq.com> Date: Wed, 9 Jul 2025 09:41:27 +0800 Subject: [PATCH 1/8] [Model] vllm v1 support mlp_speculator Signed-off-by: skylee-01 <497627264@qq.com> --- vllm/model_executor/models/mlp_speculator.py | 12 +++- vllm/v1/spec_decode/mlp_speculator.py | 63 ++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 36 +++++++++++ 3 files changed, 108 insertions(+), 3 deletions(-) create mode 100644 vllm/v1/spec_decode/mlp_speculator.py diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index c6a97388dc18..732d43ecb0c1 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -70,7 +70,12 @@ class MLPSpeculator(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() - config = vllm_config.model_config.hf_config + if hasattr(vllm_config, 'speculative_config'): + config = vllm_config.speculative_config.draft_model_config.hf_config + self.sampling_metadata_is_required = False + else: + config = vllm_config.model_config.hf_config + self.sampling_metadata_is_required = True self.n_predict = config.n_predict self.vocab_size = config.vocab_size self.emb_dim = config.emb_dim @@ -182,8 +187,9 @@ def generate_proposals( # TODO: not yet supporting top_k_tokens_per_head states = states.flatten(0, 1) - logits = self.logits_processor(self.head[head_index], states, - sampling_metadata) + logits = self.logits_processor( + self.head[head_index], states, sampling_metadata + if self.sampling_metadata_is_required else None) output = self.sampler(logits, sampling_metadata) last_tokens = output.sampled_token_ids diff --git a/vllm/v1/spec_decode/mlp_speculator.py b/vllm/v1/spec_decode/mlp_speculator.py new file mode 100644 index 000000000000..2bc03f62f1e3 --- /dev/null +++ b/vllm/v1/spec_decode/mlp_speculator.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.v1.sample.metadata import SamplingMetadata + +# Initialize logger +logger = init_logger(__name__) + + +class MLPSpeculatorProposer: + """ + MLPSpeculator proposer class for generating token sequences + """ + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + # Save config parameters + self.vllm_config = vllm_config + self.device = device + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.hidden_size = vllm_config.speculative_config.\ + draft_model_config.get_hidden_size( + ) + self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + self.dtype = vllm_config.model_config.dtype + + def propose( + self, + input_ids: torch.Tensor, + previous_hidden_states: torch.Tensor, + num_predict_tokens: int, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + # Generate blocks and compute logits + draft_tokens = self.model.generate_proposals(input_ids, previous_hidden_states, num_predict_tokens,sampling_metadata) + draft_tokens = list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) + return draft_tokens + + def load_model(self, target_model: nn.Module) -> None: + self.model = get_model(vllm_config=self.vllm_config, + model_config=self.vllm_config. + speculative_config.draft_model_config) + + @torch.inference_mode() + def dummy_run(self, num_tokens: int) -> None: + input_ids = torch.zeros((self.max_num_seqs, 1), device=self.device) + hidden_states = torch.zeros((self.max_num_seqs, self.hidden_size), + dtype=self.dtype, + device=self.device) + num_predict_tokens = self.num_speculative_tokens + with set_forward_context(None, self.vllm_config, + num_tokens=num_tokens): + self.model.generate_proposals(input_ids, hidden_states, num_predict_tokens, None) \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 670e653929ce..00e446975ded 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -64,6 +64,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.mlp_speculator import MLPSpeculatorProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -188,6 +189,9 @@ def __init__( self.drafter = MedusaProposer( vllm_config=self.vllm_config, device=self.device) # type: ignore + elif self.speculative_config.method == "mlp_speculator": + self.drafter = MLPSpeculatorProposer(self.vllm_config, + self.device) else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -1638,6 +1642,38 @@ def propose_draft_token_ids( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) + elif self.speculative_config.method == "mlp_speculator": + assert isinstance(self.drafter, MLPSpeculatorProposer) + + is_sample_match = sample_hidden_states.shape[0] == len( + sampled_token_ids) + # Get last token from each sequence + draft_input_ids = torch.tensor( + sampled_token_ids[0] if is_sample_match else + [tokens[-1] for tokens in sampled_token_ids], + device=sample_hidden_states.device) + + if is_sample_match: + # Calculate indices for hidden states + indices = [] + offset = 0 + for num_draft, tokens in zip( + spec_decode_metadata.num_draft_tokens, + sampled_token_ids): + indices.append(offset + len(tokens) - 1) + offset += num_draft + 1 + indices = torch.tensor(indices, device=self.device) + hidden_states = sample_hidden_states[indices] + else: + hidden_states = sample_hidden_states + + spec_token_ids = self.drafter.propose( + input_ids=draft_input_ids, + previous_hidden_states=hidden_states, + num_predict_tokens=self.vllm_config.speculative_config. + num_speculative_tokens, + sampling_metadata=sampling_metadata, + ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. From 3b6292b83f15d3c56d3013a461456f60a7122e02 Mon Sep 17 00:00:00 2001 From: skylee-01 <497627264@qq.com> Date: Wed, 9 Jul 2025 10:45:34 +0800 Subject: [PATCH 2/8] Optimize code format Signed-off-by: skylee-01 <497627264@qq.com> --- vllm/v1/worker/gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 00e446975ded..ab5efc5b00fc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -177,6 +177,7 @@ def __init__( # NOTE(Jiayi): currently we put the entire draft model on # the last PP rank. This is not ideal if there are many # layers in the draft model. + self.drafter: Union[NgramProposer, EagleProposer, MedusaProposer, MLPSpeculatorProposer, None] = None if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) From 5fa1088c3ba14ffd387c06aa40d6be15df6a2f0c Mon Sep 17 00:00:00 2001 From: skylee-01 <497627264@qq.com> Date: Wed, 9 Jul 2025 12:43:14 +0800 Subject: [PATCH 3/8] Optimize code format Signed-off-by: skylee-01 <497627264@qq.com> --- vllm/v1/worker/gpu_model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ab5efc5b00fc..00e446975ded 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -177,7 +177,6 @@ def __init__( # NOTE(Jiayi): currently we put the entire draft model on # the last PP rank. This is not ideal if there are many # layers in the draft model. - self.drafter: Union[NgramProposer, EagleProposer, MedusaProposer, MLPSpeculatorProposer, None] = None if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) From 7c6ec22c851f3ec2e81228d5421e5c4ea6a312dd Mon Sep 17 00:00:00 2001 From: skylee-01 <497627264@qq.com> Date: Wed, 9 Jul 2025 16:03:56 +0800 Subject: [PATCH 4/8] fix bug Signed-off-by: skylee-01 <497627264@qq.com> --- vllm/v1/spec_decode/mlp_speculator.py | 6 +++--- vllm/v1/worker/gpu_model_runner.py | 6 +----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm/v1/spec_decode/mlp_speculator.py b/vllm/v1/spec_decode/mlp_speculator.py index 2bc03f62f1e3..e6e7d78fd08a 100644 --- a/vllm/v1/spec_decode/mlp_speculator.py +++ b/vllm/v1/spec_decode/mlp_speculator.py @@ -31,7 +31,8 @@ def __init__( self.hidden_size = vllm_config.speculative_config.\ draft_model_config.get_hidden_size( ) - self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens + self.num_speculative_tokens = vllm_config.speculative_config.\ + num_speculative_tokens self.dtype = vllm_config.model_config.dtype def propose( @@ -43,8 +44,7 @@ def propose( ) -> torch.Tensor: # Generate blocks and compute logits draft_tokens = self.model.generate_proposals(input_ids, previous_hidden_states, num_predict_tokens,sampling_metadata) - draft_tokens = list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) - return draft_tokens + return list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) def load_model(self, target_model: nn.Module) -> None: self.model = get_model(vllm_config=self.vllm_config, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 00e446975ded..4abeefcf5c8f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1644,16 +1644,13 @@ def propose_draft_token_ids( ) elif self.speculative_config.method == "mlp_speculator": assert isinstance(self.drafter, MLPSpeculatorProposer) - is_sample_match = sample_hidden_states.shape[0] == len( sampled_token_ids) # Get last token from each sequence draft_input_ids = torch.tensor( - sampled_token_ids[0] if is_sample_match else [tokens[-1] for tokens in sampled_token_ids], device=sample_hidden_states.device) - - if is_sample_match: + if not is_sample_match: # Calculate indices for hidden states indices = [] offset = 0 @@ -1666,7 +1663,6 @@ def propose_draft_token_ids( hidden_states = sample_hidden_states[indices] else: hidden_states = sample_hidden_states - spec_token_ids = self.drafter.propose( input_ids=draft_input_ids, previous_hidden_states=hidden_states, From 1d69184b5632d3e6a0f0987e1f65ae2977f18327 Mon Sep 17 00:00:00 2001 From: skylee-01 <497627264@qq.com> Date: Wed, 9 Jul 2025 16:10:21 +0800 Subject: [PATCH 5/8] Optimize code format Signed-off-by: skylee-01 <497627264@qq.com> --- vllm/v1/spec_decode/mlp_speculator.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/v1/spec_decode/mlp_speculator.py b/vllm/v1/spec_decode/mlp_speculator.py index e6e7d78fd08a..f1ac16075747 100644 --- a/vllm/v1/spec_decode/mlp_speculator.py +++ b/vllm/v1/spec_decode/mlp_speculator.py @@ -28,11 +28,10 @@ def __init__( self.vllm_config = vllm_config self.device = device self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs - self.hidden_size = vllm_config.speculative_config.\ - draft_model_config.get_hidden_size( - ) - self.num_speculative_tokens = vllm_config.speculative_config.\ - num_speculative_tokens + self.hidden_size = (vllm_config.speculative_config. + draft_model_config.get_hidden_size()) + self.num_speculative_tokens = (vllm_config.speculative_config. + num_speculative_tokens) self.dtype = vllm_config.model_config.dtype def propose( @@ -41,7 +40,7 @@ def propose( previous_hidden_states: torch.Tensor, num_predict_tokens: int, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> list[list[int]]: # Generate blocks and compute logits draft_tokens = self.model.generate_proposals(input_ids, previous_hidden_states, num_predict_tokens,sampling_metadata) return list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) From c39822cd91890840f6252e1160e7c854c9b7ef3d Mon Sep 17 00:00:00 2001 From: skylee-01 <497627264@qq.com> Date: Wed, 9 Jul 2025 16:18:13 +0800 Subject: [PATCH 6/8] Optimize code format Signed-off-by: skylee-01 <497627264@qq.com> --- vllm/model_executor/models/mlp_speculator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 732d43ecb0c1..8ef2ade57b62 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -187,9 +187,11 @@ def generate_proposals( # TODO: not yet supporting top_k_tokens_per_head states = states.flatten(0, 1) - logits = self.logits_processor( - self.head[head_index], states, sampling_metadata - if self.sampling_metadata_is_required else None) + if self.logits_processor: + logits = self.logits_processor(self.head[head_index], states, + sampling_metadata) + else: + logits = self.head[head_index](states) output = self.sampler(logits, sampling_metadata) last_tokens = output.sampled_token_ids From 1318345013bea2acee43171ffdaa9ec9aecfc171 Mon Sep 17 00:00:00 2001 From: lisiqi23 Date: Sun, 20 Jul 2025 21:30:38 -0700 Subject: [PATCH 7/8] Update Signed-off-by: lisiqi23 --- tests/models/registry.py | 6 ++---- tests/models/test_registry.py | 4 +--- vllm/model_executor/models/mlp_speculator.py | 5 +++-- vllm/model_executor/models/registry.py | 4 +--- vllm/v1/spec_decode/mlp_speculator.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 +- 6 files changed, 9 insertions(+), 14 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 19725acd6c45..4ca4e2c7f915 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -463,10 +463,8 @@ def check_available_online( _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "MedusaModel": _HfExamplesInfo("JackFram/llama-68m", speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501 - # Temporarily disabled. - # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. - # "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", - # speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501 + "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", + speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501 "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501 trust_remote_code=True), diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 1ce90070c5c8..d1ad1aefd2f0 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -75,9 +75,7 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): @pytest.mark.parametrize( "model_arch,is_pp,init_cuda", [ - # TODO(woosuk): Re-enable this once the MLP Speculator is supported - # in V1. - # ("MLPSpeculatorPreTrainedModel", False, False), + ("MLPSpeculatorPreTrainedModel", False, False), ("DeepseekV2ForCausalLM", True, False), ("Qwen2VLForConditionalGeneration", True, True), ]) diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 8ef2ade57b62..83af82f8022f 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -188,8 +188,9 @@ def generate_proposals( states = states.flatten(0, 1) if self.logits_processor: - logits = self.logits_processor(self.head[head_index], states, - sampling_metadata) + logits = self.logits_processor( + self.head[head_index], states, sampling_metadata + if self.sampling_metadata_is_required else None) else: logits = self.head[head_index](states) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a85e8b0e7b1b..cdc0dc3a4ad8 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -247,9 +247,7 @@ "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), - # Temporarily disabled. - # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. - # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), + "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), } _TRANSFORMERS_MODELS = { diff --git a/vllm/v1/spec_decode/mlp_speculator.py b/vllm/v1/spec_decode/mlp_speculator.py index f1ac16075747..80b74eaa3b96 100644 --- a/vllm/v1/spec_decode/mlp_speculator.py +++ b/vllm/v1/spec_decode/mlp_speculator.py @@ -42,7 +42,7 @@ def propose( sampling_metadata: SamplingMetadata, ) -> list[list[int]]: # Generate blocks and compute logits - draft_tokens = self.model.generate_proposals(input_ids, previous_hidden_states, num_predict_tokens,sampling_metadata) + draft_tokens = self.model.generate_proposals(input_ids, previous_hidden_states, num_predict_tokens, sampling_metadata) return list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) def load_model(self, target_model: nn.Module) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4abeefcf5c8f..4ced180b9b3d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -190,7 +190,7 @@ def __init__( vllm_config=self.vllm_config, device=self.device) # type: ignore elif self.speculative_config.method == "mlp_speculator": - self.drafter = MLPSpeculatorProposer(self.vllm_config, + self.drafter = MLPSpeculatorProposer(self.vllm_config, # type: ignore self.device) else: raise ValueError("Unknown speculative decoding method: " From 356be5e65cbab6c4bf61316828046b05696ec2ba Mon Sep 17 00:00:00 2001 From: skylee-01 <497627264@qq.com> Date: Mon, 21 Jul 2025 12:56:50 +0800 Subject: [PATCH 8/8] Optimize code format Signed-off-by: skylee-01 <497627264@qq.com> --- tests/models/test_registry.py | 12 +++++------- vllm/v1/spec_decode/mlp_speculator.py | 24 +++++++++++++++--------- vllm/v1/worker/gpu_model_runner.py | 5 +++-- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index d1ad1aefd2f0..01b2260abe8c 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -72,13 +72,11 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): @create_new_process_for_each_test() -@pytest.mark.parametrize( - "model_arch,is_pp,init_cuda", - [ - ("MLPSpeculatorPreTrainedModel", False, False), - ("DeepseekV2ForCausalLM", True, False), - ("Qwen2VLForConditionalGeneration", True, True), - ]) +@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [ + ("MLPSpeculatorPreTrainedModel", False, False), + ("DeepseekV2ForCausalLM", True, False), + ("Qwen2VLForConditionalGeneration", True, True), +]) def test_registry_is_pp(model_arch, is_pp, init_cuda): assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp diff --git a/vllm/v1/spec_decode/mlp_speculator.py b/vllm/v1/spec_decode/mlp_speculator.py index 80b74eaa3b96..76caed9309e9 100644 --- a/vllm/v1/spec_decode/mlp_speculator.py +++ b/vllm/v1/spec_decode/mlp_speculator.py @@ -28,10 +28,10 @@ def __init__( self.vllm_config = vllm_config self.device = device self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs - self.hidden_size = (vllm_config.speculative_config. - draft_model_config.get_hidden_size()) - self.num_speculative_tokens = (vllm_config.speculative_config. - num_speculative_tokens) + self.hidden_size = (vllm_config.speculative_config.draft_model_config. + get_hidden_size()) + self.num_speculative_tokens = ( + vllm_config.speculative_config.num_speculative_tokens) self.dtype = vllm_config.model_config.dtype def propose( @@ -42,8 +42,13 @@ def propose( sampling_metadata: SamplingMetadata, ) -> list[list[int]]: # Generate blocks and compute logits - draft_tokens = self.model.generate_proposals(input_ids, previous_hidden_states, num_predict_tokens, sampling_metadata) - return list(map(lambda x: x[0], zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) + draft_tokens = self.model.generate_proposals(input_ids, + previous_hidden_states, + num_predict_tokens, + sampling_metadata) + return list( + map(lambda x: x[0], + zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) def load_model(self, target_model: nn.Module) -> None: self.model = get_model(vllm_config=self.vllm_config, @@ -54,9 +59,10 @@ def load_model(self, target_model: nn.Module) -> None: def dummy_run(self, num_tokens: int) -> None: input_ids = torch.zeros((self.max_num_seqs, 1), device=self.device) hidden_states = torch.zeros((self.max_num_seqs, self.hidden_size), - dtype=self.dtype, - device=self.device) + dtype=self.dtype, + device=self.device) num_predict_tokens = self.num_speculative_tokens with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - self.model.generate_proposals(input_ids, hidden_states, num_predict_tokens, None) \ No newline at end of file + self.model.generate_proposals(input_ids, hidden_states, + num_predict_tokens, None) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4ced180b9b3d..2fb13841b395 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -190,8 +190,9 @@ def __init__( vllm_config=self.vllm_config, device=self.device) # type: ignore elif self.speculative_config.method == "mlp_speculator": - self.drafter = MLPSpeculatorProposer(self.vllm_config, # type: ignore - self.device) + self.drafter = MLPSpeculatorProposer( + self.vllm_config, # type: ignore + self.device) else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}")