Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
attn_metadata: AttentionMetadata,
spec_metadata: Optional[SpecMetadata] = None,
use_mrope: bool = False,
max_beam_width: int = 1,
) -> None:
"""
Stores a CUDA graph and its associated input buffers.
Expand All @@ -49,19 +50,21 @@ def __init__(
e.g. FlashInfer cause graph breaks).
"""
self.batch_size = batch_size

self.max_beam_width = max_beam_width
# [CUDA graph spec decode padding]
# We pad input IDs/position IDs to the maximum draft length (token per request).
# We're forced to do this because we cannot reallocate inputs over many graph runs.
token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1

# Using ones instead of zeros prevents NaNs in e.g. Deepseek
self.input_ids = torch.ones((batch_size * token_per_request, ),
device=device,
dtype=torch.int32)
self.position_ids = torch.zeros((1, batch_size * token_per_request),
device=device,
dtype=torch.int32)
self.input_ids = torch.ones(
(batch_size * max_beam_width * token_per_request, ),
device=device,
dtype=torch.int32)
self.position_ids = torch.zeros(
(1, batch_size * max_beam_width * token_per_request),
device=device,
dtype=torch.int32)
self.mrope_position_deltas = torch.zeros(
(batch_size,
1), device=device, dtype=torch.int32) if use_mrope else None
Expand Down
15 changes: 9 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,8 +842,8 @@ def _get_padded_batch(
spec_resource_manager: Optional[BaseResourceManager] = None) -> int:
can_run_cuda_graph = scheduled_requests.can_run_cuda_graph
batch_size = scheduled_requests.batch_size
# The number of sequences in the batch is the number of prompts times the beam width.
new_batch_size = batch_size * self.max_beam_width
new_batch_size = batch_size

if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
graph_batch_size = self.dist.tp_allgather(
[can_run_cuda_graph, batch_size])
Expand Down Expand Up @@ -977,8 +977,8 @@ def _maybe_get_cuda_graph(
self._cuda_graphs[batch_size] = {}

self._cuda_graphs[batch_size][draft_len] = DecodingCUDAGraphRunner(
num_sequences_in_batch, "cuda", attn_metadata, spec_metadata,
self.use_mrope)
batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope,
self.max_beam_width)
return self._cuda_graphs[batch_size][draft_len]

def __del__(self) -> None:
Expand Down Expand Up @@ -1372,8 +1372,11 @@ def _prepare_tp_inputs(
gather_ids.append(len(position_ids) - 1)

request_ids.append(request.py_request_id)
gen_request_seq_slots.append(request.py_seq_slot)
request.py_batch_idx = request.py_seq_slot
# Do not add a gen_request_seq_slot for CUDA graph dummy requests
# to prevent access errors due to None values
if not request.is_cuda_graph_dummy:
gen_request_seq_slots.append(request.py_seq_slot)

previous_batch_len = len(previous_batch_indices)

Expand Down Expand Up @@ -1502,7 +1505,7 @@ def previous_seq_slots_device():
pin_memory=True,
)

num_generation_requests = len(scheduled_requests.generation_requests)
num_generation_requests = len(gen_request_seq_slots)
# Cache indirection is only used for beam search on generation requests
if self.use_beam_search and num_generation_requests > 0:
# CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph
Expand Down
9 changes: 7 additions & 2 deletions tests/unittest/_torch/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def llm_cuda_graph(fixed_params, input_prompts):
enable_trtllm_sampler=True,
max_beam_width=fixed_params["max_beam_width"],
disable_overlap_scheduler=False,
cuda_graph_config=CudaGraphConfig(),
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8],
enable_padding=True),
)


Expand Down Expand Up @@ -128,7 +129,7 @@ def test_beam_search_output_shapes(gather_context_logits: bool,
@pytest.mark.parametrize("gather_generation_logits", [True, False])
@pytest.mark.parametrize("gather_context_logits", [True, False])
@pytest.mark.parametrize("num_output_beams", [1, 2])
@pytest.mark.parametrize("num_prompts", [1, 2])
@pytest.mark.parametrize("num_prompts", [1, 2, 3])
@pytest.mark.threadleak(enabled=False)
def test_beam_search_output_shapes_cuda_graph_and_overlap(
gather_context_logits: bool, gather_generation_logits: bool,
Expand All @@ -147,6 +148,10 @@ def test_beam_search_output_shapes_cuda_graph_and_overlap(
return_generation_logits=gather_generation_logits,
logprobs=return_log_probs,
)
# test padding of cuda graph with 3 prompts
# replicate the prompts to have more than 2 prompts available
if (num_prompts == 3 and len(input_prompts) == 2):
input_prompts = [input_prompts[0]] * 3
outputs = llm_cuda_graph.generate(input_prompts[:num_prompts],
sampling_params=sampling_params)
assert len(outputs) == num_prompts
Expand Down