Skip to content

Commit

Permalink
[HybridParallel]Support fp16 in dygraph hybrid parallel (#36420)
Browse files Browse the repository at this point in the history
* [HybridParallel]Support fp16 in dygraph hybrid parallel

* update

* update

* update for recompute

* add unittest of pp+fp16

* add unittest of recompute+fp16

* update

* modify ut
  • Loading branch information
haohongxiang authored Oct 18, 2021
1 parent bdac9ff commit 10f0a0f
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 31 deletions.
40 changes: 34 additions & 6 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from ..meta_parallel import PipelineParallel, ShardingParallel
from ..meta_optimizers import HybridParallelOptimizer
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable

__all__ = []

Expand Down Expand Up @@ -1548,26 +1550,52 @@ def unscale_method(self, optimizer):
if getattr(optimizer, '_param_groups', None) and isinstance(
optimizer._param_groups[0], dict):
param_grads = []
param_grads_fp16 = []
param_grads_fp32 = []
for group in optimizer._param_groups:
for param in group['params']:
if param._grad_ivar() is not None:
param_grads.append(param._grad_ivar())
if param._grad_ivar(
).dtype == core.VarDesc.VarType.FP16:
param_grads_fp16.append(param._grad_ivar())
else:
param_grads_fp32.append(param._grad_ivar())
else:
param_grads = [
param._grad_ivar() for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
_C_ops.check_finite_and_unscale(param_grads, self._scale,
param_grads, self._found_inf)

self._found_inf = paddle.cast(self._found_inf, dtype="int32")
param_grads_fp16 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and (param._grad_ivar(
).dtype == core.VarDesc.VarType.FP16)
]
param_grads_fp32 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and (param._grad_ivar(
).dtype == core.VarDesc.VarType.FP32)
]
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
if len(param_grads_fp16):
_C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
param_grads_fp16,
temp_found_inf_fp16)
if len(param_grads_fp32):
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
param_grads_fp32,
temp_found_inf_fp32)
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0

# TODO(shenliang03) Since dp allreduce in the optimizer is
# after the gradscaler, check_finite needs to synchronize global
# information. In the future, we should use check_group to speed.
paddle.distributed.all_reduce(
self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
self._found_inf = paddle.cast(self._found_inf, dtype="bool")
paddle.to_tensor(
[self._found_inf], dtype="int32"),
op=paddle.distributed.ReduceOp.MAX,
group=None)

# Only tensor_parallel and pipeline_parallel need to modify scaler
if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,
Expand Down
37 changes: 27 additions & 10 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,8 @@ def forward_backward_pipeline(self, data, scaler=None):
p2p.send_backward(input_tensor_grad)

self._layers.allreduce_shared_weight_gradients()

train_loss = self._broadcast_final_loss()

with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss()
return train_loss

def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
Expand All @@ -172,7 +171,8 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
train_loss = self.forward_backward_pipeline(data, scaler)

# optimizer
self._optimizer_step()
with paddle.amp.auto_cast(enable=False):
self._optimizer_step()

return train_loss

Expand Down Expand Up @@ -242,12 +242,13 @@ def _forward_step(self, input_tensor):
output_tensor, paddle.Tensor
), "Currently, loss_fn should obtain Paddle.Tensor dtype"

if self.accumulate_steps > 1:
output_tensor = output_tensor / self.accumulate_steps
with paddle.amp.auto_cast(enable=False):
if self.accumulate_steps > 1:
output_tensor = output_tensor / self.accumulate_steps

if self.total_loss is None:
self.total_loss = paddle.zeros_like(output_tensor)
self.total_loss += output_tensor.detach()
if self.total_loss is None:
self.total_loss = paddle.zeros_like(output_tensor)
self.total_loss += output_tensor.detach()

self.micro_batch_id += 1
return output_tensor
Expand Down Expand Up @@ -321,13 +322,29 @@ def _broadcast_final_loss(self):
if self.is_last_stage:
assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
loss = self.total_loss.detach()
is_fp32 = paddle.to_tensor(
1) if loss.dtype == paddle.float32 else paddle.to_tensor(0)
paddle.distributed.broadcast(
is_fp32,
src=self.global_rank,
use_calc_stream=True,
group=self.pp_group)
paddle.distributed.broadcast(
loss,
src=self.global_rank,
use_calc_stream=True,
group=self.pp_group)
else:
loss = paddle.zeros(shape=[1], dtype="float32")
is_fp32 = paddle.to_tensor(1)
paddle.distributed.broadcast(
is_fp32,
src=self._hcg.get_rank_from_stage(self.num_stages - 1),
use_calc_stream=True,
group=self.pp_group)
loss = paddle.zeros(
shape=[1],
dtype="float32") if is_fp32.numpy()[0] else paddle.zeros(
shape=[1], dtype="float16")
paddle.distributed.broadcast(
loss,
src=self._hcg.get_rank_from_stage(self.num_stages - 1),
Expand Down
13 changes: 8 additions & 5 deletions python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,14 @@ def forward(ctx, run_function, all_outputs, *args):

# TODO support AMP
tracer = framework._dygraph_tracer()
if tracer._amp_level == core.AmpLevel.O0:
ctx.is_fw_autocast = False
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
ctx.is_fw_autocast = True
ctx.amp_mode = 'O1'
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

with paddle.no_grad():
Expand Down Expand Up @@ -263,7 +266,7 @@ def backward(ctx, *args):
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_mode):
level=ctx.amp_level):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)

Expand Down
15 changes: 9 additions & 6 deletions python/paddle/distributed/fleet/utils/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,14 @@ def forward(ctx, run_function, preserve_rng_state, *args):

# TODO support AMP
tracer = framework._dygraph_tracer()
if tracer._amp_level == core.AmpLevel.O0:
ctx.is_fw_autocast = False
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
ctx.is_fw_autocast = True
ctx.amp_mode = 'O1'
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

with paddle.no_grad():
Expand Down Expand Up @@ -133,15 +136,15 @@ def backward(ctx, *args):
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_mode):
level=ctx.amp_level):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_mode):
level=ctx.amp_level):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -6097,7 +6097,7 @@ def __init__(self, shape, dtype, **kwargs):

self.need_clip = kwargs.get('need_clip', True)

self.is_distributed = False
self.is_distributed = kwargs.get('is_distributed', False)
# self.block = default_main_program().global_block()

@property
Expand Down
138 changes: 138 additions & 0 deletions python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import print_function

import unittest
import paddle
import numpy as np
import random
import paddle
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from hybrid_parallel_pp_layer import AlexNetPipeDesc, AlexNet


def set_random_seed(seed, dp_id, rank_id):
"""Set random seed for reproducability."""
random.seed(seed)
np.random.seed(seed + dp_id)
paddle.seed(seed + dp_id)


batch_size = 4
micro_batch_size = 2


class TestDistPPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size
}
fleet.init(is_collective=True, strategy=strategy)

def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
set_random_seed(1024, dp_id, rank_id)

#construct model a
model_a = AlexNet(10)
scheduler_a = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a,
parameters=model_a.parameters())

scaler_a = paddle.amp.GradScaler(init_loss_scaling=2**5)

# construct model b
model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size)
scheduler_b = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b,
parameters=model_b.parameters())

param_len = len(model_a.parameters())
parameters = []
for param in model_a.parameters():
parameters.append(param.numpy())

for idx, param in enumerate(model_b.parameters()):
param.set_value(parameters[idx + pp_id * (param_len // 2)])

model_a, optimizer_a = paddle.amp.decorate(
models=model_a,
optimizers=optimizer_a,
level='O2',
save_dtype='float32')
model_b, optimizer_b = paddle.amp.decorate(
models=model_b,
optimizers=optimizer_b,
level='O2',
save_dtype='float32')

model_b = fleet.distributed_model(model_b)
optimizer_b = fleet.distributed_optimizer(optimizer_b)
scaler_b = paddle.amp.GradScaler(init_loss_scaling=2**5)
scaler_b = fleet.distributed_scaler(scaler_b)

# construct reader
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size, drop_last=True)

for step_id, data in enumerate(train_reader()):
x_data = np.array([x[0] for x in data]).astype('float32').reshape(
batch_size, 1, 28, 28)
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
batch_size, 1)
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
img.stop_gradient = True
label.stop_gradient = True

if step_id >= 5:
return True

with paddle.amp.auto_cast(enable=True, level='O2'):
loss_a = model_a(img, label)
scaler_a.scale(loss_a).backward()
with paddle.amp.auto_cast(enable=False):
scaler_a.minimize(optimizer_a, loss_a)
optimizer_a.clear_grad()
scheduler_a.step()

loss_b = model_b.train_batch(
[img, label], optimizer_b, scheduler_b, scaler=scaler_b)

print("loss: ", loss_a.numpy(), loss_b.numpy())
np.testing.assert_allclose(
loss_a.numpy(), loss_b.numpy(), rtol=5e-3)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 10f0a0f

Please sign in to comment.