diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 275bc75721a..670dc0df70d 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -101,6 +101,7 @@ class GenericLlmRequest using RequestPtr = std::shared_ptr; using MillisecondsType = std::chrono::milliseconds; using TimePoint = std::chrono::time_point; + using Duration = std::chrono::time_point::duration; using CacheSaltIDType = runtime::CacheSaltIDType; GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr const& inputTokens, @@ -1255,7 +1256,7 @@ class GenericLlmRequest { if (mPerfMetrics.timingMetrics.firstScheduledTime == executor::RequestPerfMetrics::TimePoint{}) { - mPerfMetrics.timingMetrics.firstScheduledTime = std::chrono::steady_clock::now(); + mPerfMetrics.timingMetrics.firstScheduledTime = getSteadyClockNow(); } } @@ -1689,22 +1690,22 @@ class GenericLlmRequest mDecodingIter = iter; } - void setKvCacheTransferStart(std::chrono::time_point const& time) + void setKvCacheTransferStart(TimePoint const& time) { - mPerfMetrics.timingMetrics.kvCacheTransferStart = time; + mPerfMetrics.timingMetrics.kvCacheTransferStart = maybeToGlobalSteadyClock(time); } - void setKvCacheTransferEnd(std::chrono::time_point const& time) + void setKvCacheTransferEnd(TimePoint const& time) { - mPerfMetrics.timingMetrics.kvCacheTransferEnd = time; + mPerfMetrics.timingMetrics.kvCacheTransferEnd = maybeToGlobalSteadyClock(time); } - std::chrono::time_point getKvCacheTransferStart() + TimePoint getKvCacheTransferStart() { return mPerfMetrics.timingMetrics.kvCacheTransferStart; } - std::chrono::time_point getKvCacheTransferEnd() + TimePoint getKvCacheTransferEnd() { return mPerfMetrics.timingMetrics.kvCacheTransferEnd; } @@ -1788,7 +1789,7 @@ class GenericLlmRequest if (finishReason == executor::FinishReason::kTIMED_OUT) { TLLM_LOG_DEBUG("Request %ld finished by timeout after %f sec", mRequestId, - std::chrono::duration(std::chrono::steady_clock::now() - mStartTime).count()); + std::chrono::duration(getSteadyClockNow() - mStartTime).count()); } if (finishReason == executor::FinishReason::kCANCELLED) { @@ -1826,10 +1827,9 @@ class GenericLlmRequest void updatePerfMetrics(executor::IterationType iter) { - auto const currentTokenTime = std::chrono::steady_clock::now(); - if (!mPerfMetrics.firstIter) { + auto const currentTokenTime = getSteadyClockNow(); mPerfMetrics.firstIter = iter; mPerfMetrics.timingMetrics.firstTokenTime = currentTokenTime; } @@ -1838,6 +1838,7 @@ class GenericLlmRequest if (isFinished()) { + auto const currentTokenTime = getSteadyClockNow(); mPerfMetrics.lastIter = iter; mPerfMetrics.timingMetrics.lastTokenTime = currentTokenTime; } @@ -1863,6 +1864,15 @@ class GenericLlmRequest return mUseDraftModel; } + // If mGlobalSteadyClockOffset is set, return a global steady clock time point, otherwise return local steady clock + // time point + [[nodiscard]] TimePoint getSteadyClockNow() const + { + const TimePoint time_point = std::chrono::steady_clock::now(); + + return maybeToGlobalSteadyClock(time_point); + } + RequestIdType mRequestId; SizeType32 mPromptLen; SizeType32 mMaxNewTokens; @@ -1882,6 +1892,9 @@ class GenericLlmRequest // current position of the prompt tuning table (only used in chunked prefill mode) SizeType32 mPtableCurrentPosition{0}; + // The offset between local steady clock and global steady clock (at rank 0) + inline static std::optional mGlobalSteadyClockOffset{std::nullopt}; + protected: bool mIsStreaming; @@ -2137,7 +2150,8 @@ class GenericLlmRequest if (mReturnPerfMetrics) { - mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(std::chrono::steady_clock::now()); + // arrivalTime is assumed to be recorded at the rank 0, so no need to convert it to global clock + mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(getSteadyClockNow()); } mStartTime = std::chrono::steady_clock::now(); } @@ -2167,6 +2181,18 @@ class GenericLlmRequest return tensor; } + + TimePoint maybeToGlobalSteadyClock(TimePoint const& time_point) const + { + if (mGlobalSteadyClockOffset.has_value()) + { + return time_point + *mGlobalSteadyClockOffset; + } + else + { + return time_point; + } + } }; class LlmRequest : public GenericLlmRequest diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index 3663fc05a53..e9e3bf67f96 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -390,7 +390,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); TLLM_CHECK(connections.size() > (processIdx / peerDuplicateHeadFactor)); TLLM_CHECK(outputSplitCaches.size() > (processIdx / peerDuplicateHeadFactor)); - auto startTime = std::chrono::steady_clock::now(); + auto startTime = llmRequest.getSteadyClockNow(); size_t ppDomainSize = targetInfo.mDomainPPSize; size_t bufferTpRank = (processIdx / ppDomainSize) / peerDuplicateHeadFactor; @@ -437,7 +437,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio } } - auto endTime = std::chrono::steady_clock::now(); + auto endTime = llmRequest.getSteadyClockNow(); double delay = 0.0; if (recordDelay) { @@ -753,7 +753,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); TLLM_CHECK(pickUpConnections.size() > processIdx); TLLM_CHECK(recvSplitCaches.size() > processIdx); - auto startTime = std::chrono::steady_clock::now(); + auto startTime = llmRequest.getSteadyClockNow(); size_t size = 0; if (processIdx >= remainNoCoverTargetNum) @@ -794,7 +794,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess } } - auto endTime = std::chrono::steady_clock::now(); + auto endTime = llmRequest.getSteadyClockNow(); double delay = 0.0; if (recordDelay) { diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 6e3093cd452..fc840cadbda 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -244,7 +244,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses NVTX3_SCOPED_RANGE(sendBufferFun); TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); - auto startTime = std::chrono::steady_clock::now(); + auto startTime = llmRequest.getSteadyClockNow(); auto cacheIdx = processIdx % (pPDomainSize * cPDomainSize); if (cacheIdx < bufferCoverTargetNum) { @@ -277,7 +277,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses remainSendSize -= sendSize; } } - auto endTime = std::chrono::steady_clock::now(); + auto endTime = llmRequest.getSteadyClockNow(); double delay = 0.0; if (recordDelay) { @@ -451,7 +451,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s { NVTX3_SCOPED_RANGE(recvBufferFun); TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); - auto startTime = std::chrono::steady_clock::now(); + auto startTime = llmRequest.getSteadyClockNow(); size_t size = 0; if (processIdx >= remainNoCoverTargetNum) { @@ -484,7 +484,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s remainRecvSize -= recvSize; } } - auto endTime = std::chrono::steady_clock::now(); + auto endTime = llmRequest.getSteadyClockNow(); double delay = 0.0; if (recordDelay) { diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index e0325b51c8a..2f144f3abcf 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -382,7 +382,8 @@ void initBindings(nb::module_& m) .def("finish_by_reason", &tb::LlmRequest::finishByReason, nb::arg("finish_reason")) .def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime) .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter")) - .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors); + .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors) + .def_rw_static("global_steady_clock_offset", &tb::LlmRequest::mGlobalSteadyClockOffset); nb::class_(m, "SequenceSlotManager") .def(nb::init(), nb::arg("max_num_slots"), diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 9bcd22e39e4..2e628e72999 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -388,7 +388,8 @@ void initBindings(pybind11::module_& m) .def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason")) .def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime) .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter")) - .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors); + .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors) + .def_readwrite_static("global_steady_clock_offset", &tb::LlmRequest::mGlobalSteadyClockOffset); py::classh(m, "SequenceSlotManager") .def(py::init(), py::arg("max_num_slots"), diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index f3754f897db..ceb6572ef0d 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -12,6 +12,8 @@ import torch +from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds + try: from cuda.bindings import runtime as cudart except ImportError: @@ -254,6 +256,7 @@ def __init__(self, self.batch_wait_iters_count = 0 # request fetcher initialization + self._set_global_steady_clock_offset() self.executor_request_queue = ExecutorRequestQueue( dist=self.dist, enable_attention_dp=self.enable_attention_dp, @@ -365,6 +368,25 @@ def start_worker(self): self.worker_thread.start() self.worker_started = True + def _set_global_steady_clock_offset(self): + assert self.global_rank >= 0, "rank should be >= 0" + + # Sync all ranks + self.dist.barrier() + # Immediately take the local steady clock timestamp + local_timestamp = get_steady_clock_now_in_seconds() + all_rank_timestamps = self.dist.allgather(local_timestamp) + if self.global_rank == 0: + logger.info( + f"global_steady_clock_offset at each rank: {[local_timestamp - ts for ts in all_rank_timestamps]}" + ) + # Compute the steady clock offset between rank 0 and current rank + global_steady_clock_offset = all_rank_timestamps[0] - local_timestamp + LlmRequest.global_steady_clock_offset = global_steady_clock_offset + logger.info( + f"Setting global_steady_clock_offset: {global_steady_clock_offset} seconds for rank {self.global_rank}" + ) + def __enter__(self): return self @@ -1962,7 +1984,8 @@ def _handle_responses(self): request) > 0 else [] request.decoding_iter = request.py_decoding_iter - if request.return_perf_metrics: + # Skip active requests that are not scheduled + if request.return_perf_metrics and request.py_decoding_iter >= 1: request.update_perf_metrics(self.model_engine.iter_counter) request_done = False diff --git a/tensorrt_llm/executor/postproc_worker.py b/tensorrt_llm/executor/postproc_worker.py index 55bf7839f4b..10494ad738b 100644 --- a/tensorrt_llm/executor/postproc_worker.py +++ b/tensorrt_llm/executor/postproc_worker.py @@ -69,6 +69,8 @@ class Output(NamedTuple): is_final: bool error: str = "" metrics: Optional[dict[str, float]] = None + request_perf_metrics: Any = None + disaggregated_params: Any = None def __init__( self, @@ -142,6 +144,11 @@ async def _handle_input( # Left the result_handler determine the final output dtype. # NOTE: This will change the CompletionOutput._postprocess_result metrics_dict = record.metrics_dict + perf_metrics = None + disaggregated_params = None + if record.outputs: + perf_metrics = record.outputs[0].request_perf_metrics + disaggregated_params = record.outputs[0].disaggregated_params if postproc_params := record.postproc_params: result_handler, args = postproc_params.post_processor, postproc_params.postproc_args args.tokenizer = self._tokenizer @@ -153,7 +160,7 @@ async def _handle_input( # TODO: Keep only the diff token_ids and text in streaming mode when # result_handler is not set - return out, metrics_dict + return out, metrics_dict, perf_metrics, disaggregated_params async def _batched_put(self): ''' Batched IPC send. ''' @@ -176,12 +183,17 @@ async def handle_single_input(inp: PostprocWorker.Input, client_id = inp.rsp.client_id is_final = inp.rsp.result.is_final if is_llm_response( inp.rsp) else True - res, metrics = await self._handle_input(inp) + res, metrics, perf_metrics, disaggregated_params = await self._handle_input( + inp) batch.append( - PostprocWorker.Output(client_id=client_id, - res=res, - is_final=is_final, - metrics=metrics)) + PostprocWorker.Output( + client_id=client_id, + res=res, + is_final=is_final, + metrics=metrics, + request_perf_metrics=perf_metrics, + disaggregated_params=disaggregated_params, + )) if is_final: self._records.pop(client_id) diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index d19a8368297..4923b776487 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -336,6 +336,18 @@ def _handle_response(self, self._outputs[0] = response.res else: self._outputs[0]._postprocess_result = response.res + + self._outputs[ + 0].request_perf_metrics = response.request_perf_metrics + if not self._outputs[0].disaggregated_params: + disaggregated_params = response.disaggregated_params + + # Generation only response has no disaggregated_params attached + if not disaggregated_params: + disaggregated_params = self.disaggregated_params + + self._outputs[0].disaggregated_params = disaggregated_params + if response.metrics: self.metrics_dict = response.metrics diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 2b5f7dc59c0..644e5133f01 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -12,7 +12,7 @@ import aiohttp import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, Response, StreamingResponse from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR @@ -30,6 +30,8 @@ CompletionResponse, DisaggregatedParams, ErrorResponse) +from tensorrt_llm.serve.responses_utils import (ServerArrivalTimeMiddleware, + get_steady_clock_now_in_seconds) from tensorrt_llm.serve.router import KvCacheAwareRouter, create_router from tensorrt_llm.version import __version__ as VERSION @@ -55,11 +57,12 @@ def __init__(self, self.perf_metrics_max_requests = config.perf_metrics_max_requests if self.perf_metrics_max_requests > 0: # record corresponding keys of context and generation servers for perf metrics - # (ctx_server, gen_server, ctx_request_id) + # (ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_time) self.perf_metrics_keys = deque(maxlen=self.perf_metrics_max_requests) self.perf_metrics_keys_lock = asyncio.Lock() - # server_key -> {ctx_request_id: perf_metrics} + # server_url -> {ctx_request_id: perf_metrics} self.server_perf_metrics: dict[str, dict[int, dict]] = {} + else: self.perf_metrics_keys = None self.perf_metrics_keys_lock = None @@ -104,6 +107,9 @@ async def lifespan(app: FastAPI): logger.info("Waiting for context and generation servers to be ready") await self.wait_for_servers_ready(server_start_timeout_secs) + if self.perf_metrics_max_requests > 0: + await self.set_steady_clock_offsets(self.session) + if self.metadata_server: logger.info("Starting server monitoring via metadata service") await self.ctx_router.start_server_monitoring(metadata_server_cfg.refresh_interval) @@ -132,6 +138,8 @@ async def lifespan(app: FastAPI): self.app = FastAPI(lifespan=lifespan) + self.app.add_middleware(ServerArrivalTimeMiddleware) + @self.app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): return JSONResponse(status_code=400, content={"error": str(exc)}) @@ -185,9 +193,9 @@ async def version(self) -> JSONResponse: ver = {"version": VERSION} return JSONResponse(content=ver) - async def _add_perf_metrics_keys(self, ctx_server: str, gen_server: str, ctx_request_id: int): + async def _add_perf_metrics_keys(self, ctx_server: str, gen_server: str, ctx_request_id: int, raw_request: Request): async with self.perf_metrics_keys_lock: - self.perf_metrics_keys.append((ctx_server, gen_server, ctx_request_id)) + self.perf_metrics_keys.append((ctx_server, gen_server, ctx_request_id, raw_request.state.server_arrival_time, raw_request.state.server_first_token_time)) async def perf_metrics(self) -> JSONResponse: if self.perf_metrics_keys is None: @@ -224,50 +232,26 @@ async def perf_metrics(self) -> JSONResponse: raise exc remain_keys = [] - for ctx_server, gen_server, ctx_request_id in self.perf_metrics_keys: + for ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_time in self.perf_metrics_keys: gen_perf_metrics = self.server_perf_metrics[gen_server].pop(ctx_request_id, None) if gen_perf_metrics is None: # generation not finished - remain_keys.append((ctx_server, gen_server, ctx_request_id)) + remain_keys.append((ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_time)) continue ctx_perf_metrics = self.server_perf_metrics[ctx_server].pop(ctx_request_id, None) return_metrics.append({ "ctx_server": ctx_server, "gen_server": gen_server, + "disagg_server_arrival_time": server_arrival_time, + "disagg_server_first_token_time": server_first_token_time, "ctx_perf_metrics": ctx_perf_metrics, "gen_perf_metrics": gen_perf_metrics}) self.perf_metrics_keys = deque(remain_keys, maxlen=self.perf_metrics_max_requests) return JSONResponse(content=return_metrics) - async def merge_streaming_responses(self, ctx_response, - gen_server: str, - gen_req: Union[CompletionRequest, ChatCompletionRequest]): - try: - if ctx_response is not None and len(ctx_response.choices) != 1: - raise ValueError("Context server did not return a single choice. This is not expected") - - #If request finished after first token not due to length, return right away and skip gen - if ctx_response is not None and ctx_response.choices[0].finish_reason not in ["length", "not_finished"]: - yield "data: [DONE]\n\n".encode('utf-8') - else: - # Then yield the generation responses - await self._increment_metric("gen_total_requests") - if isinstance(gen_req, CompletionRequest): - gen_response = await self.send_completion_request(gen_server, gen_req) - elif isinstance(gen_req, ChatCompletionRequest): - gen_response = await self.send_chat_request(gen_server, gen_req) - else: - raise TypeError("Invalid request type: {type(gen_req).__name__}") - - async for chunk in gen_response.body_iterator: - yield chunk - await self._increment_metric("gen_completed_requests") - - finally: - await self.gen_router.finish_request(gen_req) - async def openai_completion(self, req: CompletionRequest) -> Response: + async def openai_completion(self, req: CompletionRequest, raw_request: Request) -> Response: try: if not isinstance(req.prompt, str): # Check if it's a list and contains integers @@ -276,15 +260,15 @@ async def openai_completion(self, req: CompletionRequest) -> Response: elif not isinstance(req.prompt, list) or not all(isinstance(x, int) for x in req.prompt): raise ValueError("Disaggregated server currently only supports single string prompt or list of integers in request") - return await self._send_disagg_request(req) + return await self._send_disagg_request(req, raw_request) except Exception as e: await self._handle_exception(e) - async def openai_chat_completion(self, req: ChatCompletionRequest) -> Response: + async def openai_chat_completion(self, req: ChatCompletionRequest, raw_request: Request) -> Response: try: - return await self._send_disagg_request(req) + return await self._send_disagg_request(req, raw_request) except Exception as e: await self._handle_exception(e) @@ -326,9 +310,44 @@ async def _send_context_request(self, ctx_server: str, ctx_req: Union[Completion return ctx_response - async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletionRequest]): + async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletionRequest], raw_request: Request): + ctx_server = None gen_server = None + ctx_request_id = None need_ctx = False + + async def _merge_streaming_responses(ctx_response, + gen_req: Union[CompletionRequest, ChatCompletionRequest]): + try: + if ctx_response is not None and len(ctx_response.choices) != 1: + raise ValueError("Context server did not return a single choice. This is not expected") + + #If request finished after first token not due to length, return right away and skip gen + if ctx_response is not None and ctx_response.choices[0].finish_reason not in ["length", "not_finished"]: + yield "data: [DONE]\n\n".encode('utf-8') + else: + # Then yield the generation responses + await self._increment_metric("gen_total_requests") + if isinstance(gen_req, CompletionRequest): + gen_response = await self.send_completion_request(gen_server, gen_req) + elif isinstance(gen_req, ChatCompletionRequest): + gen_response = await self.send_chat_request(gen_server, gen_req) + else: + raise TypeError("Invalid request type: {type(gen_req).__name__}") + + first_response = await anext(gen_response.body_iterator) + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() + yield first_response + async for chunk in gen_response.body_iterator: + yield chunk + await self._increment_metric("gen_completed_requests") + if need_ctx and self.perf_metrics_keys is not None: + asyncio.create_task(self._add_perf_metrics_keys( + ctx_server, gen_server, ctx_request_id, raw_request)) + + + finally: + await self.gen_router.finish_request(gen_req) try: # Determine if need context server condition = self.conditional_disagg_config @@ -366,6 +385,7 @@ async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletio # Append disaggregates parameters to generation request req.disaggregated_params = ctx_response.choices[0].disaggregated_params req.disaggregated_params.request_type = "generation_only" + ctx_request_id = req.disaggregated_params.ctx_request_id # Replace the string prompt with prompt_tokens_ids if isinstance(req, CompletionRequest): @@ -382,10 +402,6 @@ async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletio gen_server, _ = await self.gen_router.get_next_server(req) logger.debug("Sending request to gen server: %s", gen_server) - if need_ctx and self.perf_metrics_keys is not None: - asyncio.create_task(self._add_perf_metrics_keys( - ctx_server, gen_server, req.disaggregated_params.ctx_request_id)) - if not req.stream: try: #If request finished after first token for reason other than length, return right away and skip gen @@ -400,6 +416,10 @@ async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletio assert isinstance(req, ChatCompletionRequest) gen_response = await self.send_chat_request(gen_server, req) await self._increment_metric("gen_completed_requests") + if need_ctx and self.perf_metrics_keys is not None: + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() + asyncio.create_task(self._add_perf_metrics_keys( + ctx_server, gen_server, ctx_request_id, raw_request)) return gen_response finally: if gen_server is not None: @@ -408,7 +428,7 @@ async def _send_disagg_request(self, req: Union[CompletionRequest, ChatCompletio else: # Return a streaming response that combines both context and generation responses return StreamingResponse( - self.merge_streaming_responses(ctx_response, gen_server, req), + _merge_streaming_responses(ctx_response, req), media_type="text/event-stream" ) except: @@ -487,6 +507,39 @@ async def send_completion_request(self, url: str, request: CompletionRequest) -> async def send_chat_request(self, url: str, request: ChatCompletionRequest) -> ChatCompletionResponse: return await self.send_request(url, request, "/v1/chat/completions", ChatCompletionResponse, self.create_chat_generator) + async def set_steady_clock_offsets(self, session: aiohttp.ClientSession): + STEADY_CLOCK_OFFSET_ENDPOINT = "/steady_clock_offset" + async def query_steady_clock_offset(server_url: str) -> tuple[Optional[float], Optional[float]]: + try: + originate_ts = get_steady_clock_now_in_seconds() + async with session.get(server_url + STEADY_CLOCK_OFFSET_ENDPOINT) as response: + destination_ts = get_steady_clock_now_in_seconds() + if response.status == 200: + response_content = await response.json() + # Compute the steady clock timestamp difference using the NTP clock synchronization algorithm. https://en.wikipedia.org/wiki/Network_Time_Protocol#Clock_synchronization_algorithm + receive_ts = response_content['receive_ts'] + transmit_ts = response_content['transmit_ts'] + delay = (destination_ts - originate_ts) - (transmit_ts - receive_ts) + offset = ((receive_ts - originate_ts) + (transmit_ts - destination_ts)) / 2 + return delay, offset + else: + return None, None + except Exception: + return None, None + async def set_steady_clock_offset(server_url: str, offset: float) -> None: + payload = {"offset": offset} + async with session.post(server_url + STEADY_CLOCK_OFFSET_ENDPOINT, json=payload) as response: + if response.status != 200: + logger.warning(f"Cannot set disagg server steady clock offset for server {server_url}, the perf metrics timestamps could be mis-aligned") + for server_url in self.ctx_servers + self.gen_servers: + delay, offset = await query_steady_clock_offset(server_url) + if delay is None or offset is None: + logger.warning(f"Unable to measure steady clock offset for {server_url}; skipping adjustment") + continue + logger.info(f'Server: {server_url}, delay: {delay} second, offset: {offset} second') + # Negate the offset so that worker servers can adjust their steady clock by adding the new offset + await set_steady_clock_offset(server_url, -offset) + @classmethod async def check_server_ready(cls, session: aiohttp.ClientSession, server_url: str) -> bool: try: diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 8d2977f6b4a..74a628cac24 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -9,10 +9,11 @@ from datetime import datetime from http import HTTPStatus from pathlib import Path -from typing import Any, AsyncGenerator, AsyncIterator, List, Optional, Union +from typing import (Annotated, Any, AsyncGenerator, AsyncIterator, List, + Optional, Union) import uvicorn -from fastapi import FastAPI, Request +from fastapi import Body, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, Response, StreamingResponse from starlette.routing import Mount @@ -49,9 +50,11 @@ chat_harmony_post_processor, chat_harmony_streaming_post_processor, chat_response_post_processor, chat_stream_post_processor, completion_response_post_processor, completion_stream_post_processor) -from tensorrt_llm.serve.responses_utils import ConversationHistoryStore +from tensorrt_llm.serve.responses_utils import (ConversationHistoryStore, + ServerArrivalTimeMiddleware) from tensorrt_llm.serve.responses_utils import \ create_response as responses_api_create_response +from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds from tensorrt_llm.serve.responses_utils import \ process_streaming_events as responses_api_process_streaming_events from tensorrt_llm.serve.responses_utils import \ @@ -105,6 +108,8 @@ def __init__(self, self.metrics_collector = None self.perf_metrics = None self.perf_metrics_lock = None + # The steady clock offset (in seconds) between this server and the disagg server + self.disagg_server_steady_clock_offset = 0 if self.llm.args.return_perf_metrics: set_prometheus_multiproc_dir() self.metrics_collector = MetricsCollector({ @@ -159,6 +164,9 @@ async def validation_exception_handler(_, exc): assert isinstance(self.llm, MultimodalEncoder), "llm must be a MultimodalEncoder for multimodal encoder" self.register_mm_encoder_routes() + self.app.add_middleware(ServerArrivalTimeMiddleware) + + async def await_disconnected(self, raw_request: Request, promise): if raw_request is None: return @@ -206,6 +214,9 @@ def register_routes(self): # TODO: the metrics endpoint only reports iteration stats, not the runtime stats for now self.app.add_api_route("/metrics", self.get_iteration_stats, methods=["GET"]) self.app.add_api_route("/perf_metrics", self.get_perf_metrics, methods=["GET"]) + self.app.add_api_route("/steady_clock_offset", self.get_steady_clock_offset, methods=["GET"]) + # Called by the disagg server to set the disagg_server_steady_clock_offset + self.app.add_api_route("/steady_clock_offset", self.set_steady_clock_offset, methods=["POST"]) # TODO: workaround before ETCD support self.app.add_api_route("/kv_cache_events", self.get_kv_cache_events, methods=["POST"]) self.app.add_api_route("/v1/completions", @@ -257,7 +268,7 @@ def register_mm_encoder_routes(self): async def health(self) -> Response: return Response(status_code=200) - async def health_generate(self) -> Response: + async def health_generate(self, raw_request: Request) -> Response: """Health check that performs a minimal generation.""" try: # Create a minimal chat request @@ -269,10 +280,8 @@ async def health_generate(self) -> Response: temperature=0.0 # Deterministic output ) - mock_request = None - # Call the chat completion logic - response = await self.openai_chat(health_request, mock_request) + response = await self.openai_chat(health_request, raw_request) # Check if the response indicates success (status code 200) if response.status_code == 200: @@ -288,7 +297,7 @@ async def health_generate(self) -> Response: return Response(status_code=500, content="Generation health check failed") except Exception as e: - logger.error(f"Health generate check encountered exception: {e}", exc_info=True) + logger.error(f"Health generate check encountered exception: {e}") return Response(status_code=500, content=f"Generation health check failed: {str(e)}") async def version(self) -> JSONResponse: @@ -305,6 +314,17 @@ async def get_iteration_stats(self) -> JSONResponse: stats.append(stat) return JSONResponse(content=stats) + async def set_steady_clock_offset(self, offset: Annotated[float, Body(embed=True)]) -> Response: + self.disagg_server_steady_clock_offset = offset + logger.info(f"The steady clock offset between local and disagg server: {offset} second") + return Response(status_code=200) + + async def get_steady_clock_offset(self) -> JSONResponse: + receive_ts = get_steady_clock_now_in_seconds() + await asyncio.sleep(0.2) + transmit_ts = get_steady_clock_now_in_seconds() + return JSONResponse(content={"receive_ts": receive_ts, "transmit_ts": transmit_ts}) + async def get_perf_metrics(self) -> JSONResponse: if self.perf_metrics is None: return JSONResponse(content=[]) @@ -321,11 +341,19 @@ async def get_perf_metrics(self) -> JSONResponse: "last_iter": metrics.last_iter, # exclude metrics.iter since it is only meaningful when the request is not finished } + server_arrival_time = metrics_dict.pop("server_arrival_time", None) + if server_arrival_time is not None: + server_arrival_time += self.disagg_server_steady_clock_offset + server_first_token_time = metrics_dict.pop("server_first_token_time", None) + if server_first_token_time is not None: + server_first_token_time += self.disagg_server_steady_clock_offset metrics_json["timing_metrics"] = { - "arrival_time": timing_metrics.arrival_time.total_seconds(), - "first_scheduled_time": timing_metrics.first_scheduled_time.total_seconds(), - "first_token_time": timing_metrics.first_token_time.total_seconds(), - "last_token_time": timing_metrics.last_token_time.total_seconds(), + "server_arrival_time": server_arrival_time, + "arrival_time": timing_metrics.arrival_time.total_seconds() + self.disagg_server_steady_clock_offset, + "first_scheduled_time": timing_metrics.first_scheduled_time.total_seconds() + self.disagg_server_steady_clock_offset, + "first_token_time": timing_metrics.first_token_time.total_seconds() + self.disagg_server_steady_clock_offset, + "server_first_token_time": server_first_token_time, + "last_token_time": timing_metrics.last_token_time.total_seconds() + self.disagg_server_steady_clock_offset, } metrics_json["kv_cache_metrics"] = { "num_total_allocated_blocks": kv_cache_metrics.num_total_allocated_blocks, @@ -337,8 +365,8 @@ async def get_perf_metrics(self) -> JSONResponse: metrics_json["timing_metrics"].update({ # TODO: move to kv_cache_metrics "kv_cache_size": timing_metrics.kv_cache_size, - "kv_cache_transfer_start": timing_metrics.kv_cache_transfer_start.total_seconds(), - "kv_cache_transfer_end": timing_metrics.kv_cache_transfer_end.total_seconds(), + "kv_cache_transfer_start": timing_metrics.kv_cache_transfer_start.total_seconds() + self.disagg_server_steady_clock_offset, + "kv_cache_transfer_end": timing_metrics.kv_cache_transfer_end.total_seconds() + self.disagg_server_steady_clock_offset, }) if speculative_decoding.total_draft_tokens > 0: metrics_json["speculative_decoding"] = { @@ -359,7 +387,7 @@ async def get_kv_cache_events(self) -> JSONResponse: pass return JSONResponse(content=events) - async def _extract_metrics(self, res: RequestOutput): + async def _extract_metrics(self, res: RequestOutput, raw_request: Request): if not res.finished: return if self.metrics_collector: @@ -370,6 +398,9 @@ async def _extract_metrics(self, res: RequestOutput): "request_id": res.request_id, "perf_metrics": res.outputs[0].request_perf_metrics } + if raw_request: + item["server_arrival_time"] = getattr(raw_request.state, "server_arrival_time", None) + item["server_first_token_time"] = getattr(raw_request.state, "server_first_token_time", None) if output.disaggregated_params: item["ctx_request_id"] = output.disaggregated_params.ctx_request_id if self.perf_metrics is not None: @@ -390,12 +421,19 @@ async def chat_stream_generator( try: if not self.postproc_worker_enabled: post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + first_response = await anext(promise) + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() + pp_results = first_response.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(first_response, args) + for pp_res in pp_results: + yield pp_res + # Making sure we can handling the situation where there is only one response + res = first_response async for res in promise: pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) - await self._extract_metrics(res) for pp_res in pp_results: yield pp_res yield "data: [DONE]\n\n" + await self._extract_metrics(res, raw_request) nvtx_mark("generation ends") except: logger.error(traceback.format_exc()) @@ -413,7 +451,8 @@ async def create_chat_response( # Add prompt_tokens_ids to the response if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": chat_response.prompt_token_ids = promise.prompt_token_ids - await self._extract_metrics(promise) + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() + await self._extract_metrics(promise, raw_request) return chat_response try: @@ -582,7 +621,8 @@ async def completion_response(promise: RequestOutput, if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": # Include prompt token ids for context-only requests pp_result.prompt_token_ids = response.prompt_token_ids - await self._extract_metrics(response) + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() + await self._extract_metrics(response, raw_request) return pp_result def merge_completion_responses(responses: List[CompletionResponse]) -> CompletionResponse: @@ -619,9 +659,9 @@ async def completion_generator(promise: RequestOutput, params: Optional[Postproc pp_result = post_processor(output, args) else: pp_result = output.outputs[0]._postprocess_result - await self._extract_metrics(output) for pp_res in pp_result: yield pp_res + await self._extract_metrics(output, raw_request) except: logger.error(traceback.format_exc()) raise @@ -646,6 +686,9 @@ async def producer(generator: AsyncIterator[Any], idx: int): await asyncio.gather(*tasks) async def generator_wrapper(generator: AsyncIterator[Any]): + first_response = await anext(generator) + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() + yield first_response async for output in generator: yield output yield "data: [DONE]\n\n" diff --git a/tensorrt_llm/serve/responses_utils.py b/tensorrt_llm/serve/responses_utils.py index d4a6af268c4..ab8fdae47b5 100644 --- a/tensorrt_llm/serve/responses_utils.py +++ b/tensorrt_llm/serve/responses_utils.py @@ -35,6 +35,7 @@ StreamState, SystemContent, TextContent, ToolDescription, load_harmony_encoding) +from tensorrt_llm.bindings import steady_clock_now from tensorrt_llm.llmapi import SamplingParams from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.logger import logger @@ -78,6 +79,10 @@ def decode_tokens(tokens): return get_encoding().decode(tokens) +def get_steady_clock_now_in_seconds() -> float: + return steady_clock_now().total_seconds() + + def parse_response_input( input_msg: ResponseInputOutputItem, prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]] @@ -846,3 +851,35 @@ def _send_event(event: OpenAIBaseModel): sequence_number=-1, response=final_response.model_dump(), )) + + +class ServerArrivalTimeMiddleware: + """ + Custom ASGI middleware to track server arrival time. + + We implement this as a pure ASGI middleware instead of using FastAPI's + @app.middleware("http") decorator because the decorator internally uses + BaseHTTPMiddleware, which wraps the ASGI `receive` callable. This wrapping + breaks Request.is_disconnected() functionality - the wrapped receive doesn't + properly forward http.disconnect events while the middleware is waiting in + call_next(), preventing detection of client disconnections during long-running + non-streaming requests. + + By implementing pure ASGI middleware, we pass through the original receive/send + callables unchanged, preserving the ability to detect client disconnections. + + See: https://github.com/encode/starlette/discussions/2094 + """ + + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + # Add arrival time to scope + scope["state"] = {} + scope["state"][ + "server_arrival_time"] = get_steady_clock_now_in_seconds() + + # Pass through the original receive/send - no wrapping! + await self.app(scope, receive, send) diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 280a4c3a554..6907fd2ab25 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -37,6 +37,109 @@ def cleanup_output_files(): pass +def validate_timing_metrics(perf_metrics_item, request_context=""): + """ + Helper function to validate timing metrics relationships. + + Args: + perf_metrics_item: A single performance metrics item from the /perf_metrics endpoint + request_context: String context for error messages (e.g., "request 1", "streaming") + """ + # Validate basic structure + required_keys = [ + "ctx_server", "gen_server", "ctx_perf_metrics", "gen_perf_metrics", + "disagg_server_arrival_time", "disagg_server_first_token_time" + ] + for key in required_keys: + assert key in perf_metrics_item, f"Missing key: {key} in {request_context}" + + assert perf_metrics_item["ctx_perf_metrics"][ + "ctx_request_id"] == perf_metrics_item["gen_perf_metrics"][ + "ctx_request_id"] + + # Extract timing metrics + ctx_metrics = perf_metrics_item["ctx_perf_metrics"]["perf_metrics"][ + "timing_metrics"] + gen_metrics = perf_metrics_item["gen_perf_metrics"]["perf_metrics"][ + "timing_metrics"] + disagg_arrival = perf_metrics_item["disagg_server_arrival_time"] + disagg_first_token = perf_metrics_item["disagg_server_first_token_time"] + + # Validate disaggregated server timing metrics + assert disagg_arrival is not None, f"disagg_server_arrival_time is None in {request_context}" + assert disagg_first_token is not None, f"disagg_server_first_token_time is None in {request_context}" + assert isinstance( + disagg_arrival, + (int, float + )), f"disagg_server_arrival_time is not numeric in {request_context}" + assert isinstance( + disagg_first_token, (int, float) + ), f"disagg_server_first_token_time is not numeric in {request_context}" + assert disagg_arrival > 0, f"disagg_server_arrival_time is not positive in {request_context}" + assert disagg_first_token > 0, f"disagg_server_first_token_time is not positive in {request_context}" + assert disagg_arrival <= disagg_first_token, f"disagg_server_arrival_time > disagg_server_first_token_time in {request_context}" + + # Validate server-level timing metrics for context server + ctx_server_arrival = ctx_metrics.get("server_arrival_time") + ctx_server_first_token = ctx_metrics.get("server_first_token_time") + assert ctx_server_arrival is not None, f"ctx server_arrival_time is None in {request_context}" + assert ctx_server_first_token is not None, f"ctx server_first_token_time is None in {request_context}" + assert isinstance( + ctx_server_arrival, + (int, + float)), f"ctx server_arrival_time is not numeric in {request_context}" + assert isinstance( + ctx_server_first_token, + (int, float + )), f"ctx server_first_token_time is not numeric in {request_context}" + assert ctx_server_arrival <= ctx_server_first_token, f"ctx server_arrival_time > server_first_token_time in {request_context}" + assert ctx_metrics["last_token_time"] - ctx_server_first_token < 1e-3 + + # Validate server-level timing metrics for generation server + gen_server_arrival = gen_metrics.get("server_arrival_time") + gen_server_first_token = gen_metrics.get("server_first_token_time") + assert gen_server_arrival is not None, f"gen server_arrival_time is None in {request_context}" + assert gen_server_first_token is not None, f"gen server_first_token_time is None in {request_context}" + assert isinstance( + gen_server_arrival, + (int, + float)), f"gen server_arrival_time is not numeric in {request_context}" + assert isinstance( + gen_server_first_token, + (int, float + )), f"gen server_first_token_time is not numeric in {request_context}" + assert gen_server_arrival <= gen_server_first_token, f"gen server_arrival_time > server_first_token_time in {request_context}" + + # Validate timing relationships between different levels + # Disaggregated server should receive request before individual servers + assert disagg_arrival <= ctx_server_arrival, f"disagg_arrival > ctx_server_arrival in {request_context}" + assert disagg_arrival <= gen_server_arrival, f"disagg_arrival > gen_server_arrival in {request_context}" + + # Context should complete before generation starts + assert ctx_server_first_token <= gen_server_arrival, f"ctx_server_first_token > gen_server_arrival in {request_context}" + + # Validate internal timing consistency + ctx_arrival_time = ctx_metrics["arrival_time"] + ctx_first_token_time = ctx_metrics["first_token_time"] + gen_arrival_time = gen_metrics["arrival_time"] + gen_first_token_time = gen_metrics["first_token_time"] + + assert ctx_arrival_time <= ctx_first_token_time, f"ctx arrival_time > first_token_time in {request_context}" + assert gen_arrival_time <= gen_first_token_time, f"gen arrival_time > first_token_time in {request_context}" + + # Test KV cache transfer timing (if present) + if "kv_cache_transfer_start" in gen_metrics and "kv_cache_transfer_end" in gen_metrics: + kv_start = gen_metrics["kv_cache_transfer_start"] + kv_end = gen_metrics["kv_cache_transfer_end"] + assert gen_metrics["kv_cache_size"] > 0 + assert kv_start <= kv_end, f"kv_cache_transfer_start > kv_cache_transfer_end in {request_context}" + assert gen_arrival_time <= kv_start, f"gen_arrival_time > kv_cache_transfer_start in {request_context}" + assert kv_end <= gen_metrics[ + "first_scheduled_time"], f"kv_cache_transfer_end > first_scheduled_time in {request_context}" + + return True + + def get_disagg_server_url_from_cfg(config_file: str) -> str: with open(config_file, 'r') as file: config = yaml.safe_load(file) @@ -556,25 +659,9 @@ def extra_endpoints_test(server_url: str): perf_metrics = json.load(resp) assert len(perf_metrics) > 0 item = perf_metrics[0] - assert "ctx_server" in item - assert "gen_server" in item - assert "ctx_perf_metrics" in item - assert "gen_perf_metrics" in item - assert item["ctx_perf_metrics"]["ctx_request_id"] == item[ - "gen_perf_metrics"]["ctx_request_id"] - ctx_metrics = item["ctx_perf_metrics"]["perf_metrics"]["timing_metrics"] - gen_metrics = item["gen_perf_metrics"]["perf_metrics"]["timing_metrics"] - # only one token is generated in ctx - assert ctx_metrics["last_token_time"] - ctx_metrics[ - "first_token_time"] < 1e-3 - assert ctx_metrics["last_token_time"] < gen_metrics["arrival_time"] - assert gen_metrics["kv_cache_size"] > 0 - assert gen_metrics["arrival_time"] < gen_metrics[ - "kv_cache_transfer_start"] - assert gen_metrics["kv_cache_transfer_start"] < gen_metrics[ - "kv_cache_transfer_end"] - assert gen_metrics["kv_cache_transfer_end"] < gen_metrics[ - "first_scheduled_time"] + + # Use helper function to validate all timing metrics comprehensively + validate_timing_metrics(item, "perf_metrics test") run_disaggregated_test(disaggregated_example_root, "perf_metrics", diff --git a/tests/unittest/llmapi/apps/openai_server.py b/tests/unittest/llmapi/apps/openai_server.py index ca98f7e1ece..39c9988d9f5 100644 --- a/tests/unittest/llmapi/apps/openai_server.py +++ b/tests/unittest/llmapi/apps/openai_server.py @@ -29,7 +29,8 @@ def __init__(self, extra_config: Optional[dict] = None) -> None: self.host = host self.port = port if port is not None else find_free_port() - self.rank = rank if rank != -1 else os.environ.get("SLURM_PROCID", 0) + self.rank = rank if rank != -1 else int( + os.environ.get("SLURM_PROCID", 0)) self.extra_config_file = None args = ["--host", f"{self.host}", "--port", f"{self.port}"] if cli_args: