Skip to content

Commit 3f77fe7

Browse files
committed
Add UB NCCL integration
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 174c518 commit 3f77fe7

File tree

8 files changed

+171
-32
lines changed

8 files changed

+171
-32
lines changed

cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,61 +14,100 @@
1414
* limitations under the License.
1515
*/
1616
#include "ub_allocator.h"
17+
#include "tensorrt_llm/common/opUtils.h"
18+
#include <set>
1719

1820
namespace tensorrt_llm::runtime::ub
1921
{
2022
UserBufferAllocator& UserBufferAllocator::Instance()
2123
{
22-
static UserBufferAllocator _;
23-
return _;
24+
// if environment variable TLLM_USE_NCCL_UB is set to 1, use NCCLUserBufferAllocator
25+
char* useNCCLUB = std::getenv("TLLM_USE_NCCL_UB");
26+
if (useNCCLUB != nullptr)
27+
{
28+
static NCCLUserBufferAllocator _;
29+
return _;
30+
}
31+
else
32+
{
33+
static UserBufferAllocator _;
34+
return _;
35+
}
2436
}
2537

26-
void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& world_config)
38+
void UserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig)
2739
{
28-
if (!is_initialized())
40+
if (!isInitialized())
2941
{
30-
ub_comm_ = nullptr;
31-
world_config_ = world_config;
32-
create_communicator_grouped2(&ub_comm_, world_config_);
33-
TLLM_CHECK(ub_comm_ != nullptr);
34-
is_initialized_ = true;
42+
mUbComm = nullptr;
43+
mWorldConfig = worldConfig;
44+
create_communicator_grouped2(&mUbComm, worldConfig);
45+
TLLM_CHECK(mUbComm != nullptr);
46+
mIsInitialized = true;
3547
}
3648
}
3749

38-
bool UserBufferAllocator::is_initialized()
50+
bool UserBufferAllocator::isInitialized()
3951
{
40-
return is_initialized_;
52+
return mIsInitialized;
4153
}
4254

43-
UBBuffer UserBufferAllocator::register_ub_buffer(size_t bytes)
55+
UBBuffer UserBufferAllocator::registerUBBuffer(size_t bytes)
4456
{
45-
TLLM_CHECK(is_initialized());
57+
TLLM_CHECK(isInitialized());
4658
void* addr = nullptr;
4759
int handle = -1;
48-
handle = register_user_buffer_collective((void**) &addr, bytes, ub_comm_);
60+
handle = register_user_buffer_collective((void**) &addr, bytes, mUbComm);
4961
return {addr, handle, bytes};
5062
}
5163

5264
UBBuffer UserBufferAllocator::allocate(size_t bytes)
5365
{
54-
TLLM_CHECK(is_initialized());
55-
auto ub_buffer = register_ub_buffer(bytes);
66+
TLLM_CHECK(isInitialized());
67+
auto ub_buffer = registerUBBuffer(bytes);
5668
TLLM_CHECK(!ub_buffer.invalid());
57-
buffers_.push_back(ub_buffer);
69+
mBuffers.push_back(ub_buffer);
5870
return ub_buffer;
5971
}
6072

6173
void UserBufferAllocator::deallocate(void* addr) {}
6274

6375
UBBuffer UserBufferAllocator::get(int idx)
6476
{
65-
TLLM_CHECK(is_initialized() && idx < buffers_.size() && !buffers_[idx].invalid());
66-
return buffers_[idx];
77+
TLLM_CHECK(isInitialized() && idx < mBuffers.size() && !mBuffers[idx].invalid());
78+
return mBuffers[idx];
6779
}
6880

6981
communicator* UserBufferAllocator::comm()
7082
{
71-
TLLM_CHECK(is_initialized());
72-
return ub_comm_;
83+
TLLM_CHECK(isInitialized());
84+
return mUbComm;
85+
}
86+
87+
void NCCLUserBufferAllocator::initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig)
88+
{
89+
if (!isInitialized())
90+
{
91+
TLLM_LOG_INFO("Initializing NCCLUserBufferAllocator");
92+
std::set<int> group;
93+
for (int i = 0; i < worldConfig.getSize(); i++)
94+
{
95+
group.insert(i);
96+
}
97+
mComm = getComm(group);
98+
mIsInitialized = true;
99+
}
73100
}
101+
102+
UBBuffer NCCLUserBufferAllocator::registerUBBuffer(size_t bytes)
103+
{
104+
TLLM_CHECK(isInitialized());
105+
UBBuffer ub_buffer;
106+
NCCLCHECK(ncclMemAlloc(&ub_buffer.addr, bytes));
107+
NCCLCHECK(ncclCommWindowRegister((*mComm), ub_buffer.addr, bytes, &ub_buffer.window, NCCL_WIN_COLL_SYMMETRIC));
108+
ub_buffer.handle = 5;
109+
ub_buffer.size = bytes;
110+
return ub_buffer;
111+
}
112+
74113
}; // namespace tensorrt_llm::runtime::ub

cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ struct UBBuffer
2828
void* addr;
2929
int handle;
3030
size_t size;
31+
ncclWindow_t window;
3132

32-
UBBuffer(void* a = nullptr, int h = -1, size_t s = 0)
33+
UBBuffer(void* a = nullptr, int h = -1, size_t s = 0, ncclWindow_t w = nullptr)
3334
: addr(a)
3435
, handle(h)
3536
, size(s)
37+
, window(w)
3638
{
3739
}
3840

@@ -49,21 +51,33 @@ class UserBufferAllocator
4951

5052
UserBufferAllocator() = default;
5153

52-
void initialize(tensorrt_llm::runtime::WorldConfig const& world_config);
53-
bool is_initialized();
54+
virtual void initialize(tensorrt_llm::runtime::WorldConfig const& worldConfig);
55+
bool isInitialized();
5456
UBBuffer allocate(size_t bytes);
5557
void deallocate(void* addr);
5658
UBBuffer get(int idx);
5759
communicator* comm();
60+
virtual UBBuffer registerUBBuffer(size_t bytes);
5861

5962
private:
60-
UBBuffer register_ub_buffer(size_t bytes);
63+
communicator* mUbComm;
6164

62-
communicator* ub_comm_;
63-
std::vector<UBBuffer> buffers_;
64-
bool is_initialized_;
65-
tensorrt_llm::runtime::WorldConfig world_config_;
65+
protected:
66+
std::vector<UBBuffer> mBuffers;
67+
bool mIsInitialized;
68+
tensorrt_llm::runtime::WorldConfig mWorldConfig;
6669
};
70+
71+
class NCCLUserBufferAllocator : public UserBufferAllocator
72+
{
73+
public:
74+
void initialize(tensorrt_llm::runtime::WorldConfig const& world_config) override;
75+
UBBuffer registerUBBuffer(size_t bytes) override;
76+
77+
private:
78+
std::shared_ptr<ncclComm_t> mComm;
79+
};
80+
6781
#else
6882
using communicator = void;
6983
#endif

cpp/tensorrt_llm/kernels/userbuffers/ub_interface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void ub_initialize(int tp_size)
3636

3737
bool ub_is_initialized()
3838
{
39-
return UserBufferAllocator::Instance().is_initialized();
39+
return UserBufferAllocator::Instance().isInitialized();
4040
}
4141

4242
UBBuffer ub_allocate(size_t bytes)

cpp/tensorrt_llm/thop/allreduceOp.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,10 @@ class AllreduceOp
161161
size_t size = input.numel();
162162
size_t seq_len = input.size(0);
163163

164-
// If strategy is set to UB, UB must be used as UB impl output is special and cannot be used
165-
// by others.
164+
if (std::getenv("TLLM_USE_NCCL_UB") && mStrategy == AllReduceStrategyType::UB)
165+
{
166+
return runNCCLAllReduceUB(input, residual, norm_weight, scale, bias);
167+
}
166168
AllReduceStrategyType runtime_strategy = getRuntimeStrategy(seq_len, size);
167169

168170
// Log runtime strategy
@@ -299,6 +301,30 @@ class AllreduceOp
299301
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output);
300302
}
301303

304+
std::vector<torch::Tensor> runNCCLAllReduceUB(torch::Tensor const& input,
305+
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
306+
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept
307+
{
308+
309+
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
310+
int size = input.numel();
311+
auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance();
312+
auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr());
313+
TLLM_CHECK(!ub_buffer0.invalid());
314+
auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type());
315+
316+
NCCLCHECK(ncclAllReduce(
317+
input.data_ptr(), norm_out.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum, *mNcclComm, stream));
318+
319+
if (mOp == AllReduceFusionOp::NONE)
320+
{
321+
return {norm_out};
322+
}
323+
324+
// Treat any other patterns as fallback cases.
325+
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, norm_out);
326+
}
327+
302328
std::vector<torch::Tensor> runLowPrecisionAllReduce(torch::Tensor const& input,
303329
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
304330
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept

tensorrt_llm/_torch/compilation/backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def get_custom_pass(cls, enable_userbuffers):
6464
register_ar_residual_norm(cls._custom_pass_instances[0])
6565
if enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
6666
):
67+
print("Registering UB patterns", flush=True)
6768
register_ub_patterns(cls._custom_pass_instances)
6869
else:
6970
register_add_norm(cls._custom_pass_instances[0])

tensorrt_llm/_torch/compilation/patterns/ub_allreduce.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,61 @@ def extra_check_fp4_quant_pattern(match: Match) -> bool:
180180
extra_check=extra_check_fp4_quant_pattern,
181181
)
182182

183+
def register_no_quant_pattern(custom_pass: PatternMatcherPass):
184+
input_node = KeywordArg('input')
185+
fusion = KeywordArg('fusion_op')
186+
trtllm_allreduce_default = CallFunction(
187+
torch.ops.trtllm.allreduce.default, input_node,
188+
KeywordArg('residual_in'), KeywordArg('gamma'), Ignored(),
189+
Ignored(), Ignored(), mapping.tp_group, strategy, fusion,
190+
KeywordArg('eps'))
191+
no_quant_pattern = MultiOutputPattern([trtllm_allreduce_default])
192+
193+
def empty_no_quant_pattern(
194+
input: torch.Tensor,
195+
residual_in: torch.Tensor,
196+
gamma: torch.Tensor,
197+
eps: float,
198+
):
199+
return
200+
201+
def target_no_quant_pattern(
202+
input: torch.Tensor,
203+
residual_in: torch.Tensor,
204+
gamma: torch.Tensor,
205+
eps: float,
206+
):
207+
input = torch.ops.trtllm.copy_to_userbuffers(input)
208+
all_reduce_output = torch.ops.trtllm.allreduce(
209+
input, residual_in, gamma, None, None, None,
210+
mapping.tp_group, int(AllReduceStrategy.UB), fusion, eps)
211+
finalize_output = torch.ops.trtllm.userbuffers_allreduce_finalize(
212+
all_reduce_output[-1], False)
213+
return all_reduce_output[0], finalize_output
214+
215+
def extra_check_no_quant_pattern(match: Match) -> bool:
216+
input = match.ctx.pattern_to_node[input_node]
217+
if not isinstance(input, torch.fx.graph.Node):
218+
return False
219+
dtype = input.meta["tensor_meta"].dtype
220+
# UB only supports FP16/BF16 input
221+
if dtype != torch.float16 and dtype != torch.bfloat16:
222+
return False
223+
return True
224+
225+
register_replacement(
226+
empty_no_quant_pattern,
227+
target_no_quant_pattern,
228+
[],
229+
fwd_only,
230+
custom_pass,
231+
search_fn_pattern=no_quant_pattern,
232+
extra_check=extra_check_no_quant_pattern,
233+
)
234+
183235
register_fp8_quant_pattern(custom_pass)
184236
register_fp4_quant_pattern(custom_pass)
237+
# register_no_quant_pattern(custom_pass)
185238

186239
def register_convert_supported_ar_to_ub(custom_pass: PatternMatcherPass):
187240
strategy = int(AllReduceStrategy.AUTO)

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import math
23
import os
34
import threading
@@ -11,6 +12,7 @@
1112
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
1213

1314
_thread_local = threading.local()
15+
logger = logging.getLogger(__name__)
1416

1517

1618
def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,16 @@ def __init__(
349349
self.model.vocab_size_padded)
350350

351351
self._torch_compile_backend = None
352+
print(
353+
f"torch_compile_enabled: {pytorch_backend_config.torch_compile_enabled}",
354+
flush=True)
352355

353356
try:
354357
if pytorch_backend_config.torch_compile_enabled:
355358
set_torch_compiling(True)
356359
use_ub = pytorch_backend_config.torch_compile_enable_userbuffers and self._init_userbuffers(
357360
self.model.config.hidden_size)
361+
print(f"use_ub: {use_ub}", flush=True)
358362
self._torch_compile_backend = Backend(
359363
pytorch_backend_config.torch_compile_inductor_enabled,
360364
enable_userbuffers=use_ub,

0 commit comments

Comments
 (0)