Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/_e2e_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ jobs:
pytest -sv --durations=0 tests/e2e/singlecard/test_cpu_offloading.py

# ------------------------------------ v1 spec decode test ------------------------------------ #
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py

e2e-2-cards:
name: multicard-2
Expand Down Expand Up @@ -294,6 +294,7 @@ jobs:
env:
VLLM_WORKER_MULTIPROC_METHOD: spawn
run: |
pytest -sv --durations=0 tests/e2e/multicard/spec_decode/test_mtp_qwen3_next.py
pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_deepseek_multistream_moe_tp2
pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_kimi_k2_thinking_w4a16_tp4
pytest -sv --durations=0 tests/e2e/multicard/test_data_parallel_tp2.py
Expand Down
152 changes: 152 additions & 0 deletions tests/e2e/multicard/spec_decode/test_mtp_qwen3_next.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.

Run `pytest tests/e2e/multicard/spec_decode/test_mtp_qwen3_next.py`.
"""

import os

import pytest
from vllm.config import CompilationConfig
from vllm.v1.metrics.reader import Counter, Vector

from tests.e2e.conftest import VllmRunner, cleanup_dist_env_and_memory

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

MODELS = ["Qwen/Qwen3-Next-80B-A3B-Instruct"]


# TODO: add full decode only (when ready)
@pytest.mark.parametrize("model_name", MODELS)
def test_qwen3_next_mtp_acceptance_tp4(model_name):
golden = [0.85, 0.46, 0.19]

example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

max_tokens = 1024

with VllmRunner(model_name,
tensor_parallel_size=4,
max_model_len=4096,
gpu_memory_utilization=0.8,
distributed_executor_backend="mp",
disable_log_stats=False,
speculative_config={
"method": "qwen3_next_mtp",
"num_speculative_tokens": 3,
},
compilation_config=CompilationConfig(
cudagraph_capture_sizes=[20])) as spec_vllm_model:
_ = spec_vllm_model.generate_greedy(example_prompts, max_tokens)
metrics = spec_vllm_model.model.get_metrics()

num_drafts = 0
num_accepted_tokens_per_pos = [0] * 3
for metric in metrics:
if metric.name == "vllm:spec_decode_num_drafts":
assert isinstance(metric, Counter)
num_drafts += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
assert isinstance(metric, Vector)
for pos in range(len(metric.values)):
num_accepted_tokens_per_pos[pos] += metric.values[pos]

acceptance_per_pos = [
num_accepted_tokens / num_drafts
for num_accepted_tokens in num_accepted_tokens_per_pos
]

match = all(abs(a - b) < 0.05 for a, b in zip(acceptance_per_pos, golden))
if not match:
print(f"acceptance_per_pos: {acceptance_per_pos}")
print(f"golden: {golden}")

assert match
cleanup_dist_env_and_memory()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The VllmRunner context manager automatically handles resource cleanup, including calling cleanup_dist_env_and_memory(), in its __exit__ method. This explicit call is redundant and should be removed to improve code clarity.



@pytest.mark.parametrize("model_name", MODELS)
@pytest.mark.parametrize("num_speculative_tokens", [1])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
def test_qwen3_next_mtp_correctness_tp4(model_name: str,
num_speculative_tokens: int,
disable_padded_drafter_batch: bool):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

max_tokens = 20
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using mtp speculative decoding.
'''
with VllmRunner(model_name,
tensor_parallel_size=4,
max_model_len=4096,
gpu_memory_utilization=0.8,
distributed_executor_backend="mp",
speculative_config={
"method":
"mtp",
"num_speculative_tokens":
num_speculative_tokens,
"disable_padded_drafter_batch":
disable_padded_drafter_batch,
},
compilation_config=CompilationConfig(
cudagraph_capture_sizes=[20])) as spec_llm:
spec_outputs = spec_llm.generate_greedy(example_prompts, max_tokens)
del spec_llm

with VllmRunner(model_name,
tensor_parallel_size=4,
max_model_len=4096,
gpu_memory_utilization=0.8,
distributed_executor_backend="mp",
compilation_config=CompilationConfig(
cudagraph_capture_sizes=[20])) as ref_llm:
ref_outputs = ref_llm.generate_greedy(example_prompts, max_tokens)
del ref_llm

matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
ref_token_ids = ref_output[0]
spec_token_ids = spec_output[0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output[1]}")
print(f"spec_output: {spec_output[1]}")

# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
cleanup_dist_env_and_memory()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The VllmRunner context manager automatically handles resource cleanup, including calling cleanup_dist_env_and_memory(), in its __exit__ method. This explicit call is redundant and should be removed.

50 changes: 0 additions & 50 deletions tests/e2e/multicard/test_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,56 +62,6 @@ def test_qwen3_next_distributed_mp_full_decode_only_tp4():
del vllm_model


def test_qwen3_next_distributed_mp_eager_mtp_similarity_tp4():
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

max_tokens = 15

with VllmRunner(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
max_model_len=4096,
gpu_memory_utilization=0.8,
distributed_executor_backend="mp",
enforce_eager=True,
) as vllm_model:
ref_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
max_model_len=4096,
gpu_memory_utilization=0.8,
distributed_executor_backend="mp",
enforce_eager=True,
speculative_config={
"method": "qwen3_next_mtp",
"num_speculative_tokens": 1
}) as spec_vllm_model:
spec_outputs = spec_vllm_model.generate_greedy(example_prompts,
max_tokens)
del spec_vllm_model

matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
ref_token_ids = ref_output[0]
spec_token_ids = spec_output[0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output[1]}")
print(f"spec_output: {spec_output[1]}")

assert matches > int(0.66 * len(ref_outputs))


# TODO: will conduct accuracy verification after the subsequent version becomes stable
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
def test_qwen3_next_w8a8dynamic_distributed_tp4_ep():
Expand Down
Loading
Loading