diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 2829fcb18ff..500f1718a85 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1183,13 +1183,29 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]): num_ctx_requests = inputs['attn_metadata'].num_contexts num_gen_requests = inputs['attn_metadata'].num_generations num_ctx_tokens = inputs['attn_metadata'].num_ctx_tokens + num_chunked_ctx_requests = inputs[ + 'attn_metadata'].num_chunked_ctx_requests previous_batch_tokens = inputs['input_ids'].shape[ 0] - num_ctx_tokens inputs['position_ids'][0, num_ctx_tokens:] += ( self.previous_pos_id_offsets_cuda[:previous_batch_tokens]) - inputs['attn_metadata'].kv_lens_cuda[ - num_ctx_requests:num_seqs] += ( - self.previous_kv_lens_offsets_cuda[:num_gen_requests]) + # Only TrtllmAttentionMetadata has kv_lens_cuda. + if isinstance(inputs['attn_metadata'], TrtllmAttentionMetadata): + if num_chunked_ctx_requests > 0: + # The generation requests with draft_tokens are treated as chunked context requests when extend_ctx returns True. + inputs['attn_metadata'].kv_lens_cuda[ + num_ctx_requests - + num_chunked_ctx_requests:num_ctx_requests] += ( + self. + previous_kv_lens_offsets_cuda[: + num_chunked_ctx_requests] + ) + else: + inputs['attn_metadata'].kv_lens_cuda[ + num_ctx_requests:num_seqs] += ( + self. + previous_kv_lens_offsets_cuda[:num_gen_requests] + ) if self.guided_decoder is not None: self.guided_decoder.token_event.record() @@ -1285,8 +1301,9 @@ def _prepare_tp_inputs( if new_tensors_device is not None: # speculative decoding cases: [batch, 1 + draft_len], others: [batch] new_tokens_device = new_tensors_device.new_tokens - if self.without_logits: - assert isinstance(new_tensors_device, SampleStateTensorsMTP) + # When using overlap scheduler with speculative decoding, the target model's inputs would be SampleStateTensorsMTP. + if isinstance(new_tensors_device, SampleStateTensorsMTP): + assert self.enable_spec_decode and not self.is_draft_model new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch] next_draft_tokens_device = new_tensors_device.next_draft_tokens # [batch, draft_len] @@ -1453,9 +1470,19 @@ def _prepare_tp_inputs( previous_batch_indices.append(previous_batch_idx) previous_pos_indices.extend([previous_batch_idx] * (1 + self.runtime_draft_len)) - num_cached_tokens_per_seq.append(past_seen_token_num + - self.runtime_draft_len + 1) - prompt_lengths.append(request.py_prompt_len) + if self.spec_config.spec_dec_mode.has_draft_model(): + # In the overlap scheduler workflow, if having draft model, we already updated the previous batch before launching the target model, + # so we only need to add the runtime_draft_len to the past_seen_token_num. + num_cached_tokens_per_seq.append(past_seen_token_num + + self.runtime_draft_len) + else: + num_cached_tokens_per_seq.append(past_seen_token_num + + self.runtime_draft_len + 1) + if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx( + self.attn_backend): + prompt_lengths.append(1 + self.runtime_draft_len) + else: + prompt_lengths.append(request.py_prompt_len) for request in generation_requests: request_ids.append(request.py_request_id) @@ -1637,9 +1664,13 @@ def previous_seq_slots_device(): attn_metadata.request_ids = request_ids attn_metadata.prompt_lens = prompt_lengths attn_metadata.num_contexts = len(scheduled_requests.context_requests) + # Use num_chunked_ctx_requests to record the number of extend context requests, + # so that we can update the kv_lens_cuda correctly in _preprocess_inputs. + attn_metadata.num_chunked_ctx_requests = 0 if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx( self.attn_backend): attn_metadata.num_contexts += len(extend_requests) + attn_metadata.num_chunked_ctx_requests = len(extend_requests) attn_metadata.kv_cache_params = KVCacheParams( use_cache=True, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 5c5e573492b..7fa7339e61b 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -44,7 +44,7 @@ from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_transceiver import KvCacheTransceiver from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, - LlmResponse, get_draft_token_length) + LlmResponse) from .model_engine import ModelEngine from .sampler import Sampler, SampleState, SampleStateTensors from .scheduler import RequestScheduler, ScheduledRequests @@ -220,6 +220,7 @@ def __init__(self, self.expected_num_active_requests = 0 self.ctx_in_transmission_requests = [] self.previous_batch: Optional[BatchState] = None + self.has_previous_draft_tokens = False self.num_scheduled_requests: int = 0 self.benchmark_req_queues_size = int( os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0)) @@ -278,11 +279,10 @@ def __init__(self, self.event_loop = trace_func(self.event_loop) if self.drafter is not None: - if self.event_loop.__name__ != self._executor_loop.__name__: + if self.event_loop.__name__ == self._executor_loop_pp.__name__: raise NotImplementedError( "Drafting is not supported for selected executor loop. " - "Please disable disagg/pipeline parallelism/overlap scheduler." - ) + "Please disable disagg/pipeline parallelism scheduler.") self.draft_seq_slot_manager = SeqSlotManager(max_num_sequences) self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold self.max_seq_len = max_seq_len @@ -967,11 +967,15 @@ def _prepare_and_schedule_batch(self): self.model_engine.max_num_tokens, self.model_engine.spec_config.max_draft_len) self.model_engine.enable_spec_decode = self.use_spec_decode - # If speculation is off, this function sets py_draft_tokens to None - # for all active requests. If it's on, we initialize py_draft_tokens - # with dummy draft tokens to make the scheduler aware of the fact - # that speculation is about to happen. - self._prepare_draft_requests() + + # When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch, + # we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet. + if not self.has_previous_draft_tokens: + # If speculation is off, this function sets py_draft_tokens to None + # for all active requests. If it's on, we initialize py_draft_tokens + # with dummy draft tokens to make the scheduler aware of the fact + # that speculation is about to happen. + self._prepare_draft_requests() scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( ) @@ -1063,14 +1067,9 @@ def _executor_loop(self): scheduled_requests=scheduled_batch): self.drafter.prepare_draft_tokens( scheduled_batch, self.resource_manager) - # Pad draft tokens to the max draft length. This is for CUDA - # graph compatibility. - for req in scheduled_batch.generation_requests: - max_draft_tokens = self.max_draft_len - num_draft_tokens = get_draft_token_length(req) - req.py_draft_tokens.extend( - 0 for _ in range(max_draft_tokens - - num_draft_tokens)) + # Pad draft tokens to the max draft length. This is for CUDA graph compatibility. + self.drafter.pad_draft_tokens_for_cuda_graph( + scheduled_batch) # add_batch must be called again to restore to target requests with updated draft tokens. if self.guided_decoder is not None: self.guided_decoder.add_batch(scheduled_batch) @@ -1196,12 +1195,27 @@ def _executor_loop_overlap(self): self.guided_decoder.add_batch(scheduled_batch) self.guided_decoder.init_disagg_gen_requests() - previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device + previous_tensors = self.previous_batch and self.previous_batch.sample_state + target_inputs = None + draft_outputs = None + if self.drafter is not None and self.use_spec_decode: + target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding( + scheduled_batch, previous_tensors) + + # Use the draft_model's outputs if we've launched the draft model. + # Otherwise, use the previous batch's outputs. + if target_inputs is not None: + previous_tensors_device = target_inputs + else: + previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device batch_outputs = self._forward_step(scheduled_batch, previous_tensors_device) - if self.previous_batch is not None: + if target_inputs is not None: + self._process_draft_results(scheduled_batch, + draft_outputs, draft_batch) + elif self.previous_batch is not None: self._update_requests(self.previous_batch.sample_state) if self.guided_decoder is not None: @@ -1221,7 +1235,6 @@ def _executor_loop_overlap(self): if self.previous_batch is not None: self._process_previous_batch() - self.previous_batch: Optional[BatchState] = None if self.enable_iter_perf_stats: iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ @@ -1954,6 +1967,57 @@ def _remove_inflight_ids(self, scheduled_requests): for req in scheduled_requests.all_requests(): self.inflight_req_ids.erase(req.request_id) + def _handle_speculative_decoding(self, scheduled_batch, previous_tensors): + with request_context(is_draft=True, scheduled_requests=scheduled_batch): + # Do an early checking to see if we need to forward the draft model. + # If needed, the overlap should happen between the target requests and the draft requests. + # Otherwise, we can still do overlap between the previous target requests and the current target requests. + has_draft_batch = ( + self.previous_batch is not None + and self.drafter.should_forward_draft_model(scheduled_batch)) + + if has_draft_batch: + self._update_requests(self.previous_batch.sample_state) + if self.has_previous_draft_tokens: + self._prepare_draft_requests() + + target_inputs, draft_outputs, draft_batch = self.drafter.generate_draft_tokens_with_overlap( + scheduled_batch, self.resource_manager, + previous_tensors.device if previous_tensors else None) + + self.has_previous_draft_tokens = target_inputs is not None and target_inputs.next_draft_tokens is not None + else: + self.has_previous_draft_tokens = False + target_inputs, draft_outputs, draft_batch = None, None, None + + return target_inputs, draft_outputs, draft_batch + + def _process_draft_results(self, scheduled_batch, draft_outputs, + draft_batch): + """ + Append the draft tokens to the target requests, and clean up the draft resources. + """ + req_id_to_old_request = { + req.py_request_id: req + for req in scheduled_batch.all_requests() + } + + if self.drafter.use_static_draft_loop: + self.drafter.process_static_draft_outputs(draft_outputs, + draft_batch, + req_id_to_old_request) + elif draft_outputs is not None: + self.drafter.process_dynamic_draft_outputs(draft_outputs, + req_id_to_old_request) + + # Pad draft tokens to the max draft length. This is for CUDA graph compatibility. + self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch) + # add_batch must be called again to restore to target requests with updated draft tokens. + if self.guided_decoder is not None: + self.guided_decoder.add_batch(scheduled_batch) + if hasattr(self.drafter, "guided_decoder"): + self.guided_decoder.rollback_draft_tokens() + class DisaggPPTerminationHandler: """Handles termination synchronization across pipeline parallel ranks under disaggregated serving. diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index cb045f53791..7ef277b5273 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -1,6 +1,7 @@ import copy import enum import importlib +import os from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from dataclasses import dataclass @@ -37,6 +38,13 @@ from .py_executor import PyExecutor +# Development flag to control chain drafter feature +def _get_allow_chain_drafter() -> bool: + """Get the chain drafter flag from environment variable.""" + # Use environment variable for cross-process compatibility + return os.getenv("TRTLLM_ALLOW_CHAIN_DRAFTER", "0") == "1" + + class _ExecutorCreationStage(enum.Enum): SAMPLER = "Sampler" DRAFTER = "Drafter" @@ -282,11 +290,15 @@ def create_py_executor( # generation requests when we invoke it autoregressively draft_spec_config.max_draft_len = 0 - use_chain_drafter = ( - executor_config.guided_decoding_config is None - and not pytorch_backend_config.enable_mixed_sampler - and pytorch_backend_config.attn_backend == "TRTLLM") + if _get_allow_chain_drafter(): + use_chain_drafter = ( + executor_config.guided_decoding_config is None + and not pytorch_backend_config.enable_mixed_sampler + and pytorch_backend_config.attn_backend == "TRTLLM") + else: + use_chain_drafter = False + logger.debug(f"USE CHAIN DRAFTER: {use_chain_drafter}") if use_chain_drafter: def drafting_loop_wrapper(model): diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index 4fd4ff4d7f7..74384206740 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Optional, final -from ..pyexecutor.llm_request import LlmRequest +from ..pyexecutor.llm_request import LlmRequest, get_draft_token_length from ..pyexecutor.resource_manager import ResourceManager from ..pyexecutor.scheduler import ScheduledRequests @@ -52,3 +52,18 @@ def should_use_spec_decode(self, requests: List[LlmRequest], num_effective_requests = min(len(requests), max_batch_size, token_cap) return num_effective_requests <= self.max_concurrency + + @final + def pad_draft_tokens_for_cuda_graph( + self, scheduled_requests: ScheduledRequests) -> None: + """ + Pad draft tokens to the max draft length for CUDA graph compatibility. + + Args: + scheduled_requests: The scheduled requests to pad + """ + for req in scheduled_requests.generation_requests: + max_draft_tokens = self.max_draft_tokens + num_draft_tokens = get_draft_token_length(req) + req.py_draft_tokens.extend( + 0 for _ in range(max_draft_tokens - num_draft_tokens)) diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 3ecb323aa33..4b3723e2ca0 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -57,7 +57,8 @@ def needs_kv_cache_rewind(self): return self.is_mtp() or self.is_eagle3_one_model() or self.is_ngram() def support_overlap_scheduler(self): - return self.is_mtp() or self.is_eagle3_one_model() + return self.is_mtp() or self.is_eagle3_one_model( + ) or self.has_draft_model() def support_guided_decoder(self): return self.is_none() or self.has_spec_drafter() diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index ac1113b7925..cde12417f50 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -12,9 +12,11 @@ from ..pyexecutor.handle_logits import HandleLogits from ..pyexecutor.llm_request import LlmRequest, LlmRequestState from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager -from ..pyexecutor.sampler import Sampler, SampleState, TorchSampler +from ..pyexecutor.sampler import (Sampler, SampleState, SampleStateTensors, + TorchSampler) from ..pyexecutor.scheduler import ScheduledRequests from ..pyexecutor.seq_slot_manager import SeqSlotManager +from ..speculative.mtp import SampleStateTensorsMTP from .drafter import Drafter if TYPE_CHECKING: @@ -214,6 +216,10 @@ def _prepare_draft_batch( self._add_to_draft_batch(draft_batch, new_request, request) for request in scheduled_requests.generation_requests: + if request.state == LlmRequestState.GENERATION_COMPLETE: + # Skip generation complete requests. This could happen when enabling overlap scheduler. + continue + if request.py_draft_pages_allocated == 0: # No space for draft tokens continue @@ -239,31 +245,33 @@ def _prepare_draft_batch( traceback.print_exc() raise e - def _should_disable_cuda_graph( - self, previous_batch: Optional[SampleState]) -> bool: + def _should_disable_cuda_graph(self, is_first_draft_token: bool) -> bool: """Check if CUDA graph should be disabled for the current forward pass.""" - if previous_batch is not None: + if not is_first_draft_token: return False if self.use_static_draft_loop: return False return self.spec_config.spec_dec_mode.needs_kv_cache_recompute() - def _forward_draft_model( - self, - draft_batch: ScheduledRequests, - resource_manager: ResourceManager, - previous_batch: Optional[SampleState] = None) -> Dict[str, Any]: + def forward_draft_model( + self, + draft_batch: ScheduledRequests, + resource_manager: ResourceManager, + is_first_draft_token: bool, + previous_tensors: Optional[SampleStateTensors] = None + ) -> Dict[str, Any]: """Forward pass through the draft model.""" - if self._should_disable_cuda_graph(previous_batch): + if self._should_disable_cuda_graph(is_first_draft_token): with self.draft_model_engine.no_cuda_graph(): outputs = self.draft_model_engine.forward( - draft_batch, resource_manager) + draft_batch, + resource_manager, + new_tensors_device=previous_tensors) else: - new_tensors_device = previous_batch.device if previous_batch else None outputs = self.draft_model_engine.forward( draft_batch, resource_manager, - new_tensors_device=new_tensors_device) + new_tensors_device=previous_tensors) # Handle d2t data if available. Static drafting loops should incorporate d2t # in their implementations. @@ -273,8 +281,8 @@ def _forward_draft_model( return outputs - def _sample_async(self, draft_batch: ScheduledRequests, - outputs: Dict[str, Any]) -> Optional[SampleState]: + def sample_async(self, draft_batch: ScheduledRequests, + outputs: Dict[str, Any]) -> Optional[SampleState]: """Sample tokens from draft model outputs.""" try: if self.sampler is not None: @@ -298,8 +306,8 @@ def _sample_async(self, draft_batch: ScheduledRequests, logger.error(f"Error in sampling: {str(e)}") return None - def _update_request_states(self, - scheduled_requests: ScheduledRequests) -> None: + def update_request_states(self, + scheduled_requests: ScheduledRequests) -> None: """Update request states after processing.""" for request in scheduled_requests.context_requests: if request.state != LlmRequestState.GENERATION_COMPLETE: @@ -307,12 +315,12 @@ def _update_request_states(self, if request.context_remaining_length == 0: request.state = LlmRequestState.GENERATION_IN_PROGRESS - def _update_requests(self, sample_state: SampleState) -> None: + def update_requests(self, sample_state: SampleState) -> None: """Update requests with sample state.""" if self.sampler is not None: self.sampler.update_requests(sample_state) - def _process_decoded_tokens( + def process_decoded_tokens( self, draft_batch: ScheduledRequests, req_id_to_old_request: Dict[int, LlmRequest]) -> List[LlmRequest]: """Process decoded tokens and determine which requests to continue processing.""" @@ -337,6 +345,328 @@ def _process_decoded_tokens( return new_requests + def should_forward_draft_model(self, + scheduled_batch: ScheduledRequests) -> bool: + """ + Determine if the draft model should be forwarded for the given batch. + + Args: + scheduled_batch: The scheduled requests to check + + Returns: + bool: True if draft model should be forwarded, False otherwise + """ + for request in scheduled_batch.context_requests: + if request.is_first_context_chunk: + continue + return True + + for request in scheduled_batch.generation_requests: + if request.state == LlmRequestState.GENERATION_COMPLETE: + continue + + if request.max_beam_num_tokens - 1 >= self.draft_model_engine.max_seq_len: + continue + return True + + return False + + def _convert_draft_tensors( + self, + scheduled_batch: ScheduledRequests, + new_tensors_device: Optional[SampleStateTensors] = None + ) -> Optional[SampleStateTensorsMTP]: + """ + Convert tensors for draft model processing. + + Args: + scheduled_batch: The scheduled requests + new_tensors_device: The device tensors to convert + + Returns: + SampleStateTensorsMTP: Converted tensors or None + """ + if new_tensors_device is None: + return None + # Get device from the new_tokens tensor + device = new_tensors_device.new_tokens.device + + # Use the same shape as new_tensors_device.new_tokens + new_tokens = torch.zeros_like(new_tensors_device.new_tokens) + new_tokens_lens = None + next_draft_tokens = None + # Iterate through generation requests and copy tokens based on accepted draft tokens + for idx, request in enumerate(scheduled_batch.all_requests()): + if request.state != LlmRequestState.GENERATION_IN_PROGRESS: + num_accepted_tokens = request.py_num_accepted_draft_tokens + new_tokens[0, idx] = new_tensors_device.new_tokens[ + num_accepted_tokens, idx] + else: + # Create new tensors with the correct device + # We already updated the target state, so the new_tokens_lens should be all ones. + new_tokens_lens = torch.ones(scheduled_batch.batch_size, + device=device) + next_draft_tokens = torch.zeros(scheduled_batch.batch_size, + self.max_draft_tokens, + device=device) + num_accepted_tokens = request.py_num_accepted_draft_tokens + new_tokens[0, idx] = new_tensors_device.new_tokens[ + num_accepted_tokens, idx] + + # Create a new SampleStateTensorsMTP object with the additional fields + updated_tensors = SampleStateTensorsMTP( + new_tokens=new_tokens, + log_probs=new_tensors_device.log_probs, + new_tokens_lens=new_tokens_lens, + next_draft_tokens=next_draft_tokens) + + if hasattr(new_tensors_device, 'logits'): + updated_tensors.logits = new_tensors_device.logits + + return updated_tensors + + def _update_target_inputs_with_draft_tokens( + self, target_inputs: SampleStateTensorsMTP, + draft_tensors: Optional[torch.Tensor], draft_position: int, + draft_length: int, num_draft_reqs: int) -> None: + """ + Update target inputs with new draft tokens from sample state. + + Args: + target_inputs: The target input tensors to update + draft_sample_state: The draft sample state containing new tokens + iteration: The current iteration index + """ + if draft_tensors is not None: + for idx in range(num_draft_reqs): + # Skip prefill requests + if target_inputs.next_draft_tokens is None: + continue + target_inputs.new_tokens[draft_position + 1:draft_position + + draft_length + 1, idx, + 0] = draft_tensors[0:draft_length, idx] + target_inputs.next_draft_tokens[ + idx, draft_position:draft_position + + draft_length] = draft_tensors[0:draft_length, idx] + + def _setup_draft_batch_and_resources( + self, scheduled_batch: ScheduledRequests + ) -> Tuple[Optional[ScheduledRequests], Optional[Dict[int, LlmRequest]]]: + """ + Setup draft batch and prepare resources. + + Args: + scheduled_batch: The scheduled requests + + Returns: + Tuple of (draft_batch, req_id_to_old_request) or (None, None) if no batch + """ + + draft_batch = self._prepare_draft_batch(scheduled_batch) + if draft_batch.batch_size == 0: + return None, None + + req_id_to_old_request = { + req.py_request_id: req + for req in scheduled_batch.all_requests() + } + + self.draft_seq_slot_manager.prepare_resources(draft_batch) + return draft_batch, req_id_to_old_request + + def process_static_draft_outputs( + self, outputs: Any, draft_batch: ScheduledRequests, + req_id_to_old_request: Dict[int, LlmRequest]) -> None: + """ + Process outputs from static draft loop, update target requests, and clean up resources. + + Args: + outputs: The outputs from the draft model + draft_batch: The draft batch that was processed + req_id_to_old_request: Mapping from draft request ID to original request + """ + outputs_host = outputs.cpu() + for token_idx in range(self.max_draft_tokens): + for req_idx, req in enumerate(draft_batch.all_requests()): + target_model_req = req_id_to_old_request[req.py_request_id] + if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS: + # Chunked prefill request in progress; no need to append draft tokens + continue + + target_req = req_id_to_old_request[req.py_request_id] + target_req.py_draft_tokens.append( + outputs_host[token_idx][req_idx]) + + # Clean up draft resources + for req in draft_batch.all_requests(): + self.draft_seq_slot_manager.free_resources(req) + + def process_dynamic_draft_outputs( + self, outputs: Any, + req_id_to_old_request: Dict[int, LlmRequest]) -> None: + """ + Process outputs from dynamic draft loop, update target requests, and clean up resources. + """ + self.update_requests(outputs) + self.process_decoded_tokens(outputs.scheduled_requests, + req_id_to_old_request) + + def _execute_draft_iteration( + self, draft_batch: ScheduledRequests, resource_manager: ResourceManager, + previous_draft_state: Optional[SampleState] + ) -> Tuple[Any, Optional[SampleState]]: + """Forward pass through the draft model.""" + outputs = self.forward_draft_model( + draft_batch, + resource_manager, + is_first_draft_token=False, + previous_tensors=previous_draft_state.device + if previous_draft_state else None) + + if previous_draft_state is not None: + self.update_requests(previous_draft_state) + + if self.guided_decoder is not None: + self.guided_decoder.add_batch(draft_batch) + self.guided_decoder.execute(outputs['logits'], + d2t=outputs.get('d2t')) + + sample_state = self.sample_async(draft_batch, outputs) + self.update_request_states(draft_batch) + + return outputs, sample_state + + def _execute_draft_loop( + self, + draft_batch: ScheduledRequests, + resource_manager: ResourceManager, + req_id_to_old_request: Dict[int, LlmRequest], + target_inputs: Optional[SampleStateTensorsMTP] = None, + num_draft_reqs: Optional[int] = None, + initial_draft_state: Optional[SampleState] = None + ) -> Optional[SampleState]: + """ + Execute the iterative draft loop. + + Args: + draft_batch: The draft batch to process + resource_manager: The resource manager + req_id_to_old_request: Mapping from request ID to original request + target_inputs: Optional target inputs to update (for overlap mode) + num_draft_reqs: Number of draft requests (for overlap mode) + initial_draft_state: The initial draft state from the first forward pass + + Returns: + The final sample state + """ + # Convert context requests to generation requests + draft_batch.generation_requests = draft_batch.context_requests + draft_batch.generation_requests + draft_batch.context_requests = [] + + previous_draft_state = initial_draft_state + + # Generate remaining draft tokens iteratively + for i in range(self.max_draft_tokens - 1): + if len(draft_batch.generation_requests) == 0: + break + + _, sample_state = self._execute_draft_iteration( + draft_batch, resource_manager, previous_draft_state) + + # Update target inputs if provided (for overlap mode) + if target_inputs is not None and num_draft_reqs is not None: + draft_tensors = sample_state and sample_state.device and sample_state.device.new_tokens + self._update_target_inputs_with_draft_tokens( + target_inputs, + draft_tensors, + draft_position=i + 1, + draft_length=1, + num_draft_reqs=num_draft_reqs) + + if sample_state is not None and previous_draft_state is not None: + new_requests = self.process_decoded_tokens( + previous_draft_state.scheduled_requests, + req_id_to_old_request) + else: + new_requests = [] + + draft_batch.generation_requests = new_requests + previous_draft_state = sample_state + + return previous_draft_state + + def generate_draft_tokens_with_overlap( + self, scheduled_batch: ScheduledRequests, + resource_manager: ResourceManager, + previous_tensors: Optional[SampleStateTensors] + ) -> Tuple[Optional[SampleStateTensorsMTP], Optional[Any], + Optional[ScheduledRequests]]: + """ + Generate draft tokens with overlap scheduling support. + + Args: + scheduled_batch: The scheduled requests + resource_manager: The resource manager + previous_tensors: Previous iteration tensors + guided_decoder: The guided decoder + + Returns: + Tuple[Optional[SampleStateTensorsMTP], Optional[SampleState]]: + - Updated target inputs or None + - Draft sample state or None + """ + draft_batch, req_id_to_old_request = self._setup_draft_batch_and_resources( + scheduled_batch) + if draft_batch is None: + return None, None, None + + target_inputs = self._convert_draft_tensors(scheduled_batch, + previous_tensors) + if target_inputs is None: + return None, None, None + + # Initial forward pass + outputs = self.forward_draft_model(draft_batch, + resource_manager, + is_first_draft_token=True, + previous_tensors=previous_tensors) + + num_draft_reqs = len(draft_batch.all_requests()) + if self.use_static_draft_loop: + # Only update target inputs, cleanup will be done in executor loop + self._update_target_inputs_with_draft_tokens( + target_inputs, + outputs, + draft_position=0, + draft_length=self.max_draft_tokens, + num_draft_reqs=num_draft_reqs) + return target_inputs, outputs, draft_batch + + # Handle guided decoder and sampling for non-static loop + if self.guided_decoder is not None: + self.guided_decoder.add_batch(draft_batch) + self.guided_decoder.execute(outputs['logits'], + d2t=outputs.get('d2t')) + draft_sample_state = self.sample_async(draft_batch, outputs) + + # Update target inputs with first iteration results + draft_tensors = draft_sample_state and draft_sample_state.device and draft_sample_state.device.new_tokens + self._update_target_inputs_with_draft_tokens( + target_inputs, + draft_tensors, + draft_position=0, + draft_length=1, + num_draft_reqs=num_draft_reqs) + + self.update_request_states(draft_batch) + + # Execute the iterative draft loop + previous_draft_state = self._execute_draft_loop( + draft_batch, resource_manager, req_id_to_old_request, target_inputs, + num_draft_reqs, draft_sample_state) + + return target_inputs, previous_draft_state, draft_batch + @nvtx_range("prepare_draft_tokens") def prepare_draft_tokens( self, @@ -357,84 +687,38 @@ def prepare_draft_tokens( raise ValueError("Resource manager is required") try: - draft_batch = self._prepare_draft_batch(scheduled_requests) - - if draft_batch.batch_size == 0: + draft_batch, req_id_to_old_request = self._setup_draft_batch_and_resources( + scheduled_requests) + if draft_batch is None: return - self.draft_seq_slot_manager.prepare_resources(draft_batch) - - req_id_to_old_request = { - req.py_request_id: req - for req in scheduled_requests.all_requests() - } - # Initial forward pass. May do the complete drafting loop # if use_static_draft_loop is set. - outputs = self._forward_draft_model(draft_batch, resource_manager) + outputs = self.forward_draft_model(draft_batch, + resource_manager, + is_first_draft_token=True) if self.use_static_draft_loop: - outputs_host = outputs.cpu() - for token_idx in range(self.max_draft_tokens): - for req_idx, req in enumerate(draft_batch.all_requests()): - target_model_req = req_id_to_old_request[ - req.py_request_id] - if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS: - # Chunked prefill request in progress; no need to append draft tokens - continue - - target_req = req_id_to_old_request[req.py_request_id] - target_req.py_draft_tokens.append( - outputs_host[token_idx][req_idx]) - - for req in draft_batch.all_requests(): - self.draft_seq_slot_manager.free_resources(req) - + self.process_static_draft_outputs(outputs, draft_batch, + req_id_to_old_request) return if self.guided_decoder is not None: self.guided_decoder.add_batch(draft_batch) self.guided_decoder.execute(outputs['logits'], d2t=outputs.get('d2t')) - sample_state = self._sample_async(draft_batch, outputs) - previous_batch = sample_state - - self._update_request_states(draft_batch) - - # Convert context requests to generation requests - draft_batch.generation_requests = draft_batch.context_requests + draft_batch.generation_requests - draft_batch.context_requests = [] - - # Generate remaining draft tokens iteratively - for i in range(self.max_draft_tokens - 1): - if len(draft_batch.generation_requests) == 0: - break - - outputs = self._forward_draft_model(draft_batch, - resource_manager, - previous_batch) - if previous_batch is not None: - self._update_requests(previous_batch) - if self.guided_decoder is not None: - self.guided_decoder.add_batch(draft_batch) - self.guided_decoder.execute(outputs['logits'], - d2t=outputs.get('d2t')) - sample_state = self._sample_async(draft_batch, outputs) - self._update_request_states(draft_batch) - if previous_batch is not None: - new_requests = self._process_decoded_tokens( - previous_batch.scheduled_requests, - req_id_to_old_request) - else: - new_requests = [] - draft_batch.generation_requests = new_requests - previous_batch = sample_state + sample_state = self.sample_async(draft_batch, outputs) + self.update_request_states(draft_batch) + + # Execute the iterative draft loop + previous_draft_state = self._execute_draft_loop( + draft_batch, resource_manager, req_id_to_old_request, None, + None, sample_state) # Final cleanup - if previous_batch is not None: - self._update_requests(previous_batch) - self._process_decoded_tokens(previous_batch.scheduled_requests, - req_id_to_old_request) + if previous_draft_state is not None: + self.process_dynamic_draft_outputs(previous_draft_state, + req_id_to_old_request) except Exception as e: traceback.print_exc() diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index 429ccebf80e..dc23270945b 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -25,7 +25,7 @@ class NGramPoolManager(BaseResourceManager): `matches` is a list of candidate draft token ids attaching to a pattern. Arguments: - max_draft_len: int + max_draft_tokens: int The length maximum of draft tokens (can be understood as length maximum of output draft tokens). max_matching_ngram_size: int @@ -51,7 +51,7 @@ class NGramPoolManager(BaseResourceManager): def __init__(self, spec_config: "NGramDecodingConfig", max_num_requests: int): - self.max_draft_len = spec_config.max_draft_len + self.max_draft_tokens = spec_config.max_draft_len self.max_matching_ngram_size = spec_config.max_matching_ngram_size self.is_keep_all = spec_config.is_keep_all self.is_use_oldest = spec_config.is_use_oldest # TODO: remove this if updating strategy is supported @@ -107,7 +107,7 @@ def get_draft_tokens( -1): # Find each possible pattern-match combination, and use tuple for hash for l in range(len(sequence) - size): - r = min(l + size + self.max_draft_len, len(sequence)) + r = min(l + size + self.max_draft_tokens, len(sequence)) pattern = tuple(sequence[l:l + size]) new_match = tuple(sequence[l + size:r]) if pattern not in pool or \ @@ -138,7 +138,7 @@ def get_draft_tokens( # Update start_index self.start_index[request_id] = max( 0, prefix_len - - (self.max_draft_len + self.max_matching_ngram_size - 1)) + (self.max_draft_tokens + self.max_matching_ngram_size - 1)) return draft_tokens @@ -170,7 +170,7 @@ def __init__( super().__init__(spec_config.max_concurrency) assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool." self.spec_config = spec_config - self.max_draft_len = spec_config.max_draft_len + self.max_draft_tokens = spec_config.max_draft_len self.spec_resource_manager = ngram_pool_manager def prepare_draft_tokens( diff --git a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py index d1d9510b925..cdf48cfea57 100644 --- a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py +++ b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py @@ -14,8 +14,9 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +@pytest.mark.parametrize("disable_overlap_scheduler", [True, False]) @pytest.mark.high_cuda_memory -def test_dynamic_spec_decode(): +def test_dynamic_spec_decode(disable_overlap_scheduler: bool): total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if total_mem_gb < 35: pytest.skip("Not enough memory to load target + draft model") @@ -32,7 +33,7 @@ def test_dynamic_spec_decode(): llm_common_config = dict( model=target_model_dir, attn_backend="TRTLLM", - disable_overlap_scheduler=True, + disable_overlap_scheduler=disable_overlap_scheduler, cuda_graph_config=cuda_graph_config, max_batch_size=max_batch_size, kv_cache_config=kv_cache_config, @@ -51,8 +52,8 @@ def test_dynamic_spec_decode(): ) # Mock should_use_spec_decode to return True for first two calls, then False - def mock_should_use_spec_decode(self, requests, max_batch_size, - max_num_tokens, max_draft_len): + def mock_should_use_spec_decode(requests, max_batch_size, max_num_tokens, + max_draft_len): if not hasattr(mock_should_use_spec_decode, 'call_count'): mock_should_use_spec_decode.call_count = 0 mock_should_use_spec_decode.call_count += 1 diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index bf69917ef28..fbd5e93e9d9 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -17,22 +17,34 @@ @pytest.mark.parametrize( - "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill", + "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter", [ - [True, "TRTLLM", True, False, False, False], - [False, "TRTLLM", True, False, False, False], - [True, "FLASHINFER", True, False, False, False], - [False, "FLASHINFER", True, False, False, False], - [False, "TRTLLM", False, True, True, False], - [True, "TRTLLM", False, True, True, False], - [True, "TRTLLM", True, False, True, True], + [True, "TRTLLM", True, False, False, False, True], + [True, "TRTLLM", True, False, False, False, False], + [False, "TRTLLM", True, False, False, False, True], + [False, "TRTLLM", True, False, False, False, False], + [True, "FLASHINFER", True, False, False, False, True], + [False, "FLASHINFER", True, False, False, False, True], + [False, "TRTLLM", False, True, True, False, True], + [True, "TRTLLM", False, True, True, False, True], + [True, "TRTLLM", True, False, True, True, True], + [True, "TRTLLM", True, False, True, False, True], # TODO: nvbugs/5461761 - # [True, "TRTLLM", True, False, False, True], + # [True, "TRTLLM", True, False, False, True, True], + [True, "TRTLLM", False, False, False, False, True], + [False, "TRTLLM", False, False, False, False, True], + [True, "TRTLLM", False, False, False, False, False], + [False, "TRTLLM", False, False, False, False, False], + [True, "TRTLLM", False, False, False, True, True], + [True, "TRTLLM", False, False, False, True, False], + [True, "FLASHINFER", False, False, False, False, True], + [False, "FLASHINFER", False, False, False, False, True], ]) @pytest.mark.high_cuda_memory def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, disable_overlap_scheduler: bool, enable_block_reuse: bool, - use_one_model: bool, enable_chunked_prefill: bool): + use_one_model: bool, enable_chunked_prefill: bool, + use_chain_drafter: bool): # Eagle3 one model works with overlap scheduler and block reuse. total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if total_mem_gb < 35: @@ -76,52 +88,63 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, eagle3_one_model=use_one_model, ) - llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + # Set the development flag to control use_chain_drafter behavior + original_env_value = os.environ.get("TRTLLM_ALLOW_CHAIN_DRAFTER", "0") + try: + os.environ[ + "TRTLLM_ALLOW_CHAIN_DRAFTER"] = "1" if use_chain_drafter else "0" + # Create the LLM instance with the mocked flag controlling use_chain_drafter + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) - # Acceptance rate tests - if enable_chunked_prefill: - # Use a long prompt for chunked prefill tests. - prompts = [ - "The capital of France is a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided into 20 arrondissements, each with its own unique character and charm. The Latin Quarter is a popular area for students and young travelers, while the Champs-Élysées is a hub for shopping and dining. The Montmartre neighborhood is famous for its bohemian vibe and stunning views of the city.\nParis is also known for its beautiful parks and gardens, such as the Luxembourg Gardens and the Tuileries Garden. The city has a rich history, with landmarks like the Notre-Dame Cathedral and the Arc de Triomphe. Visitors can also explore the city's many museums, including the Musée d'Orsay and the Musée Rodin.\nIn addition to its cultural and historical attractions, Paris is also a great destination for foodies. The city is famous for its cuisine, including croissants, baguettes, and cheese. Visitors can sample the city's famous dishes at one of the many restaurants, cafes, and " - ] - tok_ids = llm_spec.tokenizer.encode(prompts[0]) - else: - prompts = [ - "The capital of France is", - "The president of the United States is", - ] - tok_ids = llm_spec.tokenizer.encode("The future of AI is") + # Acceptance rate tests + if enable_chunked_prefill: + # Use a long prompt for chunked prefill tests. + prompts = [ + "The capital of France is a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided into 20 arrondissements, each with its own unique character and charm. The Latin Quarter is a popular area for students and young travelers, while the Champs-Élysées is a hub for shopping and dining. The Montmartre neighborhood is famous for its bohemian vibe and stunning views of the city.\nParis is also known for its beautiful parks and gardens, such as the Luxembourg Gardens and the Tuileries Garden. The city has a rich history, with landmarks like the Notre-Dame Cathedral and the Arc de Triomphe. Visitors can also explore the city's many museums, including the Musée d'Orsay and the Musée Rodin.\nIn addition to its cultural and historical attractions, Paris is also a great destination for foodies. The city is famous for its cuisine, including croissants, baguettes, and cheese. Visitors can sample the city's famous dishes at one of the many restaurants, cafes, and " + ] + tok_ids = llm_spec.tokenizer.encode(prompts[0]) + else: + prompts = [ + "The capital of France is", + "The president of the United States is", + ] + tok_ids = llm_spec.tokenizer.encode("The future of AI is") + + num_tokens = 0 + num_drafted = 0 + num_accepted = 0 + sampling_params = SamplingParams(max_tokens=128, temperature=0) + for output in llm_spec.generate_async(tok_ids, + sampling_params, + streaming=True): + new_tokens = output.outputs[0].token_ids + num_drafted += max_draft_len + num_accepted += len(new_tokens) - num_tokens - 1 + num_tokens = len(new_tokens) + + accept_rate = num_accepted / num_drafted + assert accept_rate > 0.15 - num_tokens = 0 - num_drafted = 0 - num_accepted = 0 - sampling_params = SamplingParams(max_tokens=128, temperature=0) - for output in llm_spec.generate_async(tok_ids, - sampling_params, - streaming=True): - new_tokens = output.outputs[0].token_ids - num_drafted += max_draft_len - num_accepted += len(new_tokens) - num_tokens - 1 - num_tokens = len(new_tokens) - - accept_rate = num_accepted / num_drafted - assert accept_rate > 0.15 - - # Output tests - sampling_params = SamplingParams(max_tokens=10, temperature=0) - - results_spec = llm_spec.generate(prompts, sampling_params) - generated_text_spec = [result.outputs[0].text for result in results_spec] - llm_spec.shutdown() - - llm_ref = LLM(**llm_common_config) - results_ref = llm_ref.generate(prompts, sampling_params) - generated_text_ref = [result.outputs[0].text for result in results_ref] - llm_ref.shutdown() - - for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): - # The spec decode algorithm currently guarantees identical results - assert text_spec == text_ref + # Output tests + sampling_params = SamplingParams(max_tokens=10, temperature=0) + + results_spec = llm_spec.generate(prompts, sampling_params) + generated_text_spec = [ + result.outputs[0].text for result in results_spec + ] + llm_spec.shutdown() + + llm_ref = LLM(**llm_common_config) + results_ref = llm_ref.generate(prompts, sampling_params) + generated_text_ref = [result.outputs[0].text for result in results_ref] + llm_ref.shutdown() + + for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): + # The spec decode algorithm currently guarantees identical results + assert text_spec == text_ref + finally: + # Restore the original environment variable value + os.environ["TRTLLM_ALLOW_CHAIN_DRAFTER"] = original_env_value def test_deepseek_eagle3():