Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HybridParallel]Support fp16 in dygraph hybrid parallel #36420

Merged
merged 10 commits into from
Oct 18, 2021
40 changes: 34 additions & 6 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from ..meta_parallel import PipelineParallel, ShardingParallel
from ..meta_optimizers import HybridParallelOptimizer
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable

__all__ = []

Expand Down Expand Up @@ -1548,26 +1550,52 @@ def unscale_method(self, optimizer):
if getattr(optimizer, '_param_groups', None) and isinstance(
optimizer._param_groups[0], dict):
param_grads = []
param_grads_fp16 = []
param_grads_fp32 = []
for group in optimizer._param_groups:
for param in group['params']:
if param._grad_ivar() is not None:
param_grads.append(param._grad_ivar())
if param._grad_ivar(
).dtype == core.VarDesc.VarType.FP16:
param_grads_fp16.append(param._grad_ivar())
else:
param_grads_fp32.append(param._grad_ivar())
else:
param_grads = [
param._grad_ivar() for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
_C_ops.check_finite_and_unscale(param_grads, self._scale,
param_grads, self._found_inf)

self._found_inf = paddle.cast(self._found_inf, dtype="int32")
param_grads_fp16 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and (param._grad_ivar(
).dtype == core.VarDesc.VarType.FP16)
]
param_grads_fp32 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and (param._grad_ivar(
).dtype == core.VarDesc.VarType.FP32)
]
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
if len(param_grads_fp16):
_C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
param_grads_fp16,
temp_found_inf_fp16)
if len(param_grads_fp32):
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
param_grads_fp32,
temp_found_inf_fp32)
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0

# TODO(shenliang03) Since dp allreduce in the optimizer is
# after the gradscaler, check_finite needs to synchronize global
# information. In the future, we should use check_group to speed.
paddle.distributed.all_reduce(
self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
self._found_inf = paddle.cast(self._found_inf, dtype="bool")
paddle.to_tensor(
[self._found_inf], dtype="int32"),
op=paddle.distributed.ReduceOp.MAX,
group=None)

# Only tensor_parallel and pipeline_parallel need to modify scaler
if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,
Expand Down
37 changes: 27 additions & 10 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,8 @@ def forward_backward_pipeline(self, data, scaler=None):
p2p.send_backward(input_tensor_grad)

self._layers.allreduce_shared_weight_gradients()

train_loss = self._broadcast_final_loss()

with paddle.amp.auto_cast(enable=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is ok to put guard here, but I wonder if it is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing guard will cause diff of precision while broadcasting train_loss. So it is necessary to put guard here.

train_loss = self._broadcast_final_loss()
return train_loss

def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
Expand All @@ -172,7 +171,8 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
train_loss = self.forward_backward_pipeline(data, scaler)

# optimizer
self._optimizer_step()
with paddle.amp.auto_cast(enable=False):
self._optimizer_step()

return train_loss

Expand Down Expand Up @@ -242,12 +242,13 @@ def _forward_step(self, input_tensor):
output_tensor, paddle.Tensor
), "Currently, loss_fn should obtain Paddle.Tensor dtype"

if self.accumulate_steps > 1:
output_tensor = output_tensor / self.accumulate_steps
with paddle.amp.auto_cast(enable=False):
if self.accumulate_steps > 1:
output_tensor = output_tensor / self.accumulate_steps

if self.total_loss is None:
self.total_loss = paddle.zeros_like(output_tensor)
self.total_loss += output_tensor.detach()
if self.total_loss is None:
self.total_loss = paddle.zeros_like(output_tensor)
self.total_loss += output_tensor.detach()

self.micro_batch_id += 1
return output_tensor
Expand Down Expand Up @@ -321,13 +322,29 @@ def _broadcast_final_loss(self):
if self.is_last_stage:
assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
loss = self.total_loss.detach()
is_fp32 = paddle.to_tensor(
1) if loss.dtype == paddle.float32 else paddle.to_tensor(0)
paddle.distributed.broadcast(
is_fp32,
src=self.global_rank,
use_calc_stream=True,
group=self.pp_group)
paddle.distributed.broadcast(
loss,
src=self.global_rank,
use_calc_stream=True,
group=self.pp_group)
else:
loss = paddle.zeros(shape=[1], dtype="float32")
is_fp32 = paddle.to_tensor(1)
paddle.distributed.broadcast(
is_fp32,
src=self._hcg.get_rank_from_stage(self.num_stages - 1),
use_calc_stream=True,
group=self.pp_group)
loss = paddle.zeros(
shape=[1],
dtype="float32") if is_fp32.numpy()[0] else paddle.zeros(
shape=[1], dtype="float16")
paddle.distributed.broadcast(
loss,
src=self._hcg.get_rank_from_stage(self.num_stages - 1),
Expand Down
13 changes: 8 additions & 5 deletions python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,14 @@ def forward(ctx, run_function, all_outputs, *args):

# TODO support AMP
tracer = framework._dygraph_tracer()
if tracer._amp_level == core.AmpLevel.O0:
ctx.is_fw_autocast = False
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
ctx.is_fw_autocast = True
ctx.amp_mode = 'O1'
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

with paddle.no_grad():
Expand Down Expand Up @@ -263,7 +266,7 @@ def backward(ctx, *args):
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_mode):
level=ctx.amp_level):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)

Expand Down
15 changes: 9 additions & 6 deletions python/paddle/distributed/fleet/utils/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,14 @@ def forward(ctx, run_function, preserve_rng_state, *args):

# TODO support AMP
tracer = framework._dygraph_tracer()
if tracer._amp_level == core.AmpLevel.O0:
ctx.is_fw_autocast = False
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
ctx.is_fw_autocast = True
ctx.amp_mode = 'O1'
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

with paddle.no_grad():
Expand Down Expand Up @@ -133,15 +136,15 @@ def backward(ctx, *args):
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_mode):
level=ctx.amp_level):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_mode):
level=ctx.amp_level):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -6085,7 +6085,7 @@ def __init__(self, shape, dtype, **kwargs):

self.need_clip = kwargs.get('need_clip', True)

self.is_distributed = False
self.is_distributed = kwargs.get('is_distributed', False)
# self.block = default_main_program().global_block()

@property
Expand Down
138 changes: 138 additions & 0 deletions python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import print_function

import unittest
import paddle
import numpy as np
import random
import paddle
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from hybrid_parallel_pp_layer import AlexNetPipeDesc, AlexNet


def set_random_seed(seed, dp_id, rank_id):
"""Set random seed for reproducability."""
random.seed(seed)
np.random.seed(seed + dp_id)
paddle.seed(seed + dp_id)


batch_size = 4
micro_batch_size = 2


class TestDistPPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size
}
fleet.init(is_collective=True, strategy=strategy)

def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
set_random_seed(1024, dp_id, rank_id)

#construct model a
model_a = AlexNet(10)
scheduler_a = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a,
parameters=model_a.parameters())

scaler_a = paddle.amp.GradScaler(init_loss_scaling=2**5)

# construct model b
model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size)
scheduler_b = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b,
parameters=model_b.parameters())

param_len = len(model_a.parameters())
parameters = []
for param in model_a.parameters():
parameters.append(param.numpy())

for idx, param in enumerate(model_b.parameters()):
param.set_value(parameters[idx + pp_id * (param_len // 2)])

model_a, optimizer_a = paddle.amp.decorate(
models=model_a,
optimizers=optimizer_a,
level='O2',
save_dtype='float32')
model_b, optimizer_b = paddle.amp.decorate(
models=model_b,
optimizers=optimizer_b,
level='O2',
save_dtype='float32')

model_b = fleet.distributed_model(model_b)
optimizer_b = fleet.distributed_optimizer(optimizer_b)
scaler_b = paddle.amp.GradScaler(init_loss_scaling=2**5)
scaler_b = fleet.distributed_scaler(scaler_b)

# construct reader
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size, drop_last=True)

for step_id, data in enumerate(train_reader()):
x_data = np.array([x[0] for x in data]).astype('float32').reshape(
batch_size, 1, 28, 28)
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
batch_size, 1)
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
img.stop_gradient = True
label.stop_gradient = True

if step_id >= 5:
return True

with paddle.amp.auto_cast(enable=True, level='O2'):
loss_a = model_a(img, label)
scaler_a.scale(loss_a).backward()
with paddle.amp.auto_cast(enable=False):
scaler_a.minimize(optimizer_a, loss_a)
optimizer_a.clear_grad()
scheduler_a.step()

loss_b = model_b.train_batch(
[img, label], optimizer_b, scheduler_b, scaler=scaler_b)

print("loss: ", loss_a.numpy(), loss_b.numpy())
np.testing.assert_allclose(
loss_a.numpy(), loss_b.numpy(), rtol=5e-3)


if __name__ == "__main__":
unittest.main()
Loading