Skip to content

Commit 18eb3ee

Browse files
pcastonguayraayandharreasonsolocoderabbitai[bot]
authored andcommitted
feat: Add support for disaggregation with pp with pytorch backend (NVIDIA#6369)
Signed-off-by: Patrice Castonguay <[email protected]> Signed-off-by: raayandhar <[email protected]> Signed-off-by: Lizhi Zhou <[email protected]> Signed-off-by: pcastonguay <[email protected]> Co-authored-by: raayandhar <[email protected]> Co-authored-by: Lizhi Zhou <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Lanyu Liao <[email protected]>
1 parent 98d10df commit 18eb3ee

File tree

15 files changed

+497
-22
lines changed

15 files changed

+497
-22
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,8 @@ void CacheFormatter::unformat(TransferSession& session)
840840
if (selfConfig.getModelConfig().mNbKvHeadsPerLayer.size() != destConfig.getModelConfig().mNbKvHeadsPerLayer.size())
841841
{
842842
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support same number of layers");
843+
TLLM_LOG_WARNING("self: %zu dest %zu", selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(),
844+
destConfig.getModelConfig().mNbKvHeadsPerLayer.size());
843845
return false;
844846
}
845847
int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size();

scripts/build_wheel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ def clear_folder(folder_path):
7171
if os.path.isdir(item_path) and not os.path.islink(item_path):
7272
rmtree(item_path)
7373
else:
74-
os.remove(item_path)
74+
try:
75+
os.remove(item_path)
76+
except (OSError, IOError) as e:
77+
print(f"Failed to remove {item_path}: {e}", file=sys.stderr)
7578

7679

7780
def sysconfig_scheme(override_vars=None):

tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ def __init__(self, mapping: Mapping, kv_cache_manager: KVCacheManager,
9696
attention_type: AttentionTypeCpp,
9797
cache_transceiver_config: CacheTransceiverConfig):
9898
world_config = mapping_to_world_config(mapping)
99-
num_kv_heads_per_layer = kv_cache_manager.num_kv_heads_per_layer
99+
total_num_kv_heads_per_layer = kv_cache_manager.total_num_kv_heads_per_layer
100100
head_dim = kv_cache_manager.head_dim
101101
tokens_per_block = kv_cache_manager.tokens_per_block
102102
dtype = kv_cache_manager.dtype
103103

104104
self.impl = CacheTransceiverCpp(kv_cache_manager.impl,
105-
num_kv_heads_per_layer, head_dim,
105+
total_num_kv_heads_per_layer, head_dim,
106106
tokens_per_block, world_config, dtype,
107107
attention_type,
108108
cache_transceiver_config)

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class BatchState:
122122
@dataclasses.dataclass
123123
class BatchStatePP(BatchState):
124124
microbatch_id: int = -1
125+
scheduled_ctx_reqs: list[LlmRequest] = None
125126

126127

127128
class PyExecutor:
@@ -631,6 +632,7 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):
631632
return False
632633

633634
def _executor_loop_pp(self):
635+
logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
634636
torch.cuda.set_device(self.device_id)
635637
microbatch_id = 0
636638
with self._profiler() as profile_step:
@@ -644,6 +646,9 @@ def _executor_loop_pp(self):
644646
if self.should_stop_processing:
645647
break
646648

649+
if self.kv_cache_transceiver:
650+
self._check_disagg_gen_transfer_status()
651+
647652
if self.enable_iter_perf_stats:
648653
iter_stats = self._get_init_iter_stats(
649654
len(new_requests),
@@ -652,9 +657,23 @@ def _executor_loop_pp(self):
652657

653658
self._pad_attention_dp_dummy_request()
654659

655-
scheduled_batch, _, _ = self._schedule()
660+
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
661+
)
662+
663+
if self.kv_cache_transceiver:
664+
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
665+
self._prepare_disagg_gen_init(
666+
fitting_disagg_gen_init_requests)
667+
668+
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
669+
logger.warning(
670+
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
671+
)
672+
self.kv_cache_transceiver.check_context_transfer_status(
673+
1)
656674

657675
self.num_scheduled_requests = scheduled_batch.batch_size
676+
658677
logger.debug(
659678
f'has {len(self.active_requests)} active_request, '
660679
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
@@ -667,7 +686,7 @@ def _executor_loop_pp(self):
667686
can_queue = 0 not in tp_batch_sizes
668687
else:
669688
can_queue = scheduled_batch.batch_size > 0
670-
if not can_queue:
689+
if not can_queue and not self.kv_cache_transceiver:
671690
assert len(self.inflight_req_ids) > 0, (
672691
"fail to schedule any pending request, probably run out of resource"
673692
)
@@ -676,8 +695,28 @@ def _executor_loop_pp(self):
676695
self.micro_batches[microbatch_id] = None
677696
else:
678697
self._add_inflight_ids(scheduled_batch)
698+
699+
if self.kv_cache_transceiver:
700+
# For generation requests which have completed KV cache transfer
701+
self._prepare_disagg_gen_transmission_complete(
702+
scheduled_batch)
703+
679704
self.resource_manager.prepare_resources(scheduled_batch)
680705

706+
# The generation requests that are do not have batch_idx,
707+
# needs to be in front of the batch due to the assumptions
708+
# made in model_engine.py::_forward_step. This is only important
709+
# for disaggregated serving. For non-disaggregated serving,
710+
# the generation requests always have batch_idx.
711+
scheduled_batch.generation_requests = sorted( # stable sort
712+
scheduled_batch.generation_requests,
713+
key=lambda req: int(req.py_batch_idx is not None),
714+
)
715+
716+
if self.kv_cache_transceiver:
717+
# Return the first token to the client
718+
self._handle_first_token_response(scheduled_batch)
719+
681720
# Stage 1: Async forward (all ranks) and decoding pass (last rank only)
682721
if not self.dist.is_last_pp_rank:
683722
sample_state = self._forward_step_inter_pp(
@@ -705,6 +744,7 @@ def _executor_loop_pp(self):
705744
iter_start_time=iter_start_time,
706745
iter_stats=iter_stats,
707746
microbatch_id=microbatch_id,
747+
scheduled_ctx_reqs=scheduled_batch.context_requests,
708748
)
709749

710750
self.micro_batches[microbatch_id] = batch_state
@@ -769,6 +809,11 @@ def _executor_loop_pp(self):
769809
if previous_batch is not None:
770810
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
771811
self._update_requests(previous_batch.sample_state)
812+
813+
if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs:
814+
self._send_disagg_ctx_cache(
815+
previous_batch.scheduled_ctx_reqs)
816+
772817
self._handle_canceled_requests()
773818
finished_requests = self._handle_responses()
774819
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
@@ -777,6 +822,9 @@ def _executor_loop_pp(self):
777822
self._remove_inflight_ids(previous_scheduled_batch)
778823
self.micro_batches[prev_microbatch_id] = None
779824

825+
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
826+
self._terminate_ctx_finished_requests()
827+
780828
# march forward in microbatch slots
781829
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
782830

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,18 +155,33 @@ def __init__(
155155
(num_kv_heads + tp_size - 1) // tp_size
156156
for _ in range(self.num_local_layers)
157157
]
158+
self.total_num_kv_heads_per_layer = [
159+
(num_kv_heads + tp_size - 1) // tp_size
160+
for _ in range(self.num_layers)
161+
]
158162
else:
159163
assert len(num_kv_heads) == self.num_layers
160164

165+
def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
166+
kv_head: Optional[int]):
167+
if kv_head is not None:
168+
num_kv_heads_per_layer.append(
169+
(kv_head + tp_size - 1) // tp_size)
170+
else:
171+
num_kv_heads_per_layer.append(0)
172+
161173
self.num_kv_heads_per_layer = []
162174
if self.num_local_layers > 0:
163175
for i in self.pp_layers:
164176
kv_head = num_kv_heads[i]
165-
if kv_head is not None:
166-
self.num_kv_heads_per_layer.append(
167-
(kv_head + tp_size - 1) // tp_size)
168-
else:
169-
self.num_kv_heads_per_layer.append(0)
177+
append_to_kv_heads_per_layer(self.num_kv_heads_per_layer,
178+
kv_head)
179+
180+
self.total_num_kv_heads_per_layer = []
181+
for i in range(self.num_layers):
182+
kv_head = num_kv_heads[i]
183+
append_to_kv_heads_per_layer(self.total_num_kv_heads_per_layer,
184+
kv_head)
170185

171186
self.num_kv_heads = num_kv_heads
172187
self.head_dim = head_dim

tests/integration/defs/accuracy/accuracy_core.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,3 +735,14 @@ def setup_class(cls):
735735
logger.set_level("info")
736736
yield
737737
logger.set_level(original_level)
738+
739+
740+
def get_accuracy_task(dataset_name: str):
741+
try:
742+
task_class = globals()[dataset_name]
743+
if issubclass(task_class, AccuracyTask):
744+
return task_class
745+
else:
746+
raise ValueError(f"Unknown dataset: {dataset_name}.")
747+
except KeyError:
748+
raise ValueError(f"Not registered dataset: {dataset_name}.")

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 97 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
2121
from tensorrt_llm.llmapi.llm_args import LlmArgs
2222

23-
from ..conftest import llm_models_root, parametrize_with_ids, skip_pre_hopper
23+
from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
24+
skip_pre_hopper)
2425
from ..trt_test_alternative import popen
25-
from .accuracy_core import GSM8K, MMLU, LlmapiAccuracyTestHarness
26+
from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness,
27+
get_accuracy_task)
2628

2729

2830
class Result(GenerationResultBase):
@@ -71,6 +73,12 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
7173
temp_dir = tempfile.TemporaryDirectory()
7274
disaggregated_serving_config_path = os.path.join(
7375
temp_dir.name, "disaggregated_serving_config.yaml")
76+
77+
if tensor_parallel_size > 1:
78+
print(
79+
f"Using unified tp parameter for testing is not recommended. Please use server configs instead."
80+
)
81+
7482
with open(disaggregated_serving_config_path, "w") as f:
7583
yaml.dump(disaggregated_server_config, f)
7684
ctx_server_config_path = os.path.join(temp_dir.name,
@@ -88,27 +96,40 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
8896
trtllm_serve_path = "trtllm-serve"
8997
# Common arguments for both servers
9098
common_args = [
91-
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
92-
"pytorch"
99+
trtllm_serve_path,
100+
model_name,
101+
"--host",
102+
"localhost",
103+
"--backend",
104+
"pytorch",
93105
]
94-
95-
if tensor_parallel_size > 1:
96-
common_args.append(f"--tp_size={tensor_parallel_size}")
106+
gen_tp, gen_pp = gen_server_config.get(
107+
"tensor_parallel_size",
108+
tensor_parallel_size), gen_server_config.get("pipeline_parallel_size",
109+
1)
110+
ctx_tp, ctx_pp = ctx_server_config.get(
111+
"tensor_parallel_size",
112+
tensor_parallel_size), ctx_server_config.get("pipeline_parallel_size",
113+
1)
114+
115+
ctx_total_gpus = ctx_tp * ctx_pp
116+
gen_total_gpus = gen_tp * gen_pp
97117

98118
env_ctx = os.environ.copy()
99119
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
100-
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(
101-
map(str, range(tensor_parallel_size)))
120+
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(ctx_total_gpus)))
102121

103122
env_gen = os.environ.copy()
104123
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
105124
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
106-
map(str, range(tensor_parallel_size, 2 * tensor_parallel_size)))
125+
map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus)))
107126
ctx_server_args = common_args + [
108-
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
127+
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path,
128+
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
109129
]
110130
gen_server_args = common_args + [
111-
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
131+
"--port", "8002", "--extra_llm_api_options", gen_server_config_path,
132+
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
112133
]
113134
if "max_num_tokens" in ctx_server_config:
114135
ctx_server_args.append(
@@ -182,6 +203,56 @@ def generate_async(prompt: str,
182203
disaggregated_server.wait()
183204

184205

206+
def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
207+
ctx_tp: int, gen_pp: int, gen_tp: int,
208+
test_set: LlmapiAccuracyTestHarness):
209+
if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count():
210+
pytest.fail(
211+
f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test"
212+
)
213+
214+
kv_cache_config = {
215+
"free_gpu_memory_fraction": 0.5,
216+
"enable_block_reuse": False
217+
}
218+
ctx_server_config = {
219+
"pipeline_parallel_size": ctx_pp,
220+
"tensor_parallel_size": ctx_tp,
221+
"disable_overlap_scheduler": True,
222+
"kv_cache_config": kv_cache_config,
223+
"cache_transceiver_config": {
224+
"backend": "default"
225+
}
226+
}
227+
gen_server_config = {
228+
"tensor_parallel_size": gen_tp,
229+
"pipeline_parallel_size": gen_pp,
230+
"disable_overlap_scheduler": True,
231+
"kv_cache_config": kv_cache_config,
232+
"cache_transceiver_config": {
233+
"backend": "default"
234+
}
235+
}
236+
disaggregated_server_config = {
237+
"hostname": "localhost",
238+
"port": 8000,
239+
"backend": "pytorch",
240+
"context_servers": {
241+
"num_instances": 1,
242+
"urls": ["localhost:8001"]
243+
},
244+
"generation_servers": {
245+
"num_instances": 1,
246+
"urls": ["localhost:8002"]
247+
}
248+
}
249+
with launch_disaggregated_llm(disaggregated_server_config,
250+
ctx_server_config, gen_server_config,
251+
model_path) as llm:
252+
task = test_set(model_name)
253+
task.evaluate(llm)
254+
255+
185256
@pytest.mark.timeout(3600)
186257
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
187258
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
@@ -315,6 +386,20 @@ def test_eagle3(self, overlap_scheduler):
315386
task = GSM8K(self.MODEL_NAME)
316387
task.evaluate(llm)
317388

389+
@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
390+
ids=["tp1pp2", "tp2pp1", "tp2pp2"])
391+
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
392+
def test_tp_pp_symmetric(self, tp, pp, testset):
393+
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
394+
tp, get_accuracy_task(testset))
395+
396+
@parametrize_with_ids("ctx_pp", [2, 4])
397+
@parametrize_with_ids("gen_tp", [1, 2])
398+
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
399+
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
400+
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
401+
gen_tp, get_accuracy_task(testset))
402+
318403

319404
@pytest.mark.skip_less_device_memory(140000)
320405
@pytest.mark.timeout(3600)

0 commit comments

Comments
 (0)