Skip to content
48 changes: 37 additions & 11 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class GenericLlmRequest
using RequestPtr = std::shared_ptr<GenericLlmRequest>;
using MillisecondsType = std::chrono::milliseconds;
using TimePoint = std::chrono::time_point<std::chrono::steady_clock>;
using Duration = std::chrono::time_point<std::chrono::steady_clock>::duration;
using CacheSaltIDType = runtime::CacheSaltIDType;

GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> const& inputTokens,
Expand Down Expand Up @@ -1255,7 +1256,7 @@ class GenericLlmRequest
{
if (mPerfMetrics.timingMetrics.firstScheduledTime == executor::RequestPerfMetrics::TimePoint{})
{
mPerfMetrics.timingMetrics.firstScheduledTime = std::chrono::steady_clock::now();
mPerfMetrics.timingMetrics.firstScheduledTime = getSteadyClockNow();
}
}

Expand Down Expand Up @@ -1689,22 +1690,22 @@ class GenericLlmRequest
mDecodingIter = iter;
}

void setKvCacheTransferStart(std::chrono::time_point<std::chrono::steady_clock> const& time)
void setKvCacheTransferStart(TimePoint const& time)
{
mPerfMetrics.timingMetrics.kvCacheTransferStart = time;
mPerfMetrics.timingMetrics.kvCacheTransferStart = maybeToGlobalSteadyClock(time);
}

void setKvCacheTransferEnd(std::chrono::time_point<std::chrono::steady_clock> const& time)
void setKvCacheTransferEnd(TimePoint const& time)
{
mPerfMetrics.timingMetrics.kvCacheTransferEnd = time;
mPerfMetrics.timingMetrics.kvCacheTransferEnd = maybeToGlobalSteadyClock(time);
}

std::chrono::time_point<std::chrono::steady_clock> getKvCacheTransferStart()
TimePoint getKvCacheTransferStart()
{
return mPerfMetrics.timingMetrics.kvCacheTransferStart;
}

std::chrono::time_point<std::chrono::steady_clock> getKvCacheTransferEnd()
TimePoint getKvCacheTransferEnd()
{
return mPerfMetrics.timingMetrics.kvCacheTransferEnd;
}
Expand Down Expand Up @@ -1788,7 +1789,7 @@ class GenericLlmRequest
if (finishReason == executor::FinishReason::kTIMED_OUT)
{
TLLM_LOG_DEBUG("Request %ld finished by timeout after %f sec", mRequestId,
std::chrono::duration<float>(std::chrono::steady_clock::now() - mStartTime).count());
std::chrono::duration<float>(getSteadyClockNow() - mStartTime).count());
}
if (finishReason == executor::FinishReason::kCANCELLED)
{
Expand Down Expand Up @@ -1826,10 +1827,9 @@ class GenericLlmRequest

void updatePerfMetrics(executor::IterationType iter)
{
auto const currentTokenTime = std::chrono::steady_clock::now();

if (!mPerfMetrics.firstIter)
{
auto const currentTokenTime = getSteadyClockNow();
mPerfMetrics.firstIter = iter;
mPerfMetrics.timingMetrics.firstTokenTime = currentTokenTime;
}
Expand All @@ -1838,6 +1838,7 @@ class GenericLlmRequest

if (isFinished())
{
auto const currentTokenTime = getSteadyClockNow();
mPerfMetrics.lastIter = iter;
mPerfMetrics.timingMetrics.lastTokenTime = currentTokenTime;
}
Expand All @@ -1863,6 +1864,15 @@ class GenericLlmRequest
return mUseDraftModel;
}

// If mGlobalSteadyClockOffset is set, return a global steady clock time point, otherwise return local steady clock
// time point
[[nodiscard]] TimePoint getSteadyClockNow() const
{
const TimePoint time_point = std::chrono::steady_clock::now();

return maybeToGlobalSteadyClock(time_point);
}

RequestIdType mRequestId;
SizeType32 mPromptLen;
SizeType32 mMaxNewTokens;
Expand All @@ -1882,6 +1892,9 @@ class GenericLlmRequest
// current position of the prompt tuning table (only used in chunked prefill mode)
SizeType32 mPtableCurrentPosition{0};

// The offset between local steady clock and global steady clock (at rank 0)
inline static std::optional<Duration> mGlobalSteadyClockOffset{std::nullopt};

protected:
bool mIsStreaming;

Expand Down Expand Up @@ -2137,7 +2150,8 @@ class GenericLlmRequest

if (mReturnPerfMetrics)
{
mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(std::chrono::steady_clock::now());
// arrivalTime is assumed to be recorded at the rank 0, so no need to convert it to global clock
mPerfMetrics.timingMetrics.arrivalTime = arrivalTime.value_or(getSteadyClockNow());
}
mStartTime = std::chrono::steady_clock::now();
}
Expand Down Expand Up @@ -2167,6 +2181,18 @@ class GenericLlmRequest

return tensor;
}

TimePoint maybeToGlobalSteadyClock(TimePoint const& time_point) const
{
if (mGlobalSteadyClockOffset.has_value())
{
return time_point + *mGlobalSteadyClockOffset;
}
else
{
return time_point;
}
}
};

class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
Expand Down
8 changes: 4 additions & 4 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
TLLM_CHECK(connections.size() > (processIdx / peerDuplicateHeadFactor));
TLLM_CHECK(outputSplitCaches.size() > (processIdx / peerDuplicateHeadFactor));
auto startTime = std::chrono::steady_clock::now();
auto startTime = llmRequest.getSteadyClockNow();

size_t ppDomainSize = targetInfo.mDomainPPSize;
size_t bufferTpRank = (processIdx / ppDomainSize) / peerDuplicateHeadFactor;
Expand Down Expand Up @@ -437,7 +437,7 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
}
}

auto endTime = std::chrono::steady_clock::now();
auto endTime = llmRequest.getSteadyClockNow();
double delay = 0.0;
if (recordDelay)
{
Expand Down Expand Up @@ -753,7 +753,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
TLLM_CHECK(pickUpConnections.size() > processIdx);
TLLM_CHECK(recvSplitCaches.size() > processIdx);
auto startTime = std::chrono::steady_clock::now();
auto startTime = llmRequest.getSteadyClockNow();
size_t size = 0;

if (processIdx >= remainNoCoverTargetNum)
Expand Down Expand Up @@ -794,7 +794,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
}
}

auto endTime = std::chrono::steady_clock::now();
auto endTime = llmRequest.getSteadyClockNow();
double delay = 0.0;
if (recordDelay)
{
Expand Down
8 changes: 4 additions & 4 deletions cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
NVTX3_SCOPED_RANGE(sendBufferFun);

TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
auto startTime = std::chrono::steady_clock::now();
auto startTime = llmRequest.getSteadyClockNow();
auto cacheIdx = processIdx % (pPDomainSize * cPDomainSize);
if (cacheIdx < bufferCoverTargetNum)
{
Expand Down Expand Up @@ -277,7 +277,7 @@ void MLACacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& ses
remainSendSize -= sendSize;
}
}
auto endTime = std::chrono::steady_clock::now();
auto endTime = llmRequest.getSteadyClockNow();
double delay = 0.0;
if (recordDelay)
{
Expand Down Expand Up @@ -451,7 +451,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
{
NVTX3_SCOPED_RANGE(recvBufferFun);
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
auto startTime = std::chrono::steady_clock::now();
auto startTime = llmRequest.getSteadyClockNow();
size_t size = 0;
if (processIdx >= remainNoCoverTargetNum)
{
Expand Down Expand Up @@ -484,7 +484,7 @@ void MLACacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& s
remainRecvSize -= recvSize;
}
}
auto endTime = std::chrono::steady_clock::now();
auto endTime = llmRequest.getSteadyClockNow();
double delay = 0.0;
if (recordDelay)
{
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ void initBindings(nb::module_& m)
.def("finish_by_reason", &tb::LlmRequest::finishByReason, nb::arg("finish_reason"))
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter"))
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors);
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors)
.def_rw_static("global_steady_clock_offset", &tb::LlmRequest::mGlobalSteadyClockOffset);

nb::class_<tb::SequenceSlotManager>(m, "SequenceSlotManager")
.def(nb::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), nb::arg("max_num_slots"),
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ void initBindings(pybind11::module_& m)
.def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason"))
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter"))
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors);
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors)
.def_readwrite_static("global_steady_clock_offset", &tb::LlmRequest::mGlobalSteadyClockOffset);

py::classh<tb::SequenceSlotManager>(m, "SequenceSlotManager")
.def(py::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), py::arg("max_num_slots"),
Expand Down
25 changes: 24 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import torch

from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds

try:
from cuda.bindings import runtime as cudart
except ImportError:
Expand Down Expand Up @@ -254,6 +256,7 @@ def __init__(self,
self.batch_wait_iters_count = 0

# request fetcher initialization
self._set_global_steady_clock_offset()
self.executor_request_queue = ExecutorRequestQueue(
dist=self.dist,
enable_attention_dp=self.enable_attention_dp,
Expand Down Expand Up @@ -365,6 +368,25 @@ def start_worker(self):
self.worker_thread.start()
self.worker_started = True

def _set_global_steady_clock_offset(self):
assert self.global_rank >= 0, "rank should be >= 0"

# Sync all ranks
self.dist.barrier()
# Immediately take the local steady clock timestamp
local_timestamp = get_steady_clock_now_in_seconds()
all_rank_timestamps = self.dist.allgather(local_timestamp)
if self.global_rank == 0:
logger.info(
f"global_steady_clock_offset at each rank: {[local_timestamp - ts for ts in all_rank_timestamps]}"
)
# Compute the steady clock offset between rank 0 and current rank
global_steady_clock_offset = all_rank_timestamps[0] - local_timestamp
LlmRequest.global_steady_clock_offset = global_steady_clock_offset
logger.info(
f"Setting global_steady_clock_offset: {global_steady_clock_offset} seconds for rank {self.global_rank}"
)

def __enter__(self):
return self

Expand Down Expand Up @@ -1962,7 +1984,8 @@ def _handle_responses(self):
request) > 0 else []
request.decoding_iter = request.py_decoding_iter

if request.return_perf_metrics:
# Skip active requests that are not scheduled
if request.return_perf_metrics and request.py_decoding_iter >= 1:
request.update_perf_metrics(self.model_engine.iter_counter)

request_done = False
Expand Down
24 changes: 18 additions & 6 deletions tensorrt_llm/executor/postproc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class Output(NamedTuple):
is_final: bool
error: str = ""
metrics: Optional[dict[str, float]] = None
request_perf_metrics: Any = None
disaggregated_params: Any = None

def __init__(
self,
Expand Down Expand Up @@ -142,6 +144,11 @@ async def _handle_input(
# Left the result_handler determine the final output dtype.
# NOTE: This will change the CompletionOutput._postprocess_result
metrics_dict = record.metrics_dict
perf_metrics = None
disaggregated_params = None
if record.outputs:
perf_metrics = record.outputs[0].request_perf_metrics
disaggregated_params = record.outputs[0].disaggregated_params
if postproc_params := record.postproc_params:
result_handler, args = postproc_params.post_processor, postproc_params.postproc_args
args.tokenizer = self._tokenizer
Expand All @@ -153,7 +160,7 @@ async def _handle_input(

# TODO: Keep only the diff token_ids and text in streaming mode when
# result_handler is not set
return out, metrics_dict
return out, metrics_dict, perf_metrics, disaggregated_params

async def _batched_put(self):
''' Batched IPC send. '''
Expand All @@ -176,12 +183,17 @@ async def handle_single_input(inp: PostprocWorker.Input,
client_id = inp.rsp.client_id
is_final = inp.rsp.result.is_final if is_llm_response(
inp.rsp) else True
res, metrics = await self._handle_input(inp)
res, metrics, perf_metrics, disaggregated_params = await self._handle_input(
inp)
batch.append(
PostprocWorker.Output(client_id=client_id,
res=res,
is_final=is_final,
metrics=metrics))
PostprocWorker.Output(
client_id=client_id,
res=res,
is_final=is_final,
metrics=metrics,
request_perf_metrics=perf_metrics,
disaggregated_params=disaggregated_params,
))
if is_final:
self._records.pop(client_id)

Expand Down
12 changes: 12 additions & 0 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,18 @@ def _handle_response(self,
self._outputs[0] = response.res
else:
self._outputs[0]._postprocess_result = response.res

self._outputs[
0].request_perf_metrics = response.request_perf_metrics
if not self._outputs[0].disaggregated_params:
disaggregated_params = response.disaggregated_params

# Generation only response has no disaggregated_params attached
if not disaggregated_params:
disaggregated_params = self.disaggregated_params

self._outputs[0].disaggregated_params = disaggregated_params

if response.metrics:
self.metrics_dict = response.metrics

Expand Down
Loading