Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ def __init__(self, engine: "PyTorchModelEngine"):
self.max_beam_width = engine.max_beam_width
self.spec_config = engine.spec_config

self.max_possible_draft_len = (self.spec_config.max_draft_len
if self.enable_spec_decode else 0)

self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {}
self.graph_outputs: Dict[Tuple[int, int],
Callable[[], Optional[torch.Tensor]]] = {}
Expand All @@ -58,7 +55,7 @@ def _create_shared_static_tensors(self):
"""Allocates static tensors sized for the largest possible batch."""
engine = self._get_engine()

token_per_request = self.max_possible_draft_len + 1
token_per_request = self.draft_len + 1
max_total_tokens = (self.max_supported_batch_size *
self.max_beam_width * token_per_request)
max_total_tokens = min(max_total_tokens, engine.max_num_tokens)
Expand All @@ -78,7 +75,7 @@ def _create_shared_static_tensors(self):

@property
def enable_spec_decode(self):
return self._get_engine().is_spec_decode
return self._get_engine().enable_spec_decode

@property
def draft_len(self):
Expand Down Expand Up @@ -174,7 +171,7 @@ def capture(self,
# [CUDA graph spec decode padding]
# We pad input IDs/position IDs to the maximum draft length (token per request).
# We're forced to do this because we cannot reallocate inputs over many graph runs.
token_per_request = self.max_possible_draft_len + 1
token_per_request = self.draft_len + 1
num_tokens_for_capture = (batch_size * self.max_beam_width *
token_per_request)

Expand Down
17 changes: 12 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,13 +1198,19 @@ def _executor_loop_overlap(self):
previous_tensors = self.previous_batch and self.previous_batch.sample_state
target_inputs = None
draft_outputs = None
if self.drafter is not None and self.use_spec_decode:
# If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
# When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
# so we'll set the target model's input to None and skip updating the target requests after target model forward.
update_target_requests_before_forward = self.has_previous_draft_tokens
if self.drafter is not None and (
self.use_spec_decode
or update_target_requests_before_forward):
target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding(
scheduled_batch, previous_tensors)

# Use the draft_model's outputs if we've launched the draft model.
# Otherwise, use the previous batch's outputs.
if target_inputs is not None:
if target_inputs is not None or update_target_requests_before_forward:
previous_tensors_device = target_inputs
else:
previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device
Expand All @@ -1215,7 +1221,7 @@ def _executor_loop_overlap(self):
if target_inputs is not None:
self._process_draft_results(scheduled_batch,
draft_outputs, draft_batch)
elif self.previous_batch is not None:
elif self.previous_batch is not None and not update_target_requests_before_forward:
self._update_requests(self.previous_batch.sample_state)

if self.guided_decoder is not None:
Expand Down Expand Up @@ -1973,14 +1979,15 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
# If needed, the overlap should happen between the target requests and the draft requests.
# Otherwise, we can still do overlap between the previous target requests and the current target requests.
has_draft_batch = (
self.previous_batch is not None
self.previous_batch is not None and self.use_spec_decode
and self.drafter.should_forward_draft_model(scheduled_batch))

if has_draft_batch:
if has_draft_batch or self.has_previous_draft_tokens:
self._update_requests(self.previous_batch.sample_state)
if self.has_previous_draft_tokens:
self._prepare_draft_requests()

if has_draft_batch:
target_inputs, draft_outputs, draft_batch = self.drafter.generate_draft_tokens_with_overlap(
scheduled_batch, self.resource_manager,
previous_tensors.device if previous_tensors else None)
Expand Down
6 changes: 2 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,9 @@
from .py_executor import PyExecutor


# Development flag to control chain drafter feature
# Development function to control chain drafter feature
def _get_allow_chain_drafter() -> bool:
"""Get the chain drafter flag from environment variable."""
# Use environment variable for cross-process compatibility
return os.getenv("TRTLLM_ALLOW_CHAIN_DRAFTER", "0") == "1"
return True


class _ExecutorCreationStage(enum.Enum):
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,9 @@ def update_requests(self, state: SampleState) -> None:
if get_draft_token_length(req) > 0:
req.py_num_accepted_draft_tokens = num_accepted
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
else:
req.py_num_accepted_draft_tokens = 0
req.py_rewind_len = 0
processed += num_accepted
self.handle_logprobs(req, state, beam=self.BEAM, count=processed)
req.py_decoding_iter += 1
Expand Down
10 changes: 5 additions & 5 deletions tensorrt_llm/_torch/speculative/model_drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def _convert_draft_tensors(
new_tokens_lens = None
next_draft_tokens = None
has_draft_tokens = False
batch_size = new_tokens.shape[1]
# Iterate through generation requests and copy tokens based on accepted draft tokens
for request in scheduled_batch.all_requests():
idx = request.py_seq_slot
Expand All @@ -411,9 +412,8 @@ def _convert_draft_tensors(

if has_draft_tokens:
# We already updated the target state, so the new_tokens_lens should be all ones.
new_tokens_lens = torch.ones(scheduled_batch.batch_size,
device=device)
next_draft_tokens = torch.zeros(scheduled_batch.batch_size,
new_tokens_lens = torch.ones(batch_size, device=device)
next_draft_tokens = torch.zeros(batch_size,
self.max_draft_tokens,
device=device)

Expand All @@ -438,13 +438,13 @@ def _update_target_inputs_with_draft_tokens(
Update target inputs with new draft tokens from sample state.
"""
if draft_tensors is not None:
for request in draft_batch.all_requests():
for req_idx, request in enumerate(draft_batch.all_requests()):
# Skip prefill requests
if target_inputs.next_draft_tokens is None:
continue

# Get the index of the draft/target tokens in the device tensor
draft_idx = request.py_seq_slot
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
target_idx = req_id_to_old_request[
request.py_request_id].py_seq_slot
target_inputs.new_tokens[draft_position + 1:draft_position +
Expand Down
1 change: 1 addition & 0 deletions tests/unittest/_torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def create_mock_engine(batch_size: int):
max_beam_width=1,
max_num_tokens=8192,
is_spec_decode=False,
enable_spec_decode=False,
spec_config=None,
_cuda_graph_mem_pool=None,
use_mrope=False,
Expand Down
64 changes: 42 additions & 22 deletions tests/unittest/_torch/speculative/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
import os
import sys
import unittest
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest
import torch
from utils.llm_data import llm_models_root

from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
KvCacheConfig)

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))


@pytest.fixture(scope="function")
def enforce_single_worker(monkeypatch):
monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1")
yield


@pytest.mark.parametrize("disable_overlap_scheduler", [True, False])
@pytest.mark.high_cuda_memory
def test_dynamic_spec_decode(disable_overlap_scheduler: bool):
def test_dynamic_spec_decode(enforce_single_worker,
disable_overlap_scheduler: bool):
# mock_should_use_spec_decode doesn't work with multiple processes,
# so we use the enforce_single_worker fixture to set the environment variable.
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
pytest.skip("Not enough memory to load target + draft model")
Expand Down Expand Up @@ -51,32 +61,42 @@ def test_dynamic_spec_decode(disable_overlap_scheduler: bool):
eagle3_one_model=False,
)

# Mock should_use_spec_decode to return True for first two calls, then False
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
sampling_params = SamplingParams(max_tokens=128, temperature=0)

# Output tests
prompts = [
"The president of the United States is",
]
sampling_params = SamplingParams(max_tokens=20, temperature=0)

# Mock should_use_spec_decode to turn on/off spec decode dynamically.
def mock_should_use_spec_decode(requests, max_batch_size, max_num_tokens,
max_draft_len):
if not hasattr(mock_should_use_spec_decode, 'call_count'):
mock_should_use_spec_decode.call_count = 0
mock_should_use_spec_decode.call_count += 1
return mock_should_use_spec_decode.call_count <= 2
for req in requests:
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
continue

mock_should_use_spec_decode.call_count += 1
# Turn off spec decode when we've called it 5 times.
# In the current case, at the 5th call, there are 2 accepted draft tokens,
# so we can have better coverage for the switching between spec decode on and off.
if mock_should_use_spec_decode.call_count > 5:
return False
return True

# Create a Mock object with the mock function as side_effect
mock_should_use_spec_decode = Mock(side_effect=mock_should_use_spec_decode)
# Reset mock state before using it
mock_should_use_spec_decode.reset_mock()
mock_should_use_spec_decode.call_count = 0

with patch(
'tensorrt_llm._torch.speculative.model_drafter.ModelDrafter.should_use_spec_decode',
side_effect=mock_should_use_spec_decode):
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
sampling_params = SamplingParams(max_tokens=128, temperature=0)

# Output tests
prompts = [
"The capital of France is",
"The president of the United States is",
]
sampling_params = SamplingParams(max_tokens=10, temperature=0)

mock_should_use_spec_decode):
results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [
result.outputs[0].text for result in results_spec
]
llm_spec.shutdown()
generated_text_spec = [result.outputs[0].text for result in results_spec]
llm_spec.shutdown()

llm_ref = LLM(**llm_common_config)
results_ref = llm_ref.generate(prompts, sampling_params)
Expand Down
Loading
Loading