Skip to content

Commit 4237fc5

Browse files
committed
Add LoopbackAgent tests to transferAgentTest
Signed-off-by: Guy Lev <[email protected]>
1 parent e41fa8f commit 4237fc5

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

cpp/tests/unit_tests/executor/transferAgentTest.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
#include <gmock/gmock.h>
1717
#include <gtest/gtest.h>
1818

19+
#include <filesystem>
20+
21+
namespace fs = std::filesystem;
22+
1923
using namespace tensorrt_llm::executor::kv_cache;
2024

2125
class RegisteredHostMemory
@@ -339,3 +343,110 @@ TEST_F(TransferAgentTest, SyncMessage)
339343
nixlAgent0->invalidateRemoteAgent(agent1);
340344
nixlAgent1->invalidateRemoteAgent(agent0);
341345
}
346+
347+
class LoopbackAgentTest : public ::testing::Test,
348+
public ::testing::WithParamInterface<bool> // NOLINT(cppcoreguidelines-pro-type-member-init)
349+
{
350+
public:
351+
void SetUp() override
352+
{
353+
auto dirPath = fs::absolute("test_loopback_agent_tmp");
354+
std::error_code ec;
355+
fs::create_directories(dirPath, ec);
356+
TLLM_CHECK_WITH_INFO(!ec, "Failed to create test directory: %s", ec.message().c_str());
357+
mDirectory = dirPath.string();
358+
}
359+
360+
void TearDown() override
361+
{
362+
std::error_code ec;
363+
fs::remove_all(mDirectory, ec);
364+
if (ec)
365+
std::cerr << "Warning: Failed to clean up test directory: " << ec.message() << std::endl;
366+
}
367+
368+
[[nodiscard]] std::unique_ptr<BaseLoopbackAgent> makeLoopbackAgent(BaseAgentConfig const& config)
369+
{
370+
return tensorrt_llm::executor::kv_cache::makeLoopbackAgent("nixl", &config);
371+
}
372+
373+
[[nodiscard]] std::string getDirectory() const
374+
{
375+
return mDirectory;
376+
}
377+
378+
private:
379+
std::string mDirectory;
380+
};
381+
382+
TEST_P(LoopbackAgentTest, Basic)
383+
{
384+
std::string const agentName{"loopbackAgent"};
385+
BaseAgentConfig config{agentName, true, GetParam()};
386+
auto loopbackAgent = makeLoopbackAgent(config);
387+
388+
TLLM_CHECK(loopbackAgent);
389+
390+
std::vector<char> memory(100, 1);
391+
char* cuda_mem;
392+
TLLM_CUDA_CHECK(cudaMalloc(&cuda_mem, 100));
393+
cudaMemcpy(cuda_mem, memory.data(), 100, cudaMemcpyHostToDevice);
394+
395+
std::vector<FileDesc> fileDescVec;
396+
fileDescVec.emplace_back(getDirectory() + "/basic_test.bin", O_CREAT | O_RDWR, 0664, 100);
397+
std::vector<char> fileData(100, 10);
398+
ssize_t bytesWritten = write(fileDescVec[0].getFd(), fileData.data(), fileData.size());
399+
TLLM_CHECK_WITH_INFO(bytesWritten == fileData.size(), "Failed to write to file");
400+
401+
MemoryDescs memDescs{MemoryType::kVRAM, {MemoryDesc{cuda_mem, 100, 0}}};
402+
FileDescs fileDescs{fileDescVec};
403+
404+
loopbackAgent->registerMemory(memDescs);
405+
loopbackAgent->registerFiles(fileDescs);
406+
407+
auto status = loopbackAgent->submitLoopbackRequests(memDescs, fileDescs, false);
408+
status->wait();
409+
410+
cudaMemcpy(memory.data(), cuda_mem, 100, cudaMemcpyDeviceToHost);
411+
412+
TLLM_CHECK(memory == fileData);
413+
TLLM_CUDA_CHECK(cudaFree(cuda_mem));
414+
415+
loopbackAgent->deregisterMemory(memDescs);
416+
loopbackAgent->deregisterFiles(fileDescs);
417+
}
418+
419+
TEST_P(LoopbackAgentTest, Basic2)
420+
{
421+
std::string const agentName{"loopbackAgent"};
422+
BaseAgentConfig config{agentName, true, GetParam()};
423+
auto loopbackAgent = makeLoopbackAgent(config);
424+
425+
TLLM_CHECK(loopbackAgent);
426+
427+
std::vector<char> memory(100, 1);
428+
MemoryDesc memDesc(memory);
429+
MemoryDescs memDescs{MemoryType::kVRAM, {memDesc}};
430+
431+
std::vector<FileDesc> fileDescVec;
432+
fileDescVec.emplace_back(getDirectory() + "/basic2_test.bin", O_CREAT | O_RDWR, 0664, 100);
433+
434+
FileDescs fileDescs{fileDescVec};
435+
436+
loopbackAgent->registerMemory(memDescs);
437+
loopbackAgent->registerFiles(fileDescs);
438+
439+
auto status = loopbackAgent->submitLoopbackRequests(memDescs, fileDescs, true);
440+
status->wait();
441+
442+
std::vector<char> fileData(100);
443+
ssize_t bytesRead = read(fileDescs.getDescs()[0].getFd(), fileData.data(), fileData.size());
444+
TLLM_CHECK_WITH_INFO(bytesRead == fileData.size(), "Failed to read from file");
445+
446+
TLLM_CHECK(fileData == memory);
447+
448+
loopbackAgent->deregisterMemory(memDescs);
449+
loopbackAgent->deregisterFiles(fileDescs);
450+
}
451+
452+
INSTANTIATE_TEST_SUITE_P(, LoopbackAgentTest, ::testing::Values(true, false));

0 commit comments

Comments
 (0)