diff --git a/tests/v1/spec_decode/test_rejection_sample.py b/tests/v1/spec_decode/test_rejection_sample.py new file mode 100644 index 000000000000..f1cc402c30fa --- /dev/null +++ b/tests/v1/spec_decode/test_rejection_sample.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from scipy import stats + +from vllm.platforms import current_platform +from vllm.v1.worker.gpu.spec_decode.rejection_sample import ( + sample_recovered_and_bonus_tokens, +) + +device = current_platform.device_type + + +@pytest.mark.skipif(device != "cuda", reason="Requires CUDA") +def test_sample_recovered_and_bonus_tokens_correctness(): + """Verify that sample_recovered_and_bonus_tokens produces samples whose + empirical distribution matches the theoretical distribution. + + For each non-bonus (recovered) token position the correct distribution is + max(0, target - draft) renormalized; for the bonus token position it is + target directly. We draw N samples with independent seeds and run a + chi-squared goodness-of-fit test against those expected distributions. + """ + # 3 requests with 2 / 1 / 2 draft tokens each, giving 8 total token + # req 0 (2 draft + 1 bonus): tok 0, 1, 2 + # req 1 (1 draft + 1 bonus): tok 3, 4 + # req 2 (2 draft + 1 bonus): tok 5, 6, 7 + num_reqs = 3 + num_tokens = 8 + vocab_size = 3 + + target_probs = torch.tensor( + [ + [0.6, 0.3, 0.1], # req 0 recovered + [0.3, 0.4, 0.3], # req 0 recovered + [0.1, 0.7, 0.2], # req 0 bonus + [0.5, 0.3, 0.2], # req 1 recovered + [0.2, 0.5, 0.3], # req 1 bonus + [0.4, 0.4, 0.2], # req 2 recovered + [0.2, 0.6, 0.2], # req 2 recovered + [0.3, 0.3, 0.4], # req 2 bonus + ], + dtype=torch.float32, + device=device, + ) + draft_probs = torch.tensor( + [ + [0.1, 0.2, 0.7], # req 0 draft 0 + [0.1, 0.2, 0.7], # req 0 draft 1 + [0.3, 0.1, 0.6], # req 1 draft 0 + [0.1, 0.3, 0.6], # req 2 draft 0 + [0.1, 0.2, 0.7], # req 2 draft 1 + ], + dtype=torch.float32, + device=device, + ) + + # Expected distributions (recovered = max(0, target-draft) normalised): + # tok 0: max(0, [0.5, 0.1, -0.6]) → [5/6, 1/6, 0] + # tok 1: max(0, [0.2, 0.2, -0.4]) → [1/2, 1/2, 0] + # tok 2: bonus → [0.1, 0.7, 0.2] + # tok 3: max(0, [0.2, 0.2, -0.4]) → [1/2, 1/2, 0] + # tok 4: bonus → [0.2, 0.5, 0.3] + # tok 5: max(0, [0.3, 0.1, -0.4]) → [3/4, 1/4, 0] + # tok 6: max(0, [0.1, 0.4, -0.5]) → [1/5, 4/5, 0] + # tok 7: bonus → [0.3, 0.3, 0.4] + expected_dists = torch.tensor( + [ + [5 / 6, 1 / 6, 0.0], # req 0 recovered + [0.5, 0.5, 0.0], # req 0 recovered + [0.1, 0.7, 0.2], # req 0 bonus + [0.5, 0.5, 0.0], # req 1 recovered + [0.2, 0.5, 0.3], # req 1 bonus + [3 / 4, 1 / 4, 0.0], # req 2 recovered + [1 / 5, 4 / 5, 0.0], # req 2 recovered + [0.3, 0.3, 0.4], # req 2 bonus + ], + dtype=torch.float32, + device=device, + ) + + cu_num_logits = torch.tensor([0, 3, 5, 8], dtype=torch.int32, device=device) + idx_mapping = torch.tensor( + [0, 0, 0, 1, 1, 2, 2, 2], dtype=torch.int32, device=device + ) + # Exact values don't matter, only used for seeding + pos = torch.arange(num_tokens, dtype=torch.int64, device=device) + + # Draw N samples with varying seeds to build empirical distributions + N = 30_000 + counts = torch.zeros(num_tokens, vocab_size, dtype=torch.int64, device=device) + for trial in range(N): + # Give each request a distinct seed that changes every trial. + seed = torch.arange( + trial * num_reqs + 1, + (trial + 1) * num_reqs + 1, + dtype=torch.int64, + device=device, + ) + samples = sample_recovered_and_bonus_tokens( + target_probs, draft_probs, cu_num_logits, idx_mapping, seed, pos + ) + counts.scatter_add_( + 1, + samples.unsqueeze(1), + torch.ones(num_tokens, 1, dtype=torch.int64, device=device), + ) + + # Chi-squared test to compare empirical and expected distributions + # Chi-squared test cannot handle zero frequencies, + # so we check those explicitly first + expected_freq = expected_dists * N # [num_tokens, vocab_size] + nonzero = expected_freq > 0 # [num_tokens, vocab_size] + assert torch.where(~nonzero, counts, 0).sum() == 0, "Sampled a zero-prob token" + + alpha = 1e-4 + for tok_pos in range(num_tokens): + obs = counts[tok_pos, nonzero[tok_pos]].cpu().numpy().astype(float) + exp = expected_freq[tok_pos, nonzero[tok_pos]].cpu().numpy().astype(float) + _, p_value = stats.chisquare(obs, f_exp=exp, sum_check=False) + assert p_value > alpha, ( + f"Token position {tok_pos}: empirical distribution significantly " + f"differs from expected ({p_value=:.6f})" + ) diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index a15da926da4e..055c108accbd 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -81,6 +81,9 @@ class InputBatch: cu_num_logits: torch.Tensor cu_num_logits_np: np.ndarray + # [num_draft_tokens] + draft_logits: torch.Tensor + # Whether any requests in batch use structured output. has_structured_output_reqs: bool @@ -89,6 +92,7 @@ def make_dummy( cls, num_reqs: int, num_tokens: int, + vocab_size: int, input_buffers: InputBuffers, device: torch.device, ) -> "InputBatch": @@ -150,6 +154,7 @@ def make_dummy( logits_indices=logits_indices, cu_num_logits=cu_num_logits, cu_num_logits_np=cu_num_logits_np, + draft_logits=torch.empty(0, vocab_size, device=device, dtype=torch.float32), has_structured_output_reqs=False, ) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index ccab6cec8c78..e6f723678e8a 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -69,7 +69,6 @@ from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import ( set_eagle3_aux_hidden_state_layers, ) -from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker @@ -389,7 +388,7 @@ def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None: # NOTE(woosuk): During the initial memory profiling, the sampler may skip # top_k, top_p, and logprobs, using less GPU memory than what is possible # during actual execution. - self.sampler( + self.sampler.sample( logits, idx_mapping, idx_mapping_np, @@ -601,11 +600,23 @@ def prepare_inputs( expanded_local_pos = torch.zeros( num_reqs, dtype=torch.int32, device=self.device ) + draft_logits = torch.empty( + 0, self.vocab_size, device=self.device, dtype=torch.float32 + ) else: num_draft_tokens = np.array( [len(draft_tokens.get(req_id, ())) for req_id in req_ids], dtype=np.int32, ) + draft_logits = torch.concat( + [ + self.req_states.draft_logits[ + self.req_states.req_id_to_index[req_id], :num_draft_token + ] + for req_id, num_draft_token in zip(req_ids, num_draft_tokens) + ], + dim=0, + ) total_num_draft_tokens = int(num_draft_tokens.sum()) total_num_logits = num_reqs + total_num_draft_tokens @@ -750,6 +761,7 @@ def prepare_inputs( logits_indices=logits_indices, cu_num_logits=cu_num_logits, cu_num_logits_np=cu_num_logits_np, + draft_logits=draft_logits, has_structured_output_reqs=scheduler_output.has_structured_output_requests, ) @@ -792,36 +804,37 @@ def sample( grammar_output.grammar_bitmask, ) - # Sample tokens and compute logprobs (if needed). - sampler_output = self.sampler( - logits, - input_batch.expanded_idx_mapping, - input_batch.idx_mapping_np, - input_batch.cu_num_logits_np, - sample_pos, - input_ids, - input_batch.expanded_local_pos, - ) - if input_batch.num_draft_tokens == 0: - # No draft tokens (common case). - num_sampled = torch.ones( - input_batch.num_reqs, dtype=torch.int32, device=self.device + # Sample tokens and compute logprobs (if needed). + sampler_output = self.sampler.sample( + logits, + input_batch.expanded_idx_mapping, + input_batch.idx_mapping_np, + input_batch.cu_num_logits_np, + sample_pos, + input_ids, + input_batch.expanded_local_pos, ) else: # Rejection sampling for spec decoding. - sampled_tokens, num_sampled = rejection_sample( - sampler_output.sampled_token_ids, - input_ids, + sampler_output = self.sampler.rejection_sample( + logits, + input_batch.draft_logits, input_batch.cu_num_logits, + input_batch.cu_num_logits_np, + input_batch.idx_mapping, + input_batch.expanded_idx_mapping, + input_batch.idx_mapping_np, + sample_pos, + input_ids, + input_batch.expanded_local_pos, self.num_speculative_steps, ) - sampler_output.sampled_token_ids = sampled_tokens # Get the number of sampled and rejected tokens. # For chunked prefills, num_sampled and num_rejected are both 0. num_sampled, num_rejected = get_num_sampled_and_rejected( - num_sampled, + sampler_output.num_sampled, input_batch.seq_lens, input_batch.cu_num_logits, input_batch.idx_mapping, @@ -935,6 +948,7 @@ def execute_model( input_batch = InputBatch.make_dummy( num_reqs=num_reqs, num_tokens=num_tokens_after_padding, + vocab_size=self.vocab_size, input_buffers=self.input_buffers, device=self.device, ) @@ -1090,7 +1104,7 @@ def sample_tokens( input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected ) if self.speculator is not None: - draft_tokens = self.speculator.propose( + speculation = self.speculator.propose( input_batch, hidden_states, aux_hidden_states, @@ -1101,8 +1115,15 @@ def sample_tokens( self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.seeds.gpu, ) - self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens - self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens) + self.req_states.draft_tokens[input_batch.idx_mapping] = ( + speculation.draft_tokens + ) + self.req_states.draft_logits[input_batch.idx_mapping] = ( + speculation.draft_logits + ) + self.draft_tokens_handler.set_draft_tokens( + input_batch, speculation.draft_tokens + ) if self.use_async_scheduling: return async_output diff --git a/vllm/v1/worker/gpu/sample/output.py b/vllm/v1/worker/gpu/sample/output.py index 13e8cf1d6c1e..69b26ad9f388 100644 --- a/vllm/v1/worker/gpu/sample/output.py +++ b/vllm/v1/worker/gpu/sample/output.py @@ -12,3 +12,4 @@ class SamplerOutput: sampled_token_ids: torch.Tensor logprobs_tensors: LogprobsTensors | None num_nans: torch.Tensor | None + num_sampled: torch.Tensor diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index 87b10bcc11a1..e79757242d6e 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -15,6 +15,12 @@ from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.penalties import PenaltiesState from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates +from vllm.v1.worker.gpu.spec_decode.rejection_sample import ( + rejection_sample as rejection_sample_functional, +) +from vllm.v1.worker.gpu.spec_decode.rejection_sample import ( + sample_recovered_and_bonus_tokens, +) from vllm.v1.worker.gpu.states import RequestState @@ -53,7 +59,7 @@ def apply_staged_writes(self) -> None: self.logit_bias_state.apply_staged_writes() self.bad_words_state.apply_staged_writes() - def __call__( + def sample( self, logits: torch.Tensor, idx_mapping: torch.Tensor, @@ -66,7 +72,7 @@ def __call__( # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear # that num_nans is computed before applying penalties and temperature. num_nans = get_num_nans(logits) if self.compute_nans else None - sampled, processed_logits = self.sample( + processed_logits = self._process_logits( logits, idx_mapping, idx_mapping_np, @@ -74,6 +80,14 @@ def __call__( input_ids, expanded_local_pos, ) + sampled = gumbel_sample( + logits, + idx_mapping, + self.sampling_states.temperature.gpu, + self.sampling_states.seeds.gpu, + pos, + apply_temperature=False, + ) max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np) if max_num_logprobs != NO_LOGPROBS: @@ -87,6 +101,10 @@ def __call__( else: logprobs_tensors = None + # No draft tokens (common case). + num_reqs = len(logits) + num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=logits.device) + # These are GPU tensors. sampler_output = SamplerOutput( # The sampled tokens are expanded to 2D tensor with shape @@ -95,10 +113,75 @@ def __call__( sampled_token_ids=sampled.view(-1, 1), logprobs_tensors=logprobs_tensors, num_nans=num_nans, + num_sampled=num_sampled, ) return sampler_output - def sample( + def rejection_sample( + self, + logits: torch.Tensor, # [num_draft_tokens + num_reqs, vocab_size] + draft_logits: torch.Tensor, # [num_draft_tokens + num_reqs] + cu_num_logits: torch.Tensor, # [num_reqs + 1] + cu_num_logits_np: np.ndarray, # [num_reqs + 1] + idx_mapping: torch.Tensor, # [max_num_reqs] + expanded_idx_mapping: torch.Tensor, # [num_draft_tokens + num_reqs] + idx_mapping_np: np.ndarray, # [num_draft_tokens + num_reqs] + pos: torch.Tensor, # [num_draft_tokens + num_reqs] + input_ids: torch.Tensor, # [num_draft_tokens + num_reqs] + expanded_local_pos: torch.Tensor, # [num_draft_tokens + num_reqs] + num_speculative_steps: int, + ) -> SamplerOutput: + # TODO: Check whether functions expect expanded idx_mapping or not + processed_logits = self._process_logits( + logits, + expanded_idx_mapping, + idx_mapping_np, + pos, + input_ids, + expanded_local_pos, + ) + processed_probs = torch.softmax(processed_logits, dim=-1) + draft_probs = torch.softmax(draft_logits, dim=-1) + recovered_ids = sample_recovered_and_bonus_tokens( + processed_probs, + draft_probs, + cu_num_logits, + expanded_idx_mapping, + self.sampling_states.temperature.gpu, + self.sampling_states.seeds.gpu, + pos, + ) + sampled, num_sampled = rejection_sample_functional( + input_ids, + recovered_ids, + processed_probs, + draft_probs, + cu_num_logits, + self.sampling_states.seeds.gpu, + pos, + self.sampling_states.temperature.gpu, + num_speculative_steps, + ) + max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np) + if max_num_logprobs != NO_LOGPROBS: + expanded_logits = logits.shape[0] != idx_mapping_np.shape[0] + cu_num_logits_list = cu_num_logits_np.tolist() if expanded_logits else None + logprobs_tensors = compute_topk_logprobs( + processed_logits, max_num_logprobs, sampled, cu_num_logits_list + ) + else: + logprobs_tensors = None + + # These are GPU tensors. + sampler_output = SamplerOutput( + sampled_token_ids=sampled, + logprobs_tensors=logprobs_tensors, + num_nans=None, + num_sampled=num_sampled, + ) + return sampler_output + + def _process_logits( self, logits: torch.Tensor, idx_mapping: torch.Tensor, @@ -106,7 +189,7 @@ def sample( pos: torch.Tensor, input_ids: torch.Tensor, expanded_local_pos: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: # Copy logits to a new FP32 tensor. logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits) @@ -143,13 +226,4 @@ def sample( logits, idx_mapping, idx_mapping_np ) - # Sample the next token. - sampled = gumbel_sample( - logits, - idx_mapping, - self.sampling_states.temperature.gpu, - self.sampling_states.seeds.gpu, - pos, - apply_temperature=False, - ) - return sampled, logits + return logits diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 6cd13cebf995..27da40174871 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -17,9 +17,10 @@ ) from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers -from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample +from vllm.v1.worker.gpu.sample.gumbel import apply_temperature, gumbel_sample from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model +from vllm.v1.worker.gpu.spec_decode.outputs import Speculation from vllm.v1.worker.utils import AttentionGroup logger = init_logger(__name__) @@ -69,6 +70,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): dtype=torch.int64, device=device, ) + self.draft_logits = torch.zeros( + self.max_num_reqs, + self.num_speculative_steps, + self.vocab_size, + dtype=torch.float32, + device=device, + ) self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device) @@ -139,19 +147,21 @@ def generate_draft( ) last_hidden_states = last_hidden_states[:num_reqs] hidden_states = hidden_states[:num_reqs] - logits = self.model.compute_logits(last_hidden_states) + draft_logits = self.model.compute_logits(last_hidden_states) + apply_temperature(draft_logits, idx_mapping, self.temperature) # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise # used for draft and target sampling. draft_tokens = gumbel_sample( - logits, + draft_logits, idx_mapping, self.temperature, self.seeds, pos + 1, - apply_temperature=True, + apply_temperature=False, ) self.draft_tokens[:num_reqs, step] = draft_tokens + self.draft_logits[:num_reqs, step] = draft_logits if step < self.num_speculative_steps - 1: # Update the inputs for the next step. @@ -198,7 +208,7 @@ def propose( temperature: torch.Tensor, # [max_num_reqs] seeds: torch.Tensor, - ) -> torch.Tensor: + ) -> Speculation: # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the # number of rejected tokens, we maintain the size of eagle's input_ids and # hidden_states the same as the target model's. This means, we pad each @@ -234,7 +244,7 @@ def propose( num_tokens_across_dp=None, # FIXME ) sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states) + draft_logits = self.model.compute_logits(sample_hidden_states) num_reqs = input_batch.num_reqs # NOTE(woosuk): For draft sampling, we only consider the temperature @@ -251,20 +261,26 @@ def propose( torch.gather(input_batch.positions, 0, last_token_indices, out=pos) # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise # used for draft and target sampling. + apply_temperature(draft_logits, idx_mapping, self.temperature) draft_tokens = gumbel_sample( - logits, + draft_logits, idx_mapping, self.temperature, self.seeds, pos + 1, - apply_temperature=True, + apply_temperature=False, ) if self.num_speculative_steps == 1: # Early exit. - return draft_tokens.view(-1, 1) + return Speculation( + draft_tokens=draft_tokens.view(-1, 1), + draft_logits=draft_logits.view(-1, 1, self.vocab_size), + ) - # Save the draft tokens for the first step. + # Save the draft tokens/probs for the first step. self.draft_tokens[:num_reqs, 0] = draft_tokens + self.draft_logits[:num_reqs, 0] = draft_logits + # Prepare the inputs for the decode steps. prepare_eagle_decode( draft_tokens, @@ -287,7 +303,10 @@ def propose( if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL: # Run full CUDA graph. self.cudagraph_manager.run_fullgraph(cudagraph_size) - return self.draft_tokens[:num_reqs] + return Speculation( + draft_tokens=self.draft_tokens[:num_reqs], + draft_logits=self.draft_logits[:num_reqs], + ) # Run eager or piecewise CUDA graph. num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs @@ -321,7 +340,10 @@ def propose( num_tokens_across_dp=None, # FIXME cudagraph_runtime_mode=cudagraph_mode, ) - return self.draft_tokens[:num_reqs] + return Speculation( + draft_tokens=self.draft_tokens[:num_reqs], + draft_logits=self.draft_logits[:num_reqs], + ) @triton.jit diff --git a/vllm/v1/worker/gpu/spec_decode/outputs.py b/vllm/v1/worker/gpu/spec_decode/outputs.py new file mode 100644 index 000000000000..07159b59e2b5 --- /dev/null +++ b/vllm/v1/worker/gpu/spec_decode/outputs.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + +import torch + + +@dataclass +class Speculation: + # [num_reqs, num_speculative_steps] + draft_tokens: torch.Tensor + # [num_reqs, num_speculative_steps, vocab_size] + draft_logits: torch.Tensor diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py index 8a7bf28bacbd..bfde4797c8c3 100644 --- a/vllm/v1/worker/gpu/spec_decode/rejection_sample.py +++ b/vllm/v1/worker/gpu/spec_decode/rejection_sample.py @@ -10,62 +10,222 @@ def _rejection_sample_kernel( sampled_ptr, # [num_reqs, num_speculative_steps + 1] sampled_stride, num_sampled_ptr, # [num_reqs] - target_sampled_ptr, # [num_draft_tokens + num_reqs] input_ids_ptr, # [num_draft_tokens + num_reqs] cu_num_logits_ptr, # [num_reqs + 1] + target_probs_ptr, # [num_draft_tokens + num_reqs, vocab_size] + target_probs_stride, + draft_probs_ptr, # [num_draft_tokens, vocab_size] + draft_probs_stride, + recovered_ids_ptr, # [num_draft_tokens + num_draft_tokens + num_reqs)] + seeds_ptr, # [num_reqs] + pos_ptr, # [num_draft_tokens + num_reqs] + temperature_ptr, # [max_num_reqs] ): req_idx = tl.program_id(0) start_idx = tl.load(cu_num_logits_ptr + req_idx) end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) + seed = tl.load(seeds_ptr + req_idx) + temperature = tl.load(temperature_ptr + req_idx) num_tokens = end_idx - start_idx + is_zero_temp = temperature == 0.0 num_sampled = 0 rejected = False for i in range(num_tokens - 1): if not rejected: - target_sampled = tl.load(target_sampled_ptr + start_idx + i) - draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1) - tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled) + draft_id = tl.load(input_ids_ptr + start_idx + i + 1) + if is_zero_temp: + token_id = tl.load(recovered_ids_ptr + start_idx + i).to(tl.int32) + if token_id != draft_id: + rejected = True + else: + target_prob = tl.load( + target_probs_ptr + (start_idx + i) * target_probs_stride + draft_id + ) + draft_prob = tl.load( + draft_probs_ptr + + (start_idx + i - req_idx) * draft_probs_stride + + draft_id + ) + pos = tl.load(pos_ptr + start_idx + i) + u = tl.rand(seed=seed, offset=pos) + if target_prob >= u * draft_prob: # Accept + token_id = draft_id + else: # Reject + token_id = tl.load(recovered_ids_ptr + start_idx + i).to(tl.int32) + rejected = True + tl.store(sampled_ptr + req_idx * sampled_stride + i, token_id) num_sampled += 1 - if target_sampled != draft_sampled: - rejected = True if not rejected: - target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1) - tl.store( - sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled - ) + bonus_id = tl.load(recovered_ids_ptr + start_idx + num_tokens - 1) + tl.store(sampled_ptr + req_idx * sampled_stride + num_tokens - 1, bonus_id) num_sampled += 1 tl.store(num_sampled_ptr + req_idx, num_sampled) def rejection_sample( - # [num_draft_tokens + num_reqs] - target_sampled: torch.Tensor, # [num_draft_tokens + num_reqs] input_ids: torch.Tensor, + # [num_draft_tokens + num_reqs] + recovered_ids: torch.Tensor, + # [num_draft_tokens + num_reqs, vocab_size] + target_probs: torch.Tensor, + # [num_draft_tokens, vocab_size], + draft_probs: torch.Tensor, # [num_reqs + 1] cu_num_logits: torch.Tensor, + # [num_reqs] + seeds: torch.Tensor, + # [num_draft_tokens + num_reqs] + pos: torch.Tensor, + # [max_num_reqs] + temperature: torch.Tensor, num_speculative_steps: int, ) -> tuple[torch.Tensor, torch.Tensor]: num_reqs = cu_num_logits.shape[0] - 1 sampled = torch.empty( num_reqs, num_speculative_steps + 1, - dtype=target_sampled.dtype, - device=target_sampled.device, + dtype=input_ids.dtype, + device=input_ids.device, ) num_sampled = torch.empty( num_reqs, dtype=torch.int32, - device=target_sampled.device, + device=input_ids.device, ) _rejection_sample_kernel[(num_reqs,)]( sampled, sampled.stride(0), num_sampled, - target_sampled, input_ids, cu_num_logits, + target_probs, + target_probs.stride(0), + draft_probs, + draft_probs.stride(0), + recovered_ids, + seeds, + pos, + temperature, num_warps=1, ) return sampled, num_sampled + + +@triton.jit +def _sample_recovered_and_bonus_tokens_kernel( + local_argmax_ptr, + local_argmax_stride, + local_max_ptr, + local_max_stride, + target_probs_ptr, + target_probs_stride, + draft_probs_ptr, + draft_probs_stride, + cu_num_logits_ptr, + temperature_ptr, + seeds_ptr, + pos_ptr, + idx_mapping_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + token_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + token_idx) + end_idx = tl.load(cu_num_logits_ptr + req_state_idx + 1) + seed = tl.load(seeds_ptr + req_state_idx) + pos = tl.load(pos_ptr + token_idx) + temperature = tl.load(temperature_ptr + req_state_idx) + is_bonus_token_idx = token_idx == end_idx - 1 + + block_idx = tl.program_id(1) + block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + target_probs = tl.load( + target_probs_ptr + token_idx * target_probs_stride + block, + mask=mask, + other=0.0, + ) + + if temperature == 0.0: + # Sample max of target_prons + value, idx = tl.max(target_probs, axis=0, return_indices=True) + else: + # Generate exponential noise + # Calculate the seed for exponential noise. + gumbel_seed = tl.randint(seed, pos) + u = tl.rand(gumbel_seed, block) + u = tl.maximum(u, 1e-7) + exp_noise = -tl.log(u) + + if is_bonus_token_idx: + # Sample from target_probs + value, idx = tl.max(target_probs / exp_noise, axis=0, return_indices=True) + else: + # Sample from max(target_probs - draft_probs, 0) + draft_probs = tl.load( + draft_probs_ptr + + (token_idx - req_state_idx) * draft_probs_stride + + block, + mask=mask, + other=0.0, + ) + target_probs -= draft_probs + # No need to clamp 0 because the maximum is guaranteed to be >= 0 anyway + value, idx = tl.max(target_probs / exp_noise, axis=0, return_indices=True) + + token_id = block_idx * BLOCK_SIZE + idx + tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id) + tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value) + + +def sample_recovered_and_bonus_tokens( + target_probs: torch.Tensor, # [num_draft_tokens + num_reqs, vocab_size] + draft_probs: torch.Tensor, # [num_draft_tokens, vocab_size] + cu_num_logits: torch.Tensor, # [num_reqs + 1] + idx_mapping: torch.Tensor, # [num_draft_tokens + num_reqs] + temperature: torch.Tensor, # [max_num_reqs] + seed: torch.Tensor, # [max_num_reqs] + pos: torch.Tensor, # [num_reqs] +) -> torch.Tensor: # [num_draft_tokens + num_draft_tokens + num_reqs] + """Returned a packed tensor of: + - recovered_ids [num_draft_tokens] + - sampled_ids [num_draft_tokens + num_reqs] + """ + num_tokens, vocab_size = target_probs.shape + BLOCK_SIZE = 1024 + num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) + local_argmax = torch.empty( + num_tokens, + num_blocks, + dtype=torch.int64, + device=target_probs.device, + ) + local_max = torch.empty( + num_tokens, + num_blocks, + dtype=torch.float32, + device=target_probs.device, + ) + _sample_recovered_and_bonus_tokens_kernel[(num_tokens, num_blocks)]( + local_argmax, + local_argmax.stride(0), + local_max, + local_max.stride(0), + target_probs, + target_probs.stride(0), + draft_probs, + draft_probs.stride(0), + cu_num_logits, + temperature, + seed, + pos, + idx_mapping, + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + # NOTE(woosuk): Use int64 for later indexing. + max_block_idx = local_max.argmax(dim=-1, keepdim=True) + sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1) + return sampled diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index b338d32a3e39..cb9e1ad6e0c5 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -70,6 +70,13 @@ def __init__( dtype=torch.int64, device=device, ) + self.draft_logits = torch.zeros( + self.max_num_reqs, + self.num_speculative_steps, + self.vocab_size, + dtype=torch.float32, + device=device, + ) self.next_prefill_tokens = torch.zeros( self.max_num_reqs, dtype=torch.int32, device=device )