|
16 | 16 | #include <gmock/gmock.h> |
17 | 17 | #include <gtest/gtest.h> |
18 | 18 |
|
| 19 | +#include <filesystem> |
| 20 | + |
| 21 | +namespace fs = std::filesystem; |
| 22 | + |
19 | 23 | using namespace tensorrt_llm::executor::kv_cache; |
20 | 24 |
|
21 | 25 | class RegisteredHostMemory |
@@ -339,3 +343,110 @@ TEST_F(TransferAgentTest, SyncMessage) |
339 | 343 | nixlAgent0->invalidateRemoteAgent(agent1); |
340 | 344 | nixlAgent1->invalidateRemoteAgent(agent0); |
341 | 345 | } |
| 346 | + |
| 347 | +class LoopbackAgentTest : public ::testing::Test, |
| 348 | + public ::testing::WithParamInterface<bool> // NOLINT(cppcoreguidelines-pro-type-member-init) |
| 349 | +{ |
| 350 | +public: |
| 351 | + void SetUp() override |
| 352 | + { |
| 353 | + auto dirPath = fs::absolute("test_loopback_agent_tmp"); |
| 354 | + std::error_code ec; |
| 355 | + fs::create_directories(dirPath, ec); |
| 356 | + TLLM_CHECK_WITH_INFO(!ec, "Failed to create test directory: %s", ec.message().c_str()); |
| 357 | + mDirectory = dirPath.string(); |
| 358 | + } |
| 359 | + |
| 360 | + void TearDown() override |
| 361 | + { |
| 362 | + std::error_code ec; |
| 363 | + fs::remove_all(mDirectory, ec); |
| 364 | + if (ec) |
| 365 | + std::cerr << "Warning: Failed to clean up test directory: " << ec.message() << std::endl; |
| 366 | + } |
| 367 | + |
| 368 | + [[nodiscard]] std::unique_ptr<BaseLoopbackAgent> makeLoopbackAgent(BaseAgentConfig const& config) |
| 369 | + { |
| 370 | + return tensorrt_llm::executor::kv_cache::makeLoopbackAgent("nixl", &config); |
| 371 | + } |
| 372 | + |
| 373 | + [[nodiscard]] std::string getDirectory() const |
| 374 | + { |
| 375 | + return mDirectory; |
| 376 | + } |
| 377 | + |
| 378 | +private: |
| 379 | + std::string mDirectory; |
| 380 | +}; |
| 381 | + |
| 382 | +TEST_P(LoopbackAgentTest, Basic) |
| 383 | +{ |
| 384 | + std::string const agentName{"loopbackAgent"}; |
| 385 | + BaseAgentConfig config{agentName, true, GetParam()}; |
| 386 | + auto loopbackAgent = makeLoopbackAgent(config); |
| 387 | + |
| 388 | + TLLM_CHECK(loopbackAgent); |
| 389 | + |
| 390 | + std::vector<char> memory(100, 1); |
| 391 | + char* cuda_mem; |
| 392 | + TLLM_CUDA_CHECK(cudaMalloc(&cuda_mem, 100)); |
| 393 | + cudaMemcpy(cuda_mem, memory.data(), 100, cudaMemcpyHostToDevice); |
| 394 | + |
| 395 | + std::vector<FileDesc> fileDescVec; |
| 396 | + fileDescVec.emplace_back(getDirectory() + "/basic_test.bin", O_CREAT | O_RDWR, 0664, 100); |
| 397 | + std::vector<char> fileData(100, 10); |
| 398 | + ssize_t bytesWritten = write(fileDescVec[0].getFd(), fileData.data(), fileData.size()); |
| 399 | + TLLM_CHECK_WITH_INFO(bytesWritten == fileData.size(), "Failed to write to file"); |
| 400 | + |
| 401 | + MemoryDescs memDescs{MemoryType::kVRAM, {MemoryDesc{cuda_mem, 100, 0}}}; |
| 402 | + FileDescs fileDescs{fileDescVec}; |
| 403 | + |
| 404 | + loopbackAgent->registerMemory(memDescs); |
| 405 | + loopbackAgent->registerFiles(fileDescs); |
| 406 | + |
| 407 | + auto status = loopbackAgent->submitLoopbackRequests(memDescs, fileDescs, false); |
| 408 | + status->wait(); |
| 409 | + |
| 410 | + cudaMemcpy(memory.data(), cuda_mem, 100, cudaMemcpyDeviceToHost); |
| 411 | + |
| 412 | + TLLM_CHECK(memory == fileData); |
| 413 | + TLLM_CUDA_CHECK(cudaFree(cuda_mem)); |
| 414 | + |
| 415 | + loopbackAgent->deregisterMemory(memDescs); |
| 416 | + loopbackAgent->deregisterFiles(fileDescs); |
| 417 | +} |
| 418 | + |
| 419 | +TEST_P(LoopbackAgentTest, Basic2) |
| 420 | +{ |
| 421 | + std::string const agentName{"loopbackAgent"}; |
| 422 | + BaseAgentConfig config{agentName, true, GetParam()}; |
| 423 | + auto loopbackAgent = makeLoopbackAgent(config); |
| 424 | + |
| 425 | + TLLM_CHECK(loopbackAgent); |
| 426 | + |
| 427 | + std::vector<char> memory(100, 1); |
| 428 | + MemoryDesc memDesc(memory); |
| 429 | + MemoryDescs memDescs{MemoryType::kVRAM, {memDesc}}; |
| 430 | + |
| 431 | + std::vector<FileDesc> fileDescVec; |
| 432 | + fileDescVec.emplace_back(getDirectory() + "/basic2_test.bin", O_CREAT | O_RDWR, 0664, 100); |
| 433 | + |
| 434 | + FileDescs fileDescs{fileDescVec}; |
| 435 | + |
| 436 | + loopbackAgent->registerMemory(memDescs); |
| 437 | + loopbackAgent->registerFiles(fileDescs); |
| 438 | + |
| 439 | + auto status = loopbackAgent->submitLoopbackRequests(memDescs, fileDescs, true); |
| 440 | + status->wait(); |
| 441 | + |
| 442 | + std::vector<char> fileData(100); |
| 443 | + ssize_t bytesRead = read(fileDescs.getDescs()[0].getFd(), fileData.data(), fileData.size()); |
| 444 | + TLLM_CHECK_WITH_INFO(bytesRead == fileData.size(), "Failed to read from file"); |
| 445 | + |
| 446 | + TLLM_CHECK(fileData == memory); |
| 447 | + |
| 448 | + loopbackAgent->deregisterMemory(memDescs); |
| 449 | + loopbackAgent->deregisterFiles(fileDescs); |
| 450 | +} |
| 451 | + |
| 452 | +INSTANTIATE_TEST_SUITE_P(, LoopbackAgentTest, ::testing::Values(true, false)); |
0 commit comments