Skip to content

Commit 355db6d

Browse files
committed
Code refactor
Signed-off-by: Yilin Fan <[email protected]>
1 parent f17a763 commit 355db6d

File tree

8 files changed

+87
-66
lines changed

8 files changed

+87
-66
lines changed

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,7 +1261,7 @@ class GenericLlmRequest
12611261
{
12621262
if (mPerfMetrics.timingMetrics.firstScheduledTime == executor::RequestPerfMetrics::TimePoint{})
12631263
{
1264-
mPerfMetrics.timingMetrics.firstScheduledTime = getCurrentSteadyClock();
1264+
mPerfMetrics.timingMetrics.firstScheduledTime = getSteadyClockNow();
12651265
}
12661266
}
12671267

@@ -1677,7 +1677,7 @@ class GenericLlmRequest
16771677
{
16781678
return false;
16791679
}
1680-
auto const currentTime = getCurrentSteadyClock();
1680+
auto const currentTime = getSteadyClockNow();
16811681
auto const elapsed = (std::chrono::duration_cast<Duration>(currentTime - mStartTime));
16821682
TLLM_LOG_DEBUG("Checked timeOut for request %ld with allotted Time %ld after time %ld and got %d", mRequestId,
16831683
mAllottedTimeMs->count(), elapsed.count(), (elapsed >= mAllottedTimeMs));
@@ -1794,7 +1794,7 @@ class GenericLlmRequest
17941794
if (finishReason == executor::FinishReason::kTIMED_OUT)
17951795
{
17961796
TLLM_LOG_DEBUG("Request %ld finished by timeout after %f sec", mRequestId,
1797-
std::chrono::duration<float>(getCurrentSteadyClock() - mStartTime).count());
1797+
std::chrono::duration<float>(getSteadyClockNow() - mStartTime).count());
17981798
}
17991799
if (finishReason == executor::FinishReason::kCANCELLED)
18001800
{
@@ -1832,10 +1832,9 @@ class GenericLlmRequest
18321832

18331833
void updatePerfMetrics(executor::IterationType iter)
18341834
{
1835-
auto const currentTokenTime = getCurrentSteadyClock();
1836-
18371835
if (!mPerfMetrics.firstIter)
18381836
{
1837+
auto const currentTokenTime = getSteadyClockNow();
18391838
mPerfMetrics.firstIter = iter;
18401839
mPerfMetrics.timingMetrics.firstTokenTime = currentTokenTime;
18411840
}
@@ -1844,6 +1843,7 @@ class GenericLlmRequest
18441843

18451844
if (isFinished())
18461845
{
1846+
auto const currentTokenTime = getSteadyClockNow();
18471847
mPerfMetrics.lastIter = iter;
18481848
mPerfMetrics.timingMetrics.lastTokenTime = currentTokenTime;
18491849
}
@@ -2060,7 +2060,7 @@ class GenericLlmRequest
20602060
// Cache salt id for each request.
20612061
std::optional<CacheSaltIDType> mCacheSaltID{std::nullopt};
20622062

2063-
// The offset between local steady clock and glabol steady clock (at rank 0)
2063+
// The offset between local steady clock and global steady clock (at rank 0)
20642064
std::optional<Duration> mGlobalSteadyClockOffset;
20652065
private:
20662066
void initialize(
@@ -2158,9 +2158,9 @@ class GenericLlmRequest
21582158

21592159
if (mReturnPerfMetrics)
21602160
{
2161-
mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(getCurrentSteadyClock());
2161+
mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(getSteadyClockNow());
21622162
}
2163-
mStartTime = getCurrentSteadyClock();
2163+
mStartTime = getSteadyClockNow();
21642164
}
21652165

21662166
TensorPtr createListTensor(std::list<VecTokens> const& wordsList)
@@ -2197,7 +2197,8 @@ class GenericLlmRequest
21972197
}
21982198
}
21992199

2200-
TimePoint getCurrentSteadyClock() const {
2200+
// If mGlobalSteadyClockOffset is set, return a global steady clock time point, otherwise return local steady clock time point
2201+
TimePoint getSteadyClockNow() const {
22012202
const TimePoint time_point = std::chrono::steady_clock::now();
22022203

22032204
return maybeToGlobalSteadyClock(time_point);

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,3 @@ triton==3.3.1; platform_machine == "x86_64"
7575
tiktoken
7676
blobfile
7777
openai-harmony==0.0.4
78-
nvidia-cutlass-dsl==4.1.0; python_version >= "3.12"

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class ExecutorRequestQueue:
4444
def __init__(self, dist: Distributed, enable_attention_dp: bool,
4545
max_batch_size: int, max_beam_width: int,
4646
max_num_active_requests: int, enable_iter_perf_stats: bool,
47-
batch_wait_timeout_ms: float, is_disaggregated: bool, monotonic_ts_offset: float):
47+
batch_wait_timeout_ms: float, is_disaggregated: bool, global_steady_clock_offset: float):
4848
self.dist = dist
4949
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
5050
self.waiting_queue: deque[RequestQueueItem] = deque()
@@ -60,7 +60,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool,
6060
self.start_times = {}
6161
self.active = True
6262
self.batch_wait_timeout_ms = batch_wait_timeout_ms
63-
self.monotonic_ts_offset = monotonic_ts_offset
63+
self.global_steady_clock_offset = global_steady_clock_offset
6464

6565
# State tracking
6666
self.num_fetch_requests = 0
@@ -612,7 +612,9 @@ def _merge_requests(
612612
else:
613613
req_with_children = []
614614
for req_item in new_requests:
615-
req_item.request.py_global_steady_clock_offset = self.monotonic_ts_offset
615+
if self.global_steady_clock_offset:
616+
req_item.request.py_global_steady_clock_offset = self.global_steady_clock_offset
617+
616618
req = executor_request_to_llm_request(
617619
req_item.id, req_item.request, req_item.child_req_ids,
618620
self._should_exclude_last_generation_logits())

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def __init__(self,
166166
self.device_id = torch.cuda.current_device()
167167
self.global_rank = global_mpi_rank()
168168
self.dist = dist
169-
self.monotonic_ts_offset = self._get_monotonic_ts_offset()
169+
self.global_steady_clock_offset = self._get_global_steady_clock_offset()
170170

171171
self.peft_cache_config = peft_cache_config
172172

@@ -262,7 +262,7 @@ def __init__(self,
262262
enable_iter_perf_stats=self.enable_iter_perf_stats,
263263
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
264264
is_disaggregated=kv_cache_transceiver is not None,
265-
monotonic_ts_offset = self.monotonic_ts_offset
265+
global_steady_clock_offset=self.global_steady_clock_offset,
266266
)
267267
self.executor_request_queue.set_exclude_last_generation_logits(
268268
self.disable_overlap_scheduler, self.dist.pp_size)
@@ -365,14 +365,18 @@ def start_worker(self):
365365
self.worker_thread.start()
366366
self.worker_started = True
367367

368-
def _get_monotonic_ts_offset(self):
368+
def _get_global_steady_clock_offset(self):
369369
assert self.global_rank >= 0, "rank should be >= 0"
370+
371+
# Sync all ranks
370372
self.dist.barrier()
373+
# Immediately take the local steady clock timestamp
371374
local_timestamp = time.monotonic()
372-
timestamps = self.dist.allgather(local_timestamp)
375+
all_rank_timestamps = self.dist.allgather(local_timestamp)
373376
if self.global_rank == 0:
374-
logger.info(f"monotonic_ts_offsets for each rank: {[local_timestamp - ts for ts in timestamps]}")
375-
return timestamps[0] - local_timestamp
377+
logger.info(f"global_steady_clock_offset at each rank: {[local_timestamp - ts for ts in all_rank_timestamps]}")
378+
# Compute the steady clock offset between rank 0 and current rank
379+
return all_rank_timestamps[0] - local_timestamp
376380

377381
def __enter__(self):
378382
return self
@@ -1904,6 +1908,7 @@ def _handle_responses(self):
19041908
request.draft_tokens = request.py_draft_tokens
19051909
request.decoding_iter = request.py_decoding_iter
19061910

1911+
# Skip active requests that are not scheduled
19071912
if request.return_perf_metrics and request.py_decoding_iter >= 1:
19081913
request.update_perf_metrics(self.model_engine.iter_counter)
19091914

tensorrt_llm/executor/result.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
NamedTuple, Optional, TypeAlias, Union)
88
from weakref import WeakMethod
99

10-
from tensorrt_llm.logger import logger
1110
import torch
1211
import torch.nn.functional as F
1312

@@ -323,11 +322,17 @@ def _handle_response(self,
323322
self._outputs[0] = response.res
324323
else:
325324
self._outputs[0]._postprocess_result = response.res
325+
326326
self._outputs[0].request_perf_metrics = response.request_perf_metrics
327-
if response.disaggregated_params:
328-
self._outputs[0].disaggregated_params = response.disaggregated_params
329-
else:
330-
self._outputs[0].disaggregated_params = self.disaggregated_params
327+
if not self._outputs[0].disaggregated_params:
328+
disaggregated_params = response.disaggregated_params
329+
330+
# Generation only response has no disaggregated_params attached
331+
if not disaggregated_params:
332+
disaggregated_params = self.disaggregated_params
333+
334+
self._outputs[0].disaggregated_params = disaggregated_params
335+
331336
if response.metrics:
332337
self.metrics_dict = response.metrics
333338

tensorrt_llm/llmapi/llm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,9 @@ def generate_async(
354354
if self._executor is None or self._executor.is_shutdown():
355355
raise RuntimeError("LLM is shutting down")
356356

357+
arrival_time = steady_clock_now(
358+
) if self.args.return_perf_metrics else None
359+
357360
sampling_params = self._prepare_sampling_params(sampling_params)
358361
cache_salt_id = get_cache_salt_id(
359362
cache_salt) if cache_salt is not None else None
@@ -464,10 +467,6 @@ def generate_async(
464467
if _postproc_params:
465468
_postproc_params.postproc_args.num_prompt_tokens = len(
466469
prompt_token_ids)
467-
468-
arrival_time = steady_clock_now(
469-
) if self.args.return_perf_metrics else None
470-
471470
result = self._executor.generate_async(
472471
prompt_token_ids,
473472
query_token_ids=query_token_ids,

tensorrt_llm/serve/openai_disagg_server.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,12 @@ def __init__(self,
5757
self.perf_metrics_max_requests = config.perf_metrics_max_requests
5858
if self.perf_metrics_max_requests > 0:
5959
# record corresponding keys of context and generation servers for perf metrics
60-
# (ctx_server, gen_server, ctx_request_id, server_start_ts, server_first_token_ts)
60+
# (ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_ts)
6161
self.perf_metrics_keys = deque(maxlen=self.perf_metrics_max_requests)
6262
self.perf_metrics_keys_lock = asyncio.Lock()
6363
# server_url -> {ctx_request_id: perf_metrics}
6464
self.server_perf_metrics: dict[str, dict[int, dict]] = {}
6565

66-
# server_url -> the perf metric timestamp offset between the disagg server and worker server
67-
self.server_perf_ts_offsets: dict[str, float] = {}
6866
else:
6967
self.perf_metrics_keys = None
7068
self.perf_metrics_keys_lock = None
@@ -110,7 +108,7 @@ async def lifespan(app: FastAPI):
110108
await self.wait_for_servers_ready(server_start_timeout_secs)
111109

112110
if self.perf_metrics_max_requests > 0:
113-
await self.query_perf_ts_offsets(self.session)
111+
await self.set_steady_clock_offsets(self.session)
114112

115113
if self.metadata_server:
116114
logger.info("Starting server monitoring via metadata service")
@@ -143,7 +141,7 @@ async def lifespan(app: FastAPI):
143141
@self.app.middleware("http")
144142
async def add_process_time_header(raw_request: Request, call_next):
145143
start_time = time.monotonic()
146-
raw_request.state.server_start_ts = start_time
144+
raw_request.state.server_arrival_time = start_time
147145
response = await call_next(raw_request)
148146
return response
149147

@@ -202,7 +200,7 @@ async def version(self) -> JSONResponse:
202200

203201
async def _add_perf_metrics_keys(self, ctx_server: str, gen_server: str, ctx_request_id: int, raw_request: Request):
204202
async with self.perf_metrics_keys_lock:
205-
self.perf_metrics_keys.append((ctx_server, gen_server, ctx_request_id, raw_request.state.server_start_ts, raw_request.state.server_first_token_ts))
203+
self.perf_metrics_keys.append((ctx_server, gen_server, ctx_request_id, raw_request.state.server_arrival_time, raw_request.state.server_first_token_ts))
206204

207205
async def perf_metrics(self) -> JSONResponse:
208206
if self.perf_metrics_keys is None:
@@ -239,27 +237,23 @@ async def perf_metrics(self) -> JSONResponse:
239237
raise exc
240238

241239
remain_keys = []
242-
for ctx_server, gen_server, ctx_request_id, server_start_ts, server_first_token_ts in self.perf_metrics_keys:
240+
for ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_ts in self.perf_metrics_keys:
243241
gen_perf_metrics = self.server_perf_metrics[gen_server].pop(ctx_request_id, None)
244242
if gen_perf_metrics is None:
245243
# generation not finished
246-
remain_keys.append((ctx_server, gen_server, ctx_request_id, server_start_ts, server_first_token_ts))
244+
remain_keys.append((ctx_server, gen_server, ctx_request_id, server_arrival_time, server_first_token_ts))
247245
continue
248246
ctx_perf_metrics = self.server_perf_metrics[ctx_server].pop(ctx_request_id, None)
249247
return_metrics.append({
250248
"ctx_server": ctx_server,
251249
"gen_server": gen_server,
252-
"disagg_server_start_ts": server_start_ts,
250+
"disagg_server_arrival_time": server_arrival_time,
253251
"disagg_server_first_token_ts": server_first_token_ts,
254252
"ctx_perf_metrics": ctx_perf_metrics,
255253
"gen_perf_metrics": gen_perf_metrics})
256254
self.perf_metrics_keys = deque(remain_keys, maxlen=self.perf_metrics_max_requests)
257255

258-
response = {
259-
"server_perf_timestamp_offsets": self.server_perf_ts_offsets,
260-
"perf_metrics": return_metrics
261-
}
262-
return JSONResponse(content=response)
256+
return JSONResponse(content=return_metrics)
263257

264258

265259
async def openai_completion(self, req: CompletionRequest, raw_request: Request) -> Response:
@@ -514,28 +508,35 @@ async def send_completion_request(self, url: str, request: CompletionRequest) ->
514508
async def send_chat_request(self, url: str, request: ChatCompletionRequest) -> ChatCompletionResponse:
515509
return await self.send_request(url, request, "/v1/chat/completions", ChatCompletionResponse, self.create_chat_generator)
516510

517-
async def query_perf_ts_offsets(self, session: aiohttp.ClientSession):
518-
async def query_perf_ts_offset(server_url: str) -> Optional[float]:
511+
async def set_steady_clock_offsets(self, session: aiohttp.ClientSession):
512+
STEADY_CLOCK_OFFSET_ENDPOINT = "/steady_clock_offset"
513+
async def query_steady_clock_offset(server_url: str) -> Optional[float]:
519514
try:
520515
originate_ts = time.monotonic()
521-
async with session.get(server_url + '/perf_ts_offset') as response:
516+
async with session.get(server_url + STEADY_CLOCK_OFFSET_ENDPOINT) as response:
522517
destination_ts = time.monotonic()
523518
if response.status == 200:
524519
response = await response.json()
520+
# Compute the steady clock timestamp difference using the NTP clock synchronization algorithm. https://en.wikipedia.org/wiki/Network_Time_Protocol#Clock_synchronization_algorithm
525521
receive_ts = response['receive_ts']
526522
transmit_ts = response['transmit_ts']
527523
delay = (destination_ts - originate_ts) - (transmit_ts - receive_ts)
528-
offset = - ((receive_ts - originate_ts) + (transmit_ts - destination_ts)) / 2
524+
offset = ((receive_ts - originate_ts) + (transmit_ts - destination_ts)) / 2
529525
return delay, offset
530526
else:
531527
return None, None
532528
except Exception:
533529
return None
530+
async def set_steady_clock_offset(server_url: str, offset: float) -> Optional[float]:
531+
payload = {"offset": offset}
532+
async with session.post(server_url + STEADY_CLOCK_OFFSET_ENDPOINT, json=payload) as response:
533+
if response.status != 200:
534+
logger.warning(f"Cannot set disagg server steady clock offset for server {server_url}, the perf metrics timestamps could be mis-aligned")
534535
for server_url in self.ctx_servers + self.gen_servers:
535-
delay, offset = await query_perf_ts_offset(server_url)
536-
self.server_perf_ts_offsets[server_url] = offset
536+
delay, offset = await query_steady_clock_offset(server_url)
537537
logger.info(f'Server: {server_url}, delay: {delay} second, offset: {offset} second')
538-
logger.info(f"Server perf metrics timestamp offsets: {self.server_perf_ts_offsets}")
538+
# Negate the offset so that worker servers can adjust their steady block by adding the new offset
539+
await set_steady_clock_offset(server_url, -offset)
539540

540541
@classmethod
541542
async def check_server_ready(cls, session: aiohttp.ClientSession, server_url: str) -> bool:

0 commit comments

Comments
 (0)