diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 3b4094f047552..92601cdb1c4c6 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -438,10 +438,11 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): group=mp_group, use_calc_stream=True)) - for task in tasks: - # wait partial all gather tasks - if task is not None: - task.wait() + if in_dygraph_mode(): + for task in tasks: + # wait partial all gather tasks + if task is not None: + task.wait() return tensor_recv_prev, tensor_recv_next