@@ -269,16 +269,29 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
269269 = [](ModelConfig const & modelConfig, WorldConfig const & worldConfig,
270270 std::vector<SizeType32> const & maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor)
271271 {
272- auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = modelConfig.getNumKvHeadsPerLayerLocalRange (
273- worldConfig.getPipelineParallelism (), worldConfig.getPipelineParallelRank (), isCrossAttention);
274- auto numKvHeadsPerLayer = std::vector<SizeType32>(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd);
275- auto windowSizeLayers
276- = BaseKVCacheManager::groupLayersByWindowSize (maxAttentionWindowVec, modelConfig.getNbLayers ());
272+ // These are the number of attention layers on this PP rank.
273+ const auto numLocalAttnLayers = modelConfig.getNbAttentionLayers (
274+ worldConfig.getPipelineParallelism (), worldConfig.getPipelineParallelRank ());
275+ // These are the number of attention layers on all previous PP ranks.
276+ const auto numLowerRankAttnLayers = modelConfig.countLowerRankLayers (ModelConfig::LayerType::kATTENTION ,
277+ worldConfig.getPipelineParallelism (), worldConfig.getPipelineParallelRank ());
278+ // Use global ranks of attention layers to lookup from maxAttentionWindowVec.
279+ const auto startAttnLayerId = numLowerRankAttnLayers;
280+ const auto endAttnLayerId = numLowerRankAttnLayers + numLocalAttnLayers;
281+ auto const numNonUniqueWindowSizes = static_cast <SizeType32>(maxAttentionWindowVec.size ());
282+ std::map<SizeType32, std::vector<SizeType32>> uniqueWindowSizeToLayers;
283+ for (SizeType32 layerIdx = startAttnLayerId; layerIdx < endAttnLayerId; layerIdx++)
284+ {
285+ // maxAttentionWindowVec may or may not be stretched to the length of numLayers yet.
286+ // If not stretched yet, we cycle through the window sizes.
287+ auto const windowSize = maxAttentionWindowVec.at (layerIdx % numNonUniqueWindowSizes);
288+ uniqueWindowSizeToLayers[windowSize].push_back (layerIdx);
289+ }
277290 std::map<SizeType32, SizeType32> cacheSizeBytesPerTokenPerWindow;
278- for (auto const & [windowSize, managedLayers ] : windowSizeLayers )
291+ for (auto const & [windowSize, globalLayerIds ] : uniqueWindowSizeToLayers )
279292 {
280293 auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize (
281- modelConfig, managedLayers , isCrossAttention, kvFactor);
294+ modelConfig, globalLayerIds , isCrossAttention, kvFactor);
282295 auto const cacheSizeBytesPerToken
283296 = cacheSizePerToken * BufferDataType (modelConfig.getKvDataType ()).getSize ();
284297 cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken;
0 commit comments