diff --git a/tests/models/registry.py b/tests/models/registry.py index 19725acd6c45..4ca4e2c7f915 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -463,10 +463,8 @@ def check_available_online( _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "MedusaModel": _HfExamplesInfo("JackFram/llama-68m", speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501 - # Temporarily disabled. - # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. - # "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", - # speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501 + "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", + speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501 "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501 trust_remote_code=True), diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 1ce90070c5c8..01b2260abe8c 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -72,15 +72,11 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): @create_new_process_for_each_test() -@pytest.mark.parametrize( - "model_arch,is_pp,init_cuda", - [ - # TODO(woosuk): Re-enable this once the MLP Speculator is supported - # in V1. - # ("MLPSpeculatorPreTrainedModel", False, False), - ("DeepseekV2ForCausalLM", True, False), - ("Qwen2VLForConditionalGeneration", True, True), - ]) +@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [ + ("MLPSpeculatorPreTrainedModel", False, False), + ("DeepseekV2ForCausalLM", True, False), + ("Qwen2VLForConditionalGeneration", True, True), +]) def test_registry_is_pp(model_arch, is_pp, init_cuda): assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index c6a97388dc18..83af82f8022f 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -70,7 +70,12 @@ class MLPSpeculator(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() - config = vllm_config.model_config.hf_config + if hasattr(vllm_config, 'speculative_config'): + config = vllm_config.speculative_config.draft_model_config.hf_config + self.sampling_metadata_is_required = False + else: + config = vllm_config.model_config.hf_config + self.sampling_metadata_is_required = True self.n_predict = config.n_predict self.vocab_size = config.vocab_size self.emb_dim = config.emb_dim @@ -182,8 +187,12 @@ def generate_proposals( # TODO: not yet supporting top_k_tokens_per_head states = states.flatten(0, 1) - logits = self.logits_processor(self.head[head_index], states, - sampling_metadata) + if self.logits_processor: + logits = self.logits_processor( + self.head[head_index], states, sampling_metadata + if self.sampling_metadata_is_required else None) + else: + logits = self.head[head_index](states) output = self.sampler(logits, sampling_metadata) last_tokens = output.sampled_token_ids diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a85e8b0e7b1b..cdc0dc3a4ad8 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -247,9 +247,7 @@ "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), - # Temporarily disabled. - # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. - # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), + "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), } _TRANSFORMERS_MODELS = { diff --git a/vllm/v1/spec_decode/mlp_speculator.py b/vllm/v1/spec_decode/mlp_speculator.py new file mode 100644 index 000000000000..76caed9309e9 --- /dev/null +++ b/vllm/v1/spec_decode/mlp_speculator.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.v1.sample.metadata import SamplingMetadata + +# Initialize logger +logger = init_logger(__name__) + + +class MLPSpeculatorProposer: + """ + MLPSpeculator proposer class for generating token sequences + """ + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + # Save config parameters + self.vllm_config = vllm_config + self.device = device + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.hidden_size = (vllm_config.speculative_config.draft_model_config. + get_hidden_size()) + self.num_speculative_tokens = ( + vllm_config.speculative_config.num_speculative_tokens) + self.dtype = vllm_config.model_config.dtype + + def propose( + self, + input_ids: torch.Tensor, + previous_hidden_states: torch.Tensor, + num_predict_tokens: int, + sampling_metadata: SamplingMetadata, + ) -> list[list[int]]: + # Generate blocks and compute logits + draft_tokens = self.model.generate_proposals(input_ids, + previous_hidden_states, + num_predict_tokens, + sampling_metadata) + return list( + map(lambda x: x[0], + zip(*[i.sampled_token_ids.tolist() for i in draft_tokens]))) + + def load_model(self, target_model: nn.Module) -> None: + self.model = get_model(vllm_config=self.vllm_config, + model_config=self.vllm_config. + speculative_config.draft_model_config) + + @torch.inference_mode() + def dummy_run(self, num_tokens: int) -> None: + input_ids = torch.zeros((self.max_num_seqs, 1), device=self.device) + hidden_states = torch.zeros((self.max_num_seqs, self.hidden_size), + dtype=self.dtype, + device=self.device) + num_predict_tokens = self.num_speculative_tokens + with set_forward_context(None, self.vllm_config, + num_tokens=num_tokens): + self.model.generate_proposals(input_ids, hidden_states, + num_predict_tokens, None) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 670e653929ce..2fb13841b395 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -64,6 +64,7 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.mlp_speculator import MLPSpeculatorProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -188,6 +189,10 @@ def __init__( self.drafter = MedusaProposer( vllm_config=self.vllm_config, device=self.device) # type: ignore + elif self.speculative_config.method == "mlp_speculator": + self.drafter = MLPSpeculatorProposer( + self.vllm_config, # type: ignore + self.device) else: raise ValueError("Unknown speculative decoding method: " f"{self.speculative_config.method}") @@ -1638,6 +1643,34 @@ def propose_draft_token_ids( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) + elif self.speculative_config.method == "mlp_speculator": + assert isinstance(self.drafter, MLPSpeculatorProposer) + is_sample_match = sample_hidden_states.shape[0] == len( + sampled_token_ids) + # Get last token from each sequence + draft_input_ids = torch.tensor( + [tokens[-1] for tokens in sampled_token_ids], + device=sample_hidden_states.device) + if not is_sample_match: + # Calculate indices for hidden states + indices = [] + offset = 0 + for num_draft, tokens in zip( + spec_decode_metadata.num_draft_tokens, + sampled_token_ids): + indices.append(offset + len(tokens) - 1) + offset += num_draft + 1 + indices = torch.tensor(indices, device=self.device) + hidden_states = sample_hidden_states[indices] + else: + hidden_states = sample_hidden_states + spec_token_ids = self.drafter.propose( + input_ids=draft_input_ids, + previous_hidden_states=hidden_states, + num_predict_tokens=self.vllm_config.speculative_config. + num_speculative_tokens, + sampling_metadata=sampling_metadata, + ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop.