Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix gradient tensor mutate in {adam/ftrl/rmprop/rmspropalex}_update. #15768

Merged
merged 19 commits into from
Sep 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 150 additions & 129 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1293,15 +1293,37 @@ struct AdamParam : public dmlc::Parameter<AdamParam> {
}
};

struct AdamUpdateKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
const DType clip_gradient, const DType rescale_grad,
const DType beta1, const DType beta2,
const DType lr, const DType wd,
const DType epsilon, const OpReqType req) {
using namespace mshadow_op;

DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd;
if (clip_gradient >= 0.f) {
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
}

mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
var_data[i] = beta2 * var_data[i] +
(1.f - beta2) * grad_rescaled * grad_rescaled;

KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * mean_data[i] /
(square_root::Map(var_data[i]) + epsilon));
}
};

template<typename xpu>
inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
using namespace mxnet_op;
const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Expand All @@ -1311,22 +1333,12 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

grad = scalar<DType>(param.rescale_grad) * grad +
scalar<DType>(param.wd) * weight;

if (param.clip_gradient >= 0.0f) {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) *
F<clip>(grad, DType(param.clip_gradient));
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2)*F<square>(
F<clip>(grad, DType(param.clip_gradient)));
} else {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * grad;
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) * F<square>(grad);
}
Assign(out, req[0],
weight -
scalar<DType>(param.lr) * mean /
(F<square_root>(var) + scalar<DType>(param.epsilon)));
Kernel<AdamUpdateKernel, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.epsilon), req[0]);
});
}

Expand Down Expand Up @@ -1596,57 +1608,64 @@ struct RMSPropAlexParam : public dmlc::Parameter<RMSPropAlexParam> {
}
};

struct RMSPropAlexUpdateKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
DType* state_n_data, DType* state_g_data, DType* delta_data,
const DType* weight_data, const DType* grad_data,
const DType clip_gradient, const DType rescale_grad,
const DType gamma1, const DType gamma2,
const DType lr, const DType wd,
const DType clip_weights, const DType epsilon,
const OpReqType req) {
using namespace mshadow_op;

DType grad_rescaled = rescale_grad * grad_data[i] + wd * weight_data[i];
if (clip_gradient >= 0.0f) {
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
}

state_n_data[i] = (1.f - gamma1) * grad_rescaled * grad_rescaled +
gamma1 * state_n_data[i];
state_g_data[i] = (1.f - gamma1) * grad_rescaled +
gamma1 * state_g_data[i];
delta_data[i] = gamma2 * delta_data[i] -
(lr * (grad_rescaled) /
(square_root::Map(state_n_data[i] -
state_g_data[i] * state_g_data[i] + epsilon)));

if (clip_weights >= 0.0f) {
const DType clipped_weight = clip::Map(weight_data[i] + delta_data[i], clip_weights);
KERNEL_ASSIGN(out_data[i], req, clipped_weight);
} else {
KERNEL_ASSIGN(out_data[i], req, weight_data[i] + delta_data[i]);
}
}
};

template <typename xpu>
inline void RMSPropAlexUpdate(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
using namespace mxnet_op;
const RMSPropAlexParam &param = nnvm::get<RMSPropAlexParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> state_n = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> state_g = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> delta = inputs[4].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

grad = scalar<DType>(param.rescale_grad) * grad +
scalar<DType>(param.wd) * weight;

if (param.clip_gradient >= 0.0f) {
state_n = scalar<DType>(1.f - param.gamma1) *
F<clip>(grad, DType(param.clip_gradient)) *
F<clip>(grad, DType(param.clip_gradient)) +
scalar<DType>(param.gamma1) * state_n;
state_g = scalar<DType>(1.f - param.gamma1) *
F<clip>(grad, DType(param.clip_gradient)) +
scalar<DType>(param.gamma1) * state_g;
delta = scalar<DType>(param.gamma2) * delta -
scalar<DType>(param.lr) *
(F<clip>(grad, DType(param.clip_gradient)) /
(F<square_root>(state_n - state_g * state_g +
scalar<DType>(param.epsilon))));
} else {
state_n = scalar<DType>(1.f - param.gamma1) * (grad * grad) +
scalar<DType>(param.gamma1) * state_n;
state_g = scalar<DType>(1.f - param.gamma1) * grad +
scalar<DType>(param.gamma1) * state_g;
delta = scalar<DType>(param.gamma2) * delta -
scalar<DType>(param.lr) *
(grad / (F<square_root>(state_n - state_g * state_g +
scalar<DType>(param.epsilon))));
}
DType* weight_data = inputs[0].dptr<DType>();
DType* grad_data = inputs[1].dptr<DType>();
DType* state_n_data = inputs[2].dptr<DType>();
DType* state_g_data = inputs[3].dptr<DType>();
DType* delta_data = inputs[4].dptr<DType>();
DType* out_data = outputs[0].dptr<DType>();

if (param.clip_weights >= 0.0f) {
Assign(out, req[0], F<clip>(weight + delta, DType(param.clip_weights)));
} else {
Assign(out, req[0], weight + delta);
}
Kernel<RMSPropAlexUpdateKernel, xpu>::Launch(s, inputs[0].shape_.Size(),
out_data, state_n_data, state_g_data, delta_data, weight_data, grad_data,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
static_cast<DType>(param.gamma1), static_cast<DType>(param.gamma2),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.clip_weights), static_cast<DType>(param.epsilon), req[0]);
});
}

Expand Down Expand Up @@ -1688,64 +1707,52 @@ struct RMSPropParam : public dmlc::Parameter<RMSPropParam> {
}
};

struct RMSPropUpdateKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i,
DType* out_data, DType* state_n_data,
const DType* weight_data, const DType* grad_data,
const DType clip_gradient, const DType rescale_grad,
const DType gamma1, const DType lr, const DType wd,
const DType clip_weights, const DType epsilon,
const OpReqType req) {
using namespace mshadow_op;

DType grad_rescaled = rescale_grad * grad_data[i] + wd * weight_data[i];
if (clip_gradient >= 0.0f) {
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
}

state_n_data[i] = (1.f - gamma1) * (grad_rescaled * grad_rescaled) + gamma1 * state_n_data[i];

DType weight = weight_data[i] -
lr * (grad_rescaled / square_root::Map(state_n_data[i] + epsilon));
if (clip_weights >= 0.0f) {
weight = clip::Map(weight, clip_weights);
}
KERNEL_ASSIGN(out_data[i], req, weight);
}
};

template <typename xpu>
inline void RMSPropUpdate(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
using namespace mxnet_op;
const RMSPropParam &param = nnvm::get<RMSPropParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> state_n = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
DType* weight_data = inputs[0].dptr<DType>();
DType* grad_data = inputs[1].dptr<DType>();
DType* state_n_data = inputs[2].dptr<DType>();
DType* out_data = outputs[0].dptr<DType>();

grad = scalar<DType>(param.rescale_grad) * grad +
scalar<DType>(param.wd) * weight;

if (param.clip_gradient >= 0.0f) {
state_n = scalar<DType>(1.f - param.gamma1) *
F<clip>(grad, DType(param.clip_gradient)) *
F<clip>(grad, DType(param.clip_gradient)) +
scalar<DType>(param.gamma1) * state_n;
if (param.clip_weights >= 0.0f) {
Assign(out, req[0],
F<clip>(weight -
scalar<DType>(param.lr) *
(F<clip>(grad, DType(param.clip_gradient)) /
(F<square_root>(state_n +
scalar<DType>(param.epsilon)))),
DType(param.clip_weights)));
} else {
Assign(out, req[0], weight -
scalar<DType>(param.lr) *
(F<clip>(grad, DType(param.clip_gradient)) /
(F<square_root>(state_n +
scalar<DType>(param.epsilon)))));
}
} else {
state_n = scalar<DType>(1.f - param.gamma1) * (grad * grad) +
scalar<DType>(param.gamma1) * state_n;
if (param.clip_weights >= 0.0f) {
Assign(out, req[0],
F<clip>(weight -
scalar<DType>(param.lr) *
(grad /
(F<square_root>(state_n +
scalar<DType>(param.epsilon)))),
DType(param.clip_weights)));
} else {
Assign(out, req[0], weight -
scalar<DType>(param.lr) *
(grad /
(F<square_root>(state_n +
scalar<DType>(param.epsilon)))));
}
}
Kernel<RMSPropUpdateKernel, xpu>::Launch(s, inputs[0].shape_.Size(),
out_data, state_n_data, weight_data, grad_data,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
static_cast<DType>(param.gamma1), static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.clip_weights), static_cast<DType>(param.epsilon), req[0]);
});
}

Expand Down Expand Up @@ -1781,15 +1788,41 @@ struct FtrlParam : public dmlc::Parameter<FtrlParam> {
}
};

struct FtrlUpdateKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
DType* n_data, DType* z_data, const DType* weight_data, const DType* grad_data,
const DType clip_gradient, const DType rescale_grad,
const DType beta, const DType lamda1,
const DType lr, const DType wd,
const OpReqType req) {
using namespace mshadow_op;

DType grad_rescaled = grad_data[i] * rescale_grad;
if (clip_gradient >= 0.0f) {
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
}

z_data[i] += grad_rescaled - (square_root::Map(n_data[i] +
square::Map(grad_rescaled)) - square_root::Map(n_data[i])) *
weight_data[i] / lr;
n_data[i] += square::Map(grad_rescaled);

KERNEL_ASSIGN(out_data[i], req,
(sign::Map(z_data[i]) * lamda1 - z_data[i]) /
((beta + square_root::Map(n_data[i])) / lr + wd) *
gt::Map(abs::Map(z_data[i]), lamda1));
}
};

template<typename xpu>
inline void FtrlUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
using namespace mxnet_op;

const FtrlParam& param = nnvm::get<FtrlParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Expand All @@ -1799,23 +1832,11 @@ inline void FtrlUpdate(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 2, DType> n = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

grad = scalar<DType>(param.rescale_grad) * grad;

if (param.clip_gradient >= 0.0f) {
z += F<clip>(grad, DType(param.clip_gradient)) - (F<square_root>(n +
F<square>(F<clip>(grad, DType(param.clip_gradient)))) - F<square_root>(n)) *
weight / scalar<DType>(param.lr);
n += F<square>(F<clip>(grad, DType(param.clip_gradient)));
} else {
z += grad - (F<square_root>(n + F<square>(grad)) - F<square_root>(n)) *
weight / scalar<DType>(param.lr);
n += F<square>(grad);
}
Assign(out, req[0],
(F<sign>(z) * scalar<DType>(param.lamda1) - z) /
((scalar<DType>(param.beta) + F<square_root>(n)) /
scalar<DType>(param.lr) + scalar<DType>(param.wd)) *
F<gt>(F<abs>(z), scalar<DType>(param.lamda1)));
Kernel<FtrlUpdateKernel, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, n.dptr_, z.dptr_, weight.dptr_, grad.dptr_,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
static_cast<DType>(param.beta), static_cast<DType>(param.lamda1),
static_cast<DType>(param.lr), static_cast<DType>(param.wd), req[0]);
});
}

Expand Down
Loading