@@ -75,7 +75,6 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques
7575bool CacheFormatter::needSendCache (
7676 CacheState const & selfConfig, CacheState const & destConfig, runtime::SizeType32 selfIdx)
7777{
78- // int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
7978 auto targetInfo = executor::kv_cache::targetIRanks (destConfig, selfConfig, selfIdx);
8079 if (targetInfo.mDupHeadFactor <= 1 )
8180 {
@@ -91,12 +90,17 @@ bool CacheFormatter::needSendCache(
9190 selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup;
9291 }
9392
93+ // only TP rank % dupHeadFactor == 0 need to send cache.
9494 return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0 ;
9595}
9696
9797void checkAlternateWindow (BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const & selfConfig,
9898 BaseCacheFormatter::CacheState const & destConfig)
9999{
100+ // TODO: VSWA do not support uneven layer per PP.
101+ // if gen PP and context PP are different, cache formatter only support alternative window like gpt-oss.
102+ // which is one layer is WSA, and another layer is Full attention.
103+
100104 auto numPools = cacheManager->getBlockManager ().getNumPools ();
101105 auto layerNum = cacheManager->getBlockManager ().getNumLayers ();
102106
@@ -163,6 +167,7 @@ void CacheFormatter::format(TransferSession& session)
163167 auto const & destConfig = session.getOtherState ().getCacheState ().value ();
164168 auto const selfIdx = session.getSelfState ().getCommState ().value ().getSelfIdx ();
165169 auto & bufferManager = session.getBufferManager ();
170+ // Some TP rank don't need to send cache since duplicate header is not needed.
166171 if (!needSendCache (selfConfig, destConfig, selfIdx))
167172 {
168173 return ;
@@ -214,7 +219,7 @@ void CacheFormatter::format(TransferSession& session)
214219 int blockNum = 0 ;
215220
216221 size_t allCacheBlockSize = 0 ;
217-
222+ // gather cache blocks of the request.
218223 std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> inputKvCacheBlocks;
219224 for (auto poolIdx = 0 ; poolIdx < numPools; poolIdx++)
220225 {
@@ -224,6 +229,7 @@ void CacheFormatter::format(TransferSession& session)
224229 " window size already exists, which is not supported" );
225230 inputKvCacheBlocks.emplace (window, std::vector<runtime::ITensor::SharedPtr>());
226231 auto maxBlockThisWindow = window / selfConfig.getModelConfig ().mTokensPerBlock ;
232+ // only block in window will be sent.
227233 SizeType32 blockNumThisWindow = 0 ;
228234 for (auto it = blockRange.begin (); it != blockRange.end (); ++it)
229235 {
@@ -278,6 +284,14 @@ void CacheFormatter::format(TransferSession& session)
278284 return ;
279285 }
280286
287+ // formatter flow
288+ // 1. gather cache blocks of the request.
289+ // 2. compute the buffer size for each target.
290+ // 3. prepare the pre-allocated buffer for each target according to the buffer size.
291+ // 4. call splitKVCacheDispatch to split the cache blocks according to the different parallelis and gather the
292+ // cache blocks to the corresponding buffer.
293+ // 5. send the buffer to the corresponding target. Ideally, we send only once (one buffer) for each target.
294+
281295 auto cacheBufferId = mCacheTransBufferManager ->assignBufferIndexForSend ();
282296 int peerDuplicateHeadFactor = targetInfo.mPeerDupHeadFactor ;
283297 auto targetNum = connections.size ();
@@ -286,7 +300,7 @@ void CacheFormatter::format(TransferSession& session)
286300 int selfAttentionLayerNum
287301 = selfConfig.getParallelConfig ()
288302 .mAttentionLayerNumPerPP [selfIdx / selfConfig.getParallelConfig ().mTensorParallelism ];
289-
303+ // since layer num per pp rank maybe different, we need to compute the buffer size for each target.
290304 auto getBufferSizeForTarget = [&]()
291305 {
292306 std::vector<size_t > bufferSizeForTarget (targetNum, 0 );
@@ -419,7 +433,7 @@ void CacheFormatter::format(TransferSession& session)
419433 }
420434 else
421435 {
422- // concurrency num
436+ // concurrency num should <=bufferCoverTargetNum to avoid data-race.
423437 auto concurrencyNum
424438 = std::min (std::max (static_cast <size_t >(1 ), bufferCoverTargetNum), connections.size ());
425439
@@ -505,6 +519,7 @@ void CacheFormatter::unformat(TransferSession& session)
505519 TLLM_CHECK (!outputBuffersPerWindow.empty ());
506520 if (outputBuffersPerWindow.size () > 1 )
507521 {
522+ // We only support limited case for VSWA.
508523 if (selfConfig.getParallelConfig ().mPipelineParallelism != destConfig.getParallelConfig ().mPipelineParallelism )
509524 {
510525 checkAlternateWindow (mCacheManager , selfConfig, destConfig);
@@ -603,6 +618,13 @@ void CacheFormatter::unformat(TransferSession& session)
603618 ctxReqId);
604619 return ;
605620 }
621+ // unformatted flow
622+ // 1. gather cache blocks of the request.
623+ // 2. compute the buffer size for each target.
624+ // 3. prepare the pre-allocated buffer for each target according to the buffer size.
625+ // 4. receive the buffer from the corresponding target. Ideally, we receive only once (one buffer) for each
626+ // target.
627+ // 5. call concatKvCacheV2Dispatch to concatenate the cache blocks according to the different parallelis
606628
607629 runtime::ITensor::SharedPtr recvBufferTemp;
608630 std::vector<runtime::ITensor::SharedPtr> recvSplitCaches;
@@ -615,7 +637,7 @@ void CacheFormatter::unformat(TransferSession& session)
615637 int selfAttentionLayerNum
616638 = selfConfig.getParallelConfig ()
617639 .mAttentionLayerNumPerPP [selfIdx / selfConfig.getParallelConfig ().mTensorParallelism ];
618- auto getTargetBufferEleSzie = [&]()
640+ auto getTargetBufferEleSize = [&]()
619641 {
620642 if (outputBuffersPerWindow.size () > 1 )
621643 {
@@ -627,14 +649,17 @@ void CacheFormatter::unformat(TransferSession& session)
627649 // TODO: LayerNumbufferTargetNum for VWSA
628650 return std::make_pair (bufferSizeForTarget, std::vector<SizeType32>(targetNum, 0 ));
629651 }
630- size_t valideTpSize = pickUpConnections.size () / targetInfo.mDomainPPSize ;
631- TLLM_CHECK_WITH_INFO (cacheBlockSizeSum % valideTpSize == 0 ,
632- " cacheBlockSizeSum must be divisible by valideTpSize %ld" , valideTpSize);
633- TLLM_CHECK_WITH_INFO ((cacheBlockSizeSum % (selfAttentionLayerNum * valideTpSize)) == 0 ,
634- " cacheBlockSizeSum must be divisible by valideTpSize %ld * selfAttentionLayerNum %d" , valideTpSize,
652+ // for duplicate header, gen will not recv from TP which has duplicate header, and will not prepare
653+ // buffer for it.
654+ size_t validTpSize = pickUpConnections.size () / targetInfo.mDomainPPSize ;
655+ TLLM_CHECK_WITH_INFO (cacheBlockSizeSum % validTpSize == 0 ,
656+ " cacheBlockSizeSum must be divisible by validTpSize %ld" , validTpSize);
657+ TLLM_CHECK_WITH_INFO ((cacheBlockSizeSum % (selfAttentionLayerNum * validTpSize)) == 0 ,
658+ " cacheBlockSizeSum must be divisible by validTpSize %ld * selfAttentionLayerNum %d" , validTpSize,
635659 selfAttentionLayerNum);
636660 TLLM_CHECK (targetNum == pickUpConnections.size ());
637- size_t baseEleSize = cacheBlockSizeSum / (valideTpSize * selfAttentionLayerNum);
661+ // the sum of buffer size is cacheBlockSizeSum.
662+ size_t baseEleSize = cacheBlockSizeSum / (validTpSize * selfAttentionLayerNum);
638663
639664 std::vector<size_t > bufferEleSizes (targetNum, 0 );
640665 std::vector<SizeType32> LayerNumbufferTargetNum (targetNum, 0 );
@@ -647,7 +672,7 @@ void CacheFormatter::unformat(TransferSession& session)
647672 }
648673 return std::make_pair (bufferEleSizes, LayerNumbufferTargetNum);
649674 };
650- auto [bufferEleSizes, LayerNumbufferTargetNum] = getTargetBufferEleSzie ();
675+ auto [bufferEleSizes, LayerNumbufferTargetNum] = getTargetBufferEleSize ();
651676
652677 size_t remainNoCoverTargetNum = 0 ;
653678 size_t bufferCoverTargetNum = 0 ;
0 commit comments