Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1456,13 +1456,15 @@ class CacheTransceiverConfig
UCX = 2,
NIXL = 3
};
explicit CacheTransceiverConfig(
std::optional<BackendType> backendType = std::nullopt, std::optional<size_t> maxNumTokens = std::nullopt);
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt);

bool operator==(CacheTransceiverConfig const& other) const;
void setBackendType(std::optional<BackendType> backendType);
void setMaxTokensInBuffer(std::optional<size_t> maxTokensInBuffer);
void setKvTransferTimeoutMs(std::optional<int> kvTransferTimeoutMs);

[[nodiscard]] std::optional<int> getKvTransferTimeoutMs() const;
[[nodiscard]] std::optional<size_t> getMaxTokensInBuffer() const;
[[nodiscard]] std::optional<BackendType> getBackendType() const;

Expand All @@ -1472,6 +1474,7 @@ class CacheTransceiverConfig
/// kvCache tokens to be transferred for a single request is greater than this value, the performance of the cache
/// transfer may be degraded.
std::optional<size_t> mMaxTokensInBuffer;
std::optional<int> mKvTransferTimeoutMs;
};

/// @brief Configuration class for the model executor
Expand Down
18 changes: 15 additions & 3 deletions cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@ namespace tensorrt_llm::executor
{

CacheTransceiverConfig::CacheTransceiverConfig(
std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens)
std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs)
: mBackendType(backendType)
, mMaxTokensInBuffer(maxNumTokens)
: mMaxTokensInBuffer(maxNumTokens)
, mKvTransferTimeoutMs(kvTransferTimeoutMs)
Comment on lines +25 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix the syntax error in the constructor initializer list.

The initializer list has an incorrect colon on line 27. C++ initializer lists should have only one colon after the constructor signature, with subsequent members separated by commas.

Apply this diff to fix the syntax error:

 CacheTransceiverConfig::CacheTransceiverConfig(
     std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs)
     : mBackendType(backendType)
-    : mMaxTokensInBuffer(maxNumTokens)
+    , mMaxTokensInBuffer(maxNumTokens)
     , mKvTransferTimeoutMs(kvTransferTimeoutMs)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs)
: mBackendType(backendType)
, mMaxTokensInBuffer(maxNumTokens)
: mMaxTokensInBuffer(maxNumTokens)
, mKvTransferTimeoutMs(kvTransferTimeoutMs)
CacheTransceiverConfig::CacheTransceiverConfig(
std::optional<BackendType> backendType,
std::optional<size_t> maxNumTokens,
std::optional<int> kvTransferTimeoutMs)
: mBackendType(backendType)
, mMaxTokensInBuffer(maxNumTokens)
, mKvTransferTimeoutMs(kvTransferTimeoutMs)
{
//
}
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp around lines 25-28 the
constructor initializer list incorrectly uses a second colon before
mMaxTokensInBuffer, causing a syntax error; change that second colon to a comma
so the initializer list has a single leading colon after the constructor
signature and subsequent members separated by commas (i.e., :
mBackendType(backendType), mMaxTokensInBuffer(maxNumTokens),
mKvTransferTimeoutMs(kvTransferTimeoutMs)).

{
}

bool CacheTransceiverConfig::operator==(CacheTransceiverConfig const& other) const
{
return mMaxTokensInBuffer == other.mMaxTokensInBuffer && mBackendType == other.mBackendType;
return mMaxTokensInBuffer == other.mMaxTokensInBuffer && mBackendType == other.mBackendType
&& mKvTransferTimeoutMs == other.mKvTransferTimeoutMs;
}

void CacheTransceiverConfig::setBackendType(std::optional<BackendType> backendType)
Expand All @@ -43,6 +45,11 @@ void CacheTransceiverConfig::setMaxTokensInBuffer(std::optional<size_t> maxToken
mMaxTokensInBuffer = maxTokensInBuffer;
}

void CacheTransceiverConfig::setKvTransferTimeoutMs(std::optional<int> kvTransferTimeoutMs)
{
mKvTransferTimeoutMs = kvTransferTimeoutMs;
}

std::optional<CacheTransceiverConfig::BackendType> CacheTransceiverConfig::getBackendType() const
{
return mBackendType;
Expand All @@ -53,4 +60,9 @@ std::optional<size_t> CacheTransceiverConfig::getMaxTokensInBuffer() const
return mMaxTokensInBuffer;
}

std::optional<int> CacheTransceiverConfig::getKvTransferTimeoutMs() const
{
return mKvTransferTimeoutMs;
}

} // namespace tensorrt_llm::executor
15 changes: 9 additions & 6 deletions cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,15 +433,15 @@ void initConfigBindings(nb::module_& m)
.def("__setstate__", guidedDecodingConfigSetstate);

auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self)
{ return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); };
{ return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); };
auto cacheTransceiverConfigSetstate = [](tle::CacheTransceiverConfig& self, nb::tuple const& state)
{
if (state.size() != 2)
if (state.size() != 3)
{
throw std::runtime_error("Invalid CacheTransceiverConfig state!");
}
new (&self) tle::CacheTransceiverConfig(
nb::cast<tle::CacheTransceiverConfig::BackendType>(state[0]), nb::cast<std::optional<size_t>>(state[1]));
new (&self) tle::CacheTransceiverConfig(nb::cast<tle::CacheTransceiverConfig::BackendType>(state[0]),
nb::cast<std::optional<size_t>>(state[1]), nb::cast<std::optional<int>>(state[2]));
};

nb::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType")
Expand All @@ -464,12 +464,15 @@ void initConfigBindings(nb::module_& m)
});

nb::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt)
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
std::optional < int >> (), nb::arg("backend") = std::nullopt,
nb::arg("max_tokens_in_buffer") = std::nullopt, nb::arg("kv_transfer_timeout_ms") = std::nullopt)
Comment on lines +467 to +469
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix the syntax error in constructor binding.

The template parameter list is incorrectly split across lines. The comma after std::optional<size_t>> on line 467 breaks the template syntax.

Apply this diff to fix the syntax error:

-.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
-    std::optional < int >> (), nb::arg("backend") = std::nullopt,
+.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
+    std::optional<int>>(), nb::arg("backend") = std::nullopt,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
std::optional < int >> (), nb::arg("backend") = std::nullopt,
nb::arg("max_tokens_in_buffer") = std::nullopt, nb::arg("kv_transfer_timeout_ms") = std::nullopt)
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
std::optional<int>>(), nb::arg("backend") = std::nullopt,
nb::arg("max_tokens_in_buffer") = std::nullopt, nb::arg("kv_transfer_timeout_ms") = std::nullopt)
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp around lines 467 to
469, the nb::init template parameter list is split across lines causing a syntax
error (the comma after std::optional<size_t>> breaks the template). Fix by using
a single, correct template parameter list (e.g.
nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>,
std::optional<size_t>, std::optional<int>>() ) and keep the nb::arg(...)
defaults as before so the constructor binding compiles.

.def_prop_rw(
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
.def_prop_rw("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
.def_prop_rw("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs,
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs)
.def("__getstate__", cacheTransceiverConfigGetstate)
.def("__setstate__", cacheTransceiverConfigSetstate);

Expand Down
15 changes: 9 additions & 6 deletions cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,15 +415,15 @@ void initConfigBindings(pybind11::module_& m)
.def(py::pickle(guidedDecodingConfigGetstate, guidedDecodingConfigSetstate));

auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self)
{ return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); };
{ return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); };
auto cacheTransceiverConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 2)
if (state.size() != 3)
{
throw std::runtime_error("Invalid CacheTransceiverConfig state!");
}
return tle::CacheTransceiverConfig(
state[0].cast<tle::CacheTransceiverConfig::BackendType>(), state[1].cast<std::optional<size_t>>());
return tle::CacheTransceiverConfig(state[0].cast<tle::CacheTransceiverConfig::BackendType>(),
state[1].cast<std::optional<size_t>>(), state[2].cast<std::optional<int>>());
};

py::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType")
Expand All @@ -446,12 +446,15 @@ void initConfigBindings(pybind11::module_& m)
});

py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt)
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
std::optional < int >> (), py::arg("backend") = std::nullopt,
py::arg("max_tokens_in_buffer") = std::nullopt, py::arg("kv_transfer_timeout_ms") = std::nullopt)
Comment on lines 448 to +451
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix malformed py::init signature

The py::init template parameter list is broken—std::optional<int>> ended up outside the template argument list, so this won’t compile. Keep all three optionals inside the template parameters.

-        .def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
-            std::optional < int >> (), py::arg("backend") = std::nullopt,
+        .def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
+                std::optional<int>>(),
+            py::arg("backend") = std::nullopt,
             py::arg("max_tokens_in_buffer") = std::nullopt, py::arg("kv_transfer_timeout_ms") = std::nullopt)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt)
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
std::optional < int >> (), py::arg("backend") = std::nullopt,
py::arg("max_tokens_in_buffer") = std::nullopt, py::arg("kv_transfer_timeout_ms") = std::nullopt)
py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
std::optional<int>>(),
py::arg("backend") = std::nullopt,
py::arg("max_tokens_in_buffer") = std::nullopt, py::arg("kv_transfer_timeout_ms") = std::nullopt)

.def_property(
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
.def_property("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
.def_property("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs,
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs)
.def(py::pickle(cacheTransceiverConfigGetstate, cacheTransceiverConfigSetstate));

auto executorConfigGetState = [](py::object const& self)
Expand Down
3 changes: 3 additions & 0 deletions examples/disaggregated/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ cache_transceiver_config:
backend: <str>
# KV cache buffer size. Set it ≥ the maximum ISL (Input Sequence Length) for best performance.
max_tokens_in_buffer: <int>
# KV cache transfer timeout in milliseconds
# For requests, if they do not send/receive the KV cache in time they are cancelled and cleaned up
kv_transfer_timeout_ms: <int>
```

The following is an example, consisting of the `ctx_extra-llm-api-config.yaml` and `gen_extra-llm-api-config.yaml` files needed in the sections below.
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def __init__(self, mapping: Mapping, dist: Distributed,
# get the layer num per pp rank, which is required by cache transceiver.
pp_layer_num = len(kv_cache_manager.pp_layers)
pp_layer_num_per_pp_rank = dist.pp_allgather(pp_layer_num)

self.kv_transfer_timeout_ms = cache_transceiver_config.kv_transfer_timeout_ms
self.impl = CacheTransceiverCpp(kv_cache_manager.impl,
total_num_kv_heads_per_layer, head_dim,
tokens_per_block, world_config,
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,8 @@ def __init__(
self.py_lora_task_layer_module_configs: list[
tensorrt_llm.bindings.internal.runtime.
TaskLayerModuleConfig] | None = None
self.py_kv_transfer_start_time = None
self.py_kv_transfer_timed_out = False

self.py_num_logprobs = num_logprobs
self.py_return_log_probs = return_log_probs
Expand Down
64 changes: 64 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,7 @@ def _executor_loop_pp(self):
self.micro_batches[prev_microbatch_id] = None

if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._check_kv_transfer_timeout()
self._terminate_ctx_finished_requests()

if self._disagg_pp_termination_handler is not None:
Expand Down Expand Up @@ -1006,6 +1007,7 @@ def _prepare_and_schedule_batch(self):

if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()
self._check_kv_transfer_timeout()

iter_stats = None
if self.enable_iter_perf_stats:
Expand Down Expand Up @@ -1179,6 +1181,7 @@ def _executor_loop(self):
self._add_kv_cache_events()

if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._check_kv_transfer_timeout()
self._terminate_ctx_finished_requests()

self._kv_connector_terminate_requests()
Expand Down Expand Up @@ -1364,6 +1367,7 @@ def _executor_loop_overlap(self):
ctx_transmission_reqs=ctx_transmission_reqs)

if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._check_kv_transfer_timeout()
self._terminate_ctx_finished_requests()

self._kv_connector_terminate_requests()
Expand Down Expand Up @@ -1572,6 +1576,38 @@ def _check_disagg_gen_transfer_status(self):

return

@nvtx_range("_check_kv_transfer_timeout")
def _check_kv_transfer_timeout(self):
if not self.kv_cache_transceiver:
return
timeout_ms = self.kv_cache_transceiver.kv_transfer_timeout_ms
if timeout_ms is None or timeout_ms <= 0:
return

current_time = time.time()

for req in self.ctx_in_transmission_requests:
if req.py_kv_transfer_start_time is None:
continue
elapsed_time = (current_time - req.py_kv_transfer_start_time) * 1000
if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out:
logger.warning(
f"Terminating context request {req.py_request_id} due to KV cache transfer timeout"
)
req.py_kv_transfer_timed_out = True

Comment on lines +1589 to +1598
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix AttributeError when checking context transfer timeouts

self.ctx_in_transmission_requests holds (request, blockId) tuples, so iterating directly over it and accessing req.py_kv_transfer_start_time will throw at runtime. Unpack the tuple (or index into it) before touching request attributes.

-        for req in self.ctx_in_transmission_requests:
+        for req, _ in self.ctx_in_transmission_requests:
             if req.py_kv_transfer_start_time is None:
                 continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for req in self.ctx_in_transmission_requests:
if req.py_kv_transfer_start_time is None:
continue
elapsed_time = (current_time - req.py_kv_transfer_start_time) * 1000
if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out:
logger.warning(
f"Terminating context request {req.py_request_id} due to KV cache transfer timeout"
)
req.py_kv_transfer_timed_out = True
for req, _ in self.ctx_in_transmission_requests:
if req.py_kv_transfer_start_time is None:
continue
elapsed_time = (current_time - req.py_kv_transfer_start_time) * 1000
if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out:
logger.warning(
f"Terminating context request {req.py_request_id} due to KV cache transfer timeout"
)
req.py_kv_transfer_timed_out = True
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/py_executor.py around lines 1589 to 1598, the
loop treats items in self.ctx_in_transmission_requests as request objects but
they are (request, blockId) tuples; unpack the tuple before accessing attributes
(e.g., for req, _ in self.ctx_in_transmission_requests or req = item[0]), then
use req.py_kv_transfer_start_time, compute elapsed_time and set
req.py_kv_transfer_timed_out as before to avoid the AttributeError.

for req in self.active_requests:
if req.is_disagg_generation_transmission_in_progress and req.py_kv_transfer_start_time is not None:
elapsed_time = (current_time -
req.py_kv_transfer_start_time) * 1000
if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out:
logger.warning(
f"Terminating generation request {req.py_request_id} due to KV cache transfer timeout"
)
req.py_kv_transfer_timed_out = True

return

@nvtx_range("_pad_attention_dp_dummy_request")
def _pad_attention_dp_dummy_request(self):
"""
Expand Down Expand Up @@ -1646,6 +1682,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
req.context_current_position = req.prompt_len
req.decoding_iter = 1
req.py_decoding_iter = 1
req.py_kv_transfer_start_time = None
first_gen_tokens = req.context_phase_params.first_gen_tokens
ctx_draft_tokens = req.context_phase_params.draft_tokens
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
Expand All @@ -1669,6 +1706,11 @@ def _recv_disagg_gen_cache(self, new_gen_reqs):
for req in new_gen_reqs:
self.kv_cache_transceiver.request_and_receive_async(req)

if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
for req in new_gen_reqs:
if req.state == LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS:
req.py_kv_transfer_start_time = time.time()

block_transfer = all([
req.is_disagg_generation_transmission_in_progress
for req in self.active_requests
Expand Down Expand Up @@ -1701,6 +1743,11 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests):
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
]

if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
for req in ctx_in_transmission_requests:
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS:
req.py_kv_transfer_start_time = time.time()
Comment on lines +1746 to +1749
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix NameError when arming context timeout clock

ctx_in_transmission_requests is undefined here, so this block will raise immediately. The intent is to walk the freshly built ctx_transmission_reqs.

-        if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
-            for req in ctx_in_transmission_requests:
+        if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
+            for req in ctx_transmission_reqs:
                 if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS:
                     req.py_kv_transfer_start_time = time.time()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
for req in ctx_in_transmission_requests:
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS:
req.py_kv_transfer_start_time = time.time()
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
for req in ctx_transmission_reqs:
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS:
req.py_kv_transfer_start_time = time.time()
🧰 Tools
🪛 Ruff (0.13.3)

1747-1747: Undefined name ctx_in_transmission_requests

(F821)


return ctx_transmission_reqs

def _get_disagg_reqs_in_error_state(self):
Expand Down Expand Up @@ -2018,6 +2065,18 @@ def _handle_responses(self):
requests_to_terminate.append(request)
continue

# Check if generation request needs cleanup due to KV cache transfer timeout
if request.py_kv_transfer_timed_out:
# We need to check if can can successfully cancel the request before
# terminating it and freeing KV cache associated with it
is_cancelled = self.kv_cache_transceiver.cancel_request(request)
# This will terminate the request and send an error response
if is_cancelled:
self._handle_errors(
error_msg=f"Request {request.py_request_id} timed out",
requests=[request])
continue

if request.is_generation_only_request():
# If request is in transmission, so we don't need to emit a response
# Also, for the first iteration with overlap, we should skip since first
Expand Down Expand Up @@ -2068,6 +2127,11 @@ def _handle_responses(self):
def _terminate_ctx_finished_requests(self):
for request, block_id in self.ctx_in_transmission_requests[:]:
if request.is_disagg_context_complete_state:
if request.py_kv_transfer_timed_out:
is_cancelled = self.kv_cache_transceiver.cancel_request(
request)
if is_cancelled:
request.py_kv_transfer_start_time = None
if not self.block_reuse_enabled or self.kv_cache_manager.is_vswa:
self._terminate_request(request)
else:
Expand Down
9 changes: 8 additions & 1 deletion tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,10 +1286,17 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror):
default=None,
description="The max number of tokens the transfer buffer can fit.")

kv_transfer_timeout_ms: Optional[int] = Field(
default=None,
description=
"Timeout in milliseconds for KV cache transfer. Requests exceeding this timeout will be cancelled."
)

def _to_pybind(self):
return _CacheTransceiverConfig(
backend=_CacheTransceiverBackendType.from_string(self.backend),
max_tokens_in_buffer=self.max_tokens_in_buffer)
max_tokens_in_buffer=self.max_tokens_in_buffer,
kv_transfer_timeout_ms=self.kv_transfer_timeout_ms)


@dataclass
Expand Down
Loading