Skip to content

Commit 887cbed

Browse files
committed
ctx_pp2_gen_pp1_draft
Signed-off-by: Patrice Castonguay <[email protected]>
1 parent 3805976 commit 887cbed

File tree

3 files changed

+69
-4
lines changed

3 files changed

+69
-4
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,19 @@ void CacheFormatter::unformat(TransferSession& session)
814814
if (selfConfig.getModelConfig().mNbKvHeadsPerLayer.size() != destConfig.getModelConfig().mNbKvHeadsPerLayer.size())
815815
{
816816
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support same number of layers");
817-
return false;
817+
TLLM_LOG_WARNING("self: %d dest %d", selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(),
818+
destConfig.getModelConfig().mNbKvHeadsPerLayer.size());
819+
820+
auto selfTotalLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size()
821+
* selfConfig.getParallelConfig().mPipelineParallelism;
822+
auto destTotalLayers = destConfig.getModelConfig().mNbKvHeadsPerLayer.size()
823+
* destConfig.getParallelConfig().mPipelineParallelism;
824+
if (selfTotalLayers != destTotalLayers)
825+
{
826+
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: incompatible total layer counts: self=%d, dest=%d",
827+
static_cast<int>(selfTotalLayers), static_cast<int>(destTotalLayers));
828+
return false;
829+
}
818830
}
819831
int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size();
820832
int selfPPSize = selfConfig.getParallelConfig().mPipelineParallelism;

examples/disaggregated/disagg_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
hostname: localhost
22
port: 8000
3-
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
3+
model: /home/scratch.trt_llm_data/llm-models/llama-3.1-model/Meta-Llama-3.1-8B
44
free_gpu_memory_fraction: 0.25
55
backend: "pytorch"
66
disable_overlap_scheduler: True
77
context_servers:
88
num_instances: 1
99
tensor_parallel_size: 1
10-
pipeline_parallel_size: 1
10+
pipeline_parallel_size: 2
1111
kv_cache_config:
1212
free_gpu_memory_fraction: 0.2
1313
cache_transceiver_config:

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class BatchState:
122122
@dataclasses.dataclass
123123
class BatchStatePP(BatchState):
124124
microbatch_id: int = -1
125+
scheduled_ctx_reqs: list[LlmRequest] = None
125126

126127

127128
class PyExecutor:
@@ -656,6 +657,9 @@ def _executor_loop_pp(self):
656657
if self.should_stop_processing:
657658
break
658659

660+
if self.kv_cache_transceiver:
661+
self._check_disagg_gen_transfer_status()
662+
659663
if self.enable_iter_perf_stats:
660664
iter_stats = self._get_init_iter_stats(
661665
len(new_requests),
@@ -664,9 +668,28 @@ def _executor_loop_pp(self):
664668

665669
self._pad_attention_dp_dummy_request()
666670

667-
scheduled_batch, _, _ = self._schedule()
671+
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
672+
)
673+
674+
if self.kv_cache_transceiver:
675+
676+
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
677+
self._prepare_disagg_gen_init(
678+
fitting_disagg_gen_init_requests)
679+
680+
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
681+
logger.warning(
682+
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
683+
)
684+
self.kv_cache_transceiver.check_context_transfer_status(
685+
1)
686+
else:
687+
assert scheduled_batch.batch_size > 0, (
688+
"fail to schedule any pending request, "
689+
"probably run out of resource.")
668690

669691
self.num_scheduled_requests = scheduled_batch.batch_size
692+
670693
logger.debug(
671694
f'has {len(self.active_requests)} active_request, '
672695
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
@@ -688,8 +711,28 @@ def _executor_loop_pp(self):
688711
self.micro_batches[microbatch_id] = None
689712
else:
690713
self._add_inflight_ids(scheduled_batch)
714+
715+
if self.kv_cache_transceiver:
716+
# For generation requests which have completed KV cache transfer
717+
self._prepare_disagg_gen_transmission_complete(
718+
scheduled_batch)
719+
691720
self.resource_manager.prepare_resources(scheduled_batch)
692721

722+
# The generation requests that are do not have batch_idx,
723+
# needs to be in front of the batch due to the assumptions
724+
# made in model_engine.py::_forward_step. This is only important
725+
# for disaggregated serving. For non-disaggregated serving,
726+
# the generation requests always have batch_idx.
727+
scheduled_batch.generation_requests = sorted( # stable sort
728+
scheduled_batch.generation_requests,
729+
key=lambda req: int(req.py_batch_idx is not None),
730+
)
731+
732+
if self.kv_cache_transceiver:
733+
# Return the first token to the client
734+
self._handle_first_token_response(scheduled_batch)
735+
693736
# Stage 1: Async forward (all ranks) and decoding pass (last rank only)
694737
if not self.dist.is_last_pp_rank:
695738
sample_state = self._forward_step_inter_pp(
@@ -720,6 +763,7 @@ def _executor_loop_pp(self):
720763
iter_start_time=iter_start_time,
721764
iter_stats=iter_stats,
722765
microbatch_id=microbatch_id,
766+
scheduled_ctx_reqs=scheduled_batch.context_requests,
723767
)
724768

725769
self.micro_batches[microbatch_id] = batch_state
@@ -784,6 +828,12 @@ def _executor_loop_pp(self):
784828
if previous_batch is not None:
785829
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
786830
self._update_requests(previous_batch.sample_state)
831+
832+
if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs:
833+
ctx_transmission_reqs = self._send_disagg_ctx_cache(
834+
previous_batch.scheduled_ctx_reqs
835+
) if self.kv_cache_transceiver else []
836+
787837
self._handle_canceled_requests()
788838
finished_requests = self._handle_responses()
789839
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
@@ -792,6 +842,9 @@ def _executor_loop_pp(self):
792842
self._remove_inflight_ids(previous_scheduled_batch)
793843
self.micro_batches[prev_microbatch_id] = None
794844

845+
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
846+
self._terminate_ctx_finished_requests()
847+
795848
# march forward in microbatch slots
796849
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
797850

0 commit comments

Comments
 (0)