Skip to content

Commit

Permalink
[opt] Add regularation and Nesterov for mergerd_momentum op (PaddlePa…
Browse files Browse the repository at this point in the history
…ddle#37527)

* add regularation and Nesterov for mergerd_momentum

* refine unittest for use_nesterov attr

* refine op check

* refine code

* fix bug

* refine code of regularization_flag

* delete useless code
  • Loading branch information
zhangbo9674 authored and Zjq9409 committed Dec 10, 2021
1 parent 56d4c1b commit c3f0ab0
Show file tree
Hide file tree
Showing 4 changed files with 360 additions and 28 deletions.
15 changes: 14 additions & 1 deletion paddle/fluid/operators/optimizers/merged_momentum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDuplicable();
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"Input learning rate");
"Input learning rate")
.AsDuplicable();
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDispensable()
.AsDuplicable();
Expand All @@ -68,6 +69,18 @@ class MergedMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable()
.AsDuplicable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov",
"(bool, default false) "
"Use Nesterov Momentum or not.")
.SetDefault(false);
AddAttr<std::vector<std::string>>(
"regularization_method",
"(string) regularization_method, right now only "
"support l2decay or none")
.SetDefault({});
AddAttr<std::vector<float>>("regularization_coeff",
"(float) regularization_coeff")
.SetDefault({});
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
Expand Down
174 changes: 148 additions & 26 deletions paddle/fluid/operators/optimizers/merged_momentum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/macros.h"

Expand Down Expand Up @@ -85,33 +86,43 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size();
PADDLE_ENFORCE_EQ(
n, params_out.size(),
platform::errors::InvalidArgument(
"Output(ParamOut) number must be equal to Input(Param) number."));
PADDLE_ENFORCE_EQ(n, params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
params_out.size(), n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(
params[i], params_out[i],
platform::errors::InvalidArgument(
"Input(Param) and Output(ParamOut) must be the same Tensors."));
PADDLE_ENFORCE_EQ(params[i], params_out[i],
platform::errors::InvalidArgument(
"The size of Input(Param) and Output(ParamOut) "
"must be the same Tensors."));
}

auto grads = ctx.MultiInput<framework::Tensor>("Grad");
PADDLE_ENFORCE_EQ(
n, grads.size(),
platform::errors::InvalidArgument(
"Input(Grad) number must be equal to Input(Param) number."));
"The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.",
grads.size(), n));

auto velocitys = ctx.MultiInput<framework::Tensor>("Velocity");
PADDLE_ENFORCE_EQ(n, velocitys.size(),
platform::errors::InvalidArgument(
"Input(Velocity) number and Input(Param) number."));
"The size of Input(Velocity) must be equal to "
"Input(Param), but got the size of Input(Velocity) "
"is %d, the size of Input(Param) is %d.",
velocitys.size(), n));

auto velocitys_out = ctx.MultiOutput<framework::Tensor>("VelocityOut");
PADDLE_ENFORCE_EQ(
n, velocitys_out.size(),
platform::errors::InvalidArgument("Output(VelocityOut) number must be "
"equal to Input(Param) number."));
platform::errors::InvalidArgument(
"The size of Output(VelocityOut) must be "
"equal to Input(Param), but got the size of Output(VelocityOut) is "
"%d, the size of Input(Param) is %d.",
velocitys_out.size(), n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i],
platform::errors::InvalidArgument(
Expand All @@ -126,12 +137,18 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
if (multi_precision) {
PADDLE_ENFORCE_EQ(
n, master_params.size(),
platform::errors::InvalidArgument("Input(MasterParam) number must be "
"equal to Input(Param) number."));
PADDLE_ENFORCE_EQ(n, master_params_out.size(),
platform::errors::InvalidArgument(
"Output(MasterParamOut) number must be equal to "
"Input(MasterParam) number."));
platform::errors::InvalidArgument(
"The size of Input(MasterParam) must be "
"equal to Input(Param), but got the size of Input(MasterParam) "
"is %d, the size of Input(Param) is %d.",
master_params.size(), n));
PADDLE_ENFORCE_EQ(
n, master_params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(MasterParamOut) must be equal to "
"Input(MasterParam), but got the size of Output(MasterParamOut) "
"is %d, the size of Input(Param) is %d.",
master_params_out.size(), n));
for (size_t i = 0; i < n; ++i) {
PADDLE_ENFORCE_EQ(master_params[i], master_params_out[i],
platform::errors::InvalidArgument(
Expand All @@ -147,20 +164,61 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
master_params_out.clear();
}

auto lr = ctx.Input<framework::Tensor>("LearningRate");
auto mu = ctx.Attr<float>("mu");
auto rescale_grad = ctx.Attr<float>("rescale_grad");
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
if (lrs.size() != 1) {
PADDLE_ENFORCE_EQ(
n, lrs.size(),
platform::errors::InvalidArgument(
"If the size of Input(LearningRate) is not 1, the size of "
"Input(LearningRate) must be "
"equal to Input(Param), but got the size of Input(LearningRate) "
"is %d, the size of Input(Param) is %d.",
lrs.size(), n));
}
auto use_nesterov = ctx.Attr<bool>("use_nesterov");
auto regularization_methods =
ctx.Attr<std::vector<std::string>>("regularization_method");
auto regularization_coeffs =
ctx.Attr<std::vector<float>>("regularization_coeff");
if (regularization_methods.size() != 0) {
PADDLE_ENFORCE_EQ(
n, regularization_methods.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_method) must be equal "
"to Input(Param), but got the size of "
"Attr(regularization_method) is %d, the size of Input(Param) is "
"%d.",
regularization_methods.size(), n));
PADDLE_ENFORCE_EQ(
n, regularization_coeffs.size(),
platform::errors::InvalidArgument(
"The size of Attr(regularization_coeff) must be equal "
"to Input(Param), but got the size of Attr(regularization_coeff) "
"is %d, the size of Input(Param) is %d.",
regularization_coeffs.size(), n));
}

VLOG(5) << "use_nesterov: " << use_nesterov
<< ", regularization_methods.size(): "
<< regularization_methods.size()
<< ", regularization_coeffs.size(): "
<< regularization_coeffs.size();

using MPType = typename operators::details::MPTypeTrait<T>::Type;

auto &dev_ctx = ctx.template device_context<DeviceContext>();

if (lrs.size() == 1 && use_nesterov == false &&
regularization_methods.size() == 0) {
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MPType, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MPType>(mu); \
kernel_params.rescale_grad = static_cast<MPType>(rescale_grad); \
kernel_params.lr = lr->data<MPType>(); \
kernel_params.lr = lrs[0]->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
Expand All @@ -182,14 +240,78 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
}

if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
} else {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
} else {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false);
}
for (size_t idx = 0; idx < n; idx++) {
RegularizationType regularization_flag =
regularization_methods.size() > 0 &&
regularization_methods[idx] == "l2_decay"
? RegularizationType::kL2DECAY
: RegularizationType::kNONE;

#undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL
MPType regularization_coeff = static_cast<MPType>(0.0);
if (regularization_coeffs.size() != 0) {
regularization_coeff =
static_cast<MPType>(regularization_coeffs[idx]);
}
auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0];

const MPType *master_in_data =
multi_precision ? master_params[idx]->data<MPType>() : nullptr;
MPType *master_out_data =
multi_precision ? master_params_out[idx]->data<MPType>() : nullptr;
if (platform::is_cpu_place(ctx.GetPlace())) {
CPUDenseMomentumFunctor<MPType> functor;
functor(params[idx], grads[idx], velocitys[idx], lr_temp, mu,
use_nesterov, regularization_flag, regularization_coeff,
params_out[idx], velocitys_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel.";
} else if (platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext &>(ctx.device_context()),
params[idx]->numel());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MPType, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), grads[idx]->data<T>(), \
velocitys[idx]->data<MPType>(), lr_temp->data<MPType>(), master_in_data, \
mu, rescale_grad, params[idx]->numel(), regularization_coeff, \
params_out[idx]->data<T>(), velocitys_out[idx]->data<MPType>(), \
master_out_data); \
for_range(functor);
if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
UseNesterov, RegularizationType::kL2DECAY);
VLOG(10)
<< "Launch MergedMomentum gpu kernel use_nesterov kL2DECAY.";
} else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(UseNesterov,
RegularizationType::kNONE);
VLOG(10)
<< "Launch MergedMomentum gpu kernel use_nesterov kNONE.";
}
} else {
if (regularization_flag == RegularizationType::kL2DECAY) {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(
NoNesterov, RegularizationType::kL2DECAY);
VLOG(10)
<< "Launch MergedMomentum gpu kernel no_nesterov kL2DECAY.";
} else {
PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(NoNesterov,
RegularizationType::kNONE);
VLOG(10) << "Launch MergedMomentum gpu kernel no_nesterov kNONE.";
}
}
}
}
VLOG(10)
<< "Launch MergedMomentum kernel with multi_lr and regularization.";
}
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ GetVarBaseListFromArgs(const std::string& op_type, const std::string& arg_name,
bool dispensable = false) {
PyObject* list = PyTuple_GET_ITEM(args, arg_idx);

if (list == nullptr) {
if (list == nullptr || list == Py_None) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensor, but got "
Expand Down
Loading

0 comments on commit c3f0ab0

Please sign in to comment.