Skip to content

Commit 2ad7022

Browse files
committed
[None][feat] Enable early exit with overlap scheduler
- Update MicroBatchScheduler bindings to skip scheduling after GENERATION_TO_COMPLETE state. - Update PyExecutor to set GENERATION_TO_COMPLETE state for requests that will complete next iteration. - Fix _executor_loop_overlap to finish previous batch if current batch is empty. Signed-off-by: Robin Kobus <[email protected]>
1 parent 653aa6b commit 2ad7022

File tree

5 files changed

+33
-19
lines changed

5 files changed

+33
-19
lines changed

cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
6464
LlmRequestState>(),
6565
nb::arg("ctx_chunk_config") = std::nullopt, nb::arg("max_context_length") = std::nullopt,
6666
nb::arg("no_schedule_until_state") = LlmRequestState::kCONTEXT_INIT,
67-
nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_COMPLETE)
67+
nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_TO_COMPLETE)
6868
.def("__call__", &MicroBatchScheduler::operator(), nb::arg("active_requests"), nb::arg("inflight_req_ids"),
6969
nb::arg("max_batch_size_runtime"), nb::arg("max_num_tokens_runtime"))
7070
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ void initBindings(nb::module_& m)
103103
.def("get_last_tokens", nb::overload_cast<>(&GenLlmReq::getLastTokens))
104104
.def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, nb::arg("for_next_iteration") = false)
105105
.def_prop_ro("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens)
106+
.def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration)
106107
.def("add_new_token", &GenLlmReq::addNewToken, nb::arg("token"), nb::arg("beam"))
107108
.def("add_new_tokens", &GenLlmReq::addNewTokens, nb::arg("beam_tokens"))
108109
.def_prop_ro("num_draft_tokens", &GenLlmReq::getNumDraftTokens)

cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
6565
LlmRequestState>(),
6666
py::arg("ctx_chunk_config") = std::nullopt, py::arg("max_context_length") = std::nullopt,
6767
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
68-
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE,
69-
"LlmRequestState.GENERATION_COMPLETE"))
68+
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_TO_COMPLETE,
69+
"LlmRequestState.GENERATION_TO_COMPLETE"))
7070
.def("__call__", &MicroBatchScheduler::operator(), py::arg("active_requests"), py::arg("inflight_req_ids"),
7171
py::arg("max_batch_size_runtime"), py::arg("max_num_tokens_runtime"))
7272
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ void initBindings(pybind11::module_& m)
107107
.def("get_last_tokens", py::overload_cast<>(&GenLlmReq::getLastTokens))
108108
.def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, py::arg("for_next_iteration") = false)
109109
.def_property_readonly("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens)
110+
.def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration)
110111
.def("add_new_token", &GenLlmReq::addNewToken, py::arg("token"), py::arg("beam"))
111112
.def("add_new_tokens", &GenLlmReq::addNewTokens, py::arg("beam_tokens"))
112113
.def_property_readonly("num_draft_tokens", &GenLlmReq::getNumDraftTokens)

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ def _executor_loop_pp(self):
847847
self.num_scheduled_requests = scheduled_batch.batch_size
848848

849849
logger.debug(
850-
f'has {len(self.active_requests)} active_request, '
850+
f'has {len(self.active_requests)} active_requests, '
851851
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
852852
f'{len(scheduled_batch.generation_requests)} generation requests'
853853
)
@@ -1079,7 +1079,7 @@ def _prepare_and_schedule_batch(self):
10791079

10801080
self.num_scheduled_requests = scheduled_batch.batch_size
10811081
logger.debug(
1082-
f'has {len(self.active_requests)} active_request, '
1082+
f'has {len(self.active_requests)} active_requests, '
10831083
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
10841084
f'{len(scheduled_batch.generation_requests)} generation requests')
10851085
return scheduled_batch, iter_stats
@@ -1342,19 +1342,20 @@ def _executor_loop_overlap(self):
13421342
if target_inputs is not None:
13431343
self._process_draft_results(scheduled_batch,
13441344
draft_outputs, draft_batch)
1345-
elif self.previous_batch is not None and not use_previous_draft_tokens:
1346-
self._update_requests(self.previous_batch.sample_state)
1345+
if target_inputs is None and self.previous_batch is not None and not use_previous_draft_tokens:
1346+
self._update_requests(self.previous_batch.sample_state)
13471347

1348-
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
1349-
for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
1350-
if req.is_context_only_request and (
1351-
req.is_context_finished
1352-
or req.is_finished_due_to_length):
1353-
block_id = self.kv_cache_manager.store_blocks_for_reuse(
1354-
req, True)
1355-
self.ctx_in_transmission_requests.append(
1356-
(req, block_id))
1348+
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
1349+
for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
1350+
if req.is_context_only_request and (
1351+
req.is_context_finished
1352+
or req.is_finished_due_to_length):
1353+
block_id = self.kv_cache_manager.store_blocks_for_reuse(
1354+
req, True)
1355+
self.ctx_in_transmission_requests.append(
1356+
(req, block_id))
13571357

1358+
if scheduled_batch.batch_size > 0:
13581359
if self.guided_decoder is not None:
13591360
# add_batch must be called again to have updated new tokens.
13601361
self.guided_decoder.add_batch(scheduled_batch)
@@ -1370,9 +1371,10 @@ def _executor_loop_overlap(self):
13701371
scheduled_batch.context_requests
13711372
) if self.kv_cache_transceiver else []
13721373

1373-
if self.previous_batch is not None:
1374-
self._process_previous_batch()
1374+
if self.previous_batch is not None:
1375+
self._process_previous_batch()
13751376

1377+
if scheduled_batch.batch_size > 0:
13761378
if self.enable_iter_perf_stats:
13771379
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
13781380
'num_ctx_tokens']
@@ -1801,7 +1803,17 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
18011803
request.context_chunk_size)
18021804
request.move_to_next_context_chunk()
18031805
if request.context_remaining_length == 0:
1804-
request.state = LlmRequestState.GENERATION_IN_PROGRESS
1806+
if not self.disable_overlap_scheduler and request.will_complete_next_iteration(
1807+
):
1808+
request.state = LlmRequestState.GENERATION_TO_COMPLETE
1809+
else:
1810+
request.state = LlmRequestState.GENERATION_IN_PROGRESS
1811+
1812+
for request in scheduled_requests.generation_requests:
1813+
if request.state != LlmRequestState.GENERATION_COMPLETE:
1814+
if not self.disable_overlap_scheduler and request.will_complete_next_iteration(
1815+
):
1816+
request.state = LlmRequestState.GENERATION_TO_COMPLETE
18051817

18061818
def _update_request_states_star_attention(
18071819
self, scheduled_requests: ScheduledRequests):

0 commit comments

Comments
 (0)