Skip to content

Conversation

pcastonguay
Copy link
Collaborator

@pcastonguay pcastonguay commented Oct 7, 2025

Summary by CodeRabbit

  • New Features

    • Added optional kv_transfer_timeout_ms to configure KV cache transfer timeouts.
    • Executor now monitors KV transfers and automatically cancels and cleans up requests that exceed the timeout, emitting warnings.
    • Python API supports the new option and preserves it across save/load.
  • Documentation

    • Updated examples to document kv_transfer_timeout_ms and describe timeout behavior.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Copy link
Contributor

coderabbitai bot commented Oct 7, 2025

📝 Walkthrough

Walkthrough

Adds an optional KV cache transfer timeout to CacheTransceiverConfig across C++ API, bindings (nanobind/pybind), Python args, and docs. Implements timeout tracking and cancellation in the Python executor, initializing per-request timing, checking timeouts in main loops, and cleaning up timed-out requests.

Changes

Cohort / File(s) Summary
C++ API: CacheTransceiverConfig
cpp/include/tensorrt_llm/executor/executor.h, cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp
Introduces optional kvTransferTimeoutMs: new constructor param, getter/setter, member storage; updates equality operator.
Bindings: nanobind + pybind
cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp, cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
Extends bindings to pass kv_transfer_timeout_ms; updates constructors, properties, and pickle/getstate/setstate to 3-tuple with timeout.
Python API args
tensorrt_llm/llmapi/llm_args.py
Adds kv_transfer_timeout_ms to CacheTransceiverConfig dataclass and propagates via _to_pybind.
Python executor logic
tensorrt_llm/_torch/pyexecutor/py_executor.py, tensorrt_llm/_torch/pyexecutor/llm_request.py, tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py
Adds timeout tracking: per-request start/timeouts fields, a _check_kv_transfer_timeout method, integration points in loops, cancellation via kv_cache_transceiver, and state resets on completion.
Docs
examples/disaggregated/README.md
Documents new cache_transceiver_config.kv_transfer_timeout_ms option and behavior on timeout.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Client as Client
  participant PyExec as PyExecutor
  participant KvTx as KvCacheTransceiver
  participant Req as Request

  rect rgba(230,240,255,0.4)
    note over PyExec: Initiate KV transfer (ctx/gen)
    PyExec->>Req: set py_kv_transfer_start_time (if timeout configured)
    PyExec->>KvTx: send/recv KV cache
  end

  loop Main/Overlap loops
    PyExec->>PyExec: _check_kv_transfer_timeout()
    alt timeout configured AND elapsed > timeout
      PyExec->>Req: mark py_kv_transfer_timed_out = True
      PyExec-->>KvTx: cancel_request(req_id)
      note right of PyExec: Cleanup/terminate request
      PyExec-->>Client: return/notify timed-out request
    else transfer ongoing
      PyExec->>KvTx: poll transfer status
      KvTx-->>PyExec: status (in progress/done)
      opt done
        PyExec->>Req: clear py_kv_transfer_start_time
      end
    end
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 3.85% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description Check ⚠️ Warning The PR description includes the template headings but does not contain any actual description or test coverage details, leaving the Description and Test Coverage sections empty and reviewers without context or information about testing. This makes the description incomplete against the repository’s template requirements. Please fill in the Description section with a concise explanation of the issue and solution and populate the Test Coverage section with relevant test cases that validate the new functionality.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title includes a valid NVBugs ID and correct type tag and concisely describes the primary change of adding a KV cache transfer timeout. It accurately follows the repository convention and allows reviewers to understand the main intent at a glance.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ca9da1f and f8c3670.

📒 Files selected for processing (9)
  • cpp/include/tensorrt_llm/executor/executor.h (2 hunks)
  • cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp (3 hunks)
  • cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp (2 hunks)
  • cpp/tensorrt_llm/pybind/executor/executorConfig.cpp (2 hunks)
  • examples/disaggregated/README.md (1 hunks)
  • tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py (1 hunks)
  • tensorrt_llm/_torch/pyexecutor/llm_request.py (1 hunks)
  • tensorrt_llm/_torch/pyexecutor/py_executor.py (10 hunks)
  • tensorrt_llm/llmapi/llm_args.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (8)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/pyexecutor/llm_request.py
  • cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp
  • tensorrt_llm/llmapi/llm_args.py
  • cpp/include/tensorrt_llm/executor/executor.h
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py
  • cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
  • cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/pyexecutor/llm_request.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/pyexecutor/llm_request.py
  • cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp
  • tensorrt_llm/llmapi/llm_args.py
  • cpp/include/tensorrt_llm/executor/executor.h
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py
  • cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
  • cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...

Files:

  • cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp
  • cpp/include/tensorrt_llm/executor/executor.h
  • cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
  • cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.

Files:

  • cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp
  • cpp/include/tensorrt_llm/executor/executor.h
  • cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
  • cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp
**/*.{h,hpp,hh,hxx,cpp,cxx,cc}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc}: Prefer anonymous namespaces over 'static' for internal linkage of functions.
All templates (class/function/member/static) must be instantiated at least once; non-POD classes should have private data members.

Files:

  • cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp
  • cpp/include/tensorrt_llm/executor/executor.h
  • cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
  • cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp
**/*.{h,hpp,hh,hxx}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Document new class interfaces and function prototypes with Doxygen; use //! for single-line and //!< for members.

Files:

  • cpp/include/tensorrt_llm/executor/executor.h
**/*.{h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use include guards named 'TRTLLM_<FILE_NAME_IN_CAPS_WITH_UNDERSCORES>_H' (no leading or trailing underscore; directory names excluded).

Files:

  • cpp/include/tensorrt_llm/executor/executor.h
🧬 Code graph analysis (4)
cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp (1)
cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp (12)
  • getBackendType (53-56)
  • getBackendType (53-53)
  • setBackendType (38-41)
  • setBackendType (38-38)
  • getMaxTokensInBuffer (58-61)
  • getMaxTokensInBuffer (58-58)
  • setMaxTokensInBuffer (43-46)
  • setMaxTokensInBuffer (43-43)
  • getKvTransferTimeoutMs (63-66)
  • getKvTransferTimeoutMs (63-63)
  • setKvTransferTimeoutMs (48-51)
  • setKvTransferTimeoutMs (48-48)
cpp/include/tensorrt_llm/executor/executor.h (2)
cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp (11)
  • CacheTransceiverConfig (24-30)
  • other (32-36)
  • other (32-32)
  • setBackendType (38-41)
  • setBackendType (38-38)
  • setMaxTokensInBuffer (43-46)
  • setMaxTokensInBuffer (43-43)
  • setKvTransferTimeoutMs (48-51)
  • setKvTransferTimeoutMs (48-48)
  • getKvTransferTimeoutMs (63-66)
  • getKvTransferTimeoutMs (63-63)
tensorrt_llm/llmapi/llm_args.py (1)
  • CacheTransceiverConfig (1275-1299)
tensorrt_llm/_torch/pyexecutor/py_executor.py (4)
tensorrt_llm/_utils.py (1)
  • nvtx_range (900-919)
tensorrt_llm/logger.py (1)
  • warning (132-133)
cpp/include/tensorrt_llm/batch_manager/llmRequest.h (1)
  • LlmRequestState (47-210)
tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py (2)
  • cancel_request (94-95)
  • cancel_request (139-140)
cpp/tensorrt_llm/pybind/executor/executorConfig.cpp (2)
cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp (13)
  • CacheTransceiverConfig (24-30)
  • getBackendType (53-56)
  • getBackendType (53-53)
  • setBackendType (38-41)
  • setBackendType (38-38)
  • getMaxTokensInBuffer (58-61)
  • getMaxTokensInBuffer (58-58)
  • setMaxTokensInBuffer (43-46)
  • setMaxTokensInBuffer (43-43)
  • getKvTransferTimeoutMs (63-66)
  • getKvTransferTimeoutMs (63-63)
  • setKvTransferTimeoutMs (48-51)
  • setKvTransferTimeoutMs (48-48)
tensorrt_llm/llmapi/llm_args.py (1)
  • CacheTransceiverConfig (1275-1299)
🪛 Ruff (0.13.3)
tensorrt_llm/_torch/pyexecutor/py_executor.py

1747-1747: Undefined name ctx_in_transmission_requests

(F821)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (9)
examples/disaggregated/README.md (1)

19-21: LGTM!

The documentation clearly describes the new kv_transfer_timeout_ms configuration option and its behavior.

tensorrt_llm/_torch/pyexecutor/llm_request.py (1)

445-446: LGTM!

The new attributes follow the existing naming convention and are initialized with appropriate default values to support KV transfer timeout tracking.

tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py (1)

112-113: LGTM!

The kv_transfer_timeout_ms is correctly initialized from the configuration, following the existing pattern in the constructor.

tensorrt_llm/llmapi/llm_args.py (2)

1289-1293: LGTM!

The kv_transfer_timeout_ms field is correctly defined with an appropriate type, default value, and clear description.


1298-1299: LGTM!

The _to_pybind method correctly includes the new kv_transfer_timeout_ms parameter in the _CacheTransceiverConfig construction.

cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp (1)

474-475: LGTM!

The property binding for kv_transfer_timeout_ms correctly exposes the getter and setter methods.

cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp (3)

34-35: LGTM!

The equality operator correctly includes the new mKvTransferTimeoutMs field in the comparison.


48-51: LGTM!

The setKvTransferTimeoutMs setter is correctly implemented.


63-66: LGTM!

The getKvTransferTimeoutMs getter is correctly implemented and properly const-qualified.

Comment on lines +25 to +28
std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs)
: mBackendType(backendType)
, mMaxTokensInBuffer(maxNumTokens)
: mMaxTokensInBuffer(maxNumTokens)
, mKvTransferTimeoutMs(kvTransferTimeoutMs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix the syntax error in the constructor initializer list.

The initializer list has an incorrect colon on line 27. C++ initializer lists should have only one colon after the constructor signature, with subsequent members separated by commas.

Apply this diff to fix the syntax error:

 CacheTransceiverConfig::CacheTransceiverConfig(
     std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs)
     : mBackendType(backendType)
-    : mMaxTokensInBuffer(maxNumTokens)
+    , mMaxTokensInBuffer(maxNumTokens)
     , mKvTransferTimeoutMs(kvTransferTimeoutMs)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
std::optional<BackendType> backendType, std::optional<size_t> maxNumTokens, std::optional<int> kvTransferTimeoutMs)
: mBackendType(backendType)
, mMaxTokensInBuffer(maxNumTokens)
: mMaxTokensInBuffer(maxNumTokens)
, mKvTransferTimeoutMs(kvTransferTimeoutMs)
CacheTransceiverConfig::CacheTransceiverConfig(
std::optional<BackendType> backendType,
std::optional<size_t> maxNumTokens,
std::optional<int> kvTransferTimeoutMs)
: mBackendType(backendType)
, mMaxTokensInBuffer(maxNumTokens)
, mKvTransferTimeoutMs(kvTransferTimeoutMs)
{
//
}
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp around lines 25-28 the
constructor initializer list incorrectly uses a second colon before
mMaxTokensInBuffer, causing a syntax error; change that second colon to a comma
so the initializer list has a single leading colon after the constructor
signature and subsequent members separated by commas (i.e., :
mBackendType(backendType), mMaxTokensInBuffer(maxNumTokens),
mKvTransferTimeoutMs(kvTransferTimeoutMs)).

Comment on lines +467 to +469
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
std::optional < int >> (), nb::arg("backend") = std::nullopt,
nb::arg("max_tokens_in_buffer") = std::nullopt, nb::arg("kv_transfer_timeout_ms") = std::nullopt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix the syntax error in constructor binding.

The template parameter list is incorrectly split across lines. The comma after std::optional<size_t>> on line 467 breaks the template syntax.

Apply this diff to fix the syntax error:

-.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
-    std::optional < int >> (), nb::arg("backend") = std::nullopt,
+.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
+    std::optional<int>>(), nb::arg("backend") = std::nullopt,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
std::optional < int >> (), nb::arg("backend") = std::nullopt,
nb::arg("max_tokens_in_buffer") = std::nullopt, nb::arg("kv_transfer_timeout_ms") = std::nullopt)
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
std::optional<int>>(), nb::arg("backend") = std::nullopt,
nb::arg("max_tokens_in_buffer") = std::nullopt, nb::arg("kv_transfer_timeout_ms") = std::nullopt)
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp around lines 467 to
469, the nb::init template parameter list is split across lines causing a syntax
error (the comma after std::optional<size_t>> breaks the template). Fix by using
a single, correct template parameter list (e.g.
nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>,
std::optional<size_t>, std::optional<int>>() ) and keep the nb::arg(...)
defaults as before so the constructor binding compiles.

Comment on lines 448 to +451
py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt)
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
std::optional < int >> (), py::arg("backend") = std::nullopt,
py::arg("max_tokens_in_buffer") = std::nullopt, py::arg("kv_transfer_timeout_ms") = std::nullopt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix malformed py::init signature

The py::init template parameter list is broken—std::optional<int>> ended up outside the template argument list, so this won’t compile. Keep all three optionals inside the template parameters.

-        .def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
-            std::optional < int >> (), py::arg("backend") = std::nullopt,
+        .def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
+                std::optional<int>>(),
+            py::arg("backend") = std::nullopt,
             py::arg("max_tokens_in_buffer") = std::nullopt, py::arg("kv_transfer_timeout_ms") = std::nullopt)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(),
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt)
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
std::optional < int >> (), py::arg("backend") = std::nullopt,
py::arg("max_tokens_in_buffer") = std::nullopt, py::arg("kv_transfer_timeout_ms") = std::nullopt)
py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
std::optional<int>>(),
py::arg("backend") = std::nullopt,
py::arg("max_tokens_in_buffer") = std::nullopt, py::arg("kv_transfer_timeout_ms") = std::nullopt)

Comment on lines +1589 to +1598
for req in self.ctx_in_transmission_requests:
if req.py_kv_transfer_start_time is None:
continue
elapsed_time = (current_time - req.py_kv_transfer_start_time) * 1000
if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out:
logger.warning(
f"Terminating context request {req.py_request_id} due to KV cache transfer timeout"
)
req.py_kv_transfer_timed_out = True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix AttributeError when checking context transfer timeouts

self.ctx_in_transmission_requests holds (request, blockId) tuples, so iterating directly over it and accessing req.py_kv_transfer_start_time will throw at runtime. Unpack the tuple (or index into it) before touching request attributes.

-        for req in self.ctx_in_transmission_requests:
+        for req, _ in self.ctx_in_transmission_requests:
             if req.py_kv_transfer_start_time is None:
                 continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for req in self.ctx_in_transmission_requests:
if req.py_kv_transfer_start_time is None:
continue
elapsed_time = (current_time - req.py_kv_transfer_start_time) * 1000
if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out:
logger.warning(
f"Terminating context request {req.py_request_id} due to KV cache transfer timeout"
)
req.py_kv_transfer_timed_out = True
for req, _ in self.ctx_in_transmission_requests:
if req.py_kv_transfer_start_time is None:
continue
elapsed_time = (current_time - req.py_kv_transfer_start_time) * 1000
if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out:
logger.warning(
f"Terminating context request {req.py_request_id} due to KV cache transfer timeout"
)
req.py_kv_transfer_timed_out = True
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/py_executor.py around lines 1589 to 1598, the
loop treats items in self.ctx_in_transmission_requests as request objects but
they are (request, blockId) tuples; unpack the tuple before accessing attributes
(e.g., for req, _ in self.ctx_in_transmission_requests or req = item[0]), then
use req.py_kv_transfer_start_time, compute elapsed_time and set
req.py_kv_transfer_timed_out as before to avoid the AttributeError.

Comment on lines +1746 to +1749
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
for req in ctx_in_transmission_requests:
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS:
req.py_kv_transfer_start_time = time.time()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix NameError when arming context timeout clock

ctx_in_transmission_requests is undefined here, so this block will raise immediately. The intent is to walk the freshly built ctx_transmission_reqs.

-        if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
-            for req in ctx_in_transmission_requests:
+        if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
+            for req in ctx_transmission_reqs:
                 if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS:
                     req.py_kv_transfer_start_time = time.time()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
for req in ctx_in_transmission_requests:
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS:
req.py_kv_transfer_start_time = time.time()
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
for req in ctx_transmission_reqs:
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS:
req.py_kv_transfer_start_time = time.time()
🧰 Tools
🪛 Ruff (0.13.3)

1747-1747: Undefined name ctx_in_transmission_requests

(F821)

@pcastonguay pcastonguay changed the title [None][fix] Add KV cache transfer timeout [https://nvbugs/5429636][feat] Add KV cache transfer timeout Oct 7, 2025
Signed-off-by: Patrice Castonguay <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants