Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -2347,6 +2347,9 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager);

void moveLoraWeightsToGpu(runtime::BufferManager const& manager);

// Remove LoRA weights and LoRA config tensors
void removeLoraTensors();
};

} // namespace tensorrt_llm::batch_manager
6 changes: 6 additions & 0 deletions cpp/tensorrt_llm/batch_manager/llmRequest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,4 +365,10 @@ void LlmRequest::moveLoraWeightsToGpu(runtime::BufferManager const& manager)
mLoraWeights = gpuLoraWeights;
}

void LlmRequest::removeLoraTensors()
{
mLoraWeights.reset();
mLoraConfig.reset();
}

} // namespace tensorrt_llm::batch_manager
13 changes: 1 addition & 12 deletions cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,27 +591,16 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> llmRe
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
if (llmRequest->getLoraTaskId().has_value())
{
auto taskId = llmRequest->getLoraTaskId().value();
try
{
return mHostLoraCache->determineNumPages(taskId);
return mHostLoraCache->determineNumPages(llmRequest->getLoraTaskId().value());
}
catch (std::runtime_error& e)
{
if (llmRequest->getLoraConfig().has_value())
{
return mHostLoraCache->determineNumPages(llmRequest->getLoraConfig().value());
}
if (!llmRequest->getLoraWeights().has_value())
{
auto const reqId = llmRequest->mRequestId;
std::string errMsg
= "Request ID " + std::to_string(reqId) + " has no LoRA adapter weights while configured with LoRA task "
+ std::to_string(taskId) + " that's not found in LoRA CPU cache."
" Note that currently a request with LoRA task that was already loaded is sent without its LoRA weights to save its serialization, copy and deserialization,"
" so if this LoRA task was evicted from LoRA CPU cache, then its reuse is currently not supported.";
throw PeftTaskNotCachedException(errMsg);
}
throw;
}
}
Expand Down
27 changes: 15 additions & 12 deletions cpp/tensorrt_llm/executor/loraConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,29 @@ LoraConfig::LoraConfig(IdType taskId, std::optional<Tensor> weights, std::option
, mWeights(std::move(weights))
, mConfig(std::move(config))
{
if (mWeights.has_value() || mConfig.has_value())
if (mConfig.has_value())
{
TLLM_CHECK_WITH_INFO(mWeights.has_value() && mConfig.has_value(),
"Request for LoRA inference must have both lora weights and lora config");

SizeType32 constexpr expectedWeightsDims = 2;
SizeType32 constexpr expectedConfigDims = 2;

TLLM_CHECK_WITH_INFO(
mWeights.value().getShape().size() == expectedWeightsDims, "Expected weights tensor to have 2 dimensions");
TLLM_CHECK_WITH_INFO(
mConfig.value().getShape().size() == expectedConfigDims, "Expected config tensor to have 2 dimensions");
TLLM_CHECK_WITH_INFO(mWeights.value().getMemoryType() != MemoryType::kGPU
&& mWeights.value().getMemoryType() != MemoryType::kUNKNOWN,
"Expected lora weights to be in CPU memory");
TLLM_CHECK_WITH_INFO(mConfig.value().getMemoryType() != MemoryType::kGPU
&& mConfig.value().getMemoryType() != MemoryType::kUNKNOWN,
"Expected lora weights to be in CPU memory");
"Expected lora config to be in CPU memory");
TLLM_CHECK_WITH_INFO(
mConfig.value().getDataType() == DataType::kINT32, "Expected lora config tensor to have type kINT32");
}
if (mWeights.has_value())
{
SizeType32 constexpr expectedWeightsDims = 2;
TLLM_CHECK_WITH_INFO(
mConfig.has_value(), "Request for LoRA inference with lora weights must also have lora config");

TLLM_CHECK_WITH_INFO(
mWeights.value().getShape().size() == expectedWeightsDims, "Expected weights tensor to have 2 dimensions");

TLLM_CHECK_WITH_INFO(mWeights.value().getMemoryType() != MemoryType::kGPU
&& mWeights.value().getMemoryType() != MemoryType::kUNKNOWN,
"Expected lora weights to be in CPU memory");

TLLM_CHECK_WITH_INFO(mConfig.value().getShape()[0] == mWeights.value().getShape()[0],
"Expected dim 0 of lora weights and lora config to have the same size");
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ void initBindings(nb::module_& m)
.def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, nb::arg("manager"))
.def("finish_by_reason", &tb::LlmRequest::finishByReason, nb::arg("finish_reason"))
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter"));
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter"))
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors);

nb::class_<tb::SequenceSlotManager>(m, "SequenceSlotManager")
.def(nb::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), nb::arg("max_num_slots"),
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ void initBindings(pybind11::module_& m)
.def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, py::arg("manager"))
.def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason"))
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter"));
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter"))
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors);

py::classh<tb::SequenceSlotManager>(m, "SequenceSlotManager")
.def(py::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), py::arg("max_num_slots"),
Expand Down
9 changes: 4 additions & 5 deletions cpp/tests/unit_tests/executor/loraConfigTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ TEST(LoraConfigTest, invalidInputs)

// This should work
auto loraConfig = LoraConfig(1, weights, config);
// Having config only without weights is allowed
loraConfig = LoraConfig(1, std::nullopt, config);

{
// Only one specified
testInvalid(1, std::nullopt, config, "must have both");

// Only one specified
testInvalid(1, weights, std::nullopt, "must have both");
// Only weights specified without config - not allowed
testInvalid(1, weights, std::nullopt, "lora weights must also have lora config");
}

{
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def create_py_executor_instance(
)
peft_cache_manager = PeftCacheManager(
peft_cache_config=executor_config.peft_cache_config,
lora_config=lora_config,
model_config=model_binding_config,
world_config=world_config,
)
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def __init__(

self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
None)
self.py_lora_path: str | None = kwargs.pop("py_lora_path", None)
# Multimodal data
self.py_multimodal_data = kwargs.pop("py_multimodal_data", None)
if llm_request is not None:
Expand Down Expand Up @@ -490,6 +491,7 @@ def executor_request_to_llm_request(
if executor_request.lora_config is not None else None,
lora_config=executor_request.lora_config.config
if executor_request.lora_config is not None else None,
py_lora_path=getattr(executor_request, "py_lora_path", None),
mrope_rotary_cos_sin=mrope_rotary_cos_sin,
mrope_position_deltas=mrope_position_deltas,
lookahead_config=None,
Expand Down
39 changes: 32 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import tensorrt_llm
import tensorrt_llm.bindings
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig
from tensorrt_llm.sampling_params import SamplingParams

from ..._utils import binding_dtype_size, nvtx_range
from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range
from ...logger import logger
from ...mapping import Mapping
from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig,
Expand Down Expand Up @@ -1170,6 +1171,7 @@ class PeftCacheManager(BaseResourceManager):

def __init__(self,
peft_cache_config: PeftCacheConfig,
lora_config: LoraConfig,
model_config: ModelConfig,
world_config: WorldConfig | None = None):
import tensorrt_llm.bindings as _tb
Expand Down Expand Up @@ -1200,8 +1202,36 @@ def __init__(self,
model_config=model_config,
world_config=world_config,
buffer_manager=buffer_manager)
self._lora_config = lora_config
self._lora_model_config = LoraModelConfig(
lora_config.lora_target_modules,
lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size,
binding_to_str_dtype(model_config.data_type))
self._lora_manager = LoraManager()

def add_request_peft(self, request: LlmRequest):
if request.lora_task_id is not None:
is_task_cached = self.impl.is_task_cached(request.lora_task_id)
if is_task_cached:
# PeftCacheManager::addRequestPeft in CPP doesn't allow having only one of [config tensor, weights
# tensor] without the other. Since there's no need for any of them when the LoRA adapter is already
# cached, we can safely remove both from the request.
request.remove_lora_tensors()
elif request.lora_weights is None and request.py_lora_path:
self._lora_manager.load_from_ckpt(
[request.py_lora_path],
model_config=self._lora_model_config,
runtime_mapping=None,
uids=[request.lora_task_id],
ckpt_source=self._lora_config.lora_ckpt_source)
request.lora_weights = self._lora_manager.cpp_lora_weights[
request.lora_task_id]

# PeftCacheManager CPP implementation expects an extra dim at index 0
if request.lora_weights is not None:
request.lora_weights = request.lora_weights.unsqueeze(0)
if request.lora_config is not None:
request.lora_config = request.lora_config.unsqueeze(0)
self.impl.add_request_peft(request, True)

def ensure_batch(self,
Expand All @@ -1221,12 +1251,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
context_batch = scheduled_batch.context_requests
generation_batch = scheduled_batch.generation_requests
for req in context_batch:
if req.lora_weights is not None and req.lora_config is not None:
req.lora_weights = req.lora_weights.reshape(
[1] + list(req.lora_weights.shape))
req.lora_config = req.lora_config.reshape(
[1] + list(req.lora_config.shape))
self.impl.add_request_peft(req, True)
self.add_request_peft(req)

py_lora_task_layer_module_configs = self.impl.ensure_batch(
context_batch, generation_batch, False)
Expand Down
7 changes: 7 additions & 0 deletions tensorrt_llm/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def str_dtype_to_torch(dtype):
bool=DataType.BOOL,
fp8=DataType.FP8,
)
_binding_to_str_dtype = {v: k for k, v in _str_to_binding_dtype_dict.items()}

_binding_dtype_size = {
DataType.INT64: 8,
Expand All @@ -194,6 +195,12 @@ def str_dtype_to_torch(dtype):
}


def binding_to_str_dtype(binding_dtype) -> str:
ret = _binding_to_str_dtype.get(binding_dtype)
assert ret is not None, f'Unsupported binding dtype: {binding_dtype}'
return ret


def binding_dtype_size(dtype: DataType):
return _binding_dtype_size[dtype]

Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def _load_prompt_adapter(self,

def _enqueue_request(self, request: GenerationRequest) -> int:
assert request.id is not None
py_lora_path = None
if self._lora_manager is not None and request.lora_request is not None:
adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache(
request.lora_request.adapter_id)
Expand All @@ -381,8 +382,8 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
task_id=request.lora_request.adapter_id,
weights=self._lora_manager.cpp_lora_weights[uid]
if not adapter_in_cache else None,
config=self._lora_manager.cpp_lora_config[uid]
if not adapter_in_cache else None)
config=self._lora_manager.cpp_lora_config[uid])
py_lora_path = request.lora_request.lora_path
else:
lora_config = None

Expand Down Expand Up @@ -497,6 +498,7 @@ def _deduce_max_tokens(request: GenerationRequest,
kv_cache_retention_config=request.kv_cache_retention_config,
context_phase_params=context_phase_params,
type=request_type)
executor_request.py_lora_path = py_lora_path

if self._is_pytorch_backend and request.multimodal_params is not None:
if request.multimodal_params.multimodal_data is not None:
Expand Down
17 changes: 10 additions & 7 deletions tests/unittest/_torch/test_resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@
import unittest

import numpy as np
import pytest
import torch

import tensorrt_llm
import tensorrt_llm.bindings
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager,
PeftCacheConfig,
PeftCacheManager)
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
from tensorrt_llm.bindings import executor as tllm
from tensorrt_llm.bindings.internal.batch_manager import \
PeftTaskNotCachedException
from tensorrt_llm.lora_manager import LoraConfig

DataType = tensorrt_llm.bindings.DataType
LoraModule = tensorrt_llm.bindings.LoraModule
Expand Down Expand Up @@ -247,7 +248,7 @@ def _create_request(self,
lora_config = torch.from_numpy(lora_config)

input_tokens = [i + 1 for i in range(max_new_tokens)]
request = tensorrt_llm.bindings.internal.batch_manager.LlmRequest(
request = LlmRequest(
request_id=request_id,
max_new_tokens=max_new_tokens,
input_tokens=input_tokens,
Expand All @@ -261,22 +262,21 @@ def _create_request(self,
return request

def get_lora_data(self):
"""Create mock LoRA weights and config that match the C++ validation expectations.
"""Create mock LoRA weights and config.

Returns:
tuple: (weights tensor, config tensor) formatted correctly for the C++ implementation.
tuple: (weights tensor, config tensor).
"""
lora_weights = np.load(self.TP1_WEIGHTS_PATH).astype(np.float16)
lora_weights = np.expand_dims(lora_weights, axis=0)
lora_config = np.load(self.TP1_CONFIG_PATH)
lora_config = np.expand_dims(lora_config, axis=0)
return lora_weights, lora_config

def test_successful_mocked_peft_cache_manager_initialization(self):
peft_cache_config = self.create_peft_cache_config()

peft_cache_manager = PeftCacheManager(
peft_cache_config=peft_cache_config,
lora_config=LoraConfig(),
model_config=self.model_config,
)

Expand All @@ -290,6 +290,7 @@ def test_add_request_peft_empty_weights_config(self):

peft_cache_manager = PeftCacheManager(
peft_cache_config=peft_cache_config,
lora_config=LoraConfig(),
model_config=self.model_config,
)

Expand All @@ -307,6 +308,7 @@ def test_add_request_peft_empty_batch(self):

peft_cache_manager = PeftCacheManager(
peft_cache_config=peft_cache_config,
lora_config=LoraConfig(),
model_config=self.model_config,
)

Expand All @@ -322,6 +324,7 @@ def test_add_request_peft(self):

peft_cache_manager = PeftCacheManager(
peft_cache_config=peft_cache_config,
lora_config=LoraConfig(),
model_config=self.model_config,
)

Expand Down Expand Up @@ -349,13 +352,13 @@ def test_add_request_peft(self):

self.assertEqual(len(peft_table), self.num_lora_modules)

@pytest.mark.skip(reason="https://nvbugs/5324252")
def test_put_get(self):
"""Test adding a request with properly configured LoRA weights and config."""
peft_cache_config = self.create_peft_cache_config()

peft_cache_manager = PeftCacheManager(
peft_cache_config=peft_cache_config,
lora_config=LoraConfig(),
model_config=self.model_config,
)

Expand Down
Loading