-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[https://nvbugs/5429636][feat] Add KV cache transfer timeout #8178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -433,15 +433,15 @@ void initConfigBindings(nb::module_& m) | |||||||||||||
.def("__setstate__", guidedDecodingConfigSetstate); | ||||||||||||||
|
||||||||||||||
auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self) | ||||||||||||||
{ return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); }; | ||||||||||||||
{ return nb::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); }; | ||||||||||||||
auto cacheTransceiverConfigSetstate = [](tle::CacheTransceiverConfig& self, nb::tuple const& state) | ||||||||||||||
{ | ||||||||||||||
if (state.size() != 2) | ||||||||||||||
if (state.size() != 3) | ||||||||||||||
{ | ||||||||||||||
throw std::runtime_error("Invalid CacheTransceiverConfig state!"); | ||||||||||||||
} | ||||||||||||||
new (&self) tle::CacheTransceiverConfig( | ||||||||||||||
nb::cast<tle::CacheTransceiverConfig::BackendType>(state[0]), nb::cast<std::optional<size_t>>(state[1])); | ||||||||||||||
new (&self) tle::CacheTransceiverConfig(nb::cast<tle::CacheTransceiverConfig::BackendType>(state[0]), | ||||||||||||||
nb::cast<std::optional<size_t>>(state[1]), nb::cast<std::optional<int>>(state[2])); | ||||||||||||||
}; | ||||||||||||||
|
||||||||||||||
nb::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType") | ||||||||||||||
|
@@ -464,12 +464,15 @@ void initConfigBindings(nb::module_& m) | |||||||||||||
}); | ||||||||||||||
|
||||||||||||||
nb::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig") | ||||||||||||||
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(), | ||||||||||||||
nb::arg("backend") = std::nullopt, nb::arg("max_tokens_in_buffer") = std::nullopt) | ||||||||||||||
.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>, | ||||||||||||||
std::optional < int >> (), nb::arg("backend") = std::nullopt, | ||||||||||||||
nb::arg("max_tokens_in_buffer") = std::nullopt, nb::arg("kv_transfer_timeout_ms") = std::nullopt) | ||||||||||||||
Comment on lines
+467
to
+469
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix the syntax error in constructor binding. The template parameter list is incorrectly split across lines. The comma after Apply this diff to fix the syntax error: -.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
- std::optional < int >> (), nb::arg("backend") = std::nullopt,
+.def(nb::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
+ std::optional<int>>(), nb::arg("backend") = std::nullopt, 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||
.def_prop_rw( | ||||||||||||||
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType) | ||||||||||||||
.def_prop_rw("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer, | ||||||||||||||
&tle::CacheTransceiverConfig::setMaxTokensInBuffer) | ||||||||||||||
.def_prop_rw("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs, | ||||||||||||||
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs) | ||||||||||||||
.def("__getstate__", cacheTransceiverConfigGetstate) | ||||||||||||||
.def("__setstate__", cacheTransceiverConfigSetstate); | ||||||||||||||
|
||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -415,15 +415,15 @@ void initConfigBindings(pybind11::module_& m) | |||||||||||||||||||||||
.def(py::pickle(guidedDecodingConfigGetstate, guidedDecodingConfigSetstate)); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self) | ||||||||||||||||||||||||
{ return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); }; | ||||||||||||||||||||||||
{ return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); }; | ||||||||||||||||||||||||
auto cacheTransceiverConfigSetstate = [](py::tuple const& state) | ||||||||||||||||||||||||
{ | ||||||||||||||||||||||||
if (state.size() != 2) | ||||||||||||||||||||||||
if (state.size() != 3) | ||||||||||||||||||||||||
{ | ||||||||||||||||||||||||
throw std::runtime_error("Invalid CacheTransceiverConfig state!"); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
return tle::CacheTransceiverConfig( | ||||||||||||||||||||||||
state[0].cast<tle::CacheTransceiverConfig::BackendType>(), state[1].cast<std::optional<size_t>>()); | ||||||||||||||||||||||||
return tle::CacheTransceiverConfig(state[0].cast<tle::CacheTransceiverConfig::BackendType>(), | ||||||||||||||||||||||||
state[1].cast<std::optional<size_t>>(), state[2].cast<std::optional<int>>()); | ||||||||||||||||||||||||
}; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
py::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType") | ||||||||||||||||||||||||
|
@@ -446,12 +446,15 @@ void initConfigBindings(pybind11::module_& m) | |||||||||||||||||||||||
}); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig") | ||||||||||||||||||||||||
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>(), | ||||||||||||||||||||||||
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt) | ||||||||||||||||||||||||
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>, | ||||||||||||||||||||||||
std::optional < int >> (), py::arg("backend") = std::nullopt, | ||||||||||||||||||||||||
py::arg("max_tokens_in_buffer") = std::nullopt, py::arg("kv_transfer_timeout_ms") = std::nullopt) | ||||||||||||||||||||||||
Comment on lines
448
to
+451
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix malformed The - .def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>>,
- std::optional < int >> (), py::arg("backend") = std::nullopt,
+ .def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
+ std::optional<int>>(),
+ py::arg("backend") = std::nullopt,
py::arg("max_tokens_in_buffer") = std::nullopt, py::arg("kv_transfer_timeout_ms") = std::nullopt) 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||
.def_property( | ||||||||||||||||||||||||
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType) | ||||||||||||||||||||||||
.def_property("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer, | ||||||||||||||||||||||||
&tle::CacheTransceiverConfig::setMaxTokensInBuffer) | ||||||||||||||||||||||||
.def_property("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs, | ||||||||||||||||||||||||
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs) | ||||||||||||||||||||||||
.def(py::pickle(cacheTransceiverConfigGetstate, cacheTransceiverConfigSetstate)); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
auto executorConfigGetState = [](py::object const& self) | ||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -978,6 +978,7 @@ def _executor_loop_pp(self): | |||||||||||||||||||||||||||||||||||||
self.micro_batches[prev_microbatch_id] = None | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self.kv_cache_transceiver and self.ctx_in_transmission_requests: | ||||||||||||||||||||||||||||||||||||||
self._check_kv_transfer_timeout() | ||||||||||||||||||||||||||||||||||||||
self._terminate_ctx_finished_requests() | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self._disagg_pp_termination_handler is not None: | ||||||||||||||||||||||||||||||||||||||
|
@@ -1006,6 +1007,7 @@ def _prepare_and_schedule_batch(self): | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self.kv_cache_transceiver: | ||||||||||||||||||||||||||||||||||||||
self._check_disagg_gen_transfer_status() | ||||||||||||||||||||||||||||||||||||||
self._check_kv_transfer_timeout() | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
iter_stats = None | ||||||||||||||||||||||||||||||||||||||
if self.enable_iter_perf_stats: | ||||||||||||||||||||||||||||||||||||||
|
@@ -1179,6 +1181,7 @@ def _executor_loop(self): | |||||||||||||||||||||||||||||||||||||
self._add_kv_cache_events() | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self.kv_cache_transceiver and self.ctx_in_transmission_requests: | ||||||||||||||||||||||||||||||||||||||
self._check_kv_transfer_timeout() | ||||||||||||||||||||||||||||||||||||||
self._terminate_ctx_finished_requests() | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
self._kv_connector_terminate_requests() | ||||||||||||||||||||||||||||||||||||||
|
@@ -1364,6 +1367,7 @@ def _executor_loop_overlap(self): | |||||||||||||||||||||||||||||||||||||
ctx_transmission_reqs=ctx_transmission_reqs) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self.kv_cache_transceiver and self.ctx_in_transmission_requests: | ||||||||||||||||||||||||||||||||||||||
self._check_kv_transfer_timeout() | ||||||||||||||||||||||||||||||||||||||
self._terminate_ctx_finished_requests() | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
self._kv_connector_terminate_requests() | ||||||||||||||||||||||||||||||||||||||
|
@@ -1572,6 +1576,38 @@ def _check_disagg_gen_transfer_status(self): | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
@nvtx_range("_check_kv_transfer_timeout") | ||||||||||||||||||||||||||||||||||||||
def _check_kv_transfer_timeout(self): | ||||||||||||||||||||||||||||||||||||||
if not self.kv_cache_transceiver: | ||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||
timeout_ms = self.kv_cache_transceiver.kv_transfer_timeout_ms | ||||||||||||||||||||||||||||||||||||||
if timeout_ms is None or timeout_ms <= 0: | ||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
current_time = time.time() | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
for req in self.ctx_in_transmission_requests: | ||||||||||||||||||||||||||||||||||||||
if req.py_kv_transfer_start_time is None: | ||||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||||
elapsed_time = (current_time - req.py_kv_transfer_start_time) * 1000 | ||||||||||||||||||||||||||||||||||||||
if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out: | ||||||||||||||||||||||||||||||||||||||
logger.warning( | ||||||||||||||||||||||||||||||||||||||
f"Terminating context request {req.py_request_id} due to KV cache transfer timeout" | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
req.py_kv_transfer_timed_out = True | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
Comment on lines
+1589
to
+1598
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix AttributeError when checking context transfer timeouts
- for req in self.ctx_in_transmission_requests:
+ for req, _ in self.ctx_in_transmission_requests:
if req.py_kv_transfer_start_time is None:
continue 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||
for req in self.active_requests: | ||||||||||||||||||||||||||||||||||||||
if req.is_disagg_generation_transmission_in_progress and req.py_kv_transfer_start_time is not None: | ||||||||||||||||||||||||||||||||||||||
elapsed_time = (current_time - | ||||||||||||||||||||||||||||||||||||||
req.py_kv_transfer_start_time) * 1000 | ||||||||||||||||||||||||||||||||||||||
if elapsed_time > timeout_ms and not req.py_kv_transfer_timed_out: | ||||||||||||||||||||||||||||||||||||||
logger.warning( | ||||||||||||||||||||||||||||||||||||||
f"Terminating generation request {req.py_request_id} due to KV cache transfer timeout" | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
req.py_kv_transfer_timed_out = True | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
@nvtx_range("_pad_attention_dp_dummy_request") | ||||||||||||||||||||||||||||||||||||||
def _pad_attention_dp_dummy_request(self): | ||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||
|
@@ -1646,6 +1682,7 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch): | |||||||||||||||||||||||||||||||||||||
req.context_current_position = req.prompt_len | ||||||||||||||||||||||||||||||||||||||
req.decoding_iter = 1 | ||||||||||||||||||||||||||||||||||||||
req.py_decoding_iter = 1 | ||||||||||||||||||||||||||||||||||||||
req.py_kv_transfer_start_time = None | ||||||||||||||||||||||||||||||||||||||
first_gen_tokens = req.context_phase_params.first_gen_tokens | ||||||||||||||||||||||||||||||||||||||
ctx_draft_tokens = req.context_phase_params.draft_tokens | ||||||||||||||||||||||||||||||||||||||
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens | ||||||||||||||||||||||||||||||||||||||
|
@@ -1669,6 +1706,11 @@ def _recv_disagg_gen_cache(self, new_gen_reqs): | |||||||||||||||||||||||||||||||||||||
for req in new_gen_reqs: | ||||||||||||||||||||||||||||||||||||||
self.kv_cache_transceiver.request_and_receive_async(req) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None: | ||||||||||||||||||||||||||||||||||||||
for req in new_gen_reqs: | ||||||||||||||||||||||||||||||||||||||
if req.state == LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS: | ||||||||||||||||||||||||||||||||||||||
req.py_kv_transfer_start_time = time.time() | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
block_transfer = all([ | ||||||||||||||||||||||||||||||||||||||
req.is_disagg_generation_transmission_in_progress | ||||||||||||||||||||||||||||||||||||||
for req in self.active_requests | ||||||||||||||||||||||||||||||||||||||
|
@@ -1701,6 +1743,11 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests): | |||||||||||||||||||||||||||||||||||||
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS | ||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None: | ||||||||||||||||||||||||||||||||||||||
for req in ctx_in_transmission_requests: | ||||||||||||||||||||||||||||||||||||||
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS: | ||||||||||||||||||||||||||||||||||||||
req.py_kv_transfer_start_time = time.time() | ||||||||||||||||||||||||||||||||||||||
Comment on lines
+1746
to
+1749
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix NameError when arming context timeout clock
- if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
- for req in ctx_in_transmission_requests:
+ if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
+ for req in ctx_transmission_reqs:
if req.state == LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS:
req.py_kv_transfer_start_time = time.time() 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.13.3)1747-1747: Undefined name (F821) |
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
return ctx_transmission_reqs | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def _get_disagg_reqs_in_error_state(self): | ||||||||||||||||||||||||||||||||||||||
|
@@ -2018,6 +2065,18 @@ def _handle_responses(self): | |||||||||||||||||||||||||||||||||||||
requests_to_terminate.append(request) | ||||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# Check if generation request needs cleanup due to KV cache transfer timeout | ||||||||||||||||||||||||||||||||||||||
if request.py_kv_transfer_timed_out: | ||||||||||||||||||||||||||||||||||||||
# We need to check if can can successfully cancel the request before | ||||||||||||||||||||||||||||||||||||||
# terminating it and freeing KV cache associated with it | ||||||||||||||||||||||||||||||||||||||
is_cancelled = self.kv_cache_transceiver.cancel_request(request) | ||||||||||||||||||||||||||||||||||||||
# This will terminate the request and send an error response | ||||||||||||||||||||||||||||||||||||||
if is_cancelled: | ||||||||||||||||||||||||||||||||||||||
self._handle_errors( | ||||||||||||||||||||||||||||||||||||||
error_msg=f"Request {request.py_request_id} timed out", | ||||||||||||||||||||||||||||||||||||||
requests=[request]) | ||||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if request.is_generation_only_request(): | ||||||||||||||||||||||||||||||||||||||
# If request is in transmission, so we don't need to emit a response | ||||||||||||||||||||||||||||||||||||||
# Also, for the first iteration with overlap, we should skip since first | ||||||||||||||||||||||||||||||||||||||
|
@@ -2068,6 +2127,11 @@ def _handle_responses(self): | |||||||||||||||||||||||||||||||||||||
def _terminate_ctx_finished_requests(self): | ||||||||||||||||||||||||||||||||||||||
for request, block_id in self.ctx_in_transmission_requests[:]: | ||||||||||||||||||||||||||||||||||||||
if request.is_disagg_context_complete_state: | ||||||||||||||||||||||||||||||||||||||
if request.py_kv_transfer_timed_out: | ||||||||||||||||||||||||||||||||||||||
is_cancelled = self.kv_cache_transceiver.cancel_request( | ||||||||||||||||||||||||||||||||||||||
request) | ||||||||||||||||||||||||||||||||||||||
if is_cancelled: | ||||||||||||||||||||||||||||||||||||||
request.py_kv_transfer_start_time = None | ||||||||||||||||||||||||||||||||||||||
if not self.block_reuse_enabled or self.kv_cache_manager.is_vswa: | ||||||||||||||||||||||||||||||||||||||
self._terminate_request(request) | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the syntax error in the constructor initializer list.
The initializer list has an incorrect colon on line 27. C++ initializer lists should have only one colon after the constructor signature, with subsequent members separated by commas.
Apply this diff to fix the syntax error:
📝 Committable suggestion
🤖 Prompt for AI Agents