Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,13 +1183,29 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
num_ctx_requests = inputs['attn_metadata'].num_contexts
num_gen_requests = inputs['attn_metadata'].num_generations
num_ctx_tokens = inputs['attn_metadata'].num_ctx_tokens
num_chunked_ctx_requests = inputs[
'attn_metadata'].num_chunked_ctx_requests
previous_batch_tokens = inputs['input_ids'].shape[
0] - num_ctx_tokens
inputs['position_ids'][0, num_ctx_tokens:] += (
self.previous_pos_id_offsets_cuda[:previous_batch_tokens])
inputs['attn_metadata'].kv_lens_cuda[
num_ctx_requests:num_seqs] += (
self.previous_kv_lens_offsets_cuda[:num_gen_requests])
# Only TrtllmAttentionMetadata has kv_lens_cuda.
if isinstance(inputs['attn_metadata'], TrtllmAttentionMetadata):
if num_chunked_ctx_requests > 0:
# The generation requests with draft_tokens are treated as chunked context requests when extend_ctx returns True.
inputs['attn_metadata'].kv_lens_cuda[
num_ctx_requests -
num_chunked_ctx_requests:num_ctx_requests] += (
self.
previous_kv_lens_offsets_cuda[:
num_chunked_ctx_requests]
)
else:
inputs['attn_metadata'].kv_lens_cuda[
num_ctx_requests:num_seqs] += (
self.
previous_kv_lens_offsets_cuda[:num_gen_requests]
)

if self.guided_decoder is not None:
self.guided_decoder.token_event.record()
Expand Down Expand Up @@ -1285,8 +1301,9 @@ def _prepare_tp_inputs(
if new_tensors_device is not None:
# speculative decoding cases: [batch, 1 + draft_len], others: [batch]
new_tokens_device = new_tensors_device.new_tokens
if self.without_logits:
assert isinstance(new_tensors_device, SampleStateTensorsMTP)
# When using overlap scheduler with speculative decoding, the target model's inputs would be SampleStateTensorsMTP.
if isinstance(new_tensors_device, SampleStateTensorsMTP):
assert self.enable_spec_decode and not self.is_draft_model
new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch]
next_draft_tokens_device = new_tensors_device.next_draft_tokens # [batch, draft_len]

Expand Down Expand Up @@ -1453,9 +1470,19 @@ def _prepare_tp_inputs(
previous_batch_indices.append(previous_batch_idx)
previous_pos_indices.extend([previous_batch_idx] *
(1 + self.runtime_draft_len))
num_cached_tokens_per_seq.append(past_seen_token_num +
self.runtime_draft_len + 1)
prompt_lengths.append(request.py_prompt_len)
if self.spec_config.spec_dec_mode.has_draft_model():
# In the overlap scheduler workflow, if having draft model, we already updated the previous batch before launching the target model,
# so we only need to add the runtime_draft_len to the past_seen_token_num.
num_cached_tokens_per_seq.append(past_seen_token_num +
self.runtime_draft_len)
else:
num_cached_tokens_per_seq.append(past_seen_token_num +
self.runtime_draft_len + 1)
if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx(
self.attn_backend):
prompt_lengths.append(1 + self.runtime_draft_len)
else:
prompt_lengths.append(request.py_prompt_len)

for request in generation_requests:
request_ids.append(request.py_request_id)
Expand Down Expand Up @@ -1637,9 +1664,13 @@ def previous_seq_slots_device():
attn_metadata.request_ids = request_ids
attn_metadata.prompt_lens = prompt_lengths
attn_metadata.num_contexts = len(scheduled_requests.context_requests)
# Use num_chunked_ctx_requests to record the number of extend context requests,
# so that we can update the kv_lens_cuda correctly in _preprocess_inputs.
attn_metadata.num_chunked_ctx_requests = 0
if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx(
self.attn_backend):
attn_metadata.num_contexts += len(extend_requests)
attn_metadata.num_chunked_ctx_requests = len(extend_requests)

attn_metadata.kv_cache_params = KVCacheParams(
use_cache=True,
Expand Down
104 changes: 84 additions & 20 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from .kv_cache_connector import KvCacheConnectorManager
from .kv_cache_transceiver import KvCacheTransceiver
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
LlmResponse, get_draft_token_length)
LlmResponse)
from .model_engine import ModelEngine
from .sampler import Sampler, SampleState, SampleStateTensors
from .scheduler import RequestScheduler, ScheduledRequests
Expand Down Expand Up @@ -220,6 +220,7 @@ def __init__(self,
self.expected_num_active_requests = 0
self.ctx_in_transmission_requests = []
self.previous_batch: Optional[BatchState] = None
self.has_previous_draft_tokens = False
self.num_scheduled_requests: int = 0
self.benchmark_req_queues_size = int(
os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0))
Expand Down Expand Up @@ -278,11 +279,10 @@ def __init__(self,
self.event_loop = trace_func(self.event_loop)

if self.drafter is not None:
if self.event_loop.__name__ != self._executor_loop.__name__:
if self.event_loop.__name__ == self._executor_loop_pp.__name__:
raise NotImplementedError(
"Drafting is not supported for selected executor loop. "
"Please disable disagg/pipeline parallelism/overlap scheduler."
)
"Please disable disagg/pipeline parallelism scheduler.")
self.draft_seq_slot_manager = SeqSlotManager(max_num_sequences)
self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -967,11 +967,15 @@ def _prepare_and_schedule_batch(self):
self.model_engine.max_num_tokens,
self.model_engine.spec_config.max_draft_len)
self.model_engine.enable_spec_decode = self.use_spec_decode
# If speculation is off, this function sets py_draft_tokens to None
# for all active requests. If it's on, we initialize py_draft_tokens
# with dummy draft tokens to make the scheduler aware of the fact
# that speculation is about to happen.
self._prepare_draft_requests()

# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
# we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
if not self.has_previous_draft_tokens:
# If speculation is off, this function sets py_draft_tokens to None
# for all active requests. If it's on, we initialize py_draft_tokens
# with dummy draft tokens to make the scheduler aware of the fact
# that speculation is about to happen.
self._prepare_draft_requests()

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
Expand Down Expand Up @@ -1063,14 +1067,9 @@ def _executor_loop(self):
scheduled_requests=scheduled_batch):
self.drafter.prepare_draft_tokens(
scheduled_batch, self.resource_manager)
# Pad draft tokens to the max draft length. This is for CUDA
# graph compatibility.
for req in scheduled_batch.generation_requests:
max_draft_tokens = self.max_draft_len
num_draft_tokens = get_draft_token_length(req)
req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens -
num_draft_tokens))
# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
self.drafter.pad_draft_tokens_for_cuda_graph(
scheduled_batch)
# add_batch must be called again to restore to target requests with updated draft tokens.
if self.guided_decoder is not None:
self.guided_decoder.add_batch(scheduled_batch)
Expand Down Expand Up @@ -1196,12 +1195,27 @@ def _executor_loop_overlap(self):
self.guided_decoder.add_batch(scheduled_batch)
self.guided_decoder.init_disagg_gen_requests()

previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device
previous_tensors = self.previous_batch and self.previous_batch.sample_state
target_inputs = None
draft_outputs = None
if self.drafter is not None and self.use_spec_decode:
target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding(
scheduled_batch, previous_tensors)

# Use the draft_model's outputs if we've launched the draft model.
# Otherwise, use the previous batch's outputs.
if target_inputs is not None:
previous_tensors_device = target_inputs
else:
previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device

batch_outputs = self._forward_step(scheduled_batch,
previous_tensors_device)

if self.previous_batch is not None:
if target_inputs is not None:
self._process_draft_results(scheduled_batch,
draft_outputs, draft_batch)
elif self.previous_batch is not None:
self._update_requests(self.previous_batch.sample_state)

if self.guided_decoder is not None:
Expand All @@ -1221,7 +1235,6 @@ def _executor_loop_overlap(self):

if self.previous_batch is not None:
self._process_previous_batch()
self.previous_batch: Optional[BatchState] = None

if self.enable_iter_perf_stats:
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
Expand Down Expand Up @@ -1954,6 +1967,57 @@ def _remove_inflight_ids(self, scheduled_requests):
for req in scheduled_requests.all_requests():
self.inflight_req_ids.erase(req.request_id)

def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
with request_context(is_draft=True, scheduled_requests=scheduled_batch):
# Do an early checking to see if we need to forward the draft model.
# If needed, the overlap should happen between the target requests and the draft requests.
# Otherwise, we can still do overlap between the previous target requests and the current target requests.
has_draft_batch = (
self.previous_batch is not None
and self.drafter.should_forward_draft_model(scheduled_batch))

if has_draft_batch:
self._update_requests(self.previous_batch.sample_state)
if self.has_previous_draft_tokens:
self._prepare_draft_requests()

target_inputs, draft_outputs, draft_batch = self.drafter.generate_draft_tokens_with_overlap(
scheduled_batch, self.resource_manager,
previous_tensors.device if previous_tensors else None)

self.has_previous_draft_tokens = target_inputs is not None and target_inputs.next_draft_tokens is not None
else:
self.has_previous_draft_tokens = False
target_inputs, draft_outputs, draft_batch = None, None, None

return target_inputs, draft_outputs, draft_batch

def _process_draft_results(self, scheduled_batch, draft_outputs,
draft_batch):
"""
Append the draft tokens to the target requests, and clean up the draft resources.
"""
req_id_to_old_request = {
req.py_request_id: req
for req in scheduled_batch.all_requests()
}

if self.drafter.use_static_draft_loop:
self.drafter.process_static_draft_outputs(draft_outputs,
draft_batch,
req_id_to_old_request)
elif draft_outputs is not None:
self.drafter.process_dynamic_draft_outputs(draft_outputs,
req_id_to_old_request)

# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
# add_batch must be called again to restore to target requests with updated draft tokens.
if self.guided_decoder is not None:
self.guided_decoder.add_batch(scheduled_batch)
if hasattr(self.drafter, "guided_decoder"):
self.guided_decoder.rollback_draft_tokens()


class DisaggPPTerminationHandler:
"""Handles termination synchronization across pipeline parallel ranks under disaggregated serving.
Expand Down
20 changes: 16 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import enum
import importlib
import os
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
Expand Down Expand Up @@ -37,6 +38,13 @@
from .py_executor import PyExecutor


# Development flag to control chain drafter feature
def _get_allow_chain_drafter() -> bool:
"""Get the chain drafter flag from environment variable."""
# Use environment variable for cross-process compatibility
return os.getenv("TRTLLM_ALLOW_CHAIN_DRAFTER", "0") == "1"


class _ExecutorCreationStage(enum.Enum):
SAMPLER = "Sampler"
DRAFTER = "Drafter"
Expand Down Expand Up @@ -282,11 +290,15 @@ def create_py_executor(
# generation requests when we invoke it autoregressively
draft_spec_config.max_draft_len = 0

use_chain_drafter = (
executor_config.guided_decoding_config is None
and not pytorch_backend_config.enable_mixed_sampler
and pytorch_backend_config.attn_backend == "TRTLLM")
if _get_allow_chain_drafter():
use_chain_drafter = (
executor_config.guided_decoding_config is None
and not pytorch_backend_config.enable_mixed_sampler
and pytorch_backend_config.attn_backend == "TRTLLM")
else:
use_chain_drafter = False

logger.debug(f"USE CHAIN DRAFTER: {use_chain_drafter}")
if use_chain_drafter:

def drafting_loop_wrapper(model):
Expand Down
17 changes: 16 additions & 1 deletion tensorrt_llm/_torch/speculative/drafter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import List, Optional, final

from ..pyexecutor.llm_request import LlmRequest
from ..pyexecutor.llm_request import LlmRequest, get_draft_token_length
from ..pyexecutor.resource_manager import ResourceManager
from ..pyexecutor.scheduler import ScheduledRequests

Expand Down Expand Up @@ -52,3 +52,18 @@ def should_use_spec_decode(self, requests: List[LlmRequest],

num_effective_requests = min(len(requests), max_batch_size, token_cap)
return num_effective_requests <= self.max_concurrency

@final
def pad_draft_tokens_for_cuda_graph(
self, scheduled_requests: ScheduledRequests) -> None:
"""
Pad draft tokens to the max draft length for CUDA graph compatibility.

Args:
scheduled_requests: The scheduled requests to pad
"""
for req in scheduled_requests.generation_requests:
max_draft_tokens = self.max_draft_tokens
num_draft_tokens = get_draft_token_length(req)
req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens - num_draft_tokens))
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def needs_kv_cache_rewind(self):
return self.is_mtp() or self.is_eagle3_one_model() or self.is_ngram()

def support_overlap_scheduler(self):
return self.is_mtp() or self.is_eagle3_one_model()
return self.is_mtp() or self.is_eagle3_one_model(
) or self.has_draft_model()

def support_guided_decoder(self):
return self.is_none() or self.has_spec_drafter()
Expand Down
Loading
Loading