diff --git a/python/paddle/distributed/communication/stream/all_to_all.py b/python/paddle/distributed/communication/stream/all_to_all.py index 202f2b6c6c8d8d..a2902a07509178 100644 --- a/python/paddle/distributed/communication/stream/all_to_all.py +++ b/python/paddle/distributed/communication/stream/all_to_all.py @@ -27,10 +27,10 @@ def _all_to_all_tensor_in_dygraph( ): if use_calc_stream: return group.process_group.all_to_all_tensor_on_calc_stream( - out_tensor, in_tensor + in_tensor, out_tensor ) - task = group.process_group.all_to_all_tensor(out_tensor, in_tensor, sync_op) + task = group.process_group.all_to_all_tensor(in_tensor, out_tensor, sync_op) if sync_op: task.wait()