Skip to content

Commit ccbd9be

Browse files
committed
[https://nvbugs/5501557][fix] Fix out-of-bounds vector access for model with multiple layer types
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent f6365e6 commit ccbd9be

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,16 +269,29 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
269269
= [](ModelConfig const& modelConfig, WorldConfig const& worldConfig,
270270
std::vector<SizeType32> const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor)
271271
{
272-
auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = modelConfig.getNumKvHeadsPerLayerLocalRange(
273-
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank(), isCrossAttention);
274-
auto numKvHeadsPerLayer = std::vector<SizeType32>(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd);
275-
auto windowSizeLayers
276-
= BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, modelConfig.getNbLayers());
272+
// These are the number of attention layers on this PP rank.
273+
const auto numLocalAttnLayers = modelConfig.getNbAttentionLayers(
274+
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
275+
// These are the number of attention layers on all previous PP ranks.
276+
const auto numLowerRankAttnLayers = modelConfig.countLowerRankLayers(ModelConfig::LayerType::kATTENTION,
277+
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
278+
// Use global ranks of attention layers to lookup from maxAttentionWindowVec.
279+
const auto startAttnLayerId = numLowerRankAttnLayers;
280+
const auto endAttnLayerId = numLowerRankAttnLayers + numLocalAttnLayers;
281+
auto const numNonUniqueWindowSizes = static_cast<SizeType32>(maxAttentionWindowVec.size());
282+
std::map<SizeType32, std::vector<SizeType32>> uniqueWindowSizeToLayers;
283+
for (SizeType32 layerIdx = startAttnLayerId; layerIdx < endAttnLayerId; layerIdx++)
284+
{
285+
// maxAttentionWindowVec may or may not be stretched to the length of numLayers yet.
286+
// If not stretched yet, we cycle through the window sizes.
287+
auto const windowSize = maxAttentionWindowVec.at(layerIdx % numNonUniqueWindowSizes);
288+
uniqueWindowSizeToLayers[windowSize].push_back(layerIdx);
289+
}
277290
std::map<SizeType32, SizeType32> cacheSizeBytesPerTokenPerWindow;
278-
for (auto const& [windowSize, managedLayers] : windowSizeLayers)
291+
for (auto const& [windowSize, globalLayerIds] : uniqueWindowSizeToLayers)
279292
{
280293
auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize(
281-
modelConfig, managedLayers, isCrossAttention, kvFactor);
294+
modelConfig, globalLayerIds, isCrossAttention, kvFactor);
282295
auto const cacheSizeBytesPerToken
283296
= cacheSizePerToken * BufferDataType(modelConfig.getKvDataType()).getSize();
284297
cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken;

0 commit comments

Comments
 (0)