Skip to content
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class PyTorchConfig:
enable_min_latency: bool = False
allreduce_strategy: str = "AUTO"

# The iteration interval to create responses under the streaming mode.
# TODO: make this a per-request parameter
stream_interval: int = 1


EXETENDED_EXECUTOR_CONFIG_FIELDS = [
'backend',
Expand Down
19 changes: 12 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def __init__(self,
self.print_log = model_engine.pytorch_backend_config.print_iter_log
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
self.stream_interval = model_engine.pytorch_backend_config.stream_interval
self.num_fetch_requests_cur_rank = 0
self.num_fetch_requests = 0
self.shutdown_event = threading.Event()
Expand Down Expand Up @@ -1607,7 +1608,7 @@ def _forward_step(self,
new_tensors_device: Optional[SampleStateTensors] = None):

@nvtx_range(
f"[Executor] _forward_step: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs"
f"[Executor] _forward_step {self.model_engine.iter_counter}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs"
)
def forward(scheduled_requests, resource_manager, new_tensors_device,
gather_context_logits):
Expand Down Expand Up @@ -1979,7 +1980,7 @@ def _enqueue_responses(self, responses: Dict[int, ExecutorResponse]):

logger.debug(
f'before gather, rank = {self.dist.rank}, responses = {responses}')
if self.enable_attention_dp:
if self.enable_attention_dp and self.dist.world_size != 1:
if not self.gather_all_responses:
responses_list = self.dist.tp_gather(responses)
else:
Expand Down Expand Up @@ -2042,11 +2043,14 @@ def _handle_responses(self):

request.draft_tokens = request.py_draft_tokens
request.decoding_iter = request.py_decoding_iter
response: Response = request.create_response(False, self.dist.rank)

request_done = False
if response:
request_done = response.result.is_final
new_responses.update({req_id: response})
if self.model_engine.iter_counter % self.stream_interval == 0 or request.is_finished:
response = request.create_response(False, self.dist.rank)
if response:
request_done = response.result.is_final
new_responses.update({req_id: response})

if request_done:
if request.is_disagg_context_transmission_state:
self.ctx_in_transmission_requests.append(request)
Expand All @@ -2055,7 +2059,8 @@ def _handle_responses(self):
else:
new_active_requests.append(request)
self.active_requests = new_active_requests
self._enqueue_responses(new_responses)
if len(new_responses) > 0:
self._enqueue_responses(new_responses)
for request in requests_to_terminate:
self._terminate_request(request)
return requests_to_terminate
Expand Down
10 changes: 9 additions & 1 deletion tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,13 @@ class TorchLlmArgs(BaseLlmArgs):
"If true, enable min-latency mode. Currently only used for Llama4.",
)

# TODO: make this a per-request parameter
stream_interval: int = Field(
default=1,
description=
"The iteration interval to create responses under the streaming mode.",
)

# TODO: remove backend later
@field_validator('backend', mode='before')
def init_backend(cls, v):
Expand Down Expand Up @@ -1812,7 +1819,8 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
autotuner_enabled=self.autotuner_enabled,
enable_layerwise_nvtx_marker=self.enable_layerwise_nvtx_marker,
load_format=self.load_format,
enable_min_latency=self.enable_min_latency)
enable_min_latency=self.enable_min_latency,
stream_interval=self.stream_interval)

@field_validator('cuda_graph_max_batch_size')
@classmethod
Expand Down
3 changes: 3 additions & 0 deletions tests/unittest/api_stability/references/llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ methods:
postprocess_tokenizer_dir:
annotation: Optional[str]
default: null
stream_interval:
annotation: int
default: 1
# reasoning
reasoning_parser:
annotation: Optional[str]
Expand Down