Skip to content

Commit a50f2d1

Browse files
committed
add support for the overlap scheduler + little refactoring
1 parent 9105f24 commit a50f2d1

File tree

6 files changed

+88
-59
lines changed

6 files changed

+88
-59
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,9 @@ def estimate_max_tokens(self, py_executor: PyExecutor) -> None:
276276
executor_config.kv_cache_config.max_tokens = kv_cache_max_tokens
277277

278278
def _create_kv_cache_manager(
279-
self, model_engine: PyTorchModelEngine) -> KVCacheManager:
279+
self,
280+
model_engine: PyTorchModelEngine,
281+
for_estimation: bool = False) -> KVCacheManager:
280282
executor_config = self._executor_config
281283
mapping = self._mapping
282284
assert model_engine.model.model_config.is_generation, "Only construct KV cache for generation models."
@@ -317,15 +319,16 @@ def _create_kv_cache_manager(
317319
dtype=kv_cache_dtype,
318320
spec_config=spec_config,
319321
max_beam_width=executor_config.max_beam_width,
320-
kv_connector_manager=self._kv_connector_manager,
322+
kv_connector_manager=self._kv_connector_manager
323+
if not for_estimation else None,
321324
)
322325
elif is_nemotron_hybrid(config):
323326
if executor_config.max_beam_width > 1:
324327
raise ValueError(
325328
"MambaHybridCacheManager + beam search is not supported yet."
326329
)
327330

328-
if self._kv_connector_manager is not None:
331+
if not for_estimation and self._kv_connector_manager is not None:
329332
raise ValueError(
330333
"Connector manager is not supported for MambaHybridCacheManager."
331334
)
@@ -387,25 +390,29 @@ def _create_kv_cache_manager(
387390
max_num_tokens=executor_config.max_num_tokens,
388391
model_config=binding_model_config,
389392
max_beam_width=executor_config.max_beam_width,
390-
kv_connector_manager=self._kv_connector_manager,
393+
kv_connector_manager=self._kv_connector_manager
394+
if not for_estimation else None,
391395
)
392396
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config
393397
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:
394398
executor_config.max_seq_len = kv_cache_manager.max_seq_len
395399

396400
return kv_cache_manager
397401

398-
def build_managers(self, resources: Dict) -> None:
402+
def build_managers(self,
403+
resources: Dict,
404+
for_estimation: bool = False) -> None:
399405
"""Construct KV caches for model and draft model (if applicable)."""
400-
kv_cache_manager = self._create_kv_cache_manager(self._model_engine)
406+
kv_cache_manager = self._create_kv_cache_manager(
407+
self._model_engine, for_estimation)
401408

402-
if self._kv_connector_manager is not None and self._draft_model_engine is not None:
409+
if not for_estimation and self._kv_connector_manager is not None and self._draft_model_engine is not None:
403410
raise ValueError(
404411
"Connector manager is not supported for draft model.")
405412

406413
draft_kv_cache_manager = self._create_kv_cache_manager(
407-
self._draft_model_engine
408-
) if self._draft_model_engine is not None else None
414+
self._draft_model_engine,
415+
for_estimation) if self._draft_model_engine is not None else None
409416

410417
resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager
411418
resources[

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,6 @@ def _maybe_init_kv_connector_manager(self):
274274
"KV Cache Connector is not supported with pipeline parallelism."
275275
)
276276

277-
if not self.disable_overlap_scheduler:
278-
raise NotImplementedError(
279-
"KV Cache Connector is not supported with overlap scheduler."
280-
)
281-
282277
# TODO: This does NOT support pipeline parallel.
283278
layer_kv_tensors = {
284279
layer_idx: self.kv_cache_manager.get_buffers(layer_idx)
@@ -948,6 +943,19 @@ def _execute_guided_decoder(self, scheduled_batch, logits):
948943
self.guided_decoder.build(scheduled_batch)
949944
self.guided_decoder.execute(scheduled_batch, logits)
950945

946+
def _execute_kv_connector(self, scheduled_batch):
947+
if self.kv_connector_manager:
948+
self.kv_connector_manager.take_scheduled_requests_pending_load(
949+
scheduled_batch)
950+
self.kv_connector_manager.handle_metadata()
951+
self.kv_connector_manager.worker.start_load_kv()
952+
953+
def _terminate_async_save_requests(self):
954+
if self.kv_connector_manager:
955+
reqs_to_terminate = self.kv_connector_manager.get_finished()
956+
for req in reqs_to_terminate:
957+
self.resource_manager.free_resources(req)
958+
951959
def _executor_loop(self):
952960
torch.cuda.set_device(self.device_id)
953961
with self._profiler() as profile_step:
@@ -976,14 +984,9 @@ def _executor_loop(self):
976984

977985
# Return the first token to the client
978986
self._handle_first_token_response(scheduled_batch)
979-
scheduled_batch.is_warmup = self.is_warmup
980987
self.resource_manager.prepare_resources(scheduled_batch)
981988

982-
if self.kv_connector_manager:
983-
self.kv_connector_manager.take_scheduled_requests_pending_load(
984-
scheduled_batch)
985-
self.kv_connector_manager.handle_metadata()
986-
self.kv_connector_manager.worker.start_load_kv()
989+
self._execute_kv_connector(scheduled_batch)
987990

988991
if scheduled_batch.batch_size > 0 or (
989992
self.enable_attention_dp and self.dist.tp_size > 1):
@@ -1017,10 +1020,8 @@ def _executor_loop(self):
10171020

10181021
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
10191022
self._terminate_ctx_finished_requests()
1020-
elif self.kv_connector_manager:
1021-
reqs_to_terminate = self.kv_connector_manager.get_finished()
1022-
for req in reqs_to_terminate:
1023-
self.resource_manager.free_resources(req)
1023+
1024+
self._terminate_async_save_requests()
10241025

10251026
if self.enable_iter_perf_stats:
10261027
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
@@ -1086,9 +1087,12 @@ def _executor_loop_overlap(self):
10861087
# For generation requests which have completed KV cache transfer
10871088
self._prepare_disagg_gen_transmission_complete(
10881089
scheduled_batch)
1089-
10901090
self.resource_manager.prepare_resources(scheduled_batch)
10911091

1092+
self._execute_kv_connector(scheduled_batch)
1093+
1094+
if scheduled_batch.batch_size > 0:
1095+
10921096
# The generation requests that are do not have batch_idx,
10931097
# needs to be in front of the batch due to the assumptions
10941098
# made in model_engine.py::_forward_step. This is only important
@@ -1141,6 +1145,8 @@ def _executor_loop_overlap(self):
11411145
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
11421146
self._terminate_ctx_finished_requests()
11431147

1148+
self._terminate_async_save_requests()
1149+
11441150
def _process_previous_batch(self):
11451151
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
11461152
for req in self.previous_batch.ctx_transmission_reqs:

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def create_py_executor(
392392
with mem_monitor.observe_creation_stage(
393393
_ExecutorCreationStage.INIT_KV_CACHE
394394
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
395-
kv_cache_creator.build_managers(resources)
395+
kv_cache_creator.build_managers(resources, estimating_kv_cache)
396396

397397
# Resource managers for speculative decoding
398398
# For user-specified drafters, use extra_resource_managers in PyTorchBackend config
@@ -443,7 +443,7 @@ def create_py_executor(
443443
# create_kv_cache_manager above, which caps executor_config.max_seq_len. Restoring
444444
# the original value before creating the final KV cache.
445445
executor_config.max_seq_len = max_seq_len
446-
kv_cache_creator.build_managers(resources)
446+
kv_cache_creator.build_managers(resources, False)
447447

448448
for eng in [model_engine, draft_model_engine]:
449449
if eng is None:

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
377377
req.py_request_id,
378378
seq_len + (len(req.query_id) if self.mapping.cp_rank
379379
== self.mapping.cp_size - 1 else 0),
380-
req_beam_width, req, self.kv_connector_manager
381-
if not scheduled_batch.is_warmup else None)
380+
req_beam_width, req, self.kv_connector_manager)
382381
else:
383382
# TODO(jthomson04): This is begging for a mega refactor, and can likely be significantly simplified.
384383
# In add sequence, the connector API's get_num_new_matched_tokens is called.
@@ -388,10 +387,9 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
388387
# When that happens, the request will go through this same code path, but with is_kv_cache_connector_async_onboard set to True.
389388
# Because of this, we need to filter this case out to avoid adding the same sequence twice.
390389
if req.is_first_context_chunk and not req.is_kv_cache_connector_async_onboard:
391-
self.impl.add_sequence(
392-
req.py_request_id, req.prompt_len, req_beam_width, req,
393-
self.kv_connector_manager
394-
if not scheduled_batch.is_warmup else None)
390+
self.impl.add_sequence(req.py_request_id, req.prompt_len,
391+
req_beam_width, req,
392+
self.kv_connector_manager)
395393
for _ in range(self.num_extra_kv_tokens):
396394
self.impl.add_token(req.py_request_id)
397395
for _ in range(get_draft_token_length(req)):

tensorrt_llm/_torch/pyexecutor/scheduler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def __init__(self):
2121
self.context_requests: RequestList = []
2222
self.generation_requests: RequestList = []
2323
self.paused_requests: RequestList = []
24-
self.is_warmup: bool = False
2524

2625
@property
2726
def is_generation_only(self) -> bool:

tests/integration/defs/llmapi/test_llm_api_connector.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,16 @@ def model_fn(*args, **kwargs):
7171

7272

7373
@pytest.mark.threadleak(enabled=False)
74-
def test_connector_simple(model_with_connector):
74+
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
75+
def test_connector_simple(model_with_connector, use_overlap_scheduler):
7576
NUM_TOKENS = 8
7677

7778
model_fn, scheduler, worker = model_with_connector
7879

7980
model = model_fn(
8081
model="Qwen/Qwen2-0.5B",
8182
backend="pytorch",
82-
disable_overlap_scheduler=True,
83+
disable_overlap_scheduler=not use_overlap_scheduler,
8384
cuda_graph_config=None,
8485
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1))
8586

@@ -93,7 +94,9 @@ def test_connector_simple(model_with_connector):
9394

9495
model.generate(["Hello, world"], sampling_params)
9596

96-
assert scheduler.build_connector_meta.call_count == NUM_TOKENS
97+
# With the overlap scheduler, we generate one extra token.
98+
assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int(
99+
use_overlap_scheduler)
97100

98101
# We should have a single `SchedulerOutput` per forward pass.
99102
for i, call in enumerate(scheduler.build_connector_meta.call_args_list):
@@ -105,7 +108,8 @@ def test_connector_simple(model_with_connector):
105108
assert len(scheduler_output.requests[0].new_tokens) == 1
106109

107110
# We call `start_load_kv` once at the beginning of each forward pass.
108-
assert worker.start_load_kv.call_count == NUM_TOKENS
111+
assert worker.start_load_kv.call_count == NUM_TOKENS + int(
112+
use_overlap_scheduler)
109113

110114
# Only called once when the request is received.
111115
assert scheduler.get_num_new_matched_tokens.call_count == 1
@@ -114,31 +118,36 @@ def test_connector_simple(model_with_connector):
114118
for call in worker.wait_for_layer_load.call_args_list) + 1
115119

116120
# Called num_layers * num_forward_passes times.
117-
assert worker.wait_for_layer_load.call_count == num_layers * NUM_TOKENS
118-
assert worker.save_kv_layer.call_count == num_layers * NUM_TOKENS
121+
assert worker.wait_for_layer_load.call_count == num_layers * (
122+
NUM_TOKENS + int(use_overlap_scheduler))
123+
assert worker.save_kv_layer.call_count == num_layers * (
124+
NUM_TOKENS + int(use_overlap_scheduler))
119125

120126
for i, call in enumerate(worker.wait_for_layer_load.call_args_list):
121127
assert call.args[0] == i % num_layers
122128

123129
for i, call in enumerate(worker.save_kv_layer.call_args_list):
124130
assert call.args[0] == i % num_layers
125131

126-
assert worker.wait_for_save.call_count == NUM_TOKENS
132+
assert worker.wait_for_save.call_count == NUM_TOKENS + int(
133+
use_overlap_scheduler)
127134

128135
assert scheduler.request_finished.call_count == 1
129-
assert worker.get_finished.call_count == NUM_TOKENS
136+
assert worker.get_finished.call_count == NUM_TOKENS + int(
137+
use_overlap_scheduler)
130138

131139

132140
@pytest.mark.threadleak(enabled=False)
133-
def test_connector_async_onboard(model_with_connector):
141+
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
142+
def test_connector_async_onboard(model_with_connector, use_overlap_scheduler):
134143
NUM_TOKENS = 8
135144

136145
model_fn, scheduler, worker = model_with_connector
137146

138147
model = model_fn(
139148
model="Qwen/Qwen2-0.5B",
140149
backend="pytorch",
141-
disable_overlap_scheduler=True,
150+
disable_overlap_scheduler=not use_overlap_scheduler,
142151
cuda_graph_config=None,
143152
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1))
144153

@@ -153,23 +162,25 @@ def test_connector_async_onboard(model_with_connector):
153162
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
154163
], SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True))
155164

156-
# Once for the initial poll, then once for each token.
157-
assert worker.get_finished.call_count == NUM_TOKENS + 1
165+
# Once for the initial poll, then once for each token. One extra token when using the overlap scheduler.
166+
assert worker.get_finished.call_count == NUM_TOKENS + 1 + int(
167+
use_overlap_scheduler)
158168

159169
# In the first iteration, there should be a single request id provided.
160170
assert len(worker.get_finished.call_args_list[0].args[1]) == 1
161171

162172

163173
@pytest.mark.threadleak(enabled=False)
164-
def test_connector_async_save(model_with_connector):
174+
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
175+
def test_connector_async_save(model_with_connector, use_overlap_scheduler):
165176
NUM_TOKENS = 8
166177

167178
model_fn, scheduler, worker = model_with_connector
168179

169180
model = model_fn(
170181
model="Qwen/Qwen2-0.5B",
171182
backend="pytorch",
172-
disable_overlap_scheduler=True,
183+
disable_overlap_scheduler=not use_overlap_scheduler,
173184
cuda_graph_config=None,
174185
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1))
175186

@@ -188,12 +199,13 @@ def test_connector_async_save(model_with_connector):
188199

189200
assert scheduler.request_finished.call_count == 1
190201

191-
# On the last call to get_finished, we should be providing the async saving request.
192-
assert worker.get_finished.call_count == NUM_TOKENS
202+
# On the last call to get_finished, we should be providing the async saving request. One extra token when using the overlap scheduler.
203+
assert worker.get_finished.call_count == NUM_TOKENS + int(
204+
use_overlap_scheduler)
193205

194-
for i in range(NUM_TOKENS):
195-
args = worker.get_finished.call_args_list[i].args
196-
if i != NUM_TOKENS - 1:
206+
for i, call in enumerate(worker.get_finished.call_args_list):
207+
args = call.args
208+
if i != len(worker.get_finished.call_args_list) - 1:
197209
assert args == ([], [])
198210
else:
199211
assert len(args[0]) == 1
@@ -202,7 +214,9 @@ def test_connector_async_save(model_with_connector):
202214

203215

204216
@pytest.mark.threadleak(enabled=False)
205-
def test_connector_scheduler_output(model_with_connector):
217+
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
218+
def test_connector_scheduler_output(model_with_connector,
219+
use_overlap_scheduler):
206220
NUM_INPUT_TOKENS = 48
207221
NUM_TOKENS = 32
208222
BLOCK_SIZE = 32
@@ -212,7 +226,7 @@ def test_connector_scheduler_output(model_with_connector):
212226
model = model_fn(
213227
model="Qwen/Qwen2-0.5B",
214228
backend="pytorch",
215-
disable_overlap_scheduler=True,
229+
disable_overlap_scheduler=not use_overlap_scheduler,
216230
cuda_graph_config=None,
217231
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1))
218232

@@ -226,7 +240,9 @@ def test_connector_scheduler_output(model_with_connector):
226240

227241
model.generate([0] * NUM_INPUT_TOKENS, sampling_params)
228242

229-
assert scheduler.build_connector_meta.call_count == NUM_TOKENS
243+
# Additional token when using the overlap scheduler.
244+
assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int(
245+
use_overlap_scheduler)
230246

231247
for i, call in enumerate(scheduler.build_connector_meta.call_args_list):
232248
sched_output = call.args[0]
@@ -241,7 +257,8 @@ def test_connector_scheduler_output(model_with_connector):
241257
else:
242258
assert len(request.new_tokens) == 1
243259

244-
if request.computed_position % BLOCK_SIZE == 0:
260+
if (request.computed_position +
261+
int(use_overlap_scheduler)) % BLOCK_SIZE == 0:
245262
assert len(request.new_block_ids) == 1
246263
else:
247264
assert request.new_block_ids == []
@@ -257,7 +274,9 @@ def test_connector_scheduler_output(model_with_connector):
257274

258275

259276
@pytest.mark.threadleak(enabled=False)
260-
def test_connector_scheduler_output_chunked_context(model_with_connector):
277+
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
278+
def test_connector_scheduler_output_chunked_context(model_with_connector,
279+
use_overlap_scheduler):
261280
model_fn, scheduler, worker = model_with_connector
262281

263282
CHUNK_SIZE = 128
@@ -266,7 +285,7 @@ def test_connector_scheduler_output_chunked_context(model_with_connector):
266285
model = model_fn(
267286
model="Qwen/Qwen2-0.5B",
268287
backend="pytorch",
269-
disable_overlap_scheduler=True,
288+
disable_overlap_scheduler=not use_overlap_scheduler,
270289
cuda_graph_config=None,
271290
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1),
272291
enable_chunked_prefill=True,

0 commit comments

Comments
 (0)