Skip to content

Commit 8227616

Browse files
authored
[None][feat] Add NCCL Symmetric Integration for All Reduce (#4500)
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 980929e commit 8227616

File tree

16 files changed

+375
-66
lines changed

16 files changed

+375
-66
lines changed

cpp/tensorrt_llm/kernels/customAllReduceKernels.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ enum class AllReduceStrategyType : int8_t
5656
ONESHOT = 4,
5757
TWOSHOT = 5,
5858
LOWPRECISION = 6,
59+
MNNVL = 7,
60+
NCCL_SYMMETRIC = 8,
5961
};
6062

6163
enum class AllReduceStrategyConfig : int8_t

cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp

Lines changed: 196 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,61 +14,236 @@
1414
* limitations under the License.
1515
*/
1616
#include "ub_allocator.h"
17+
#include "tensorrt_llm/common/opUtils.h"
18+
#include <set>
19+
#include <stdexcept>
1720

1821
namespace tensorrt_llm::runtime::ub
1922
{
2023
UserBufferAllocator& UserBufferAllocator::Instance()
2124
{
22-
static UserBufferAllocator _;
23-
return _;
25+
if (use_nccl_symmetric)
26+
{
27+
static NCCLUserBufferAllocator _;
28+
return _;
29+
}
30+
else
31+
{
32+
static UserBufferAllocator _;
33+
return _;
34+
}
2435
}
2536

26-
void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& world_config)
37+
void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig)
2738
{
28-
if (!is_initialized())
39+
if (!isInitialized())
2940
{
30-
ub_comm_ = nullptr;
31-
world_config_ = world_config;
32-
create_communicator_grouped2(&ub_comm_, world_config_);
33-
TLLM_CHECK(ub_comm_ != nullptr);
34-
is_initialized_ = true;
41+
mUbComm = nullptr;
42+
mWorldConfig = worldConfig;
43+
create_communicator_grouped2(&mUbComm, worldConfig);
44+
TLLM_CHECK(mUbComm != nullptr);
45+
mIsInitialized = true;
3546
}
3647
}
3748

38-
bool UserBufferAllocator::is_initialized()
49+
bool UserBufferAllocator::isInitialized()
3950
{
40-
return is_initialized_;
51+
return mIsInitialized;
4152
}
4253

43-
UBBuffer UserBufferAllocator::register_ub_buffer(size_t bytes)
54+
UBBuffer UserBufferAllocator::registerUBBuffer(size_t bytes)
4455
{
45-
TLLM_CHECK(is_initialized());
56+
TLLM_CHECK(isInitialized());
4657
void* addr = nullptr;
4758
int handle = -1;
48-
handle = register_user_buffer_collective((void**) &addr, bytes, ub_comm_);
59+
handle = register_user_buffer_collective((void**) &addr, bytes, mUbComm);
4960
return {addr, handle, bytes};
5061
}
5162

5263
UBBuffer UserBufferAllocator::allocate(size_t bytes)
5364
{
54-
TLLM_CHECK(is_initialized());
55-
auto ub_buffer = register_ub_buffer(bytes);
65+
TLLM_CHECK(isInitialized());
66+
auto ub_buffer = registerUBBuffer(bytes);
5667
TLLM_CHECK(!ub_buffer.invalid());
57-
buffers_.push_back(ub_buffer);
68+
mBuffers.push_back(ub_buffer);
5869
return ub_buffer;
5970
}
6071

6172
void UserBufferAllocator::deallocate(void* addr) {}
6273

6374
UBBuffer UserBufferAllocator::get(int idx)
6475
{
65-
TLLM_CHECK(is_initialized() && idx < buffers_.size() && !buffers_[idx].invalid());
66-
return buffers_[idx];
76+
TLLM_CHECK(isInitialized() && idx < mBuffers.size() && !mBuffers[idx].invalid());
77+
return mBuffers[idx];
6778
}
6879

6980
communicator* UserBufferAllocator::comm()
7081
{
71-
TLLM_CHECK(is_initialized());
72-
return ub_comm_;
82+
TLLM_CHECK(isInitialized());
83+
return mUbComm;
84+
}
85+
86+
void NCCLUserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig)
87+
{
88+
if (!isInitialized())
89+
{
90+
TLLM_LOG_INFO("Initializing NCCLUserBufferAllocator");
91+
std::set<int> group;
92+
for (int i = 0; i < worldConfig.getSize(); i++)
93+
{
94+
group.insert(i);
95+
}
96+
mComm = getComm(group);
97+
mIsInitialized = true;
98+
}
7399
}
100+
101+
UBBuffer NCCLUserBufferAllocator::registerUBBuffer(size_t bytes)
102+
{
103+
TLLM_CHECK(isInitialized());
104+
UBBuffer ub_buffer;
105+
106+
auto& ncclHelper = getNCCLHelper();
107+
if (!ncclHelper.isLoaded())
108+
{
109+
TLLM_THROW("NCCL library could not be loaded for dynamic symbol access");
110+
}
111+
112+
auto ncclMemAllocFunc = ncclHelper.getNCCLMemAlloc();
113+
auto ncclCommWindowRegisterFunc = ncclHelper.getNCCLCommWindowRegister();
114+
115+
NCCLCHECK(ncclMemAllocFunc(&ub_buffer.addr, bytes));
116+
NCCLCHECK(ncclCommWindowRegisterFunc((*mComm), ub_buffer.addr, bytes, &ub_buffer.window, NCCL_WIN_COLL_SYMMETRIC));
117+
ub_buffer.handle = 5;
118+
ub_buffer.size = bytes;
119+
return ub_buffer;
120+
}
121+
122+
// Static member definitions
123+
std::unique_ptr<NCCLHelper> NCCLUserBufferAllocator::mNCCLHelper = nullptr;
124+
125+
NCCLHelper& NCCLUserBufferAllocator::getNCCLHelper()
126+
{
127+
if (!mNCCLHelper)
128+
{
129+
mNCCLHelper = std::make_unique<NCCLHelper>();
130+
}
131+
return *mNCCLHelper;
132+
}
133+
134+
// NCCLHelper implementation
135+
NCCLHelper::NCCLHelper()
136+
: mLibraryHandle(nullptr)
137+
, mNCCLCommWindowRegister(nullptr)
138+
, mNCCLMemAlloc(nullptr)
139+
, mIsLoaded(false)
140+
{
141+
loadNCCLLibrary();
142+
}
143+
144+
NCCLHelper::~NCCLHelper()
145+
{
146+
if (mLibraryHandle)
147+
{
148+
#ifdef _WIN32
149+
FreeLibrary(mLibraryHandle);
150+
#else
151+
dlclose(mLibraryHandle);
152+
#endif
153+
mLibraryHandle = nullptr;
154+
}
155+
}
156+
157+
void NCCLHelper::loadNCCLLibrary()
158+
{
159+
try
160+
{
161+
#ifdef _WIN32
162+
char const* libraryNames[] = {"nccl.dll"};
163+
#else
164+
char const* libraryNames[] = {"libnccl.so"};
165+
#endif
166+
167+
for (int i = 0; libraryNames[i] != nullptr; ++i)
168+
{
169+
mLibraryHandle = loadLibraryHandle(libraryNames[i]);
170+
if (mLibraryHandle)
171+
{
172+
TLLM_LOG_INFO("Successfully loaded NCCL library: %s", libraryNames[i]);
173+
break;
174+
}
175+
}
176+
177+
if (!mLibraryHandle)
178+
{
179+
TLLM_LOG_WARNING("Failed to load NCCL library");
180+
return;
181+
}
182+
183+
// Load the required symbols
184+
mNCCLCommWindowRegister
185+
= reinterpret_cast<ncclCommWindowRegisterFunc>(getSymbolAddress(mLibraryHandle, "ncclCommWindowRegister"));
186+
187+
mNCCLMemAlloc = reinterpret_cast<ncclMemAllocFunc>(getSymbolAddress(mLibraryHandle, "ncclMemAlloc"));
188+
189+
if (mNCCLCommWindowRegister == nullptr)
190+
{
191+
TLLM_LOG_WARNING("Failed to load ncclCommWindowRegister symbol, NCCL symmetric will not be supported.");
192+
}
193+
194+
if (mNCCLMemAlloc)
195+
{
196+
mIsLoaded = true;
197+
}
198+
else
199+
{
200+
TLLM_LOG_WARNING("Failed to load required NCCL symbols");
201+
}
202+
}
203+
catch (std::exception const& e)
204+
{
205+
TLLM_LOG_WARNING("Exception while loading NCCL library: %s", e.what());
206+
}
207+
}
208+
209+
void* NCCLHelper::loadLibraryHandle(char const* libName)
210+
{
211+
#ifdef _WIN32
212+
return LoadLibraryA(libName);
213+
#else
214+
return dlopen(libName, RTLD_LAZY | RTLD_GLOBAL);
215+
#endif
216+
}
217+
218+
void* NCCLHelper::getSymbolAddress(void* handle, char const* symbolName)
219+
{
220+
if (!handle)
221+
{
222+
return nullptr;
223+
}
224+
225+
#ifdef _WIN32
226+
return GetProcAddress(static_cast<HMODULE>(handle), symbolName);
227+
#else
228+
return dlsym(handle, symbolName);
229+
#endif
230+
}
231+
232+
NCCLHelper::ncclCommWindowRegisterFunc NCCLHelper::getNCCLCommWindowRegister()
233+
{
234+
return mNCCLCommWindowRegister;
235+
}
236+
237+
NCCLHelper::ncclMemAllocFunc NCCLHelper::getNCCLMemAlloc()
238+
{
239+
return mNCCLMemAlloc;
240+
}
241+
242+
bool NCCLHelper::isLoaded() const
243+
{
244+
return mIsLoaded;
245+
}
246+
247+
bool UserBufferAllocator::use_nccl_symmetric = false;
248+
74249
}; // namespace tensorrt_llm::runtime::ub

cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,16 @@
1414
* limitations under the License.
1515
*/
1616
#pragma once
17+
#include "nccl.h"
1718
#include "tensorrt_llm/runtime/worldConfig.h"
19+
#include <memory>
1820
#if ENABLE_MULTI_DEVICE
1921
#include "userbuffers.h"
22+
#ifdef _WIN32
23+
#include <windows.h>
24+
#else
25+
#include <dlfcn.h>
26+
#endif
2027
#endif
2128

2229
namespace tensorrt_llm::runtime::ub
@@ -28,11 +35,13 @@ struct UBBuffer
2835
void* addr;
2936
int handle;
3037
size_t size;
38+
ncclWindow_t window;
3139

32-
UBBuffer(void* a = nullptr, int h = -1, size_t s = 0)
40+
UBBuffer(void* a = nullptr, int h = -1, size_t s = 0, ncclWindow_t w = nullptr)
3341
: addr(a)
3442
, handle(h)
3543
, size(s)
44+
, window(w)
3645
{
3746
}
3847

@@ -49,21 +58,74 @@ class UserBufferAllocator
4958

5059
UserBufferAllocator() = default;
5160

52-
void initialize(tensorrt_llm::runtime::WorldConfig const& world_config);
53-
bool is_initialized();
61+
virtual void initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig);
62+
bool isInitialized();
5463
UBBuffer allocate(size_t bytes);
5564
void deallocate(void* addr);
5665
UBBuffer get(int idx);
5766
communicator* comm();
67+
virtual UBBuffer registerUBBuffer(size_t bytes);
68+
69+
static bool use_nccl_symmetric;
5870

5971
private:
60-
UBBuffer register_ub_buffer(size_t bytes);
72+
communicator* mUbComm;
6173

62-
communicator* ub_comm_;
63-
std::vector<UBBuffer> buffers_;
64-
bool is_initialized_;
65-
tensorrt_llm::runtime::WorldConfig world_config_;
74+
protected:
75+
std::vector<UBBuffer> mBuffers;
76+
bool mIsInitialized;
77+
tensorrt_llm::runtime::WorldConfig mWorldConfig;
6678
};
79+
80+
class NCCLHelper
81+
{
82+
public:
83+
NCCLHelper();
84+
~NCCLHelper();
85+
86+
// Dynamic loading function type definition
87+
using ncclCommWindowRegisterFunc = ncclResult_t (*)(ncclComm_t, void*, size_t, ncclWindow_t*, int);
88+
using ncclMemAllocFunc = ncclResult_t (*)(void**, size_t);
89+
90+
// Get function pointer for ncclCommWindowRegister
91+
ncclCommWindowRegisterFunc getNCCLCommWindowRegister();
92+
93+
// Get function pointer for ncclMemAlloc
94+
ncclMemAllocFunc getNCCLMemAlloc();
95+
96+
// Check if NCCL library is successfully loaded
97+
bool isLoaded() const;
98+
99+
private:
100+
void loadNCCLLibrary();
101+
void* loadLibraryHandle(char const* libName);
102+
void* getSymbolAddress(void* handle, char const* symbolName);
103+
104+
#ifdef _WIN32
105+
HMODULE mLibraryHandle;
106+
#else
107+
void* mLibraryHandle;
108+
#endif
109+
110+
ncclCommWindowRegisterFunc mNCCLCommWindowRegister;
111+
ncclMemAllocFunc mNCCLMemAlloc;
112+
bool mIsLoaded;
113+
};
114+
115+
class NCCLUserBufferAllocator : public UserBufferAllocator
116+
{
117+
public:
118+
void initialize(tensorrt_llm::runtime::WorldConfig const& world_config) override;
119+
UBBuffer registerUBBuffer(size_t bytes) override;
120+
121+
// Get shared NCCLHelper instance
122+
static NCCLHelper& getNCCLHelper();
123+
124+
private:
125+
std::shared_ptr<ncclComm_t> mComm;
126+
static std::unique_ptr<NCCLHelper> mNCCLHelper;
127+
};
128+
67129
#else
68130
using communicator = void;
69131
#endif

cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void ub_initialize(int tp_size)
3636

3737
bool ub_is_initialized()
3838
{
39-
return UserBufferAllocator::Instance().is_initialized();
39+
return UserBufferAllocator::Instance().isInitialized();
4040
}
4141

4242
UBBuffer ub_allocate(size_t bytes)

0 commit comments

Comments
 (0)