From 1bbc15d1b685fe7c53844eb18f0af55ec8666331 Mon Sep 17 00:00:00 2001 From: zixi-qi Date: Fri, 19 Sep 2025 02:28:32 +0000 Subject: [PATCH 1/2] Add unit tests and e2e test for MTP inference Signed-off-by: zixi-qi --- tests/v1/e2e/test_spec_decode.py | 77 +++++++++++ tests/v1/spec_decode/test_mtp.py | 223 +++++++++++++++++++++++++++++++ 2 files changed, 300 insertions(+) create mode 100644 tests/v1/spec_decode/test_mtp.py diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index bf90f50b1082..967ee69fa7c0 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -223,3 +223,80 @@ def test_eagle_correctness( del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ + (("mimo_mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False), +], + ids=["mimo_mtp"]) +@pytest.mark.parametrize("attn_backend", + get_attn_backend_list_based_on_platform()) +def test_mtp_correctness( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_setup: tuple[str, str, int], + mm_enabled: bool, + attn_backend: str, +): + if attn_backend == "TREE_ATTN": + pytest.skip("MTP does not support tree-based speculative decoding") + + if (attn_backend == "TRITON_ATTN_VLLM_V1" + and not current_platform.is_rocm()): + pytest.skip("TRITON_ATTN_VLLM_V1 does not support " + "multi-token MTP spec decode on current platform") + + # Generate test prompts inside the function instead of using fixture + test_prompts = get_test_prompts(mm_enabled) + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using MTP speculative decoding. + model_setup: (method, model_name, tp_size) + ''' + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_MLA_DISABLE", "1") + m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + + if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + m.setenv("VLLM_ROCM_USE_AITER", "1") + + method, model_name, tp_size = model_setup + + ref_llm = LLM(model=model_name, + max_model_len=2048, + tensor_parallel_size=tp_size, + trust_remote_code=True) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + spec_llm = LLM( + model=model_name, + trust_remote_code=True, + tensor_parallel_size=tp_size, + speculative_config={ + "method": method, + "num_speculative_tokens": 1, + "max_model_len": 2048, + }, + max_model_len=2048, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 80% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.8 * len(ref_outputs)) + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py new file mode 100644 index 000000000000..c8d0b62b815b --- /dev/null +++ b/tests/v1/spec_decode/test_mtp.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest import mock + +import pytest +import torch + +from tests.utils import get_attn_backend_list_based_on_platform +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata, + create_standard_kv_cache_spec, + get_attention_backend) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VllmConfig) +from vllm.config.load import LoadConfig +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.platforms import current_platform +from vllm.v1.spec_decode.eagle import EagleProposer + +mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base" + + +def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: + """Create an MTP proposer with unified model configuration.""" + model_config = ModelConfig(model=mimo_7b_dir, + runner="generate", + max_model_len=100, + trust_remote_code=True) + + speculative_config = SpeculativeConfig( + target_model_config=model_config, + target_parallel_config=ParallelConfig(), + model=mimo_7b_dir, + method="mimo_mtp", + num_speculative_tokens=num_speculative_tokens, + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + speculative_config=speculative_config, + device_config=DeviceConfig(device=current_platform.device_type), + parallel_config=ParallelConfig(), + load_config=LoadConfig(), + scheduler_config=SchedulerConfig()) + + return EagleProposer(vllm_config=vllm_config, + device=current_platform.device_type) + + +@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') +@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') +@mock.patch('vllm.v1.spec_decode.eagle.get_model') +def test_mtp_load_model_unified(mock_get_model, mock_get_layers, + mock_get_pp_group): + """Test MTP-specific model loading with unified model approach.""" + + # Setup mocks + mock_model = mock.MagicMock() + mock_model.model.embed_tokens.weight.shape = (131072, 4096) + mock_get_model.return_value = mock_model + + target_attn_layers = {"target_attn_1": mock.MagicMock()} + all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()} + mock_get_layers.side_effect = [target_attn_layers, all_attn_layers] + + mock_pp_group = mock.MagicMock() + mock_pp_group.world_size = 1 + mock_get_pp_group.return_value = mock_pp_group + + # Create target model + class _TargetModelStub(LlamaForCausalLM): + model: mock.MagicMock + lm_head: mock.MagicMock + + target_model = mock.create_autospec(_TargetModelStub, instance=True) + target_model.model = mock.MagicMock() + target_model.model.embed_tokens.weight.shape = (131072, 4096) + target_model.lm_head = mock.MagicMock() + + # Create MTP proposer + proposer = _create_mtp_proposer(num_speculative_tokens=4) + proposer.load_model(target_model) + + # Verify MTP-specific behavior: + # Model is loaded + mock_get_model.assert_called_once() + # MTP shares lm_head with target model + assert proposer.model.lm_head == target_model.lm_head + # MTP shares embed_tokens with target model + assert proposer.model.model.embed_tokens == target_model.model.embed_tokens + + +@pytest.mark.parametrize("attn_backend", + get_attn_backend_list_based_on_platform()) +@pytest.mark.parametrize("num_speculative_tokens", [1]) +def test_mtp_propose_returns_hidden_states(attn_backend, + num_speculative_tokens, + monkeypatch): + """Test that MTP's forward method returns hidden states directly""" + + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + + if (attn_backend == "TRITON_ATTN_VLLM_V1" + and not current_platform.is_rocm()): + pytest.skip("TRITON_ATTN_VLLM_V1 does not support " + "multi-token spec decode on current platform") + + if attn_backend == "TREE_ATTN": + pytest.skip("MTP does not support tree-based speculative decoding") + + if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + + device = torch.device(current_platform.device_type) + batch_size = 2 + seq_lens = [5, 3] + total_tokens = sum(seq_lens) + vocab_size = 100 + + proposer = _create_mtp_proposer(num_speculative_tokens) + hidden_size = proposer.hidden_size + + # Mock the MTP model to verify it returns hidden states directly + model_mock = mock.MagicMock() + + # MTP returns hidden states directly + if num_speculative_tokens == 1: + model_mock.return_value = torch.zeros(total_tokens, + hidden_size, + device=device) + else: + # Multiple forward passes for multi-token speculation + forward_returns = [] + for i in range(num_speculative_tokens): + if i == 0: + h_states = torch.zeros(total_tokens, + hidden_size, + device=device) + else: + h_states = torch.zeros(batch_size, hidden_size, device=device) + forward_returns.append(h_states) + model_mock.side_effect = forward_returns + + # Mock compute_logits + def create_deterministic_logits(batch_size, vocab_size, token_offset): + logits = torch.full((batch_size, vocab_size), -100.0, device=device) + logits[:, token_offset] = 100.0 + return logits + + if num_speculative_tokens == 1: + model_mock.compute_logits.return_value = create_deterministic_logits( + batch_size, vocab_size, 42) + else: + logits_returns = [ + create_deterministic_logits(batch_size, vocab_size, 42 + i) + for i in range(num_speculative_tokens) + ] + model_mock.compute_logits.side_effect = logits_returns + + proposer.model = model_mock + proposer.attn_layer_names = ["layer.0"] + + # Prepare inputs + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) + common_attn_metadata = create_common_attn_metadata(batch_spec, + block_size=16, + device=device) + + target_token_ids = torch.randint(0, + vocab_size, (total_tokens, ), + device=device) + target_positions = torch.cat([ + torch.arange(seq_lens[0], device=device), + torch.arange(seq_lens[1], device=device) + ]) + target_hidden_states = torch.randn(total_tokens, + hidden_size, + device=device) + next_token_ids = torch.randint(0, + vocab_size, (batch_size, ), + dtype=torch.int32, + device=device) + sampling_metadata = mock.MagicMock() + + # Setup attention metadata + if attn_backend == "FLASH_ATTN_VLLM_V1": + attn_metadata_builder_cls, _ = get_attention_backend( + _Backend.FLASH_ATTN_VLLM_V1) + elif attn_backend == "TRITON_ATTN_VLLM_V1": + attn_metadata_builder_cls, _ = get_attention_backend( + _Backend.TRITON_ATTN_VLLM_V1) + else: + raise ValueError(f"Unsupported attention backend: {attn_backend}") + + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=proposer.attn_layer_names, + vllm_config=proposer.vllm_config, + device=device, + ) + + proposer.runner = mock.MagicMock() + proposer.runner.attn_groups.append([mock.MagicMock()]) + proposer.runner.attn_groups[0][0].metadata_builders = [ + attn_metadata_builder + ] + + # Run propose + result = proposer.propose(target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata) + + # Verify the model was called correctly + assert model_mock.called + # Verify output shape + assert result.shape == (batch_size, num_speculative_tokens) From 4bb4b5ed8728f7c3d8ac43c545a0f5f8baec284e Mon Sep 17 00:00:00 2001 From: zixi-qi Date: Fri, 19 Sep 2025 05:23:44 +0000 Subject: [PATCH 2/2] rename to deepseek_mtp Signed-off-by: zixi-qi --- tests/v1/e2e/test_spec_decode.py | 4 ++-- tests/v1/spec_decode/test_mtp.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 967ee69fa7c0..fbf6671741ee 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -226,9 +226,9 @@ def test_eagle_correctness( @pytest.mark.parametrize(["model_setup", "mm_enabled"], [ - (("mimo_mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False), + (("deepseek_mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False), ], - ids=["mimo_mtp"]) + ids=["deepseek_mtp"]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) def test_mtp_correctness( diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index c8d0b62b815b..b12eca159cb1 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -33,7 +33,7 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: target_model_config=model_config, target_parallel_config=ParallelConfig(), model=mimo_7b_dir, - method="mimo_mtp", + method="deepseek_mtp", num_speculative_tokens=num_speculative_tokens, )