@@ -3053,189 +3053,6 @@ TEST_F(KVCacheManagerTest, KVCacheManagerVariableWindowAttentionWithReuseTest)
30533053 assertBlocks (seq3, {4 }, {6 });
30543054}
30553055
3056- namespace
3057- {
3058- KVCacheManager setupKvCacheManagerForHashTest (bool enableBlockReuse)
3059- {
3060- auto constexpr numLayers = 2 ;
3061- auto constexpr numHeads = 2 ;
3062- auto constexpr sizePerHead = 64 ;
3063- auto constexpr tokensPerBlock = 4 ;
3064- auto constexpr maxNumSequences = 8 ;
3065- auto constexpr maxBeamWidth = 1 ;
3066- auto constexpr sinkTokenLength = 0 ;
3067- auto const stream = std::make_shared<tr::CudaStream>();
3068-
3069- auto constexpr maxBlocksPerSeq = 8 ;
3070- auto constexpr maxNumTokens = tokensPerBlock * maxBlocksPerSeq;
3071- auto constexpr maxAttentionWindow = maxNumTokens;
3072-
3073- auto constexpr blocksInPrimaryPool = 16 ;
3074- auto constexpr blocksInSecondaryPool = 0 ;
3075-
3076- auto constexpr onboardBlocks = true ;
3077-
3078- auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}};
3079-
3080- return KVCacheManager (std::vector<SizeType32>(numLayers, numHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
3081- maxNumSequences, maxBeamWidth, std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt ,
3082- nvinfer1::DataType::kHALF , sinkTokenLength, stream, std::nullopt , enableBlockReuse, onboardBlocks,
3083- CacheType::kSELF , std::nullopt , nullptr ,
3084- /* enableHashKey*/ true );
3085- }
3086-
3087- std::vector<size_t > getHashAndRetrieveBlocksByHashTest (
3088- BlockManager const & blockManager, std::vector<KVCacheBlock::IdType> const & blockIds, SizeType32 windowSize)
3089- {
3090- std::vector<size_t > blockHashes;
3091- for (auto blockId : blockIds)
3092- {
3093- blockHashes.emplace_back (blockManager.getBlockById (blockId, windowSize)->getHash ());
3094- }
3095- std::vector<BlockPtr> blockPtrs;
3096- for (auto hash : blockHashes)
3097- {
3098- auto range = blockManager.getBlocksByHash (hash, windowSize);
3099- BlockPtr const prevBlock = blockPtrs.empty () ? nullptr : blockPtrs.back ();
3100- BlockPtr thisBlock = nullptr ;
3101- for (auto it = range.first ; it != range.second ; ++it)
3102- {
3103- if (it->second ->getPrevBlockInSeq () == prevBlock)
3104- {
3105- thisBlock = it->second ;
3106- break ;
3107- }
3108- }
3109- EXPECT_NE (thisBlock, nullptr );
3110- blockPtrs.emplace_back (thisBlock);
3111- }
3112- EXPECT_EQ (blockHashes.size (), blockPtrs.size ());
3113- for (size_t i = 0 ; i < blockHashes.size (); i++)
3114- {
3115- EXPECT_EQ (blockManager.getBlockById (blockIds[i], windowSize), blockPtrs[i]);
3116- }
3117- return blockHashes;
3118- }
3119- } // namespace
3120-
3121- TEST_F (KVCacheManagerTest, KVCacheManagerHashKeyTest)
3122- {
3123- auto kvCacheManager = setupKvCacheManagerForHashTest (false );
3124-
3125- auto const & blockManager = kvCacheManager.getBlockManager ();
3126-
3127- SizeType32 constexpr maxNewTokens = 4 ;
3128-
3129- // prepare tokens with token[i] = 1000 + i
3130- TokenIdType constexpr firstToken = 1000 ;
3131-
3132- auto constexpr beamWidth = 1 ;
3133- tr::SamplingConfig const samplingConfig{beamWidth};
3134- bool constexpr isStreaming{false };
3135-
3136- SizeType32 requestId = 0 ;
3137- int inputLength = 16 ;
3138- auto inputTokens = std::make_shared<VecTokens>(inputLength);
3139- std::iota (inputTokens->begin (), inputTokens->end (), firstToken);
3140- auto llmRequest = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming);
3141- auto constexpr beamIdx = 0 ;
3142-
3143- // /////////////////////////////////////////////////////////////////////////
3144- // add a request and then remove it without reuse
3145- kvCacheManager.addSequence (requestId, inputLength, beamWidth, llmRequest);
3146- GenerationRequest const & seq = kvCacheManager.getSequence (requestId);
3147- EXPECT_EQ (llmRequest->getContextCurrentPosition (), 0 );
3148-
3149- auto const onlyWindowSize = theOnlyWindowSize (kvCacheManager);
3150-
3151- auto & blockIds = seq.getCacheBlockIds (onlyWindowSize).at (beamIdx);
3152- EXPECT_THAT (blockIds, ::testing::ElementsAreArray ({0 , 1 , 2 , 3 }));
3153-
3154- // get blocks by hash and try to retrieve them by hash
3155- auto blockHashes = getHashAndRetrieveBlocksByHashTest (blockManager, blockIds, onlyWindowSize);
3156-
3157- EXPECT_NO_THROW (kvCacheManager.removeSequence (requestId, llmRequest));
3158-
3159- // blocks are all removed
3160- for (auto hash : blockHashes)
3161- {
3162- auto range = blockManager.getBlocksByHash (hash, onlyWindowSize);
3163- EXPECT_EQ (range.first , range.second );
3164- }
3165- EXPECT_EQ (blockManager.getNumAllocatedBlocks (), 0 );
3166- }
3167-
3168- TEST_F (KVCacheManagerTest, KVCacheManagerHashKeyWithReuseTest)
3169- {
3170- auto kvCacheManager = setupKvCacheManagerForHashTest (true );
3171-
3172- auto const & blockManager = kvCacheManager.getBlockManager ();
3173-
3174- SizeType32 constexpr maxNewTokens = 4 ;
3175-
3176- // prepare tokens with token[i] = 1000 + i
3177- TokenIdType constexpr firstToken = 1000 ;
3178-
3179- auto constexpr beamWidth = 1 ;
3180- tr::SamplingConfig const samplingConfig{beamWidth};
3181- bool constexpr isStreaming{false };
3182-
3183- SizeType32 requestId = 0 ;
3184- int inputLength = 16 ;
3185- auto inputTokens = std::make_shared<VecTokens>(inputLength);
3186- std::iota (inputTokens->begin (), inputTokens->end (), firstToken);
3187- auto llmRequest = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming);
3188- auto constexpr beamIdx = 0 ;
3189-
3190- // /////////////////////////////////////////////////////////////////////////
3191- // add a request and then remove it with reuse
3192- kvCacheManager.addSequence (requestId, inputLength, beamWidth, llmRequest);
3193- GenerationRequest const & seq0 = kvCacheManager.getSequence (requestId);
3194- EXPECT_EQ (llmRequest->getContextCurrentPosition (), 0 );
3195-
3196- EXPECT_EQ (blockManager.getNumPools (), 1 );
3197- auto const onlyWindowSize = theOnlyWindowSize (kvCacheManager);
3198-
3199- auto & blockIds0 = seq0.getCacheBlockIds (onlyWindowSize).at (beamIdx);
3200- EXPECT_THAT (blockIds0, ::testing::ElementsAreArray ({0 , 1 , 2 , 3 }));
3201-
3202- // get blocks by hash and try to retrieve them by hash
3203- auto blockHashes = getHashAndRetrieveBlocksByHashTest (blockManager, blockIds0, onlyWindowSize);
3204-
3205- EXPECT_NO_THROW (kvCacheManager.removeSequence (requestId, llmRequest));
3206-
3207- // TODO: Make reused blocks accessible by hash, after sequence removed. Test here.
3208-
3209- // /////////////////////////////////////////////////////////////////////////
3210- // add a new request with same prefix
3211- requestId = 1 ;
3212- inputLength = 20 ;
3213- inputTokens->resize (inputLength);
3214- std::iota (inputTokens->begin (), inputTokens->end (), firstToken);
3215- llmRequest = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming);
3216- kvCacheManager.addSequence (requestId, inputLength, beamWidth, llmRequest);
3217- GenerationRequest const & seq1 = kvCacheManager.getSequence (requestId);
3218- EXPECT_EQ (llmRequest->getContextCurrentPosition (), 15 );
3219- auto & blockIds1 = seq1.getCacheBlockIds (onlyWindowSize).at (beamIdx);
3220- EXPECT_THAT (blockIds1, ::testing::ElementsAreArray ({0 , 1 , 2 , 3 , 4 }));
3221-
3222- std::ignore = getHashAndRetrieveBlocksByHashTest (blockManager, blockIds1, onlyWindowSize);
3223-
3224- // blocks are reused, so reused blocks are still accessible by previous hashes
3225- for (size_t i = 0 ; i < 4 ; i++)
3226- {
3227- auto range = blockManager.getBlocksByHash (blockHashes[i], onlyWindowSize);
3228- EXPECT_NE (range.first , range.second );
3229- }
3230- // evicted block is not accessible
3231- {
3232- size_t i = 4 ;
3233- auto range = blockManager.getBlocksByHash (blockHashes[i], onlyWindowSize);
3234- EXPECT_EQ (range.first , range.second );
3235- }
3236- EXPECT_EQ (blockManager.getNumAllocatedBlocks (), 5 );
3237- }
3238-
32393056TEST_F (KVCacheManagerTest, KVCacheManagerEventStream)
32403057{
32413058 auto constexpr numLayers = 12 ;
0 commit comments