Skip to content

Commit f083d44

Browse files
committed
formatting
1 parent 1a6e317 commit f083d44

File tree

3 files changed

+30
-33
lines changed

3 files changed

+30
-33
lines changed

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,19 @@ namespace tk = tensorrt_llm::kernels;
8383
namespace tensorrt_llm::batch_manager
8484
{
8585

86-
std::map<SizeType32, SizeType32> TrtGptModelInflightBatching::calculateCacheSizePerToken(ModelConfig const& modelConfig, WorldConfig const& worldConfig,
87-
std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor)
86+
std::map<SizeType32, SizeType32> TrtGptModelInflightBatching::calculateCacheSizePerToken(ModelConfig const& modelConfig,
87+
WorldConfig const& worldConfig, std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention,
88+
SizeType32 kvFactor)
8889
{
8990
// These are the number of attention layers on this PP rank.
90-
const auto numLocalAttnLayers = modelConfig.getNbAttentionLayers(
91-
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
91+
auto const numLocalAttnLayers
92+
= modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
9293
// These are the number of attention layers on all previous PP ranks.
93-
const auto numLowerRankAttnLayers = modelConfig.countLowerRankLayers(ModelConfig::LayerType::kATTENTION,
94+
auto const numLowerRankAttnLayers = modelConfig.countLowerRankLayers(ModelConfig::LayerType::kATTENTION,
9495
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
9596
// Use global ranks of attention layers to lookup from maxAttentionWindowVec.
96-
const auto startAttnLayerId = numLowerRankAttnLayers;
97-
const auto endAttnLayerId = numLowerRankAttnLayers + numLocalAttnLayers;
97+
auto const startAttnLayerId = numLowerRankAttnLayers;
98+
auto const endAttnLayerId = numLowerRankAttnLayers + numLocalAttnLayers;
9899
auto const numNonUniqueWindowSizes = static_cast<SizeType32>(maxAttentionWindowVec.size());
99100
std::map<SizeType32, std::vector<SizeType32>> uniqueWindowSizeToLayers;
100101
for (SizeType32 layerIdx = startAttnLayerId; layerIdx < endAttnLayerId; layerIdx++)
@@ -109,8 +110,7 @@ std::map<SizeType32, SizeType32> TrtGptModelInflightBatching::calculateCacheSize
109110
{
110111
auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize(
111112
modelConfig, globalLayerIds, isCrossAttention, kvFactor);
112-
auto const cacheSizeBytesPerToken
113-
= cacheSizePerToken * BufferDataType(modelConfig.getKvDataType()).getSize();
113+
auto const cacheSizeBytesPerToken = cacheSizePerToken * BufferDataType(modelConfig.getKvDataType()).getSize();
114114
cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken;
115115
}
116116

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ class TrtGptModelInflightBatching : public TrtGptModel
148148

149149
~TrtGptModelInflightBatching() override;
150150

151-
152-
[[nodiscard]] static std::map<SizeType32, SizeType32> calculateCacheSizePerToken(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
153-
std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor);
151+
[[nodiscard]] static std::map<SizeType32, SizeType32> calculateCacheSizePerToken(
152+
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
153+
std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor);
154154

155155
void terminateRequest(LlmRequestPtr const& llmRequest, bool pause = false) override;
156156

cpp/tests/unit_tests/executor/executorTestSmall.cpp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -204,19 +204,14 @@ INSTANTIATE_TEST_SUITE_P(Float, DecoderFloatTest, paramGenerator,
204204

205205
// Helper function to test calculateCacheSizePerToken with given parameters.
206206
std::map<runtime::SizeType32, runtime::SizeType32> calculateCacheSizePerTokenHelper(
207-
std::vector<runtime::SizeType32> const& maxAttentionWindowVec,
208-
runtime::SizeType32 kvFactor = 2,
209-
runtime::SizeType32 vocabSize = 32,
210-
runtime::SizeType32 nbLayers = 4,
211-
runtime::SizeType32 nbAttentionLayers = 4,
212-
runtime::SizeType32 nbRnnLayers = 0,
213-
runtime::SizeType32 nbHeads = 8,
214-
runtime::SizeType32 hiddenSize = 512,
207+
std::vector<runtime::SizeType32> const& maxAttentionWindowVec, runtime::SizeType32 kvFactor = 2,
208+
runtime::SizeType32 vocabSize = 32, runtime::SizeType32 nbLayers = 4, runtime::SizeType32 nbAttentionLayers = 4,
209+
runtime::SizeType32 nbRnnLayers = 0, runtime::SizeType32 nbHeads = 8, runtime::SizeType32 hiddenSize = 512,
215210
bool isCrossAttention = false)
216211
{
217212
// Create minimal ModelConfig for testing.
218-
auto modelConfig = runtime::ModelConfig(vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers,
219-
nbHeads, hiddenSize, nvinfer1::DataType::kFLOAT);
213+
auto modelConfig = runtime::ModelConfig(
214+
vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, nvinfer1::DataType::kFLOAT);
220215
modelConfig.useGptAttentionPlugin(true);
221216
modelConfig.setModelVariant(runtime::ModelConfig::ModelVariant::kGpt);
222217
modelConfig.setKVCacheType(runtime::ModelConfig::KVCacheType::kPAGED);
@@ -242,8 +237,8 @@ TEST(TrtInflightBatchingTest, CalculateCacheSizePerTokenForDisagg)
242237
constexpr runtime::SizeType32 nbAttentionLayers = 5;
243238
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
244239
constexpr runtime::SizeType32 nbRnnLayers = 0;
245-
auto result = calculateCacheSizePerTokenHelper(
246-
maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
240+
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
241+
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
247242
EXPECT_EQ(result.size(), 1);
248243
EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement);
249244
}
@@ -254,22 +249,23 @@ TEST(TrtInflightBatchingTest, CalculateCacheSizePerTokenForDisagg)
254249
constexpr runtime::SizeType32 nbAttentionLayers = 5;
255250
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
256251
constexpr runtime::SizeType32 nbRnnLayers = 0;
257-
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
252+
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
253+
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
258254
EXPECT_EQ(result.size(), 2);
259-
const auto nbAttentionLayersIn128Window = 3;
260-
const auto nbAttentionLayersIn256Window = 2;
255+
auto const nbAttentionLayersIn128Window = 3;
256+
auto const nbAttentionLayersIn256Window = 2;
261257
EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement);
262258
EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement);
263259
}
264260

265261
// Test case 3: Single attention window size - attention and rnn layers.
266262
{
267-
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128};
263+
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128};
268264
constexpr runtime::SizeType32 nbAttentionLayers = 3;
269265
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
270266
constexpr runtime::SizeType32 nbRnnLayers = 2;
271-
auto result = calculateCacheSizePerTokenHelper(
272-
maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
267+
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
268+
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
273269
EXPECT_EQ(result.size(), 1);
274270
EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement);
275271
}
@@ -280,10 +276,11 @@ TEST(TrtInflightBatchingTest, CalculateCacheSizePerTokenForDisagg)
280276
constexpr runtime::SizeType32 nbAttentionLayers = 3;
281277
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
282278
constexpr runtime::SizeType32 nbRnnLayers = 2;
283-
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
279+
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
280+
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
284281
EXPECT_EQ(result.size(), 2);
285-
const auto nbAttentionLayersIn128Window = 2;
286-
const auto nbAttentionLayersIn256Window = 1;
282+
auto const nbAttentionLayersIn128Window = 2;
283+
auto const nbAttentionLayersIn256Window = 1;
287284
EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement);
288285
EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement);
289286
}

0 commit comments

Comments
 (0)