Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmarks/cpp/disaggServerBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,8 @@ texec::Request makeExecutorContextRequest(Sample const& sample, SizeType32 const
std::nullopt, // kvCacheRetentionConfig
std::nullopt, // logitsPostProcessorName
std::nullopt, // logitsPostProcessor
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt,
std::nullopt); // cacheSaltID
request.setRequestType(tensorrt_llm::executor::RequestType::REQUEST_TYPE_CONTEXT_ONLY);
return request;
}
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/cpp/gptManagerBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,8 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
std::nullopt, // kvCacheRetentionConfig
std::nullopt, // logitsPostProcessorName
std::nullopt, // logitsPostProcessor
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt,
std::nullopt); // cacheSaltID
}

void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngineDir,
Expand Down
9 changes: 6 additions & 3 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ using UniqueToken = tensorrt_llm::runtime::UniqueToken;
using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType;

// Type alias for multimodal hash key (hash array + start offset)
using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
Expand Down Expand Up @@ -115,6 +116,7 @@ struct BlockKey
// Extra keys for multimodal data (similar to VLLM's approach)
// Each extra key is a pair of (mm_hash, start_offset_in_block)
std::vector<MmKey> extraKeys;
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt;

BlockKey() = default;

Expand All @@ -129,24 +131,25 @@ struct BlockKey
}

explicit BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens,
std::vector<MmKey> extraKeys = {})
std::vector<MmKey> extraKeys = {}, std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: usesExtraIds{usesExtraIds}
, loraTaskId{loraTaskId}
, uniqueTokens{std::move(uniqueTokens)}
, extraKeys{std::move(extraKeys)}
, cacheSaltID{cacheSaltID}
{
}

bool operator==(BlockKey const& other) const noexcept
{
return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys);
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID);
}

int partialMatch(BlockKey const& other) const noexcept
{
SizeType32 numMatched{0};
if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys)
if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID)
{
auto [matchEnd, otherMatchEnd] = std::mismatch(
uniqueTokens.begin(), uniqueTokens.end(), other.uniqueTokens.begin(), other.uniqueTokens.end());
Expand Down
40 changes: 29 additions & 11 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ class GenericLlmRequest
RequestIdType, TensorPtr&, BeamTokens const&, TStream const&, std::optional<RequestIdType>)>;
using RequestPtr = std::shared_ptr<GenericLlmRequest>;
using MillisecondsType = std::chrono::milliseconds;
using CacheSaltIDType = runtime::CacheSaltIDType;

// 49 parameters, 56 items in initialization list
// 50 parameters, 57 items in initialization list
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> const& inputTokens,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
Expand Down Expand Up @@ -134,7 +135,8 @@ class GenericLlmRequest
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens)
Expand Down Expand Up @@ -191,6 +193,7 @@ class GenericLlmRequest
, mGuidedDecodingParams(std::move(guidedDecodingParams))
, mLanguageAdapterUid(languageAdapterUid)
, mAllottedTimeMs(allottedTimeMs)
, mCacheSaltID(cacheSaltID)
{
if (mEncoderTokens.has_value() || encoderInputFeatures.has_value())
{
Expand All @@ -200,7 +203,7 @@ class GenericLlmRequest
initialize(*inputTokens, returnLogProbs);
}

// 32 parameters, 39 items in initialization list
// 33 parameters, 40 items in initialization list
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, VecTokens const& inputTokens,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
Expand All @@ -218,7 +221,8 @@ class GenericLlmRequest
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens.size())
, mMaxNewTokens(maxNewTokens)
Expand Down Expand Up @@ -258,6 +262,7 @@ class GenericLlmRequest
, mContextPhaseParams(contextPhaseParams)
, mNumReturnSequences(numReturnSequences)
, mLanguageAdapterUid(languageAdapterUid)
, mCacheSaltID(cacheSaltID)
{
if (mEncoderTokens.has_value())
{
Expand All @@ -266,7 +271,7 @@ class GenericLlmRequest
initialize(inputTokens, returnLogProbs);
}

// 29 items in initialization list
// 30 items in initialization list
GenericLlmRequest(RequestIdType requestId, executor::Request const& req)
: mRequestId(requestId)
, mPromptLen(req.getInputTokenIds().size())
Expand Down Expand Up @@ -297,6 +302,7 @@ class GenericLlmRequest
, mGuidedDecodingParams(req.getGuidedDecodingParams())
, mLanguageAdapterUid(req.getLanguageAdapterUid())
, mAllottedTimeMs(req.getAllottedTimeMs())
, mCacheSaltID(req.getCacheSaltID())
{
if (req.getRequestType() == executor::RequestType::REQUEST_TYPE_GENERATION_ONLY)
{
Expand Down Expand Up @@ -1761,6 +1767,11 @@ class GenericLlmRequest
return mLanguageAdapterUid;
}

[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const
{
return mCacheSaltID;
}

std::vector<SizeType32> getLanguageAdapterRouting(
SizeType32 const reqNumLanguages, SizeType32 const inputLength) const
{
Expand Down Expand Up @@ -2039,6 +2050,9 @@ class GenericLlmRequest

bool mUseDraftModel{false};

// Cache salt id for each request.
std::optional<CacheSaltIDType> mCacheSaltID{std::nullopt};

private:
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
{
Expand Down Expand Up @@ -2219,7 +2233,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
std::move(promptEmbeddingTable), promptVocabSize, std::move(multimodalHashes),
Expand All @@ -2231,7 +2246,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
std::move(encoderInputTokens), returnEncoderOutput, clientId, priority, std::move(encoderInputFeatures),
std::move(encoderOutputLength), std::move(crossAttentionMask), llmRequestType,
std::move(inputTokenExtraIds), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks),
returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams)
returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams,
cacheSaltID)
{
}

Expand Down Expand Up @@ -2269,7 +2285,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList),
std::move(stopWordsList),
Expand Down Expand Up @@ -2299,7 +2316,7 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
inputTokenExtraIds ? std::make_optional(std::make_shared<VecTokenExtraIds>(std::move(*inputTokenExtraIds)))
: std::optional<std::shared_ptr<VecTokenExtraIds>>(std::nullopt),
numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics,
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams)
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID)
{
}

Expand All @@ -2321,14 +2338,15 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId,
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
lookaheadConfig, returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens),
std::move(draftLogits), excludeInputFromOutput, std::move(logitsPostProcessor),
applyLogitsPostProcessorBatched, std::move(encoderInputTokens), returnEncoderOutput, clientId, priority,
numReturnSequences, languageAdapterUid, contextPhaseParams)
numReturnSequences, languageAdapterUid, contextPhaseParams, cacheSaltID)
{
}

Expand Down
8 changes: 6 additions & 2 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,8 @@ class Request
/// @param allottedTimeMs The allotted time in milliseconds after which the request is cancelled with a timedOut
/// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism
/// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled.
// 34 parameters
/// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string.
// 35 parameters
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
Expand All @@ -697,7 +698,8 @@ class Request
std::optional<EagleConfig> eagleConfig = std::nullopt, std::optional<Tensor> skipCrossAttnBlocks = std::nullopt,
std::optional<GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt);
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt);

/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
static auto constexpr kBatchedPostProcessorName = "batched";
Expand Down Expand Up @@ -745,6 +747,7 @@ class Request
[[nodiscard]] std::optional<GuidedDecodingParams> getGuidedDecodingParams() const;
[[nodiscard]] std::optional<SizeType32> getLanguageAdapterUid() const;
[[nodiscard]] std::optional<MillisecondsType> getAllottedTimeMs() const;
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const;
[[nodiscard]] std::optional<std::vector<std::string>> getAdditionalOutputNames() const;

void setStreaming(bool streaming);
Expand Down Expand Up @@ -780,6 +783,7 @@ class Request
void setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams);
void setLanguageAdapterUid(SizeType32 languageAdapterUid);
void setAllottedTimeMs(MillisecondsType allottedTimeMs);
void setCacheSaltID(CacheSaltIDType cacheSaltID);

private:
friend class Serialization;
Expand Down
1 change: 1 addition & 0 deletions cpp/include/tensorrt_llm/executor/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ using RandomSeedType = std::uint64_t;
using VecLogProbs = std::vector<FloatType>;
using StreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
using MillisecondsType = std::chrono::milliseconds;
using CacheSaltIDType = std::uint64_t;
using LogitsPostProcessor
= std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr const&, std::optional<IdType>)>;
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;
Expand Down
1 change: 1 addition & 0 deletions cpp/include/tensorrt_llm/runtime/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using TokenIdType = std::int32_t;
using LoraTaskIdType = std::uint64_t;
using TokenExtraIdType = std::uint64_t;
using VecTokenExtraIds = std::vector<TokenExtraIdType>;
using CacheSaltIDType = std::uint64_t;

struct UniqueToken
{
Expand Down
14 changes: 12 additions & 2 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ std::vector<MmKey> generateBlockHashExtraKeys(
// Check if this multimodal content overlaps with the current block
if (endTokenIdx > startPos && startTokenIdx < startPos + length)
{
SizeType32 mmStartInBlock = (startPos >= startTokenIdx) ? 0 : startTokenIdx - startPos;
uint64_t mmStartInBlock = (startPos >= startTokenIdx) ? 0 : static_cast<uint64_t>(startTokenIdx - startPos);
extraKeys.emplace_back(mmHashArray, mmStartInBlock);
}
}
Expand All @@ -151,7 +151,7 @@ std::vector<BlockKey> buildBlockKeys(
currentTokenIdx += uniqueTokens.size();

blockKeys.emplace_back(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(),
std::move(uniqueTokens), std::move(extraKeys));
std::move(uniqueTokens), std::move(extraKeys), llmRequest.getCacheSaltID());
}
return blockKeys;
}
Expand All @@ -167,6 +167,16 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no
// Constants provide very good distribution - each input bit affects each output bit with ~50% probability.
size_t seed = blockKey.uniqueTokens.size() ^ parentHash * UINT64_C(0xbf58476d1ce4e5b9);

if (parentHash == 0 && blockKey.cacheSaltID)
{
// Only hashing the cache salt ID for the first block in the sequence
uint64_t c = blockKey.cacheSaltID.value();
c = (c ^ (c >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
c = (c ^ (c >> 27)) * UINT64_C(0x94d049bb133111eb);
c = c ^ (c >> 31);
seed ^= c + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}

for (auto const& uniqueToken : blockKey.uniqueTokens)
{
uint32_t a = static_cast<uint32_t>(uniqueToken.tokenId);
Expand Down
16 changes: 13 additions & 3 deletions cpp/tensorrt_llm/executor/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

namespace tensorrt_llm::executor
{
// 36 parameters
// 37 parameters
Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, SamplingConfig const& samplingConfig,
OutputConfig const& outputConfig, std::optional<SizeType32> const& endId, std::optional<SizeType32> const& padId,
std::optional<std::vector<SizeType32>> positionIds, std::optional<std::list<VecTokens>> badWords,
Expand All @@ -41,7 +41,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
std::optional<SizeType32> encoderOutputLength, std::optional<Tensor> crossAttentionMask,
SizeType32 numReturnSequences, std::optional<EagleConfig> eagleConfig, std::optional<Tensor> skipCrossAttnBlocks,
std::optional<GuidedDecodingParams> guidedDecodingParams, std::optional<SizeType32> languageAdapterUid,
std::optional<MillisecondsType> allottedTimeMs)
std::optional<MillisecondsType> allottedTimeMs, std::optional<CacheSaltIDType> cacheSaltID)
: mImpl(std::make_unique<Impl>(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId,
padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias),
std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput),
Expand All @@ -50,7 +50,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, type,
std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask,
numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid,
allottedTimeMs))
allottedTimeMs, cacheSaltID))
{
}

Expand Down Expand Up @@ -249,6 +249,11 @@ std::optional<SizeType32> Request::getLanguageAdapterUid() const
return mImpl->getLanguageAdapterUid();
}

std::optional<CacheSaltIDType> Request::getCacheSaltID() const
{
return mImpl->getCacheSaltID();
}

void Request::setStreaming(bool streaming)
{
mImpl->setStreaming(streaming);
Expand Down Expand Up @@ -413,4 +418,9 @@ void Request::setLanguageAdapterUid(SizeType32 languageAdapterUid)
{
return mImpl->setLanguageAdapterUid(languageAdapterUid);
}

void Request::setCacheSaltID(CacheSaltIDType cacheSaltID)
{
return mImpl->setCacheSaltID(cacheSaltID);
}
} // namespace tensorrt_llm::executor
Loading