From f8c3670d9b2a12edb1667f143b18b91e0ea2a74b Mon Sep 17 00:00:00 2001 From: raayandhar Date: Tue, 7 Oct 2025 01:02:38 -0700 Subject: [PATCH 1/2] update to ToM, clean up a bit, move to cancel_request Signed-off-by: raayandhar --- cpp/include/tensorrt_llm/executor/executor.h | 7 ++- .../executor/cacheTransceiverConfig.cpp | 18 +++++- .../nanobind/executor/executorConfig.cpp | 15 +++-- .../pybind/executor/executorConfig.cpp | 15 +++-- examples/disaggregated/README.md | 3 + .../_torch/pyexecutor/kv_cache_transceiver.py | 2 + tensorrt_llm/_torch/pyexecutor/llm_request.py | 2 + tensorrt_llm/_torch/pyexecutor/py_executor.py | 56 +++++++++++++++++++ tensorrt_llm/llmapi/llm_args.py | 9 ++- 9 files changed, 109 insertions(+), 18 deletions(-) diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index b5769177b28..565b02c028d 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -1456,13 +1456,15 @@ class CacheTransceiverConfig UCX = 2, NIXL = 3 }; - explicit CacheTransceiverConfig( - std::optional backendType = std::nullopt, std::optional maxNumTokens = std::nullopt); + explicit CacheTransceiverConfig(std::optional backendType = std::nullopt, + std::optional maxNumTokens = std::nullopt, std::optional kvTransferTimeoutMs = std::nullopt); bool operator==(CacheTransceiverConfig const& other) const; void setBackendType(std::optional backendType); void setMaxTokensInBuffer(std::optional maxTokensInBuffer); + void setKvTransferTimeoutMs(std::optional kvTransferTimeoutMs); + [[nodiscard]] std::optional getKvTransferTimeoutMs() const; [[nodiscard]] std::optional getMaxTokensInBuffer() const; [[nodiscard]] std::optional getBackendType() const; @@ -1472,6 +1474,7 @@ class CacheTransceiverConfig /// kvCache tokens to be transferred for a single request is greater than this value, the performance of the cache /// transfer may be degraded. std::optional mMaxTokensInBuffer; + std::optional mKvTransferTimeoutMs; }; /// @brief Configuration class for the model executor diff --git a/cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp b/cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp index 6919d213642..e3abbb59ee2 100644 --- a/cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp +++ b/cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp @@ -22,15 +22,17 @@ namespace tensorrt_llm::executor { CacheTransceiverConfig::CacheTransceiverConfig( - std::optional backendType, std::optional maxNumTokens) + std::optional backendType, std::optional maxNumTokens, std::optional kvTransferTimeoutMs) : mBackendType(backendType) - , mMaxTokensInBuffer(maxNumTokens) + : mMaxTokensInBuffer(maxNumTokens) + , mKvTransferTimeoutMs(kvTransferTimeoutMs) { } bool CacheTransceiverConfig::operator==(CacheTransceiverConfig const& other) const { - return mMaxTokensInBuffer == other.mMaxTokensInBuffer && mBackendType == other.mBackendType; + return mMaxTokensInBuffer == other.mMaxTokensInBuffer && mBackendType == other.mBackendType + && mKvTransferTimeoutMs == other.mKvTransferTimeoutMs; } void CacheTransceiverConfig::setBackendType(std::optional backendType) @@ -43,6 +45,11 @@ void CacheTransceiverConfig::setMaxTokensInBuffer(std::optional maxToken mMaxTokensInBuffer = maxTokensInBuffer; } +void CacheTransceiverConfig::setKvTransferTimeoutMs(std::optional kvTransferTimeoutMs) +{ + mKvTransferTimeoutMs = kvTransferTimeoutMs; +} + std::optional CacheTransceiverConfig::getBackendType() const { return mBackendType; @@ -53,4 +60,9 @@ std::optional CacheTransceiverConfig::getMaxTokensInBuffer() const return mMaxTokensInBuffer; } +std::optional CacheTransceiverConfig::getKvTransferTimeoutMs() const +{ + return mKvTransferTimeoutMs; +} + } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp index 0334eb14f6a..ba06fbb4426 100644 --- a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp @@ -433,15 +433,15 @@ void initConfigBindings(nb::module_& m) .def("__setstate__", guidedDecodingConfigSetstate); auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self) - { return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); }; + { return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); }; auto cacheTransceiverConfigSetstate = [](tle::CacheTransceiverConfig& self, nb::tuple const& state) { - if (state.size() != 2) + if (state.size() != 3) { throw std::runtime_error("Invalid CacheTransceiverConfig state!"); } - new (&self) tle::CacheTransceiverConfig( - nb::cast(state[0]), nb::cast>(state[1])); + new (&self) tle::CacheTransceiverConfig(nb::cast(state[0]), + nb::cast>(state[1]), nb::cast>(state[2])); }; nb::enum_(m, "CacheTransceiverBackendType") @@ -464,12 +464,15 @@ void initConfigBindings(nb::module_& m) }); nb::class_(m, "CacheTransceiverConfig") - .def(nb::init, std::optional>(), - nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt) + .def(nb::init, std::optional>, + std::optional < int >> (), nb::arg("backend") = std::nullopt, + nb::arg("max_tokens_in_buffer") = std::nullopt, nb::arg("kv_transfer_timeout_ms") = std::nullopt) .def_prop_rw( "backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType) .def_prop_rw("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer, &tle::CacheTransceiverConfig::setMaxTokensInBuffer) + .def_prop_rw("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs, + &tle::CacheTransceiverConfig::setKvTransferTimeoutMs) .def("__getstate__", cacheTransceiverConfigGetstate) .def("__setstate__", cacheTransceiverConfigSetstate); diff --git a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp index 74e2fe56c16..f23971eb453 100644 --- a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp @@ -415,15 +415,15 @@ void initConfigBindings(pybind11::module_& m) .def(py::pickle(guidedDecodingConfigGetstate, guidedDecodingConfigSetstate)); auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self) - { return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); }; + { return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); }; auto cacheTransceiverConfigSetstate = [](py::tuple const& state) { - if (state.size() != 2) + if (state.size() != 3) { throw std::runtime_error("Invalid CacheTransceiverConfig state!"); } - return tle::CacheTransceiverConfig( - state[0].cast(), state[1].cast>()); + return tle::CacheTransceiverConfig(state[0].cast(), + state[1].cast>(), state[2].cast>()); }; py::enum_(m, "CacheTransceiverBackendType") @@ -446,12 +446,15 @@ void initConfigBindings(pybind11::module_& m) }); py::class_(m, "CacheTransceiverConfig") - .def(py::init, std::optional>(), - py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt) + .def(py::init, std::optional>, + std::optional < int >> (), py::arg("backend") = std::nullopt, + py::arg("max_tokens_in_buffer") = std::nullopt, py::arg("kv_transfer_timeout_ms") = std::nullopt) .def_property( "backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType) .def_property("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer, &tle::CacheTransceiverConfig::setMaxTokensInBuffer) + .def_property("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs, + &tle::CacheTransceiverConfig::setKvTransferTimeoutMs) .def(py::pickle(cacheTransceiverConfigGetstate, cacheTransceiverConfigSetstate)); auto executorConfigGetState = [](py::object const& self) diff --git a/examples/disaggregated/README.md b/examples/disaggregated/README.md index 3943d5fefba..cb731a91cc0 100644 --- a/examples/disaggregated/README.md +++ b/examples/disaggregated/README.md @@ -16,6 +16,9 @@ cache_transceiver_config: backend: # KV cache buffer size. Set it ≥ the maximum ISL (Input Sequence Length) for best performance. max_tokens_in_buffer: + # KV cache transfer timeout in milliseconds + # For requests, if they do not send/receive the KV cache in time they are cancelled and cleaned up + kv_transfer_timeout_ms: ``` The following is an example, consisting of the `ctx_extra-llm-api-config.yaml` and `gen_extra-llm-api-config.yaml` files needed in the sections below. diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index c762cd3378a..319000d28c6 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -109,6 +109,8 @@ def __init__(self, mapping: Mapping, dist: Distributed, # get the layer num per pp rank, which is required by cache transceiver. pp_layer_num = len(kv_cache_manager.pp_layers) pp_layer_num_per_pp_rank = dist.pp_allgather(pp_layer_num) + + self.kv_transfer_timeout_ms = cache_transceiver_config.kv_transfer_timeout_ms self.impl = CacheTransceiverCpp(kv_cache_manager.impl, total_num_kv_heads_per_layer, head_dim, tokens_per_block, world_config, diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 0c305078445..83ea30f72e4 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -442,6 +442,8 @@ def __init__( self.py_lora_task_layer_module_configs: list[ tensorrt_llm.bindings.internal.runtime. TaskLayerModuleConfig] | None = None + self.py_kv_transfer_start_time = None + self.py_kv_transfer_timed_out = False self.py_num_logprobs = num_logprobs self.py_return_log_probs = return_log_probs diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 3ccd466eb63..750d12d30e6 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -978,6 +978,7 @@ def _executor_loop_pp(self): self.micro_batches[prev_microbatch_id] = None if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + self._check_kv_transfer_timeout() self._terminate_ctx_finished_requests() if self._disagg_pp_termination_handler is not None: @@ -1006,6 +1007,7 @@ def _prepare_and_schedule_batch(self): if self.kv_cache_transceiver: self._check_disagg_gen_transfer_status() + self._check_kv_transfer_timeout() iter_stats = None if self.enable_iter_perf_stats: @@ -1179,6 +1181,7 @@ def _executor_loop(self): self._add_kv_cache_events() if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + self._check_kv_transfer_timeout() self._terminate_ctx_finished_requests() self._kv_connector_terminate_requests() @@ -1364,6 +1367,7 @@ def _executor_loop_overlap(self): ctx_transmission_reqs=ctx_transmission_reqs) if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + self._check_kv_transfer_timeout() self._terminate_ctx_finished_requests() self._kv_connector_terminate_requests() @@ -1572,6 +1576,38 @@ def _check_disagg_gen_transfer_status(self): return + @nvtx_range("_check_kv_transfer_timeout") + def _check_kv_transfer_timeout(self): + if not self.kv_cache_transceiver: + return + timeout_ms = self.kv_cache_transceiver.kv_transfer_timeout_ms + if timeout_ms is None or timeout_ms <= 0: + return + + current_time = time.time() + + for req in self.ctx_in_transmission_requests: + if req.py_kv_transfer_start_time is None: + continue + elapsed_time = (current_time - req.py_kv_transfer_start_time) * 1000 + if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out: + logger.warning( + f"Terminating context request {req.py_request_id} due to KV cache transfer timeout" + ) + req.py_kv_transfer_timed_out = True + + for req in self.active_requests: + if req.is_disagg_generation_transmission_in_progress and req.py_kv_transfer_start_time is not None: + elapsed_time = (current_time - + req.py_kv_transfer_start_time) * 1000 + if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out: + logger.warning( + f"Terminating generation request {req.py_request_id} due to KV cache transfer timeout" + ) + req.py_kv_transfer_timed_out = True + + return + @nvtx_range("_pad_attention_dp_dummy_request") def _pad_attention_dp_dummy_request(self): """ @@ -1646,6 +1682,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch): req.context_current_position = req.prompt_len req.decoding_iter = 1 req.py_decoding_iter = 1 + req.py_kv_transfer_start_time = None first_gen_tokens = req.context_phase_params.first_gen_tokens ctx_draft_tokens = req.context_phase_params.draft_tokens req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens @@ -1669,6 +1706,11 @@ def _recv_disagg_gen_cache(self, new_gen_reqs): for req in new_gen_reqs: self.kv_cache_transceiver.request_and_receive_async(req) + if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None: + for req in new_gen_reqs: + if req.state == LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS: + req.py_kv_transfer_start_time = time.time() + block_transfer = all([ req.is_disagg_generation_transmission_in_progress for req in self.active_requests @@ -1701,6 +1743,11 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests): if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS ] + if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None: + for req in ctx_in_transmission_requests: + if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS: + req.py_kv_transfer_start_time = time.time() + return ctx_transmission_reqs def _get_disagg_reqs_in_error_state(self): @@ -2018,6 +2065,12 @@ def _handle_responses(self): requests_to_terminate.append(request) continue + # Check if generation request needs cleanup due to KV cache transfer timeout + if request.py_kv_transfer_timed_out: + # Previously, we were doing _handle_errors, which sends an error response. + # We should consider how we should be doing this now? + self.kv_cache_transceiver.cancel_request(request) + if request.is_generation_only_request(): # If request is in transmission, so we don't need to emit a response # Also, for the first iteration with overlap, we should skip since first @@ -2068,6 +2121,9 @@ def _handle_responses(self): def _terminate_ctx_finished_requests(self): for request, block_id in self.ctx_in_transmission_requests[:]: if request.is_disagg_context_complete_state: + if request.py_kv_transfer_timed_out: + request.py_kv_transfer_start_time = None + self.kv_cache_transceiver.cancel_request(request) if not self.block_reuse_enabled or self.kv_cache_manager.is_vswa: self._terminate_request(request) else: diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 18f00e0a62a..82df030ca07 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1286,10 +1286,17 @@ class CacheTransceiverConfig(StrictBaseModel, PybindMirror): default=None, description="The max number of tokens the transfer buffer can fit.") + kv_transfer_timeout_ms: Optional[int] = Field( + default=None, + description= + "Timeout in milliseconds for KV cache transfer. Requests exceeding this timeout will be cancelled." + ) + def _to_pybind(self): return _CacheTransceiverConfig( backend=_CacheTransceiverBackendType.from_string(self.backend), - max_tokens_in_buffer=self.max_tokens_in_buffer) + max_tokens_in_buffer=self.max_tokens_in_buffer, + kv_transfer_timeout_ms=self.kv_transfer_timeout_ms) @dataclass From 1d1323a3499d8a1a54eb4aa2feeabe3aaf1d7bb0 Mon Sep 17 00:00:00 2001 From: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> Date: Wed, 8 Oct 2025 05:11:57 -0700 Subject: [PATCH 2/2] properly handling errors Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 750d12d30e6..6ea3f346220 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2067,9 +2067,15 @@ def _handle_responses(self): # Check if generation request needs cleanup due to KV cache transfer timeout if request.py_kv_transfer_timed_out: - # Previously, we were doing _handle_errors, which sends an error response. - # We should consider how we should be doing this now? - self.kv_cache_transceiver.cancel_request(request) + # We need to check if can can successfully cancel the request before + # terminating it and freeing KV cache associated with it + is_cancelled = self.kv_cache_transceiver.cancel_request(request) + # This will terminate the request and send an error response + if is_cancelled: + self._handle_errors( + error_msg=f"Request {request.py_request_id} timed out", + requests=[request]) + continue if request.is_generation_only_request(): # If request is in transmission, so we don't need to emit a response @@ -2122,8 +2128,10 @@ def _terminate_ctx_finished_requests(self): for request, block_id in self.ctx_in_transmission_requests[:]: if request.is_disagg_context_complete_state: if request.py_kv_transfer_timed_out: - request.py_kv_transfer_start_time = None - self.kv_cache_transceiver.cancel_request(request) + is_cancelled = self.kv_cache_transceiver.cancel_request( + request) + if is_cancelled: + request.py_kv_transfer_start_time = None if not self.block_reuse_enabled or self.kv_cache_manager.is_vswa: self._terminate_request(request) else: