Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HybridParallel]fix bug of check_inf in fleet_base.py #36651

Merged
merged 2 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,16 +1586,16 @@ def unscale_method(self, optimizer):
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
param_grads_fp32,
temp_found_inf_fp32)

self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")

# TODO(shenliang03) Since dp allreduce in the optimizer is
# after the gradscaler, check_finite needs to synchronize global
# information. In the future, we should use check_group to speed.
paddle.distributed.all_reduce(
paddle.to_tensor(
[self._found_inf], dtype="int32"),
op=paddle.distributed.ReduceOp.MAX,
group=None)
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
self._found_inf = is_found_inf.numpy()[0]

# Only tensor_parallel and pipeline_parallel need to modify scaler
if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,14 @@ def _apply_collective_grads(parameters, comm_group):
nranks = paddle.distributed.get_world_size(
) if comm_group is None else comm_group.nranks
div_factor = paddle.to_tensor(nranks, dtype=coalesced_grad.dtype)
paddle.distributed.all_reduce(coalesced_grad, group=comm_group)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="elementwise_div",
inputs={'X': coalesced_grad,
'Y': div_factor},
outputs={'Out': coalesced_grad},
attrs={'axis': -1})

paddle.distributed.all_reduce(coalesced_grad, group=comm_group)

_split_tensors(coalesced_grads_and_vars)


Expand Down