diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 241fc1447ce..50306d66a66 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -34,6 +34,7 @@ def __init__( attn_metadata: AttentionMetadata, spec_metadata: Optional[SpecMetadata] = None, use_mrope: bool = False, + max_beam_width: int = 1, ) -> None: """ Stores a CUDA graph and its associated input buffers. @@ -49,19 +50,21 @@ def __init__( e.g. FlashInfer cause graph breaks). """ self.batch_size = batch_size - + self.max_beam_width = max_beam_width # [CUDA graph spec decode padding] # We pad input IDs/position IDs to the maximum draft length (token per request). # We're forced to do this because we cannot reallocate inputs over many graph runs. token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1 # Using ones instead of zeros prevents NaNs in e.g. Deepseek - self.input_ids = torch.ones((batch_size * token_per_request, ), - device=device, - dtype=torch.int32) - self.position_ids = torch.zeros((1, batch_size * token_per_request), - device=device, - dtype=torch.int32) + self.input_ids = torch.ones( + (batch_size * max_beam_width * token_per_request, ), + device=device, + dtype=torch.int32) + self.position_ids = torch.zeros( + (1, batch_size * max_beam_width * token_per_request), + device=device, + dtype=torch.int32) self.mrope_position_deltas = torch.zeros( (batch_size, 1), device=device, dtype=torch.int32) if use_mrope else None diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 247f6da1754..885db68f898 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -842,8 +842,8 @@ def _get_padded_batch( spec_resource_manager: Optional[BaseResourceManager] = None) -> int: can_run_cuda_graph = scheduled_requests.can_run_cuda_graph batch_size = scheduled_requests.batch_size - # The number of sequences in the batch is the number of prompts times the beam width. - new_batch_size = batch_size * self.max_beam_width + new_batch_size = batch_size + if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1: graph_batch_size = self.dist.tp_allgather( [can_run_cuda_graph, batch_size]) @@ -977,8 +977,8 @@ def _maybe_get_cuda_graph( self._cuda_graphs[batch_size] = {} self._cuda_graphs[batch_size][draft_len] = DecodingCUDAGraphRunner( - num_sequences_in_batch, "cuda", attn_metadata, spec_metadata, - self.use_mrope) + batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope, + self.max_beam_width) return self._cuda_graphs[batch_size][draft_len] def __del__(self) -> None: @@ -1372,8 +1372,11 @@ def _prepare_tp_inputs( gather_ids.append(len(position_ids) - 1) request_ids.append(request.py_request_id) - gen_request_seq_slots.append(request.py_seq_slot) request.py_batch_idx = request.py_seq_slot + # Do not add a gen_request_seq_slot for CUDA graph dummy requests + # to prevent access errors due to None values + if not request.is_cuda_graph_dummy: + gen_request_seq_slots.append(request.py_seq_slot) previous_batch_len = len(previous_batch_indices) @@ -1502,7 +1505,7 @@ def previous_seq_slots_device(): pin_memory=True, ) - num_generation_requests = len(scheduled_requests.generation_requests) + num_generation_requests = len(gen_request_seq_slots) # Cache indirection is only used for beam search on generation requests if self.use_beam_search and num_generation_requests > 0: # CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph diff --git a/tests/unittest/_torch/test_beam_search.py b/tests/unittest/_torch/test_beam_search.py index 1b417ef284c..569add6a58c 100644 --- a/tests/unittest/_torch/test_beam_search.py +++ b/tests/unittest/_torch/test_beam_search.py @@ -63,7 +63,8 @@ def llm_cuda_graph(fixed_params, input_prompts): enable_trtllm_sampler=True, max_beam_width=fixed_params["max_beam_width"], disable_overlap_scheduler=False, - cuda_graph_config=CudaGraphConfig(), + cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8], + enable_padding=True), ) @@ -128,7 +129,7 @@ def test_beam_search_output_shapes(gather_context_logits: bool, @pytest.mark.parametrize("gather_generation_logits", [True, False]) @pytest.mark.parametrize("gather_context_logits", [True, False]) @pytest.mark.parametrize("num_output_beams", [1, 2]) -@pytest.mark.parametrize("num_prompts", [1, 2]) +@pytest.mark.parametrize("num_prompts", [1, 2, 3]) @pytest.mark.threadleak(enabled=False) def test_beam_search_output_shapes_cuda_graph_and_overlap( gather_context_logits: bool, gather_generation_logits: bool, @@ -147,6 +148,10 @@ def test_beam_search_output_shapes_cuda_graph_and_overlap( return_generation_logits=gather_generation_logits, logprobs=return_log_probs, ) + # test padding of cuda graph with 3 prompts + # replicate the prompts to have more than 2 prompts available + if (num_prompts == 3 and len(input_prompts) == 2): + input_prompts = [input_prompts[0]] * 3 outputs = llm_cuda_graph.generate(input_prompts[:num_prompts], sampling_params=sampling_params) assert len(outputs) == num_prompts