Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SpecDecode][Kernel] Use Flashinfer for Rejection Sampling in Speculative Decoding #7244

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix spec decode sampler tests
LiuXiaoxuanPKU committed Aug 6, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit d7b254bd2339c7e638fa929f2cf71e7f0f335b5b
14 changes: 7 additions & 7 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -25,15 +25,15 @@ def mock_causal_accepted_tensor(

accepted = (torch.arange(k).expand(batch_size, k) <=
last_accepted_indices.unsqueeze(-1).broadcast_to(
batch_size, k)).to(device="cuda")
batch_size, k))

# Sprinkle accepted values after the contiguous initial accepted values.
# This replicates the behavior of rejection sampling, which may "accept"
# a token that cannot be accepted because of causality.
sprinkle_candidates = (
torch.arange(k).expand(batch_size, k) >
last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1)
sprinkle = torch.rand(batch_size, k, device="cuda") > 0.5
sprinkle = torch.rand(batch_size, k) > 0.5
accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
return accepted

@@ -86,7 +86,7 @@ def test_correct_output_format(which_tokens_accepted: str,

rejection_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens)
rejection_sampler.init_gpu_tensors(rank=0)
rejection_sampler.init_gpu_tensors(device=device)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted,
recovered_token_ids,
@@ -138,7 +138,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str):
torch.set_default_device(device)
rejection_sampler = RejectionSampler()
rejection_sampler.init_gpu_tensors(rank=0)
rejection_sampler.init_gpu_tensors(device=device)

draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
@@ -167,7 +167,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
device: str):
torch.set_default_device(device)
rejection_sampler = RejectionSampler()
rejection_sampler.init_gpu_tensors(rank=0)
rejection_sampler.init_gpu_tensors(device=device)

draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
@@ -211,7 +211,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
torch.set_default_device(device)

rejection_sampler = RejectionSampler(strict_mode=True)
rejection_sampler.init_gpu_tensors(rank=0)
rejection_sampler.init_gpu_tensors(device=device)

draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
@@ -339,7 +339,7 @@ def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
self.vocab_size = vocab_size
self.vocab_range = (0, vocab_size)

self.rejection_sampler.init_gpu_tensors(rank=0)
self.rejection_sampler.init_gpu_tensors(device=0)

# Keep test simple, use k=1
self.k = 1
18 changes: 9 additions & 9 deletions tests/samplers/test_typical_acceptance_sampler.py
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
"""
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(rank=0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
@@ -111,7 +111,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
@@ -171,7 +171,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
@@ -225,7 +225,7 @@ def test_temperature_zero_target_distribution(seed: int,

typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
@@ -285,7 +285,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
# distribution.
@@ -352,7 +352,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted.
@@ -414,7 +414,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
# id has probability 1.0 and others have a very low probability of
@@ -447,7 +447,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
disable_bonus_tokens=disable_bonus_tokens,
posterior_threshold=0.0,
posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
@@ -485,7 +485,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = -torch.ones(
(batch_size, k), dtype=torch.long)
9 changes: 6 additions & 3 deletions vllm/model_executor/layers/spec_decode_base_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Dict, Optional
from typing import Dict, Optional, Union

import torch
import torch.jit
@@ -36,9 +36,12 @@ def __init__(self,
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0

def init_gpu_tensors(self, rank: int) -> None:
def init_gpu_tensors(self, device: Union[int, str]) -> None:
assert self.num_accepted_tokens is None
device = f"cuda:{rank}"
if isinstance(device, int):
device = f"cuda:{device}"
elif not isinstance(device, str):
raise ValueError(f"Device must be int or str, get {type(device)}")
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)