Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,14 @@ def forward(
metadata.cu_num_draft_tokens,
sampling_metadata,
)
# Compute probability distribution from target logits.
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)

output_token_ids = rejection_sample(
metadata.draft_token_ids,
metadata.num_draft_tokens,
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
draft_probs,
target_probs,
target_logits,
bonus_token_ids,
sampling_metadata,
)
Expand Down Expand Up @@ -353,25 +351,24 @@ def rejection_sample(
# [num_tokens, vocab_size]
draft_probs: torch.Tensor | None,
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
target_logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2
assert cu_num_draft_tokens.ndim == 1
assert target_probs.ndim == 2
assert target_logits.ndim == 2

batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0]
vocab_size = target_probs.shape[-1]
device = target_probs.device
vocab_size = target_logits.shape[-1]
device = target_logits.device
assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous()
assert target_probs.is_contiguous()
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
assert target_logits.shape == (num_tokens, vocab_size)

# Create output buffer.
output_token_ids = torch.full(
Expand All @@ -387,7 +384,7 @@ def rejection_sample(
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
target_argmax = target_logits.argmax(dim=-1)
rejection_greedy_sample_kernel[(batch_size,)](
output_token_ids,
cu_num_draft_tokens,
Expand All @@ -400,6 +397,10 @@ def rejection_sample(
if sampling_metadata.all_greedy:
return output_token_ids

# Compute probability distribution from target logits.
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
assert target_probs.is_contiguous()

# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_probs = generate_uniform_probs(
Expand Down