Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,40 @@ namespace tk = tensorrt_llm::kernels;
namespace tensorrt_llm::batch_manager
{

std::map<SizeType32, SizeType32> TrtGptModelInflightBatching::calculateCacheSizePerTokenForDisagg(
ModelConfig const& modelConfig, WorldConfig const& worldConfig,
std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor)
{
// These are the number of attention layers on this PP rank.
auto const numLocalAttnLayers
= modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
// These are the number of attention layers on all previous PP ranks.
auto const numLowerRankAttnLayers = modelConfig.countLowerRankLayers(ModelConfig::LayerType::kATTENTION,
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
// Use global ranks of attention layers to lookup from maxAttentionWindowVec.
auto const startAttnLayerId = numLowerRankAttnLayers;
auto const endAttnLayerId = numLowerRankAttnLayers + numLocalAttnLayers;
auto const numNonUniqueWindowSizes = static_cast<SizeType32>(maxAttentionWindowVec.size());
std::map<SizeType32, std::vector<SizeType32>> uniqueWindowSizeToLayers;
for (SizeType32 layerIdx = startAttnLayerId; layerIdx < endAttnLayerId; layerIdx++)
{
// maxAttentionWindowVec may or may not be stretched to the length of numLayers yet.
// If not stretched yet, we cycle through the window sizes.
auto const windowSize = maxAttentionWindowVec.at(layerIdx % numNonUniqueWindowSizes);
uniqueWindowSizeToLayers[windowSize].push_back(layerIdx);
}
std::map<SizeType32, SizeType32> cacheSizeBytesPerTokenPerWindow;
for (auto const& [windowSize, globalLayerIds] : uniqueWindowSizeToLayers)
{
auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize(
modelConfig, globalLayerIds, isCrossAttention, kvFactor);
auto const cacheSizeBytesPerToken = cacheSizePerToken * BufferDataType(modelConfig.getKvDataType()).getSize();
cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken;
}

return cacheSizeBytesPerTokenPerWindow;
};

bool TrtGptModelInflightBatching::executorConfigIsValid(
ModelConfig const& modelConfig, executor::ExecutorConfig const& executorConfig)
{
Expand Down Expand Up @@ -264,32 +298,10 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
}
if (mModelConfig.isTransformerBased() && modelConfig.isKVCacheEnabled())
{

auto calculateCacheSizePerToken
= [](ModelConfig const& modelConfig, WorldConfig const& worldConfig,
std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor)
{
auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = modelConfig.getNumKvHeadsPerLayerLocalRange(
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank(), isCrossAttention);
auto numKvHeadsPerLayer = std::vector<SizeType32>(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd);
auto windowSizeLayers
= BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, modelConfig.getNbLayers());
std::map<SizeType32, SizeType32> cacheSizeBytesPerTokenPerWindow;
for (auto const& [windowSize, managedLayers] : windowSizeLayers)
{
auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize(
modelConfig, managedLayers, isCrossAttention, kvFactor);
auto const cacheSizeBytesPerToken
= cacheSizePerToken * BufferDataType(modelConfig.getKvDataType()).getSize();
cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken;
}

return cacheSizeBytesPerTokenPerWindow;
};
auto cacheTransceiverConfig
= executorConfig.getCacheTransceiverConfig().value_or(executor::CacheTransceiverConfig());

auto const cacheSizeBytesPerTokenPerWindow = calculateCacheSizePerToken(
auto const cacheSizeBytesPerTokenPerWindow = calculateCacheSizePerTokenForDisagg(
mModelConfig, mWorldConfig, getMaxAttentionWindowVec(), mModelConfig.useCrossAttention(), 2);
auto cacheTransPreAllocaSize = kv_cache_manager::CacheTransBufferManager::preAllocBufferSize(
cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig);
Expand Down
13 changes: 13 additions & 0 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,19 @@ class TrtGptModelInflightBatching : public TrtGptModel

~TrtGptModelInflightBatching() override;

/// @brief Calculate the cache size per token for the disaggregated serving.
/// @param modelConfig Model configuration.
/// @param worldConfig World configuration.
/// @param maxAttentionWindowVec Maximum attention window vector. (may have fewer elements than numLayers, in which
/// case it cycles)
/// @param isCrossAttention Whether the attention is cross attention.
/// @param kvFactor KV factor.
/// @return Cache size per token for the disaggregated layers. Note that window size is not included in the result
/// here.
[[nodiscard]] static std::map<SizeType32, SizeType32> calculateCacheSizePerTokenForDisagg(
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor);

void terminateRequest(LlmRequestPtr const& llmRequest, bool pause = false) override;

/// @brief Terminate request in the next forwardSync call that includes the request.
Expand Down
85 changes: 85 additions & 0 deletions cpp/tests/unit_tests/executor/executorTestSmall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <random>
#include <tuple>
#include <unordered_map>

namespace tensorrt_llm::testing
{
Expand Down Expand Up @@ -201,4 +202,88 @@ INSTANTIATE_TEST_SUITE_P(Float, DecoderFloatTest, paramGenerator,
return nameStringStream.str();
});

// Helper function to test calculateCacheSizePerToken with given parameters.
std::map<runtime::SizeType32, runtime::SizeType32> calculateCacheSizePerTokenHelper(
std::vector<runtime::SizeType32> const& maxAttentionWindowVec, runtime::SizeType32 kvFactor = 2,
runtime::SizeType32 vocabSize = 32, runtime::SizeType32 nbLayers = 4, runtime::SizeType32 nbAttentionLayers = 4,
runtime::SizeType32 nbRnnLayers = 0, runtime::SizeType32 nbHeads = 8, runtime::SizeType32 hiddenSize = 512,
bool isCrossAttention = false)
{
// Create minimal ModelConfig for testing.
auto modelConfig = runtime::ModelConfig(
vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, nvinfer1::DataType::kFLOAT);
modelConfig.useGptAttentionPlugin(true);
modelConfig.setModelVariant(runtime::ModelConfig::ModelVariant::kGpt);
modelConfig.setKVCacheType(runtime::ModelConfig::KVCacheType::kPAGED);

auto const worldConfig = runtime::WorldConfig();

return batch_manager::TrtGptModelInflightBatching::calculateCacheSizePerTokenForDisagg(
modelConfig, worldConfig, maxAttentionWindowVec, isCrossAttention, kvFactor);
}

// Test for TrtGptModelInflightBatching::calculateCacheSizePerToken function with different layer types.
TEST(TrtInflightBatchingTest, CalculateCacheSizePerTokenForDisagg)
{
// Common parameters.
constexpr runtime::SizeType32 nbLayers = 5;
constexpr runtime::SizeType32 hiddenSize = 512;
constexpr runtime::SizeType32 kvFactor = 2;
constexpr runtime::SizeType32 vocabSize = 32;
constexpr runtime::SizeType32 nbHeads = 8;
// Test case 1: Single attention window size - attention layers only.
{
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128};
constexpr runtime::SizeType32 nbAttentionLayers = 5;
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
constexpr runtime::SizeType32 nbRnnLayers = 0;
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
EXPECT_EQ(result.size(), 1);
EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement);
}

// Test case 2: Multiple attention window sizes - attention layers only.
{
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128, 256};
constexpr runtime::SizeType32 nbAttentionLayers = 5;
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
constexpr runtime::SizeType32 nbRnnLayers = 0;
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
EXPECT_EQ(result.size(), 2);
auto const nbAttentionLayersIn128Window = 3;
auto const nbAttentionLayersIn256Window = 2;
EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement);
EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement);
}

// Test case 3: Single attention window size - attention and rnn layers.
{
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128};
constexpr runtime::SizeType32 nbAttentionLayers = 3;
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
constexpr runtime::SizeType32 nbRnnLayers = 2;
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
EXPECT_EQ(result.size(), 1);
EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement);
}

// Test case 4: Multiple attention window sizes - attention and rnn layers.
{
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128, 256};
constexpr runtime::SizeType32 nbAttentionLayers = 3;
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
constexpr runtime::SizeType32 nbRnnLayers = 2;
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
EXPECT_EQ(result.size(), 2);
auto const nbAttentionLayersIn128Window = 2;
auto const nbAttentionLayersIn256Window = 1;
EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement);
EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement);
}
}

} // namespace tensorrt_llm::testing