Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions tests/e2e/singlecard/spec_decode/test_mtp_eagle_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from __future__ import annotations

import os
from typing import Union

import pytest
from vllm import SamplingParams
Expand Down Expand Up @@ -124,11 +123,11 @@ def test_deepseek_mtp_correctness(model_name: str, num_speculative_tokens: int,
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
@pytest.mark.parametrize("async_scheduling", [True, False])
@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1])
def test_llama_qwen3_eagle_correctness(
model_name: str, model_name_main: str, num_speculative_tokens: int,
method: str, disable_padded_drafter_batch: bool,
async_scheduling: bool, draft_tensor_parallel_size: Union[None, int]):
def test_llama_qwen3_eagle_correctness(model_name: str, model_name_main: str,
num_speculative_tokens: int,
method: str,
disable_padded_drafter_batch: bool,
async_scheduling: bool):

example_prompts = [
"Hello, my name is",
Expand Down Expand Up @@ -163,8 +162,6 @@ def test_llama_qwen3_eagle_correctness(
"method": method,
"model": model_name,
"num_speculative_tokens": num_speculative_tokens,
"draft_tensor_parallel_size":
draft_tensor_parallel_size,
"max_model_len": 128,
"draft_vocab_size": 128256,
},
Expand Down
8 changes: 1 addition & 7 deletions tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
import os
import random
from typing import Any, Union
from typing import Any

import pytest
from transformers import AutoTokenizer
Expand Down Expand Up @@ -217,11 +217,9 @@ def test_suffix_acceptance(


@pytest.mark.parametrize("use_eagle3", [True], ids=["eagle3"])
@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1])
def test_eagle_logprobs(
model_name: str,
use_eagle3: bool,
draft_tensor_parallel_size: Union[None, int],
):
prompt = {"role": "user", "content": "Hello world " * 10}
sampling_params = SamplingParams(temperature=0,
Expand All @@ -248,7 +246,6 @@ def test_eagle_logprobs(
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 2,
"draft_tensor_parallel_size": draft_tensor_parallel_size,
"max_model_len": 128,
},
max_model_len=128,
Expand All @@ -274,13 +271,11 @@ def test_eagle_logprobs(

@pytest.mark.parametrize("method", MODELS.keys())
@pytest.mark.parametrize("num_speculative_tokens", [3])
@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
@pytest.mark.parametrize("async_scheduling", [True, False])
def test_llama_qwen_eagle_acceptance(
method: str,
num_speculative_tokens: int,
draft_tensor_parallel_size: Union[None, int],
disable_padded_drafter_batch: bool,
async_scheduling: bool,
):
Expand Down Expand Up @@ -331,7 +326,6 @@ def test_llama_qwen_eagle_acceptance(
speculative_config = {
"method": method,
"num_speculative_tokens": num_speculative_tokens,
"draft_tensor_parallel_size": draft_tensor_parallel_size,
"disable_padded_drafter_batch": disable_padded_drafter_batch,
"model": spec_model_name,
}
Expand Down
8 changes: 0 additions & 8 deletions tests/ut/spec_decode/test_eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def setUp(self):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.vllm_config.model_config.uses_mrope = False
self.vllm_config.parallel_config.tensor_parallel_size = 1
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
self.vllm_config.speculative_config.num_speculative_tokens = 2
self.vllm_config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(2)
Expand Down Expand Up @@ -116,8 +114,6 @@ def setUp(self):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.vllm_config.model_config.uses_mrope = False
self.vllm_config.parallel_config.tensor_parallel_size = 1
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
self.vllm_config.speculative_config.num_speculative_tokens = 2
self.vllm_config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(2)
Expand Down Expand Up @@ -250,8 +246,6 @@ def setUp(self):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.vllm_config.model_config.uses_mrope = False
self.vllm_config.parallel_config.tensor_parallel_size = 1
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
self.vllm_config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(4)
])
Expand Down Expand Up @@ -366,8 +360,6 @@ def setUp(self):
self.vllm_config.model_config.dtype = torch.float16
self.vllm_config.model_config.max_model_len = 2048
self.vllm_config.model_config.uses_mrope = False
self.vllm_config.parallel_config.tensor_parallel_size = 1
self.vllm_config.speculative_config.draft_tensor_parallel_size = 1
self.vllm_config.speculative_config.num_speculative_tokens = 2
self.vllm_config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(2)
Expand Down
3 changes: 0 additions & 3 deletions tests/ut/spec_decode/test_mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ def vllm_config(self):
config.model_config.max_model_len = 2048
config.model_config.uses_mrope = False
config.model_config.hf_text_config = None
config.model_config.hf_config = None
config.parallel_config.tensor_parallel_size = 1
config.speculative_config.draft_tensor_parallel_size = 1

config.load_config = None

Expand Down
21 changes: 0 additions & 21 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,6 @@ def __init__(self,

self.use_sparse = hasattr(vllm_config.model_config.hf_text_config,
"index_topk")
# NOTE:
# `draft_tensor_parallel_size` does not take effect for Eagle:
# the draft model uses the same TP size as the target model in practice.
# so we applied this patch to set tp=1 of draft model separately.
# Due to verification of `_verify_and_get_draft_tp` in vllm,
# the value of `draft_tensor_parallel_size` here will either be 1 separately
# or the same as target model.
# TODO(zhaomingyu13): If we want to adapt to the case where draft model tp
# is not 1 and differs from target model, this part should be rewritten.
if (vllm_config.parallel_config.tensor_parallel_size
!= self.speculative_config.draft_tensor_parallel_size):
tp_group = init_model_parallel_group(
[[get_world_group().rank]],
get_world_group().rank,
torch.distributed.get_backend(get_world_group().device_group),
use_message_queue_broadcaster=True,
group_name="tp",
)
self.tp_group_context = patch_tensor_parallel_group(tp_group)
else:
self.tp_group_context = nullcontext()

# TODO: Remove it when the bug of fx-graph is solved
self.maybe_eager_context: ContextManager[Any] = nullcontext()
Expand Down
19 changes: 5 additions & 14 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,6 @@ def graph_capture(device: torch.device):
yield graph_capture_context


def get_tp_context(drafter):
return getattr(drafter, "tp_group_context", nullcontext())


class ExecuteModelState(NamedTuple):
"""Ephemeral cached state transferred between execute_model() and
sample_tokens(), after execute_model() returns None."""
Expand Down Expand Up @@ -2326,8 +2322,7 @@ def load_model(self) -> None:
model_register(self.model, self.model_config)
if self.drafter:
logger.info("Loading drafter model...")
with get_tp_context(self.drafter):
self.drafter.load_model(self.model)
self.drafter.load_model(self.model)
if self.use_aux_hidden_state_outputs:
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
Expand Down Expand Up @@ -2703,15 +2698,11 @@ def may_reinitialize_input_batch(self,
kernel_block_sizes = []
for kv_cache_group_id, kv_cache_group in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
# All layers in the UniformTypeKVCacheSpecs have the same type,
# Pick an arbitrary one to dispatch.
kv_cache_spec = next(
iter(kv_cache_spec.kv_cache_specs.values()))
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):

if isinstance(kv_cache_group.kv_cache_spec,
EncoderOnlyAttentionSpec):
continue
elif isinstance(kv_cache_spec, AttentionSpec):
elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
# This is an attention backend that supports virtual
# block splitting. Get the supported block sizes from
# the backend.
Expand Down
Loading