Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,51 @@ def test_ngram_correctness(
assert matches > int(0.66 * len(ref_outputs))


def test_ngram_npu_async_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
"""
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram_npu speculative decoding + async.
"""

with VllmRunner(
model_name,
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
) as ref_llm:
ref_outputs = ref_llm.model.chat(test_prompts, sampling_config)

with VllmRunner(
model_name,
speculative_config={
"method": "ngram_gpu",
"prompt_lookup_max": 2,
"prompt_lookup_min": 2,
"num_speculative_tokens": 3,
},
max_model_len=1024,
async_scheduling=True,
cudagraph_capture_sizes=[1, 2, 4, 8],
) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))


def test_qwen3_vl_eagle_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
Expand Down
3 changes: 3 additions & 0 deletions vllm_ascend/spec_decode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
from vllm_ascend.spec_decode.ngram_proposer_npu import AscendNgramProposerNPU
from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer


def get_spec_decode_method(method, vllm_config, device, runner):
if method == "ngram":
return AscendNgramProposer(vllm_config, runner)
elif method == "ngram_gpu":
return AscendNgramProposerNPU(vllm_config, device, runner)
elif method == "suffix":
return AscendSuffixDecodingProposer(vllm_config, runner)
elif method == "medusa":
Expand Down
38 changes: 38 additions & 0 deletions vllm_ascend/spec_decode/ngram_proposer_npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from vllm.v1.spec_decode.ngram_proposer_gpu import NgramProposerGPU


class AscendNgramProposerNPU(NgramProposerGPU):
def __init__(self, vllm_config, device: torch.device, runner):
self.runner = runner
super().__init__(vllm_config, device=device)

def load_model(self, *args, **kwargs):
# No model to load.
pass

@torch.inference_mode()
def dummy_run(
self,
num_tokens,
with_prefill=None,
in_graph_capturing=None,
num_reqs=None,
num_tokens_across_dp=None,
aclgraph_runtime_mode=None,
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None,
is_profile=False,
):
pass

def propose(
self,
num_tokens_no_spec: torch.Tensor, # [batch_size]
token_ids_gpu: torch.Tensor, # [batch_size, max_len]
valid_sampled_token_ids_gpu: torch.Tensor, # [batch_size, num_spec_tokens + 1]
valid_sampled_tokens_count: torch.Tensor, # [batch_size]
) -> tuple[torch.Tensor, torch.Tensor]:
return super().propose(
num_tokens_no_spec, token_ids_gpu, valid_sampled_token_ids_gpu, valid_sampled_tokens_count
)
66 changes: 62 additions & 4 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import copy, deepcopy
from dataclasses import dataclass
from dataclasses import dataclass, replace
from multiprocessing import Manager
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias

Expand Down Expand Up @@ -75,6 +75,7 @@
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer_gpu import copy_num_valid_draft_tokens
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import record_function_or_nullcontext
from vllm.v1.worker import mamba_utils
Expand Down Expand Up @@ -115,6 +116,7 @@
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
from vllm_ascend.spec_decode.ngram_proposer_npu import AscendNgramProposerNPU
from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer
from vllm_ascend.utils import (
calc_split_factor,
Expand Down Expand Up @@ -437,6 +439,7 @@ def _set_up_drafter(self):
# Set up speculative decoding.
self.drafter: (
AscendNgramProposer
| AscendNgramProposerNPU
| AscendEagleProposer
| AscendDraftModelProposer
| AscendSuffixDecodingProposer
Expand Down Expand Up @@ -1004,6 +1007,42 @@ def propose_draft_token_ids(
draft_token_ids = None
elif isinstance(self.drafter, (AscendNgramProposer, AscendSuffixDecodingProposer)):
draft_token_ids = self.drafter.propose(valid_sampled_token_ids)
elif isinstance(self.drafter, AscendNgramProposerNPU):
(
next_token_ids,
valid_sampled_tokens_count,
valid_sampled_token_ids_gpu,
) = self.drafter.update_token_ids_ngram(
valid_sampled_token_ids,
self.input_batch,
self.token_ids_gpu_tensor,
self.num_tokens_no_spec_gpu,
self.discard_request_mask.gpu,
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)

batch_size = next_token_ids.shape[0]

draft_token_ids, num_valid_draft_tokens = self.drafter.propose(
self.num_tokens_no_spec_gpu[:batch_size],
self.token_ids_gpu_tensor[:batch_size],
valid_sampled_token_ids_gpu,
valid_sampled_tokens_count,
)

# Cache valid draft counts for scheduler-side trimming.
self._num_valid_draft_tokens = num_valid_draft_tokens

# Async D2H copy on a dedicated stream.
copy_num_valid_draft_tokens(
self._num_valid_draft_tokens_cpu,
self._num_valid_draft_tokens_copy_stream,
self._num_valid_draft_tokens_event,
self._num_valid_draft_tokens,
self.input_batch.num_reqs,
)
elif isinstance(self.drafter, AscendMedusaProposer):
draft_token_ids = self.drafter.propose(
valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states
Expand Down Expand Up @@ -1147,6 +1186,24 @@ def execute_model(
logger.warning("RoutedExpertsCapturer is not initialized.")
if self.execute_model_state is not None:
raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.")

# If ngram_gpu is used, we need to copy the scheduler_output to avoid
# the modification has influence on the scheduler_output in engine core process.
# The replace is much faster than deepcopy.
if (
self.speculative_config is not None
and self.speculative_config.use_ngram_gpu()
):
num_scheduled_tokens_copy = scheduler_output.num_scheduled_tokens.copy()
spec_decode_tokens_copy = (
scheduler_output.scheduled_spec_decode_tokens.copy()
)
scheduler_output = replace(
scheduler_output,
num_scheduled_tokens=num_scheduled_tokens_copy,
scheduled_spec_decode_tokens=spec_decode_tokens_copy,
)

# self._draft_token_ids is None when `input_fits_in_drafter=False`
# and there is no draft tokens scheduled. so it need to update the
# spec_decoding info in scheduler_output with async_scheduling.
Expand Down Expand Up @@ -1564,14 +1621,15 @@ def propose_draft_token_ids(sampled_token_ids):
if self.speculative_config:
use_padded_batch = (
self.speculative_config
and (self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model())
and (self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model()
or self.speculative_config.use_ngram_gpu())
and not self.speculative_config.disable_padded_drafter_batch
)
if use_padded_batch:
# EAGLE speculative decoding can use the GPU sampled tokens
# EAGLE/ngram_gpu speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampler_output.sampled_token_ids)
if self.speculative_config and not use_padded_batch:
elif self.speculative_config and not use_padded_batch:
# ngram and other speculative decoding methods use the sampled
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids)
Expand Down
Loading