diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 57f170b59000..1f13de50bf3a 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -821,9 +821,7 @@ def sample( logits, input_batch, # Draft logits are needed for probabilistic rejection sampling. - self.req_states.draft_logits[input_batch.idx_mapping] - if self.req_states.draft_logits is not None - else None, + self.req_states.draft_logits, ) # Get the number of sampled and rejected tokens. diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py b/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py index c835d86b2cd6..9bcf629b8034 100644 --- a/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py +++ b/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py @@ -68,55 +68,158 @@ def strict_rejection_sample( @triton.jit -def _probabilistic_rejection_sample_kernel( +def _gather_draft_logits_and_target_argmax_kernel( + local_target_argmax_ptr, + local_target_argmax_stride, + local_target_max_ptr, + local_target_max_stride, + # [num_logits, V] + out_draft_logits_ptr, + out_draft_logits_stride, + # [num_logits, V] + target_logits_ptr, + target_logits_stride, + # [max_num_reqs, num_speculative_steps, V] + draft_logits_ptr, + draft_logits_stride_0, + draft_logits_stride_1, + # [num_logits] + expanded_idx_mapping_ptr, + # [num_logits] + expanded_local_pos_ptr, + # [max_num_reqs] + temp_ptr, + vocab_size, + num_speculative_steps, + BLOCK_SIZE: tl.constexpr, +): + logit_idx = tl.program_id(0) + req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx) + draft_step_idx = tl.load(expanded_local_pos_ptr + logit_idx) + + block_idx = tl.program_id(1) + block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = block_offsets < vocab_size + temp = tl.load(temp_ptr + req_state_idx).to(tl.float32) + + if temp == 0.0: + # Greedy sampling. Get the target logits argmax. + target_logits = tl.load( + target_logits_ptr + logit_idx * target_logits_stride + block_offsets, + mask=mask, + other=float("-inf"), + ).to(tl.float32) + value, idx = tl.max(target_logits, axis=0, return_indices=True) + token_id = block_idx * BLOCK_SIZE + idx + tl.store( + local_target_argmax_ptr + + logit_idx * local_target_argmax_stride + + block_idx, + token_id, + ) + tl.store( + local_target_max_ptr + logit_idx * local_target_max_stride + block_idx, + value, + ) + elif draft_step_idx < num_speculative_steps: + draft_logits = tl.load( + draft_logits_ptr + + req_state_idx * draft_logits_stride_0 + + draft_step_idx * draft_logits_stride_1 + + block_offsets, + mask=mask, + other=float("-inf"), + ).to(tl.float32) + tl.store( + out_draft_logits_ptr + logit_idx * out_draft_logits_stride + block_offsets, + draft_logits, + mask=mask, + ) + + +@triton.jit +def _probabilistic_rejection_kernel( # [num_reqs, num_speculative_steps + 1] sampled_ptr, sampled_stride, # [num_reqs] rejected_steps_ptr, + # [num_reqs] + rejected_pos_ptr, # [num_logits] draft_sampled_ptr, # [num_logits, V] target_probs_ptr, target_probs_stride, - # [num_reqs, num_speculative_steps, V] + # [num_logits, V] draft_probs_ptr, - draft_probs_stride_0, - draft_probs_stride_1, + draft_probs_stride, + # [num_logits, num_blocks] + local_target_argmax_ptr, + local_target_argmax_stride, + # [num_logits, num_blocks] + local_target_max_ptr, + local_target_max_stride, # [num_reqs + 1] cu_num_logits_ptr, # [num_logits] pos_ptr, # [num_reqs] idx_mapping_ptr, - # [num_reqs] + # [max_num_reqs] + temp_ptr, + # [max_num_reqs] seeds_ptr, + NUM_BLOCKS: tl.constexpr, + PADDED_NUM_BLOCKS: tl.constexpr, ): req_idx = tl.program_id(0) start_idx = tl.load(cu_num_logits_ptr + req_idx) num_tokens = tl.load(cu_num_logits_ptr + req_idx + 1) - start_idx - seed = tl.load(seeds_ptr + tl.load(idx_mapping_ptr + req_idx)) + req_state_idx = tl.load(idx_mapping_ptr + req_idx) + seed = tl.load(seeds_ptr + req_state_idx) + temp = tl.load(temp_ptr + req_state_idx).to(tl.float32) rejected_step = 0 accepted = True for i in range(num_tokens - 1): if accepted: - draft_sampled = tl.load(draft_sampled_ptr + start_idx + i + 1) - target_prob = tl.load( - target_probs_ptr + (start_idx + i) * target_probs_stride + draft_sampled - ) - draft_prob = tl.load( - draft_probs_ptr - + req_idx * draft_probs_stride_0 - + i * draft_probs_stride_1 - + draft_sampled - ) - pos = tl.load(pos_ptr + start_idx + i) - u = tl.sum(tl.rand(seed, pos + tl.arange(0, 1))) - accepted &= target_prob > u * draft_prob + logit_idx = start_idx + i + draft_sampled = tl.load(draft_sampled_ptr + logit_idx + 1) + if temp == 0.0: + # Greedy sampling. Only accept the sampled draft token if + # it exactly matches the target argmax. + block_offsets = tl.arange(0, PADDED_NUM_BLOCKS) + block_mask = block_offsets < NUM_BLOCKS + local_max = tl.load( + local_target_max_ptr + + logit_idx * local_target_max_stride + + block_offsets, + mask=block_mask, + other=float("-inf"), + ) + max_block = tl.argmax(local_max, axis=0) + target_argmax = tl.load( + local_target_argmax_ptr + + logit_idx * local_target_argmax_stride + + max_block + ) + accepted &= target_argmax == draft_sampled + else: + target_prob = tl.load( + target_probs_ptr + logit_idx * target_probs_stride + draft_sampled + ) + draft_prob = tl.load( + draft_probs_ptr + logit_idx * draft_probs_stride + draft_sampled + ) + pos = tl.load(pos_ptr + logit_idx) + u = tl.sum(tl.rand(seed, pos + tl.arange(0, 1))) + accepted &= target_prob > u * draft_prob tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled) rejected_step += accepted tl.store(rejected_steps_ptr + req_idx, rejected_step) + pos_val = tl.load(pos_ptr + start_idx + rejected_step) + tl.store(rejected_pos_ptr + req_idx, pos_val) @triton.jit @@ -124,63 +227,60 @@ def _compute_residual_logits_kernel( # [num_reqs, V] residual_logits_ptr, residual_logits_stride, - # [num_reqs] - residual_pos_ptr, - # [num_logits, V] - target_logits_ptr, - target_logits_stride, # [num_logits, V] target_probs_ptr, target_probs_stride, - # [num_reqs, num_speculative_steps, V] + # [num_logits, V] draft_probs_ptr, - draft_probs_stride_0, - draft_probs_stride_1, + draft_probs_stride, + # [num_logits, V] + target_logits_ptr, + target_logits_stride, # [num_reqs] rejected_step_ptr, # [num_reqs + 1] cu_num_logits_ptr, - # [num_logits] - pos_ptr, + # [num_reqs] + idx_mapping_ptr, + # [max_num_reqs] + temp_ptr, vocab_size, BLOCK_SIZE: tl.constexpr, ): req_idx = tl.program_id(0) block_idx = tl.program_id(1) + req_state_idx = tl.load(idx_mapping_ptr + req_idx) start_idx = tl.load(cu_num_logits_ptr + req_idx) end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) - rejected_draft_step = tl.load(rejected_step_ptr + req_idx) - rejected_logit_idx = start_idx + rejected_draft_step - + rejected_logit_idx = start_idx + tl.load(rejected_step_ptr + req_idx) + temp = tl.load(temp_ptr + req_state_idx).to(tl.float32) block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = block_offsets < vocab_size - if rejected_logit_idx < end_idx - 1: + if temp == 0.0 or (rejected_logit_idx == end_idx - 1): + # Greedy sampling / bonus token. In either case, use the + # target logits directly to reduce numerical error. + residual_logits = tl.load( + target_logits_ptr + + rejected_logit_idx * target_logits_stride + + block_offsets, + mask=mask, + other=float("-inf"), + ) + else: target_probs = tl.load( target_probs_ptr + rejected_logit_idx * target_probs_stride + block_offsets, mask=mask, other=0.0, ) draft_probs = tl.load( - draft_probs_ptr - + req_idx * draft_probs_stride_0 - + rejected_draft_step * draft_probs_stride_1 - + block_offsets, + draft_probs_ptr + rejected_logit_idx * draft_probs_stride + block_offsets, mask=mask, other=0.0, ) residual_probs = tl.maximum(target_probs - draft_probs, 0.0) residual_logits = tl.log(residual_probs) - else: - # This is a bonus token. Directly return the target logits. - residual_logits = tl.load( - target_logits_ptr - + rejected_logit_idx * target_logits_stride - + block_offsets, - mask=mask, - other=0.0, - ) tl.store( residual_logits_ptr + req_idx * residual_logits_stride + block_offsets, @@ -188,18 +288,13 @@ def _compute_residual_logits_kernel( mask=mask, ) - # First block computes the residual logit positions. - if block_idx == 0: - pos_val = tl.load(pos_ptr + rejected_logit_idx) - tl.store(residual_pos_ptr + req_idx, pos_val) - def probabilistic_rejection_sample( - # [num_draft_tokens + num_reqs, V] + # [num_logits, V] target_logits: torch.Tensor, - # [num_reqs, num_speculative_steps, V] + # [max_num_reqs, num_speculative_steps, V] draft_logits: torch.Tensor, - # [num_draft_tokens + num_reqs] + # [num_logits] draft_sampled: torch.Tensor, # [num_reqs + 1] cu_num_logits: torch.Tensor, @@ -207,16 +302,53 @@ def probabilistic_rejection_sample( pos: torch.Tensor, # [num_reqs] idx_mapping: torch.Tensor, + # [num_logits] + expanded_idx_mapping: torch.Tensor, + # [num_logits] + expanded_local_pos: torch.Tensor, + # [max_num_reqs] temperature: torch.Tensor, + # [max_num_reqs] seed: torch.Tensor, num_speculative_steps: int, ) -> tuple[torch.Tensor, torch.Tensor]: num_reqs = cu_num_logits.shape[0] - 1 - vocab_size = target_logits.shape[-1] + num_logits, vocab_size = target_logits.shape + + BLOCK_SIZE = 1024 + num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) + + # Gather draft logits and target argmax for greedy sampling. + gathered_draft_logits = target_logits.new_empty(target_logits.shape) + local_target_argmax = target_logits.new_empty( + num_logits, num_blocks, dtype=torch.int64 + ) + local_target_max = target_logits.new_empty( + num_logits, num_blocks, dtype=torch.float32 + ) + _gather_draft_logits_and_target_argmax_kernel[(num_logits, num_blocks)]( + local_target_argmax, + local_target_argmax.stride(0), + local_target_max, + local_target_max.stride(0), + gathered_draft_logits, + gathered_draft_logits.stride(0), + target_logits, + target_logits.stride(0), + draft_logits, + draft_logits.stride(0), + draft_logits.stride(1), + expanded_idx_mapping, + expanded_local_pos, + temperature, + vocab_size, + num_speculative_steps, + BLOCK_SIZE=BLOCK_SIZE, + ) # Compute target and draft probs. target_probs = torch.softmax(target_logits, dim=-1) - draft_probs = torch.softmax(draft_logits, dim=-1) + draft_probs = torch.softmax(gathered_draft_logits, dim=-1) # Rejection sample. # [num_reqs, num_speculative_steps + 1] @@ -225,45 +357,49 @@ def probabilistic_rejection_sample( ) # [num_reqs] rejected_steps = sampled.new_empty(num_reqs) - _probabilistic_rejection_sample_kernel[(num_reqs,)]( + # [num_reqs] + rejected_pos = pos.new_empty(num_reqs) + _probabilistic_rejection_kernel[(num_reqs,)]( sampled, sampled.stride(0), rejected_steps, + rejected_pos, draft_sampled, target_probs, target_probs.stride(0), draft_probs, draft_probs.stride(0), - draft_probs.stride(1), + local_target_argmax, + local_target_argmax.stride(0), + local_target_max, + local_target_max.stride(0), cu_num_logits, pos, idx_mapping, + temperature, seed, num_warps=1, + NUM_BLOCKS=num_blocks, + PADDED_NUM_BLOCKS=triton.next_power_of_2(num_blocks), ) # Compute the logits and positions to resample the rejected/bonus # tokens from. # [num_reqs, vocab_size] residual_logits = target_logits.new_empty(num_reqs, vocab_size) - # [num_reqs] - residual_pos = pos.new_empty(num_reqs) - BLOCK_SIZE = 1024 - num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) _compute_residual_logits_kernel[(num_reqs, num_blocks)]( residual_logits, residual_logits.stride(0), - residual_pos, - target_logits, - target_logits.stride(0), target_probs, target_probs.stride(0), draft_probs, draft_probs.stride(0), - draft_probs.stride(1), + target_logits, + target_logits.stride(0), rejected_steps, cu_num_logits, - pos, + idx_mapping, + temperature, vocab_size, BLOCK_SIZE=BLOCK_SIZE, ) @@ -274,7 +410,7 @@ def probabilistic_rejection_sample( idx_mapping, temperature, seed, - residual_pos, + rejected_pos, apply_temperature=False, ) sampled.scatter_(1, rejected_steps.unsqueeze(1), resampled.unsqueeze(1)) @@ -333,6 +469,8 @@ def __call__( input_batch.cu_num_logits, pos, input_batch.idx_mapping, + input_batch.expanded_idx_mapping, + input_batch.expanded_local_pos, self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.seeds.gpu, self.num_speculative_steps,