diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 7418c60ac7f..cd035d7a735 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -122,8 +122,8 @@ jobs: pytest -sv --durations=0 tests/e2e/singlecard/test_cpu_offloading.py # ------------------------------------ v1 spec decode test ------------------------------------ # - pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py - pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py + pytest -sv --durations=0 tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py + pytest -sv --durations=0 tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py e2e-2-cards: name: multicard-2 @@ -294,6 +294,7 @@ jobs: env: VLLM_WORKER_MULTIPROC_METHOD: spawn run: | + pytest -sv --durations=0 tests/e2e/multicard/spec_decode/test_mtp_qwen3_next.py pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_deepseek_multistream_moe_tp2 pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_kimi_k2_thinking_w4a16_tp4 pytest -sv --durations=0 tests/e2e/multicard/test_data_parallel_tp2.py diff --git a/tests/e2e/multicard/spec_decode/test_mtp_qwen3_next.py b/tests/e2e/multicard/spec_decode/test_mtp_qwen3_next.py new file mode 100644 index 00000000000..7ea0ca94064 --- /dev/null +++ b/tests/e2e/multicard/spec_decode/test_mtp_qwen3_next.py @@ -0,0 +1,152 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +Run `pytest tests/e2e/multicard/spec_decode/test_mtp_qwen3_next.py`. +""" + +import os + +import pytest +from vllm.config import CompilationConfig +from vllm.v1.metrics.reader import Counter, Vector + +from tests.e2e.conftest import VllmRunner, cleanup_dist_env_and_memory + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +MODELS = ["Qwen/Qwen3-Next-80B-A3B-Instruct"] + + +# TODO: add full decode only (when ready) +@pytest.mark.parametrize("model_name", MODELS) +def test_qwen3_next_mtp_acceptance_tp4(model_name): + golden = [0.85, 0.46, 0.19] + + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + max_tokens = 1024 + + with VllmRunner(model_name, + tensor_parallel_size=4, + max_model_len=4096, + gpu_memory_utilization=0.8, + distributed_executor_backend="mp", + disable_log_stats=False, + speculative_config={ + "method": "qwen3_next_mtp", + "num_speculative_tokens": 3, + }, + compilation_config=CompilationConfig( + cudagraph_capture_sizes=[20])) as spec_vllm_model: + _ = spec_vllm_model.generate_greedy(example_prompts, max_tokens) + metrics = spec_vllm_model.model.get_metrics() + + num_drafts = 0 + num_accepted_tokens_per_pos = [0] * 3 + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(len(metric.values)): + num_accepted_tokens_per_pos[pos] += metric.values[pos] + + acceptance_per_pos = [ + num_accepted_tokens / num_drafts + for num_accepted_tokens in num_accepted_tokens_per_pos + ] + + match = all(abs(a - b) < 0.05 for a, b in zip(acceptance_per_pos, golden)) + if not match: + print(f"acceptance_per_pos: {acceptance_per_pos}") + print(f"golden: {golden}") + + assert match + cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize("model_name", MODELS) +@pytest.mark.parametrize("num_speculative_tokens", [1]) +@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False]) +def test_qwen3_next_mtp_correctness_tp4(model_name: str, + num_speculative_tokens: int, + disable_padded_drafter_batch: bool): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + max_tokens = 20 + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using mtp speculative decoding. + ''' + with VllmRunner(model_name, + tensor_parallel_size=4, + max_model_len=4096, + gpu_memory_utilization=0.8, + distributed_executor_backend="mp", + speculative_config={ + "method": + "mtp", + "num_speculative_tokens": + num_speculative_tokens, + "disable_padded_drafter_batch": + disable_padded_drafter_batch, + }, + compilation_config=CompilationConfig( + cudagraph_capture_sizes=[20])) as spec_llm: + spec_outputs = spec_llm.generate_greedy(example_prompts, max_tokens) + del spec_llm + + with VllmRunner(model_name, + tensor_parallel_size=4, + max_model_len=4096, + gpu_memory_utilization=0.8, + distributed_executor_backend="mp", + compilation_config=CompilationConfig( + cudagraph_capture_sizes=[20])) as ref_llm: + ref_outputs = ref_llm.generate_greedy(example_prompts, max_tokens) + del ref_llm + + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + ref_token_ids = ref_output[0] + spec_token_ids = spec_output[0] + if ref_token_ids == spec_token_ids[:len(ref_token_ids)]: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output[1]}") + print(f"spec_output: {spec_output[1]}") + + # Heuristic: expect at least 66% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.66 * len(ref_outputs)) + cleanup_dist_env_and_memory() diff --git a/tests/e2e/multicard/test_qwen3_next.py b/tests/e2e/multicard/test_qwen3_next.py index 11e37af2561..ca6d077c44c 100644 --- a/tests/e2e/multicard/test_qwen3_next.py +++ b/tests/e2e/multicard/test_qwen3_next.py @@ -62,56 +62,6 @@ def test_qwen3_next_distributed_mp_full_decode_only_tp4(): del vllm_model -def test_qwen3_next_distributed_mp_eager_mtp_similarity_tp4(): - example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - max_tokens = 15 - - with VllmRunner( - "Qwen/Qwen3-Next-80B-A3B-Instruct", - tensor_parallel_size=4, - max_model_len=4096, - gpu_memory_utilization=0.8, - distributed_executor_backend="mp", - enforce_eager=True, - ) as vllm_model: - ref_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - del vllm_model - - with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct", - tensor_parallel_size=4, - max_model_len=4096, - gpu_memory_utilization=0.8, - distributed_executor_backend="mp", - enforce_eager=True, - speculative_config={ - "method": "qwen3_next_mtp", - "num_speculative_tokens": 1 - }) as spec_vllm_model: - spec_outputs = spec_vllm_model.generate_greedy(example_prompts, - max_tokens) - del spec_vllm_model - - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - ref_token_ids = ref_output[0] - spec_token_ids = spec_output[0] - if ref_token_ids == spec_token_ids[:len(ref_token_ids)]: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output[1]}") - print(f"spec_output: {spec_output[1]}") - - assert matches > int(0.66 * len(ref_outputs)) - - # TODO: will conduct accuracy verification after the subsequent version becomes stable @patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"}) def test_qwen3_next_w8a8dynamic_distributed_tp4_ep(): diff --git a/tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py b/tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py new file mode 100644 index 00000000000..421a0e88edb --- /dev/null +++ b/tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py @@ -0,0 +1,206 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +""" + +from __future__ import annotations + +import os + +import pytest +from vllm import SamplingParams +from vllm.config import CompilationConfig + +from tests.e2e.conftest import VllmRunner, cleanup_dist_env_and_memory + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +MODELS = ["wemaster/deepseek_mtp_main_random_bf16"] +MODELS_EAGLE = [ + "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B", + "RedHatAI/Qwen3-8B-speculator.eagle3" +] +MODELS_MAIN = ["LLM-Research/Meta-Llama-3.1-8B-Instruct", "Qwen/Qwen3-8B"] +VALID_COMBINATIONS = {("eagle", "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B", + "LLM-Research/Meta-Llama-3.1-8B-Instruct"), + ("eagle3", "RedHatAI/Qwen3-8B-speculator.eagle3", + "Qwen/Qwen3-8B")} + + +@pytest.mark.parametrize("model_name", MODELS) +@pytest.mark.parametrize("num_speculative_tokens", [1, 2, 3]) +@pytest.mark.parametrize("cudagraph_mode", ["PIECEWISE", "FULL_DECODE_ONLY"]) +@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False]) +def test_deepseek_mtp_correctness(model_name: str, num_speculative_tokens: int, + cudagraph_mode: str, + disable_padded_drafter_batch: bool): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using mtp speculative decoding. + ''' + with VllmRunner(model_name, + tensor_parallel_size=1, + max_num_seqs=256, + gpu_memory_utilization=0.7, + distributed_executor_backend="mp", + enable_expert_parallel=True, + speculative_config={ + "method": + "mtp", + "num_speculative_tokens": + num_speculative_tokens, + "disable_padded_drafter_batch": + disable_padded_drafter_batch, + }, + max_model_len=2000, + compilation_config=CompilationConfig( + cudagraph_mode=cudagraph_mode, + cudagraph_capture_sizes=[20], + )) as spec_llm: + sampling_config = SamplingParams(temperature=0, + max_tokens=256, + ignore_eos=False) + spec_outputs = spec_llm.generate(example_prompts, sampling_config) + + with VllmRunner(model_name, + tensor_parallel_size=1, + gpu_memory_utilization=0.7, + max_model_len=256, + compilation_config=CompilationConfig( + cudagraph_mode=cudagraph_mode, + cudagraph_capture_sizes=[20], + )) as ref_llm: + sampling_config = SamplingParams(temperature=0, + max_tokens=256, + ignore_eos=False) + ref_outputs = ref_llm.generate(example_prompts, sampling_config) + + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + ref_token_ids = ref_output[0][0] + spec_token_ids = spec_output[0][0] + if ref_token_ids == spec_token_ids[:len(ref_token_ids)]: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output[1][0]}") + print(f"spec_output: {spec_output[1][0]}") + + # Heuristic: expect at least 66% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.66 * len(ref_outputs)) + cleanup_dist_env_and_memory() + del spec_llm + + +@pytest.mark.parametrize("model_name", MODELS_EAGLE) +@pytest.mark.parametrize("model_name_main", MODELS_MAIN) +@pytest.mark.parametrize("num_speculative_tokens", [1, 2]) +@pytest.mark.parametrize("method", ["eagle", "eagle3"]) +@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False]) +@pytest.mark.parametrize("async_scheduling", [True, False]) +def test_llama_qwen3_eagle_correctness(model_name: str, model_name_main: str, + num_speculative_tokens: int, + method: str, + disable_padded_drafter_batch: bool, + async_scheduling: bool): + + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + if (method, model_name, model_name_main) not in VALID_COMBINATIONS or \ + (async_scheduling and disable_padded_drafter_batch): + pytest.skip( + f"Invalid combination: method={method}, model_name={model_name}, model_name_main={model_name_main}, or case not support yet" + ) + + sampling_params = SamplingParams( + max_tokens=300, + temperature=0.0, + ignore_eos=False, + ) + + with VllmRunner(model_name_main, + tensor_parallel_size=1, + pipeline_parallel_size=1, + data_parallel_size=1, + disable_log_stats=False, + max_model_len=4096, + seed=1024, + async_scheduling=async_scheduling, + speculative_config={ + "disable_padded_drafter_batch": + disable_padded_drafter_batch, + "method": method, + "model": model_name, + "num_speculative_tokens": num_speculative_tokens, + "max_model_len": 128, + "draft_vocab_size": 128256, + }, + compilation_config=CompilationConfig( + cudagraph_mode="FULL_DECODE_ONLY", + cudagraph_capture_sizes=[12])) as llm: + spec_outputs = llm.generate(example_prompts, sampling_params) + cleanup_dist_env_and_memory() + del llm + + with VllmRunner(model_name_main, + tensor_parallel_size=1, + pipeline_parallel_size=1, + data_parallel_size=1, + disable_log_stats=False, + max_model_len=4096, + seed=1024, + async_scheduling=async_scheduling, + compilation_config=CompilationConfig( + cudagraph_mode="FULL_DECODE_ONLY", + cudagraph_capture_sizes=[12])) as llm: + ref_outputs = llm.generate(example_prompts, sampling_params) + cleanup_dist_env_and_memory() + del llm + + matches = 0 + misses = 0 + threshold = 0.66 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + ref_token_ids = ref_output[0][0] + spec_token_ids = spec_output[0][0] + if ref_token_ids == spec_token_ids[:len(ref_token_ids)]: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output[1][0]}") + print(f"spec_output: {spec_output[1][0]}") + + # Heuristic: expect at least 66.6% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(threshold * len(ref_outputs)) + cleanup_dist_env_and_memory() diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py similarity index 74% rename from tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py rename to tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py index 6e6e46bb1e3..58d4c709f0e 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py @@ -12,7 +12,7 @@ from vllm.config import CompilationConfig from vllm.v1.metrics.reader import Counter, Vector -from tests.e2e.conftest import VllmRunner, cleanup_dist_env_and_memory +from tests.e2e.conftest import VllmRunner os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" @@ -130,127 +130,6 @@ def test_ngram_correctness( assert matches > int(0.66 * len(ref_outputs)) -@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) -def test_eagle_correctness( - test_prompts: list[list[dict[str, Any]]], - sampling_config: SamplingParams, - model_name: str, - use_eagle3: bool, -): - ''' - Compare the outputs of a original LLM and a speculative LLM - should be the same when using eagle speculative decoding. - ''' - # NOTE: e2e of eagle has many problems before. - # We first check whether it is functioning properly. - # Should fix the e2e with VllmRunner in future. - spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name() - tokenizer = AutoTokenizer.from_pretrained(model_name, - trust_remote_code=True) - prompts = [{ - "role": "user", - "content": "Hello, my name is" - }, { - "role": "user", - "content": "The president of the United States is" - }, { - "role": "user", - "content": "The capital of France is" - }, { - "role": "user", - "content": "The future of AI is" - }] - prompts = [ - tokenizer.apply_chat_template( - [prompt], - tokenize=False, - add_generation_prompt=True, - ) for prompt in prompts - ] - - sampling_params = SamplingParams( - max_tokens=300, - temperature=0.8, - top_p=0.7, - top_k=4, - ignore_eos=False, - ) - - # Create an LLM. - llm = LLM( - model=model_name, - tensor_parallel_size=1, - pipeline_parallel_size=1, - data_parallel_size=1, - disable_log_stats=False, - max_model_len=4096, - seed=1024, - async_scheduling=True, - compilation_config={ - "level": 3, - "cudagraph_mode": "FULL_DECODE_ONLY", - "cudagraph_num_of_warmups": 1, - "cudagraph_capture_sizes": [12], - }, - speculative_config={ - "disable_padded_drafter_batch": False, - "method": "eagle3" if use_eagle3 else "eagle", - "model": spec_model_name, - "num_speculative_tokens": 2, - "max_model_len": 128, - "draft_vocab_size": 128256, - }, - ) - llm.generate(prompts, sampling_params) - cleanup_dist_env_and_memory() - del llm - - -@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) -def test_eaqgle_fullgraph_correctness( - test_prompts: list[list[dict[str, Any]]], - sampling_config: SamplingParams, - model_name: str, - use_eagle3: bool, -): - ''' - Compare the outputs of a original LLM and a speculative LLM - should be the same when using eagle3 speculative decoding - in full-graph mode. - ''' - spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name() - with VllmRunner(model_name, max_model_len=1024) as ref_llm: - ref_outputs = ref_llm.model.chat(test_prompts, sampling_config) - - with VllmRunner(model_name, - speculative_config={ - "method": "eagle3" if use_eagle3 else "eagle", - "model": spec_model_name, - "num_speculative_tokens": 4, - }, - compilation_config={ - "level": 3, - "cudagraph_mode": "FULL_DECODE_ONLY", - "cudagraph_num_of_warmups": 1, - "cudagraph_capture_sizes": [5, 10, 15, 20], - }, - max_model_len=1024) as runner: - spec_outputs = runner.model.chat(test_prompts, sampling_config) - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - if ref_output.outputs[0].text == spec_output.outputs[0].text: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output.outputs[0].text}") - print(f"spec_output: {spec_output.outputs[0].text}") - - # Heuristic: expect at least 70% of the prompts to match exactly - # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.66 * len(ref_outputs)) - - def test_suffix_correctness( test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py deleted file mode 100644 index 9369c4e2f45..00000000000 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ /dev/null @@ -1,175 +0,0 @@ -from __future__ import annotations - -import os - -import pytest -from vllm import SamplingParams -from vllm.config import CompilationConfig, CUDAGraphMode - -from tests.e2e.conftest import VllmRunner - -os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - - -@pytest.fixture -def sampling_config(): - return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False) - - -@pytest.fixture -def model_name(): - return "wemaster/deepseek_mtp_main_random_bf16" - - -def mtp_correctness(sampling_config: SamplingParams, - model_name: str, - num_speculative_tokens: int, - graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE, - enforce_eager=False, - disable_padded_drafter_batch=True): - example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - ''' - Compare the outputs of a original LLM and a speculative LLM - should be the same when using mtp speculative decoding. - ''' - with VllmRunner(model_name, - tensor_parallel_size=1, - gpu_memory_utilization=0.7, - max_model_len=256, - cudagraph_capture_sizes=[12], - enforce_eager=enforce_eager) as ref_llm: - ref_outputs = ref_llm.generate(example_prompts, sampling_config) - - graph_mode_str = "PIECEWISE" - if graph_mode == CUDAGraphMode.FULL: - graph_mode_str = "FULL_DECODE_ONLY" - - with VllmRunner(model_name, - tensor_parallel_size=1, - max_num_seqs=256, - gpu_memory_utilization=0.7, - distributed_executor_backend="mp", - enable_expert_parallel=True, - speculative_config={ - "method": - "mtp", - "num_speculative_tokens": - num_speculative_tokens, - "disable_padded_drafter_batch": - disable_padded_drafter_batch, - }, - enforce_eager=enforce_eager, - max_model_len=2000, - compilation_config=CompilationConfig( - cudagraph_mode=graph_mode_str, - cudagraph_capture_sizes=[12], - )) as spec_llm: - spec_outputs = spec_llm.generate(example_prompts, sampling_config) - - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - ref_token_ids = ref_output[0][0] - spec_token_ids = spec_output[0][0] - if ref_token_ids == spec_token_ids[:len(ref_token_ids)]: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output[1][0]}") - print(f"spec_output: {spec_output[1][0]}") - - # Heuristic: expect at least 66% of the prompts to match exactly - # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.66 * len(ref_outputs)) - del spec_llm - - -def test_mtp1_correctness_eager( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, model_name, 1, enforce_eager=True) - - -def test_mtp2_correctness_eager( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, model_name, 2, enforce_eager=True) - - -def test_mtp1_correctness_piecewise_graph( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, model_name, 1) - - -def test_mtp2_correctness_piecewise_graph( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, model_name, 2) - - -def test_mtp1_correctness_full_graph( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL) - - -def test_mtp2_correctness_full_graph( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL) - - -def test_mtp1_correctness_eager_with_pad( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, - model_name, - 1, - enforce_eager=True, - disable_padded_drafter_batch=False) - - -def test_mtp2_correctness_eager_with_pad( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, - model_name, - 2, - enforce_eager=True, - disable_padded_drafter_batch=False) - - -@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed") -def test_mtp1_correctness_piecewise_graph_with_pad( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, - model_name, - 1, - disable_padded_drafter_batch=False) - - -@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed") -def test_mtp2_correctness_piecewise_graph_with_pad( - sampling_config: SamplingParams, - model_name: str, -): - mtp_correctness(sampling_config, - model_name, - 2, - disable_padded_drafter_batch=False) \ No newline at end of file