Skip to content

Commit 79697c2

Browse files
authored
Merge branch 'main' into fix/circular_import_with_torch_models
Signed-off-by: rakib-hasan <[email protected]>
2 parents 219ecd8 + bcf5ec0 commit 79697c2

File tree

18 files changed

+442
-21
lines changed

18 files changed

+442
-21
lines changed

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ nvidia-modelopt[torch]~=0.33.0
2929
nvidia-nccl-cu12
3030
nvidia-cuda-nvrtc-cu12
3131
transformers==4.55.0
32+
prometheus_client
33+
prometheus_fastapi_instrumentator
3234
pydantic>=2.9.1
3335
pydantic-settings[yaml]
3436
omegaconf

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,12 @@ def deserialize(self):
250250
self._result = tensorrt_llm.bindings.executor.deserialize_result(
251251
self._result)
252252

253+
def get_result(self):
254+
if tmp_res := tensorrt_llm.bindings.executor.deserialize_result(
255+
self._result):
256+
return tmp_res
257+
return None
258+
253259

254260
@dataclass
255261
class LlmResponse:

tensorrt_llm/_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import math
2121
import os
2222
import struct
23+
import tempfile
2324
import trace
2425
import weakref
2526
from contextlib import contextmanager
@@ -1112,3 +1113,17 @@ def is_multi_device_enable():
11121113
the number of devices
11131114
"""
11141115
return local_mpi_size() > 1
1116+
1117+
1118+
def set_prometheus_multiproc_dir() -> object:
1119+
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.10/python/sglang/srt/utils.py#L1266
1120+
global prometheus_multiproc_dir
1121+
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
1122+
logger.info("User set PROMETHEUS_MULTIPROC_DIR detected.")
1123+
prometheus_multiproc_dir = tempfile.TemporaryDirectory(
1124+
dir=os.environ["PROMETHEUS_MULTIPROC_DIR"])
1125+
else:
1126+
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
1127+
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
1128+
logger.info(
1129+
f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")

tensorrt_llm/executor/postproc_worker.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import deque
44
from dataclasses import dataclass
55
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
6-
Optional)
6+
Optional, Union)
77

88
import zmq
99
import zmq.asyncio
@@ -18,7 +18,7 @@
1818

1919
if TYPE_CHECKING:
2020
from .result import (DetokenizedGenerationResultBase, GenerationResult,
21-
GenerationResultBase)
21+
GenerationResultBase, ResponseWrapper)
2222

2323
__all__ = [
2424
"PostprocWorker",
@@ -57,7 +57,7 @@ class PostprocWorker:
5757

5858
@dataclass
5959
class Input:
60-
rsp: "tllm.Response"
60+
rsp: Union["tllm.Response", "ResponseWrapper"]
6161

6262
# The information necessary for creating a GenerationResult in the first Input for each request
6363
sampling_params: Optional[SamplingParams] = None
@@ -69,6 +69,7 @@ class Output(NamedTuple):
6969
res: Any
7070
is_final: bool
7171
error: str = ""
72+
metrics: Optional[dict[str, float]] = None
7273

7374
def __init__(
7475
self,
@@ -118,7 +119,9 @@ def default_record_creator(
118119
streaming=inp.streaming,
119120
tokenizer=tokenizer)
120121

121-
async def _handle_input(self, input: "PostprocWorker.Input") -> Any:
122+
async def _handle_input(
123+
self, input: Union["PostprocWorker.Input", "ResponseWrapper"]
124+
) -> [Any, Optional[dict[str, float]]]:
122125
''' Handle a single response from await_response worker. '''
123126
if input.rsp.result.context_logits is not None or \
124127
input.rsp.result.generation_logits is not None:
@@ -139,6 +142,7 @@ async def _handle_input(self, input: "PostprocWorker.Input") -> Any:
139142
record._handle_response(input.rsp) # inplace
140143
# Left the result_handler determine the final output dtype.
141144
# NOTE: This will change the CompletionOutput._postprocess_result
145+
metrics_dict = record.metrics_dict
142146
if postproc_params := record.postproc_params:
143147
result_handler, args = postproc_params.post_processor, postproc_params.postproc_args
144148
args.tokenizer = self._tokenizer
@@ -150,7 +154,7 @@ async def _handle_input(self, input: "PostprocWorker.Input") -> Any:
150154

151155
# TODO: Keep only the diff token_ids and text in streaming mode when
152156
# result_handler is not set
153-
return out
157+
return out, metrics_dict
154158

155159
async def _batched_put(self):
156160
''' Batched IPC send. '''
@@ -173,8 +177,12 @@ async def handle_single_input(inp: PostprocWorker.Input,
173177
client_id = inp.rsp.client_id
174178
is_final = inp.rsp.result.is_final if is_llm_response(
175179
inp.rsp) else True
176-
res = await self._handle_input(inp)
177-
batch.append(PostprocWorker.Output(client_id, res, is_final))
180+
res, metrics = await self._handle_input(inp)
181+
batch.append(
182+
PostprocWorker.Output(client_id=client_id,
183+
res=res,
184+
is_final=is_final,
185+
metrics=metrics))
178186
if is_final:
179187
self._records.pop(client_id)
180188

tensorrt_llm/executor/result.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ..disaggregated_params import DisaggregatedParams
1616
from ..llmapi.tracer import global_tracer
1717
from ..llmapi.utils import AsyncQueue
18+
from ..metrics import MetricNames, MetricsCollector, RequestEventTiming
1819
from ..sampling_params import LogprobParams, SamplingParams
1920
from .utils import ErrorResponse, has_event_loop, is_llm_response
2021

@@ -50,14 +51,18 @@ class LogProbsResult(NamedTuple):
5051

5152

5253
class ResponseWrapper:
53-
"""Wrapper of runtime response with optional outputs computed post runtime.
54+
"""
55+
1. Wrapper of runtime response with optional outputs computed post runtime.
56+
2. A workaround to pass around RequestPerfMetrics.
5457
"""
5558

5659
def __init__(self,
5760
response: Union["PostprocWorker.Output", tllm.Response],
58-
logprobs: Optional[LogProbsResult] = None):
61+
logprobs: Optional[LogProbsResult] = None,
62+
request_perf_metrics: Optional[dict[str, float]] = None):
5963
self._response = response
6064
self.logprobs = logprobs
65+
self.request_perf_metrics = request_perf_metrics
6166

6267
@property
6368
def _is_llm_response(self):
@@ -68,6 +73,14 @@ def __getattr__(self, name):
6873
response = object.__getattribute__(self, '_response')
6974
return getattr(response, name)
7075

76+
def __getstate__(self):
77+
return (self._response, self.logprobs, self.request_perf_metrics)
78+
79+
def __setstate__(self, state):
80+
self._response = state[0]
81+
self.logprobs = state[1]
82+
self.request_perf_metrics = state[2]
83+
7184

7285
@dataclass(slots=True)
7386
class CompletionOutput:
@@ -146,6 +159,7 @@ def __init__(self,
146159
self.disaggregated_params = None
147160
self.decoding_iter = 0
148161
self._done = False
162+
self.metrics_dict = {}
149163

150164
if has_event_loop():
151165
self.aqueue = AsyncQueue()
@@ -201,7 +215,9 @@ def _handle_sequence(self,
201215
finish_reasons,
202216
response_tensors,
203217
sequence_index,
204-
logprobs_result=None):
218+
logprobs_result=None,
219+
req_perf_metrics_dict: Optional[dict[str,
220+
float]] = None):
205221
""" Handle a single sequence in the response. """
206222

207223
seq_idx = sequence_index
@@ -271,14 +287,17 @@ def _handle_sequence(self,
271287
else:
272288
raise ValueError(
273289
f"Unknown finish reason: {finish_reasons[src_idx]}")
290+
self.record_stats(output, req_perf_metrics_dict)
274291

275292
@nvtx_range_debug("handle_response",
276293
color="red",
277294
category="GenerationResultBase")
278295
def _handle_response(self,
279296
response: Union["PostprocWorker.Output", tllm.Response,
280297
ResponseWrapper, ErrorResponse]):
298+
req_perf_metrics_dict = None
281299
if isinstance(response, ResponseWrapper):
300+
req_perf_metrics_dict = response.request_perf_metrics
282301
logprobs_result = response.logprobs
283302
response = response._response
284303
else:
@@ -291,6 +310,8 @@ def _handle_response(self,
291310
self._outputs[0] = response.res
292311
else:
293312
self._outputs[0]._postprocess_result = response.res
313+
if response.metrics:
314+
self.metrics_dict = response.metrics
294315

295316
if response.error:
296317
if self._background_error_handler is not None and (
@@ -303,7 +324,8 @@ def _handle_response(self,
303324
handler(response.error_msg)
304325

305326
response_result = response.result
306-
if hasattr(response_result, "_result"):
327+
if hasattr(response_result, "_result") and isinstance(
328+
response_result._result, bytes):
307329
response_result.deserialize()
308330

309331
self._done = response_result.is_final
@@ -322,11 +344,12 @@ def _handle_response(self,
322344
if self.sampling_params.use_beam_search:
323345
for beam_idx, _ in enumerate(response_result.output_token_ids):
324346
self._handle_sequence(finish_reasons, response_result,
325-
beam_idx, logprobs_result)
347+
beam_idx, logprobs_result,
348+
req_perf_metrics_dict)
326349
else:
327350
self._handle_sequence(finish_reasons, response_result,
328351
response_result.sequence_index,
329-
logprobs_result)
352+
logprobs_result, req_perf_metrics_dict)
330353

331354
if response_result.context_logits is not None:
332355
self._context_logits = response_result.context_logits
@@ -342,6 +365,29 @@ def _handle_response(self,
342365
else:
343366
raise ValueError(f"Unknown response type: {response}")
344367

368+
def record_stats(self,
369+
output: CompletionOutput,
370+
stats: Optional[dict[str, float]] = None) -> None:
371+
"""Record the stats of the generation result.
372+
373+
Args:
374+
output (CompletionOutput): The output of the generation result.
375+
stats (Optional[dict[str, float]]): The stats of the generation result. Defaults to None.
376+
"""
377+
if not stats:
378+
return
379+
metrics_stats = {}
380+
if output.finish_reason:
381+
metrics_stats.update({
382+
MetricsCollector.labelname_finish_reason:
383+
output.finish_reason
384+
})
385+
processed_metrics_stat = _process_req_perf_metrics(
386+
stats, len(output.token_ids), self.sampling_params.n > 1)
387+
if processed_metrics_stat:
388+
metrics_stats.update(processed_metrics_stat)
389+
self.metrics_dict = metrics_stats
390+
345391

346392
class DetokenizedGenerationResultBase(GenerationResultBase):
347393
''' The base class for the generation result with detokenization support. '''
@@ -688,3 +734,30 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int,
688734

689735
return LogProbsResult(prompt=prompt_logprobs,
690736
generation=generation_logprobs)
737+
738+
739+
def _process_req_perf_metrics(
740+
req_perf_metrics_dict: Optional[dict[str, float]],
741+
output_length: int,
742+
is_multiple_response: bool = False) -> dict[MetricNames, float]:
743+
stat = {}
744+
if not req_perf_metrics_dict:
745+
return stat
746+
ttft = req_perf_metrics_dict.get(RequestEventTiming.FIRST_TOKEN_TIME, 0) - \
747+
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
748+
e2e = req_perf_metrics_dict.get(RequestEventTiming.LAST_TOKEN_TIME, 0) - \
749+
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
750+
request_queue_time = req_perf_metrics_dict.get(RequestEventTiming.FIRST_SCHEDULED_TIME, 0) - \
751+
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
752+
stat = {
753+
MetricNames.TTFT: ttft,
754+
MetricNames.E2E: e2e,
755+
MetricNames.REQUEST_QUEUE_TIME: request_queue_time
756+
}
757+
if output_length > 1 and not is_multiple_response:
758+
tpot = (req_perf_metrics_dict.get(
759+
RequestEventTiming.LAST_TOKEN_TIME, 0) - req_perf_metrics_dict.get(
760+
RequestEventTiming.FIRST_TOKEN_TIME, 0)) / (output_length - 1)
761+
stat.update({MetricNames.TPOT: tpot})
762+
stat = dict(filter(lambda item: item[1] > 0, stat.items()))
763+
return stat

tensorrt_llm/executor/worker.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
print_traceback_on_error)
2727
from ..lora_helper import LoraConfig
2828
from ..lora_manager import LoraManager
29+
from ..metrics import RequestEventTiming
2930
from ..prompt_adapter_manager import PromptAdapterManager
3031
from ..runtime import ModelConfig
3132
from ..runtime.model_runner import _engine_config_to_model_config
@@ -900,10 +901,8 @@ def handle_for_worker(self, responses: List[tllm.Response]) -> None:
900901
assert response is not None
901902
queue = self.worker.return_queue(response.client_id)
902903

903-
logprobs_result = _get_logprobs(self.worker, response,
904+
response = _maybe_wrap_response(self.worker, response,
904905
self.worker._is_pytorch_backend)
905-
if logprobs_result:
906-
response = ResponseWrapper(response, logprobs_result)
907906

908907
# For AsyncQueue.sync_q, we will batch the events to avoid too many
909908
# event notifications, thus put without wait here.
@@ -941,10 +940,8 @@ def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None:
941940
response = ErrorResponse(response.client_id, response.error_msg,
942941
response.request_id)
943942
else:
944-
logprobs_result = _get_logprobs(self.worker, response,
943+
response = _maybe_wrap_response(self.worker, response,
945944
self.worker._is_pytorch_backend)
946-
if logprobs_result:
947-
response = ResponseWrapper(response, logprobs_result)
948945

949946
_send_rsp(self.worker,
950947
response,
@@ -1052,3 +1049,41 @@ def _send_rsp(
10521049
worker._pop_result(response.client_id)
10531050
else:
10541051
raise ValueError(f"Unknown response type: {response}")
1052+
1053+
1054+
def _get_metrics_dict(
1055+
response: tllm.Response) -> dict[RequestEventTiming, float]:
1056+
req_perf_metrics, metrics_dict = None, {}
1057+
res = response.result
1058+
if res:
1059+
if hasattr(res, '_result'):
1060+
if result := res.get_result():
1061+
req_perf_metrics = result.request_perf_metrics
1062+
else:
1063+
req_perf_metrics = res.request_perf_metrics
1064+
if req_perf_metrics and req_perf_metrics.timing_metrics:
1065+
metrics_dict = {
1066+
RequestEventTiming.ARRIVAL_TIME:
1067+
req_perf_metrics.timing_metrics.arrival_time.total_seconds(),
1068+
RequestEventTiming.FIRST_TOKEN_TIME:
1069+
req_perf_metrics.timing_metrics.first_token_time.total_seconds(
1070+
),
1071+
RequestEventTiming.FIRST_SCHEDULED_TIME:
1072+
req_perf_metrics.timing_metrics.first_scheduled_time.
1073+
total_seconds(),
1074+
RequestEventTiming.LAST_TOKEN_TIME:
1075+
req_perf_metrics.timing_metrics.last_token_time.total_seconds()
1076+
}
1077+
return metrics_dict
1078+
1079+
1080+
def _maybe_wrap_response(
1081+
worker,
1082+
response: tllm.Response,
1083+
is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]:
1084+
1085+
logprobs_result = _get_logprobs(worker, response, is_pytorch_backend)
1086+
req_perf_metrics = _get_metrics_dict(response)
1087+
if logprobs_result or req_perf_metrics:
1088+
response = ResponseWrapper(response, logprobs_result, req_perf_metrics)
1089+
return response

tensorrt_llm/llmapi/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,7 @@ def _prepare_sampling_params(
548548
if sampling_params._stream_interval is None:
549549
sampling_params._stream_interval = getattr(self.args,
550550
"stream_interval", 1)
551-
551+
sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics
552552
return sampling_params
553553

554554
def _check_arguments(self, prompt_len: int, query_len: int,

tensorrt_llm/llmapi/llm_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,10 @@ class BaseLlmArgs(StrictBaseModel):
13111311
status="deprecated",
13121312
)
13131313

1314+
return_perf_metrics: bool = Field(default=False,
1315+
description="Return perf metrics.",
1316+
status="prototype")
1317+
13141318
_parallel_config: Optional[object] = PrivateAttr(default=None)
13151319
_model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None)
13161320
_speculative_model: Optional[str] = PrivateAttr(default=None)

0 commit comments

Comments
 (0)