Skip to content

Commit ae752b7

Browse files
committed
Enable NCCL symmetric for non-torch compile path
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent d8614d1 commit ae752b7

File tree

13 files changed

+61
-94
lines changed

13 files changed

+61
-94
lines changed

cpp/tensorrt_llm/kernels/customAllReduceKernels.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ enum class AllReduceStrategyType : int8_t
5656
ONESHOT = 4,
5757
TWOSHOT = 5,
5858
LOWPRECISION = 6,
59+
MNNVL = 7,
60+
NCCL_SYMMETRIC = 8,
5961
};
6062

6163
enum class AllReduceStrategyConfig : int8_t

cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ namespace tensorrt_llm::runtime::ub
2121
{
2222
UserBufferAllocator& UserBufferAllocator::Instance()
2323
{
24-
// if environment variable TLLM_USE_NCCL_UB is set to 1, use NCCLUserBufferAllocator
25-
char* useNCCLUB = std::getenv("TLLM_USE_NCCL_UB");
26-
if (useNCCLUB != nullptr)
24+
if (use_nccl_symmetric)
2725
{
2826
static NCCLUserBufferAllocator _;
2927
return _;
@@ -110,4 +108,6 @@ UBBuffer NCCLUserBufferAllocator::registerUBBuffer(size_t bytes)
110108
return ub_buffer;
111109
}
112110

111+
bool UserBufferAllocator::use_nccl_symmetric = false;
112+
113113
}; // namespace tensorrt_llm::runtime::ub

cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class UserBufferAllocator
5959
communicator* comm();
6060
virtual UBBuffer registerUBBuffer(size_t bytes);
6161

62+
static bool use_nccl_symmetric;
63+
6264
private:
6365
communicator* mUbComm;
6466

cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ UserBuffersManager& UserBuffersManager::get_instance()
2929
return allocator;
3030
}
3131

32-
void UserBuffersManager::initialize(
33-
int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size)
32+
void UserBuffersManager::initialize(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank,
33+
int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric)
3434
{
3535
std::lock_guard<std::mutex> lock(mutex_);
3636
tensorrt_llm::runtime::WorldConfig world_config(tp_size, pp_size, cp_size, rank, gpus_per_node);
37+
UserBufferAllocator::Instance().use_nccl_symmetric = use_nccl_symmetric;
3738
tensorrt_llm::runtime::ub::ub_initialize(world_config);
3839
TLLM_CHECK(tensorrt_llm::runtime::ub::ub_is_initialized());
3940
buffer_size_ = buffer_size;
@@ -95,10 +96,11 @@ tensorrt_llm::runtime::ub::communicator* UserBuffersManager::comm()
9596
return tensorrt_llm::runtime::ub::ub_comm();
9697
}
9798

98-
void initialize_userbuffers_manager(
99-
int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size)
99+
void initialize_userbuffers_manager(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank,
100+
int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric)
100101
{
101-
UserBuffersManager::get_instance().initialize(tp_size, pp_size, cp_size, rank, gpus_per_node, buffer_size);
102+
UserBuffersManager::get_instance().initialize(
103+
tp_size, pp_size, cp_size, rank, gpus_per_node, buffer_size, use_nccl_symmetric);
102104
}
103105

104106
} // namespace tensorrt_llm::runtime::ub

cpp/tensorrt_llm/kernels/userbuffers/userbuffersManager.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ class UserBuffersManager
4646
//! @param gpus_per_node The number of GPUs per node.
4747
//! @param buffer_size The size of the buffer to allocate. All buffers allocated by this manager will have this
4848
//! size.
49-
void initialize(
50-
int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size);
49+
//! @param use_nccl_symmetric Whether to use NCCL symmetric communication.
50+
void initialize(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node,
51+
int64_t buffer_size, bool use_nccl_symmetric);
5152

5253
//! @brief Create a UB tensor from the given shape, strides and data type. The function will choose available UB
5354
//! buffer or create a new one if no available buffer is found.
@@ -75,7 +76,7 @@ class UserBuffersManager
7576
int64_t buffer_size_;
7677
};
7778

78-
void initialize_userbuffers_manager(
79-
int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank, int64_t gpus_per_node, int64_t buffer_size);
79+
void initialize_userbuffers_manager(int64_t tp_size, int64_t pp_size, int64_t cp_size, int64_t rank,
80+
int64_t gpus_per_node, int64_t buffer_size, bool use_nccl_symmetric);
8081

8182
} // namespace tensorrt_llm::runtime::ub

cpp/tensorrt_llm/pybind/runtime/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,8 @@ void initBindings(pybind11::module_& m)
448448
.value("AUTO", tensorrt_llm::kernels::AllReduceStrategyType::AUTO)
449449
.value("UB", tensorrt_llm::kernels::AllReduceStrategyType::UB)
450450
.value("ONESHOT", tensorrt_llm::kernels::AllReduceStrategyType::ONESHOT)
451-
.value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT);
451+
.value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT)
452+
.value("NCCL_SYMMETRIC", tensorrt_llm::kernels::AllReduceStrategyType::NCCL_SYMMETRIC);
452453

453454
// Initialize MoeLoadBalancer bindings
454455
initMoeBindings(m);

cpp/tensorrt_llm/thop/allreduceOp.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,6 @@ class AllreduceOp
166166
size_t bytes_per_element = input.element_size();
167167
TLLM_LOG_DEBUG("All reduce message size is %zu", size * bytes_per_element);
168168

169-
if (std::getenv("TLLM_USE_NCCL_UB") && mStrategy == AllReduceStrategyType::UB)
170-
{
171-
return runNCCLAllReduceUB(input, residual, norm_weight, scale, bias);
172-
}
173169
AllReduceStrategyType runtime_strategy = getRuntimeStrategy(seq_len, size);
174170

175171
// Log runtime strategy
@@ -181,6 +177,8 @@ class AllreduceOp
181177
{
182178
case AllReduceStrategyType::UB: return runUBAllReduce(input, residual, norm_weight, scale, bias);
183179
case AllReduceStrategyType::NCCL: return runNCCLAllReduce(input, residual, norm_weight, scale, bias);
180+
case AllReduceStrategyType::NCCL_SYMMETRIC:
181+
return runNCCLAllReduceSymmetric(input, residual, norm_weight, scale, bias);
184182
case AllReduceStrategyType::MIN_LATENCY:
185183
case AllReduceStrategyType::ONESHOT:
186184
case AllReduceStrategyType::TWOSHOT:
@@ -307,7 +305,7 @@ class AllreduceOp
307305
return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output);
308306
}
309307

310-
std::vector<torch::Tensor> runNCCLAllReduceUB(torch::Tensor const& input,
308+
std::vector<torch::Tensor> runNCCLAllReduceSymmetric(torch::Tensor const& input,
311309
torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
312310
torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept
313311
{
@@ -316,11 +314,20 @@ class AllreduceOp
316314
int size = input.numel();
317315
auto& ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance();
318316
auto ub_buffer0 = ub_manager.search_buffer(input.data_ptr());
317+
if (ub_buffer0.invalid())
318+
{
319+
auto [symmetric_input, symmetric_ub_buffer0]
320+
= torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type());
321+
cudaMemcpyAsync(symmetric_ub_buffer0.addr, input.data_ptr(), size * input.element_size(),
322+
cudaMemcpyDeviceToDevice, stream);
323+
ub_buffer0 = symmetric_ub_buffer0;
324+
}
325+
319326
TLLM_CHECK(!ub_buffer0.invalid());
320327
auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor(input.sizes(), input.scalar_type());
321328

322329
NCCLCHECK(ncclAllReduce(
323-
input.data_ptr(), norm_out.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum, *mNcclComm, stream));
330+
ub_buffer0.addr, norm_out.mutable_data_ptr(), size, (*getDtypeMap())[mType], ncclSum, *mNcclComm, stream));
324331

325332
if (mOp == AllReduceFusionOp::NONE)
326333
{
@@ -661,6 +668,10 @@ class AllreduceOp
661668
{
662669
runtime_strategy = AllReduceStrategyType::NCCL;
663670
}
671+
else if (mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC)
672+
{
673+
runtime_strategy = AllReduceStrategyType::NCCL_SYMMETRIC;
674+
}
664675
else
665676
{
666677
// This is for DEBUG and BENCHMARK purpose. It will overried the strategy if AUTO is set.
@@ -686,6 +697,11 @@ class AllreduceOp
686697
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL", rank);
687698
break;
688699
}
700+
case AllReduceStrategyType::NCCL_SYMMETRIC:
701+
{
702+
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: NCCL_SYMMETRIC", rank);
703+
break;
704+
}
689705
case AllReduceStrategyType::MIN_LATENCY:
690706
{
691707
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: MIN_LATENCY", rank);
@@ -701,7 +717,7 @@ class AllreduceOp
701717
TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: LOWPRECISION", rank);
702718
break;
703719
}
704-
default: break;
720+
default: TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UNKNOWN: %d", rank, strategy); break;
705721
}
706722
}
707723

tensorrt_llm/_torch/compilation/patterns/ub_allreduce.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -184,61 +184,8 @@ def extra_check_fp4_quant_pattern(match: Match) -> bool:
184184
extra_check=extra_check_fp4_quant_pattern,
185185
)
186186

187-
def register_no_quant_pattern(custom_pass: PatternMatcherPass):
188-
input_node = KeywordArg('input')
189-
fusion = KeywordArg('fusion_op')
190-
trtllm_allreduce_default = CallFunction(
191-
torch.ops.trtllm.allreduce.default, input_node,
192-
KeywordArg('residual_in'), KeywordArg('gamma'), Ignored(),
193-
Ignored(), Ignored(), mapping.tp_group, strategy, fusion,
194-
KeywordArg('eps'))
195-
no_quant_pattern = MultiOutputPattern([trtllm_allreduce_default])
196-
197-
def empty_no_quant_pattern(
198-
input: torch.Tensor,
199-
residual_in: torch.Tensor,
200-
gamma: torch.Tensor,
201-
eps: float,
202-
):
203-
return
204-
205-
def target_no_quant_pattern(
206-
input: torch.Tensor,
207-
residual_in: torch.Tensor,
208-
gamma: torch.Tensor,
209-
eps: float,
210-
):
211-
input = torch.ops.trtllm.copy_to_userbuffers(input)
212-
all_reduce_output = torch.ops.trtllm.allreduce(
213-
input, residual_in, gamma, None, None, None,
214-
mapping.tp_group, int(AllReduceStrategy.UB), fusion, eps)
215-
finalize_output = torch.ops.trtllm.userbuffers_allreduce_finalize(
216-
all_reduce_output[-1], False)
217-
return all_reduce_output[0], finalize_output
218-
219-
def extra_check_no_quant_pattern(match: Match) -> bool:
220-
input = match.ctx.pattern_to_node[input_node]
221-
if not isinstance(input, torch.fx.graph.Node):
222-
return False
223-
dtype = input.meta["tensor_meta"].dtype
224-
# UB only supports FP16/BF16 input
225-
if dtype != torch.float16 and dtype != torch.bfloat16:
226-
return False
227-
return True
228-
229-
register_replacement(
230-
empty_no_quant_pattern,
231-
target_no_quant_pattern,
232-
[],
233-
fwd_only,
234-
custom_pass,
235-
search_fn_pattern=no_quant_pattern,
236-
extra_check=extra_check_no_quant_pattern,
237-
)
238-
239187
register_fp8_quant_pattern(custom_pass)
240188
register_fp4_quant_pattern(custom_pass)
241-
# register_no_quant_pattern(custom_pass)
242189

243190
def register_convert_supported_ar_to_ub(custom_pass: PatternMatcherPass):
244191
strategy = int(AllReduceStrategy.AUTO)

tensorrt_llm/_torch/model_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def get_all_reduce_strategy(strategy: str = "AUTO"):
126126
"ONESHOT": AllReduceStrategy.ONESHOT,
127127
"TWOSHOT": AllReduceStrategy.TWOSHOT,
128128
"LOWPRECISION": AllReduceStrategy.LOWPRECISION,
129-
"MNNVL": AllReduceStrategy.MNNVL
129+
"MNNVL": AllReduceStrategy.MNNVL,
130+
"NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC
130131
}
131132
key = strategy.upper()
132133
return maps[key] if key in maps else AllReduceStrategy.AUTO

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -316,16 +316,14 @@ def __init__(
316316
self._init_model_capacity()
317317

318318
self._torch_compile_backend = None
319-
print(
320-
f"torch_compile_enabled: {pytorch_backend_config.torch_compile_enabled}",
321-
flush=True)
322319

323320
try:
321+
if pytorch_backend_config.allreduce_strategy == "NCCL_SYMMETRIC":
322+
self._init_userbuffers(self.model.config.hidden_size)
324323
if pytorch_backend_config.torch_compile_enabled:
325324
set_torch_compiling(True)
326325
use_ub = pytorch_backend_config.torch_compile_enable_userbuffers and self._init_userbuffers(
327326
self.model.config.hidden_size)
328-
print(f"use_ub: {use_ub}", flush=True)
329327
self._torch_compile_backend = Backend(
330328
pytorch_backend_config.torch_compile_inductor_enabled,
331329
enable_userbuffers=use_ub,
@@ -997,7 +995,6 @@ def _load_model(self,
997995
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None,
998996
lora_config: Optional[LoraConfig] = None,
999997
**kwargs):
1000-
1001998
config = checkpoint_loader.load_config(
1002999
checkpoint_dir,
10031000
trust_remote_code=True,
@@ -2233,12 +2230,11 @@ def _init_userbuffers(self, hidden_size):
22332230
# Disable UB for unsupported platforms
22342231
if not ub.ub_supported():
22352232
return False
2236-
ub.initialize_userbuffers_manager(self.mapping.tp_size,
2237-
self.mapping.pp_size,
2238-
self.mapping.cp_size,
2239-
self.mapping.rank,
2240-
self.mapping.gpus_per_node,
2241-
hidden_size * self.max_num_tokens * 2)
2233+
ub.initialize_userbuffers_manager(
2234+
self.mapping.tp_size, self.mapping.pp_size, self.mapping.cp_size,
2235+
self.mapping.rank, self.mapping.gpus_per_node,
2236+
hidden_size * self.max_num_tokens * 2, True)
2237+
22422238
return True
22432239

22442240
def load_weights_from_target_model(self,

0 commit comments

Comments
 (0)