|
11 | 11 |
|
12 | 12 | #include <random> |
13 | 13 | #include <tuple> |
| 14 | +#include <unordered_map> |
14 | 15 |
|
15 | 16 | namespace tensorrt_llm::testing |
16 | 17 | { |
@@ -201,4 +202,88 @@ INSTANTIATE_TEST_SUITE_P(Float, DecoderFloatTest, paramGenerator, |
201 | 202 | return nameStringStream.str(); |
202 | 203 | }); |
203 | 204 |
|
| 205 | +// Helper function to test calculateCacheSizePerToken with given parameters. |
| 206 | +std::map<runtime::SizeType32, runtime::SizeType32> calculateCacheSizePerTokenHelper( |
| 207 | + std::vector<runtime::SizeType32> const& maxAttentionWindowVec, runtime::SizeType32 kvFactor = 2, |
| 208 | + runtime::SizeType32 vocabSize = 32, runtime::SizeType32 nbLayers = 4, runtime::SizeType32 nbAttentionLayers = 4, |
| 209 | + runtime::SizeType32 nbRnnLayers = 0, runtime::SizeType32 nbHeads = 8, runtime::SizeType32 hiddenSize = 512, |
| 210 | + bool isCrossAttention = false) |
| 211 | +{ |
| 212 | + // Create minimal ModelConfig for testing. |
| 213 | + auto modelConfig = runtime::ModelConfig( |
| 214 | + vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, nvinfer1::DataType::kFLOAT); |
| 215 | + modelConfig.useGptAttentionPlugin(true); |
| 216 | + modelConfig.setModelVariant(runtime::ModelConfig::ModelVariant::kGpt); |
| 217 | + modelConfig.setKVCacheType(runtime::ModelConfig::KVCacheType::kPAGED); |
| 218 | + |
| 219 | + auto const worldConfig = runtime::WorldConfig(); |
| 220 | + |
| 221 | + return batch_manager::TrtGptModelInflightBatching::calculateCacheSizePerTokenForDisagg( |
| 222 | + modelConfig, worldConfig, maxAttentionWindowVec, isCrossAttention, kvFactor); |
| 223 | +} |
| 224 | + |
| 225 | +// Test for TrtGptModelInflightBatching::calculateCacheSizePerToken function with different layer types. |
| 226 | +TEST(TrtInflightBatchingTest, CalculateCacheSizePerTokenForDisagg) |
| 227 | +{ |
| 228 | + // Common parameters. |
| 229 | + constexpr runtime::SizeType32 nbLayers = 5; |
| 230 | + constexpr runtime::SizeType32 hiddenSize = 512; |
| 231 | + constexpr runtime::SizeType32 kvFactor = 2; |
| 232 | + constexpr runtime::SizeType32 vocabSize = 32; |
| 233 | + constexpr runtime::SizeType32 nbHeads = 8; |
| 234 | + // Test case 1: Single attention window size - attention layers only. |
| 235 | + { |
| 236 | + std::vector<runtime::SizeType32> maxAttentionWindowVec = {128}; |
| 237 | + constexpr runtime::SizeType32 nbAttentionLayers = 5; |
| 238 | + constexpr runtime::SizeType32 numBytesPerFloatElement = 4; |
| 239 | + constexpr runtime::SizeType32 nbRnnLayers = 0; |
| 240 | + auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, |
| 241 | + nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false); |
| 242 | + EXPECT_EQ(result.size(), 1); |
| 243 | + EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement); |
| 244 | + } |
| 245 | + |
| 246 | + // Test case 2: Multiple attention window sizes - attention layers only. |
| 247 | + { |
| 248 | + std::vector<runtime::SizeType32> maxAttentionWindowVec = {128, 256}; |
| 249 | + constexpr runtime::SizeType32 nbAttentionLayers = 5; |
| 250 | + constexpr runtime::SizeType32 numBytesPerFloatElement = 4; |
| 251 | + constexpr runtime::SizeType32 nbRnnLayers = 0; |
| 252 | + auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, |
| 253 | + nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false); |
| 254 | + EXPECT_EQ(result.size(), 2); |
| 255 | + auto const nbAttentionLayersIn128Window = 3; |
| 256 | + auto const nbAttentionLayersIn256Window = 2; |
| 257 | + EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement); |
| 258 | + EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement); |
| 259 | + } |
| 260 | + |
| 261 | + // Test case 3: Single attention window size - attention and rnn layers. |
| 262 | + { |
| 263 | + std::vector<runtime::SizeType32> maxAttentionWindowVec = {128}; |
| 264 | + constexpr runtime::SizeType32 nbAttentionLayers = 3; |
| 265 | + constexpr runtime::SizeType32 numBytesPerFloatElement = 4; |
| 266 | + constexpr runtime::SizeType32 nbRnnLayers = 2; |
| 267 | + auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, |
| 268 | + nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false); |
| 269 | + EXPECT_EQ(result.size(), 1); |
| 270 | + EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement); |
| 271 | + } |
| 272 | + |
| 273 | + // Test case 4: Multiple attention window sizes - attention and rnn layers. |
| 274 | + { |
| 275 | + std::vector<runtime::SizeType32> maxAttentionWindowVec = {128, 256}; |
| 276 | + constexpr runtime::SizeType32 nbAttentionLayers = 3; |
| 277 | + constexpr runtime::SizeType32 numBytesPerFloatElement = 4; |
| 278 | + constexpr runtime::SizeType32 nbRnnLayers = 2; |
| 279 | + auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers, |
| 280 | + nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false); |
| 281 | + EXPECT_EQ(result.size(), 2); |
| 282 | + auto const nbAttentionLayersIn128Window = 2; |
| 283 | + auto const nbAttentionLayersIn256Window = 1; |
| 284 | + EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement); |
| 285 | + EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement); |
| 286 | + } |
| 287 | +} |
| 288 | + |
204 | 289 | } // namespace tensorrt_llm::testing |
0 commit comments