Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
haohongxiang committed Oct 14, 2021
1 parent 6a637a3 commit 2cf4707
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
9 changes: 5 additions & 4 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,15 +1586,16 @@ def unscale_method(self, optimizer):
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
param_grads_fp32,
temp_found_inf_fp32)
self._found_inf = temp_found_inf_fp16 or temp_found_inf_fp32
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0

# TODO(shenliang03) Since dp allreduce in the optimizer is
# after the gradscaler, check_finite needs to synchronize global
# information. In the future, we should use check_group to speed.
self._found_inf = paddle.cast(self._found_inf, dtype="int32")
paddle.distributed.all_reduce(
self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
self._found_inf = paddle.cast(self._found_inf, dtype="bool")
paddle.to_tensor(
[self._found_inf], dtype="int32"),
op=paddle.distributed.ReduceOp.MAX,
group=None)

# Only tensor_parallel and pipeline_parallel need to modify scaler
if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,
Expand Down
10 changes: 8 additions & 2 deletions python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,14 @@ def forward(ctx, run_function, all_outputs, *args):

# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = False if tracer._amp_level == 0 else True
ctx.amp_level = 'O2' if tracer._amp_level == 2 else 'O1'
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level == core.AmpLevel.O1:
ctx.amp_level = 'O1'
else:
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

with paddle.no_grad():
Expand Down

0 comments on commit 2cf4707

Please sign in to comment.