1616
1717#pragma once
1818
19+ #include " tensorrt_llm/batch_manager/kvCacheConnector.h"
1920#include " tensorrt_llm/batch_manager/kvCacheEventManager.h"
2021#include " tensorrt_llm/batch_manager/kvCacheType.h"
2122#include " tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
@@ -479,7 +480,6 @@ class KVCacheBlockPool
479480 SizeType32 numKvHeads;
480481 SizeType32 sizePerHead;
481482 SizeType32 tokensPerBlock;
482- SizeType32 quantSize;
483483 SizeType32 blockSize;
484484
485485 // Memory pools. Primary is fast memory, secondary is slower memory used for offloading.
@@ -490,15 +490,14 @@ class KVCacheBlockPool
490490 bool containsBlockScales;
491491
492492 KVCacheBlockPool (SizeType32 numLayers, SizeType32 kvFactor, SizeType32 numKvHeads, SizeType32 sizePerHead,
493- SizeType32 tokensPerBlock, SizeType32 quantSize, runtime::ITensor::SharedPtr primaryPtr = nullptr ,
493+ SizeType32 tokensPerBlock, runtime::ITensor::SharedPtr primaryPtr = nullptr ,
494494 runtime::ITensor::SharedPtr secondaryPtr = nullptr , bool containsBlockScales = false )
495495 : numLayers(numLayers)
496496 , kvFactor(kvFactor)
497497 , numKvHeads(numKvHeads)
498498 , sizePerHead(sizePerHead)
499499 , tokensPerBlock(tokensPerBlock)
500- , quantSize(quantSize)
501- , blockSize((numKvHeads * sizePerHead * tokensPerBlock) / quantSize)
500+ , blockSize(numKvHeads * sizePerHead * tokensPerBlock)
502501 , primaryPtr(std::move(primaryPtr))
503502 , secondaryPtr(std::move(secondaryPtr))
504503 , containsBlockScales(containsBlockScales)
@@ -538,7 +537,8 @@ class WindowBlockManager
538537 SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
539538 SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
540539 bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
541- std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse);
540+ std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
541+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager);
542542
543543 ~WindowBlockManager ();
544544
@@ -646,6 +646,15 @@ class WindowBlockManager
646646 return mPools .at (poolIdx).blockSize ;
647647 }
648648
649+ [[nodiscard]] SizeType32 getNumEltsPerContainer () const
650+ {
651+ #ifdef ENABLE_FP4
652+ return mDataType == nvinfer1::DataType::kFP4 ? 2 : 1 ;
653+ #else
654+ return 1 ;
655+ #endif
656+ }
657+
649658 [[nodiscard]] SizeType32 getNumPools (bool includeBlockScalePools = true ) const noexcept
650659 {
651660 if (includeBlockScalePools)
@@ -835,6 +844,8 @@ class WindowBlockManager
835844 bool mEnablePartialReuse ;
836845 // Whether partially matched blocks that are already in use should be copied and reused.
837846 bool mCopyOnPartialReuse ;
847+ // The kv cache connector manager
848+ std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager ;
838849};
839850
840851class BlockManager
@@ -852,7 +863,8 @@ class BlockManager
852863 SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF ,
853864 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt ,
854865 std::shared_ptr<KVCacheEventManager> eventManager = nullptr , bool enablePartialReuse = true ,
855- bool copyOnPartialReuse = true );
866+ bool copyOnPartialReuse = true ,
867+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr );
856868
857869 BlockManager (BlockManager const &) = delete ;
858870 BlockManager& operator =(BlockManager const &) = delete ;
@@ -1238,6 +1250,8 @@ class BaseKVCacheManager
12381250
12391251 [[nodiscard]] virtual runtime::ITensor::SharedPtr getBlockPoolPointers () const = 0;
12401252
1253+ [[nodiscard]] virtual runtime::ITensor::SharedPtr getBlockScalePoolPointers () const = 0;
1254+
12411255 [[nodiscard]] virtual runtime::ITensor::SharedPtr getLayerToPoolMapping () const = 0;
12421256
12431257 virtual void getBlockOffsetsOfBatch (
@@ -1287,6 +1301,7 @@ class BaseKVCacheManager
12871301 LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
12881302 = 0;
12891303
1304+ [[nodiscard]] virtual runtime::ITensor::SharedPtr getUniquePrimaryPool () const = 0;
12901305 [[nodiscard]] virtual runtime::ITensor::SharedPtr getPrimaryPool (SizeType32 layer_idx) const = 0;
12911306 [[nodiscard]] virtual SizeType32 getPoolLayerIdx (SizeType32 layer_idx) const = 0;
12921307
@@ -1373,7 +1388,8 @@ class KVCacheManager : public BaseKVCacheManager
13731388 bool enableBlockReuse = false , bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
13741389 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt ,
13751390 std::shared_ptr<KVCacheEventManager> eventManager = nullptr , bool enablePartialReuse = true ,
1376- bool copyOnpartialReuse = true );
1391+ bool copyOnpartialReuse = true ,
1392+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr );
13771393
13781394 KVCacheManager (std::vector<SizeType32> const & numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
13791395 BlocksPerWindow const & blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@@ -1383,7 +1399,8 @@ class KVCacheManager : public BaseKVCacheManager
13831399 bool enableBlockReuse = false , bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
13841400 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt ,
13851401 std::shared_ptr<KVCacheEventManager> eventManager = nullptr , bool enablePartialReuse = true ,
1386- bool copyOnpartialReuse = true );
1402+ bool copyOnpartialReuse = true ,
1403+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr );
13871404
13881405 KVCacheManager (SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
13891406 BlocksPerWindow const & blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@@ -1393,7 +1410,8 @@ class KVCacheManager : public BaseKVCacheManager
13931410 bool enableBlockReuse = true , bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
13941411 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt ,
13951412 std::shared_ptr<KVCacheEventManager> eventManager = nullptr , bool enablePartialReuse = true ,
1396- bool copyOnpartialReuse = true );
1413+ bool copyOnpartialReuse = true ,
1414+ std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr );
13971415
13981416 KVCacheManager (SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
13991417 BlocksPerWindow const & blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
@@ -1543,7 +1561,7 @@ class KVCacheManager : public BaseKVCacheManager
15431561 return mLayerToPoolMapping ;
15441562 }
15451563
1546- [[nodiscard]] runtime::ITensor::SharedPtr getBlockScalePoolPointers () const
1564+ [[nodiscard]] runtime::ITensor::SharedPtr getBlockScalePoolPointers () const override
15471565 {
15481566 // TODO: add a new optional model input so the attention plugin can access these
15491567 return mBlockScalePoolPointers ;
@@ -1624,6 +1642,7 @@ class KVCacheManager : public BaseKVCacheManager
16241642 std::vector<SizeType32> getNewlyAllocatedBlockIds (
16251643 LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override ;
16261644
1645+ runtime::ITensor::SharedPtr getUniquePrimaryPool () const override ;
16271646 runtime::ITensor::SharedPtr getPrimaryPool (SizeType32 layer_idx) const override ;
16281647
16291648 SizeType32 getPoolLayerIdx (SizeType32 layer_idx) const override
0 commit comments