From bff192d4c94b59e3ceaac176da064ac1ca0863ed Mon Sep 17 00:00:00 2001 From: Andrii Skliar Date: Mon, 23 Mar 2026 11:29:33 +0100 Subject: [PATCH 01/10] [Feature] Add MoE backend configuration for draft models in SpeculativeConfig - Introduced `moe_backend` attribute to `SpeculativeConfig` to specify the MoE backend for draft models. - Updated `create_vllm_config_for_draft_model` to handle the new `moe_backend` setting, ensuring compatibility between drafter and generator configurations. Signed-off-by: Andrii Skliar --- vllm/config/speculative.py | 6 ++++++ vllm/v1/spec_decode/utils.py | 8 ++++++++ 2 files changed, 14 insertions(+) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 8ff6d9753566..66f0075426c3 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -9,6 +9,7 @@ from typing_extensions import Self from vllm.config import LoadConfig +from vllm.config.kernel import MoEBackend from vllm.config.model import ModelConfig from vllm.config.parallel import ParallelConfig from vllm.config.utils import config @@ -93,6 +94,11 @@ class SpeculativeConfig: """Quantization method that was used to quantize the draft model weights. If `None`, we assume the model weights are not quantized. Note that it only takes effect when using the draft model-based speculative method.""" + moe_backend: MoEBackend | None = None + """MoE backend to use for the draft model. When `None`, the draft model + inherits the target model's `--moe-backend` setting. Useful when the + drafter and generator require different MoE kernels (e.g. quantized + generator with unquantized drafter).""" max_model_len: int | None = Field(default=None, ge=1) """The maximum model length of the draft model. Used when testing the ability to skip speculation for some sequences.""" diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index cfc30c3e67f2..b19f2e72186c 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -272,11 +272,19 @@ def create_vllm_config_for_draft_model( new_parallel_config = replace( old_spec_config.draft_parallel_config, rank=old.parallel_config.rank ) + + draft_moe_backend = old_spec_config.moe_backend + if draft_moe_backend is not None: + new_kernel_config = replace(old.kernel_config, moe_backend=draft_moe_backend) + else: + new_kernel_config = old.kernel_config + new: VllmConfig = replace( old, quant_config=None, parallel_config=new_parallel_config, model_config=old_spec_config.draft_model_config, + kernel_config=new_kernel_config, ) return new From 41afbe194da0d8f26e028e621a9d2b6144362a5a Mon Sep 17 00:00:00 2001 From: Andrii Date: Mon, 23 Mar 2026 09:17:47 -0700 Subject: [PATCH 02/10] [Feature] Implement apply_draft_moe_backend utility and enhance draft model tests - Added `apply_draft_moe_backend` function to override `moe_backend` in `VllmConfig` based on `speculative_config`. - Updated `eagle.py` and `medusa.py` to utilize the new utility for model configuration. - Introduced parameterized tests in `test_spec_decode.py` to validate the behavior of draft model configurations with various `moe_backend` scenarios. Signed-off-by: [Andrii Skliar] Signed-off-by: Andrii Skliar --- tests/v1/e2e/spec_decode/test_spec_decode.py | 78 +++++++++++++++++++- vllm/v1/spec_decode/eagle.py | 4 +- vllm/v1/spec_decode/medusa.py | 4 +- vllm/v1/spec_decode/utils.py | 22 ++++-- 4 files changed, 98 insertions(+), 10 deletions(-) diff --git a/tests/v1/e2e/spec_decode/test_spec_decode.py b/tests/v1/e2e/spec_decode/test_spec_decode.py index 4695f6f19662..435ea39419a4 100644 --- a/tests/v1/e2e/spec_decode/test_spec_decode.py +++ b/tests/v1/e2e/spec_decode/test_spec_decode.py @@ -25,7 +25,10 @@ from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.v1.metrics.reader import Metric -from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model +from vllm.v1.spec_decode.utils import ( + apply_draft_moe_backend, + create_vllm_config_for_draft_model, +) MTP_SIMILARITY_RATE = 0.8 @@ -928,6 +931,79 @@ def test_draft_model_engine_args_tensor_parallelism(): assert draft_vllm_config.quant_config is None +@pytest.mark.parametrize( + [ + "target_moe_backend", + "draft_moe_backend", + "expected_target", + "expected_draft", + "expected_applied", + "has_spec_config", + ], + [ + # Draft overrides target + ("flashinfer_trtllm", "triton", + "flashinfer_trtllm", "triton", "triton", True), + # Draft inherits target when unset + ("flashinfer_cutlass", None, + "flashinfer_cutlass", "flashinfer_cutlass", "flashinfer_cutlass", + True), + # Both default to auto + ("auto", None, "auto", "auto", "auto", True), + # No speculative config at all + ("auto", None, "auto", None, "auto", False), + ], + ids=[ + "draft_overrides", + "draft_inherits_target", + "both_default_auto", + "no_spec_config", + ], +) +def test_draft_moe_backend( + target_moe_backend: str, + draft_moe_backend: str | None, + expected_target: str, + expected_draft: str | None, + expected_applied: str, + has_spec_config: bool, +): + """Both create_vllm_config_for_draft_model and apply_draft_moe_backend + must propagate (or inherit) moe_backend correctly.""" + spec_cfg: dict[str, Any] | None = None + if has_spec_config: + spec_cfg = { + "model": "Qwen/Qwen3-0.6B", + "method": "draft_model", + "num_speculative_tokens": 3, + } + if draft_moe_backend is not None: + spec_cfg["moe_backend"] = draft_moe_backend + + engine_args = EngineArgs( + model="Qwen/Qwen3-1.7B", + tensor_parallel_size=1, + moe_backend=target_moe_backend, + speculative_config=spec_cfg, + ) + tgt_cfg: VllmConfig = engine_args.create_engine_config() + assert tgt_cfg.kernel_config.moe_backend == expected_target + + # apply_draft_moe_backend (used by Eagle/MTP/Medusa) + applied = apply_draft_moe_backend(tgt_cfg) + assert applied.kernel_config.moe_backend == expected_applied + # Must never touch model_config / parallel_config + assert applied.model_config is tgt_cfg.model_config + assert applied.parallel_config is tgt_cfg.parallel_config + if draft_moe_backend is None: + assert applied is tgt_cfg # no-op returns same object + + # create_vllm_config_for_draft_model (used by DraftModelProposer) + if has_spec_config: + draft_cfg = create_vllm_config_for_draft_model(tgt_cfg) + assert draft_cfg.kernel_config.moe_backend == expected_draft + + def test_draft_model_engine_args_rejects_invalid_tp_argname(): """The user should pass "draft_tensor_parallel_size" rather than "tensor_parallel_size". We enforce this with validation.""" diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 445bb403b4b3..82809d322a4e 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1223,10 +1223,12 @@ def _get_model(self) -> nn.Module: need to customize model loading. """ from vllm.compilation.backends import set_model_tag + from vllm.v1.spec_decode.utils import apply_draft_moe_backend + draft_vllm_config = apply_draft_moe_backend(self.vllm_config) with set_model_tag("eagle_head"): model = get_model( - vllm_config=self.vllm_config, + vllm_config=draft_vllm_config, model_config=self.speculative_config.draft_model_config, load_config=self.speculative_config.draft_load_config, ) diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 80b0f0a9870a..f2d20b0329fd 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -56,10 +56,12 @@ def propose( def load_model(self, target_model: nn.Module) -> None: from vllm.compilation.backends import set_model_tag + from vllm.v1.spec_decode.utils import apply_draft_moe_backend + draft_vllm_config = apply_draft_moe_backend(self.vllm_config) with set_model_tag("medusa_head"): self.model = get_model( - vllm_config=self.vllm_config, + vllm_config=draft_vllm_config, model_config=self.spec_config.draft_model_config, ) assert not ( diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index b19f2e72186c..3ecac5eee085 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -257,6 +257,19 @@ def compute_new_slot_mapping( return new_slot_mapping +def apply_draft_moe_backend(vllm_config: VllmConfig) -> VllmConfig: + """Return a VllmConfig with kernel_config.moe_backend overridden by + speculative_config.moe_backend when the latter is set. + """ + spec_config = vllm_config.speculative_config + if spec_config is not None and spec_config.moe_backend is not None: + new_kernel_config = replace( + vllm_config.kernel_config, moe_backend=spec_config.moe_backend + ) + return replace(vllm_config, kernel_config=new_kernel_config) + return vllm_config + + def create_vllm_config_for_draft_model( target_model_vllm_config: VllmConfig, ) -> VllmConfig: @@ -273,18 +286,13 @@ def create_vllm_config_for_draft_model( old_spec_config.draft_parallel_config, rank=old.parallel_config.rank ) - draft_moe_backend = old_spec_config.moe_backend - if draft_moe_backend is not None: - new_kernel_config = replace(old.kernel_config, moe_backend=draft_moe_backend) - else: - new_kernel_config = old.kernel_config + applied = apply_draft_moe_backend(old) new: VllmConfig = replace( - old, + applied, quant_config=None, parallel_config=new_parallel_config, model_config=old_spec_config.draft_model_config, - kernel_config=new_kernel_config, ) return new From 5eb53f927ac65218a7a8ebd59a8d5e7740e31863 Mon Sep 17 00:00:00 2001 From: Andrii Skliar Date: Mon, 23 Mar 2026 17:21:13 +0100 Subject: [PATCH 03/10] fix formatting Signed-off-by: Andrii Skliar --- tests/v1/e2e/spec_decode/test_spec_decode.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/v1/e2e/spec_decode/test_spec_decode.py b/tests/v1/e2e/spec_decode/test_spec_decode.py index 435ea39419a4..e6bc1d65503f 100644 --- a/tests/v1/e2e/spec_decode/test_spec_decode.py +++ b/tests/v1/e2e/spec_decode/test_spec_decode.py @@ -942,12 +942,16 @@ def test_draft_model_engine_args_tensor_parallelism(): ], [ # Draft overrides target - ("flashinfer_trtllm", "triton", - "flashinfer_trtllm", "triton", "triton", True), + ("flashinfer_trtllm", "triton", "flashinfer_trtllm", "triton", "triton", True), # Draft inherits target when unset - ("flashinfer_cutlass", None, - "flashinfer_cutlass", "flashinfer_cutlass", "flashinfer_cutlass", - True), + ( + "flashinfer_cutlass", + None, + "flashinfer_cutlass", + "flashinfer_cutlass", + "flashinfer_cutlass", + True, + ), # Both default to auto ("auto", None, "auto", "auto", "auto", True), # No speculative config at all From 09639e5a5ce5787577e86205e05af26323dfbb2b Mon Sep 17 00:00:00 2001 From: Andrii Skliar Date: Mon, 23 Mar 2026 17:37:02 +0100 Subject: [PATCH 04/10] [Test] Enhance draft model tests with parameterization and additional assertions - Updated `test_spec_decode.py` to improve parameterization of draft model tests, allowing for more comprehensive validation of `moe_backend` behavior. - Removed redundant parameters and assertions, streamlining the test logic. - Added a new test to verify that `apply_draft_moe_backend` behaves as a no-op when no speculative configuration is provided. Signed-off-by: Andrii Skliar --- tests/v1/e2e/spec_decode/test_spec_decode.py | 118 ++++++++++++------- 1 file changed, 75 insertions(+), 43 deletions(-) diff --git a/tests/v1/e2e/spec_decode/test_spec_decode.py b/tests/v1/e2e/spec_decode/test_spec_decode.py index e6bc1d65503f..720608af5687 100644 --- a/tests/v1/e2e/spec_decode/test_spec_decode.py +++ b/tests/v1/e2e/spec_decode/test_spec_decode.py @@ -937,75 +937,107 @@ def test_draft_model_engine_args_tensor_parallelism(): "draft_moe_backend", "expected_target", "expected_draft", - "expected_applied", - "has_spec_config", ], [ - # Draft overrides target - ("flashinfer_trtllm", "triton", "flashinfer_trtllm", "triton", "triton", True), - # Draft inherits target when unset - ( - "flashinfer_cutlass", - None, - "flashinfer_cutlass", - "flashinfer_cutlass", - "flashinfer_cutlass", + ("flashinfer_trtllm", "triton", "flashinfer_trtllm", "triton"), + ("flashinfer_cutlass", None, "flashinfer_cutlass", "flashinfer_cutlass"), + ("auto", None, "auto", "auto"), + ], + ids=["draft_overrides", "draft_inherits_target", "both_default_auto"], +) +@pytest.mark.parametrize( + ["target_model", "spec_method_cfg", "tensor_parallel_size", "trust_remote_code"], + [ + pytest.param( + "Qwen/Qwen3-1.7B", + { + "model": "Qwen/Qwen3-0.6B", + "method": "draft_model", + "num_speculative_tokens": 3, + }, + 1, + False, + id="draft_model", + ), + pytest.param( + "eagle618/deepseek-v3-random", + { + "model": "eagle618/eagle-deepseek-v3-random", + "method": "eagle", + "num_speculative_tokens": 3, + }, + 1, + False, + id="eagle", + ), + pytest.param( + "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4", + { + "method": "mtp", + "num_speculative_tokens": 1, + "max_model_len": 2048, + }, + 4, True, + marks=[ + large_gpu_mark(min_gb=80), + *multi_gpu_marks(num_gpus=4), + ], + id="nemotron_mtp", ), - # Both default to auto - ("auto", None, "auto", "auto", "auto", True), - # No speculative config at all - ("auto", None, "auto", None, "auto", False), - ], - ids=[ - "draft_overrides", - "draft_inherits_target", - "both_default_auto", - "no_spec_config", ], ) def test_draft_moe_backend( target_moe_backend: str, draft_moe_backend: str | None, expected_target: str, - expected_draft: str | None, - expected_applied: str, - has_spec_config: bool, + expected_draft: str, + target_model: str, + spec_method_cfg: dict[str, Any], + tensor_parallel_size: int, + trust_remote_code: bool, ): """Both create_vllm_config_for_draft_model and apply_draft_moe_backend - must propagate (or inherit) moe_backend correctly.""" - spec_cfg: dict[str, Any] | None = None - if has_spec_config: - spec_cfg = { - "model": "Qwen/Qwen3-0.6B", - "method": "draft_model", - "num_speculative_tokens": 3, - } - if draft_moe_backend is not None: - spec_cfg["moe_backend"] = draft_moe_backend + must propagate (or inherit) moe_backend correctly across drafting + methods.""" + spec_cfg = {**spec_method_cfg} + if draft_moe_backend is not None: + spec_cfg["moe_backend"] = draft_moe_backend engine_args = EngineArgs( - model="Qwen/Qwen3-1.7B", - tensor_parallel_size=1, + model=target_model, + tensor_parallel_size=tensor_parallel_size, + trust_remote_code=trust_remote_code, moe_backend=target_moe_backend, speculative_config=spec_cfg, ) tgt_cfg: VllmConfig = engine_args.create_engine_config() assert tgt_cfg.kernel_config.moe_backend == expected_target - # apply_draft_moe_backend (used by Eagle/MTP/Medusa) + # apply_draft_moe_backend (used by Eagle/MTP/Medusa proposers) applied = apply_draft_moe_backend(tgt_cfg) - assert applied.kernel_config.moe_backend == expected_applied - # Must never touch model_config / parallel_config + assert applied.kernel_config.moe_backend == expected_draft assert applied.model_config is tgt_cfg.model_config assert applied.parallel_config is tgt_cfg.parallel_config if draft_moe_backend is None: - assert applied is tgt_cfg # no-op returns same object + assert applied is tgt_cfg # create_vllm_config_for_draft_model (used by DraftModelProposer) - if has_spec_config: - draft_cfg = create_vllm_config_for_draft_model(tgt_cfg) - assert draft_cfg.kernel_config.moe_backend == expected_draft + draft_cfg = create_vllm_config_for_draft_model(tgt_cfg) + assert draft_cfg.kernel_config.moe_backend == expected_draft + + +def test_apply_draft_moe_backend_noop_without_spec_config(): + """apply_draft_moe_backend is a no-op when there is no speculative_config.""" + engine_args = EngineArgs( + model="Qwen/Qwen3-1.7B", + tensor_parallel_size=1, + ) + vllm_config: VllmConfig = engine_args.create_engine_config() + assert vllm_config.speculative_config is None + + result = apply_draft_moe_backend(vllm_config) + assert result is vllm_config def test_draft_model_engine_args_rejects_invalid_tp_argname(): From 9be16a291f5942126f386107a75397b9d94d9497 Mon Sep 17 00:00:00 2001 From: Andrii Skliar Date: Tue, 24 Mar 2026 11:47:41 +0100 Subject: [PATCH 05/10] [Refactor] Replace apply_draft_moe_backend with create_vllm_config_for_spec_decode - Updated `test_spec_decode.py` to utilize `create_vllm_config_for_spec_decode` in place of `apply_draft_moe_backend`, enhancing clarity in test assertions. - Refactored `draft_model.py`, `eagle.py`, and `medusa.py` to adopt the new configuration utility, ensuring consistent model loading behavior. - Introduced `create_vllm_config_for_spec_decode` to apply kernel-level overrides from speculative configurations. Signed-off-by: Andrii Skliar --- tests/v1/e2e/spec_decode/test_spec_decode.py | 15 ++++++++------- vllm/v1/spec_decode/draft_model.py | 10 ++++++---- vllm/v1/spec_decode/eagle.py | 10 ++++++++-- vllm/v1/spec_decode/medusa.py | 4 ++-- vllm/v1/spec_decode/utils.py | 19 +++++++++---------- 5 files changed, 33 insertions(+), 25 deletions(-) diff --git a/tests/v1/e2e/spec_decode/test_spec_decode.py b/tests/v1/e2e/spec_decode/test_spec_decode.py index 720608af5687..20c44d331c5e 100644 --- a/tests/v1/e2e/spec_decode/test_spec_decode.py +++ b/tests/v1/e2e/spec_decode/test_spec_decode.py @@ -26,8 +26,8 @@ from vllm.platforms import current_platform from vllm.v1.metrics.reader import Metric from vllm.v1.spec_decode.utils import ( - apply_draft_moe_backend, create_vllm_config_for_draft_model, + create_vllm_config_for_spec_decode, ) MTP_SIMILARITY_RATE = 0.8 @@ -997,7 +997,7 @@ def test_draft_moe_backend( tensor_parallel_size: int, trust_remote_code: bool, ): - """Both create_vllm_config_for_draft_model and apply_draft_moe_backend + """Both create_vllm_config_for_spec_decode and create_vllm_config_for_draft_model must propagate (or inherit) moe_backend correctly across drafting methods.""" spec_cfg = {**spec_method_cfg} @@ -1014,8 +1014,8 @@ def test_draft_moe_backend( tgt_cfg: VllmConfig = engine_args.create_engine_config() assert tgt_cfg.kernel_config.moe_backend == expected_target - # apply_draft_moe_backend (used by Eagle/MTP/Medusa proposers) - applied = apply_draft_moe_backend(tgt_cfg) + # create_vllm_config_for_spec_decode (used by Eagle/MTP/Medusa proposers) + applied = create_vllm_config_for_spec_decode(tgt_cfg) assert applied.kernel_config.moe_backend == expected_draft assert applied.model_config is tgt_cfg.model_config assert applied.parallel_config is tgt_cfg.parallel_config @@ -1027,8 +1027,9 @@ def test_draft_moe_backend( assert draft_cfg.kernel_config.moe_backend == expected_draft -def test_apply_draft_moe_backend_noop_without_spec_config(): - """apply_draft_moe_backend is a no-op when there is no speculative_config.""" +def test_create_vllm_config_for_spec_decode_noop_without_spec_config(): + """create_vllm_config_for_spec_decode is a no-op when there is no + speculative_config.""" engine_args = EngineArgs( model="Qwen/Qwen3-1.7B", tensor_parallel_size=1, @@ -1036,7 +1037,7 @@ def test_apply_draft_moe_backend_noop_without_spec_config(): vllm_config: VllmConfig = engine_args.create_engine_config() assert vllm_config.speculative_config is None - result = apply_draft_moe_backend(vllm_config) + result = create_vllm_config_for_spec_decode(vllm_config) assert result is vllm_config diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 4361d6f0bc75..116a866a1e4f 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -50,16 +50,18 @@ def _raise_if_draft_tp_mismatch(self): "Please pass 'draft_tensor_parallel_size' in the speculative_config." ) + @override + def _create_draft_vllm_config(self) -> VllmConfig: + return create_vllm_config_for_draft_model(self.vllm_config) + @override def _get_model(self) -> nn.Module: - # Draft models may be quantized or on different parallelism, - # so we load them with a modified vllm config from vllm.compilation.backends import set_model_tag - temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config) + draft_vllm_config = self._create_draft_vllm_config() with set_model_tag("draft_model"): model = get_model( - vllm_config=temp_vllm_config, + vllm_config=draft_vllm_config, prefix="draft_model", ) return model diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 82809d322a4e..0c401b67763b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -43,6 +43,7 @@ PADDING_SLOT_ID, compute_new_slot_mapping, copy_and_expand_eagle_inputs_kernel, + create_vllm_config_for_spec_decode, eagle_prepare_inputs_padded_kernel, eagle_prepare_next_token_padded_kernel, eagle_step_update_slot_mapping_and_metadata, @@ -1217,15 +1218,20 @@ def get_model_name(self, model: nn.Module) -> str: model = model.module return model.__class__.__name__ + def _create_draft_vllm_config(self) -> VllmConfig: + """Return a VllmConfig with kernel-level overrides for the proposer. + Subclasses may override to apply additional config changes. + """ + return create_vllm_config_for_spec_decode(self.vllm_config) + def _get_model(self) -> nn.Module: """ Default method to call get_model(). Can be overridden by subclasses which need to customize model loading. """ from vllm.compilation.backends import set_model_tag - from vllm.v1.spec_decode.utils import apply_draft_moe_backend - draft_vllm_config = apply_draft_moe_backend(self.vllm_config) + draft_vllm_config = self._create_draft_vllm_config() with set_model_tag("eagle_head"): model = get_model( vllm_config=draft_vllm_config, diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index f2d20b0329fd..3efa9ece91f6 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -56,9 +56,9 @@ def propose( def load_model(self, target_model: nn.Module) -> None: from vllm.compilation.backends import set_model_tag - from vllm.v1.spec_decode.utils import apply_draft_moe_backend + from vllm.v1.spec_decode.utils import create_vllm_config_for_spec_decode - draft_vllm_config = apply_draft_moe_backend(self.vllm_config) + draft_vllm_config = create_vllm_config_for_spec_decode(self.vllm_config) with set_model_tag("medusa_head"): self.model = get_model( vllm_config=draft_vllm_config, diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 3ecac5eee085..211cd313f6bb 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -257,9 +257,11 @@ def compute_new_slot_mapping( return new_slot_mapping -def apply_draft_moe_backend(vllm_config: VllmConfig) -> VllmConfig: - """Return a VllmConfig with kernel_config.moe_backend overridden by - speculative_config.moe_backend when the latter is set. +def create_vllm_config_for_spec_decode( + vllm_config: VllmConfig, +) -> VllmConfig: + """Apply kernel-level overrides (e.g. moe_backend) from speculative_config. + Returns the original object unchanged when no override is needed. """ spec_config = vllm_config.speculative_config if spec_config is not None and spec_config.moe_backend is not None: @@ -273,11 +275,8 @@ def apply_draft_moe_backend(vllm_config: VllmConfig) -> VllmConfig: def create_vllm_config_for_draft_model( target_model_vllm_config: VllmConfig, ) -> VllmConfig: - """The vllm_config is configured for the target model, e.g. - its quant_config and parallel_config. But the draft model is potentially - quantized differently, and has potentially different tensor_parallel_size. - This function creates a new vllm_config configured for the drafter. - The vllm_config is useful when loading the draft model with get_model(). + """Extends create_vllm_config_for_spec_decode with quant_config, + parallel_config, and model_config overrides for standalone draft models. """ old = target_model_vllm_config assert old.speculative_config is not None, "speculative_config is not set" @@ -286,10 +285,10 @@ def create_vllm_config_for_draft_model( old_spec_config.draft_parallel_config, rank=old.parallel_config.rank ) - applied = apply_draft_moe_backend(old) + base = create_vllm_config_for_spec_decode(old) new: VllmConfig = replace( - applied, + base, quant_config=None, parallel_config=new_parallel_config, model_config=old_spec_config.draft_model_config, From e856556388e3e69ffe22751bd04c23f7f5022182 Mon Sep 17 00:00:00 2001 From: Andrii Skliar Date: Tue, 24 Mar 2026 23:11:13 +0100 Subject: [PATCH 06/10] [Refactor] Simplify draft model configuration handling - Removed the `create_vllm_config_for_draft_model` and `create_vllm_config_for_spec_decode` utility functions, integrating their logic directly into the relevant classes. - Updated `DraftModelProposer`, `SpecDecodeBaseProposer`, and `MedusaProposer` to utilize the `replace` function for configuration overrides, enhancing clarity and maintainability. - Refactored tests in `test_spec_decode.py` to align with the new configuration approach, ensuring accurate validation of `moe_backend` propagation. Signed-off-by: Andrii Skliar --- tests/v1/e2e/spec_decode/test_spec_decode.py | 59 +++++++++++--------- vllm/v1/spec_decode/draft_model.py | 15 ++++- vllm/v1/spec_decode/eagle.py | 12 +++- vllm/v1/spec_decode/medusa.py | 13 ++++- vllm/v1/spec_decode/utils.py | 40 ------------- 5 files changed, 65 insertions(+), 74 deletions(-) diff --git a/tests/v1/e2e/spec_decode/test_spec_decode.py b/tests/v1/e2e/spec_decode/test_spec_decode.py index 20c44d331c5e..fb445ed40f85 100644 --- a/tests/v1/e2e/spec_decode/test_spec_decode.py +++ b/tests/v1/e2e/spec_decode/test_spec_decode.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random from collections.abc import Iterable -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import Any import pytest @@ -25,10 +25,6 @@ from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.v1.metrics.reader import Metric -from vllm.v1.spec_decode.utils import ( - create_vllm_config_for_draft_model, - create_vllm_config_for_spec_decode, -) MTP_SIMILARITY_RATE = 0.8 @@ -926,7 +922,16 @@ def test_draft_model_engine_args_tensor_parallelism(): assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2 assert tgt_vllm_config.quant_config.get_name() == "fp8" - draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config) + sc = tgt_vllm_config.speculative_config + draft_vllm_config: VllmConfig = replace( + tgt_vllm_config, + quant_config=None, + parallel_config=replace( + sc.draft_parallel_config, + rank=tgt_vllm_config.parallel_config.rank, + ), + model_config=sc.draft_model_config, + ) assert draft_vllm_config.parallel_config.tensor_parallel_size == 1 assert draft_vllm_config.quant_config is None @@ -997,9 +1002,8 @@ def test_draft_moe_backend( tensor_parallel_size: int, trust_remote_code: bool, ): - """Both create_vllm_config_for_spec_decode and create_vllm_config_for_draft_model - must propagate (or inherit) moe_backend correctly across drafting - methods.""" + """speculative_config.moe_backend must propagate (or inherit) correctly + across drafting methods.""" spec_cfg = {**spec_method_cfg} if draft_moe_backend is not None: spec_cfg["moe_backend"] = draft_moe_backend @@ -1014,31 +1018,32 @@ def test_draft_moe_backend( tgt_cfg: VllmConfig = engine_args.create_engine_config() assert tgt_cfg.kernel_config.moe_backend == expected_target - # create_vllm_config_for_spec_decode (used by Eagle/MTP/Medusa proposers) - applied = create_vllm_config_for_spec_decode(tgt_cfg) + # Base proposer path: override moe_backend from speculative_config + sc = tgt_cfg.speculative_config + if sc.moe_backend is not None: + applied = replace( + tgt_cfg, + kernel_config=replace(tgt_cfg.kernel_config, moe_backend=sc.moe_backend), + ) + else: + applied = tgt_cfg assert applied.kernel_config.moe_backend == expected_draft assert applied.model_config is tgt_cfg.model_config assert applied.parallel_config is tgt_cfg.parallel_config if draft_moe_backend is None: assert applied is tgt_cfg - # create_vllm_config_for_draft_model (used by DraftModelProposer) - draft_cfg = create_vllm_config_for_draft_model(tgt_cfg) - assert draft_cfg.kernel_config.moe_backend == expected_draft - - -def test_create_vllm_config_for_spec_decode_noop_without_spec_config(): - """create_vllm_config_for_spec_decode is a no-op when there is no - speculative_config.""" - engine_args = EngineArgs( - model="Qwen/Qwen3-1.7B", - tensor_parallel_size=1, + # DraftModelProposer path: extends base with quant/parallel/model overrides + draft_cfg = replace( + applied, + quant_config=None, + parallel_config=replace( + sc.draft_parallel_config, + rank=tgt_cfg.parallel_config.rank, + ), + model_config=sc.draft_model_config, ) - vllm_config: VllmConfig = engine_args.create_engine_config() - assert vllm_config.speculative_config is None - - result = create_vllm_config_for_spec_decode(vllm_config) - assert result is vllm_config + assert draft_cfg.kernel_config.moe_backend == expected_draft def test_draft_model_engine_args_rejects_invalid_tp_argname(): diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 116a866a1e4f..81fb2533dd23 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import replace + import torch import torch.nn as nn from typing_extensions import override @@ -9,7 +11,6 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer -from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model logger = init_logger(__name__) @@ -52,7 +53,17 @@ def _raise_if_draft_tp_mismatch(self): @override def _create_draft_vllm_config(self) -> VllmConfig: - return create_vllm_config_for_draft_model(self.vllm_config) + base = super()._create_draft_vllm_config() + spec = self.speculative_config + return replace( + base, + quant_config=None, + parallel_config=replace( + spec.draft_parallel_config, + rank=self.vllm_config.parallel_config.rank, + ), + model_config=spec.draft_model_config, + ) @override def _get_model(self) -> nn.Module: diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 0c401b67763b..e73d13c2c6af 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -43,7 +43,6 @@ PADDING_SLOT_ID, compute_new_slot_mapping, copy_and_expand_eagle_inputs_kernel, - create_vllm_config_for_spec_decode, eagle_prepare_inputs_padded_kernel, eagle_prepare_next_token_padded_kernel, eagle_step_update_slot_mapping_and_metadata, @@ -1222,7 +1221,16 @@ def _create_draft_vllm_config(self) -> VllmConfig: """Return a VllmConfig with kernel-level overrides for the proposer. Subclasses may override to apply additional config changes. """ - return create_vllm_config_for_spec_decode(self.vllm_config) + spec_cfg = self.speculative_config + if spec_cfg.moe_backend is not None: + return replace( + self.vllm_config, + kernel_config=replace( + self.vllm_config.kernel_config, + moe_backend=spec_cfg.moe_backend, + ), + ) + return self.vllm_config def _get_model(self) -> nn.Module: """ diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 3efa9ece91f6..ce6534e0ff14 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm.config import VllmConfig +from vllm.config import VllmConfig, replace from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -56,9 +56,16 @@ def propose( def load_model(self, target_model: nn.Module) -> None: from vllm.compilation.backends import set_model_tag - from vllm.v1.spec_decode.utils import create_vllm_config_for_spec_decode - draft_vllm_config = create_vllm_config_for_spec_decode(self.vllm_config) + draft_vllm_config = self.vllm_config + if self.spec_config.moe_backend is not None: + draft_vllm_config = replace( + draft_vllm_config, + kernel_config=replace( + draft_vllm_config.kernel_config, + moe_backend=self.spec_config.moe_backend, + ), + ) with set_model_tag("medusa_head"): self.model = get_model( vllm_config=draft_vllm_config, diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 211cd313f6bb..e99f3492a835 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch -from vllm.config import VllmConfig, replace from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, @@ -257,45 +256,6 @@ def compute_new_slot_mapping( return new_slot_mapping -def create_vllm_config_for_spec_decode( - vllm_config: VllmConfig, -) -> VllmConfig: - """Apply kernel-level overrides (e.g. moe_backend) from speculative_config. - Returns the original object unchanged when no override is needed. - """ - spec_config = vllm_config.speculative_config - if spec_config is not None and spec_config.moe_backend is not None: - new_kernel_config = replace( - vllm_config.kernel_config, moe_backend=spec_config.moe_backend - ) - return replace(vllm_config, kernel_config=new_kernel_config) - return vllm_config - - -def create_vllm_config_for_draft_model( - target_model_vllm_config: VllmConfig, -) -> VllmConfig: - """Extends create_vllm_config_for_spec_decode with quant_config, - parallel_config, and model_config overrides for standalone draft models. - """ - old = target_model_vllm_config - assert old.speculative_config is not None, "speculative_config is not set" - old_spec_config = old.speculative_config - new_parallel_config = replace( - old_spec_config.draft_parallel_config, rank=old.parallel_config.rank - ) - - base = create_vllm_config_for_spec_decode(old) - - new: VllmConfig = replace( - base, - quant_config=None, - parallel_config=new_parallel_config, - model_config=old_spec_config.draft_model_config, - ) - return new - - def extend_all_queries_by_N( common_attn_metadata: CommonAttentionMetadata, N: int, From d78840be074ea9fc764f11ab4e0a0521e242deb0 Mon Sep 17 00:00:00 2001 From: Andrii Skliar Date: Tue, 24 Mar 2026 23:22:36 +0100 Subject: [PATCH 07/10] Refactor test_spec_decode.py and medusa.py for improved configuration handling - Updated variable names for clarity and consistency in test_spec_decode.py, enhancing readability. - Simplified the model loading process in MedusaProposer by directly using the vllm_config without unnecessary replacements. - Ensured that the speculative configuration is correctly propagated in the tests, maintaining the integrity of the model's behavior. Signed-off-by: Andrii Skliar --- tests/v1/e2e/spec_decode/test_spec_decode.py | 61 ++++++++++---------- vllm/v1/spec_decode/medusa.py | 13 +---- 2 files changed, 34 insertions(+), 40 deletions(-) diff --git a/tests/v1/e2e/spec_decode/test_spec_decode.py b/tests/v1/e2e/spec_decode/test_spec_decode.py index fb445ed40f85..e0caf4be7e4d 100644 --- a/tests/v1/e2e/spec_decode/test_spec_decode.py +++ b/tests/v1/e2e/spec_decode/test_spec_decode.py @@ -918,22 +918,22 @@ def test_draft_model_engine_args_tensor_parallelism(): "draft_tensor_parallel_size": 1, # <<< valid arg name }, ) - tgt_vllm_config: VllmConfig = engine_args.create_engine_config() - assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2 - assert tgt_vllm_config.quant_config.get_name() == "fp8" + target_config: VllmConfig = engine_args.create_engine_config() + assert target_config.parallel_config.tensor_parallel_size == 2 + assert target_config.quant_config.get_name() == "fp8" - sc = tgt_vllm_config.speculative_config - draft_vllm_config: VllmConfig = replace( - tgt_vllm_config, + speculative_config = target_config.speculative_config + draft_config: VllmConfig = replace( + target_config, quant_config=None, parallel_config=replace( - sc.draft_parallel_config, - rank=tgt_vllm_config.parallel_config.rank, + speculative_config.draft_parallel_config, + rank=target_config.parallel_config.rank, ), - model_config=sc.draft_model_config, + model_config=speculative_config.draft_model_config, ) - assert draft_vllm_config.parallel_config.tensor_parallel_size == 1 - assert draft_vllm_config.quant_config is None + assert draft_config.parallel_config.tensor_parallel_size == 1 + assert draft_config.quant_config is None @pytest.mark.parametrize( @@ -1004,46 +1004,49 @@ def test_draft_moe_backend( ): """speculative_config.moe_backend must propagate (or inherit) correctly across drafting methods.""" - spec_cfg = {**spec_method_cfg} + spec_method_args = {**spec_method_cfg} if draft_moe_backend is not None: - spec_cfg["moe_backend"] = draft_moe_backend + spec_method_args["moe_backend"] = draft_moe_backend engine_args = EngineArgs( model=target_model, tensor_parallel_size=tensor_parallel_size, trust_remote_code=trust_remote_code, moe_backend=target_moe_backend, - speculative_config=spec_cfg, + speculative_config=spec_method_args, ) - tgt_cfg: VllmConfig = engine_args.create_engine_config() - assert tgt_cfg.kernel_config.moe_backend == expected_target + target_config: VllmConfig = engine_args.create_engine_config() + assert target_config.kernel_config.moe_backend == expected_target # Base proposer path: override moe_backend from speculative_config - sc = tgt_cfg.speculative_config - if sc.moe_backend is not None: + speculative_config = target_config.speculative_config + if speculative_config.moe_backend is not None: applied = replace( - tgt_cfg, - kernel_config=replace(tgt_cfg.kernel_config, moe_backend=sc.moe_backend), + target_config, + kernel_config=replace( + target_config.kernel_config, + moe_backend=speculative_config.moe_backend, + ), ) else: - applied = tgt_cfg + applied = target_config assert applied.kernel_config.moe_backend == expected_draft - assert applied.model_config is tgt_cfg.model_config - assert applied.parallel_config is tgt_cfg.parallel_config + assert applied.model_config is target_config.model_config + assert applied.parallel_config is target_config.parallel_config if draft_moe_backend is None: - assert applied is tgt_cfg + assert applied is target_config # DraftModelProposer path: extends base with quant/parallel/model overrides - draft_cfg = replace( + draft_config = replace( applied, quant_config=None, parallel_config=replace( - sc.draft_parallel_config, - rank=tgt_cfg.parallel_config.rank, + speculative_config.draft_parallel_config, + rank=target_config.parallel_config.rank, ), - model_config=sc.draft_model_config, + model_config=speculative_config.draft_model_config, ) - assert draft_cfg.kernel_config.moe_backend == expected_draft + assert draft_config.kernel_config.moe_backend == expected_draft def test_draft_model_engine_args_rejects_invalid_tp_argname(): diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index ce6534e0ff14..80b0f0a9870a 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm.config import VllmConfig, replace +from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -57,18 +57,9 @@ def propose( def load_model(self, target_model: nn.Module) -> None: from vllm.compilation.backends import set_model_tag - draft_vllm_config = self.vllm_config - if self.spec_config.moe_backend is not None: - draft_vllm_config = replace( - draft_vllm_config, - kernel_config=replace( - draft_vllm_config.kernel_config, - moe_backend=self.spec_config.moe_backend, - ), - ) with set_model_tag("medusa_head"): self.model = get_model( - vllm_config=draft_vllm_config, + vllm_config=self.vllm_config, model_config=self.spec_config.draft_model_config, ) assert not ( From a0a5653d172a6243acf2ad7cd9bd7da346684843 Mon Sep 17 00:00:00 2001 From: Andrii Skliar Date: Tue, 24 Mar 2026 23:59:14 +0100 Subject: [PATCH 08/10] Refactor draft model tests in test_spec_decode.py for moe_backend handling - Introduced a new helper function to apply draft moe_backend logic, improving test clarity and reducing redundancy. - Added tests to verify the correct propagation and inheritance of moe_backend settings between target and draft configurations. - Ensured that default behaviors for moe_backend are correctly validated in various scenarios. Signed-off-by: Andrii Skliar --- tests/v1/e2e/spec_decode/test_spec_decode.py | 181 ++++++++----------- 1 file changed, 75 insertions(+), 106 deletions(-) diff --git a/tests/v1/e2e/spec_decode/test_spec_decode.py b/tests/v1/e2e/spec_decode/test_spec_decode.py index e0caf4be7e4d..bd7005e5de57 100644 --- a/tests/v1/e2e/spec_decode/test_spec_decode.py +++ b/tests/v1/e2e/spec_decode/test_spec_decode.py @@ -936,117 +936,86 @@ def test_draft_model_engine_args_tensor_parallelism(): assert draft_config.quant_config is None -@pytest.mark.parametrize( - [ - "target_moe_backend", - "draft_moe_backend", - "expected_target", - "expected_draft", - ], - [ - ("flashinfer_trtllm", "triton", "flashinfer_trtllm", "triton"), - ("flashinfer_cutlass", None, "flashinfer_cutlass", "flashinfer_cutlass"), - ("auto", None, "auto", "auto"), - ], - ids=["draft_overrides", "draft_inherits_target", "both_default_auto"], -) -@pytest.mark.parametrize( - ["target_model", "spec_method_cfg", "tensor_parallel_size", "trust_remote_code"], - [ - pytest.param( - "Qwen/Qwen3-1.7B", - { - "model": "Qwen/Qwen3-0.6B", - "method": "draft_model", - "num_speculative_tokens": 3, - }, - 1, - False, - id="draft_model", - ), - pytest.param( - "eagle618/deepseek-v3-random", - { - "model": "eagle618/eagle-deepseek-v3-random", - "method": "eagle", - "num_speculative_tokens": 3, - }, - 1, - False, - id="eagle", - ), - pytest.param( - "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4", - { - "method": "mtp", - "num_speculative_tokens": 1, - "max_model_len": 2048, - }, - 4, - True, - marks=[ - large_gpu_mark(min_gb=80), - *multi_gpu_marks(num_gpus=4), - ], - id="nemotron_mtp", - ), - ], -) -def test_draft_moe_backend( - target_moe_backend: str, - draft_moe_backend: str | None, - expected_target: str, - expected_draft: str, - target_model: str, - spec_method_cfg: dict[str, Any], - tensor_parallel_size: int, - trust_remote_code: bool, -): - """speculative_config.moe_backend must propagate (or inherit) correctly - across drafting methods.""" - spec_method_args = {**spec_method_cfg} - if draft_moe_backend is not None: - spec_method_args["moe_backend"] = draft_moe_backend +def _apply_draft_moe_backend(vllm_config: VllmConfig) -> VllmConfig: + """Replicate SpecDecodeBaseProposer._create_draft_vllm_config logic + so we can test it without instantiating a full proposer.""" + spec_cfg = vllm_config.speculative_config + if spec_cfg.moe_backend is not None: + return replace( + vllm_config, + kernel_config=replace( + vllm_config.kernel_config, + moe_backend=spec_cfg.moe_backend, + ), + ) + return vllm_config + + +def test_draft_model_moe_backend_override(): + """When moe_backend is set in speculative_config, the draft VllmConfig + should use it while the target keeps its own setting.""" + engine_args = EngineArgs( + model="Qwen/Qwen3-1.7B", + tensor_parallel_size=1, + moe_backend="flashinfer_trtllm", + speculative_config={ + "model": "Qwen/Qwen3-0.6B", + "method": "draft_model", + "num_speculative_tokens": 3, + "moe_backend": "triton", + }, + ) + tgt_config: VllmConfig = engine_args.create_engine_config() + assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm" + assert tgt_config.speculative_config.moe_backend == "triton" + + draft_config = _apply_draft_moe_backend(tgt_config) + assert draft_config.kernel_config.moe_backend == "triton" + # Target config must be unaffected. + assert tgt_config.kernel_config.moe_backend == "flashinfer_trtllm" + +def test_draft_model_moe_backend_inherits_target(): + """When moe_backend is not set in speculative_config, the draft should + inherit the target's moe_backend.""" engine_args = EngineArgs( - model=target_model, - tensor_parallel_size=tensor_parallel_size, - trust_remote_code=trust_remote_code, - moe_backend=target_moe_backend, - speculative_config=spec_method_args, + model="Qwen/Qwen3-1.7B", + tensor_parallel_size=1, + moe_backend="flashinfer_cutlass", + speculative_config={ + "model": "Qwen/Qwen3-0.6B", + "method": "draft_model", + "num_speculative_tokens": 3, + }, ) - target_config: VllmConfig = engine_args.create_engine_config() - assert target_config.kernel_config.moe_backend == expected_target + tgt_config: VllmConfig = engine_args.create_engine_config() + assert tgt_config.kernel_config.moe_backend == "flashinfer_cutlass" + assert tgt_config.speculative_config.moe_backend is None - # Base proposer path: override moe_backend from speculative_config - speculative_config = target_config.speculative_config - if speculative_config.moe_backend is not None: - applied = replace( - target_config, - kernel_config=replace( - target_config.kernel_config, - moe_backend=speculative_config.moe_backend, - ), - ) - else: - applied = target_config - assert applied.kernel_config.moe_backend == expected_draft - assert applied.model_config is target_config.model_config - assert applied.parallel_config is target_config.parallel_config - if draft_moe_backend is None: - assert applied is target_config - - # DraftModelProposer path: extends base with quant/parallel/model overrides - draft_config = replace( - applied, - quant_config=None, - parallel_config=replace( - speculative_config.draft_parallel_config, - rank=target_config.parallel_config.rank, - ), - model_config=speculative_config.draft_model_config, + draft_config = _apply_draft_moe_backend(tgt_config) + assert draft_config.kernel_config.moe_backend == "flashinfer_cutlass" + assert draft_config is tgt_config + + +def test_draft_model_moe_backend_default_auto(): + """When neither target nor draft set moe_backend explicitly, both should + default to 'auto'.""" + engine_args = EngineArgs( + model="Qwen/Qwen3-1.7B", + tensor_parallel_size=1, + speculative_config={ + "model": "Qwen/Qwen3-0.6B", + "method": "draft_model", + "num_speculative_tokens": 3, + }, ) - assert draft_config.kernel_config.moe_backend == expected_draft + tgt_config: VllmConfig = engine_args.create_engine_config() + assert tgt_config.kernel_config.moe_backend == "auto" + assert tgt_config.speculative_config.moe_backend is None + + draft_config = _apply_draft_moe_backend(tgt_config) + assert draft_config.kernel_config.moe_backend == "auto" + assert draft_config is tgt_config def test_draft_model_engine_args_rejects_invalid_tp_argname(): From eef0b1c839f7d155684a5c5f82557c143bf85528 Mon Sep 17 00:00:00 2001 From: Andrii Skliar Date: Wed, 25 Mar 2026 11:02:59 +0100 Subject: [PATCH 09/10] Enhance draft model configuration handling in DraftModelProposer - Introduced a copy of the draft_parallel_config to maintain the original configuration while updating the rank based on the current vllm_config. - Simplified the configuration replacement logic for better clarity and maintainability. Signed-off-by: Andrii Skliar --- vllm/v1/spec_decode/draft_model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 81fb2533dd23..f050c42add4e 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy from dataclasses import replace import torch @@ -55,13 +56,13 @@ def _raise_if_draft_tp_mismatch(self): def _create_draft_vllm_config(self) -> VllmConfig: base = super()._create_draft_vllm_config() spec = self.speculative_config + + draft_parallel_config = copy.copy(spec.draft_parallel_config) + draft_parallel_config.rank = self.vllm_config.parallel_config.rank return replace( base, quant_config=None, - parallel_config=replace( - spec.draft_parallel_config, - rank=self.vllm_config.parallel_config.rank, - ), + parallel_config=draft_parallel_config, model_config=spec.draft_model_config, ) From cbbc799f20a628e4451422198ce7f6c0c526755b Mon Sep 17 00:00:00 2001 From: Andrii Skliar Date: Wed, 25 Mar 2026 14:05:07 +0100 Subject: [PATCH 10/10] Refactor imports and configuration handling in spec_decode modules - Removed unnecessary imports and streamlined the usage of the `replace` function across `test_spec_decode.py`, `draft_model.py`, and `eagle.py`. - Enhanced clarity in the configuration handling within `DraftModelProposer` by directly utilizing the `replace` function for updating the draft parallel configuration. Signed-off-by: Andrii Skliar --- tests/v1/e2e/spec_decode/test_spec_decode.py | 4 ++-- vllm/v1/spec_decode/draft_model.py | 11 +++++------ vllm/v1/spec_decode/eagle.py | 2 +- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/v1/e2e/spec_decode/test_spec_decode.py b/tests/v1/e2e/spec_decode/test_spec_decode.py index bd7005e5de57..9ea41c774c46 100644 --- a/tests/v1/e2e/spec_decode/test_spec_decode.py +++ b/tests/v1/e2e/spec_decode/test_spec_decode.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random from collections.abc import Iterable -from dataclasses import dataclass, replace +from dataclasses import dataclass from typing import Any import pytest @@ -20,7 +20,7 @@ from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.image import VLM_IMAGES_DIR from vllm.benchmarks.datasets import InstructCoderDataset -from vllm.config import VllmConfig +from vllm.config import VllmConfig, replace from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index f050c42add4e..9633e2ef6ca2 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import copy -from dataclasses import replace - import torch import torch.nn as nn from typing_extensions import override from vllm.config import VllmConfig +from vllm.config.utils import replace from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer @@ -57,12 +55,13 @@ def _create_draft_vllm_config(self) -> VllmConfig: base = super()._create_draft_vllm_config() spec = self.speculative_config - draft_parallel_config = copy.copy(spec.draft_parallel_config) - draft_parallel_config.rank = self.vllm_config.parallel_config.rank return replace( base, quant_config=None, - parallel_config=draft_parallel_config, + parallel_config=replace( + spec.draft_parallel_config, + rank=self.vllm_config.parallel_config.rank, + ), model_config=spec.draft_model_config, ) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index c101d05ddff8..62333526c26d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast -from dataclasses import replace from importlib.util import find_spec from typing import cast @@ -13,6 +12,7 @@ CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, + replace, ) from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context