diff --git a/tests/v1/e2e/spec_decode/test_spec_decode.py b/tests/v1/e2e/spec_decode/test_spec_decode.py index 4695f6f19662..9ea41c774c46 100644 --- a/tests/v1/e2e/spec_decode/test_spec_decode.py +++ b/tests/v1/e2e/spec_decode/test_spec_decode.py @@ -20,12 +20,11 @@ from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.image import VLM_IMAGES_DIR from vllm.benchmarks.datasets import InstructCoderDataset -from vllm.config import VllmConfig +from vllm.config import VllmConfig, replace from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.v1.metrics.reader import Metric -from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model MTP_SIMILARITY_RATE = 0.8 @@ -919,13 +918,104 @@ def test_draft_model_engine_args_tensor_parallelism(): "draft_tensor_parallel_size": 1, # <<< valid arg name }, ) - tgt_vllm_config: VllmConfig = engine_args.create_engine_config() - assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2 - assert tgt_vllm_config.quant_config.get_name() == "fp8" + target_config: VllmConfig = engine_args.create_engine_config() + assert target_config.parallel_config.tensor_parallel_size == 2 + assert target_config.quant_config.get_name() == "fp8" + + speculative_config = target_config.speculative_config + draft_config: VllmConfig = replace( + target_config, + quant_config=None, + parallel_config=replace( + speculative_config.draft_parallel_config, + rank=target_config.parallel_config.rank, + ), + model_config=speculative_config.draft_model_config, + ) + assert draft_config.parallel_config.tensor_parallel_size == 1 + assert draft_config.quant_config is None + + +def _apply_draft_moe_backend(vllm_config: VllmConfig) -> VllmConfig: + """Replicate SpecDecodeBaseProposer._create_draft_vllm_config logic + so we can test it without instantiating a full proposer.""" + spec_cfg = vllm_config.speculative_config + if spec_cfg.moe_backend is not None: + return replace( + vllm_config, + kernel_config=replace( + vllm_config.kernel_config, + moe_backend=spec_cfg.moe_backend, + ), + ) + return vllm_config + + +def test_draft_model_moe_backend_override(): + """When moe_backend is set in speculative_config, the draft VllmConfig + should use it while the target keeps its own setting.""" + engine_args = EngineArgs( + model="Qwen/Qwen3-1.7B", + tensor_parallel_size=1, + moe_backend="flashinfer_trtllm", + speculative_config={ + "model": "Qwen/Qwen3-0.6B", + "method": "draft_model", + "num_speculative_tokens": 3, + "moe_backend": "triton", + }, + ) + tgt_config: VllmConfig = engine_args.create_engine_config() + assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm" + assert tgt_config.speculative_config.moe_backend == "triton" + + draft_config = _apply_draft_moe_backend(tgt_config) + assert draft_config.kernel_config.moe_backend == "triton" + # Target config must be unaffected. + assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm" + + +def test_draft_model_moe_backend_inherits_target(): + """When moe_backend is not set in speculative_config, the draft should + inherit the target's moe_backend.""" + engine_args = EngineArgs( + model="Qwen/Qwen3-1.7B", + tensor_parallel_size=1, + moe_backend="flashinfer_cutlass", + speculative_config={ + "model": "Qwen/Qwen3-0.6B", + "method": "draft_model", + "num_speculative_tokens": 3, + }, + ) + tgt_config: VllmConfig = engine_args.create_engine_config() + assert tgt_config.kernel_config.moe_backend == "flashinfer_cutlass" + assert tgt_config.speculative_config.moe_backend is None + + draft_config = _apply_draft_moe_backend(tgt_config) + assert draft_config.kernel_config.moe_backend == "flashinfer_cutlass" + assert draft_config is tgt_config + + +def test_draft_model_moe_backend_default_auto(): + """When neither target nor draft set moe_backend explicitly, both should + default to 'auto'.""" + engine_args = EngineArgs( + model="Qwen/Qwen3-1.7B", + tensor_parallel_size=1, + speculative_config={ + "model": "Qwen/Qwen3-0.6B", + "method": "draft_model", + "num_speculative_tokens": 3, + }, + ) + tgt_config: VllmConfig = engine_args.create_engine_config() + assert tgt_config.kernel_config.moe_backend == "auto" + assert tgt_config.speculative_config.moe_backend is None - draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config) - assert draft_vllm_config.parallel_config.tensor_parallel_size == 1 - assert draft_vllm_config.quant_config is None + draft_config = _apply_draft_moe_backend(tgt_config) + assert draft_config.kernel_config.moe_backend == "auto" + assert draft_config is tgt_config def test_draft_model_engine_args_rejects_invalid_tp_argname(): diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index e9dc4cac5c11..8c81b36a8f91 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -9,6 +9,7 @@ from typing_extensions import Self from vllm.config import LoadConfig +from vllm.config.kernel import MoEBackend from vllm.config.model import ModelConfig from vllm.config.parallel import ParallelConfig from vllm.config.utils import config @@ -93,6 +94,11 @@ class SpeculativeConfig: """Quantization method that was used to quantize the draft model weights. If `None`, we assume the model weights are not quantized. Note that it only takes effect when using the draft model-based speculative method.""" + moe_backend: MoEBackend | None = None + """MoE backend to use for the draft model. When `None`, the draft model + inherits the target model's `--moe-backend` setting. Useful when the + drafter and generator require different MoE kernels (e.g. quantized + generator with unquantized drafter).""" max_model_len: int | None = Field(default=None, ge=1) """The maximum model length of the draft model. Used when testing the ability to skip speculation for some sequences.""" diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 4361d6f0bc75..9633e2ef6ca2 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -6,10 +6,10 @@ from typing_extensions import override from vllm.config import VllmConfig +from vllm.config.utils import replace from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer -from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model logger = init_logger(__name__) @@ -50,16 +50,29 @@ def _raise_if_draft_tp_mismatch(self): "Please pass 'draft_tensor_parallel_size' in the speculative_config." ) + @override + def _create_draft_vllm_config(self) -> VllmConfig: + base = super()._create_draft_vllm_config() + spec = self.speculative_config + + return replace( + base, + quant_config=None, + parallel_config=replace( + spec.draft_parallel_config, + rank=self.vllm_config.parallel_config.rank, + ), + model_config=spec.draft_model_config, + ) + @override def _get_model(self) -> nn.Module: - # Draft models may be quantized or on different parallelism, - # so we load them with a modified vllm config from vllm.compilation.backends import set_model_tag - temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config) + draft_vllm_config = self._create_draft_vllm_config() with set_model_tag("draft_model"): model = get_model( - vllm_config=temp_vllm_config, + vllm_config=draft_vllm_config, prefix="draft_model", ) return model diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4b20413ca702..62333526c26d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast -from dataclasses import replace from importlib.util import find_spec from typing import cast @@ -13,6 +12,7 @@ CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, + replace, ) from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context @@ -1213,6 +1213,21 @@ def get_model_name(self, model: nn.Module) -> str: model = model.module return model.__class__.__name__ + def _create_draft_vllm_config(self) -> VllmConfig: + """Return a VllmConfig with kernel-level overrides for the proposer. + Subclasses may override to apply additional config changes. + """ + spec_cfg = self.speculative_config + if spec_cfg.moe_backend is not None: + return replace( + self.vllm_config, + kernel_config=replace( + self.vllm_config.kernel_config, + moe_backend=spec_cfg.moe_backend, + ), + ) + return self.vllm_config + def _get_model(self) -> nn.Module: """ Default method to call get_model(). Can be overridden by subclasses which @@ -1220,9 +1235,10 @@ def _get_model(self) -> nn.Module: """ from vllm.compilation.backends import set_model_tag + draft_vllm_config = self._create_draft_vllm_config() with set_model_tag("eagle_head"): model = get_model( - vllm_config=self.vllm_config, + vllm_config=draft_vllm_config, model_config=self.speculative_config.draft_model_config, load_config=self.speculative_config.draft_load_config, ) diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 48840967b4b8..b85459c86f24 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch -from vllm.config import VllmConfig, replace from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.utils import ( @@ -258,30 +257,6 @@ def compute_new_slot_mapping( return new_slot_mapping -def create_vllm_config_for_draft_model( - target_model_vllm_config: VllmConfig, -) -> VllmConfig: - """The vllm_config is configured for the target model, e.g. - its quant_config and parallel_config. But the draft model is potentially - quantized differently, and has potentially different tensor_parallel_size. - This function creates a new vllm_config configured for the drafter. - The vllm_config is useful when loading the draft model with get_model(). - """ - old = target_model_vllm_config - assert old.speculative_config is not None, "speculative_config is not set" - old_spec_config = old.speculative_config - new_parallel_config = replace( - old_spec_config.draft_parallel_config, rank=old.parallel_config.rank - ) - new: VllmConfig = replace( - old, - quant_config=None, - parallel_config=new_parallel_config, - model_config=old_spec_config.draft_model_config, - ) - return new - - def extend_all_queries_by_N( common_attn_metadata: CommonAttentionMetadata, N: int,