Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 98 additions & 8 deletions tests/v1/e2e/spec_decode/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.benchmarks.datasets import InstructCoderDataset
from vllm.config import VllmConfig
from vllm.config import VllmConfig, replace
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.v1.metrics.reader import Metric
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model

MTP_SIMILARITY_RATE = 0.8

Expand Down Expand Up @@ -919,13 +918,104 @@ def test_draft_model_engine_args_tensor_parallelism():
"draft_tensor_parallel_size": 1, # <<< valid arg name
},
)
tgt_vllm_config: VllmConfig = engine_args.create_engine_config()
assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2
assert tgt_vllm_config.quant_config.get_name() == "fp8"
target_config: VllmConfig = engine_args.create_engine_config()
assert target_config.parallel_config.tensor_parallel_size == 2
assert target_config.quant_config.get_name() == "fp8"

speculative_config = target_config.speculative_config
draft_config: VllmConfig = replace(
target_config,
quant_config=None,
parallel_config=replace(
speculative_config.draft_parallel_config,
rank=target_config.parallel_config.rank,
),
model_config=speculative_config.draft_model_config,
)
assert draft_config.parallel_config.tensor_parallel_size == 1
assert draft_config.quant_config is None


def _apply_draft_moe_backend(vllm_config: VllmConfig) -> VllmConfig:
"""Replicate SpecDecodeBaseProposer._create_draft_vllm_config logic
so we can test it without instantiating a full proposer."""
spec_cfg = vllm_config.speculative_config
if spec_cfg.moe_backend is not None:
return replace(
vllm_config,
kernel_config=replace(
vllm_config.kernel_config,
moe_backend=spec_cfg.moe_backend,
),
)
return vllm_config


def test_draft_model_moe_backend_override():
"""When moe_backend is set in speculative_config, the draft VllmConfig
should use it while the target keeps its own setting."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
moe_backend="flashinfer_trtllm",
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
"moe_backend": "triton",
},
)
tgt_config: VllmConfig = engine_args.create_engine_config()
assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm"
assert tgt_config.speculative_config.moe_backend == "triton"

draft_config = _apply_draft_moe_backend(tgt_config)
assert draft_config.kernel_config.moe_backend == "triton"
# Target config must be unaffected.
assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm"


def test_draft_model_moe_backend_inherits_target():
"""When moe_backend is not set in speculative_config, the draft should
inherit the target's moe_backend."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
moe_backend="flashinfer_cutlass",
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
},
)
tgt_config: VllmConfig = engine_args.create_engine_config()
assert tgt_config.kernel_config.moe_backend == "flashinfer_cutlass"
assert tgt_config.speculative_config.moe_backend is None

draft_config = _apply_draft_moe_backend(tgt_config)
assert draft_config.kernel_config.moe_backend == "flashinfer_cutlass"
assert draft_config is tgt_config


def test_draft_model_moe_backend_default_auto():
"""When neither target nor draft set moe_backend explicitly, both should
default to 'auto'."""
engine_args = EngineArgs(
model="Qwen/Qwen3-1.7B",
tensor_parallel_size=1,
speculative_config={
"model": "Qwen/Qwen3-0.6B",
"method": "draft_model",
"num_speculative_tokens": 3,
},
)
tgt_config: VllmConfig = engine_args.create_engine_config()
assert tgt_config.kernel_config.moe_backend == "auto"
assert tgt_config.speculative_config.moe_backend is None

draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config)
assert draft_vllm_config.parallel_config.tensor_parallel_size == 1
assert draft_vllm_config.quant_config is None
draft_config = _apply_draft_moe_backend(tgt_config)
assert draft_config.kernel_config.moe_backend == "auto"
assert draft_config is tgt_config


def test_draft_model_engine_args_rejects_invalid_tp_argname():
Expand Down
6 changes: 6 additions & 0 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing_extensions import Self

from vllm.config import LoadConfig
from vllm.config.kernel import MoEBackend
from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
Expand Down Expand Up @@ -93,6 +94,11 @@ class SpeculativeConfig:
"""Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method."""
moe_backend: MoEBackend | None = None
"""MoE backend to use for the draft model. When `None`, the draft model
inherits the target model's `--moe-backend` setting. Useful when the
drafter and generator require different MoE kernels (e.g. quantized
generator with unquantized drafter)."""
max_model_len: int | None = Field(default=None, ge=1)
"""The maximum model length of the draft model. Used when testing the
ability to skip speculation for some sequences."""
Expand Down
23 changes: 18 additions & 5 deletions vllm/v1/spec_decode/draft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from typing_extensions import override

from vllm.config import VllmConfig
from vllm.config.utils import replace
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model

logger = init_logger(__name__)

Expand Down Expand Up @@ -50,16 +50,29 @@ def _raise_if_draft_tp_mismatch(self):
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
)

@override
def _create_draft_vllm_config(self) -> VllmConfig:
base = super()._create_draft_vllm_config()
spec = self.speculative_config

return replace(
base,
quant_config=None,
parallel_config=replace(
spec.draft_parallel_config,
rank=self.vllm_config.parallel_config.rank,
),
model_config=spec.draft_model_config,
)

@override
def _get_model(self) -> nn.Module:
# Draft models may be quantized or on different parallelism,
# so we load them with a modified vllm config
from vllm.compilation.backends import set_model_tag

temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config)
draft_vllm_config = self._create_draft_vllm_config()
with set_model_tag("draft_model"):
model = get_model(
vllm_config=temp_vllm_config,
vllm_config=draft_vllm_config,
prefix="draft_model",
)
return model
Expand Down
20 changes: 18 additions & 2 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from dataclasses import replace
from importlib.util import find_spec
from typing import cast

Expand All @@ -13,6 +12,7 @@
CUDAGraphMode,
VllmConfig,
get_layers_from_vllm_config,
replace,
)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
Expand Down Expand Up @@ -1213,16 +1213,32 @@ def get_model_name(self, model: nn.Module) -> str:
model = model.module
return model.__class__.__name__

def _create_draft_vllm_config(self) -> VllmConfig:
"""Return a VllmConfig with kernel-level overrides for the proposer.
Subclasses may override to apply additional config changes.
"""
spec_cfg = self.speculative_config
if spec_cfg.moe_backend is not None:
return replace(
self.vllm_config,
kernel_config=replace(
self.vllm_config.kernel_config,
moe_backend=spec_cfg.moe_backend,
),
)
return self.vllm_config

def _get_model(self) -> nn.Module:
"""
Default method to call get_model(). Can be overridden by subclasses which
need to customize model loading.
"""
from vllm.compilation.backends import set_model_tag

draft_vllm_config = self._create_draft_vllm_config()
with set_model_tag("eagle_head"):
model = get_model(
vllm_config=self.vllm_config,
vllm_config=draft_vllm_config,
model_config=self.speculative_config.draft_model_config,
load_config=self.speculative_config.draft_load_config,
)
Expand Down
25 changes: 0 additions & 25 deletions vllm/v1/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch

from vllm.config import VllmConfig, replace
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import (
Expand Down Expand Up @@ -258,30 +257,6 @@ def compute_new_slot_mapping(
return new_slot_mapping


def create_vllm_config_for_draft_model(
target_model_vllm_config: VllmConfig,
) -> VllmConfig:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the drafter.
The vllm_config is useful when loading the draft model with get_model().
"""
old = target_model_vllm_config
assert old.speculative_config is not None, "speculative_config is not set"
old_spec_config = old.speculative_config
new_parallel_config = replace(
old_spec_config.draft_parallel_config, rank=old.parallel_config.rank
)
new: VllmConfig = replace(
old,
quant_config=None,
parallel_config=new_parallel_config,
model_config=old_spec_config.draft_model_config,
)
return new


def extend_all_queries_by_N(
common_attn_metadata: CommonAttentionMetadata,
N: int,
Expand Down
Loading