|
14 | 14 | * limitations under the License. |
15 | 15 | */ |
16 | 16 | #include "ub_allocator.h" |
| 17 | +#include "tensorrt_llm/common/opUtils.h" |
| 18 | +#include <set> |
17 | 19 |
|
18 | 20 | namespace tensorrt_llm::runtime::ub |
19 | 21 | { |
20 | 22 | UserBufferAllocator& UserBufferAllocator::Instance() |
21 | 23 | { |
22 | | - static UserBufferAllocator _; |
23 | | - return _; |
| 24 | + // if environment variable TLLM_USE_NCCL_UB is set to 1, use NCCLUserBufferAllocator |
| 25 | + char* useNCCLUB = std::getenv("TLLM_USE_NCCL_UB"); |
| 26 | + if (useNCCLUB != nullptr) |
| 27 | + { |
| 28 | + static NCCLUserBufferAllocator _; |
| 29 | + return _; |
| 30 | + } |
| 31 | + else |
| 32 | + { |
| 33 | + static UserBufferAllocator _; |
| 34 | + return _; |
| 35 | + } |
24 | 36 | } |
25 | 37 |
|
26 | | -void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& world_config) |
| 38 | +void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig) |
27 | 39 | { |
28 | | - if (!is_initialized()) |
| 40 | + if (!isInitialized()) |
29 | 41 | { |
30 | | - ub_comm_ = nullptr; |
31 | | - world_config_ = world_config; |
32 | | - create_communicator_grouped2(&ub_comm_, world_config_); |
33 | | - TLLM_CHECK(ub_comm_ != nullptr); |
34 | | - is_initialized_ = true; |
| 42 | + mUbComm = nullptr; |
| 43 | + mWorldConfig = worldConfig; |
| 44 | + create_communicator_grouped2(&mUbComm, worldConfig); |
| 45 | + TLLM_CHECK(mUbComm != nullptr); |
| 46 | + mIsInitialized = true; |
35 | 47 | } |
36 | 48 | } |
37 | 49 |
|
38 | | -bool UserBufferAllocator::is_initialized() |
| 50 | +bool UserBufferAllocator::isInitialized() |
39 | 51 | { |
40 | | - return is_initialized_; |
| 52 | + return mIsInitialized; |
41 | 53 | } |
42 | 54 |
|
43 | | -UBBuffer UserBufferAllocator::register_ub_buffer(size_t bytes) |
| 55 | +UBBuffer UserBufferAllocator::registerUBBuffer(size_t bytes) |
44 | 56 | { |
45 | | - TLLM_CHECK(is_initialized()); |
| 57 | + TLLM_CHECK(isInitialized()); |
46 | 58 | void* addr = nullptr; |
47 | 59 | int handle = -1; |
48 | | - handle = register_user_buffer_collective((void**) &addr, bytes, ub_comm_); |
| 60 | + handle = register_user_buffer_collective((void**) &addr, bytes, mUbComm); |
49 | 61 | return {addr, handle, bytes}; |
50 | 62 | } |
51 | 63 |
|
52 | 64 | UBBuffer UserBufferAllocator::allocate(size_t bytes) |
53 | 65 | { |
54 | | - TLLM_CHECK(is_initialized()); |
55 | | - auto ub_buffer = register_ub_buffer(bytes); |
| 66 | + TLLM_CHECK(isInitialized()); |
| 67 | + auto ub_buffer = registerUBBuffer(bytes); |
56 | 68 | TLLM_CHECK(!ub_buffer.invalid()); |
57 | | - buffers_.push_back(ub_buffer); |
| 69 | + mBuffers.push_back(ub_buffer); |
58 | 70 | return ub_buffer; |
59 | 71 | } |
60 | 72 |
|
61 | 73 | void UserBufferAllocator::deallocate(void* addr) {} |
62 | 74 |
|
63 | 75 | UBBuffer UserBufferAllocator::get(int idx) |
64 | 76 | { |
65 | | - TLLM_CHECK(is_initialized() && idx < buffers_.size() && !buffers_[idx].invalid()); |
66 | | - return buffers_[idx]; |
| 77 | + TLLM_CHECK(isInitialized() && idx < mBuffers.size() && !mBuffers[idx].invalid()); |
| 78 | + return mBuffers[idx]; |
67 | 79 | } |
68 | 80 |
|
69 | 81 | communicator* UserBufferAllocator::comm() |
70 | 82 | { |
71 | | - TLLM_CHECK(is_initialized()); |
72 | | - return ub_comm_; |
| 83 | + TLLM_CHECK(isInitialized()); |
| 84 | + return mUbComm; |
| 85 | +} |
| 86 | + |
| 87 | +void NCCLUserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig) |
| 88 | +{ |
| 89 | + if (!isInitialized()) |
| 90 | + { |
| 91 | + TLLM_LOG_INFO("Initializing NCCLUserBufferAllocator"); |
| 92 | + std::set<int> group; |
| 93 | + for (int i = 0; i < worldConfig.getSize(); i++) |
| 94 | + { |
| 95 | + group.insert(i); |
| 96 | + } |
| 97 | + mComm = getComm(group); |
| 98 | + mIsInitialized = true; |
| 99 | + } |
73 | 100 | } |
| 101 | + |
| 102 | +UBBuffer NCCLUserBufferAllocator::registerUBBuffer(size_t bytes) |
| 103 | +{ |
| 104 | + TLLM_CHECK(isInitialized()); |
| 105 | + UBBuffer ub_buffer; |
| 106 | + NCCLCHECK(ncclMemAlloc(&ub_buffer.addr, bytes)); |
| 107 | + NCCLCHECK(ncclCommWindowRegister((*mComm), ub_buffer.addr, bytes, &ub_buffer.window, NCCL_WIN_COLL_SYMMETRIC)); |
| 108 | + ub_buffer.handle = 5; |
| 109 | + ub_buffer.size = bytes; |
| 110 | + return ub_buffer; |
| 111 | +} |
| 112 | + |
74 | 113 | }; // namespace tensorrt_llm::runtime::ub |
0 commit comments