Skip to content

Commit

Permalink
[NPU] Add flatten_param_grads for Trainer to improve NPU performance (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pangyoki authored Jan 12, 2023
1 parent 06de433 commit b976b74
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 1 deletion.
7 changes: 6 additions & 1 deletion docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,15 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
是否从断点重启恢复训练,(可选,默认为 None)
The path to a folder with a valid checkpoint for your
model. (default: None)

--skip_memory_metrics
是否跳过内存profiler检测。(可选,默认为True,跳过)
Whether or not to skip adding of memory profiler reports
to metrics.(default:True)

--flatten_param_grads
是否在优化器中使用flatten_param_grads策略,该策略将素有参数摊平后输入Optimizer更新。目前该策略仅在NPU设备上生效。(可选,默认为False
Whether use flatten_param_grads method in optimizer,
only used on NPU devices.(default:False)

```
127 changes: 127 additions & 0 deletions paddlenlp/trainer/plugins/npu_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2020-present the HuggingFace Inc. team.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import types

import numpy as np
import paddle
from paddle.fluid.layer_helper import LayerHelper

from ...utils.log import logger


def npu_accelerate_plugin(optimizer):
"""npu_accelerate_plugin uses the flatten_param_grads method to speed up the performance of the model on NPU devices.
flatten_param_grads method will be added to `step` function of optimizer.
Args:
optimizer (`paddle.optimizer.Optimizer`):
The Optimizer whose `step` method will be modified.
"""
optimizer.step = types.MethodType(_optimizer_step_with_flatten_param_grads, optimizer)


def _optimizer_step_with_flatten_param_grads(optimizer):
if not isinstance(optimizer._param_groups[0], dict):
params_grads = []
for param in optimizer._param_groups:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads.append((param, grad_var))

# currently, only support ClipGradByGlobalNorm and without regularization.
if isinstance(params_grads, list) and optimizer.regularization is None:
if optimizer._grad_clip is None or isinstance(optimizer._grad_clip, paddle.nn.ClipGradByGlobalNorm):
params_grads = _flatten_param_grads(optimizer, params_grads)

optimizer._apply_optimize(
loss=None,
startup_program=None,
params_grads=params_grads,
param_group_idx=0,
)
else:
raise RuntimeError("flatten_param_grads is not supported when _param_groups[0] is dict.")


def _flatten_param_grads(optimizer, params_grads):
optimizer.helper = LayerHelper(optimizer.__class__.__name__)
need_flatten_params = []
need_flatten_grads = []
for p, g in params_grads:
if g is None:
continue
g.persistable = True
if getattr(p, "need_clip", True) is False or getattr(p, "regularizer", None) is not None:
logger.warning(
f"flatten_param_grads=True will be discarded since paramter {p.name}'s need_clip is False or "
"the regularizer is set."
)
return params_grads

need_flatten_params.append(p)
need_flatten_grads.append(g)

shape = [np.prod(p.shape) for p in need_flatten_params]

flatten_param = optimizer.helper.create_global_variable(
name="flatten_param",
persistable=True,
dtype=need_flatten_params[0].dtype,
shape=[np.sum(shape)],
belong_to_optimizer=True,
)

flatten_grad = optimizer.helper.create_global_variable(
name="flatten_grad",
persistable=True,
dtype=need_flatten_grads[0].dtype,
shape=[np.sum(shape)],
belong_to_optimizer=True,
)

flatten_param.stop_gradient = False
# In the final state of the dynamic graph, the `coalesce_tensor` op
# does not support passing the output as an input into the op in
# temporary, so _legacy_C_ops is temporarily used here.
# `use_align` is set to false, which is different from the behavior
# under static graphs. `use_align` can be set to true after calling
# the coalesce_tensor op of the final state (_C_ops).
paddle._legacy_C_ops.coalesce_tensor(
need_flatten_params,
need_flatten_params,
flatten_param,
"copy_data",
True,
"use_align",
False,
"dtype",
need_flatten_params[0].dtype,
)

paddle._legacy_C_ops.coalesce_tensor(
need_flatten_grads,
need_flatten_grads,
flatten_grad,
"copy_data",
True,
"use_align",
False,
"dtype",
need_flatten_grads[0].dtype,
)
return [(flatten_param, flatten_grad)]
5 changes: 5 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,11 @@ def train(
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step

if self.args.device == "npu" and self.args.flatten_param_grads:
from .plugins.npu_plugin import npu_accelerate_plugin

npu_accelerate_plugin(self.optimizer)

for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance(
train_dataloader.batch_sampler, DistributedBatchSampler
Expand Down
9 changes: 9 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ class TrainingArguments:
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
[`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
scripts](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples) for more details.
flatten_param_grads (`bool`, *optional*):
Whether use flatten_param_grads method in optimizer, only used on NPU devices. Default is `False`.
"""

output_dir: str = field(
Expand Down Expand Up @@ -496,6 +498,10 @@ class TrainingArguments:
skip_memory_metrics: bool = field(
default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
)
flatten_param_grads: Optional[bool] = field(
default=False,
metadata={"help": "Whether use flatten_param_grads method in optimizer, only used on NPU devices."},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down Expand Up @@ -624,6 +630,9 @@ def __post_init__(self):
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
)

if self.flatten_param_grads and self.device != "npu":
raise ValueError("flatten_param_grads can only be used on npu devices in temporary.")

def __str__(self):
self_as_dict = asdict(self)
self_as_dict = {k: f"<{k.upper()}>" if k.endswith("_token") else v for k, v in self_as_dict.items()}
Expand Down

0 comments on commit b976b74

Please sign in to comment.