diff --git a/vllm/v1/worker/gpu/sample/gumbel.py b/vllm/v1/worker/gpu/sample/gumbel.py index 62912491492e..a02dd62026ad 100644 --- a/vllm/v1/worker/gpu/sample/gumbel.py +++ b/vllm/v1/worker/gpu/sample/gumbel.py @@ -76,6 +76,8 @@ def gumbel_block_argmax( pos_ptr, processed_logits_ptr, processed_logits_stride, + processed_logits_col_ptr, + vocab_size, APPLY_TEMPERATURE: tl.constexpr, ): req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx) @@ -88,8 +90,15 @@ def gumbel_block_argmax( if processed_logits_ptr is not None: # Store the temperature-applied logits. + if processed_logits_col_ptr is not None: + col = tl.load(processed_logits_col_ptr) + else: + col = 0 tl.store( - processed_logits_ptr + req_state_idx * processed_logits_stride + block, + processed_logits_ptr + + req_state_idx * processed_logits_stride + + col * vocab_size + + block, logits, mask=mask, ) @@ -121,6 +130,7 @@ def _gumbel_sample_kernel( local_max_stride, processed_logits_ptr, processed_logits_stride, + processed_logits_col_ptr, logits_ptr, logits_stride, expanded_idx_mapping_ptr, @@ -153,6 +163,8 @@ def _gumbel_sample_kernel( pos_ptr, processed_logits_ptr, processed_logits_stride, + processed_logits_col_ptr, + vocab_size, APPLY_TEMPERATURE=APPLY_TEMPERATURE, ) token_id = block_idx * BLOCK_SIZE + idx @@ -167,7 +179,8 @@ def gumbel_sample( seed: torch.Tensor, # [max_num_reqs] pos: torch.Tensor, # [num_tokens] apply_temperature: bool, - processed_logits_out: torch.Tensor | None = None, # [num_reqs, vocab_size] + output_processed_logits: torch.Tensor | None = None, + output_processed_logits_col: torch.Tensor | None = None, ) -> torch.Tensor: num_tokens, vocab_size = logits.shape BLOCK_SIZE = 1024 @@ -179,8 +192,9 @@ def gumbel_sample( local_argmax.stride(0), local_max, local_max.stride(0), - processed_logits_out, - processed_logits_out.stride(0) if processed_logits_out is not None else 0, + output_processed_logits, + output_processed_logits.stride(0) if output_processed_logits is not None else 0, + output_processed_logits_col, logits, logits.stride(0), expanded_idx_mapping, diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index c6b0aa364f53..efe510f16e22 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -89,9 +89,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): dtype=torch.int64, device=device, ) + self.current_draft_step = torch.tensor(0, dtype=torch.int64, device=device) self.last_token_indices = torch.zeros( self.max_num_reqs, dtype=torch.int64, device=device ) + self.arange = torch.arange( + self.max_num_reqs + 1, dtype=torch.int32, device="cpu" + ) self.supports_mm_inputs = MULTIMODAL_REGISTRY.supports_multimodal_inputs( self.draft_model_config @@ -228,9 +232,10 @@ def _sample_draft( logits: torch.Tensor, idx_mapping: torch.Tensor, pos: torch.Tensor, - step: int, + draft_step: torch.Tensor, + draft_logits: torch.Tensor | None, ) -> torch.Tensor: - if self.draft_logits is not None: + if draft_logits is not None: # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise # used for draft and target sampling. return gumbel_sample( @@ -240,7 +245,8 @@ def _sample_draft( self.seeds, pos + 1, apply_temperature=True, - processed_logits_out=self.draft_logits[:, step], + output_processed_logits=draft_logits, + output_processed_logits_col=draft_step, ) else: return logits.argmax(dim=-1) @@ -274,11 +280,63 @@ def prefill( logits, idx_mapping, pos, - step=0, + self.current_draft_step, + self.draft_logits, ) self.hidden_states[:num_reqs] = hidden_states[last_token_indices] self.input_buffers.positions[:num_reqs] = pos + def multi_step_decode( + self, + num_reqs: int, + skip_attn: bool, + batch_desc: BatchExecutionDescriptor, + num_tokens_across_dp: torch.Tensor | None, + ) -> None: + positions = self.input_buffers.positions[:num_reqs] + query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] + idx_mapping = self.idx_mapping[:num_reqs] + + for step in range(1, self.num_speculative_steps): + attn_metadata = None + slot_mappings_by_layer = None + if not skip_attn: + # Build attention metadata and slot mappings for each draft + # decode step. It is necessary to rebuild the attention + # metadata even when replaying the FULL graph so that any + # attention metadata builder state is updated. + slot_mappings = self.block_tables.compute_slot_mappings( + idx_mapping, + query_start_loc, + positions, + batch_desc.num_tokens, + ) + slot_mappings_by_layer = build_slot_mappings_by_layer( + slot_mappings, self.kv_cache_config + ) + attn_metadata = self._build_draft_attn_metadata( + num_reqs=num_reqs, + num_reqs_padded=batch_desc.num_reqs or num_reqs, + num_tokens_padded=batch_desc.num_tokens, + ) + + # Update the current draft step. + self.current_draft_step.fill_(step) + + # Generate draft tokens for the current step. + if batch_desc.cg_mode == CUDAGraphMode.FULL: + assert self.decode_cudagraph_manager is not None + self.decode_cudagraph_manager.run_fullgraph(batch_desc) + else: + self.generate_draft( + num_reqs, + batch_desc.num_tokens, + attn_metadata, + slot_mappings_by_layer, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=batch_desc.cg_mode, + ) + def generate_draft( self, num_reqs: int, @@ -288,59 +346,52 @@ def generate_draft( num_tokens_across_dp: torch.Tensor | None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, ) -> None: - pos = self.input_buffers.positions[:num_reqs] - query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] idx_mapping = self.idx_mapping[:num_reqs] - for step in range(1, self.num_speculative_steps): - # Run the eagle model. - last_hidden_states, hidden_states = self.run_model( - num_tokens_padded, - attn_metadata, - slot_mappings, - num_tokens_across_dp, - cudagraph_runtime_mode, - ) - last_hidden_states = last_hidden_states[:num_reqs] - hidden_states = hidden_states[:num_reqs] - logits = self.model.compute_logits(last_hidden_states) + positions = self.input_buffers.positions[:num_reqs] + # Run the eagle model forward pass. + last_hidden_states, hidden_states = self.run_model( + num_tokens_padded, + attn_metadata, + slot_mappings, + num_tokens_across_dp, + cudagraph_runtime_mode, + ) + last_hidden_states = last_hidden_states[:num_reqs] - draft_tokens = self._sample_draft( - logits, - idx_mapping, - pos, - step=step, - ) - self.draft_tokens[:num_reqs, step] = draft_tokens - - if step < self.num_speculative_steps - 1: - # Update the inputs for the next step. - update_eagle_inputs( - draft_tokens, - hidden_states, - self.input_buffers, - self.hidden_states, - self.max_model_len, - ) - if attn_metadata is not None: - self.block_tables.compute_slot_mappings( - idx_mapping, query_start_loc, pos, num_tokens_padded - ) + # Sample the draft tokens. + logits = self.model.compute_logits(last_hidden_states) + draft_tokens = self._sample_draft( + logits, + idx_mapping, + positions, + self.current_draft_step, + self.draft_logits, + ) + + # Update the inputs for the next step. + update_eagle_draft_inputs( + draft_tokens, + self.current_draft_step, + hidden_states, + self.draft_tokens, + self.hidden_states, + self.input_buffers, + num_reqs, + self.max_model_len, + self.num_speculative_steps, + ) def _build_draft_attn_metadata( self, num_reqs: int, num_reqs_padded: int, num_tokens_padded: int, - max_query_len: int, ) -> dict[str, Any] | None: if not self.draft_attn_layer_names: return None - query_start_loc_cpu = ( - torch.arange(num_reqs_padded + 1, dtype=torch.int32, device="cpu").clamp_( - max=num_reqs - ) - * max_query_len + query_start_loc_cpu = torch.clamp( + self.arange[: num_reqs_padded + 1], max=num_reqs ) block_tables = [ x[:num_reqs_padded] for x in self.block_tables.input_block_tables @@ -354,7 +405,7 @@ def _build_draft_attn_metadata( : num_reqs_padded + 1 ], query_start_loc_cpu=query_start_loc_cpu, - max_query_len=max_query_len, + max_query_len=1, seq_lens=self.input_buffers.seq_lens[:num_reqs_padded], max_seq_len=self.max_model_len, block_tables=block_tables, @@ -373,7 +424,7 @@ def capture( self.last_token_indices.zero_() # Capture the prefill routine (model forward + compute_logits + - # gumbel_sample). + # sample). # For FULL graphs, the entire routine is recorded as one graph. # For PIECEWISE, only the model's compiled regions are captured # and the rest (compute_logits, gumbel_sample) runs eagerly. @@ -387,10 +438,9 @@ def capture( if self.num_speculative_steps == 1: return - # Capture the decode draft generation loop (model forward + - # compute_logits + gumbel_sample + update_eagle_inputs, for - # each step). For FULL graphs, the entire multi-step loop is - # recorded as one graph. + # Capture the decode draft generation routine (model forward + + # compute_logits + sample + update_eagle_inputs) for a single + # step. assert self.decode_cudagraph_manager is not None self.decode_cudagraph_manager.capture( self.generate_draft, @@ -461,9 +511,10 @@ def propose( # Get the input ids and last token indices for the speculator. prepare_eagle_inputs( + self.last_token_indices, + self.current_draft_step, self.input_buffers, input_batch, - self.last_token_indices, num_sampled, num_rejected, last_sampled, @@ -473,12 +524,18 @@ def propose( # When all requests are decoding (no true prefills), each has # num_speculative_steps + 1 tokens, enabling FULL graph replay. - # Mixed or prefill-only batches fall back to PIECEWISE. + uniform_token_count = get_uniform_token_count( + num_reqs, + # Use the actual number of tokens without padding added by + # the target model during FULL cudagraph. + input_batch.num_tokens, + max_query_len, + ) prefill_batch_desc, num_tokens_across_dp = dispatch_cg_and_sync_dp( self.prefill_cudagraph_manager, num_reqs, num_tokens, - get_uniform_token_count(num_reqs, num_tokens, max_query_len), + uniform_token_count, dp_size=self.dp_size, dp_rank=self.dp_rank, need_eager=is_profile, @@ -528,48 +585,21 @@ def propose( need_eager=is_profile, ) - attn_metadata_updated = None - slot_mappings_updated = None - if not (dummy_run and skip_attn_for_dummy_run): - # Build attention metadata and slot mappings for the draft - # decode steps. It is necessary to rebuild the attention - # metadata even when replaying the FULL graph so that any - # attention metadata builder state is updated. - slot_mappings = self.block_tables.compute_slot_mappings( - self.idx_mapping[:num_reqs], - self.input_buffers.query_start_loc[: num_reqs + 1], - self.input_buffers.positions[:num_reqs], - decode_batch_desc.num_tokens, - ) - slot_mappings_updated = build_slot_mappings_by_layer( - slot_mappings, self.kv_cache_config - ) - attn_metadata_updated = self._build_draft_attn_metadata( - num_reqs=num_reqs, - num_reqs_padded=decode_batch_desc.num_reqs or num_reqs, - num_tokens_padded=decode_batch_desc.num_tokens, - max_query_len=1, - ) + # Generate the remaining num_speculative_steps - 1 draft tokens. + self.multi_step_decode( + num_reqs, + dummy_run and skip_attn_for_dummy_run, + decode_batch_desc, + num_tokens_across_dp, + ) - if decode_batch_desc.cg_mode == CUDAGraphMode.FULL: - # Replay the full graph for draft generation. - assert self.decode_cudagraph_manager is not None - self.decode_cudagraph_manager.run_fullgraph(decode_batch_desc) - else: - self.generate_draft( - num_reqs, - decode_batch_desc.num_tokens, - attn_metadata_updated, - slot_mappings_updated, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=decode_batch_desc.cg_mode, - ) return self.draft_tokens[:num_reqs] @triton.jit def _prepare_eagle_inputs_kernel( last_token_indices_ptr, + eagle_current_draft_step_ptr, eagle_input_ids_ptr, eagle_positions_ptr, eagle_query_start_loc_ptr, @@ -630,6 +660,8 @@ def _prepare_eagle_inputs_kernel( # Copy sequence lengths. tl.store(eagle_seq_lens_ptr + req_idx, seq_len) if req_idx == (num_reqs - 1): + # Reset the current draft step to 0. + tl.store(eagle_current_draft_step_ptr, 0) # Pad query_start_loc for CUDA graphs. for i in range(num_reqs, max_num_reqs + 1, BLOCK_SIZE): block = i + tl.arange(0, BLOCK_SIZE) @@ -648,10 +680,11 @@ def _prepare_eagle_inputs_kernel( def prepare_eagle_inputs( - input_buffers: InputBuffers, - input_batch: InputBatch, # [num_reqs] last_token_indices: torch.Tensor, + current_draft_step: torch.Tensor, + input_buffers: InputBuffers, + input_batch: InputBatch, # [num_reqs] num_sampled: torch.Tensor, # [num_reqs] @@ -665,6 +698,7 @@ def prepare_eagle_inputs( num_reqs = input_batch.num_reqs _prepare_eagle_inputs_kernel[(num_reqs,)]( last_token_indices, + current_draft_step, input_buffers.input_ids, input_buffers.positions, input_buffers.query_start_loc, @@ -685,7 +719,7 @@ def prepare_eagle_inputs( @triton.jit -def _prepare_eagle_docode_kernel( +def _prepare_eagle_decode_kernel( draft_tokens_ptr, draft_tokens_stride, target_seq_lens_ptr, @@ -742,7 +776,7 @@ def prepare_eagle_decode( max_num_reqs: int, ): num_reqs = draft_tokens.shape[0] - _prepare_eagle_docode_kernel[(num_reqs + 1,)]( + _prepare_eagle_decode_kernel[(num_reqs + 1,)]( draft_tokens, draft_tokens.stride(0), target_seq_lens, @@ -758,36 +792,55 @@ def prepare_eagle_decode( @triton.jit -def _update_eagle_inputs_kernel( +def _update_eagle_draft_inputs_kernel( + output_draft_tokens_ptr, + output_draft_tokens_stride, + next_input_hidden_states_ptr, + next_input_hidden_states_stride, input_ids_ptr, positions_ptr, - input_hidden_states_ptr, - input_hidden_states_stride, seq_lens_ptr, - max_model_len, draft_tokens_ptr, - output_hidden_states_ptr, - output_hidden_states_stride, + current_draft_step_ptr, + hidden_states_ptr, + hidden_states_stride, hidden_size, + max_model_len, + num_speculative_steps, BLOCK_SIZE: tl.constexpr, ): req_idx = tl.program_id(0) - # Draft token -> Input ID. + # Write the sampled draft token into self.draft_tokens[req_idx, step]. draft_token = tl.load(draft_tokens_ptr + req_idx) + step = tl.load(current_draft_step_ptr) + tl.store( + output_draft_tokens_ptr + req_idx * output_draft_tokens_stride + step, + draft_token, + ) + + if step >= num_speculative_steps - 1: + # This is the final step. Skip updating draft forward inputs. + return + + # Write the sampled draft token into the input ids tensor for the next + # forward pass. tl.store(input_ids_ptr + req_idx, draft_token) - # Output hidden states -> Input hidden states. + # Copy hidden states into the input hidden states tensor for the next + # forward pass. for i in range(0, hidden_size, BLOCK_SIZE): block = i + tl.arange(0, BLOCK_SIZE) mask = block < hidden_size - output_hidden_states = tl.load( - output_hidden_states_ptr + req_idx * output_hidden_states_stride + block, + hidden_states = tl.load( + hidden_states_ptr + req_idx * hidden_states_stride + block, mask=mask, ) tl.store( - input_hidden_states_ptr + req_idx * input_hidden_states_stride + block, - output_hidden_states, + next_input_hidden_states_ptr + + req_idx * next_input_hidden_states_stride + + block, + hidden_states, mask=mask, ) @@ -803,24 +856,32 @@ def _update_eagle_inputs_kernel( tl.store(seq_lens_ptr + req_idx, seq_len) -def update_eagle_inputs( +def update_eagle_draft_inputs( draft_tokens: torch.Tensor, - output_hidden_states: torch.Tensor, - input_buffers: InputBuffers, + current_draft_step: torch.Tensor, hidden_states: torch.Tensor, + output_draft_tokens: torch.Tensor, + next_input_hidden_states: torch.Tensor, + input_buffers: InputBuffers, + num_reqs: int, max_model_len: int, + num_speculative_steps: int, ): - num_reqs, hidden_size = output_hidden_states.shape - _update_eagle_inputs_kernel[(num_reqs,)]( + _, hidden_size = hidden_states.shape + _update_eagle_draft_inputs_kernel[(num_reqs,)]( + output_draft_tokens, + output_draft_tokens.stride(0), + next_input_hidden_states, + next_input_hidden_states.stride(0), input_buffers.input_ids, input_buffers.positions, - hidden_states, - hidden_states.stride(0), input_buffers.seq_lens, - max_model_len, draft_tokens, - output_hidden_states, - output_hidden_states.stride(0), + current_draft_step, + hidden_states, + hidden_states.stride(0), hidden_size, + max_model_len, + num_speculative_steps, BLOCK_SIZE=1024, ) diff --git a/vllm/v1/worker/gpu/spec_decode/probabilistic_rejection_sampler_utils.py b/vllm/v1/worker/gpu/spec_decode/probabilistic_rejection_sampler_utils.py index 9d86372e624b..10b29433efb2 100644 --- a/vllm/v1/worker/gpu/spec_decode/probabilistic_rejection_sampler_utils.py +++ b/vllm/v1/worker/gpu/spec_decode/probabilistic_rejection_sampler_utils.py @@ -392,8 +392,10 @@ def _resample_kernel( temp_ptr, seed_ptr, pos_ptr, - None, - 0, + None, # processed_logits_ptr + 0, # processed_logits_stride + None, # processed_logits_col_ptr + vocab_size, APPLY_TEMPERATURE=False, ) token_id = block_idx * BLOCK_SIZE + idx