Skip to content

Commit 98428f3

Browse files
authored
[TRTLLM-5826][feat] Support pytorch LoRA adapter eviction (#5616)
Signed-off-by: Amit Zuker <[email protected]>
1 parent 943fd41 commit 98428f3

File tree

14 files changed

+457
-131
lines changed

14 files changed

+457
-131
lines changed

cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -591,20 +591,28 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> llmRe
591591
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
592592
if (llmRequest->getLoraTaskId().has_value())
593593
{
594+
auto taskId = llmRequest->getLoraTaskId().value();
594595
try
595596
{
596-
return mHostLoraCache->determineNumPages(llmRequest->getLoraTaskId().value());
597+
return mHostLoraCache->determineNumPages(taskId);
597598
}
598599
catch (std::runtime_error& e)
599600
{
600601
if (llmRequest->getLoraConfig().has_value())
601602
{
602603
return mHostLoraCache->determineNumPages(llmRequest->getLoraConfig().value());
603604
}
604-
else
605+
if (!llmRequest->getLoraWeights().has_value())
605606
{
606-
throw;
607+
auto const reqId = llmRequest->mRequestId;
608+
std::string errMsg
609+
= "Request ID " + std::to_string(reqId) + " has no LoRA adapter weights while configured with LoRA task "
610+
+ std::to_string(taskId) + " that's not found in LoRA CPU cache."
611+
" Note that currently a request with LoRA task that was already loaded is sent without its LoRA weights to save its serialization, copy and deserialization,"
612+
" so if this LoRA task was evicted from LoRA CPU cache, then its reuse is currently not supported.";
613+
throw PeftTaskNotCachedException(errMsg);
607614
}
615+
throw;
608616
}
609617
}
610618
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);

cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,8 @@ void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)
469469

470470
py::classh<tb::PeftCacheManager, tb::BasePeftCacheManager>(m, "PeftCacheManager")
471471
.def(py::init<tb::PeftCacheManagerConfig, tr::ModelConfig, tr::WorldConfig, tr::BufferManager>(),
472-
py::arg("config"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"));
472+
py::arg("config"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"))
473+
.def("is_task_cached", &tb::PeftCacheManager::isTaskCached, py::arg("taskId"));
473474

474475
py::classh<tb::NoOpPeftCacheManager, tb::BasePeftCacheManager>(m, "NoOpPeftCacheManager").def(py::init());
475476
}

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,9 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
286286
resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True)
287287

288288
# scheduling
289-
capacitor_scheduler = BindCapacityScheduler(ad_config.max_batch_size, kv_cache_manager.impl)
289+
capacitor_scheduler = BindCapacityScheduler(
290+
ad_config.max_batch_size, kv_cache_manager.impl, peft_cache_manager=None
291+
)
290292
mb_scheduler = BindMicroBatchScheduler(
291293
ad_config.max_batch_size, engine.cache_seq_interface.info.max_num_tokens
292294
)

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ def create_py_executor_instance(
432432
f"Cannot overwrite existing resource manager {key}.")
433433
resources[key] = value
434434

435+
peft_cache_manager = None
435436
if lora_config is not None:
436437
from tensorrt_llm.bindings import LoraModule
437438

@@ -507,6 +508,7 @@ def create_py_executor_instance(
507508
capacity_scheduler = BindCapacityScheduler(
508509
max_num_sequences,
509510
kv_cache_manager.impl if kv_cache_manager is not None else None,
511+
peft_cache_manager.impl if peft_cache_manager is not None else None,
510512
executor_config.scheduler_config.capacity_scheduler_policy,
511513
two_step_lookahead=mapping.has_pp())
512514
mb_scheduler = BindMicroBatchScheduler(executor_config.max_batch_size,

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1218,7 +1218,7 @@ def update_resources(self, scheduled_batch: ScheduledRequests):
12181218
pass
12191219

12201220
def free_resources(self, request: LlmRequest):
1221-
pass
1221+
self.impl.mark_request_done(request)
12221222

12231223
def shutdown(self):
12241224
pass

tensorrt_llm/_torch/pyexecutor/scheduler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,14 @@ def __init__(
7373
self,
7474
max_num_requests: int,
7575
kv_cache_manager,
76+
peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None,
7677
scheduler_policy: tb_executor.CapacitySchedulerPolicy = tb_executor.
7778
CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
7879
two_step_lookahead: bool = False,
7980
):
8081
super(BindCapacityScheduler, self).__init__()
8182
self.kv_cache_manager = kv_cache_manager
83+
self.peft_cache_manager = peft_cache_manager
8284

8385
self.impl = tb_internal.algorithms.CapacityScheduler(
8486
max_num_requests=max_num_requests,
@@ -91,7 +93,8 @@ def __init__(
9193
def schedule_request(
9294
self, active_requests: RequestList
9395
) -> tuple[list[LlmRequest], list[LlmRequest], list[LlmRequest]]:
94-
return self.impl(active_requests, self.kv_cache_manager)
96+
return self.impl(active_requests, self.kv_cache_manager,
97+
self.peft_cache_manager)
9598

9699

97100
class GuaranteedNoEvictScheduler(CapacityScheduler):

tensorrt_llm/executor/worker.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,23 @@ def _create_engine():
150150
self._runtime_model_config = _engine_config_to_model_config(
151151
engine_config)
152152
if engine_config.build_config.plugin_config.lora_plugin:
153-
self._lora_manager = LoraManager()
153+
# TODO(azuker): Passing peft cache manager to LoraManager is used for LoRA optimization
154+
# (see LoraManager constructor docstring). Getting the peft cache manager from this
155+
# point in the TRT flow is currently not supported (it's at the CPP
156+
# Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA
157+
# optimization is not available in TRT-python flow.
158+
self._lora_manager = LoraManager(cpp_peft_cache_manager=None)
154159
if engine_config.build_config.max_prompt_embedding_table_size > 0:
155160
self._prompt_adapter_manager = PromptAdapterManager()
156161

157162
if getattr(executor_config, "backend",
158163
"") == "pytorch" and lora_config is not None:
159-
self._lora_manager = LoraManager()
164+
from tensorrt_llm._torch.pyexecutor.resource_manager import \
165+
ResourceManagerType
166+
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
167+
ResourceManagerType.PEFT_CACHE_MANAGER)
168+
self._lora_manager = LoraManager(
169+
cpp_peft_cache_manager=peft_cache_manager.impl)
160170
lora_model_config = self.engine.model_engine.lora_model_config
161171
assert lora_model_config is not None
162172
self._lora_model_config = lora_model_config
@@ -362,15 +372,16 @@ def _load_prompt_adapter(self,
362372
def _enqueue_request(self, request: GenerationRequest) -> int:
363373
assert request.id is not None
364374
if self._lora_manager is not None and request.lora_request is not None:
365-
loaded_new_lora_adapter = self._load_lora_adapter(
366-
request.lora_request)
375+
adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache(
376+
request.lora_request.adapter_id)
377+
self._load_lora_adapter(request.lora_request)
367378
uid = str(request.lora_request.adapter_id)
368379
lora_config = tllm.LoraConfig(
369380
task_id=request.lora_request.adapter_id,
370381
weights=self._lora_manager.cpp_lora_weights[uid]
371-
if loaded_new_lora_adapter else None,
382+
if not adapter_in_cache else None,
372383
config=self._lora_manager.cpp_lora_config[uid]
373-
if loaded_new_lora_adapter else None)
384+
if not adapter_in_cache else None)
374385
else:
375386
lora_config = None
376387

tensorrt_llm/lora_manager.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import torch
1212
import yaml
1313

14+
from tensorrt_llm.bindings import internal as tb_internal
15+
1416
from ._utils import DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
1517
from .layers.linear import ColumnLinear
1618
from .mapping import Mapping
@@ -436,8 +438,16 @@ class LoraManager(object):
436438
"mlp_gate_up": 18,
437439
}
438440

439-
def __init__(self):
440-
"""Constructor."""
441+
def __init__(
442+
self, cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None
443+
):
444+
"""Constructor.
445+
446+
Args:
447+
cpp_peft_cache_manager (PeftCacheManager, optional): used by is_adapter_in_cpu_cache method, that's used for
448+
a performance optimization with LoRA of not sending the LoRA adapter weights with every LLM request when
449+
the adapter is already loaded in the LoRA CPU cache.
450+
"""
441451
# _lora_uid_to_low_ranks: dict[str -> dict[int -> dict[str -> int]]]
442452
# {
443453
# uid: {
@@ -473,6 +483,19 @@ def __init__(self):
473483
self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu
474484
self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu
475485
self.lora_target_modules: List[str] = []
486+
self._cpp_peft_cache_manager = cpp_peft_cache_manager
487+
488+
def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool:
489+
"""Best effort to check if a LoRA adapter is in the LoRA CPU cache.
490+
491+
If no cpp_peft_cache_manager instance was given at the construction of this LoraManager instance, then False is
492+
returned.
493+
"""
494+
return (
495+
self._cpp_peft_cache_manager.is_task_cached(adapter_uid)
496+
if self._cpp_peft_cache_manager
497+
else False
498+
)
476499

477500
@staticmethod
478501
def get_missing_qkv_modules(lora_target_modules):
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from typing import OrderedDict, Type
2+
3+
from utils.llm_data import llm_models_root
4+
from utils.util import duplicate_list_to_length, flatten_list, similar
5+
6+
from tensorrt_llm import SamplingParams
7+
from tensorrt_llm.executor.request import LoRARequest
8+
from tensorrt_llm.llmapi.llm import BaseLLM
9+
10+
11+
def check_llama_7b_multi_unique_lora_adapters_from_request(
12+
lora_adapter_count_per_call: list[int], repeat_calls: int,
13+
repeats_per_call: int, llm_class: Type[BaseLLM], **llm_kwargs):
14+
"""Calls llm.generate s.t. for each C in lora_adapter_count_per_call, llm.generate is called with C requests
15+
repeated 'repeats_per_call' times, where each request is configured with a unique LoRA adapter ID.
16+
This entire process is done in a loop 'repeats_per_call' times with the same requests.
17+
Asserts the output of each llm.generate call is similar to the expected.
18+
""" # noqa: D205
19+
total_lora_adapters = sum(lora_adapter_count_per_call)
20+
hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf"
21+
hf_lora_dirs = [
22+
f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1",
23+
f"{llm_models_root()}/llama-models/Japanese-Alpaca-LoRA-7b-v0"
24+
]
25+
# Each prompt should have a reference for every LoRA adapter dir (in the same order as in hf_lora_dirs)
26+
prompt_to_references = OrderedDict({
27+
"美国的首都在哪里? \n答案:": [
28+
"美国的首都是华盛顿。\n\n美国的",
29+
"纽约\n\n### カンファレンスの",
30+
],
31+
"アメリカ合衆国の首都はどこですか? \n答え:": [
32+
"华盛顿。\n\n英国の首都是什",
33+
"ワシントン\nQ1. アメリカ合衆国",
34+
],
35+
})
36+
37+
prompts_to_generate = duplicate_list_to_length(
38+
flatten_list([[prompt] * len(hf_lora_dirs)
39+
for prompt in prompt_to_references.keys()]),
40+
total_lora_adapters)
41+
references = duplicate_list_to_length(
42+
flatten_list(list(prompt_to_references.values())), total_lora_adapters)
43+
lora_requests = [
44+
LoRARequest(str(i), i, hf_lora_dirs[i % len(hf_lora_dirs)])
45+
for i in range(total_lora_adapters)
46+
]
47+
llm = llm_class(hf_model_dir, **llm_kwargs)
48+
49+
# Perform repeats of the same requests to test reuse and reload of adapters previously unloaded from cache
50+
try:
51+
for _ in range(repeat_calls):
52+
last_idx = 0
53+
for adapter_count in lora_adapter_count_per_call:
54+
sampling_params = SamplingParams(max_tokens=20)
55+
outputs = llm.generate(
56+
prompts_to_generate[last_idx:last_idx + adapter_count] *
57+
repeats_per_call,
58+
sampling_params,
59+
lora_request=lora_requests[last_idx:last_idx +
60+
adapter_count] *
61+
repeats_per_call)
62+
for output, ref in zip(
63+
outputs, references[last_idx:last_idx + adapter_count] *
64+
repeats_per_call):
65+
assert similar(output.outputs[0].text, ref)
66+
last_idx += adapter_count
67+
finally:
68+
llm.shutdown()
69+
70+
71+
def check_llama_7b_multi_lora_from_request_test_harness(
72+
llm_class: Type[BaseLLM], **llm_kwargs) -> None:
73+
hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf"
74+
hf_lora_dir1 = f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"
75+
hf_lora_dir2 = f"{llm_models_root()}/llama-models/Japanese-Alpaca-LoRA-7b-v0"
76+
prompts = [
77+
"美国的首都在哪里? \n答案:",
78+
"美国的首都在哪里? \n答案:",
79+
"美国的首都在哪里? \n答案:",
80+
"アメリカ合衆国の首都はどこですか? \n答え:",
81+
"アメリカ合衆国の首都はどこですか? \n答え:",
82+
"アメリカ合衆国の首都はどこですか? \n答え:",
83+
]
84+
references = [
85+
"沃尔玛\n\n## 新闻\n\n* ",
86+
"美国的首都是华盛顿。\n\n美国的",
87+
"纽约\n\n### カンファレンスの",
88+
"Washington, D.C.\nWashington, D.C. is the capital of the United",
89+
"华盛顿。\n\n英国の首都是什",
90+
"ワシントン\nQ1. アメリカ合衆国",
91+
]
92+
key_words = [
93+
"沃尔玛",
94+
"华盛顿",
95+
"纽约",
96+
"Washington",
97+
"华盛顿",
98+
"ワシントン",
99+
]
100+
lora_req1 = LoRARequest("luotuo", 1, hf_lora_dir1)
101+
lora_req2 = LoRARequest("Japanese", 2, hf_lora_dir2)
102+
sampling_params = SamplingParams(max_tokens=20)
103+
104+
llm = llm_class(hf_model_dir, **llm_kwargs)
105+
try:
106+
outputs = llm.generate(prompts,
107+
sampling_params,
108+
lora_request=[
109+
None, lora_req1, lora_req2, None, lora_req1,
110+
lora_req2
111+
])
112+
finally:
113+
llm.shutdown()
114+
for output, ref, key_word in zip(outputs, references, key_words):
115+
assert similar(output.outputs[0].text,
116+
ref) or key_word in output.outputs[0].text

0 commit comments

Comments
 (0)