@@ -204,19 +204,14 @@ INSTANTIATE_TEST_SUITE_P(Float, DecoderFloatTest, paramGenerator,
204204
205205// Helper function to test calculateCacheSizePerToken with given parameters.
206206std::map<runtime::SizeType32, runtime::SizeType32> calculateCacheSizePerTokenHelper (
207- std::vector<runtime::SizeType32> const & maxAttentionWindowVec,
208- runtime::SizeType32 kvFactor = 2 ,
209- runtime::SizeType32 vocabSize = 32 ,
210- runtime::SizeType32 nbLayers = 4 ,
211- runtime::SizeType32 nbAttentionLayers = 4 ,
212- runtime::SizeType32 nbRnnLayers = 0 ,
213- runtime::SizeType32 nbHeads = 8 ,
214- runtime::SizeType32 hiddenSize = 512 ,
207+ std::vector<runtime::SizeType32> const & maxAttentionWindowVec, runtime::SizeType32 kvFactor = 2 ,
208+ runtime::SizeType32 vocabSize = 32 , runtime::SizeType32 nbLayers = 4 , runtime::SizeType32 nbAttentionLayers = 4 ,
209+ runtime::SizeType32 nbRnnLayers = 0 , runtime::SizeType32 nbHeads = 8 , runtime::SizeType32 hiddenSize = 512 ,
215210 bool isCrossAttention = false )
216211{
217212 // Create minimal ModelConfig for testing.
218- auto modelConfig = runtime::ModelConfig (vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers,
219- nbHeads, hiddenSize, nvinfer1::DataType::kFLOAT );
213+ auto modelConfig = runtime::ModelConfig (
214+ vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, nvinfer1::DataType::kFLOAT );
220215 modelConfig.useGptAttentionPlugin (true );
221216 modelConfig.setModelVariant (runtime::ModelConfig::ModelVariant::kGpt );
222217 modelConfig.setKVCacheType (runtime::ModelConfig::KVCacheType::kPAGED );
@@ -242,8 +237,8 @@ TEST(TrtInflightBatchingTest, CalculateCacheSizePerTokenForDisagg)
242237 constexpr runtime::SizeType32 nbAttentionLayers = 5 ;
243238 constexpr runtime::SizeType32 numBytesPerFloatElement = 4 ;
244239 constexpr runtime::SizeType32 nbRnnLayers = 0 ;
245- auto result = calculateCacheSizePerTokenHelper (
246- maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false );
240+ auto result = calculateCacheSizePerTokenHelper (maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
241+ nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false );
247242 EXPECT_EQ (result.size (), 1 );
248243 EXPECT_EQ (result.at (128 ), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement);
249244 }
@@ -254,22 +249,23 @@ TEST(TrtInflightBatchingTest, CalculateCacheSizePerTokenForDisagg)
254249 constexpr runtime::SizeType32 nbAttentionLayers = 5 ;
255250 constexpr runtime::SizeType32 numBytesPerFloatElement = 4 ;
256251 constexpr runtime::SizeType32 nbRnnLayers = 0 ;
257- auto result = calculateCacheSizePerTokenHelper (maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false );
252+ auto result = calculateCacheSizePerTokenHelper (maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
253+ nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false );
258254 EXPECT_EQ (result.size (), 2 );
259- const auto nbAttentionLayersIn128Window = 3 ;
260- const auto nbAttentionLayersIn256Window = 2 ;
255+ auto const nbAttentionLayersIn128Window = 3 ;
256+ auto const nbAttentionLayersIn256Window = 2 ;
261257 EXPECT_EQ (result.at (128 ), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement);
262258 EXPECT_EQ (result.at (256 ), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement);
263259 }
264260
265261 // Test case 3: Single attention window size - attention and rnn layers.
266262 {
267- std::vector<runtime::SizeType32> maxAttentionWindowVec = {128 };
263+ std::vector<runtime::SizeType32> maxAttentionWindowVec = {128 };
268264 constexpr runtime::SizeType32 nbAttentionLayers = 3 ;
269265 constexpr runtime::SizeType32 numBytesPerFloatElement = 4 ;
270266 constexpr runtime::SizeType32 nbRnnLayers = 2 ;
271- auto result = calculateCacheSizePerTokenHelper (
272- maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false );
267+ auto result = calculateCacheSizePerTokenHelper (maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
268+ nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false );
273269 EXPECT_EQ (result.size (), 1 );
274270 EXPECT_EQ (result.at (128 ), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement);
275271 }
@@ -280,10 +276,11 @@ TEST(TrtInflightBatchingTest, CalculateCacheSizePerTokenForDisagg)
280276 constexpr runtime::SizeType32 nbAttentionLayers = 3 ;
281277 constexpr runtime::SizeType32 numBytesPerFloatElement = 4 ;
282278 constexpr runtime::SizeType32 nbRnnLayers = 2 ;
283- auto result = calculateCacheSizePerTokenHelper (maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false );
279+ auto result = calculateCacheSizePerTokenHelper (maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
280+ nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false );
284281 EXPECT_EQ (result.size (), 2 );
285- const auto nbAttentionLayersIn128Window = 2 ;
286- const auto nbAttentionLayersIn256Window = 1 ;
282+ auto const nbAttentionLayersIn128Window = 2 ;
283+ auto const nbAttentionLayersIn256Window = 1 ;
287284 EXPECT_EQ (result.at (128 ), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement);
288285 EXPECT_EQ (result.at (256 ), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement);
289286 }
0 commit comments