Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,8 @@ void CacheFormatter::unformat(TransferSession& session)
if (selfConfig.getModelConfig().mNbKvHeadsPerLayer.size() != destConfig.getModelConfig().mNbKvHeadsPerLayer.size())
{
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support same number of layers");
TLLM_LOG_WARNING("self: %zu dest %zu", selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(),
destConfig.getModelConfig().mNbKvHeadsPerLayer.size());
return false;
}
int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size();
Expand Down
5 changes: 4 additions & 1 deletion scripts/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def clear_folder(folder_path):
if os.path.isdir(item_path) and not os.path.islink(item_path):
rmtree(item_path)
else:
os.remove(item_path)
try:
os.remove(item_path)
except (OSError, IOError) as e:
print(f"Failed to remove {item_path}: {e}", file=sys.stderr)


def sysconfig_scheme(override_vars=None):
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def __init__(self, mapping: Mapping, kv_cache_manager: KVCacheManager,
attention_type: AttentionTypeCpp,
cache_transceiver_config: CacheTransceiverConfig):
world_config = mapping_to_world_config(mapping)
num_kv_heads_per_layer = kv_cache_manager.num_kv_heads_per_layer
total_num_kv_heads_per_layer = kv_cache_manager.total_num_kv_heads_per_layer
head_dim = kv_cache_manager.head_dim
tokens_per_block = kv_cache_manager.tokens_per_block
dtype = kv_cache_manager.dtype

self.impl = CacheTransceiverCpp(kv_cache_manager.impl,
num_kv_heads_per_layer, head_dim,
total_num_kv_heads_per_layer, head_dim,
tokens_per_block, world_config, dtype,
attention_type,
cache_transceiver_config)
Expand Down
52 changes: 50 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class BatchState:
@dataclasses.dataclass
class BatchStatePP(BatchState):
microbatch_id: int = -1
scheduled_ctx_reqs: list[LlmRequest] = None


class PyExecutor:
Expand Down Expand Up @@ -643,6 +644,7 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):
return False

def _executor_loop_pp(self):
logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
torch.cuda.set_device(self.device_id)
microbatch_id = 0
with self._profiler() as profile_step:
Expand All @@ -656,6 +658,9 @@ def _executor_loop_pp(self):
if self.should_stop_processing:
break

if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()

if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats(
len(new_requests),
Expand All @@ -664,9 +669,23 @@ def _executor_loop_pp(self):

self._pad_attention_dp_dummy_request()

scheduled_batch, _, _ = self._schedule()
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)

if self.kv_cache_transceiver:
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
self._prepare_disagg_gen_init(
fitting_disagg_gen_init_requests)

if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
logger.warning(
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
)
self.kv_cache_transceiver.check_context_transfer_status(
1)

self.num_scheduled_requests = scheduled_batch.batch_size

logger.debug(
f'has {len(self.active_requests)} active_request, '
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
Expand All @@ -679,7 +698,7 @@ def _executor_loop_pp(self):
can_queue = 0 not in tp_batch_sizes
else:
can_queue = scheduled_batch.batch_size > 0
if not can_queue:
if not can_queue and not self.kv_cache_transceiver:
assert len(self.inflight_req_ids) > 0, (
"fail to schedule any pending request, probably run out of resource"
)
Expand All @@ -688,8 +707,28 @@ def _executor_loop_pp(self):
self.micro_batches[microbatch_id] = None
else:
self._add_inflight_ids(scheduled_batch)

if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
self._prepare_disagg_gen_transmission_complete(
scheduled_batch)

self.resource_manager.prepare_resources(scheduled_batch)

# The generation requests that are do not have batch_idx,
# needs to be in front of the batch due to the assumptions
# made in model_engine.py::_forward_step. This is only important
# for disaggregated serving. For non-disaggregated serving,
# the generation requests always have batch_idx.
scheduled_batch.generation_requests = sorted( # stable sort
scheduled_batch.generation_requests,
key=lambda req: int(req.py_batch_idx is not None),
)

if self.kv_cache_transceiver:
# Return the first token to the client
self._handle_first_token_response(scheduled_batch)

# Stage 1: Async forward (all ranks) and decoding pass (last rank only)
if not self.dist.is_last_pp_rank:
sample_state = self._forward_step_inter_pp(
Expand Down Expand Up @@ -720,6 +759,7 @@ def _executor_loop_pp(self):
iter_start_time=iter_start_time,
iter_stats=iter_stats,
microbatch_id=microbatch_id,
scheduled_ctx_reqs=scheduled_batch.context_requests,
)

self.micro_batches[microbatch_id] = batch_state
Expand Down Expand Up @@ -784,6 +824,11 @@ def _executor_loop_pp(self):
if previous_batch is not None:
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
self._update_requests(previous_batch.sample_state)

if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs:
self._send_disagg_ctx_cache(
previous_batch.scheduled_ctx_reqs)

self._handle_canceled_requests()
finished_requests = self._handle_responses()
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
Expand All @@ -792,6 +837,9 @@ def _executor_loop_pp(self):
self._remove_inflight_ids(previous_scheduled_batch)
self.micro_batches[prev_microbatch_id] = None

if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
self._terminate_ctx_finished_requests()

# march forward in microbatch slots
microbatch_id = (microbatch_id + 1) % self.num_micro_batches

Expand Down
25 changes: 20 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,33 @@ def __init__(
(num_kv_heads + tp_size - 1) // tp_size
for _ in range(self.num_local_layers)
]
self.total_num_kv_heads_per_layer = [
(num_kv_heads + tp_size - 1) // tp_size
for _ in range(self.num_layers)
]
else:
assert len(num_kv_heads) == self.num_layers

def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
kv_head: Optional[int]):
if kv_head is not None:
num_kv_heads_per_layer.append(
(kv_head + tp_size - 1) // tp_size)
else:
num_kv_heads_per_layer.append(0)

self.num_kv_heads_per_layer = []
if self.num_local_layers > 0:
for i in self.pp_layers:
kv_head = num_kv_heads[i]
if kv_head is not None:
self.num_kv_heads_per_layer.append(
(kv_head + tp_size - 1) // tp_size)
else:
self.num_kv_heads_per_layer.append(0)
append_to_kv_heads_per_layer(self.num_kv_heads_per_layer,
kv_head)

self.total_num_kv_heads_per_layer = []
for i in range(self.num_layers):
kv_head = num_kv_heads[i]
append_to_kv_heads_per_layer(self.total_num_kv_heads_per_layer,
kv_head)

self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
Expand Down
11 changes: 11 additions & 0 deletions tests/integration/defs/accuracy/accuracy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,3 +735,14 @@ def setup_class(cls):
logger.set_level("info")
yield
logger.set_level(original_level)


def get_accuracy_task(dataset_name: str):
try:
task_class = globals()[dataset_name]
if issubclass(task_class, AccuracyTask):
return task_class
else:
raise ValueError(f"Unknown dataset: {dataset_name}.")
except KeyError:
raise ValueError(f"Not registered dataset: {dataset_name}.")
109 changes: 97 additions & 12 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
from tensorrt_llm.llmapi.llm_args import LlmArgs

from ..conftest import llm_models_root, parametrize_with_ids, skip_pre_hopper
from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
skip_pre_hopper)
from ..trt_test_alternative import popen
from .accuracy_core import GSM8K, MMLU, LlmapiAccuracyTestHarness
from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness,
get_accuracy_task)


class Result(GenerationResultBase):
Expand Down Expand Up @@ -71,6 +73,12 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
temp_dir = tempfile.TemporaryDirectory()
disaggregated_serving_config_path = os.path.join(
temp_dir.name, "disaggregated_serving_config.yaml")

if tensor_parallel_size > 1:
print(
f"Using unified tp parameter for testing is not recommended. Please use server configs instead."
)

with open(disaggregated_serving_config_path, "w") as f:
yaml.dump(disaggregated_server_config, f)
ctx_server_config_path = os.path.join(temp_dir.name,
Expand All @@ -88,27 +96,40 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
trtllm_serve_path = "trtllm-serve"
# Common arguments for both servers
common_args = [
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
"pytorch"
trtllm_serve_path,
model_name,
"--host",
"localhost",
"--backend",
"pytorch",
]

if tensor_parallel_size > 1:
common_args.append(f"--tp_size={tensor_parallel_size}")
gen_tp, gen_pp = gen_server_config.get(
"tensor_parallel_size",
tensor_parallel_size), gen_server_config.get("pipeline_parallel_size",
1)
ctx_tp, ctx_pp = ctx_server_config.get(
"tensor_parallel_size",
tensor_parallel_size), ctx_server_config.get("pipeline_parallel_size",
1)

ctx_total_gpus = ctx_tp * ctx_pp
gen_total_gpus = gen_tp * gen_pp

env_ctx = os.environ.copy()
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(
map(str, range(tensor_parallel_size)))
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(ctx_total_gpus)))

env_gen = os.environ.copy()
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
map(str, range(tensor_parallel_size, 2 * tensor_parallel_size)))
map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus)))
ctx_server_args = common_args + [
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path,
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
]
gen_server_args = common_args + [
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
"--port", "8002", "--extra_llm_api_options", gen_server_config_path,
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
]
if "max_num_tokens" in ctx_server_config:
ctx_server_args.append(
Expand Down Expand Up @@ -182,6 +203,56 @@ def generate_async(prompt: str,
disaggregated_server.wait()


def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
ctx_tp: int, gen_pp: int, gen_tp: int,
test_set: LlmapiAccuracyTestHarness):
if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count():
pytest.fail(
f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test"
)

kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False
}
ctx_server_config = {
"pipeline_parallel_size": ctx_pp,
"tensor_parallel_size": ctx_tp,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"cache_transceiver_config": {
"backend": "default"
}
}
gen_server_config = {
"tensor_parallel_size": gen_tp,
"pipeline_parallel_size": gen_pp,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"cache_transceiver_config": {
"backend": "default"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
model_path) as llm:
task = test_set(model_name)
task.evaluate(llm)


@pytest.mark.timeout(3600)
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
Expand Down Expand Up @@ -315,6 +386,20 @@ def test_eagle3(self, overlap_scheduler):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
ids=["tp1pp2", "tp2pp1", "tp2pp2"])
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_tp_pp_symmetric(self, tp, pp, testset):
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
tp, get_accuracy_task(testset))

@parametrize_with_ids("ctx_pp", [2, 4])
@parametrize_with_ids("gen_tp", [1, 2])
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
gen_tp, get_accuracy_task(testset))


@pytest.mark.skip_less_device_memory(140000)
@pytest.mark.timeout(3600)
Expand Down
Loading