diff --git a/python/paddle/distributed/communication/stream/all_to_all.py b/python/paddle/distributed/communication/stream/all_to_all.py index 9c24d71cb0f7d5..dcdc85ccebe094 100644 --- a/python/paddle/distributed/communication/stream/all_to_all.py +++ b/python/paddle/distributed/communication/stream/all_to_all.py @@ -27,10 +27,10 @@ def _all_to_all_tensor_in_dygraph( ): if use_calc_stream: return group.process_group.all_to_all_tensor_on_calc_stream( - in_tensor, out_tensor + out_tensor, in_tensor ) - task = group.process_group.all_to_all_tensor(in_tensor, out_tensor, sync_op) + task = group.process_group.all_to_all_tensor(out_tensor, in_tensor, sync_op) if sync_op: task.wait()