From 56f322ea6e0a5ef9ffbb13bd20b3f7d126b0e225 Mon Sep 17 00:00:00 2001 From: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:14:45 -0700 Subject: [PATCH 01/11] Add server level request start and first token timestamp Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> --- tensorrt_llm/serve/openai_server.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 8d2977f6b4a..08f64e36141 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -3,6 +3,7 @@ import os import re import signal +import time import traceback from collections import deque from contextlib import asynccontextmanager @@ -159,6 +160,14 @@ async def validation_exception_handler(_, exc): assert isinstance(self.llm, MultimodalEncoder), "llm must be a MultimodalEncoder for multimodal encoder" self.register_mm_encoder_routes() + @self.app.middleware("http") + async def add_process_time_header(raw_request: Request, call_next): + start_time = time.monotonic() + raw_request.state.server_start_ts = start_time + response = await call_next(raw_request) + return response + + async def await_disconnected(self, raw_request: Request, promise): if raw_request is None: return @@ -322,9 +331,11 @@ async def get_perf_metrics(self) -> JSONResponse: # exclude metrics.iter since it is only meaningful when the request is not finished } metrics_json["timing_metrics"] = { + "server_start_ts": metrics_dict.pop("server_start_ts", None), "arrival_time": timing_metrics.arrival_time.total_seconds(), "first_scheduled_time": timing_metrics.first_scheduled_time.total_seconds(), "first_token_time": timing_metrics.first_token_time.total_seconds(), + "server_first_token_ts":metrics_dict.pop("server_first_token_ts", None), "last_token_time": timing_metrics.last_token_time.total_seconds(), } metrics_json["kv_cache_metrics"] = { @@ -359,7 +370,7 @@ async def get_kv_cache_events(self) -> JSONResponse: pass return JSONResponse(content=events) - async def _extract_metrics(self, res: RequestOutput): + async def _extract_metrics(self, res: RequestOutput, raw_request: Optional[Request] = None): if not res.finished: return if self.metrics_collector: @@ -370,6 +381,9 @@ async def _extract_metrics(self, res: RequestOutput): "request_id": res.request_id, "perf_metrics": res.outputs[0].request_perf_metrics } + if raw_request: + item["server_start_ts"] = getattr(raw_request.state, "server_start_ts", None) + item["server_first_token_ts"] = getattr(raw_request.state, "server_first_token_ts", None) if output.disaggregated_params: item["ctx_request_id"] = output.disaggregated_params.ctx_request_id if self.perf_metrics is not None: @@ -582,7 +596,8 @@ async def completion_response(promise: RequestOutput, if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": # Include prompt token ids for context-only requests pp_result.prompt_token_ids = response.prompt_token_ids - await self._extract_metrics(response) + raw_request.state.server_first_token_ts = time.monotonic() + await self._extract_metrics(response, raw_request) return pp_result def merge_completion_responses(responses: List[CompletionResponse]) -> CompletionResponse: @@ -619,7 +634,7 @@ async def completion_generator(promise: RequestOutput, params: Optional[Postproc pp_result = post_processor(output, args) else: pp_result = output.outputs[0]._postprocess_result - await self._extract_metrics(output) + await self._extract_metrics(output, raw_request) for pp_res in pp_result: yield pp_res except: @@ -646,6 +661,9 @@ async def producer(generator: AsyncIterator[Any], idx: int): await asyncio.gather(*tasks) async def generator_wrapper(generator: AsyncIterator[Any]): + first_response = await anext(generator) + raw_request.state.server_first_token_ts = time.monotonic() + yield first_response async for output in generator: yield output yield "data: [DONE]\n\n" From def4d5adb4cd2c3c99f09add6aa7421669f7e0ca Mon Sep 17 00:00:00 2001 From: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:14:29 -0700 Subject: [PATCH 02/11] Add disagg server level timestamp Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> --- tensorrt_llm/llmapi/llm.py | 7 +- tensorrt_llm/serve/openai_disagg_server.py | 100 ++++++++++++--------- 2 files changed, 62 insertions(+), 45 deletions(-) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index e8b119a967a..c18ca9920eb 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -354,9 +354,6 @@ def generate_async( if self._executor is None or self._executor.is_shutdown(): raise RuntimeError("LLM is shutting down") - arrival_time = steady_clock_now( - ) if self.args.return_perf_metrics else None - sampling_params = self._prepare_sampling_params(sampling_params) cache_salt_id = get_cache_salt_id( cache_salt) if cache_salt is not None else None @@ -467,6 +464,10 @@ def generate_async( if _postproc_params: _postproc_params.postproc_args.num_prompt_tokens = len( prompt_token_ids) + + arrival_time = steady_clock_now( + ) if self.args.return_perf_metrics else None + result = self._executor.generate_async( prompt_token_ids, query_token_ids=query_token_ids, diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 2b5f7dc59c0..f91dbda1387 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -4,6 +4,7 @@ import itertools import os import signal +import time import traceback from collections import deque from contextlib import asynccontextmanager @@ -12,7 +13,7 @@ import aiohttp import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, Response, StreamingResponse from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR @@ -55,7 +56,7 @@ def __init__(self, self.perf_metrics_max_requests = config.perf_metrics_max_requests if self.perf_metrics_max_requests > 0: # record corresponding keys of context and generation servers for perf metrics - # (ctx_server, gen_server, ctx_request_id) + # (ctx_server, gen_server, ctx_request_id, server_start_ts, server_first_token_ts) self.perf_metrics_keys = deque(maxlen=self.perf_metrics_max_requests) self.perf_metrics_keys_lock = asyncio.Lock() # server_key -> {ctx_request_id: perf_metrics} @@ -132,6 +133,13 @@ async def lifespan(app: FastAPI): self.app = FastAPI(lifespan=lifespan) + @self.app.middleware("http") + async def add_process_time_header(raw_request: Request, call_next): + start_time = time.monotonic() + raw_request.state.server_start_ts = start_time + response = await call_next(raw_request) + return response + @self.app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): return JSONResponse(status_code=400, content={"error": str(exc)}) @@ -185,9 +193,9 @@ async def version(self) -> JSONResponse: ver = {"version": VERSION} return JSONResponse(content=ver) - async def _add_perf_metrics_keys(self, ctx_server: str, gen_server: str, ctx_request_id: int): + async def _add_perf_metrics_keys(self, ctx_server: str, gen_server: str, ctx_request_id: int, raw_request: Request): async with self.perf_metrics_keys_lock: - self.perf_metrics_keys.append((ctx_server, gen_server, ctx_request_id)) + self.perf_metrics_keys.append((ctx_server, gen_server, ctx_request_id, raw_request.state.server_start_ts, raw_request.state.server_first_token_ts)) async def perf_metrics(self) -> JSONResponse: if self.perf_metrics_keys is None: @@ -224,50 +232,26 @@ async def perf_metrics(self) -> JSONResponse: raise exc remain_keys = [] - for ctx_server, gen_server, ctx_request_id in self.perf_metrics_keys: + for ctx_server, gen_server, ctx_request_id, server_start_ts, server_first_token_ts in self.perf_metrics_keys: gen_perf_metrics = self.server_perf_metrics[gen_server].pop(ctx_request_id, None) if gen_perf_metrics is None: # generation not finished - remain_keys.append((ctx_server, gen_server, ctx_request_id)) + remain_keys.append((ctx_server, gen_server, ctx_request_id, server_start_ts, server_first_token_ts)) continue ctx_perf_metrics = self.server_perf_metrics[ctx_server].pop(ctx_request_id, None) return_metrics.append({ "ctx_server": ctx_server, "gen_server": gen_server, + "disagg_server_start_ts": server_start_ts, + "disagg_server_first_token_ts": server_first_token_ts, "ctx_perf_metrics": ctx_perf_metrics, "gen_perf_metrics": gen_perf_metrics}) self.perf_metrics_keys = deque(remain_keys, maxlen=self.perf_metrics_max_requests) return JSONResponse(content=return_metrics) - async def merge_streaming_responses(self, ctx_response, - gen_server: str, - gen_req: Union[CompletionRequest, ChatCompletionRequest]): - try: - if ctx_response is not None and len(ctx_response.choices) != 1: - raise ValueError("Context server did not return a single choice. This is not expected") - - #If request finished after first token not due to length, return right away and skip gen - if ctx_response is not None and ctx_response.choices[0].finish_reason not in ["length", "not_finished"]: - yield "data: [DONE]\n\n".encode('utf-8') - else: - # Then yield the generation responses - await self._increment_metric("gen_total_requests") - if isinstance(gen_req, CompletionRequest): - gen_response = await self.send_completion_request(gen_server, gen_req) - elif isinstance(gen_req, ChatCompletionRequest): - gen_response = await self.send_chat_request(gen_server, gen_req) - else: - raise TypeError("Invalid request type: {type(gen_req).__name__}") - - async for chunk in gen_response.body_iterator: - yield chunk - await self._increment_metric("gen_completed_requests") - - finally: - await self.gen_router.finish_request(gen_req) - async def openai_completion(self, req: CompletionRequest) -> Response: + async def openai_completion(self, req: CompletionRequest, raw_request: Request) -> Response: try: if not isinstance(req.prompt, str): # Check if it's a list and contains integers @@ -276,15 +260,15 @@ async def openai_completion(self, req: CompletionRequest) -> Response: elif not isinstance(req.prompt, list) or not all(isinstance(x, int) for x in req.prompt): raise ValueError("Disaggregated server currently only supports single string prompt or list of integers in request") - return await self._send_disagg_request(req) + return await self._send_disagg_request(req, raw_request) except Exception as e: await self._handle_exception(e) - async def openai_chat_completion(self, req: ChatCompletionRequest) -> Response: + async def openai_chat_completion(self, req: ChatCompletionRequest, raw_request: Request) -> Response: try: - return await self._send_disagg_request(req) + return await self._send_disagg_request(req, raw_request) except Exception as e: await self._handle_exception(e) @@ -326,9 +310,44 @@ async def _send_context_request(self, ctx_server: str, ctx_req: Union[Completion return ctx_response - async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletionRequest]): + async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletionRequest], raw_request: Request): + ctx_server = None gen_server = None + ctx_request_id = None need_ctx = False + + async def _merge_streaming_responses(ctx_response, + gen_req: Union[CompletionRequest, ChatCompletionRequest]): + try: + if ctx_response is not None and len(ctx_response.choices) != 1: + raise ValueError("Context server did not return a single choice. This is not expected") + + #If request finished after first token not due to length, return right away and skip gen + if ctx_response is not None and ctx_response.choices[0].finish_reason not in ["length", "not_finished"]: + yield "data: [DONE]\n\n".encode('utf-8') + else: + # Then yield the generation responses + await self._increment_metric("gen_total_requests") + if isinstance(gen_req, CompletionRequest): + gen_response = await self.send_completion_request(gen_server, gen_req) + elif isinstance(gen_req, ChatCompletionRequest): + gen_response = await self.send_chat_request(gen_server, gen_req) + else: + raise TypeError("Invalid request type: {type(gen_req).__name__}") + + first_response = await anext(gen_response.body_iterator) + raw_request.state.server_first_token_ts = time.monotonic() + yield first_response + async for chunk in gen_response.body_iterator: + yield chunk + await self._increment_metric("gen_completed_requests") + if need_ctx and self.perf_metrics_keys is not None: + asyncio.create_task(self._add_perf_metrics_keys( + ctx_server, gen_server, ctx_request_id, raw_request)) + + + finally: + await self.gen_router.finish_request(gen_req) try: # Determine if need context server condition = self.conditional_disagg_config @@ -366,6 +385,7 @@ async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletio # Append disaggregates parameters to generation request req.disaggregated_params = ctx_response.choices[0].disaggregated_params req.disaggregated_params.request_type = "generation_only" + ctx_request_id = req.disaggregated_params.ctx_request_id # Replace the string prompt with prompt_tokens_ids if isinstance(req, CompletionRequest): @@ -382,10 +402,6 @@ async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletio gen_server, _ = await self.gen_router.get_next_server(req) logger.debug("Sending request to gen server: %s", gen_server) - if need_ctx and self.perf_metrics_keys is not None: - asyncio.create_task(self._add_perf_metrics_keys( - ctx_server, gen_server, req.disaggregated_params.ctx_request_id)) - if not req.stream: try: #If request finished after first token for reason other than length, return right away and skip gen @@ -408,7 +424,7 @@ async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletio else: # Return a streaming response that combines both context and generation responses return StreamingResponse( - self.merge_streaming_responses(ctx_response, gen_server, req), + _merge_streaming_responses(ctx_response, req), media_type="text/event-stream" ) except: From b7452ce00ac6af4ff09c42592e4f590b335e2034 Mon Sep 17 00:00:00 2001 From: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> Date: Fri, 19 Sep 2025 09:42:13 -0700 Subject: [PATCH 03/11] Add work server timestamp correlation Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> --- tensorrt_llm/serve/openai_disagg_server.py | 37 ++++++++++++++++++++-- tensorrt_llm/serve/openai_server.py | 7 ++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index f91dbda1387..8ab5232f132 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -59,8 +59,11 @@ def __init__(self, # (ctx_server, gen_server, ctx_request_id, server_start_ts, server_first_token_ts) self.perf_metrics_keys = deque(maxlen=self.perf_metrics_max_requests) self.perf_metrics_keys_lock = asyncio.Lock() - # server_key -> {ctx_request_id: perf_metrics} + # server_url -> {ctx_request_id: perf_metrics} self.server_perf_metrics: dict[str, dict[int, dict]] = {} + + # server_url -> the perf metric timestamp offset between the disagg server and worker server + self.server_perf_ts_offsets: dict[str, float] = {} else: self.perf_metrics_keys = None self.perf_metrics_keys_lock = None @@ -105,6 +108,9 @@ async def lifespan(app: FastAPI): logger.info("Waiting for context and generation servers to be ready") await self.wait_for_servers_ready(server_start_timeout_secs) + if self.perf_metrics_max_requests > 0: + await self.query_perf_ts_offsets(self.session) + if self.metadata_server: logger.info("Starting server monitoring via metadata service") await self.ctx_router.start_server_monitoring(metadata_server_cfg.refresh_interval) @@ -248,7 +254,11 @@ async def perf_metrics(self) -> JSONResponse: "gen_perf_metrics": gen_perf_metrics}) self.perf_metrics_keys = deque(remain_keys, maxlen=self.perf_metrics_max_requests) - return JSONResponse(content=return_metrics) + response = { + "server_perf_timestamp_offsets": self.server_perf_ts_offsets, + "perf_metrics": return_metrics + } + return JSONResponse(content=response) async def openai_completion(self, req: CompletionRequest, raw_request: Request) -> Response: @@ -503,6 +513,29 @@ async def send_completion_request(self, url: str, request: CompletionRequest) -> async def send_chat_request(self, url: str, request: ChatCompletionRequest) -> ChatCompletionResponse: return await self.send_request(url, request, "/v1/chat/completions", ChatCompletionResponse, self.create_chat_generator) + async def query_perf_ts_offsets(self, session: aiohttp.ClientSession): + async def query_perf_ts_offset(server_url: str) -> Optional[float]: + try: + originate_ts = time.monotonic() + async with session.get(server_url + '/perf_ts_offset') as response: + destination_ts = time.monotonic() + if response.status == 200: + response = await response.json() + receive_ts = response['receive_ts'] + transmit_ts = response['transmit_ts'] + delay = (destination_ts - originate_ts) - (transmit_ts - receive_ts) + offset = - ((receive_ts - originate_ts) + (transmit_ts - destination_ts)) / 2 + return delay, offset + else: + return None, None + except Exception: + return None + for server_url in self.ctx_servers + self.gen_servers: + delay, offset = await query_perf_ts_offset(server_url) + self.server_perf_ts_offsets[server_url] = offset + logger.info(f'Server: {server_url}, delay: {delay} second, offset: {offset} second') + logger.info(f"Server perf metrics timestamp offsets: {self.server_perf_ts_offsets}") + @classmethod async def check_server_ready(cls, session: aiohttp.ClientSession, server_url: str) -> bool: try: diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 08f64e36141..4d80a6a8152 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -215,6 +215,7 @@ def register_routes(self): # TODO: the metrics endpoint only reports iteration stats, not the runtime stats for now self.app.add_api_route("/metrics", self.get_iteration_stats, methods=["GET"]) self.app.add_api_route("/perf_metrics", self.get_perf_metrics, methods=["GET"]) + self.app.add_api_route("/perf_ts_offset", self.get_perf_ts_offset, methods=["GET"]) # TODO: workaround before ETCD support self.app.add_api_route("/kv_cache_events", self.get_kv_cache_events, methods=["POST"]) self.app.add_api_route("/v1/completions", @@ -314,6 +315,12 @@ async def get_iteration_stats(self) -> JSONResponse: stats.append(stat) return JSONResponse(content=stats) + async def get_perf_ts_offset(self) -> JSONResponse: + receive_ts = time.monotonic() + await asyncio.sleep(0.2) + transmit_ts = time.monotonic() + return JSONResponse(content={"receive_ts": receive_ts, "transmit_ts": transmit_ts}) + async def get_perf_metrics(self) -> JSONResponse: if self.perf_metrics is None: return JSONResponse(content=[]) From 7e5362e928c65c3cfb081f08b88dee08d80fc784 Mon Sep 17 00:00:00 2001 From: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> Date: Mon, 22 Sep 2025 00:28:13 -0700 Subject: [PATCH 04/11] fix negative processing time Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 02e95fa1e5b..f4ae4ea605c 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1958,7 +1958,7 @@ def _handle_responses(self): request) > 0 else [] request.decoding_iter = request.py_decoding_iter - if request.return_perf_metrics: + if request.return_perf_metrics and request.py_decoding_iter >= 1: request.update_perf_metrics(self.model_engine.iter_counter) request_done = False From 8b1633a00946a77ce738160baeb1aa16bf34653d Mon Sep 17 00:00:00 2001 From: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> Date: Mon, 22 Sep 2025 16:48:30 -0700 Subject: [PATCH 05/11] Add time sync for nodes within the same MPI WORLD Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> --- .../tensorrt_llm/batch_manager/llmRequest.h | 56 +++++++++++++------ .../nanobind/batch_manager/bindings.cpp | 7 ++- .../nanobind/batch_manager/llmRequest.cpp | 3 +- .../nanobind/batch_manager/llmRequest.h | 6 +- .../pybind/batch_manager/bindings.cpp | 7 ++- .../pybind/batch_manager/llmRequest.cpp | 3 +- .../pybind/batch_manager/llmRequest.h | 3 +- .../pyexecutor/executor_request_queue.py | 4 +- tensorrt_llm/_torch/pyexecutor/llm_request.py | 1 + tensorrt_llm/_torch/pyexecutor/py_executor.py | 13 ++++- 10 files changed, 73 insertions(+), 30 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 275bc75721a..e72157879ff 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -101,6 +101,7 @@ class GenericLlmRequest using RequestPtr = std::shared_ptr; using MillisecondsType = std::chrono::milliseconds; using TimePoint = std::chrono::time_point; + using Duration = std::chrono::time_point::duration; using CacheSaltIDType = runtime::CacheSaltIDType; GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr const& inputTokens, @@ -139,7 +140,9 @@ class GenericLlmRequest std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt) + std::optional cacheSaltID = std::nullopt, + std::optional arrivalTime = std::nullopt, + std::optional globalSteadyClockOffset = std::nullopt) : mRequestId(requestId) , mPromptLen(inputTokens->size()) , mMaxNewTokens(maxNewTokens) @@ -197,6 +200,7 @@ class GenericLlmRequest , mLanguageAdapterUid(languageAdapterUid) , mAllottedTimeMs(allottedTimeMs) , mCacheSaltID(cacheSaltID) + , mGlobalSteadyClockOffset(globalSteadyClockOffset) { if (mEncoderTokens.has_value() || encoderInputFeatures.has_value()) { @@ -224,7 +228,8 @@ class GenericLlmRequest executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1, std::optional languageAdapterUid = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt) + std::optional cacheSaltID = std::nullopt, + std::optional globalSteadyClockOffset = std::nullopt) : mRequestId(requestId) , mPromptLen(inputTokens.size()) , mMaxNewTokens(maxNewTokens) @@ -265,6 +270,7 @@ class GenericLlmRequest , mNumReturnSequences(numReturnSequences) , mLanguageAdapterUid(languageAdapterUid) , mCacheSaltID(cacheSaltID) + , mGlobalSteadyClockOffset(globalSteadyClockOffset) { if (mEncoderTokens.has_value()) { @@ -1255,7 +1261,7 @@ class GenericLlmRequest { if (mPerfMetrics.timingMetrics.firstScheduledTime == executor::RequestPerfMetrics::TimePoint{}) { - mPerfMetrics.timingMetrics.firstScheduledTime = std::chrono::steady_clock::now(); + mPerfMetrics.timingMetrics.firstScheduledTime = getCurrentSteadyClock(); } } @@ -1671,8 +1677,8 @@ class GenericLlmRequest { return false; } - auto const currentTime = std::chrono::steady_clock::now(); - auto const elapsed = (std::chrono::duration_cast(currentTime - mStartTime)); + auto const currentTime = getCurrentSteadyClock(); + auto const elapsed = (std::chrono::duration_cast(currentTime - mStartTime)); TLLM_LOG_DEBUG("Checked timeOut for request %ld with allotted Time %ld after time %ld and got %d", mRequestId, mAllottedTimeMs->count(), elapsed.count(), (elapsed >= mAllottedTimeMs)); @@ -1689,22 +1695,22 @@ class GenericLlmRequest mDecodingIter = iter; } - void setKvCacheTransferStart(std::chrono::time_point const& time) + void setKvCacheTransferStart(TimePoint const& time) { - mPerfMetrics.timingMetrics.kvCacheTransferStart = time; + mPerfMetrics.timingMetrics.kvCacheTransferStart = maybeToGlobalSteadyClock(time); } - void setKvCacheTransferEnd(std::chrono::time_point const& time) + void setKvCacheTransferEnd(TimePoint const& time) { - mPerfMetrics.timingMetrics.kvCacheTransferEnd = time; + mPerfMetrics.timingMetrics.kvCacheTransferEnd = maybeToGlobalSteadyClock(time); } - std::chrono::time_point getKvCacheTransferStart() + TimePoint getKvCacheTransferStart() { return mPerfMetrics.timingMetrics.kvCacheTransferStart; } - std::chrono::time_point getKvCacheTransferEnd() + TimePoint getKvCacheTransferEnd() { return mPerfMetrics.timingMetrics.kvCacheTransferEnd; } @@ -1788,7 +1794,7 @@ class GenericLlmRequest if (finishReason == executor::FinishReason::kTIMED_OUT) { TLLM_LOG_DEBUG("Request %ld finished by timeout after %f sec", mRequestId, - std::chrono::duration(std::chrono::steady_clock::now() - mStartTime).count()); + std::chrono::duration(getCurrentSteadyClock() - mStartTime).count()); } if (finishReason == executor::FinishReason::kCANCELLED) { @@ -1826,7 +1832,7 @@ class GenericLlmRequest void updatePerfMetrics(executor::IterationType iter) { - auto const currentTokenTime = std::chrono::steady_clock::now(); + auto const currentTokenTime = getCurrentSteadyClock(); if (!mPerfMetrics.firstIter) { @@ -2041,6 +2047,8 @@ class GenericLlmRequest // Cache salt id for each request. std::optional mCacheSaltID{std::nullopt}; + // The offset between local steady clock and glabol steady clock (at rank 0) + std::optional mGlobalSteadyClockOffset; private: void initialize( VecTokens const& inputTokens, bool outputLogProbs, std::optional arrivalTime = std::nullopt) @@ -2137,9 +2145,9 @@ class GenericLlmRequest if (mReturnPerfMetrics) { - mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(std::chrono::steady_clock::now()); + mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(getCurrentSteadyClock()); } - mStartTime = std::chrono::steady_clock::now(); + mStartTime = getCurrentSteadyClock(); } TensorPtr createListTensor(std::list const& wordsList) @@ -2167,6 +2175,20 @@ class GenericLlmRequest return tensor; } + + TimePoint maybeToGlobalSteadyClock(TimePoint const & time_point) const { + if (mGlobalSteadyClockOffset.has_value()) { + return time_point + *mGlobalSteadyClockOffset; + } else { + return time_point; + } + } + + TimePoint getCurrentSteadyClock() const { + const TimePoint time_point = std::chrono::steady_clock::now(); + + return maybeToGlobalSteadyClock(time_point); + } }; class LlmRequest : public GenericLlmRequest @@ -2223,7 +2245,7 @@ class LlmRequest : public GenericLlmRequest std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt) + std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt, std::optional globalSteadyClockOffset = std::nullopt) : Base(requestId, maxNewTokens, std::make_shared>(std::move(inputTokens)), samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), @@ -2254,7 +2276,7 @@ class LlmRequest : public GenericLlmRequest : std::optional>(std::nullopt), numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID, - arrivalTime) + arrivalTime, globalSteadyClockOffset) { } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index e0325b51c8a..e220e5cf910 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -291,7 +291,8 @@ void initBindings(nb::module_& m) std::optional allotted_time_ms, std::optional context_phase_params, std::optional cache_salt_id, - std::optional arrival_time) + std::optional arrival_time, + std::optional global_steady_clock_offset) { auto makeOptionalTensor = [](std::optional const& atTensor, bool unsqueeze = false) { @@ -332,7 +333,7 @@ void initBindings(nb::module_& m) encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, - arrival_time}; + arrival_time, global_steady_clock_offset}; }, nb::arg("request_id"), nb::arg("max_new_tokens"), nb::arg("input_tokens"), nb::arg("sampling_config"), nb::arg("is_streaming"), nb::arg("end_id") = std::nullopt, nb::arg("pad_id") = std::nullopt, @@ -358,7 +359,7 @@ void initBindings(nb::module_& m) nb::arg("return_perf_metrics") = false, nb::arg("guided_decoding_params") = std::nullopt, nb::arg("language_adapter_uid") = std::nullopt, nb::arg("allotted_time_ms") = std::nullopt, nb::arg("context_phase_params") = std::nullopt, nb::arg("cache_salt_id") = std::nullopt, - nb::arg("arrival_time") = std::nullopt) + nb::arg("arrival_time") = std::nullopt, nb::arg("global_steady_clock_offset") = std::nullopt) .def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, nb::arg("vocab_size")) .def(nb::init()) .def("validate", &tb::LlmRequest::validate, nb::arg("max_input_len"), nb::arg("max_seq_len"), diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp index 07d630cb3b2..ac40d299592 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp @@ -127,6 +127,7 @@ std::shared_ptr LlmRequest::toTrtLlm() const mAllottedTimeMs, // mContextPhaseParams, // mCacheSaltID, // - mPerfMetrics.timingMetrics.arrivalTime // + mPerfMetrics.timingMetrics.arrivalTime, // + mGlobalSteadyClockOffset // ); } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h index 4ea47fdcc8c..c81b49257c3 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h @@ -86,7 +86,8 @@ class LlmRequest : public tb::GenericLlmRequest std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, std::optional cacheSaltID = std::nullopt, - std::optional arrivalTime = std::nullopt) + std::optional arrivalTime = std::nullopt, + std::optional globalSteadyClockOffset = std::nullopt) : Base(requestId, // maxNewTokens, // std::make_shared>(std::move(inputTokens)), // @@ -149,7 +150,8 @@ class LlmRequest : public tb::GenericLlmRequest allottedTimeMs, // contextPhaseParams, // cacheSaltID, // - arrivalTime // + arrivalTime, // + globalSteadyClockOffset // ) { } diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 9bcd22e39e4..3353e7a33b5 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -296,7 +296,8 @@ void initBindings(pybind11::module_& m) std::optional allotted_time_ms, std::optional context_phase_params, std::optional cache_salt_id, - std::optional arrival_time) + std::optional arrival_time, + std::optional globalSteadyClockOffset = std::nullopt) { auto makeOptionalTensor = [](std::optional const& atTensor, bool unsqueeze = false) { @@ -337,7 +338,7 @@ void initBindings(pybind11::module_& m) encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, guided_decoding_params, - language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, arrival_time}; + language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, arrival_time, global_steady_clock_offset}; }), py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"), py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt, @@ -364,7 +365,7 @@ void initBindings(pybind11::module_& m) py::arg("return_perf_metrics") = false, py::arg("guided_decoding_params") = std::nullopt, py::arg("language_adapter_uid") = std::nullopt, py::arg("allotted_time_ms") = std::nullopt, py::arg("context_phase_params") = std::nullopt, py::arg("cache_salt_id") = std::nullopt, - py::arg("arrival_time") = std::nullopt) + nb::arg("arrival_time") = std::nullopt, nb::arg("global_steady_clock_offset") = std::nullopt) .def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, py::arg("vocab_size")) .def(py::init()) .def("validate", &tb::LlmRequest::validate, py::arg("max_input_len"), py::arg("max_seq_len"), diff --git a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp index bcc9d4bf13f..53a29e32223 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp @@ -126,6 +126,7 @@ std::shared_ptr LlmRequest::toTrtLlm() const mAllottedTimeMs, // mContextPhaseParams, // mCacheSaltID, // - mPerfMetrics.timingMetrics.arrivalTime // + mPerfMetrics.timingMetrics.arrivalTime, // + mGlobalSteadyClockOffset // ); } diff --git a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h index b43fb8dd073..96d8b906b74 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h @@ -86,7 +86,8 @@ class LlmRequest : public tb::GenericLlmRequest std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, std::optional cacheSaltID = std::nullopt, - std::optional arrivalTime = std::nullopt) + std::optional arrivalTime = std::nullopt, + std::optional globalSteadyClockOffset = std::nullopt) : Base(requestId, // maxNewTokens, // std::make_shared>(std::move(inputTokens)), // diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index c03673f34e7..86ce4fc6850 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -44,7 +44,7 @@ class ExecutorRequestQueue: def __init__(self, dist: Distributed, enable_attention_dp: bool, max_batch_size: int, max_beam_width: int, max_num_active_requests: int, enable_iter_perf_stats: bool, - batch_wait_timeout_ms: float, is_disaggregated: bool): + batch_wait_timeout_ms: float, is_disaggregated: bool, monotonic_ts_offset: float): self.dist = dist self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue() self.waiting_queue: deque[RequestQueueItem] = deque() @@ -60,6 +60,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool, self.start_times = {} self.active = True self.batch_wait_timeout_ms = batch_wait_timeout_ms + self.monotonic_ts_offset = monotonic_ts_offset # State tracking self.num_fetch_requests = 0 @@ -611,6 +612,7 @@ def _merge_requests( else: req_with_children = [] for req_item in new_requests: + req_item.request.py_global_steady_clock_offset = self.monotonic_ts_offset req = executor_request_to_llm_request( req_item.id, req_item.request, req_item.child_req_ids, self._should_exclude_last_generation_logits()) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 1f971095126..a7208a8bd20 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -593,6 +593,7 @@ def executor_request_to_llm_request( context_phase_params=executor_request.context_phase_params, cache_salt_id=executor_request.cache_salt_id, arrival_time=getattr(executor_request, "py_arrival_time", None), + global_steady_clock_offset=getattr(executor_request, "py_global_steady_clock_offset", None), py_multimodal_data=getattr(executor_request, "py_multimodal_data", None)) if child_req_ids: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index f4ae4ea605c..39651d5f1cf 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -166,6 +166,8 @@ def __init__(self, super(PyExecutor, self).__init__() self.device_id = torch.cuda.current_device() self.global_rank = global_mpi_rank() + self.dist = dist + self.monotonic_ts_offset = self._get_monotonic_ts_offset() self.peft_cache_config = peft_cache_config @@ -184,7 +186,6 @@ def __init__(self, self.draft_model_engine = getattr(self.drafter, "draft_model_engine", None) self.guided_decoder = guided_decoder - self.dist = dist self.disable_overlap_scheduler = disable_overlap_scheduler # enqueue and _fetch_new_requests used data @@ -263,6 +264,7 @@ def __init__(self, enable_iter_perf_stats=self.enable_iter_perf_stats, batch_wait_timeout_ms=self.batch_wait_timeout_ms, is_disaggregated=kv_cache_transceiver is not None, + monotonic_ts_offset = self.monotonic_ts_offset ) self.executor_request_queue.set_exclude_last_generation_logits( self.disable_overlap_scheduler, self.dist.pp_size) @@ -365,6 +367,15 @@ def start_worker(self): self.worker_thread.start() self.worker_started = True + def _get_monotonic_ts_offset(self): + assert self.global_rank >= 0, "rank should be >= 0" + self.dist.barrier() + local_timestamp = time.monotonic() + timestamps = self.dist.allgather(local_timestamp) + if self.global_rank == 0: + logger.info(f"monotonic_ts_offsets for each rank: {[local_timestamp - ts for ts in timestamps]}") + return timestamps[0] - local_timestamp + def __enter__(self): return self From 91a086d2fb23d939dcaaa3b832ea1fb30485f0ee Mon Sep 17 00:00:00 2001 From: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> Date: Tue, 23 Sep 2025 22:40:33 -0700 Subject: [PATCH 06/11] Support perf_metrics for multiple post worker Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> --- tensorrt_llm/executor/postproc_worker.py | 13 ++++++++++--- tensorrt_llm/executor/result.py | 6 ++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/executor/postproc_worker.py b/tensorrt_llm/executor/postproc_worker.py index 55bf7839f4b..ea32215679c 100644 --- a/tensorrt_llm/executor/postproc_worker.py +++ b/tensorrt_llm/executor/postproc_worker.py @@ -69,6 +69,8 @@ class Output(NamedTuple): is_final: bool error: str = "" metrics: Optional[dict[str, float]] = None + request_perf_metrics: Any = None + disaggregated_params: Any = None def __init__( self, @@ -142,6 +144,8 @@ async def _handle_input( # Left the result_handler determine the final output dtype. # NOTE: This will change the CompletionOutput._postprocess_result metrics_dict = record.metrics_dict + perf_metrics = record.outputs[0].request_perf_metrics + disaggregated_params = record.outputs[0].disaggregated_params if postproc_params := record.postproc_params: result_handler, args = postproc_params.post_processor, postproc_params.postproc_args args.tokenizer = self._tokenizer @@ -153,7 +157,7 @@ async def _handle_input( # TODO: Keep only the diff token_ids and text in streaming mode when # result_handler is not set - return out, metrics_dict + return out, metrics_dict, perf_metrics, disaggregated_params async def _batched_put(self): ''' Batched IPC send. ''' @@ -176,12 +180,15 @@ async def handle_single_input(inp: PostprocWorker.Input, client_id = inp.rsp.client_id is_final = inp.rsp.result.is_final if is_llm_response( inp.rsp) else True - res, metrics = await self._handle_input(inp) + res, metrics, perf_metrics, disaggregated_params = await self._handle_input(inp) batch.append( PostprocWorker.Output(client_id=client_id, res=res, is_final=is_final, - metrics=metrics)) + metrics=metrics, + request_perf_metrics=perf_metrics, + disaggregated_params=disaggregated_params, + )) if is_final: self._records.pop(client_id) diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index d19a8368297..a6882a7cdb2 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -7,6 +7,7 @@ NamedTuple, Optional, TypeAlias, Union) from weakref import WeakMethod +from tensorrt_llm.logger import logger import torch import torch.nn.functional as F @@ -336,6 +337,11 @@ def _handle_response(self, self._outputs[0] = response.res else: self._outputs[0]._postprocess_result = response.res + self._outputs[0].request_perf_metrics = response.request_perf_metrics + if response.disaggregated_params: + self._outputs[0].disaggregated_params = response.disaggregated_params + else: + self._outputs[0].disaggregated_params = self.disaggregated_params if response.metrics: self.metrics_dict = response.metrics From 6f0f5824e3667759879e3390f6a3706ff3f6ed45 Mon Sep 17 00:00:00 2001 From: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> Date: Thu, 25 Sep 2025 17:52:23 -0700 Subject: [PATCH 07/11] code refactor Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> --- .../tensorrt_llm/batch_manager/llmRequest.h | 38 +++++++++------- .../pyexecutor/executor_request_queue.py | 9 ++-- tensorrt_llm/_torch/pyexecutor/py_executor.py | 19 +++++--- tensorrt_llm/executor/result.py | 18 +++++--- tensorrt_llm/llmapi/llm.py | 7 ++- tensorrt_llm/serve/openai_disagg_server.py | 43 +++++++++--------- tensorrt_llm/serve/openai_server.py | 44 ++++++++++++------- 7 files changed, 106 insertions(+), 72 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index e72157879ff..eb76e7565e7 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -140,8 +140,7 @@ class GenericLlmRequest std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, - std::optional arrivalTime = std::nullopt, + std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt, std::optional globalSteadyClockOffset = std::nullopt) : mRequestId(requestId) , mPromptLen(inputTokens->size()) @@ -1261,7 +1260,7 @@ class GenericLlmRequest { if (mPerfMetrics.timingMetrics.firstScheduledTime == executor::RequestPerfMetrics::TimePoint{}) { - mPerfMetrics.timingMetrics.firstScheduledTime = getCurrentSteadyClock(); + mPerfMetrics.timingMetrics.firstScheduledTime = getSteadyClockNow(); } } @@ -1677,7 +1676,7 @@ class GenericLlmRequest { return false; } - auto const currentTime = getCurrentSteadyClock(); + auto const currentTime = getSteadyClockNow(); auto const elapsed = (std::chrono::duration_cast(currentTime - mStartTime)); TLLM_LOG_DEBUG("Checked timeOut for request %ld with allotted Time %ld after time %ld and got %d", mRequestId, mAllottedTimeMs->count(), elapsed.count(), (elapsed >= mAllottedTimeMs)); @@ -1794,7 +1793,7 @@ class GenericLlmRequest if (finishReason == executor::FinishReason::kTIMED_OUT) { TLLM_LOG_DEBUG("Request %ld finished by timeout after %f sec", mRequestId, - std::chrono::duration(getCurrentSteadyClock() - mStartTime).count()); + std::chrono::duration(getSteadyClockNow() - mStartTime).count()); } if (finishReason == executor::FinishReason::kCANCELLED) { @@ -1832,10 +1831,9 @@ class GenericLlmRequest void updatePerfMetrics(executor::IterationType iter) { - auto const currentTokenTime = getCurrentSteadyClock(); - if (!mPerfMetrics.firstIter) { + auto const currentTokenTime = getSteadyClockNow(); mPerfMetrics.firstIter = iter; mPerfMetrics.timingMetrics.firstTokenTime = currentTokenTime; } @@ -1844,6 +1842,7 @@ class GenericLlmRequest if (isFinished()) { + auto const currentTokenTime = getSteadyClockNow(); mPerfMetrics.lastIter = iter; mPerfMetrics.timingMetrics.lastTokenTime = currentTokenTime; } @@ -2047,8 +2046,9 @@ class GenericLlmRequest // Cache salt id for each request. std::optional mCacheSaltID{std::nullopt}; - // The offset between local steady clock and glabol steady clock (at rank 0) + // The offset between local steady clock and global steady clock (at rank 0) std::optional mGlobalSteadyClockOffset; + private: void initialize( VecTokens const& inputTokens, bool outputLogProbs, std::optional arrivalTime = std::nullopt) @@ -2145,9 +2145,9 @@ class GenericLlmRequest if (mReturnPerfMetrics) { - mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(getCurrentSteadyClock()); + mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(getSteadyClockNow()); } - mStartTime = getCurrentSteadyClock(); + mStartTime = getSteadyClockNow(); } TensorPtr createListTensor(std::list const& wordsList) @@ -2176,15 +2176,22 @@ class GenericLlmRequest return tensor; } - TimePoint maybeToGlobalSteadyClock(TimePoint const & time_point) const { - if (mGlobalSteadyClockOffset.has_value()) { + TimePoint maybeToGlobalSteadyClock(TimePoint const& time_point) const + { + if (mGlobalSteadyClockOffset.has_value()) + { return time_point + *mGlobalSteadyClockOffset; - } else { + } + else + { return time_point; } } - TimePoint getCurrentSteadyClock() const { + // If mGlobalSteadyClockOffset is set, return a global steady clock time point, otherwise return local steady clock + // time point + TimePoint getSteadyClockNow() const + { const TimePoint time_point = std::chrono::steady_clock::now(); return maybeToGlobalSteadyClock(time_point); @@ -2245,7 +2252,8 @@ class LlmRequest : public GenericLlmRequest std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt, std::optional globalSteadyClockOffset = std::nullopt) + std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt, + std::optional globalSteadyClockOffset = std::nullopt) : Base(requestId, maxNewTokens, std::make_shared>(std::move(inputTokens)), samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 86ce4fc6850..1ab252d69cb 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -44,7 +44,8 @@ class ExecutorRequestQueue: def __init__(self, dist: Distributed, enable_attention_dp: bool, max_batch_size: int, max_beam_width: int, max_num_active_requests: int, enable_iter_perf_stats: bool, - batch_wait_timeout_ms: float, is_disaggregated: bool, monotonic_ts_offset: float): + batch_wait_timeout_ms: float, is_disaggregated: bool, + global_steady_clock_offset: float): self.dist = dist self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue() self.waiting_queue: deque[RequestQueueItem] = deque() @@ -60,7 +61,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool, self.start_times = {} self.active = True self.batch_wait_timeout_ms = batch_wait_timeout_ms - self.monotonic_ts_offset = monotonic_ts_offset + self.global_steady_clock_offset = global_steady_clock_offset # State tracking self.num_fetch_requests = 0 @@ -612,7 +613,9 @@ def _merge_requests( else: req_with_children = [] for req_item in new_requests: - req_item.request.py_global_steady_clock_offset = self.monotonic_ts_offset + if self.global_steady_clock_offset: + req_item.request.py_global_steady_clock_offset = self.global_steady_clock_offset + req = executor_request_to_llm_request( req_item.id, req_item.request, req_item.child_req_ids, self._should_exclude_last_generation_logits()) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 39651d5f1cf..e1b23e44926 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -167,7 +167,7 @@ def __init__(self, self.device_id = torch.cuda.current_device() self.global_rank = global_mpi_rank() self.dist = dist - self.monotonic_ts_offset = self._get_monotonic_ts_offset() + self.global_steady_clock_offset = self._get_global_steady_clock_offset() self.peft_cache_config = peft_cache_config @@ -264,7 +264,7 @@ def __init__(self, enable_iter_perf_stats=self.enable_iter_perf_stats, batch_wait_timeout_ms=self.batch_wait_timeout_ms, is_disaggregated=kv_cache_transceiver is not None, - monotonic_ts_offset = self.monotonic_ts_offset + global_steady_clock_offset=self.global_steady_clock_offset, ) self.executor_request_queue.set_exclude_last_generation_logits( self.disable_overlap_scheduler, self.dist.pp_size) @@ -367,14 +367,20 @@ def start_worker(self): self.worker_thread.start() self.worker_started = True - def _get_monotonic_ts_offset(self): + def _get_global_steady_clock_offset(self): assert self.global_rank >= 0, "rank should be >= 0" + + # Sync all ranks self.dist.barrier() + # Immediately take the local steady clock timestamp local_timestamp = time.monotonic() - timestamps = self.dist.allgather(local_timestamp) + all_rank_timestamps = self.dist.allgather(local_timestamp) if self.global_rank == 0: - logger.info(f"monotonic_ts_offsets for each rank: {[local_timestamp - ts for ts in timestamps]}") - return timestamps[0] - local_timestamp + logger.info( + f"global_steady_clock_offset at each rank: {[local_timestamp - ts for ts in all_rank_timestamps]}" + ) + # Compute the steady clock offset between rank 0 and current rank + return all_rank_timestamps[0] - local_timestamp def __enter__(self): return self @@ -1969,6 +1975,7 @@ def _handle_responses(self): request) > 0 else [] request.decoding_iter = request.py_decoding_iter + # Skip active requests that are not scheduled if request.return_perf_metrics and request.py_decoding_iter >= 1: request.update_perf_metrics(self.model_engine.iter_counter) diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index a6882a7cdb2..4923b776487 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -7,7 +7,6 @@ NamedTuple, Optional, TypeAlias, Union) from weakref import WeakMethod -from tensorrt_llm.logger import logger import torch import torch.nn.functional as F @@ -337,11 +336,18 @@ def _handle_response(self, self._outputs[0] = response.res else: self._outputs[0]._postprocess_result = response.res - self._outputs[0].request_perf_metrics = response.request_perf_metrics - if response.disaggregated_params: - self._outputs[0].disaggregated_params = response.disaggregated_params - else: - self._outputs[0].disaggregated_params = self.disaggregated_params + + self._outputs[ + 0].request_perf_metrics = response.request_perf_metrics + if not self._outputs[0].disaggregated_params: + disaggregated_params = response.disaggregated_params + + # Generation only response has no disaggregated_params attached + if not disaggregated_params: + disaggregated_params = self.disaggregated_params + + self._outputs[0].disaggregated_params = disaggregated_params + if response.metrics: self.metrics_dict = response.metrics diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index c18ca9920eb..e8b119a967a 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -354,6 +354,9 @@ def generate_async( if self._executor is None or self._executor.is_shutdown(): raise RuntimeError("LLM is shutting down") + arrival_time = steady_clock_now( + ) if self.args.return_perf_metrics else None + sampling_params = self._prepare_sampling_params(sampling_params) cache_salt_id = get_cache_salt_id( cache_salt) if cache_salt is not None else None @@ -464,10 +467,6 @@ def generate_async( if _postproc_params: _postproc_params.postproc_args.num_prompt_tokens = len( prompt_token_ids) - - arrival_time = steady_clock_now( - ) if self.args.return_perf_metrics else None - result = self._executor.generate_async( prompt_token_ids, query_token_ids=query_token_ids, diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 8ab5232f132..9719553ca03 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -56,14 +56,12 @@ def __init__(self, self.perf_metrics_max_requests = config.perf_metrics_max_requests if self.perf_metrics_max_requests > 0: # record corresponding keys of context and generation servers for perf metrics - # (ctx_server, gen_server, ctx_request_id, server_start_ts, server_first_token_ts) + # (ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_ts) self.perf_metrics_keys = deque(maxlen=self.perf_metrics_max_requests) self.perf_metrics_keys_lock = asyncio.Lock() # server_url -> {ctx_request_id: perf_metrics} self.server_perf_metrics: dict[str, dict[int, dict]] = {} - # server_url -> the perf metric timestamp offset between the disagg server and worker server - self.server_perf_ts_offsets: dict[str, float] = {} else: self.perf_metrics_keys = None self.perf_metrics_keys_lock = None @@ -109,7 +107,7 @@ async def lifespan(app: FastAPI): await self.wait_for_servers_ready(server_start_timeout_secs) if self.perf_metrics_max_requests > 0: - await self.query_perf_ts_offsets(self.session) + await self.set_steady_clock_offsets(self.session) if self.metadata_server: logger.info("Starting server monitoring via metadata service") @@ -142,7 +140,7 @@ async def lifespan(app: FastAPI): @self.app.middleware("http") async def add_process_time_header(raw_request: Request, call_next): start_time = time.monotonic() - raw_request.state.server_start_ts = start_time + raw_request.state.server_arrival_time = start_time response = await call_next(raw_request) return response @@ -201,7 +199,7 @@ async def version(self) -> JSONResponse: async def _add_perf_metrics_keys(self, ctx_server: str, gen_server: str, ctx_request_id: int, raw_request: Request): async with self.perf_metrics_keys_lock: - self.perf_metrics_keys.append((ctx_server, gen_server, ctx_request_id, raw_request.state.server_start_ts, raw_request.state.server_first_token_ts)) + self.perf_metrics_keys.append((ctx_server, gen_server, ctx_request_id, raw_request.state.server_arrival_time, raw_request.state.server_first_token_ts)) async def perf_metrics(self) -> JSONResponse: if self.perf_metrics_keys is None: @@ -238,27 +236,23 @@ async def perf_metrics(self) -> JSONResponse: raise exc remain_keys = [] - for ctx_server, gen_server, ctx_request_id, server_start_ts, server_first_token_ts in self.perf_metrics_keys: + for ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_ts in self.perf_metrics_keys: gen_perf_metrics = self.server_perf_metrics[gen_server].pop(ctx_request_id, None) if gen_perf_metrics is None: # generation not finished - remain_keys.append((ctx_server, gen_server, ctx_request_id, server_start_ts, server_first_token_ts)) + remain_keys.append((ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_ts)) continue ctx_perf_metrics = self.server_perf_metrics[ctx_server].pop(ctx_request_id, None) return_metrics.append({ "ctx_server": ctx_server, "gen_server": gen_server, - "disagg_server_start_ts": server_start_ts, + "disagg_server_arrival_time": server_arrival_time, "disagg_server_first_token_ts": server_first_token_ts, "ctx_perf_metrics": ctx_perf_metrics, "gen_perf_metrics": gen_perf_metrics}) self.perf_metrics_keys = deque(remain_keys, maxlen=self.perf_metrics_max_requests) - response = { - "server_perf_timestamp_offsets": self.server_perf_ts_offsets, - "perf_metrics": return_metrics - } - return JSONResponse(content=response) + return JSONResponse(content=return_metrics) async def openai_completion(self, req: CompletionRequest, raw_request: Request) -> Response: @@ -513,28 +507,35 @@ async def send_completion_request(self, url: str, request: CompletionRequest) -> async def send_chat_request(self, url: str, request: ChatCompletionRequest) -> ChatCompletionResponse: return await self.send_request(url, request, "/v1/chat/completions", ChatCompletionResponse, self.create_chat_generator) - async def query_perf_ts_offsets(self, session: aiohttp.ClientSession): - async def query_perf_ts_offset(server_url: str) -> Optional[float]: + async def set_steady_clock_offsets(self, session: aiohttp.ClientSession): + STEADY_CLOCK_OFFSET_ENDPOINT = "/steady_clock_offset" + async def query_steady_clock_offset(server_url: str) -> Optional[float]: try: originate_ts = time.monotonic() - async with session.get(server_url + '/perf_ts_offset') as response: + async with session.get(server_url + STEADY_CLOCK_OFFSET_ENDPOINT) as response: destination_ts = time.monotonic() if response.status == 200: response = await response.json() + # Compute the steady clock timestamp difference using the NTP clock synchronization algorithm. https://en.wikipedia.org/wiki/Network_Time_Protocol#Clock_synchronization_algorithm receive_ts = response['receive_ts'] transmit_ts = response['transmit_ts'] delay = (destination_ts - originate_ts) - (transmit_ts - receive_ts) - offset = - ((receive_ts - originate_ts) + (transmit_ts - destination_ts)) / 2 + offset = ((receive_ts - originate_ts) + (transmit_ts - destination_ts)) / 2 return delay, offset else: return None, None except Exception: return None + async def set_steady_clock_offset(server_url: str, offset: float) -> Optional[float]: + payload = {"offset": offset} + async with session.post(server_url + STEADY_CLOCK_OFFSET_ENDPOINT, json=payload) as response: + if response.status != 200: + logger.warning(f"Cannot set disagg server steady clock offset for server {server_url}, the perf metrics timestamps could be mis-aligned") for server_url in self.ctx_servers + self.gen_servers: - delay, offset = await query_perf_ts_offset(server_url) - self.server_perf_ts_offsets[server_url] = offset + delay, offset = await query_steady_clock_offset(server_url) logger.info(f'Server: {server_url}, delay: {delay} second, offset: {offset} second') - logger.info(f"Server perf metrics timestamp offsets: {self.server_perf_ts_offsets}") + # Negate the offset so that worker servers can adjust their steady block by adding the new offset + await set_steady_clock_offset(server_url, -offset) @classmethod async def check_server_ready(cls, session: aiohttp.ClientSession, server_url: str) -> bool: diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 4d80a6a8152..1adec4834e4 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -10,10 +10,11 @@ from datetime import datetime from http import HTTPStatus from pathlib import Path -from typing import Any, AsyncGenerator, AsyncIterator, List, Optional, Union +from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, List, + Optional, Union) import uvicorn -from fastapi import FastAPI, Request +from fastapi import Body, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, Response, StreamingResponse from starlette.routing import Mount @@ -106,6 +107,8 @@ def __init__(self, self.metrics_collector = None self.perf_metrics = None self.perf_metrics_lock = None + # The steady clock offset (in seconds) between this server and the disagg server + self.disagg_server_steady_clock_offset = 0 if self.llm.args.return_perf_metrics: set_prometheus_multiproc_dir() self.metrics_collector = MetricsCollector({ @@ -163,7 +166,7 @@ async def validation_exception_handler(_, exc): @self.app.middleware("http") async def add_process_time_header(raw_request: Request, call_next): start_time = time.monotonic() - raw_request.state.server_start_ts = start_time + raw_request.state.server_arrival_time = start_time response = await call_next(raw_request) return response @@ -215,7 +218,9 @@ def register_routes(self): # TODO: the metrics endpoint only reports iteration stats, not the runtime stats for now self.app.add_api_route("/metrics", self.get_iteration_stats, methods=["GET"]) self.app.add_api_route("/perf_metrics", self.get_perf_metrics, methods=["GET"]) - self.app.add_api_route("/perf_ts_offset", self.get_perf_ts_offset, methods=["GET"]) + self.app.add_api_route("/steady_clock_offset", self.get_steady_clock_offset, methods=["GET"]) + # Called by the disagg server to set the disagg_server_steady_clock_offset + self.app.add_api_route("/steady_clock_offset", self.set_steady_clock_offset, methods=["POST"]) # TODO: workaround before ETCD support self.app.add_api_route("/kv_cache_events", self.get_kv_cache_events, methods=["POST"]) self.app.add_api_route("/v1/completions", @@ -315,7 +320,12 @@ async def get_iteration_stats(self) -> JSONResponse: stats.append(stat) return JSONResponse(content=stats) - async def get_perf_ts_offset(self) -> JSONResponse: + async def set_steady_clock_offset(self, offset: Annotated[float, Body(embed=True)]) -> Response: + self.disagg_server_steady_clock_offset = offset + logger.info(f"The steady clock offset between local and disagg server: {offset} second") + return Response(status_code=200) + + async def get_steady_clock_offset(self) -> JSONResponse: receive_ts = time.monotonic() await asyncio.sleep(0.2) transmit_ts = time.monotonic() @@ -338,12 +348,12 @@ async def get_perf_metrics(self) -> JSONResponse: # exclude metrics.iter since it is only meaningful when the request is not finished } metrics_json["timing_metrics"] = { - "server_start_ts": metrics_dict.pop("server_start_ts", None), - "arrival_time": timing_metrics.arrival_time.total_seconds(), - "first_scheduled_time": timing_metrics.first_scheduled_time.total_seconds(), - "first_token_time": timing_metrics.first_token_time.total_seconds(), - "server_first_token_ts":metrics_dict.pop("server_first_token_ts", None), - "last_token_time": timing_metrics.last_token_time.total_seconds(), + "server_arrival_time": metrics_dict.pop("server_arrival_time", None) + self.disagg_server_steady_clock_offset, + "arrival_time": timing_metrics.arrival_time.total_seconds() + self.disagg_server_steady_clock_offset, + "first_scheduled_time": timing_metrics.first_scheduled_time.total_seconds() + self.disagg_server_steady_clock_offset, + "first_token_time": timing_metrics.first_token_time.total_seconds() + self.disagg_server_steady_clock_offset, + "server_first_token_time":metrics_dict.pop("server_first_token_time", None) + self.disagg_server_steady_clock_offset, + "last_token_time": timing_metrics.last_token_time.total_seconds() + self.disagg_server_steady_clock_offset, } metrics_json["kv_cache_metrics"] = { "num_total_allocated_blocks": kv_cache_metrics.num_total_allocated_blocks, @@ -355,8 +365,8 @@ async def get_perf_metrics(self) -> JSONResponse: metrics_json["timing_metrics"].update({ # TODO: move to kv_cache_metrics "kv_cache_size": timing_metrics.kv_cache_size, - "kv_cache_transfer_start": timing_metrics.kv_cache_transfer_start.total_seconds(), - "kv_cache_transfer_end": timing_metrics.kv_cache_transfer_end.total_seconds(), + "kv_cache_transfer_start": timing_metrics.kv_cache_transfer_start.total_seconds() + self.disagg_server_steady_clock_offset, + "kv_cache_transfer_end": timing_metrics.kv_cache_transfer_end.total_seconds() + self.disagg_server_steady_clock_offset, }) if speculative_decoding.total_draft_tokens > 0: metrics_json["speculative_decoding"] = { @@ -389,8 +399,8 @@ async def _extract_metrics(self, res: RequestOutput, raw_request: Optional[Reque "perf_metrics": res.outputs[0].request_perf_metrics } if raw_request: - item["server_start_ts"] = getattr(raw_request.state, "server_start_ts", None) - item["server_first_token_ts"] = getattr(raw_request.state, "server_first_token_ts", None) + item["server_arrival_time"] = getattr(raw_request.state, "server_arrival_time", None) + item["server_first_token_time"] = getattr(raw_request.state, "server_first_token_time", None) if output.disaggregated_params: item["ctx_request_id"] = output.disaggregated_params.ctx_request_id if self.perf_metrics is not None: @@ -603,7 +613,7 @@ async def completion_response(promise: RequestOutput, if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": # Include prompt token ids for context-only requests pp_result.prompt_token_ids = response.prompt_token_ids - raw_request.state.server_first_token_ts = time.monotonic() + raw_request.state.server_first_token_time = time.monotonic() await self._extract_metrics(response, raw_request) return pp_result @@ -669,7 +679,7 @@ async def producer(generator: AsyncIterator[Any], idx: int): async def generator_wrapper(generator: AsyncIterator[Any]): first_response = await anext(generator) - raw_request.state.server_first_token_ts = time.monotonic() + raw_request.state.server_first_token_time = time.monotonic() yield first_response async for output in generator: yield output From cc8cd5af32504134c0c5778bac92540b5ca87ca9 Mon Sep 17 00:00:00 2001 From: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> Date: Fri, 26 Sep 2025 15:50:04 -0700 Subject: [PATCH 08/11] Address comments Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> --- .../tensorrt_llm/batch_manager/llmRequest.h | 20 +++++------ .../nanobind/batch_manager/bindings.cpp | 10 +++--- .../nanobind/batch_manager/llmRequest.cpp | 3 +- .../nanobind/batch_manager/llmRequest.h | 6 ++-- .../pybind/batch_manager/bindings.cpp | 10 +++--- .../pybind/batch_manager/llmRequest.cpp | 3 +- .../pybind/batch_manager/llmRequest.h | 3 +- .../pyexecutor/executor_request_queue.py | 7 +--- tensorrt_llm/_torch/pyexecutor/llm_request.py | 1 - tensorrt_llm/_torch/pyexecutor/py_executor.py | 17 +++++---- tensorrt_llm/executor/postproc_worker.py | 25 +++++++------ tensorrt_llm/serve/openai_disagg_server.py | 32 ++++++++++------- tensorrt_llm/serve/openai_server.py | 35 ++++++++++++------- tensorrt_llm/serve/responses_utils.py | 5 +++ 14 files changed, 96 insertions(+), 81 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index eb76e7565e7..a91d9760e35 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -140,8 +140,7 @@ class GenericLlmRequest std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt, - std::optional globalSteadyClockOffset = std::nullopt) + std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt) : mRequestId(requestId) , mPromptLen(inputTokens->size()) , mMaxNewTokens(maxNewTokens) @@ -199,7 +198,6 @@ class GenericLlmRequest , mLanguageAdapterUid(languageAdapterUid) , mAllottedTimeMs(allottedTimeMs) , mCacheSaltID(cacheSaltID) - , mGlobalSteadyClockOffset(globalSteadyClockOffset) { if (mEncoderTokens.has_value() || encoderInputFeatures.has_value()) { @@ -227,8 +225,7 @@ class GenericLlmRequest executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1, std::optional languageAdapterUid = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, - std::optional globalSteadyClockOffset = std::nullopt) + std::optional cacheSaltID = std::nullopt) : mRequestId(requestId) , mPromptLen(inputTokens.size()) , mMaxNewTokens(maxNewTokens) @@ -269,7 +266,6 @@ class GenericLlmRequest , mNumReturnSequences(numReturnSequences) , mLanguageAdapterUid(languageAdapterUid) , mCacheSaltID(cacheSaltID) - , mGlobalSteadyClockOffset(globalSteadyClockOffset) { if (mEncoderTokens.has_value()) { @@ -1887,6 +1883,9 @@ class GenericLlmRequest // current position of the prompt tuning table (only used in chunked prefill mode) SizeType32 mPtableCurrentPosition{0}; + // The offset between local steady clock and global steady clock (at rank 0) + inline static std::optional mGlobalSteadyClockOffset{std::nullopt}; + protected: bool mIsStreaming; @@ -2046,9 +2045,6 @@ class GenericLlmRequest // Cache salt id for each request. std::optional mCacheSaltID{std::nullopt}; - // The offset between local steady clock and global steady clock (at rank 0) - std::optional mGlobalSteadyClockOffset; - private: void initialize( VecTokens const& inputTokens, bool outputLogProbs, std::optional arrivalTime = std::nullopt) @@ -2145,6 +2141,7 @@ class GenericLlmRequest if (mReturnPerfMetrics) { + // arrivalTime is assumed to be recorded at the rank 0, so no need to convert it to global clock mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(getSteadyClockNow()); } mStartTime = getSteadyClockNow(); @@ -2252,8 +2249,7 @@ class LlmRequest : public GenericLlmRequest std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, - std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt, - std::optional globalSteadyClockOffset = std::nullopt) + std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt) : Base(requestId, maxNewTokens, std::make_shared>(std::move(inputTokens)), samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), @@ -2284,7 +2280,7 @@ class LlmRequest : public GenericLlmRequest : std::optional>(std::nullopt), numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID, - arrivalTime, globalSteadyClockOffset) + arrivalTime) { } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index e220e5cf910..2f144f3abcf 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -291,8 +291,7 @@ void initBindings(nb::module_& m) std::optional allotted_time_ms, std::optional context_phase_params, std::optional cache_salt_id, - std::optional arrival_time, - std::optional global_steady_clock_offset) + std::optional arrival_time) { auto makeOptionalTensor = [](std::optional const& atTensor, bool unsqueeze = false) { @@ -333,7 +332,7 @@ void initBindings(nb::module_& m) encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, - arrival_time, global_steady_clock_offset}; + arrival_time}; }, nb::arg("request_id"), nb::arg("max_new_tokens"), nb::arg("input_tokens"), nb::arg("sampling_config"), nb::arg("is_streaming"), nb::arg("end_id") = std::nullopt, nb::arg("pad_id") = std::nullopt, @@ -359,7 +358,7 @@ void initBindings(nb::module_& m) nb::arg("return_perf_metrics") = false, nb::arg("guided_decoding_params") = std::nullopt, nb::arg("language_adapter_uid") = std::nullopt, nb::arg("allotted_time_ms") = std::nullopt, nb::arg("context_phase_params") = std::nullopt, nb::arg("cache_salt_id") = std::nullopt, - nb::arg("arrival_time") = std::nullopt, nb::arg("global_steady_clock_offset") = std::nullopt) + nb::arg("arrival_time") = std::nullopt) .def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, nb::arg("vocab_size")) .def(nb::init()) .def("validate", &tb::LlmRequest::validate, nb::arg("max_input_len"), nb::arg("max_seq_len"), @@ -383,7 +382,8 @@ void initBindings(nb::module_& m) .def("finish_by_reason", &tb::LlmRequest::finishByReason, nb::arg("finish_reason")) .def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime) .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter")) - .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors); + .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors) + .def_rw_static("global_steady_clock_offset", &tb::LlmRequest::mGlobalSteadyClockOffset); nb::class_(m, "SequenceSlotManager") .def(nb::init(), nb::arg("max_num_slots"), diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp index ac40d299592..07d630cb3b2 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp @@ -127,7 +127,6 @@ std::shared_ptr LlmRequest::toTrtLlm() const mAllottedTimeMs, // mContextPhaseParams, // mCacheSaltID, // - mPerfMetrics.timingMetrics.arrivalTime, // - mGlobalSteadyClockOffset // + mPerfMetrics.timingMetrics.arrivalTime // ); } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h index c81b49257c3..4ea47fdcc8c 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h @@ -86,8 +86,7 @@ class LlmRequest : public tb::GenericLlmRequest std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, std::optional cacheSaltID = std::nullopt, - std::optional arrivalTime = std::nullopt, - std::optional globalSteadyClockOffset = std::nullopt) + std::optional arrivalTime = std::nullopt) : Base(requestId, // maxNewTokens, // std::make_shared>(std::move(inputTokens)), // @@ -150,8 +149,7 @@ class LlmRequest : public tb::GenericLlmRequest allottedTimeMs, // contextPhaseParams, // cacheSaltID, // - arrivalTime, // - globalSteadyClockOffset // + arrivalTime // ) { } diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 3353e7a33b5..2e628e72999 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -296,8 +296,7 @@ void initBindings(pybind11::module_& m) std::optional allotted_time_ms, std::optional context_phase_params, std::optional cache_salt_id, - std::optional arrival_time, - std::optional globalSteadyClockOffset = std::nullopt) + std::optional arrival_time) { auto makeOptionalTensor = [](std::optional const& atTensor, bool unsqueeze = false) { @@ -338,7 +337,7 @@ void initBindings(pybind11::module_& m) encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, guided_decoding_params, - language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, arrival_time, global_steady_clock_offset}; + language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, arrival_time}; }), py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"), py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt, @@ -365,7 +364,7 @@ void initBindings(pybind11::module_& m) py::arg("return_perf_metrics") = false, py::arg("guided_decoding_params") = std::nullopt, py::arg("language_adapter_uid") = std::nullopt, py::arg("allotted_time_ms") = std::nullopt, py::arg("context_phase_params") = std::nullopt, py::arg("cache_salt_id") = std::nullopt, - nb::arg("arrival_time") = std::nullopt, nb::arg("global_steady_clock_offset") = std::nullopt) + py::arg("arrival_time") = std::nullopt) .def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, py::arg("vocab_size")) .def(py::init()) .def("validate", &tb::LlmRequest::validate, py::arg("max_input_len"), py::arg("max_seq_len"), @@ -389,7 +388,8 @@ void initBindings(pybind11::module_& m) .def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason")) .def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime) .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter")) - .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors); + .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors) + .def_readwrite_static("global_steady_clock_offset", &tb::LlmRequest::mGlobalSteadyClockOffset); py::classh(m, "SequenceSlotManager") .def(py::init(), py::arg("max_num_slots"), diff --git a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp index 53a29e32223..bcc9d4bf13f 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp @@ -126,7 +126,6 @@ std::shared_ptr LlmRequest::toTrtLlm() const mAllottedTimeMs, // mContextPhaseParams, // mCacheSaltID, // - mPerfMetrics.timingMetrics.arrivalTime, // - mGlobalSteadyClockOffset // + mPerfMetrics.timingMetrics.arrivalTime // ); } diff --git a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h index 96d8b906b74..b43fb8dd073 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h @@ -86,8 +86,7 @@ class LlmRequest : public tb::GenericLlmRequest std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, std::optional cacheSaltID = std::nullopt, - std::optional arrivalTime = std::nullopt, - std::optional globalSteadyClockOffset = std::nullopt) + std::optional arrivalTime = std::nullopt) : Base(requestId, // maxNewTokens, // std::make_shared>(std::move(inputTokens)), // diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 1ab252d69cb..c03673f34e7 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -44,8 +44,7 @@ class ExecutorRequestQueue: def __init__(self, dist: Distributed, enable_attention_dp: bool, max_batch_size: int, max_beam_width: int, max_num_active_requests: int, enable_iter_perf_stats: bool, - batch_wait_timeout_ms: float, is_disaggregated: bool, - global_steady_clock_offset: float): + batch_wait_timeout_ms: float, is_disaggregated: bool): self.dist = dist self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue() self.waiting_queue: deque[RequestQueueItem] = deque() @@ -61,7 +60,6 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool, self.start_times = {} self.active = True self.batch_wait_timeout_ms = batch_wait_timeout_ms - self.global_steady_clock_offset = global_steady_clock_offset # State tracking self.num_fetch_requests = 0 @@ -613,9 +611,6 @@ def _merge_requests( else: req_with_children = [] for req_item in new_requests: - if self.global_steady_clock_offset: - req_item.request.py_global_steady_clock_offset = self.global_steady_clock_offset - req = executor_request_to_llm_request( req_item.id, req_item.request, req_item.child_req_ids, self._should_exclude_last_generation_logits()) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index a7208a8bd20..1f971095126 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -593,7 +593,6 @@ def executor_request_to_llm_request( context_phase_params=executor_request.context_phase_params, cache_salt_id=executor_request.cache_salt_id, arrival_time=getattr(executor_request, "py_arrival_time", None), - global_steady_clock_offset=getattr(executor_request, "py_global_steady_clock_offset", None), py_multimodal_data=getattr(executor_request, "py_multimodal_data", None)) if child_req_ids: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index e1b23e44926..fb68058ac78 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -12,6 +12,8 @@ import torch +from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds + try: from cuda.bindings import runtime as cudart except ImportError: @@ -166,8 +168,6 @@ def __init__(self, super(PyExecutor, self).__init__() self.device_id = torch.cuda.current_device() self.global_rank = global_mpi_rank() - self.dist = dist - self.global_steady_clock_offset = self._get_global_steady_clock_offset() self.peft_cache_config = peft_cache_config @@ -186,6 +186,7 @@ def __init__(self, self.draft_model_engine = getattr(self.drafter, "draft_model_engine", None) self.guided_decoder = guided_decoder + self.dist = dist self.disable_overlap_scheduler = disable_overlap_scheduler # enqueue and _fetch_new_requests used data @@ -255,6 +256,7 @@ def __init__(self, self.batch_wait_iters_count = 0 # request fetcher initialization + self._set_global_steady_clock_offset() self.executor_request_queue = ExecutorRequestQueue( dist=self.dist, enable_attention_dp=self.enable_attention_dp, @@ -264,7 +266,6 @@ def __init__(self, enable_iter_perf_stats=self.enable_iter_perf_stats, batch_wait_timeout_ms=self.batch_wait_timeout_ms, is_disaggregated=kv_cache_transceiver is not None, - global_steady_clock_offset=self.global_steady_clock_offset, ) self.executor_request_queue.set_exclude_last_generation_logits( self.disable_overlap_scheduler, self.dist.pp_size) @@ -367,20 +368,24 @@ def start_worker(self): self.worker_thread.start() self.worker_started = True - def _get_global_steady_clock_offset(self): + def _set_global_steady_clock_offset(self): assert self.global_rank >= 0, "rank should be >= 0" # Sync all ranks self.dist.barrier() # Immediately take the local steady clock timestamp - local_timestamp = time.monotonic() + local_timestamp = get_steady_clock_now_in_seconds() all_rank_timestamps = self.dist.allgather(local_timestamp) if self.global_rank == 0: logger.info( f"global_steady_clock_offset at each rank: {[local_timestamp - ts for ts in all_rank_timestamps]}" ) # Compute the steady clock offset between rank 0 and current rank - return all_rank_timestamps[0] - local_timestamp + global_steady_clock_offset = all_rank_timestamps[0] - local_timestamp + LlmRequest.global_steady_clock_offset = global_steady_clock_offset + logger.info( + f"Setting global_steady_clock_offset: {global_steady_clock_offset} seconds for rank {self.global_rank}" + ) def __enter__(self): return self diff --git a/tensorrt_llm/executor/postproc_worker.py b/tensorrt_llm/executor/postproc_worker.py index ea32215679c..10494ad738b 100644 --- a/tensorrt_llm/executor/postproc_worker.py +++ b/tensorrt_llm/executor/postproc_worker.py @@ -144,8 +144,11 @@ async def _handle_input( # Left the result_handler determine the final output dtype. # NOTE: This will change the CompletionOutput._postprocess_result metrics_dict = record.metrics_dict - perf_metrics = record.outputs[0].request_perf_metrics - disaggregated_params = record.outputs[0].disaggregated_params + perf_metrics = None + disaggregated_params = None + if record.outputs: + perf_metrics = record.outputs[0].request_perf_metrics + disaggregated_params = record.outputs[0].disaggregated_params if postproc_params := record.postproc_params: result_handler, args = postproc_params.post_processor, postproc_params.postproc_args args.tokenizer = self._tokenizer @@ -180,15 +183,17 @@ async def handle_single_input(inp: PostprocWorker.Input, client_id = inp.rsp.client_id is_final = inp.rsp.result.is_final if is_llm_response( inp.rsp) else True - res, metrics, perf_metrics, disaggregated_params = await self._handle_input(inp) + res, metrics, perf_metrics, disaggregated_params = await self._handle_input( + inp) batch.append( - PostprocWorker.Output(client_id=client_id, - res=res, - is_final=is_final, - metrics=metrics, - request_perf_metrics=perf_metrics, - disaggregated_params=disaggregated_params, - )) + PostprocWorker.Output( + client_id=client_id, + res=res, + is_final=is_final, + metrics=metrics, + request_perf_metrics=perf_metrics, + disaggregated_params=disaggregated_params, + )) if is_final: self._records.pop(client_id) diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 9719553ca03..dae68613f18 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -4,7 +4,6 @@ import itertools import os import signal -import time import traceback from collections import deque from contextlib import asynccontextmanager @@ -31,6 +30,7 @@ CompletionResponse, DisaggregatedParams, ErrorResponse) +from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds from tensorrt_llm.serve.router import KvCacheAwareRouter, create_router from tensorrt_llm.version import __version__ as VERSION @@ -56,7 +56,7 @@ def __init__(self, self.perf_metrics_max_requests = config.perf_metrics_max_requests if self.perf_metrics_max_requests > 0: # record corresponding keys of context and generation servers for perf metrics - # (ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_ts) + # (ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_time) self.perf_metrics_keys = deque(maxlen=self.perf_metrics_max_requests) self.perf_metrics_keys_lock = asyncio.Lock() # server_url -> {ctx_request_id: perf_metrics} @@ -139,8 +139,7 @@ async def lifespan(app: FastAPI): @self.app.middleware("http") async def add_process_time_header(raw_request: Request, call_next): - start_time = time.monotonic() - raw_request.state.server_arrival_time = start_time + raw_request.state.server_arrival_time = get_steady_clock_now_in_seconds() response = await call_next(raw_request) return response @@ -199,7 +198,7 @@ async def version(self) -> JSONResponse: async def _add_perf_metrics_keys(self, ctx_server: str, gen_server: str, ctx_request_id: int, raw_request: Request): async with self.perf_metrics_keys_lock: - self.perf_metrics_keys.append((ctx_server, gen_server, ctx_request_id, raw_request.state.server_arrival_time, raw_request.state.server_first_token_ts)) + self.perf_metrics_keys.append((ctx_server, gen_server, ctx_request_id, raw_request.state.server_arrival_time, raw_request.state.server_first_token_time)) async def perf_metrics(self) -> JSONResponse: if self.perf_metrics_keys is None: @@ -236,18 +235,18 @@ async def perf_metrics(self) -> JSONResponse: raise exc remain_keys = [] - for ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_ts in self.perf_metrics_keys: + for ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_time in self.perf_metrics_keys: gen_perf_metrics = self.server_perf_metrics[gen_server].pop(ctx_request_id, None) if gen_perf_metrics is None: # generation not finished - remain_keys.append((ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_ts)) + remain_keys.append((ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_time)) continue ctx_perf_metrics = self.server_perf_metrics[ctx_server].pop(ctx_request_id, None) return_metrics.append({ "ctx_server": ctx_server, "gen_server": gen_server, "disagg_server_arrival_time": server_arrival_time, - "disagg_server_first_token_ts": server_first_token_ts, + "disagg_server_first_token_time": server_first_token_time, "ctx_perf_metrics": ctx_perf_metrics, "gen_perf_metrics": gen_perf_metrics}) self.perf_metrics_keys = deque(remain_keys, maxlen=self.perf_metrics_max_requests) @@ -340,7 +339,7 @@ async def _merge_streaming_responses(ctx_response, raise TypeError("Invalid request type: {type(gen_req).__name__}") first_response = await anext(gen_response.body_iterator) - raw_request.state.server_first_token_ts = time.monotonic() + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() yield first_response async for chunk in gen_response.body_iterator: yield chunk @@ -420,6 +419,10 @@ async def _merge_streaming_responses(ctx_response, assert isinstance(req, ChatCompletionRequest) gen_response = await self.send_chat_request(gen_server, req) await self._increment_metric("gen_completed_requests") + if need_ctx and self.perf_metrics_keys is not None: + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() + asyncio.create_task(self._add_perf_metrics_keys( + ctx_server, gen_server, ctx_request_id, raw_request)) return gen_response finally: if gen_server is not None: @@ -509,11 +512,11 @@ async def send_chat_request(self, url: str, request: ChatCompletionRequest) -> C async def set_steady_clock_offsets(self, session: aiohttp.ClientSession): STEADY_CLOCK_OFFSET_ENDPOINT = "/steady_clock_offset" - async def query_steady_clock_offset(server_url: str) -> Optional[float]: + async def query_steady_clock_offset(server_url: str) -> tuple[Optional[float], Optional[float]]: try: - originate_ts = time.monotonic() + originate_ts = get_steady_clock_now_in_seconds() async with session.get(server_url + STEADY_CLOCK_OFFSET_ENDPOINT) as response: - destination_ts = time.monotonic() + destination_ts = get_steady_clock_now_in_seconds() if response.status == 200: response = await response.json() # Compute the steady clock timestamp difference using the NTP clock synchronization algorithm. https://en.wikipedia.org/wiki/Network_Time_Protocol#Clock_synchronization_algorithm @@ -533,8 +536,11 @@ async def set_steady_clock_offset(server_url: str, offset: float) -> Optional[fl logger.warning(f"Cannot set disagg server steady clock offset for server {server_url}, the perf metrics timestamps could be mis-aligned") for server_url in self.ctx_servers + self.gen_servers: delay, offset = await query_steady_clock_offset(server_url) + if delay is None or offset is None: + logger.warning(f"Unable to measure steady clock offset for {server_url}; skipping adjustment") + continue logger.info(f'Server: {server_url}, delay: {delay} second, offset: {offset} second') - # Negate the offset so that worker servers can adjust their steady block by adding the new offset + # Negate the offset so that worker servers can adjust their steady clock by adding the new offset await set_steady_clock_offset(server_url, -offset) @classmethod diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 1adec4834e4..02ad9012767 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -3,7 +3,6 @@ import os import re import signal -import time import traceback from collections import deque from contextlib import asynccontextmanager @@ -54,6 +53,7 @@ from tensorrt_llm.serve.responses_utils import ConversationHistoryStore from tensorrt_llm.serve.responses_utils import \ create_response as responses_api_create_response +from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds from tensorrt_llm.serve.responses_utils import \ process_streaming_events as responses_api_process_streaming_events from tensorrt_llm.serve.responses_utils import \ @@ -165,8 +165,7 @@ async def validation_exception_handler(_, exc): @self.app.middleware("http") async def add_process_time_header(raw_request: Request, call_next): - start_time = time.monotonic() - raw_request.state.server_arrival_time = start_time + raw_request.state.server_arrival_time = get_steady_clock_now_in_seconds() response = await call_next(raw_request) return response @@ -326,9 +325,9 @@ async def set_steady_clock_offset(self, offset: Annotated[float, Body(embed=True return Response(status_code=200) async def get_steady_clock_offset(self) -> JSONResponse: - receive_ts = time.monotonic() + receive_ts = get_steady_clock_now_in_seconds() await asyncio.sleep(0.2) - transmit_ts = time.monotonic() + transmit_ts = get_steady_clock_now_in_seconds() return JSONResponse(content={"receive_ts": receive_ts, "transmit_ts": transmit_ts}) async def get_perf_metrics(self) -> JSONResponse: @@ -347,12 +346,18 @@ async def get_perf_metrics(self) -> JSONResponse: "last_iter": metrics.last_iter, # exclude metrics.iter since it is only meaningful when the request is not finished } + server_arrival_time = metrics_dict.pop("server_arrival_time", None) + if server_arrival_time is not None: + server_arrival_time += self.disagg_server_steady_clock_offset + server_first_token_time = metrics_dict.pop("server_first_token_time", None) + if server_first_token_time is not None: + server_first_token_time += self.disagg_server_steady_clock_offset metrics_json["timing_metrics"] = { - "server_arrival_time": metrics_dict.pop("server_arrival_time", None) + self.disagg_server_steady_clock_offset, + "server_arrival_time": server_arrival_time, "arrival_time": timing_metrics.arrival_time.total_seconds() + self.disagg_server_steady_clock_offset, "first_scheduled_time": timing_metrics.first_scheduled_time.total_seconds() + self.disagg_server_steady_clock_offset, "first_token_time": timing_metrics.first_token_time.total_seconds() + self.disagg_server_steady_clock_offset, - "server_first_token_time":metrics_dict.pop("server_first_token_time", None) + self.disagg_server_steady_clock_offset, + "server_first_token_time": server_first_token_time, "last_token_time": timing_metrics.last_token_time.total_seconds() + self.disagg_server_steady_clock_offset, } metrics_json["kv_cache_metrics"] = { @@ -387,7 +392,7 @@ async def get_kv_cache_events(self) -> JSONResponse: pass return JSONResponse(content=events) - async def _extract_metrics(self, res: RequestOutput, raw_request: Optional[Request] = None): + async def _extract_metrics(self, res: RequestOutput, raw_request: Request): if not res.finished: return if self.metrics_collector: @@ -421,12 +426,15 @@ async def chat_stream_generator( try: if not self.postproc_worker_enabled: post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + first_response = await anext(promise) + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() + yield first_response async for res in promise: pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) - await self._extract_metrics(res) for pp_res in pp_results: yield pp_res yield "data: [DONE]\n\n" + await self._extract_metrics(res, raw_request) nvtx_mark("generation ends") except: logger.error(traceback.format_exc()) @@ -444,7 +452,8 @@ async def create_chat_response( # Add prompt_tokens_ids to the response if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": chat_response.prompt_token_ids = promise.prompt_token_ids - await self._extract_metrics(promise) + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() + await self._extract_metrics(promise, raw_request) return chat_response try: @@ -613,7 +622,7 @@ async def completion_response(promise: RequestOutput, if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": # Include prompt token ids for context-only requests pp_result.prompt_token_ids = response.prompt_token_ids - raw_request.state.server_first_token_time = time.monotonic() + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() await self._extract_metrics(response, raw_request) return pp_result @@ -651,9 +660,9 @@ async def completion_generator(promise: RequestOutput, params: Optional[Postproc pp_result = post_processor(output, args) else: pp_result = output.outputs[0]._postprocess_result - await self._extract_metrics(output, raw_request) for pp_res in pp_result: yield pp_res + await self._extract_metrics(output, raw_request) except: logger.error(traceback.format_exc()) raise @@ -679,7 +688,7 @@ async def producer(generator: AsyncIterator[Any], idx: int): async def generator_wrapper(generator: AsyncIterator[Any]): first_response = await anext(generator) - raw_request.state.server_first_token_time = time.monotonic() + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() yield first_response async for output in generator: yield output diff --git a/tensorrt_llm/serve/responses_utils.py b/tensorrt_llm/serve/responses_utils.py index d4a6af268c4..6db16b1c5ca 100644 --- a/tensorrt_llm/serve/responses_utils.py +++ b/tensorrt_llm/serve/responses_utils.py @@ -35,6 +35,7 @@ StreamState, SystemContent, TextContent, ToolDescription, load_harmony_encoding) +from tensorrt_llm.bindings import steady_clock_now from tensorrt_llm.llmapi import SamplingParams from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.logger import logger @@ -78,6 +79,10 @@ def decode_tokens(tokens): return get_encoding().decode(tokens) +def get_steady_clock_now_in_seconds() -> float: + return steady_clock_now().total_seconds() + + def parse_response_input( input_msg: ResponseInputOutputItem, prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]] From 7268708b899f56bb87e1c1d8358c43badd75c65c Mon Sep 17 00:00:00 2001 From: nv-yilinf <206948969+nv-yilinf@users.noreply.github.com> Date: Mon, 29 Sep 2025 14:46:07 -0700 Subject: [PATCH 09/11] address comments Signed-off-by: nv-yilinf <206948969+nv-yilinf@users.noreply.github.com> --- .../tensorrt_llm/batch_manager/llmRequest.h | 18 +++++++++--------- .../batch_manager/cacheFormatter.cpp | 8 ++++---- .../batch_manager/mlaCacheFormatter.cpp | 8 ++++---- tensorrt_llm/serve/openai_disagg_server.py | 10 +++++----- tensorrt_llm/serve/openai_server.py | 4 +++- 5 files changed, 25 insertions(+), 23 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index a91d9760e35..5c76685380a 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1864,6 +1864,15 @@ class GenericLlmRequest return mUseDraftModel; } + // If mGlobalSteadyClockOffset is set, return a global steady clock time point, otherwise return local steady clock + // time point + [[nodiscard]] TimePoint getSteadyClockNow() const + { + const TimePoint time_point = std::chrono::steady_clock::now(); + + return maybeToGlobalSteadyClock(time_point); + } + RequestIdType mRequestId; SizeType32 mPromptLen; SizeType32 mMaxNewTokens; @@ -2184,15 +2193,6 @@ class GenericLlmRequest return time_point; } } - - // If mGlobalSteadyClockOffset is set, return a global steady clock time point, otherwise return local steady clock - // time point - TimePoint getSteadyClockNow() const - { - const TimePoint time_point = std::chrono::steady_clock::now(); - - return maybeToGlobalSteadyClock(time_point); - } }; class LlmRequest : public GenericLlmRequest diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 3663fc05a53..e9e3bf67f96 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -390,7 +390,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); TLLM_CHECK(connections.size() > (processIdx / peerDuplicateHeadFactor)); TLLM_CHECK(outputSplitCaches.size() > (processIdx / peerDuplicateHeadFactor)); - auto startTime = std::chrono::steady_clock::now(); + auto startTime = llmRequest.getSteadyClockNow(); size_t ppDomainSize = targetInfo.mDomainPPSize; size_t bufferTpRank = (processIdx / ppDomainSize) / peerDuplicateHeadFactor; @@ -437,7 +437,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio } } - auto endTime = std::chrono::steady_clock::now(); + auto endTime = llmRequest.getSteadyClockNow(); double delay = 0.0; if (recordDelay) { @@ -753,7 +753,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); TLLM_CHECK(pickUpConnections.size() > processIdx); TLLM_CHECK(recvSplitCaches.size() > processIdx); - auto startTime = std::chrono::steady_clock::now(); + auto startTime = llmRequest.getSteadyClockNow(); size_t size = 0; if (processIdx >= remainNoCoverTargetNum) @@ -794,7 +794,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess } } - auto endTime = std::chrono::steady_clock::now(); + auto endTime = llmRequest.getSteadyClockNow(); double delay = 0.0; if (recordDelay) { diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 6e3093cd452..fc840cadbda 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -244,7 +244,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses NVTX3_SCOPED_RANGE(sendBufferFun); TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); - auto startTime = std::chrono::steady_clock::now(); + auto startTime = llmRequest.getSteadyClockNow(); auto cacheIdx = processIdx % (pPDomainSize * cPDomainSize); if (cacheIdx < bufferCoverTargetNum) { @@ -277,7 +277,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses remainSendSize -= sendSize; } } - auto endTime = std::chrono::steady_clock::now(); + auto endTime = llmRequest.getSteadyClockNow(); double delay = 0.0; if (recordDelay) { @@ -451,7 +451,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s { NVTX3_SCOPED_RANGE(recvBufferFun); TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); - auto startTime = std::chrono::steady_clock::now(); + auto startTime = llmRequest.getSteadyClockNow(); size_t size = 0; if (processIdx >= remainNoCoverTargetNum) { @@ -484,7 +484,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s remainRecvSize -= recvSize; } } - auto endTime = std::chrono::steady_clock::now(); + auto endTime = llmRequest.getSteadyClockNow(); double delay = 0.0; if (recordDelay) { diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index dae68613f18..6e50c825d2c 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -518,18 +518,18 @@ async def query_steady_clock_offset(server_url: str) -> tuple[Optional[float], O async with session.get(server_url + STEADY_CLOCK_OFFSET_ENDPOINT) as response: destination_ts = get_steady_clock_now_in_seconds() if response.status == 200: - response = await response.json() + response_content = await response.json() # Compute the steady clock timestamp difference using the NTP clock synchronization algorithm. https://en.wikipedia.org/wiki/Network_Time_Protocol#Clock_synchronization_algorithm - receive_ts = response['receive_ts'] - transmit_ts = response['transmit_ts'] + receive_ts = response_content['receive_ts'] + transmit_ts = response_content['transmit_ts'] delay = (destination_ts - originate_ts) - (transmit_ts - receive_ts) offset = ((receive_ts - originate_ts) + (transmit_ts - destination_ts)) / 2 return delay, offset else: return None, None except Exception: - return None - async def set_steady_clock_offset(server_url: str, offset: float) -> Optional[float]: + return None, None + async def set_steady_clock_offset(server_url: str, offset: float) -> None: payload = {"offset": offset} async with session.post(server_url + STEADY_CLOCK_OFFSET_ENDPOINT, json=payload) as response: if response.status != 200: diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 02ad9012767..3fc3b9c6a51 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -428,7 +428,9 @@ async def chat_stream_generator( post_processor, args = postproc_params.post_processor, postproc_params.postproc_args first_response = await anext(promise) raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() - yield first_response + pp_results = first_response.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(first_response, args) + for pp_res in pp_results: + yield pp_res async for res in promise: pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) for pp_res in pp_results: From d0955d4bad628b94ff2e6e3fe11c3ef9addaf1ad Mon Sep 17 00:00:00 2001 From: nv-yilinf <206948969+nv-yilinf@users.noreply.github.com> Date: Mon, 29 Sep 2025 15:56:34 -0700 Subject: [PATCH 10/11] Add tests Signed-off-by: nv-yilinf <206948969+nv-yilinf@users.noreply.github.com> --- .../defs/disaggregated/test_disaggregated.py | 125 +++++++++++++++--- 1 file changed, 106 insertions(+), 19 deletions(-) diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 1a0f881a32e..33ba8ff7fa9 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -37,6 +37,109 @@ def cleanup_output_files(): pass +def validate_timing_metrics(perf_metrics_item, request_context=""): + """ + Helper function to validate timing metrics relationships. + + Args: + perf_metrics_item: A single performance metrics item from the /perf_metrics endpoint + request_context: String context for error messages (e.g., "request 1", "streaming") + """ + # Validate basic structure + required_keys = [ + "ctx_server", "gen_server", "ctx_perf_metrics", "gen_perf_metrics", + "disagg_server_arrival_time", "disagg_server_first_token_time" + ] + for key in required_keys: + assert key in perf_metrics_item, f"Missing key: {key} in {request_context}" + + assert perf_metrics_item["ctx_perf_metrics"][ + "ctx_request_id"] == perf_metrics_item["gen_perf_metrics"][ + "ctx_request_id"] + + # Extract timing metrics + ctx_metrics = perf_metrics_item["ctx_perf_metrics"]["perf_metrics"][ + "timing_metrics"] + gen_metrics = perf_metrics_item["gen_perf_metrics"]["perf_metrics"][ + "timing_metrics"] + disagg_arrival = perf_metrics_item["disagg_server_arrival_time"] + disagg_first_token = perf_metrics_item["disagg_server_first_token_time"] + + # Validate disaggregated server timing metrics + assert disagg_arrival is not None, f"disagg_server_arrival_time is None in {request_context}" + assert disagg_first_token is not None, f"disagg_server_first_token_time is None in {request_context}" + assert isinstance( + disagg_arrival, + (int, float + )), f"disagg_server_arrival_time is not numeric in {request_context}" + assert isinstance( + disagg_first_token, (int, float) + ), f"disagg_server_first_token_time is not numeric in {request_context}" + assert disagg_arrival > 0, f"disagg_server_arrival_time is not positive in {request_context}" + assert disagg_first_token > 0, f"disagg_server_first_token_time is not positive in {request_context}" + assert disagg_arrival <= disagg_first_token, f"disagg_server_arrival_time > disagg_server_first_token_time in {request_context}" + + # Validate server-level timing metrics for context server + ctx_server_arrival = ctx_metrics.get("server_arrival_time") + ctx_server_first_token = ctx_metrics.get("server_first_token_time") + assert ctx_server_arrival is not None, f"ctx server_arrival_time is None in {request_context}" + assert ctx_server_first_token is not None, f"ctx server_first_token_time is None in {request_context}" + assert isinstance( + ctx_server_arrival, + (int, + float)), f"ctx server_arrival_time is not numeric in {request_context}" + assert isinstance( + ctx_server_first_token, + (int, float + )), f"ctx server_first_token_time is not numeric in {request_context}" + assert ctx_server_arrival <= ctx_server_first_token, f"ctx server_arrival_time > server_first_token_time in {request_context}" + assert ctx_metrics["last_token_time"] - ctx_server_first_token < 1e-3 + + # Validate server-level timing metrics for generation server + gen_server_arrival = gen_metrics.get("server_arrival_time") + gen_server_first_token = gen_metrics.get("server_first_token_time") + assert gen_server_arrival is not None, f"gen server_arrival_time is None in {request_context}" + assert gen_server_first_token is not None, f"gen server_first_token_time is None in {request_context}" + assert isinstance( + gen_server_arrival, + (int, + float)), f"gen server_arrival_time is not numeric in {request_context}" + assert isinstance( + gen_server_first_token, + (int, float + )), f"gen server_first_token_time is not numeric in {request_context}" + assert gen_server_arrival <= gen_server_first_token, f"gen server_arrival_time > server_first_token_time in {request_context}" + + # Validate timing relationships between different levels + # Disaggregated server should receive request before individual servers + assert disagg_arrival <= ctx_server_arrival, f"disagg_arrival > ctx_server_arrival in {request_context}" + assert disagg_arrival <= gen_server_arrival, f"disagg_arrival > gen_server_arrival in {request_context}" + + # Context should complete before generation starts + assert ctx_server_first_token <= gen_server_arrival, f"ctx_server_first_token > gen_server_arrival in {request_context}" + + # Validate internal timing consistency + ctx_arrival_time = ctx_metrics["arrival_time"] + ctx_first_token_time = ctx_metrics["first_token_time"] + gen_arrival_time = gen_metrics["arrival_time"] + gen_first_token_time = gen_metrics["first_token_time"] + + assert ctx_arrival_time <= ctx_first_token_time, f"ctx arrival_time > first_token_time in {request_context}" + assert gen_arrival_time <= gen_first_token_time, f"gen arrival_time > first_token_time in {request_context}" + + # Test KV cache transfer timing (if present) + if "kv_cache_transfer_start" in gen_metrics and "kv_cache_transfer_end" in gen_metrics: + kv_start = gen_metrics["kv_cache_transfer_start"] + kv_end = gen_metrics["kv_cache_transfer_end"] + assert gen_metrics["kv_cache_size"] > 0 + assert kv_start <= kv_end, f"kv_cache_transfer_start > kv_cache_transfer_end in {request_context}" + assert gen_arrival_time <= kv_start, f"gen_arrival_time > kv_cache_transfer_start in {request_context}" + assert kv_end <= gen_metrics[ + "first_scheduled_time"], f"kv_cache_transfer_end > first_scheduled_time in {request_context}" + + return True + + def get_disagg_server_url_from_cfg(config_file: str) -> str: with open(config_file, 'r') as file: config = yaml.safe_load(file) @@ -557,25 +660,9 @@ def extra_endpoints_test(server_url: str): perf_metrics = json.load(resp) assert len(perf_metrics) > 0 item = perf_metrics[0] - assert "ctx_server" in item - assert "gen_server" in item - assert "ctx_perf_metrics" in item - assert "gen_perf_metrics" in item - assert item["ctx_perf_metrics"]["ctx_request_id"] == item[ - "gen_perf_metrics"]["ctx_request_id"] - ctx_metrics = item["ctx_perf_metrics"]["perf_metrics"]["timing_metrics"] - gen_metrics = item["gen_perf_metrics"]["perf_metrics"]["timing_metrics"] - # only one token is generated in ctx - assert ctx_metrics["last_token_time"] - ctx_metrics[ - "first_token_time"] < 1e-3 - assert ctx_metrics["last_token_time"] < gen_metrics["arrival_time"] - assert gen_metrics["kv_cache_size"] > 0 - assert gen_metrics["arrival_time"] < gen_metrics[ - "kv_cache_transfer_start"] - assert gen_metrics["kv_cache_transfer_start"] < gen_metrics[ - "kv_cache_transfer_end"] - assert gen_metrics["kv_cache_transfer_end"] < gen_metrics[ - "first_scheduled_time"] + + # Use helper function to validate all timing metrics comprehensively + validate_timing_metrics(item, "perf_metrics test") run_disaggregated_test(disaggregated_example_root, "perf_metrics", From 034bda5564728a69468ac046b4ddce16067a12de Mon Sep 17 00:00:00 2001 From: nv-yilinf <206948969+nv-yilinf@users.noreply.github.com> Date: Mon, 29 Sep 2025 23:36:30 -0700 Subject: [PATCH 11/11] Fix ut Signed-off-by: nv-yilinf <206948969+nv-yilinf@users.noreply.github.com> --- .../tensorrt_llm/batch_manager/llmRequest.h | 6 ++-- tensorrt_llm/serve/openai_disagg_server.py | 9 ++---- tensorrt_llm/serve/openai_server.py | 19 +++++------ tensorrt_llm/serve/responses_utils.py | 32 +++++++++++++++++++ tests/unittest/llmapi/apps/openai_server.py | 3 +- 5 files changed, 48 insertions(+), 21 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 5c76685380a..670dc0df70d 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1672,8 +1672,8 @@ class GenericLlmRequest { return false; } - auto const currentTime = getSteadyClockNow(); - auto const elapsed = (std::chrono::duration_cast(currentTime - mStartTime)); + auto const currentTime = std::chrono::steady_clock::now(); + auto const elapsed = (std::chrono::duration_cast(currentTime - mStartTime)); TLLM_LOG_DEBUG("Checked timeOut for request %ld with allotted Time %ld after time %ld and got %d", mRequestId, mAllottedTimeMs->count(), elapsed.count(), (elapsed >= mAllottedTimeMs)); @@ -2153,7 +2153,7 @@ class GenericLlmRequest // arrivalTime is assumed to be recorded at the rank 0, so no need to convert it to global clock mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(getSteadyClockNow()); } - mStartTime = getSteadyClockNow(); + mStartTime = std::chrono::steady_clock::now(); } TensorPtr createListTensor(std::list const& wordsList) diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 6e50c825d2c..644e5133f01 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -30,7 +30,8 @@ CompletionResponse, DisaggregatedParams, ErrorResponse) -from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds +from tensorrt_llm.serve.responses_utils import (ServerArrivalTimeMiddleware, + get_steady_clock_now_in_seconds) from tensorrt_llm.serve.router import KvCacheAwareRouter, create_router from tensorrt_llm.version import __version__ as VERSION @@ -137,11 +138,7 @@ async def lifespan(app: FastAPI): self.app = FastAPI(lifespan=lifespan) - @self.app.middleware("http") - async def add_process_time_header(raw_request: Request, call_next): - raw_request.state.server_arrival_time = get_steady_clock_now_in_seconds() - response = await call_next(raw_request) - return response + self.app.add_middleware(ServerArrivalTimeMiddleware) @self.app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 3fc3b9c6a51..74a628cac24 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -50,7 +50,8 @@ chat_harmony_post_processor, chat_harmony_streaming_post_processor, chat_response_post_processor, chat_stream_post_processor, completion_response_post_processor, completion_stream_post_processor) -from tensorrt_llm.serve.responses_utils import ConversationHistoryStore +from tensorrt_llm.serve.responses_utils import (ConversationHistoryStore, + ServerArrivalTimeMiddleware) from tensorrt_llm.serve.responses_utils import \ create_response as responses_api_create_response from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds @@ -163,11 +164,7 @@ async def validation_exception_handler(_, exc): assert isinstance(self.llm, MultimodalEncoder), "llm must be a MultimodalEncoder for multimodal encoder" self.register_mm_encoder_routes() - @self.app.middleware("http") - async def add_process_time_header(raw_request: Request, call_next): - raw_request.state.server_arrival_time = get_steady_clock_now_in_seconds() - response = await call_next(raw_request) - return response + self.app.add_middleware(ServerArrivalTimeMiddleware) async def await_disconnected(self, raw_request: Request, promise): @@ -271,7 +268,7 @@ def register_mm_encoder_routes(self): async def health(self) -> Response: return Response(status_code=200) - async def health_generate(self) -> Response: + async def health_generate(self, raw_request: Request) -> Response: """Health check that performs a minimal generation.""" try: # Create a minimal chat request @@ -283,10 +280,8 @@ async def health_generate(self) -> Response: temperature=0.0 # Deterministic output ) - mock_request = None - # Call the chat completion logic - response = await self.openai_chat(health_request, mock_request) + response = await self.openai_chat(health_request, raw_request) # Check if the response indicates success (status code 200) if response.status_code == 200: @@ -302,7 +297,7 @@ async def health_generate(self) -> Response: return Response(status_code=500, content="Generation health check failed") except Exception as e: - logger.error(f"Health generate check encountered exception: {e}", exc_info=True) + logger.error(f"Health generate check encountered exception: {e}") return Response(status_code=500, content=f"Generation health check failed: {str(e)}") async def version(self) -> JSONResponse: @@ -431,6 +426,8 @@ async def chat_stream_generator( pp_results = first_response.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(first_response, args) for pp_res in pp_results: yield pp_res + # Making sure we can handling the situation where there is only one response + res = first_response async for res in promise: pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) for pp_res in pp_results: diff --git a/tensorrt_llm/serve/responses_utils.py b/tensorrt_llm/serve/responses_utils.py index 6db16b1c5ca..ab8fdae47b5 100644 --- a/tensorrt_llm/serve/responses_utils.py +++ b/tensorrt_llm/serve/responses_utils.py @@ -851,3 +851,35 @@ def _send_event(event: OpenAIBaseModel): sequence_number=-1, response=final_response.model_dump(), )) + + +class ServerArrivalTimeMiddleware: + """ + Custom ASGI middleware to track server arrival time. + + We implement this as a pure ASGI middleware instead of using FastAPI's + @app.middleware("http") decorator because the decorator internally uses + BaseHTTPMiddleware, which wraps the ASGI `receive` callable. This wrapping + breaks Request.is_disconnected() functionality - the wrapped receive doesn't + properly forward http.disconnect events while the middleware is waiting in + call_next(), preventing detection of client disconnections during long-running + non-streaming requests. + + By implementing pure ASGI middleware, we pass through the original receive/send + callables unchanged, preserving the ability to detect client disconnections. + + See: https://github.com/encode/starlette/discussions/2094 + """ + + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + # Add arrival time to scope + scope["state"] = {} + scope["state"][ + "server_arrival_time"] = get_steady_clock_now_in_seconds() + + # Pass through the original receive/send - no wrapping! + await self.app(scope, receive, send) diff --git a/tests/unittest/llmapi/apps/openai_server.py b/tests/unittest/llmapi/apps/openai_server.py index ca98f7e1ece..39c9988d9f5 100644 --- a/tests/unittest/llmapi/apps/openai_server.py +++ b/tests/unittest/llmapi/apps/openai_server.py @@ -29,7 +29,8 @@ def __init__(self, extra_config: Optional[dict] = None) -> None: self.host = host self.port = port if port is not None else find_free_port() - self.rank = rank if rank != -1 else os.environ.get("SLURM_PROCID", 0) + self.rank = rank if rank != -1 else int( + os.environ.get("SLURM_PROCID", 0)) self.extra_config_file = None args = ["--host", f"{self.host}", "--port", f"{self.port}"] if cli_args: