Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import math
import os
import random
from typing import Any
Expand Down Expand Up @@ -239,3 +240,56 @@ def test_suffix_acceptance(

# Heuristic: expect at least 80% acceptance rate at the end.
assert last_accept_rate > 0.60


@pytest.mark.parametrize("use_eagle3", [True], ids=["eagle3"])
def test_eagle_logprobs(
model_name: str,
use_eagle3: bool,
):
prompt = {"role": "user", "content": "Hello world " * 10}
sampling_params = SamplingParams(temperature=0,
logprobs=1,
max_tokens=10,
ignore_eos=False)

ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False)
ref_outputs = ref_llm.chat([prompt], sampling_params)
ref_logprobs = []
for output in ref_outputs[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
ref_logprobs.append(logprobs[token_id])
del ref_llm

spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
with VllmRunner(
model_name,
max_num_seqs=1,
max_num_batched_tokens=2048,
gpu_memory_utilization=0.6,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 2,
"max_model_len": 128,
},
max_model_len=128,
enforce_eager=False,
) as runner:
spec_outputs = runner.model.chat([prompt], sampling_params)

# Collect logprobs outputs from spec decode LLM.
spec_logprobs = []
for output in spec_outputs[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
spec_logprobs.append(logprobs[token_id])

for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(ref_logprob.logprob,
spec_logprob.logprob,
rel_tol=5e-2,
abs_tol=1e-1)
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token
20 changes: 18 additions & 2 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@
# Future Plan:
# Remove this patch when the bug is fixed.
#
# ** File: worker/patch_qwen3_next_mtp.py**
# ** 11. File: worker/patch_qwen3_next_mtp.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.worker.utils.bind_kv_cache`
# Why:
Expand All @@ -241,7 +241,7 @@
# Future Plan:
# Remove this patch after discussing with vllm community and adapting bind_kv_cache to npu.
#
# ** File: worker/patch_module.py**
# ** 12. File: worker/patch_module.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.attention.backends.gdn_attn.torch.argsort`
# Why:
Expand All @@ -257,3 +257,19 @@
# Remove this patch when bool is supported in 'torch.argsort' func of npu.
# Make 'torch.argsort' in `vllm.v1.attention.backends.gdn_attn` be stable.
#
# ** 13. File: worker/patch_rejection_sampler.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.sample.rejection_sampler`
# Why:
# - some functions from `rejection_sampler` are not supported or slow on npu.
# How:
# - add npu_top_k_top_p to 'apply_sampling_constraints' func
# - add custom triton kernel to `expand_batch_to_tokens` and `rejection_sample`
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/874
# https://github.com/vllm-project/vllm/pull/4849
# Future Plan:
# 1. make these functions as class func of RejectionSampler, create AscendRejectionSampler
# to override them, then delete the patch file `worker/patch_rejection_sampler.py`.
# 2. make these functions as costom op, then remove AscendRejectionSampler
#
1 change: 1 addition & 0 deletions vllm_ascend/patch/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@
import vllm_ascend.patch.worker.patch_qwen3_vl # noqa
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import vllm_ascend.patch.worker.patch_rope # noqa
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
import vllm_ascend.patch.worker.patch_rejection_sampler # noqa
11 changes: 11 additions & 0 deletions vllm_ascend/patch/worker/patch_rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import vllm.v1.sample.rejection_sampler as rs

from vllm_ascend.sample.rejection_sampler import (apply_sampling_constraints,
expand_batch_to_tokens,
rejection_sample)

# TODO: delete this patch after apply_sampling_constraints and rejection_sample
# are extracted to as class func of RejectionSampler
rs.apply_sampling_constraints = apply_sampling_constraints
rs.rejection_sample = rejection_sample
rs.expand_batch_to_tokens = expand_batch_to_tokens
95 changes: 1 addition & 94 deletions vllm_ascend/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@
from typing import Optional

import torch
import torch.nn as nn
import torch_npu
import vllm.v1.sample.rejection_sampler as rs
from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
generate_uniform_probs)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.sample.rejection_sampler import generate_uniform_probs

from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type

Expand All @@ -21,92 +17,6 @@
MAX_SPEC_LEN = 32


class AscendRejectionSampler(RejectionSampler, nn.Module):
"""
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
probabilities.
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
end of the sequence. The bonus token is only sampled from the target
probabilities. We pass in the bonus tokens instead of sampling them
in the rejection sampler to allow for more flexibility in the
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
strategies.
output tokens:
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
"""

def forward(
self,
metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
# [num_tokens, vocab_size]
target_logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
'''
Args:
metadata:
Metadata for spec decoding.
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
target_logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
sampling_metadata (SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
target_logits = apply_sampling_constraints(
target_logits,
metadata.cu_num_draft_tokens,
sampling_metadata,
)
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)

output_token_ids = rejection_sample(
metadata.draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
draft_probs,
target_probs,
bonus_token_ids,
sampling_metadata,
)
return output_token_ids


def apply_sampling_constraints(
logits: torch.Tensor, # [num_tokens, vocab_size]
cu_num_draft_tokens: torch.Tensor, # [batch_size]
Expand Down Expand Up @@ -844,6 +754,3 @@ def sample_recovered_tokens_kernel(
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
orig_prob)


rs.expand_batch_to_tokens = expand_batch_to_tokens
Loading
Loading