Skip to content

Commit 6e67c55

Browse files
committed
[TRTLLM-6683][feat] Support LoRA reload CPU cache evicted adapter (#6510)
Signed-off-by: Amit Zuker <[email protected]>
1 parent 824feb8 commit 6e67c55

File tree

16 files changed

+185
-218
lines changed

16 files changed

+185
-218
lines changed

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2347,6 +2347,9 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
23472347
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager);
23482348

23492349
void moveLoraWeightsToGpu(runtime::BufferManager const& manager);
2350+
2351+
// Remove LoRA weights and LoRA config tensors
2352+
void removeLoraTensors();
23502353
};
23512354

23522355
} // namespace tensorrt_llm::batch_manager

cpp/tensorrt_llm/batch_manager/llmRequest.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,4 +365,10 @@ void LlmRequest::moveLoraWeightsToGpu(runtime::BufferManager const& manager)
365365
mLoraWeights = gpuLoraWeights;
366366
}
367367

368+
void LlmRequest::removeLoraTensors()
369+
{
370+
mLoraWeights.reset();
371+
mLoraConfig.reset();
372+
}
373+
368374
} // namespace tensorrt_llm::batch_manager

cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -591,27 +591,16 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> llmRe
591591
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
592592
if (llmRequest->getLoraTaskId().has_value())
593593
{
594-
auto taskId = llmRequest->getLoraTaskId().value();
595594
try
596595
{
597-
return mHostLoraCache->determineNumPages(taskId);
596+
return mHostLoraCache->determineNumPages(llmRequest->getLoraTaskId().value());
598597
}
599598
catch (std::runtime_error& e)
600599
{
601600
if (llmRequest->getLoraConfig().has_value())
602601
{
603602
return mHostLoraCache->determineNumPages(llmRequest->getLoraConfig().value());
604603
}
605-
if (!llmRequest->getLoraWeights().has_value())
606-
{
607-
auto const reqId = llmRequest->mRequestId;
608-
std::string errMsg
609-
= "Request ID " + std::to_string(reqId) + " has no LoRA adapter weights while configured with LoRA task "
610-
+ std::to_string(taskId) + " that's not found in LoRA CPU cache."
611-
" Note that currently a request with LoRA task that was already loaded is sent without its LoRA weights to save its serialization, copy and deserialization,"
612-
" so if this LoRA task was evicted from LoRA CPU cache, then its reuse is currently not supported.";
613-
throw PeftTaskNotCachedException(errMsg);
614-
}
615604
throw;
616605
}
617606
}

cpp/tensorrt_llm/executor/loraConfig.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,29 @@ LoraConfig::LoraConfig(IdType taskId, std::optional<Tensor> weights, std::option
2727
, mWeights(std::move(weights))
2828
, mConfig(std::move(config))
2929
{
30-
if (mWeights.has_value() || mConfig.has_value())
30+
if (mConfig.has_value())
3131
{
32-
TLLM_CHECK_WITH_INFO(mWeights.has_value() && mConfig.has_value(),
33-
"Request for LoRA inference must have both lora weights and lora config");
34-
35-
SizeType32 constexpr expectedWeightsDims = 2;
3632
SizeType32 constexpr expectedConfigDims = 2;
37-
38-
TLLM_CHECK_WITH_INFO(
39-
mWeights.value().getShape().size() == expectedWeightsDims, "Expected weights tensor to have 2 dimensions");
4033
TLLM_CHECK_WITH_INFO(
4134
mConfig.value().getShape().size() == expectedConfigDims, "Expected config tensor to have 2 dimensions");
42-
TLLM_CHECK_WITH_INFO(mWeights.value().getMemoryType() != MemoryType::kGPU
43-
&& mWeights.value().getMemoryType() != MemoryType::kUNKNOWN,
44-
"Expected lora weights to be in CPU memory");
4535
TLLM_CHECK_WITH_INFO(mConfig.value().getMemoryType() != MemoryType::kGPU
4636
&& mConfig.value().getMemoryType() != MemoryType::kUNKNOWN,
47-
"Expected lora weights to be in CPU memory");
37+
"Expected lora config to be in CPU memory");
4838
TLLM_CHECK_WITH_INFO(
4939
mConfig.value().getDataType() == DataType::kINT32, "Expected lora config tensor to have type kINT32");
40+
}
41+
if (mWeights.has_value())
42+
{
43+
SizeType32 constexpr expectedWeightsDims = 2;
44+
TLLM_CHECK_WITH_INFO(
45+
mConfig.has_value(), "Request for LoRA inference with lora weights must also have lora config");
46+
47+
TLLM_CHECK_WITH_INFO(
48+
mWeights.value().getShape().size() == expectedWeightsDims, "Expected weights tensor to have 2 dimensions");
49+
50+
TLLM_CHECK_WITH_INFO(mWeights.value().getMemoryType() != MemoryType::kGPU
51+
&& mWeights.value().getMemoryType() != MemoryType::kUNKNOWN,
52+
"Expected lora weights to be in CPU memory");
5053

5154
TLLM_CHECK_WITH_INFO(mConfig.value().getShape()[0] == mWeights.value().getShape()[0],
5255
"Expected dim 0 of lora weights and lora config to have the same size");

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,8 @@ void initBindings(nb::module_& m)
375375
.def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, nb::arg("manager"))
376376
.def("finish_by_reason", &tb::LlmRequest::finishByReason, nb::arg("finish_reason"))
377377
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
378-
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter"));
378+
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter"))
379+
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors);
379380

380381
nb::class_<tb::SequenceSlotManager>(m, "SequenceSlotManager")
381382
.def(nb::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), nb::arg("max_num_slots"),

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,8 @@ void initBindings(pybind11::module_& m)
381381
.def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, py::arg("manager"))
382382
.def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason"))
383383
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
384-
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter"));
384+
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter"))
385+
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors);
385386

386387
py::classh<tb::SequenceSlotManager>(m, "SequenceSlotManager")
387388
.def(py::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), py::arg("max_num_slots"),

cpp/tests/unit_tests/executor/loraConfigTest.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,12 @@ TEST(LoraConfigTest, invalidInputs)
5353

5454
// This should work
5555
auto loraConfig = LoraConfig(1, weights, config);
56+
// Having config only without weights is allowed
57+
loraConfig = LoraConfig(1, std::nullopt, config);
5658

5759
{
58-
// Only one specified
59-
testInvalid(1, std::nullopt, config, "must have both");
60-
61-
// Only one specified
62-
testInvalid(1, weights, std::nullopt, "must have both");
60+
// Only weights specified without config - not allowed
61+
testInvalid(1, weights, std::nullopt, "lora weights must also have lora config");
6362
}
6463

6564
{

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ def create_py_executor_instance(
502502
)
503503
peft_cache_manager = PeftCacheManager(
504504
peft_cache_config=executor_config.peft_cache_config,
505+
lora_config=lora_config,
505506
model_config=model_binding_config,
506507
world_config=world_config,
507508
)

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def __init__(
285285

286286
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
287287
None)
288+
self.py_lora_path: str | None = kwargs.pop("py_lora_path", None)
288289
# Multimodal data
289290
self.py_multimodal_data = kwargs.pop("py_multimodal_data", None)
290291
if llm_request is not None:
@@ -490,6 +491,7 @@ def executor_request_to_llm_request(
490491
if executor_request.lora_config is not None else None,
491492
lora_config=executor_request.lora_config.config
492493
if executor_request.lora_config is not None else None,
494+
py_lora_path=getattr(executor_request, "py_lora_path", None),
493495
mrope_rotary_cos_sin=mrope_rotary_cos_sin,
494496
mrope_position_deltas=mrope_position_deltas,
495497
lookahead_config=None,

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
import tensorrt_llm
1111
import tensorrt_llm.bindings
1212
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
13+
from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig
1314
from tensorrt_llm.sampling_params import SamplingParams
1415

15-
from ..._utils import binding_dtype_size, nvtx_range
16+
from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range
1617
from ...logger import logger
1718
from ...mapping import Mapping
1819
from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig,
@@ -1170,6 +1171,7 @@ class PeftCacheManager(BaseResourceManager):
11701171

11711172
def __init__(self,
11721173
peft_cache_config: PeftCacheConfig,
1174+
lora_config: LoraConfig,
11731175
model_config: ModelConfig,
11741176
world_config: WorldConfig | None = None):
11751177
import tensorrt_llm.bindings as _tb
@@ -1200,8 +1202,36 @@ def __init__(self,
12001202
model_config=model_config,
12011203
world_config=world_config,
12021204
buffer_manager=buffer_manager)
1205+
self._lora_config = lora_config
1206+
self._lora_model_config = LoraModelConfig(
1207+
lora_config.lora_target_modules,
1208+
lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size,
1209+
binding_to_str_dtype(model_config.data_type))
1210+
self._lora_manager = LoraManager()
12031211

12041212
def add_request_peft(self, request: LlmRequest):
1213+
if request.lora_task_id is not None:
1214+
is_task_cached = self.impl.is_task_cached(request.lora_task_id)
1215+
if is_task_cached:
1216+
# PeftCacheManager::addRequestPeft in CPP doesn't allow having only one of [config tensor, weights
1217+
# tensor] without the other. Since there's no need for any of them when the LoRA adapter is already
1218+
# cached, we can safely remove both from the request.
1219+
request.remove_lora_tensors()
1220+
elif request.lora_weights is None and request.py_lora_path:
1221+
self._lora_manager.load_from_ckpt(
1222+
[request.py_lora_path],
1223+
model_config=self._lora_model_config,
1224+
runtime_mapping=None,
1225+
uids=[request.lora_task_id],
1226+
ckpt_source=self._lora_config.lora_ckpt_source)
1227+
request.lora_weights = self._lora_manager.cpp_lora_weights[
1228+
request.lora_task_id]
1229+
1230+
# PeftCacheManager CPP implementation expects an extra dim at index 0
1231+
if request.lora_weights is not None:
1232+
request.lora_weights = request.lora_weights.unsqueeze(0)
1233+
if request.lora_config is not None:
1234+
request.lora_config = request.lora_config.unsqueeze(0)
12051235
self.impl.add_request_peft(request, True)
12061236

12071237
def ensure_batch(self,
@@ -1221,12 +1251,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
12211251
context_batch = scheduled_batch.context_requests
12221252
generation_batch = scheduled_batch.generation_requests
12231253
for req in context_batch:
1224-
if req.lora_weights is not None and req.lora_config is not None:
1225-
req.lora_weights = req.lora_weights.reshape(
1226-
[1] + list(req.lora_weights.shape))
1227-
req.lora_config = req.lora_config.reshape(
1228-
[1] + list(req.lora_config.shape))
1229-
self.impl.add_request_peft(req, True)
1254+
self.add_request_peft(req)
12301255

12311256
py_lora_task_layer_module_configs = self.impl.ensure_batch(
12321257
context_batch, generation_batch, False)

0 commit comments

Comments
 (0)