Skip to content

Commit 9661360

Browse files
committed
Enable NCCL symmetric for non-torch compile path
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 1c9374a commit 9661360

File tree

14 files changed

+61
-92
lines changed

14 files changed

+61
-92
lines changed

cpp/tensorrt_llm/kernels/customAllReduceKernels.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ enum class AllReduceStrategyType : int8_t
5656
ONESHOT = 4,
5757
TWOSHOT = 5,
5858
LOWPRECISION = 6,
59+
MNNVL = 7,
60+
NCCL_SYMMETRIC = 8,
5961
};
6062

6163
enum class AllReduceStrategyConfig : int8_t

cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ namespace tensorrt_llm::runtime::ub
2121
{
2222
UserBufferAllocator& UserBufferAllocator::Instance()
2323
{
24-
// if environment variable TLLM_USE_NCCL_UB is set to 1, use NCCLUserBufferAllocator
25-
char* useNCCLUB = std::getenv("TLLM_USE_NCCL_UB");
26-
if (useNCCLUB != nullptr)
24+
if (use_nccl_symmetric)
2725
{
2826
static NCCLUserBufferAllocator _;
2927
return _;
@@ -110,4 +108,6 @@ UBBuffer NCCLUserBufferAllocator::registerUBBuffer(size_t bytes)
110108
return ub_buffer;
111109
}
112110

111+
bool UserBufferAllocator::use_nccl_symmetric = false;
112+
113113
}; // namespace tensorrt_llm::runtime::ub

cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class UserBufferAllocator
5959
communicator* comm();
6060
virtual UBBuffer registerUBBuffer(size_t bytes);
6161

62+
static bool use_nccl_symmetric;
63+
6264
private:
6365
communicator* mUbComm;
6466

cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ UserBuffersManager& UserBuffersManager::get_instance()
2929
return allocator;
3030
}
3131

32-
void UserBuffersManager::initialize(
33-
int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size)
32+
void UserBuffersManager::initialize(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank,
33+
int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric)
3434
{
3535
std::lock_guard<std::mutex> lock(mutex_);
3636
tensorrt_llm::runtime::WorldConfig world_config(tp_size, pp_size, cp_size, rank, gpus_per_node);
37+
UserBufferAllocator::Instance().use_nccl_symmetric = use_nccl_symmetric;
3738
tensorrt_llm::runtime::ub::ub_initialize(world_config);
3839
TLLM_CHECK(tensorrt_llm::runtime::ub::ub_is_initialized());
3940
buffer_size_ = buffer_size;
@@ -95,10 +96,11 @@ tensorrt_llm::runtime::ub::communicator* UserBuffersManager::comm()
9596
return tensorrt_llm::runtime::ub::ub_comm();
9697
}
9798

98-
void initialize_userbuffers_manager(
99-
int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size)
99+
void initialize_userbuffers_manager(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank,
100+
int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric)
100101
{
101-
UserBuffersManager::get_instance().initialize(tp_size, pp_size, cp_size, rank, gpus_per_node, buffer_size);
102+
UserBuffersManager::get_instance().initialize(
103+
tp_size, pp_size, cp_size, rank, gpus_per_node, buffer_size, use_nccl_symmetric);
102104
}
103105

104106
} // namespace tensorrt_llm::runtime::ub

cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ class UserBuffersManager
4646
//! @param gpus_per_node The number of GPUs per node.
4747
//! @param buffer_size The size of the buffer to allocate. All buffers allocated by this manager will have this
4848
//! size.
49-
void initialize(
50-
int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size);
49+
//! @param use_nccl_symmetric Whether to use NCCL symmetric communication.
50+
void initialize(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node,
51+
int64_t buffer_size, bool use_nccl_symmetric);
5152

5253
//! @brief Create a UB tensor from the given shape, strides and data type. The function will choose available UB
5354
//! buffer or create a new one if no available buffer is found.
@@ -75,7 +76,7 @@ class UserBuffersManager
7576
int64_t buffer_size_;
7677
};
7778

78-
void initialize_userbuffers_manager(
79-
int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size);
79+
void initialize_userbuffers_manager(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank,
80+
int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric);
8081

8182
} // namespace tensorrt_llm::runtime::ub

cpp/tensorrt_llm/pybind/runtime/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,8 @@ void initBindings(pybind11::module_& m)
448448
.value("AUTO", tensorrt_llm::kernels::AllReduceStrategyType::AUTO)
449449
.value("UB", tensorrt_llm::kernels::AllReduceStrategyType::UB)
450450
.value("ONESHOT", tensorrt_llm::kernels::AllReduceStrategyType::ONESHOT)
451-
.value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT);
451+
.value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT)
452+
.value("NCCL_SYMMETRIC", tensorrt_llm::kernels::AllReduceStrategyType::NCCL_SYMMETRIC);
452453

453454
// Initialize MoeLoadBalancer bindings
454455
initMoeBindings(m);

cpp/tensorrt_llm/thop/allreduceOp.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,6 @@ class AllreduceOp
166166
size_t bytes_per_element = input.element_size();
167167
TLLM_LOG_DEBUG("All reduce message size is %zu", size * bytes_per_element);
168168

169-
if (std::getenv("TLLM_USE_NCCL_UB") && mStrategy == AllReduceStrategyType::UB)
170-
{
171-
return runNCCLAllReduceUB(input, residual, norm_weight, scale, bias);
172-
}
173169
AllReduceStrategyType runtime_strategy = getRuntimeStrategy(seq_len, size);
174170

175171
// Log runtime strategy
@@ -181,6 +177,8 @@ class AllreduceOp
181177
{
182178
case AllReduceStrategyType::UB: return runUBAllReduce(input, residual, norm_weight, scale, bias);
183179
case AllReduceStrategyType::NCCL: return runNCCLAllReduce(input, residual, norm_weight, scale, bias);
180+
case AllReduceStrategyType::NCCL_SYMMETRIC:
181+
return runNCCLAllReduceSymmetric(input, residual, norm_weight, scale, bias);
184182
case AllReduceStrategyType::MIN_LATENCY:
185183
case AllReduceStrategyType::ONESHOT:
186184
case AllReduceStrategyType::TWOSHOT:
@@ -307,7 +305,7 @@ class AllreduceOp
307305
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output);
308306
}
309307

310-
std::vector<torch::Tensor> runNCCLAllReduceUB(torch::Tensor const& input,
308+
std::vector<torch::Tensor> runNCCLAllReduceSymmetric(torch::Tensor const& input,
311309
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
312310
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept
313311
{
@@ -316,11 +314,20 @@ class AllreduceOp
316314
int size = input.numel();
317315
auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance();
318316
auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr());
317+
if (ub_buffer0.invalid())
318+
{
319+
auto [symmetric_input, symmetric_ub_buffer0]
320+
= torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type());
321+
cudaMemcpyAsync(symmetric_ub_buffer0.addr, input.data_ptr(), size * input.element_size(),
322+
cudaMemcpyDeviceToDevice, stream);
323+
ub_buffer0 = symmetric_ub_buffer0;
324+
}
325+
319326
TLLM_CHECK(!ub_buffer0.invalid());
320327
auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type());
321328

322329
NCCLCHECK(ncclAllReduce(
323-
input.data_ptr(), norm_out.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum, *mNcclComm, stream));
330+
ub_buffer0.addr, norm_out.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum, *mNcclComm, stream));
324331

325332
if (mOp == AllReduceFusionOp::NONE)
326333
{
@@ -661,6 +668,10 @@ class AllreduceOp
661668
{
662669
runtime_strategy = AllReduceStrategyType::NCCL;
663670
}
671+
else if (mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC)
672+
{
673+
runtime_strategy = AllReduceStrategyType::NCCL_SYMMETRIC;
674+
}
664675
else
665676
{
666677
// This is for DEBUG and BENCHMARK purpose. It will overried the strategy if AUTO is set.
@@ -686,6 +697,11 @@ class AllreduceOp
686697
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL", rank);
687698
break;
688699
}
700+
case AllReduceStrategyType::NCCL_SYMMETRIC:
701+
{
702+
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL_SYMMETRIC", rank);
703+
break;
704+
}
689705
case AllReduceStrategyType::MIN_LATENCY:
690706
{
691707
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: MIN_LATENCY", rank);
@@ -701,7 +717,7 @@ class AllreduceOp
701717
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: LOWPRECISION", rank);
702718
break;
703719
}
704-
default: break;
720+
default: TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UNKNOWN: %d", rank, strategy); break;
705721
}
706722
}
707723

tensorrt_llm/_torch/compilation/backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def get_custom_pass(cls, enable_userbuffers):
6666
register_ar_residual_norm(cls._custom_pass_instances[0])
6767
if enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
6868
):
69-
print("Registering UB patterns", flush=True)
7069
register_ub_patterns(cls._custom_pass_instances)
7170
else:
7271
register_add_norm(cls._custom_pass_instances[0])

tensorrt_llm/_torch/compilation/patterns/ub_allreduce.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -184,61 +184,8 @@ def extra_check_fp4_quant_pattern(match: Match) -> bool:
184184
extra_check=extra_check_fp4_quant_pattern,
185185
)
186186

187-
def register_no_quant_pattern(custom_pass: PatternMatcherPass):
188-
input_node = KeywordArg('input')
189-
fusion = KeywordArg('fusion_op')
190-
trtllm_allreduce_default = CallFunction(
191-
torch.ops.trtllm.allreduce.default, input_node,
192-
KeywordArg('residual_in'), KeywordArg('gamma'), Ignored(),
193-
Ignored(), Ignored(), mapping.tp_group, strategy, fusion,
194-
KeywordArg('eps'))
195-
no_quant_pattern = MultiOutputPattern([trtllm_allreduce_default])
196-
197-
def empty_no_quant_pattern(
198-
input: torch.Tensor,
199-
residual_in: torch.Tensor,
200-
gamma: torch.Tensor,
201-
eps: float,
202-
):
203-
return
204-
205-
def target_no_quant_pattern(
206-
input: torch.Tensor,
207-
residual_in: torch.Tensor,
208-
gamma: torch.Tensor,
209-
eps: float,
210-
):
211-
input = torch.ops.trtllm.copy_to_userbuffers(input)
212-
all_reduce_output = torch.ops.trtllm.allreduce(
213-
input, residual_in, gamma, None, None, None,
214-
mapping.tp_group, int(AllReduceStrategy.UB), fusion, eps)
215-
finalize_output = torch.ops.trtllm.userbuffers_allreduce_finalize(
216-
all_reduce_output[-1], False)
217-
return all_reduce_output[0], finalize_output
218-
219-
def extra_check_no_quant_pattern(match: Match) -> bool:
220-
input = match.ctx.pattern_to_node[input_node]
221-
if not isinstance(input, torch.fx.graph.Node):
222-
return False
223-
dtype = input.meta["tensor_meta"].dtype
224-
# UB only supports FP16/BF16 input
225-
if dtype != torch.float16 and dtype != torch.bfloat16:
226-
return False
227-
return True
228-
229-
register_replacement(
230-
empty_no_quant_pattern,
231-
target_no_quant_pattern,
232-
[],
233-
fwd_only,
234-
custom_pass,
235-
search_fn_pattern=no_quant_pattern,
236-
extra_check=extra_check_no_quant_pattern,
237-
)
238-
239187
register_fp8_quant_pattern(custom_pass)
240188
register_fp4_quant_pattern(custom_pass)
241-
# register_no_quant_pattern(custom_pass)
242189

243190
def register_convert_supported_ar_to_ub(custom_pass: PatternMatcherPass):
244191
strategy = int(AllReduceStrategy.AUTO)

tensorrt_llm/_torch/model_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def get_all_reduce_strategy(strategy: str = "AUTO"):
125125
"ONESHOT": AllReduceStrategy.ONESHOT,
126126
"TWOSHOT": AllReduceStrategy.TWOSHOT,
127127
"LOWPRECISION": AllReduceStrategy.LOWPRECISION,
128-
"MNNVL": AllReduceStrategy.MNNVL
128+
"MNNVL": AllReduceStrategy.MNNVL,
129+
"NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC
129130
}
130131
key = strategy.upper()
131132
return maps[key] if key in maps else AllReduceStrategy.AUTO

0 commit comments

Comments
 (0)