Skip to content

Commit e5ab4a8

Browse files
committed
[None][feat] Enable early exit with overlap scheduler
- Update MicroBatchScheduler bindings to skip scheduling after GENERATION_TO_COMPLETE state. - Update PyExecutor to set GENERATION_TO_COMPLETE state for requests that will complete next iteration. - Fix _executor_loop_overlap to finish previous batch if current batch is empty. Signed-off-by: Robin Kobus <[email protected]>
1 parent 3a5845e commit e5ab4a8

File tree

5 files changed

+33
-19
lines changed

5 files changed

+33
-19
lines changed

cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
6464
LlmRequestState>(),
6565
nb::arg("ctx_chunk_config") = std::nullopt, nb::arg("max_context_length") = std::nullopt,
6666
nb::arg("no_schedule_until_state") = LlmRequestState::kCONTEXT_INIT,
67-
nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_COMPLETE)
67+
nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_TO_COMPLETE)
6868
.def("__call__", &MicroBatchScheduler::operator(), nb::arg("active_requests"), nb::arg("inflight_req_ids"),
6969
nb::arg("max_batch_size_runtime"), nb::arg("max_num_tokens_runtime"))
7070
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ void initBindings(nb::module_& m)
103103
.def("get_last_tokens", nb::overload_cast<>(&GenLlmReq::getLastTokens))
104104
.def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, nb::arg("for_next_iteration") = false)
105105
.def_prop_ro("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens)
106+
.def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration)
106107
.def("add_new_token", &GenLlmReq::addNewToken, nb::arg("token"), nb::arg("beam"))
107108
.def("add_new_tokens", &GenLlmReq::addNewTokens, nb::arg("beam_tokens"))
108109
.def_prop_ro("num_draft_tokens", &GenLlmReq::getNumDraftTokens)

cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
6565
LlmRequestState>(),
6666
py::arg("ctx_chunk_config") = std::nullopt, py::arg("max_context_length") = std::nullopt,
6767
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
68-
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE,
69-
"LlmRequestState.GENERATION_COMPLETE"))
68+
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_TO_COMPLETE,
69+
"LlmRequestState.GENERATION_TO_COMPLETE"))
7070
.def("__call__", &MicroBatchScheduler::operator(), py::arg("active_requests"), py::arg("inflight_req_ids"),
7171
py::arg("max_batch_size_runtime"), py::arg("max_num_tokens_runtime"))
7272
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ void initBindings(pybind11::module_& m)
107107
.def("get_last_tokens", py::overload_cast<>(&GenLlmReq::getLastTokens))
108108
.def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, py::arg("for_next_iteration") = false)
109109
.def_property_readonly("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens)
110+
.def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration)
110111
.def("add_new_token", &GenLlmReq::addNewToken, py::arg("token"), py::arg("beam"))
111112
.def("add_new_tokens", &GenLlmReq::addNewTokens, py::arg("beam_tokens"))
112113
.def_property_readonly("num_draft_tokens", &GenLlmReq::getNumDraftTokens)

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ def _executor_loop_pp(self):
850850
self.num_scheduled_requests = scheduled_batch.batch_size
851851

852852
logger.debug(
853-
f'has {len(self.active_requests)} active_request, '
853+
f'has {len(self.active_requests)} active_requests, '
854854
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
855855
f'{len(scheduled_batch.generation_requests)} generation requests'
856856
)
@@ -1089,7 +1089,7 @@ def _prepare_and_schedule_batch(self):
10891089

10901090
self.num_scheduled_requests = scheduled_batch.batch_size
10911091
logger.debug(
1092-
f'has {len(self.active_requests)} active_request, '
1092+
f'has {len(self.active_requests)} active_requests, '
10931093
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
10941094
f'{len(scheduled_batch.generation_requests)} generation requests')
10951095
return scheduled_batch, iter_stats
@@ -1359,19 +1359,20 @@ def _executor_loop_overlap(self):
13591359
if target_inputs is not None:
13601360
self._process_draft_results(scheduled_batch,
13611361
draft_outputs, draft_batch)
1362-
elif self.previous_batch is not None and not use_previous_draft_tokens:
1363-
self._update_requests(self.previous_batch.sample_state)
1362+
if target_inputs is None and self.previous_batch is not None and not use_previous_draft_tokens:
1363+
self._update_requests(self.previous_batch.sample_state)
13641364

1365-
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
1366-
for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
1367-
if req.is_context_only_request and (
1368-
req.is_context_finished
1369-
or req.is_finished_due_to_length):
1370-
block_id = self.kv_cache_manager.store_blocks_for_reuse(
1371-
req, True)
1372-
self.ctx_in_transmission_requests.append(
1373-
(req, block_id))
1365+
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
1366+
for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
1367+
if req.is_context_only_request and (
1368+
req.is_context_finished
1369+
or req.is_finished_due_to_length):
1370+
block_id = self.kv_cache_manager.store_blocks_for_reuse(
1371+
req, True)
1372+
self.ctx_in_transmission_requests.append(
1373+
(req, block_id))
13741374

1375+
if scheduled_batch.batch_size > 0:
13751376
if self.guided_decoder is not None:
13761377
# add_batch must be called again to have updated new tokens.
13771378
self.guided_decoder.add_batch(scheduled_batch)
@@ -1387,9 +1388,10 @@ def _executor_loop_overlap(self):
13871388
scheduled_batch.context_requests
13881389
) if self.kv_cache_transceiver else []
13891390

1390-
if self.previous_batch is not None:
1391-
self._process_previous_batch()
1391+
if self.previous_batch is not None:
1392+
self._process_previous_batch()
13921393

1394+
if scheduled_batch.batch_size > 0:
13931395
if self.enable_iter_perf_stats:
13941396
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
13951397
'num_ctx_tokens']
@@ -1862,7 +1864,17 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
18621864
request.context_chunk_size)
18631865
request.move_to_next_context_chunk()
18641866
if request.context_remaining_length == 0:
1865-
request.state = LlmRequestState.GENERATION_IN_PROGRESS
1867+
if not self.disable_overlap_scheduler and request.will_complete_next_iteration(
1868+
):
1869+
request.state = LlmRequestState.GENERATION_TO_COMPLETE
1870+
else:
1871+
request.state = LlmRequestState.GENERATION_IN_PROGRESS
1872+
1873+
for request in scheduled_requests.generation_requests:
1874+
if request.state != LlmRequestState.GENERATION_COMPLETE:
1875+
if not self.disable_overlap_scheduler and request.will_complete_next_iteration(
1876+
):
1877+
request.state = LlmRequestState.GENERATION_TO_COMPLETE
18661878

18671879
def _update_request_states_star_attention(
18681880
self, scheduled_requests: ScheduledRequests):

0 commit comments

Comments
 (0)