@@ -166,10 +166,6 @@ class AllreduceOp
166166 size_t bytes_per_element = input.element_size ();
167167 TLLM_LOG_DEBUG (" All reduce message size is %zu" , size * bytes_per_element);
168168
169- if (std::getenv (" TLLM_USE_NCCL_UB" ) && mStrategy == AllReduceStrategyType::UB)
170- {
171- return runNCCLAllReduceUB (input, residual, norm_weight, scale, bias);
172- }
173169 AllReduceStrategyType runtime_strategy = getRuntimeStrategy (seq_len, size);
174170
175171 // Log runtime strategy
@@ -181,6 +177,8 @@ class AllreduceOp
181177 {
182178 case AllReduceStrategyType::UB: return runUBAllReduce (input, residual, norm_weight, scale, bias);
183179 case AllReduceStrategyType::NCCL: return runNCCLAllReduce (input, residual, norm_weight, scale, bias);
180+ case AllReduceStrategyType::NCCL_SYMMETRIC:
181+ return runNCCLAllReduceSymmetric (input, residual, norm_weight, scale, bias);
184182 case AllReduceStrategyType::MIN_LATENCY:
185183 case AllReduceStrategyType::ONESHOT:
186184 case AllReduceStrategyType::TWOSHOT:
@@ -307,7 +305,7 @@ class AllreduceOp
307305 return fallbackRunSubsequentOps (input, residual, norm_weight, scale, bias, reduce_output);
308306 }
309307
310- std::vector<torch::Tensor> runNCCLAllReduceUB (torch::Tensor const & input,
308+ std::vector<torch::Tensor> runNCCLAllReduceSymmetric (torch::Tensor const & input,
311309 torch::optional<torch::Tensor> const & residual, torch::optional<torch::Tensor> const & norm_weight,
312310 torch::optional<torch::Tensor> const & scale, torch::optional<torch::Tensor> const & bias) noexcept
313311 {
@@ -316,11 +314,20 @@ class AllreduceOp
316314 int size = input.numel ();
317315 auto & ub_manager = tensorrt_llm::runtime::ub::UserBuffersManager::get_instance ();
318316 auto ub_buffer0 = ub_manager.search_buffer (input.data_ptr ());
317+ if (ub_buffer0.invalid ())
318+ {
319+ auto [symmetric_input, symmetric_ub_buffer0]
320+ = torch_ext::create_userbuffers_tensor (input.sizes (), input.scalar_type ());
321+ cudaMemcpyAsync (symmetric_ub_buffer0.addr , input.data_ptr (), size * input.element_size (),
322+ cudaMemcpyDeviceToDevice, stream);
323+ ub_buffer0 = symmetric_ub_buffer0;
324+ }
325+
319326 TLLM_CHECK (!ub_buffer0.invalid ());
320327 auto [norm_out, ub_buffer1] = torch_ext::create_userbuffers_tensor (input.sizes (), input.scalar_type ());
321328
322329 NCCLCHECK (ncclAllReduce (
323- input. data_ptr () , norm_out.mutable_data_ptr (), size, (*getDtypeMap ())[mType ], ncclSum, *mNcclComm , stream));
330+ ub_buffer0. addr , norm_out.mutable_data_ptr (), size, (*getDtypeMap ())[mType ], ncclSum, *mNcclComm , stream));
324331
325332 if (mOp == AllReduceFusionOp::NONE)
326333 {
@@ -661,6 +668,10 @@ class AllreduceOp
661668 {
662669 runtime_strategy = AllReduceStrategyType::NCCL;
663670 }
671+ else if (mStrategy == AllReduceStrategyType::NCCL_SYMMETRIC)
672+ {
673+ runtime_strategy = AllReduceStrategyType::NCCL_SYMMETRIC;
674+ }
664675 else
665676 {
666677 // This is for DEBUG and BENCHMARK purpose. It will overried the strategy if AUTO is set.
@@ -686,6 +697,11 @@ class AllreduceOp
686697 TLLM_LOG_DEBUG (" AllReducePlugin strategy for rank %d: NCCL" , rank);
687698 break ;
688699 }
700+ case AllReduceStrategyType::NCCL_SYMMETRIC:
701+ {
702+ TLLM_LOG_DEBUG (" AllReducePlugin strategy for rank %d: NCCL_SYMMETRIC" , rank);
703+ break ;
704+ }
689705 case AllReduceStrategyType::MIN_LATENCY:
690706 {
691707 TLLM_LOG_DEBUG (" AllReducePlugin strategy for rank %d: MIN_LATENCY" , rank);
@@ -701,7 +717,7 @@ class AllreduceOp
701717 TLLM_LOG_DEBUG (" AllReducePlugin strategy for rank %d: LOWPRECISION" , rank);
702718 break ;
703719 }
704- default : break ;
720+ default : TLLM_LOG_DEBUG ( " AllReducePlugin strategy for rank %d: UNKNOWN: %d " , rank, strategy); break ;
705721 }
706722 }
707723
0 commit comments