diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 0e104185bc2..0c649b34374 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1224,6 +1224,11 @@ def _get_all_rank_num_tokens(self, attn_metadata: AttentionMetadata): return list(self.dist.tp_allgather(attn_metadata.num_tokens)) return None + def _get_all_rank_ctx_requests(self, num_ctx_requests: int): + if self.enable_attention_dp: + return list(self.dist.tp_allgather(num_ctx_requests)) + return None + def _get_padding_params( self, total_num_tokens: int, num_ctx_requests: int, attn_all_rank_num_tokens: Optional[List[int]] @@ -1237,6 +1242,9 @@ def _get_padding_params( """ padded_num_tokens = total_num_tokens + all_rank_ctx_requests = self._get_all_rank_ctx_requests( + num_ctx_requests) + def get_padded_piecewise_tokens(tokens): captured_num_tokens = self._torch_compile_backend.capture_num_tokens return captured_num_tokens[bisect.bisect_left( @@ -1249,7 +1257,12 @@ def get_padded_piecewise_tokens(tokens): -1] # Torch piecewise cuda graph is enabled. if attn_all_rank_num_tokens is not None: - can_run_piecewise_cuda_graph = (num_ctx_requests != 0 and + # Any rank has context requests, we enable piecewise cuda graph. + has_ctx_requests = num_ctx_requests != 0 or ( + all_rank_ctx_requests is not None + and any(ctx_requests != 0 + for ctx_requests in all_rank_ctx_requests)) + can_run_piecewise_cuda_graph = (has_ctx_requests and max(attn_all_rank_num_tokens) <= max_captured_num_tokens) all_ranks_can_run_piecewise_cuda_graph = list(