diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py index 10b2bf4e251..7d19c9cd624 100644 --- a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py +++ b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py @@ -28,7 +28,7 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None): p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) # Use Strategy.AUTO for optimal performance _allreduce_cache[cache_key] = AllReduce( - mapping=p_config, strategy=AllReduceStrategy.AUTO, dtype=tensor.dtype + mapping=p_config, strategy=AllReduceStrategy.NCCL, dtype=tensor.dtype ) torch_op = _allreduce_cache[cache_key]