Skip to content

Commit 188924b

Browse files
committed
Address comments
Signed-off-by: Yilin Fan <[email protected]>
1 parent 6bd99cc commit 188924b

File tree

14 files changed

+91
-76
lines changed

14 files changed

+91
-76
lines changed

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,7 @@ class GenericLlmRequest
140140
std::optional<SizeType32> languageAdapterUid = std::nullopt,
141141
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
142142
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
143-
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt, std::optional<TimePoint> arrivalTime = std::nullopt,
144-
std::optional<Duration> globalSteadyClockOffset = std::nullopt)
143+
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt, std::optional<TimePoint> arrivalTime = std::nullopt)
145144
: mRequestId(requestId)
146145
, mPromptLen(inputTokens->size())
147146
, mMaxNewTokens(maxNewTokens)
@@ -199,7 +198,6 @@ class GenericLlmRequest
199198
, mLanguageAdapterUid(languageAdapterUid)
200199
, mAllottedTimeMs(allottedTimeMs)
201200
, mCacheSaltID(cacheSaltID)
202-
, mGlobalSteadyClockOffset(globalSteadyClockOffset)
203201
{
204202
if (mEncoderTokens.has_value() || encoderInputFeatures.has_value())
205203
{
@@ -227,8 +225,7 @@ class GenericLlmRequest
227225
executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1,
228226
std::optional<SizeType32> languageAdapterUid = std::nullopt,
229227
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
230-
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt,
231-
std::optional<Duration> globalSteadyClockOffset = std::nullopt)
228+
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
232229
: mRequestId(requestId)
233230
, mPromptLen(inputTokens.size())
234231
, mMaxNewTokens(maxNewTokens)
@@ -269,7 +266,6 @@ class GenericLlmRequest
269266
, mNumReturnSequences(numReturnSequences)
270267
, mLanguageAdapterUid(languageAdapterUid)
271268
, mCacheSaltID(cacheSaltID)
272-
, mGlobalSteadyClockOffset(globalSteadyClockOffset)
273269
{
274270
if (mEncoderTokens.has_value())
275271
{
@@ -1897,6 +1893,9 @@ class GenericLlmRequest
18971893
// current position of the prompt tuning table (only used in chunked prefill mode)
18981894
SizeType32 mPtableCurrentPosition{0};
18991895

1896+
// The offset between local steady clock and global steady clock (at rank 0)
1897+
inline static std::optional<Duration> mGlobalSteadyClockOffset{std::nullopt};
1898+
19001899
protected:
19011900
bool mIsStreaming;
19021901

@@ -2059,9 +2058,6 @@ class GenericLlmRequest
20592058
// Cache salt id for each request.
20602059
std::optional<CacheSaltIDType> mCacheSaltID{std::nullopt};
20612060

2062-
// The offset between local steady clock and global steady clock (at rank 0)
2063-
std::optional<Duration> mGlobalSteadyClockOffset;
2064-
20652061
private:
20662062
void initialize(
20672063
VecTokens const& inputTokens, bool outputLogProbs, std::optional<TimePoint> arrivalTime = std::nullopt)
@@ -2158,6 +2154,7 @@ class GenericLlmRequest
21582154

21592155
if (mReturnPerfMetrics)
21602156
{
2157+
// arrivalTime is assumed to be recorded at the rank 0, so no need to convert it to global clock
21612158
mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(getSteadyClockNow());
21622159
}
21632160
mStartTime = getSteadyClockNow();
@@ -2265,8 +2262,7 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
22652262
std::optional<SizeType32> languageAdapterUid = std::nullopt,
22662263
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
22672264
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
2268-
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt, std::optional<TimePoint> arrivalTime = std::nullopt,
2269-
std::optional<Duration> globalSteadyClockOffset = std::nullopt)
2265+
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt, std::optional<TimePoint> arrivalTime = std::nullopt)
22702266
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
22712267
samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList),
22722268
std::move(stopWordsList),
@@ -2297,7 +2293,7 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
22972293
: std::optional<std::shared_ptr<VecTokenExtraIds>>(std::nullopt),
22982294
numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics,
22992295
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID,
2300-
arrivalTime, globalSteadyClockOffset)
2296+
arrivalTime)
23012297
{
23022298
}
23032299

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,7 @@ void initBindings(nb::module_& m)
291291
std::optional<tb::LlmRequest::MillisecondsType> allotted_time_ms,
292292
std::optional<executor::ContextPhaseParams> context_phase_params,
293293
std::optional<tb::LlmRequest::CacheSaltIDType> cache_salt_id,
294-
std::optional<tb::LlmRequest::TimePoint> arrival_time,
295-
std::optional<tb::LlmRequest::TimePoint::duration> global_steady_clock_offset)
294+
std::optional<tb::LlmRequest::TimePoint> arrival_time)
296295
{
297296
auto makeOptionalTensor = [](std::optional<at::Tensor> const& atTensor, bool unsqueeze = false)
298297
{
@@ -333,7 +332,7 @@ void initBindings(nb::module_& m)
333332
encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids,
334333
num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics,
335334
guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id,
336-
arrival_time, global_steady_clock_offset};
335+
arrival_time};
337336
},
338337
nb::arg("request_id"), nb::arg("max_new_tokens"), nb::arg("input_tokens"), nb::arg("sampling_config"),
339338
nb::arg("is_streaming"), nb::arg("end_id") = std::nullopt, nb::arg("pad_id") = std::nullopt,
@@ -359,7 +358,7 @@ void initBindings(nb::module_& m)
359358
nb::arg("return_perf_metrics") = false, nb::arg("guided_decoding_params") = std::nullopt,
360359
nb::arg("language_adapter_uid") = std::nullopt, nb::arg("allotted_time_ms") = std::nullopt,
361360
nb::arg("context_phase_params") = std::nullopt, nb::arg("cache_salt_id") = std::nullopt,
362-
nb::arg("arrival_time") = std::nullopt, nb::arg("global_steady_clock_offset") = std::nullopt)
361+
nb::arg("arrival_time") = std::nullopt)
363362
.def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, nb::arg("vocab_size"))
364363
.def(nb::init<tb::LlmRequest const&>())
365364
.def("validate", &tb::LlmRequest::validate, nb::arg("max_input_len"), nb::arg("max_seq_len"),
@@ -383,7 +382,8 @@ void initBindings(nb::module_& m)
383382
.def("finish_by_reason", &tb::LlmRequest::finishByReason, nb::arg("finish_reason"))
384383
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
385384
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter"))
386-
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors);
385+
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors)
386+
.def_rw_static("global_steady_clock_offset", &tb::LlmRequest::mGlobalSteadyClockOffset);
387387

388388
nb::class_<tb::SequenceSlotManager>(m, "SequenceSlotManager")
389389
.def(nb::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), nb::arg("max_num_slots"),

cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
127127
mAllottedTimeMs, //
128128
mContextPhaseParams, //
129129
mCacheSaltID, //
130-
mPerfMetrics.timingMetrics.arrivalTime, //
131-
mGlobalSteadyClockOffset //
130+
mPerfMetrics.timingMetrics.arrivalTime //
132131
);
133132
}

cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ class LlmRequest : public tb::GenericLlmRequest<at::Tensor, c10::Stream>
8686
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
8787
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
8888
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt,
89-
std::optional<TimePoint> arrivalTime = std::nullopt,
90-
std::optional<TimePoint::duration> globalSteadyClockOffset = std::nullopt)
89+
std::optional<TimePoint> arrivalTime = std::nullopt)
9190
: Base(requestId, //
9291
maxNewTokens, //
9392
std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)), //
@@ -150,8 +149,7 @@ class LlmRequest : public tb::GenericLlmRequest<at::Tensor, c10::Stream>
150149
allottedTimeMs, //
151150
contextPhaseParams, //
152151
cacheSaltID, //
153-
arrivalTime, //
154-
globalSteadyClockOffset //
152+
arrivalTime //
155153
)
156154
{
157155
}

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,7 @@ void initBindings(pybind11::module_& m)
296296
std::optional<tb::LlmRequest::MillisecondsType> allotted_time_ms,
297297
std::optional<executor::ContextPhaseParams> context_phase_params,
298298
std::optional<tb::LlmRequest::CacheSaltIDType> cache_salt_id,
299-
std::optional<tb::LlmRequest::TimePoint> arrival_time,
300-
std::optional<TimePoint::duration> globalSteadyClockOffset = std::nullopt)
299+
std::optional<tb::LlmRequest::TimePoint> arrival_time)
301300
{
302301
auto makeOptionalTensor = [](std::optional<at::Tensor> const& atTensor, bool unsqueeze = false)
303302
{
@@ -338,7 +337,7 @@ void initBindings(pybind11::module_& m)
338337
encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr,
339338
llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config,
340339
skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, guided_decoding_params,
341-
language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, arrival_time, global_steady_clock_offset};
340+
language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, arrival_time};
342341
}),
343342
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
344343
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
@@ -365,7 +364,7 @@ void initBindings(pybind11::module_& m)
365364
py::arg("return_perf_metrics") = false, py::arg("guided_decoding_params") = std::nullopt,
366365
py::arg("language_adapter_uid") = std::nullopt, py::arg("allotted_time_ms") = std::nullopt,
367366
py::arg("context_phase_params") = std::nullopt, py::arg("cache_salt_id") = std::nullopt,
368-
nb::arg("arrival_time") = std::nullopt, nb::arg("global_steady_clock_offset") = std::nullopt)
367+
py::arg("arrival_time") = std::nullopt)
369368
.def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, py::arg("vocab_size"))
370369
.def(py::init<tb::LlmRequest const&>())
371370
.def("validate", &tb::LlmRequest::validate, py::arg("max_input_len"), py::arg("max_seq_len"),
@@ -389,7 +388,8 @@ void initBindings(pybind11::module_& m)
389388
.def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason"))
390389
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
391390
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter"))
392-
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors);
391+
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors)
392+
.def_readwrite_static("global_steady_clock_offset", &tb::LlmRequest::mGlobalSteadyClockOffset);
393393

394394
py::classh<tb::SequenceSlotManager>(m, "SequenceSlotManager")
395395
.def(py::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), py::arg("max_num_slots"),

cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
126126
mAllottedTimeMs, //
127127
mContextPhaseParams, //
128128
mCacheSaltID, //
129-
mPerfMetrics.timingMetrics.arrivalTime, //
130-
mGlobalSteadyClockOffset //
129+
mPerfMetrics.timingMetrics.arrivalTime //
131130
);
132131
}

cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ class LlmRequest : public tb::GenericLlmRequest<at::Tensor, c10::Stream>
8686
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
8787
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
8888
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt,
89-
std::optional<TimePoint> arrivalTime = std::nullopt,
90-
std::optional<TimePoint::duration> globalSteadyClockOffset = std::nullopt)
89+
std::optional<TimePoint> arrivalTime = std::nullopt)
9190
: Base(requestId, //
9291
maxNewTokens, //
9392
std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)), //

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ class ExecutorRequestQueue:
4444
def __init__(self, dist: Distributed, enable_attention_dp: bool,
4545
max_batch_size: int, max_beam_width: int,
4646
max_num_active_requests: int, enable_iter_perf_stats: bool,
47-
batch_wait_timeout_ms: float, is_disaggregated: bool,
48-
global_steady_clock_offset: float):
47+
batch_wait_timeout_ms: float, is_disaggregated: bool):
4948
self.dist = dist
5049
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
5150
self.waiting_queue: deque[RequestQueueItem] = deque()
@@ -61,7 +60,6 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool,
6160
self.start_times = {}
6261
self.active = True
6362
self.batch_wait_timeout_ms = batch_wait_timeout_ms
64-
self.global_steady_clock_offset = global_steady_clock_offset
6563

6664
# State tracking
6765
self.num_fetch_requests = 0
@@ -613,9 +611,6 @@ def _merge_requests(
613611
else:
614612
req_with_children = []
615613
for req_item in new_requests:
616-
if self.global_steady_clock_offset:
617-
req_item.request.py_global_steady_clock_offset = self.global_steady_clock_offset
618-
619614
req = executor_request_to_llm_request(
620615
req_item.id, req_item.request, req_item.child_req_ids,
621616
self._should_exclude_last_generation_logits())

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,6 @@ def executor_request_to_llm_request(
586586
context_phase_params=executor_request.context_phase_params,
587587
cache_salt_id=executor_request.cache_salt_id,
588588
arrival_time=getattr(executor_request, "py_arrival_time", None),
589-
global_steady_clock_offset=getattr(executor_request, "py_global_steady_clock_offset", None),
590589
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
591590
None))
592591
if child_req_ids:

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import torch
1414

15+
from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds
16+
1517
try:
1618
from cuda.bindings import runtime as cudart
1719
except ImportError:
@@ -165,8 +167,6 @@ def __init__(self,
165167
super(PyExecutor, self).__init__()
166168
self.device_id = torch.cuda.current_device()
167169
self.global_rank = global_mpi_rank()
168-
self.dist = dist
169-
self.global_steady_clock_offset = self._get_global_steady_clock_offset()
170170

171171
self.peft_cache_config = peft_cache_config
172172

@@ -185,6 +185,7 @@ def __init__(self,
185185
self.draft_model_engine = getattr(self.drafter, "draft_model_engine",
186186
None)
187187
self.guided_decoder = guided_decoder
188+
self.dist = dist
188189
self.disable_overlap_scheduler = disable_overlap_scheduler
189190

190191
# enqueue and _fetch_new_requests used data
@@ -253,6 +254,7 @@ def __init__(self,
253254
self.batch_wait_iters_count = 0
254255

255256
# request fetcher initialization
257+
self._set_global_steady_clock_offset()
256258
self.executor_request_queue = ExecutorRequestQueue(
257259
dist=self.dist,
258260
enable_attention_dp=self.enable_attention_dp,
@@ -262,7 +264,6 @@ def __init__(self,
262264
enable_iter_perf_stats=self.enable_iter_perf_stats,
263265
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
264266
is_disaggregated=kv_cache_transceiver is not None,
265-
global_steady_clock_offset=self.global_steady_clock_offset,
266267
)
267268
self.executor_request_queue.set_exclude_last_generation_logits(
268269
self.disable_overlap_scheduler, self.dist.pp_size)
@@ -365,20 +366,24 @@ def start_worker(self):
365366
self.worker_thread.start()
366367
self.worker_started = True
367368

368-
def _get_global_steady_clock_offset(self):
369+
def _set_global_steady_clock_offset(self):
369370
assert self.global_rank >= 0, "rank should be >= 0"
370371

371372
# Sync all ranks
372373
self.dist.barrier()
373374
# Immediately take the local steady clock timestamp
374-
local_timestamp = time.monotonic()
375+
local_timestamp = get_steady_clock_now_in_seconds()
375376
all_rank_timestamps = self.dist.allgather(local_timestamp)
376377
if self.global_rank == 0:
377378
logger.info(
378379
f"global_steady_clock_offset at each rank: {[local_timestamp - ts for ts in all_rank_timestamps]}"
379380
)
380381
# Compute the steady clock offset between rank 0 and current rank
381-
return all_rank_timestamps[0] - local_timestamp
382+
global_steady_clock_offset = all_rank_timestamps[0] - local_timestamp
383+
LlmRequest.global_steady_clock_offset = global_steady_clock_offset
384+
logger.info(
385+
f"Setting global_steady_clock_offset: {global_steady_clock_offset} seconds for rank {self.global_rank}"
386+
)
382387

383388
def __enter__(self):
384389
return self

0 commit comments

Comments
 (0)