Skip to content

Commit

Permalink
[cherry-pick]Support FP16 in HybridParallel and Fix bugs in HybridOpt…
Browse files Browse the repository at this point in the history
…imizer (#36707)

* fix bugs in HybridParallelClipGrad of hybrid_parallel_optimizer (#36237)

* fix bugs in HybridParallelClipGrad of hybrid_parallel_optimizer

* update

* update

* fix bugs in mp_layers、pp_layers and HybridParallelClipGrad (#36144)

* fix calling bug of HybridParallelClipGrad

* fix bugs of HybridParallelClipGrad

* add unittest of pp with HybridParallelClipGrad

* fix bugs in mp_layers.py

* update

* fix bugs in pp_layers.py

* update

* [HybridParallel]Rebuild code for pipeline (#36396)

* add no_sync for parameters sync

* add pipeline for moe

* [HybridParallel]Support fp16 in dygraph hybrid parallel (#36420)

* [HybridParallel]Support fp16 in dygraph hybrid parallel

* update

* update

* update for recompute

* add unittest of pp+fp16

* add unittest of recompute+fp16

* update

* modify ut

* modify ut of cond (#36475)

* fix bugs of ClipGradByGlobalNorm in HybridParallel (#36555)

* fix bugs of ClipGradByGlobalNorm

* add unittests

* add unittests

* [HybridParallel]fix bug of check_inf in fleet_base.py (#36651)

* fix bug of check_inf

* fix allreduce

* support ClipGradByGlobalNorm in sharding (#36012)

* support ClipGradByGlobalNorm in sharding

* support ClipGradByGlobalNorm in sharding

* test=allcase

* Update test_linalg_cond.py

* Update hybrid_parallel_util.py

* Update hybrid_parallel_util.py

Co-authored-by: ShenLiang <[email protected]>
Co-authored-by: zhaoyingli <[email protected]>
  • Loading branch information
3 people authored Oct 26, 2021
1 parent 77034fc commit 5b357e0
Show file tree
Hide file tree
Showing 19 changed files with 533 additions and 94 deletions.
40 changes: 34 additions & 6 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from ..meta_parallel import PipelineParallel, ShardingParallel
from ..meta_optimizers import HybridParallelOptimizer
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable

__all__ = []

Expand Down Expand Up @@ -1547,26 +1549,52 @@ def unscale_method(self, optimizer):
if getattr(optimizer, '_param_groups', None) and isinstance(
optimizer._param_groups[0], dict):
param_grads = []
param_grads_fp16 = []
param_grads_fp32 = []
for group in optimizer._param_groups:
for param in group['params']:
if param._grad_ivar() is not None:
param_grads.append(param._grad_ivar())
if param._grad_ivar(
).dtype == core.VarDesc.VarType.FP16:
param_grads_fp16.append(param._grad_ivar())
else:
param_grads_fp32.append(param._grad_ivar())
else:
param_grads = [
param._grad_ivar() for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
_C_ops.check_finite_and_unscale(param_grads, self._scale,
param_grads, self._found_inf)

self._found_inf = paddle.cast(self._found_inf, dtype="int32")
param_grads_fp16 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and (param._grad_ivar(
).dtype == core.VarDesc.VarType.FP16)
]
param_grads_fp32 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and (param._grad_ivar(
).dtype == core.VarDesc.VarType.FP32)
]
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
if len(param_grads_fp16):
_C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
param_grads_fp16,
temp_found_inf_fp16)
if len(param_grads_fp32):
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
param_grads_fp32,
temp_found_inf_fp32)

self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")

# TODO(shenliang03) Since dp allreduce in the optimizer is
# after the gradscaler, check_finite needs to synchronize global
# information. In the future, we should use check_group to speed.
paddle.distributed.all_reduce(
self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
self._found_inf = paddle.cast(self._found_inf, dtype="bool")
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
self._found_inf = is_found_inf.numpy()[0]

# Only tensor_parallel and pipeline_parallel need to modify scaler
if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
from .hybrid_parallel_optimizer import HybridParallelOptimizer
from .hybrid_parallel_gradscaler import HybridParallelGradScaler
from .dygraph_sharding_optimizer import DygraphShardingOptimizer

__all__ = []
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ def __init__(self, clip, hcg):
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_list = []

sum_square_dist_fp16 = []
sum_square_dist_fp32 = []
sum_square_not_dist_fp16 = []
sum_square_not_dist_fp32 = []

for p, g in params_grads:
if g is None:
continue
Expand All @@ -62,32 +67,98 @@ def _dygraph_clip(self, params_grads):
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(square)
sum_square_list.append(sum_square)

# all parameters have been filterd out
if len(sum_square_list) == 0:
return params_grads

global_norm_var = layers.concat(sum_square_list)
global_norm_var = layers.reduce_sum(global_norm_var)
# add all reduce to get global norm in world size
paddle.distributed.all_reduce(global_norm_var,
self._hcg.get_check_parallel_group())
global_norm_var = layers.sqrt(global_norm_var)
not_shared_enable = (not hasattr(p, 'is_firstly_shared')) or (
hasattr(p, 'is_firstly_shared') and
getattr(p, 'is_firstly_shared', True))

if not_shared_enable:
if p.is_distributed:
if p.dtype == paddle.float16:
sum_square_dist_fp16.append(sum_square)
elif p.dtype == paddle.float32:
sum_square_dist_fp32.append(sum_square)
else:
if p.dtype == paddle.float16:
sum_square_not_dist_fp16.append(sum_square)
elif p.dtype == paddle.float32:
sum_square_not_dist_fp32.append(sum_square)

# global norm of distributed FP16 params_and_grads
if len(sum_square_dist_fp16) == 0:
global_norm_dist_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
else:
global_norm_dist_fp16 = layers.concat(sum_square_dist_fp16)
global_norm_dist_fp16 = layers.reduce_sum(global_norm_dist_fp16)
global_norm_dist_fp16 = paddle.cast(
global_norm_dist_fp16, dtype=paddle.float32)

# global norm of non-distributed FP16 params_and_grads
if len(sum_square_not_dist_fp16) == 0:
global_norm_not_dist_fp16 = paddle.to_tensor(
[0.], dtype=paddle.float32)
else:
global_norm_not_dist_fp16 = layers.concat(sum_square_not_dist_fp16)
global_norm_not_dist_fp16 = layers.reduce_sum(
global_norm_not_dist_fp16)
global_norm_not_dist_fp16 = paddle.cast(
global_norm_not_dist_fp16, dtype=paddle.float32)

# global norm of distributed FP32 params_and_grads
global_norm_dist_fp32 = layers.concat(sum_square_dist_fp32) if len(
sum_square_dist_fp32) != 0 else paddle.to_tensor(
[0.], dtype=paddle.float32)
global_norm_dist_fp32 = layers.reduce_sum(global_norm_dist_fp32)

# global norm of non-distributed FP32 params_and_grads
global_norm_not_dist_fp32 = layers.concat(
sum_square_not_dist_fp32) if len(
sum_square_not_dist_fp32) != 0 else paddle.to_tensor(
[0.], dtype=paddle.float32)
global_norm_not_dist_fp32 = layers.reduce_sum(global_norm_not_dist_fp32)

global_norm_var_dist = global_norm_dist_fp16 + global_norm_dist_fp32
global_norm_var_not_dist = global_norm_not_dist_fp16 + global_norm_not_dist_fp32

# add all reduce to get global norm of distributed params_and_grads
if self._hcg.get_model_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_check_parallel_group())

# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
if self._hcg.get_pipe_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_pipe_parallel_group())

# In Sharding mode, param and grad is mapping different rank in optimizer.
# ClipGradByGlobalNorm need allreduce to get globol norm
if self._hcg.get_sharding_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group())

global_norm_var_fp32 = layers.sqrt(global_norm_var_dist +
global_norm_var_not_dist)

max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
shape=[1], dtype=global_norm_var_fp32.dtype, value=self.clip_norm)
clip_var = layers.elementwise_div(
x=max_global_norm,
y=layers.elementwise_max(
x=global_norm_var, y=max_global_norm))
x=global_norm_var_fp32, y=max_global_norm))
clip_var_fp16 = paddle.cast(clip_var, paddle.float16)
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
new_grad = layers.elementwise_mul(x=g, y=clip_var)
if p.dtype == paddle.float16:
new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16)
else:
new_grad = layers.elementwise_mul(x=g, y=clip_var)
params_and_grads.append((p, new_grad))

return params_and_grads
Expand All @@ -96,7 +167,7 @@ def __getattr__(self, item):
return getattr(self._clip, item)

def __call__(self, params_grads):
return self._clip(params_grads)
return self._dygraph_clip(params_grads)


class HybridParallelOptimizer:
Expand All @@ -112,19 +183,24 @@ def __init__(self, optimizer, hcg, strategy):
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)

# NOTE(shenliang03): Because of the pure DataParallel mode, the gradient synchronization
# is achieved through reducer, so there is no need to call fuse_allreduce in oprimizer.
# is achieved through reducer, so there is no need to call fuse_allreduce in optimizer.
self._dp_enable = not self._use_dp_mode and self._need_dp

self._sharding_enable = (
self._hcg.get_sharding_parallel_world_size() > 1)

if isinstance(self._inner_opt._grad_clip,
ClipGradByGlobalNorm) and not self._use_dp_mode:
logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
"optmizer'grad clip will be changed.")

self._inner_opt._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)
logger.warning("While using ClipGradByGlobalNorm in TensorParallel, PipelineParallel " \
"or Sharding, the grad clip of original optimizer will be changed.")

if self._sharding_enable:
# change sharding inner_optimizer's _grad_clip
self._inner_opt._inner_optimizer._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)
else:
self._inner_opt._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)

@imperative_base.no_grad
@framework.dygraph_only
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self,
dtype=self._dtype,
is_bias=False)

self.weight.is_distributed = True
self.weight.is_distributed = True if self.is_mp else False

def forward(self, x):
if self.is_mp:
Expand Down Expand Up @@ -135,7 +135,7 @@ def __init__(self,
dtype=self._dtype,
is_bias=False)

self.weight.is_distributed = True
self.weight.is_distributed = True if self.is_mp else False

if has_bias:
# initialize bias to zero like Megatron
Expand All @@ -144,7 +144,7 @@ def __init__(self,
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype,
is_bias=True)
self.bias.is_distributed = True
self.bias.is_distributed = True if self.is_mp else False
else:
self.bias = None

Expand Down Expand Up @@ -212,7 +212,7 @@ def __init__(self,
dtype=self._dtype,
is_bias=False)

self.weight.is_distributed = True
self.weight.is_distributed = True if self.is_mp else False

if has_bias:
self.bias = self.create_parameter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ def _synchronize_shared_weights(self):
src=min(comm['ranks']),
group=comm['group'])

for param in comm['layer'].parameters():
if self.global_rank != min(comm['ranks']):
setattr(param, 'is_firstly_shared', False)

def allreduce_shared_weight_gradients(self):
for key, comm in self.shared_comm.items():
param = getattr(self.shared_layers[key], comm['weight_attr'])
Expand Down Expand Up @@ -316,6 +320,9 @@ def _build_layer(self):
self.shared_layers[layer.layer_name] = layer.build_layer()
self.shared_weight_attrs[
layer.layer_name] = layer.shared_weight_attr
for param in self.shared_layers[
layer.layer_name].parameters():
setattr(param, "is_firstly_shared", True)

if layer.forward_func is None:
self.run_function.append(self.shared_layers[
Expand Down
Loading

0 comments on commit 5b357e0

Please sign in to comment.