Skip to content

Commit c0e4fce

Browse files
authored
[https://nvbugs/5501557][fix] Fix out-of-bounds vector access for model with multiple layer types (#7636)
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 541fd3e commit c0e4fce

File tree

3 files changed

+133
-23
lines changed

3 files changed

+133
-23
lines changed

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,40 @@ namespace tk = tensorrt_llm::kernels;
8383
namespace tensorrt_llm::batch_manager
8484
{
8585

86+
std::map<SizeType32, SizeType32> TrtGptModelInflightBatching::calculateCacheSizePerTokenForDisagg(
87+
ModelConfig const& modelConfig, WorldConfig const& worldConfig,
88+
std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor)
89+
{
90+
// These are the number of attention layers on this PP rank.
91+
auto const numLocalAttnLayers
92+
= modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
93+
// These are the number of attention layers on all previous PP ranks.
94+
auto const numLowerRankAttnLayers = modelConfig.countLowerRankLayers(ModelConfig::LayerType::kATTENTION,
95+
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
96+
// Use global ranks of attention layers to lookup from maxAttentionWindowVec.
97+
auto const startAttnLayerId = numLowerRankAttnLayers;
98+
auto const endAttnLayerId = numLowerRankAttnLayers + numLocalAttnLayers;
99+
auto const numNonUniqueWindowSizes = static_cast<SizeType32>(maxAttentionWindowVec.size());
100+
std::map<SizeType32, std::vector<SizeType32>> uniqueWindowSizeToLayers;
101+
for (SizeType32 layerIdx = startAttnLayerId; layerIdx < endAttnLayerId; layerIdx++)
102+
{
103+
// maxAttentionWindowVec may or may not be stretched to the length of numLayers yet.
104+
// If not stretched yet, we cycle through the window sizes.
105+
auto const windowSize = maxAttentionWindowVec.at(layerIdx % numNonUniqueWindowSizes);
106+
uniqueWindowSizeToLayers[windowSize].push_back(layerIdx);
107+
}
108+
std::map<SizeType32, SizeType32> cacheSizeBytesPerTokenPerWindow;
109+
for (auto const& [windowSize, globalLayerIds] : uniqueWindowSizeToLayers)
110+
{
111+
auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize(
112+
modelConfig, globalLayerIds, isCrossAttention, kvFactor);
113+
auto const cacheSizeBytesPerToken = cacheSizePerToken * BufferDataType(modelConfig.getKvDataType()).getSize();
114+
cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken;
115+
}
116+
117+
return cacheSizeBytesPerTokenPerWindow;
118+
};
119+
86120
bool TrtGptModelInflightBatching::executorConfigIsValid(
87121
ModelConfig const& modelConfig, executor::ExecutorConfig const& executorConfig)
88122
{
@@ -264,32 +298,10 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
264298
}
265299
if (mModelConfig.isTransformerBased() && modelConfig.isKVCacheEnabled())
266300
{
267-
268-
auto calculateCacheSizePerToken
269-
= [](ModelConfig const& modelConfig, WorldConfig const& worldConfig,
270-
std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor)
271-
{
272-
auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = modelConfig.getNumKvHeadsPerLayerLocalRange(
273-
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank(), isCrossAttention);
274-
auto numKvHeadsPerLayer = std::vector<SizeType32>(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd);
275-
auto windowSizeLayers
276-
= BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, modelConfig.getNbLayers());
277-
std::map<SizeType32, SizeType32> cacheSizeBytesPerTokenPerWindow;
278-
for (auto const& [windowSize, managedLayers] : windowSizeLayers)
279-
{
280-
auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize(
281-
modelConfig, managedLayers, isCrossAttention, kvFactor);
282-
auto const cacheSizeBytesPerToken
283-
= cacheSizePerToken * BufferDataType(modelConfig.getKvDataType()).getSize();
284-
cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken;
285-
}
286-
287-
return cacheSizeBytesPerTokenPerWindow;
288-
};
289301
auto cacheTransceiverConfig
290302
= executorConfig.getCacheTransceiverConfig().value_or(executor::CacheTransceiverConfig());
291303

292-
auto const cacheSizeBytesPerTokenPerWindow = calculateCacheSizePerToken(
304+
auto const cacheSizeBytesPerTokenPerWindow = calculateCacheSizePerTokenForDisagg(
293305
mModelConfig, mWorldConfig, getMaxAttentionWindowVec(), mModelConfig.useCrossAttention(), 2);
294306
auto cacheTransPreAllocaSize = kv_cache_manager::CacheTransBufferManager::preAllocBufferSize(
295307
cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig);

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,19 @@ class TrtGptModelInflightBatching : public TrtGptModel
148148

149149
~TrtGptModelInflightBatching() override;
150150

151+
/// @brief Calculate the cache size per token for the disaggregated serving.
152+
/// @param modelConfig Model configuration.
153+
/// @param worldConfig World configuration.
154+
/// @param maxAttentionWindowVec Maximum attention window vector. (may have fewer elements than numLayers, in which
155+
/// case it cycles)
156+
/// @param isCrossAttention Whether the attention is cross attention.
157+
/// @param kvFactor KV factor.
158+
/// @return Cache size per token for the disaggregated layers. Note that window size is not included in the result
159+
/// here.
160+
[[nodiscard]] static std::map<SizeType32, SizeType32> calculateCacheSizePerTokenForDisagg(
161+
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
162+
std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor);
163+
151164
void terminateRequest(LlmRequestPtr const& llmRequest, bool pause = false) override;
152165

153166
/// @brief Terminate request in the next forwardSync call that includes the request.

cpp/tests/unit_tests/executor/executorTestSmall.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <random>
1313
#include <tuple>
14+
#include <unordered_map>
1415

1516
namespace tensorrt_llm::testing
1617
{
@@ -201,4 +202,88 @@ INSTANTIATE_TEST_SUITE_P(Float, DecoderFloatTest, paramGenerator,
201202
return nameStringStream.str();
202203
});
203204

205+
// Helper function to test calculateCacheSizePerToken with given parameters.
206+
std::map<runtime::SizeType32, runtime::SizeType32> calculateCacheSizePerTokenHelper(
207+
std::vector<runtime::SizeType32> const& maxAttentionWindowVec, runtime::SizeType32 kvFactor = 2,
208+
runtime::SizeType32 vocabSize = 32, runtime::SizeType32 nbLayers = 4, runtime::SizeType32 nbAttentionLayers = 4,
209+
runtime::SizeType32 nbRnnLayers = 0, runtime::SizeType32 nbHeads = 8, runtime::SizeType32 hiddenSize = 512,
210+
bool isCrossAttention = false)
211+
{
212+
// Create minimal ModelConfig for testing.
213+
auto modelConfig = runtime::ModelConfig(
214+
vocabSize, nbLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, nvinfer1::DataType::kFLOAT);
215+
modelConfig.useGptAttentionPlugin(true);
216+
modelConfig.setModelVariant(runtime::ModelConfig::ModelVariant::kGpt);
217+
modelConfig.setKVCacheType(runtime::ModelConfig::KVCacheType::kPAGED);
218+
219+
auto const worldConfig = runtime::WorldConfig();
220+
221+
return batch_manager::TrtGptModelInflightBatching::calculateCacheSizePerTokenForDisagg(
222+
modelConfig, worldConfig, maxAttentionWindowVec, isCrossAttention, kvFactor);
223+
}
224+
225+
// Test for TrtGptModelInflightBatching::calculateCacheSizePerToken function with different layer types.
226+
TEST(TrtInflightBatchingTest, CalculateCacheSizePerTokenForDisagg)
227+
{
228+
// Common parameters.
229+
constexpr runtime::SizeType32 nbLayers = 5;
230+
constexpr runtime::SizeType32 hiddenSize = 512;
231+
constexpr runtime::SizeType32 kvFactor = 2;
232+
constexpr runtime::SizeType32 vocabSize = 32;
233+
constexpr runtime::SizeType32 nbHeads = 8;
234+
// Test case 1: Single attention window size - attention layers only.
235+
{
236+
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128};
237+
constexpr runtime::SizeType32 nbAttentionLayers = 5;
238+
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
239+
constexpr runtime::SizeType32 nbRnnLayers = 0;
240+
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
241+
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
242+
EXPECT_EQ(result.size(), 1);
243+
EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement);
244+
}
245+
246+
// Test case 2: Multiple attention window sizes - attention layers only.
247+
{
248+
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128, 256};
249+
constexpr runtime::SizeType32 nbAttentionLayers = 5;
250+
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
251+
constexpr runtime::SizeType32 nbRnnLayers = 0;
252+
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
253+
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
254+
EXPECT_EQ(result.size(), 2);
255+
auto const nbAttentionLayersIn128Window = 3;
256+
auto const nbAttentionLayersIn256Window = 2;
257+
EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement);
258+
EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement);
259+
}
260+
261+
// Test case 3: Single attention window size - attention and rnn layers.
262+
{
263+
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128};
264+
constexpr runtime::SizeType32 nbAttentionLayers = 3;
265+
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
266+
constexpr runtime::SizeType32 nbRnnLayers = 2;
267+
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
268+
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
269+
EXPECT_EQ(result.size(), 1);
270+
EXPECT_EQ(result.at(128), nbAttentionLayers * kvFactor * hiddenSize * numBytesPerFloatElement);
271+
}
272+
273+
// Test case 4: Multiple attention window sizes - attention and rnn layers.
274+
{
275+
std::vector<runtime::SizeType32> maxAttentionWindowVec = {128, 256};
276+
constexpr runtime::SizeType32 nbAttentionLayers = 3;
277+
constexpr runtime::SizeType32 numBytesPerFloatElement = 4;
278+
constexpr runtime::SizeType32 nbRnnLayers = 2;
279+
auto result = calculateCacheSizePerTokenHelper(maxAttentionWindowVec, kvFactor, vocabSize, nbLayers,
280+
nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, false);
281+
EXPECT_EQ(result.size(), 2);
282+
auto const nbAttentionLayersIn128Window = 2;
283+
auto const nbAttentionLayersIn256Window = 1;
284+
EXPECT_EQ(result.at(128), nbAttentionLayersIn128Window * kvFactor * hiddenSize * numBytesPerFloatElement);
285+
EXPECT_EQ(result.at(256), nbAttentionLayersIn256Window * kvFactor * hiddenSize * numBytesPerFloatElement);
286+
}
287+
}
288+
204289
} // namespace tensorrt_llm::testing

0 commit comments

Comments
 (0)