Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ class BasePeftCacheManager
public:
using LlmRequestPtr = std::shared_ptr<LlmRequest>;
using RequestVector = std::vector<LlmRequestPtr>;
using PeftTable = std::map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
using PeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
using TaskPeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
using TaskIdToReqIds = std::unordered_map<uint64_t, std::vector<uint64_t>>;
using EnsureBatchTaskResult = std::tuple<TaskPeftTable, TaskIdToReqIds>;

virtual ~BasePeftCacheManager() = default;

Expand Down Expand Up @@ -99,6 +102,8 @@ class BasePeftCacheManager
class PeftCacheManager : public BasePeftCacheManager
{
public:
using EnsureBatchTaskResult = BasePeftCacheManager::EnsureBatchTaskResult;

PeftCacheManager(PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);

Expand All @@ -109,12 +114,17 @@ class PeftCacheManager : public BasePeftCacheManager
PeftTable ensureBatch(RequestVector const& contextRequests, RequestVector const& generationRequests,
bool resetGpuCache = false) override;

EnsureBatchTaskResult ensureBatchMapTaskId(
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache = false);

[[nodiscard]] bool isTaskCached(uint64_t taskId) const;

[[nodiscard]] bool isTaskDone(uint64_t taskId) const;

[[nodiscard]] bool isTaskDoneDevice(uint64_t taskId) const;

[[nodiscard]] bool isTaskCachedDevice(uint64_t const taskId) const;

void resetDeviceCache() override;

void markRequestDone(LlmRequest const& llmReq, bool pause = false) override;
Expand Down Expand Up @@ -159,7 +169,7 @@ class PeftCacheManager : public BasePeftCacheManager
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToReqIds;
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToPausedReqIds;

std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>> getTaskMaps(
std::tuple<std::unordered_map<uint64_t, std::future<void>>, TaskIdToReqIds> getTaskMaps(
RequestVector const& contextRequests, RequestVector const& generationRequests);

runtime::ModelConfig mModelConfig;
Expand Down
40 changes: 30 additions & 10 deletions cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,11 +373,11 @@ void PeftCacheManager::addRequestPeft(std::shared_ptr<LlmRequest> llmRequest, bo
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}

std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>>
std::tuple<std::unordered_map<uint64_t, std::future<void>>, BasePeftCacheManager::TaskIdToReqIds>
PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVector const& generationRequests)
{
std::map<uint64_t, std::vector<uint64_t>> taskIdToReqIds;
std::map<uint64_t, std::future<void>> taskIdToFuture;
TaskIdToReqIds taskIdToReqIds;
std::unordered_map<uint64_t, std::future<void>> taskIdToFuture;
std::lock_guard<std::mutex> futuresLock(mPutFuturesMutex);
for (auto const& requests : {contextRequests, generationRequests})
{
Expand Down Expand Up @@ -415,7 +415,7 @@ PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVecto
return {std::move(taskIdToFuture), taskIdToReqIds};
}

PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
PeftCacheManager::EnsureBatchTaskResult PeftCacheManager::ensureBatchMapTaskId(
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
Expand All @@ -426,7 +426,7 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
auto [taskIdToFuture_, taskIdToReqIds] = getTaskMaps(contextRequests, generationRequests);
auto taskIdToFuture = std::move(taskIdToFuture_); // captured structured bindings are a C++20 extension

std::map<uint64_t, std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
std::unordered_map<uint64_t, std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
for (auto& [taskId, taskFuture] : taskIdToFuture)
{
auto fn = [&taskIdToFuture, taskId = taskId, this]() -> std::vector<runtime::LoraCache::TaskLayerModuleConfig>
Expand Down Expand Up @@ -457,18 +457,31 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
ensureFutures.try_emplace(taskId, std::move(f));
}

PeftTable peftTable{};
TaskPeftTable peftTable{};
for (auto const& [taskId, reqIds] : taskIdToReqIds)
{
auto&& f = ensureFutures.at(taskId);
auto const values = f.get();
for (auto const& reqId : reqIds)
peftTable.try_emplace(taskId, values);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
return {std::move(peftTable), std::move(taskIdToReqIds)};
}

PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache)
{
auto [taskTable, taskIdToReqIds] = ensureBatchMapTaskId(contextRequests, generationRequests, resetGpuCache);
PeftTable requestTable{};
for (auto const& [taskId, values] : taskTable)
{
auto const& reqIds = taskIdToReqIds.at(taskId);
for (auto const reqId : reqIds)
{
peftTable.try_emplace(reqId, values);
requestTable.try_emplace(reqId, values);
}
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
return peftTable;
return requestTable;
}

bool PeftCacheManager::isTaskCached(uint64_t taskId) const
Expand All @@ -486,6 +499,11 @@ bool PeftCacheManager::isTaskDoneDevice(uint64_t taskId) const
return mDeviceLoraCache->isDone(taskId);
}

bool PeftCacheManager::isTaskCachedDevice(uint64_t const taskId) const
{
return mDeviceLoraCache->has(taskId);
}

void PeftCacheManager::updateTaskState(uint64_t taskId, uint64_t reqId, bool terminate, bool pause)
{
if (!terminate)
Expand Down Expand Up @@ -645,3 +663,5 @@ SizeType32 NoOpPeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> l
return 0;
}
} // namespace tensorrt_llm::batch_manager

// TODO: merge C++ LoRA caching status with Py Slot manager
Loading