Skip to content
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class PyTorchConfig:
enable_min_latency: bool = False
allreduce_strategy: str = "AUTO"

# The iteration interval to create responses under the streaming mode.
# TODO: make this a per-request parameter
stream_interval: int = 1


EXETENDED_EXECUTOR_CONFIG_FIELDS = [
'backend',
Expand Down
19 changes: 12 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def __init__(self,
self.print_log = model_engine.pytorch_backend_config.print_iter_log
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
self.stream_interval = model_engine.pytorch_backend_config.stream_interval
self.num_fetch_requests_cur_rank = 0
self.num_fetch_requests = 0
self.shutdown_event = threading.Event()
Expand Down Expand Up @@ -1607,7 +1608,7 @@ def _forward_step(self,
new_tensors_device: Optional[SampleStateTensors] = None):

@nvtx_range(
f"[Executor] _forward_step: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs"
f"[Executor] _forward_step {self.model_engine.iter_counter}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs"
)
def forward(scheduled_requests, resource_manager, new_tensors_device,
gather_context_logits):
Expand Down Expand Up @@ -1981,7 +1982,7 @@ def _enqueue_responses(self, responses: Dict[int, LlmResponse]):

logger.debug(
f'before gather, rank = {self.dist.rank}, responses = {responses}')
if self.enable_attention_dp:
if self.enable_attention_dp and self.dist.world_size != 1:
if not self.gather_all_responses:
responses_list = self.dist.tp_gather(responses)
else:
Expand Down Expand Up @@ -2044,11 +2045,14 @@ def _handle_responses(self):

request.draft_tokens = request.py_draft_tokens
request.decoding_iter = request.py_decoding_iter
response = request.create_response(False, self.dist.rank)

request_done = False
if response:
request_done = response.result.is_final
new_responses.update({req_id: response})
if self.model_engine.iter_counter % self.stream_interval == 0 or request.is_finished:
response = request.create_response(False, self.dist.rank)
if response:
request_done = response.result.is_final
new_responses.update({req_id: response})

if request_done:
if request.is_disagg_context_transmission_state:
self.ctx_in_transmission_requests.append(request)
Expand All @@ -2057,7 +2061,8 @@ def _handle_responses(self):
else:
new_active_requests.append(request)
self.active_requests = new_active_requests
self._enqueue_responses(new_responses)
if len(new_responses) > 0:
self._enqueue_responses(new_responses)
for request in requests_to_terminate:
self._terminate_request(request)
return requests_to_terminate
Expand Down
9 changes: 7 additions & 2 deletions tensorrt_llm/evaluate/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def _get_sampline_params(self, sampling_params: Optional[SamplingParams],

def evaluate(self,
llm: Any,
sampling_params: Optional[SamplingParams] = None) -> float:
sampling_params: Optional[SamplingParams] = None,
streaming: bool = False) -> float:
profiler.start("trtllm exec")
outputs, references, auxiliaries = [], [], []
for prompt, sampling_args, reference, *aux in tqdm(
Expand All @@ -87,7 +88,11 @@ def evaluate(self,
prompt = self.do_apply_chat_template(llm, prompt)
sampling_params = self._get_sampline_params(sampling_params,
sampling_args)
output = llm.generate_async(prompt, sampling_params)
output = llm.generate_async(
prompt,
sampling_params,
streaming=streaming,
)
outputs.append(output)
references.append(reference)
auxiliaries.append(aux)
Expand Down
10 changes: 9 additions & 1 deletion tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,13 @@ class TorchLlmArgs(BaseLlmArgs):
"If true, enable min-latency mode. Currently only used for Llama4.",
)

# TODO: make this a per-request parameter
stream_interval: int = Field(
default=1,
description=
"The iteration interval to create responses under the streaming mode.",
)

# TODO: remove backend later
@field_validator('backend', mode='before')
def init_backend(cls, v):
Expand Down Expand Up @@ -1812,7 +1819,8 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
autotuner_enabled=self.autotuner_enabled,
enable_layerwise_nvtx_marker=self.enable_layerwise_nvtx_marker,
load_format=self.load_format,
enable_min_latency=self.enable_min_latency)
enable_min_latency=self.enable_min_latency,
stream_interval=self.stream_interval)

@field_validator('cuda_graph_max_batch_size')
@classmethod
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/defs/accuracy/accuracy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def evaluate(self,
llm: Union[LLM, PyTorchLLM],
extra_acc_spec: Optional[str] = None,
extra_evaluator_kwargs: Optional[dict] = None,
sampling_params: Optional[SamplingParams] = None):
sampling_params: Optional[SamplingParams] = None,
streaming: bool = False):
assert self.EVALUATOR_CLS is not None

if llm.args.speculative_config is None:
Expand Down Expand Up @@ -193,7 +194,7 @@ def evaluate(self,
evaluator_kwargs.update(extra_evaluator_kwargs)
evaluator = self.EVALUATOR_CLS(num_samples=num_samples,
**evaluator_kwargs)
accuracy = evaluator.evaluate(llm, sampling_params)
accuracy = evaluator.evaluate(llm, sampling_params, streaming)
if self.HIGHER_IS_BETTER:
assert accuracy >= threshold, f"Expected accuracy >= {threshold}, but got {accuracy}."
else:
Expand Down
11 changes: 11 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def test_nvfp4(self):
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)

@skip_pre_blackwell
def test_nvfp4_streaming(self):
model_path = f"{llm_models_root()}/nvfp4-quantized/Meta-Llama-3.1-8B"
with LLM(model_path, stream_interval=4) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm, streaming=True)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm, streaming=True)


class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
Expand Down
3 changes: 3 additions & 0 deletions tests/unittest/api_stability/references/llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ methods:
postprocess_tokenizer_dir:
annotation: Optional[str]
default: null
stream_interval:
annotation: int
default: 1
# reasoning
reasoning_parser:
annotation: Optional[str]
Expand Down