1818#include " tensorrt_llm/common/assert.h"
1919#include " tensorrt_llm/common/cudaUtils.h"
2020#include " tensorrt_llm/common/memoryUtils.h"
21+ #include " tensorrt_llm/executor/transferAgent.h"
2122#include " tensorrt_llm/executor/types.h"
2223#include " tensorrt_llm/kernels/kvCacheIndex.h"
2324#include " tensorrt_llm/kernels/kvCacheUtils.h"
3233#include < chrono>
3334#include < cmath>
3435#include < cstddef>
36+ #include < filesystem>
3537#include < memory>
3638#include < set>
3739#include < thread>
@@ -45,6 +47,7 @@ namespace tk = tensorrt_llm::kernels;
4547namespace tlk = tensorrt_llm::batch_manager::kv_cache_manager;
4648namespace tle = tensorrt_llm::executor;
4749namespace tr = tensorrt_llm::runtime;
50+ namespace fs = std::filesystem;
4851
4952using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
5053
@@ -178,7 +181,39 @@ TEST_F(KVCacheManagerTest, BlockManagerTest)
178181 blockManager.addSequence (seq3, numBlocksPerBeam, numBlocksPerBeam - 1 , maxAttentionWindow), std::runtime_error);
179182}
180183
181- template <typename T, nvinfer1::DataType type, int mask>
184+ template <typename T>
185+ void writePatternToOffloadedBlocksDRAM (T* rawBlockPtr, int blockSize, int mask)
186+ {
187+ for (int i = 0 ; i < blockSize; ++i)
188+ {
189+ rawBlockPtr[i] = i & mask;
190+ }
191+ }
192+
193+ template <typename T>
194+ void writePatternToOffloadedBlocksGDS (
195+ std::string const & directory, int blockId, SizeType32 numPools, int blockSize, int mask)
196+ {
197+ for (size_t poolIdx = 0 ; poolIdx < numPools; ++poolIdx)
198+ {
199+ std::string filename
200+ = directory + " /block_" + std::to_string (blockId) + " _pool_" + std::to_string (poolIdx) + " .bin" ;
201+ int fd = ::open (filename.c_str (), O_WRONLY);
202+ if (fd >= 0 )
203+ {
204+ auto poolBlockSize = blockSize / numPools;
205+ std::vector<T> buffer (poolBlockSize);
206+ for (int i = 0 ; i < poolBlockSize; ++i)
207+ {
208+ buffer[i] = i & mask;
209+ }
210+ ::write (fd, buffer.data(), poolBlockSize * sizeof(T));
211+ ::close (fd);
212+ }
213+ }
214+ }
215+
216+ template <typename T, nvinfer1::DataType type, int mask, KvCacheTransferMode transferMode>
182217void runPartialCopyTest ()
183218{
184219 auto constexpr numLayers = 12 ;
@@ -199,6 +234,10 @@ void runPartialCopyTest()
199234 auto constexpr sinkTokenLen = 0 ;
200235 auto constexpr canUseOneMoreBlock = true ;
201236
237+ auto dirPath = fs::absolute (" test_partial_copy_tmp" );
238+ fs::create_directories (dirPath);
239+ std::string directory = dirPath.string ();
240+
202241 SizeType32 constexpr maxNewTokens{0 };
203242 auto constexpr beamWidth = 1 ;
204243 auto constexpr beamIdx = 0 ;
@@ -252,20 +291,27 @@ void runPartialCopyTest()
252291 auto block = blockManager.getBlockById (cacheBlockId, maxAttentionWindow);
253292 EXPECT_TRUE (block->isPrimary ());
254293 // offload so we can write to block in CPU code
255- blockManager.offloadBlock (block, maxAttentionWindow);
294+ blockManager.offloadBlock (block, maxAttentionWindow, transferMode, directory );
256295 EXPECT_FALSE (block->isPrimary ());
257296 // need to sync so D2H transfer is done before accessing blocks
258297 EXPECT_EQ (cudaDeviceSynchronize (), cudaSuccess);
259298 // fill with predictable pattern
260299 auto memoryPoolIndex = block->getMemoryPoolBlockIndex ();
261300 auto blockPtr{tr::ITensor::slice (secondaryPoolPtr, memoryPoolIndex, 1 )};
262301 auto rawBlockPtr = reinterpret_cast <T*>(blockPtr->data ());
263- for (int i = 0 ; i < blockSize; ++i)
302+ // Write value
303+ if constexpr (transferMode == KvCacheTransferMode::DRAM)
304+ {
305+ writePatternToOffloadedBlocksDRAM<T>(rawBlockPtr, blockSize, mask);
306+ }
307+ else if constexpr (transferMode == KvCacheTransferMode::GDS)
264308 {
265- rawBlockPtr[i] = i & mask;
309+ auto block_id = block->getBlockId ();
310+ auto numPools = blockManager.getNumPools (false );
311+ writePatternToOffloadedBlocksGDS<T>(directory, block_id, numPools, blockSize, mask);
266312 }
267313 // onboard
268- blockManager.onboardBlock (block, maxAttentionWindow);
314+ blockManager.onboardBlock (block, maxAttentionWindow, transferMode, directory );
269315 EXPECT_TRUE (block->isPrimary ());
270316 EXPECT_EQ (cudaDeviceSynchronize (), cudaSuccess);
271317 EXPECT_TRUE (blockManager.verifyQueueIntegrity (maxAttentionWindow));
@@ -340,60 +386,71 @@ void runPartialCopyTest()
340386 }
341387 }
342388 EXPECT_EQ (numBad, 0 );
343- blockManager.onboardBlock (block2, maxAttentionWindow);
389+ blockManager.onboardBlock (block2, maxAttentionWindow, transferMode, directory );
344390 EXPECT_TRUE (block2->isPrimary ());
345391 EXPECT_EQ (cudaDeviceSynchronize (), cudaSuccess);
346392
347393 blockManager.releaseBlocks (seq1, llmRequest1);
348394 blockManager.releaseBlocks (seq2, llmRequest2);
395+
396+ fs::remove_all (directory);
349397}
350398
351399TEST_F (KVCacheManagerTest, BlockManagerTestPartialCopyINT64)
352400{
353- runPartialCopyTest<std::uint64_t , nvinfer1::DataType::kINT64 , -1 >();
401+ runPartialCopyTest<std::uint64_t , nvinfer1::DataType::kINT64 , -1 , KvCacheTransferMode::DRAM>();
402+ runPartialCopyTest<std::uint64_t , nvinfer1::DataType::kINT64 , -1 , KvCacheTransferMode::GDS>();
354403}
355404
356405TEST_F (KVCacheManagerTest, BlockManagerTestPartialCopyINT32)
357406{
358- runPartialCopyTest<std::uint32_t , nvinfer1::DataType::kINT32 , -1 >();
407+ runPartialCopyTest<std::uint32_t , nvinfer1::DataType::kINT32 , -1 , KvCacheTransferMode::DRAM>();
408+ runPartialCopyTest<std::uint32_t , nvinfer1::DataType::kINT32 , -1 , KvCacheTransferMode::GDS>();
359409}
360410
361411TEST_F (KVCacheManagerTest, BlockManagerTestPartialCopyFLOAT)
362412{
363- runPartialCopyTest<std::uint32_t , nvinfer1::DataType::kFLOAT , -1 >();
413+ runPartialCopyTest<std::uint32_t , nvinfer1::DataType::kFLOAT , -1 , KvCacheTransferMode::DRAM>();
414+ runPartialCopyTest<std::uint32_t , nvinfer1::DataType::kFLOAT , -1 , KvCacheTransferMode::GDS>();
364415}
365416
366417#ifdef ENABLE_BF16
367418TEST_F (KVCacheManagerTest, BlockManagerTestPartialCopyBF16)
368419{
369- runPartialCopyTest<std::uint16_t , nvinfer1::DataType::kBF16 , 65535 >();
420+ runPartialCopyTest<std::uint16_t , nvinfer1::DataType::kBF16 , 65535 , KvCacheTransferMode::DRAM>();
421+ runPartialCopyTest<std::uint16_t , nvinfer1::DataType::kBF16 , 65535 , KvCacheTransferMode::GDS>();
370422}
371423#endif
372424
373425TEST_F (KVCacheManagerTest, BlockManagerTestPartialCopyHALF)
374426{
375- runPartialCopyTest<std::uint16_t , nvinfer1::DataType::kHALF , 65535 >();
427+ runPartialCopyTest<std::uint16_t , nvinfer1::DataType::kHALF , 65535 , KvCacheTransferMode::DRAM>();
428+ runPartialCopyTest<std::uint16_t , nvinfer1::DataType::kHALF , 65535 , KvCacheTransferMode::GDS>();
376429}
377430
378431TEST_F (KVCacheManagerTest, BlockManagerTestPartialCopyBOOL)
379432{
380- runPartialCopyTest<std::uint8_t , nvinfer1::DataType::kBOOL , 255 >();
433+ runPartialCopyTest<std::uint8_t , nvinfer1::DataType::kBOOL , 255 , KvCacheTransferMode::DRAM>();
434+ runPartialCopyTest<std::uint8_t , nvinfer1::DataType::kBOOL , 255 , KvCacheTransferMode::GDS>();
381435}
382436
383437TEST_F (KVCacheManagerTest, BlockManagerTestPartialCopyUINT8)
384438{
385- runPartialCopyTest<std::uint8_t , nvinfer1::DataType::kUINT8 , 255 >();
439+ runPartialCopyTest<std::uint8_t , nvinfer1::DataType::kUINT8 , 255 , KvCacheTransferMode::DRAM>();
440+ runPartialCopyTest<std::uint8_t , nvinfer1::DataType::kUINT8 , 255 , KvCacheTransferMode::GDS>();
386441}
387442
388443TEST_F (KVCacheManagerTest, BlockManagerTestPartialCopyINT8)
389444{
390- runPartialCopyTest<std::uint8_t , nvinfer1::DataType::kINT8 , 255 >();
445+ runPartialCopyTest<std::uint8_t , nvinfer1::DataType::kINT8 , 255 , KvCacheTransferMode::DRAM>();
446+ runPartialCopyTest<std::uint8_t , nvinfer1::DataType::kINT8 , 255 , KvCacheTransferMode::GDS>();
391447}
392448
393449#ifdef ENABLE_FP8
394450TEST_F (KVCacheManagerTest, BlockManagerTestPartialCopyFP8)
395451{
396- runPartialCopyTest<std::uint8_t , nvinfer1::DataType::kFP8 , 255 >();
452+ runPartialCopyTest<std::uint8_t , nvinfer1::DataType::kFP8 , 255 , KvCacheTransferMode::DRAM>();
453+ runPartialCopyTest<std::uint8_t , nvinfer1::DataType::kFP8 , 255 , KvCacheTransferMode::GDS>();
397454}
398455#endif
399456
0 commit comments