Skip to content

Commit

Permalink
fix lamb beta1pow beta2pow update
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Dec 28, 2021
1 parent 404a4a6 commit c7145c2
Showing 1 changed file with 108 additions and 72 deletions.
180 changes: 108 additions & 72 deletions paddle/fluid/operators/optimizers/lamb_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,16 @@ struct LambMomentREGUpdateFunctor {
const bool* skip_update_;

LambMomentREGUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon,
MT beta1_pow, MT* beta1_pow_out, MT beta2_pow,
MT* beta2_pow_out, const MT* mom1, MT* mom1_out,
const MT* mom2, MT* mom2_out, const T* grad,
const MT* param, MT* trust_ratio_div,
const bool* skip_update)
MT beta1_pow, MT beta2_pow, const MT* mom1,
MT* mom1_out, const MT* mom2, MT* mom2_out,
const T* grad, const MT* param,
MT* trust_ratio_div, const bool* skip_update)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
Expand Down Expand Up @@ -95,10 +92,6 @@ struct LambMomentREGUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta2_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}
};

Expand All @@ -113,9 +106,7 @@ struct LambMomentMENUpdateFunctor {
MT epsilon_;

const MT* beta1_pow_;
MT* beta1_pow_out_;
const MT* beta2_pow_;
MT* beta2_pow_out_;
const MT* moment1_;
MT* moment1_out_;
const MT* moment2_;
Expand All @@ -126,8 +117,7 @@ struct LambMomentMENUpdateFunctor {
const bool* skip_update_;

LambMomentMENUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon,
const MT* beta1_pow, MT* beta1_pow_out,
const MT* beta2_pow, MT* beta2_pow_out,
const MT* beta1_pow, const MT* beta2_pow,
const MT* mom1, MT* mom1_out, const MT* mom2,
MT* mom2_out, const T* grad, const MT* param,
MT* trust_ratio_div, const bool* skip_update)
Expand All @@ -136,9 +126,7 @@ struct LambMomentMENUpdateFunctor {
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
Expand Down Expand Up @@ -168,10 +156,6 @@ struct LambMomentMENUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta2_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}
};

Expand All @@ -183,9 +167,7 @@ struct SparseLambMomentREGUpdateFunctor {
T epsilon_;

T beta1_pow_;
T* beta1_pow_out_;
T beta2_pow_;
T* beta2_pow_out_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
Expand All @@ -201,20 +183,18 @@ struct SparseLambMomentREGUpdateFunctor {
const bool* skip_update_;

SparseLambMomentREGUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
T beta1_pow, T* beta1_pow_out, T beta2_pow,
T* beta2_pow_out, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* grad,
const T* param, T* trust_ratio_div,
const int64_t* rows, int64_t row_numel,
int64_t row_count, const bool* skip_update)
T beta1_pow, T beta2_pow, const T* mom1,
T* mom1_out, const T* mom2, T* mom2_out,
const T* grad, const T* param,
T* trust_ratio_div, const int64_t* rows,
int64_t row_numel, int64_t row_count,
const bool* skip_update)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
Expand Down Expand Up @@ -246,10 +226,6 @@ struct SparseLambMomentREGUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta1_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}

inline HOSTDEVICE void operator()(size_t i) const {
Expand All @@ -270,9 +246,7 @@ struct SparseLambMomentMENUpdateFunctor {
T epsilon_;

const T* beta1_pow_;
T* beta1_pow_out_;
const T* beta2_pow_;
T* beta2_pow_out_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
Expand All @@ -288,8 +262,7 @@ struct SparseLambMomentMENUpdateFunctor {
const bool* skip_update_;

SparseLambMomentMENUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
const T* beta1_pow, T* beta1_pow_out,
const T* beta2_pow, T* beta2_pow_out,
const T* beta1_pow, const T* beta2_pow,
const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* grad, const T* param,
T* trust_ratio_div, const int64_t* rows,
Expand All @@ -300,9 +273,7 @@ struct SparseLambMomentMENUpdateFunctor {
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
Expand Down Expand Up @@ -334,10 +305,6 @@ struct SparseLambMomentMENUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta1_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}

inline HOSTDEVICE void operator()(size_t i) const {
Expand All @@ -350,11 +317,44 @@ struct SparseLambMomentMENUpdateFunctor {
}
};

template <typename T, bool IsMultiPrecision>
struct LambParamUpateFunctor {
using MT = typename std::conditional<
IsMultiPrecision, typename details::MPTypeTrait<T>::Type, T>::type;
template <typename MT, bool NeedUpdateBetaPow /*=true*/>
struct LambBetaPowUpdateFunctor {
void SetBetaPows(const MT* beta1pow, const MT* beta2pow, MT* beta1pow_out,
MT* beta2pow_out, MT beta1, MT beta2) {
beta1pow_ = beta1pow;
beta2pow_ = beta2pow;
beta1pow_out_ = beta1pow_out;
beta2pow_out_ = beta2pow_out;
beta1_ = beta1;
beta2_ = beta2;
}

HOSTDEVICE void UpdateBetaPow(size_t i) const {
if (i == 0) {
beta1pow_out_[0] = beta1pow_[0] * beta1_;
beta2pow_out_[0] = beta2pow_[0] * beta2_;
}
}

private:
const MT* beta1pow_;
const MT* beta2pow_;
MT* beta1pow_out_;
MT* beta2pow_out_;
MT beta1_;
MT beta2_;
};

template <typename MT>
struct LambBetaPowUpdateFunctor<MT, /*NeedUpdateBetaPow=*/false> {
void SetBetaPows(const MT* beta1pow, const MT* beta2pow, MT* beta1pow_out,
MT* beta2pow_out, MT beta1, MT beta2) {}
HOSTDEVICE void UpdateBetaPow(size_t) const {}
};

template <typename T, typename MT, bool IsMultiPrecision, bool UpdateBetaPow>
struct LambParamUpateFunctor
: public LambBetaPowUpdateFunctor<MT, UpdateBetaPow> {
const MT* lr_;
const T* param_;
const MT* master_param_;
Expand Down Expand Up @@ -396,6 +396,7 @@ struct LambParamUpateFunctor {
if (IsMultiPrecision) {
master_param_out_[i] = param_out;
}
this->UpdateBetaPow(i);
}
};

Expand Down Expand Up @@ -501,15 +502,19 @@ class LambOpKernel : public framework::OpKernel<T> {
: nullptr;

// Update moments
bool should_update_beta_pow_later = false;
const MT *beta1_pow_ptr = nullptr, *beta2_pow_ptr = nullptr;
MT *beta1_pow_out_ptr = nullptr, *beta2_pow_out_ptr = nullptr;
VLOG(10) << "Beta1Pow place: " << beta1_pow.place()
<< " , Beta2Pow place: " << beta2_pow.place();
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = grad_var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(ctx.GetPlace()) &&
beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
LambMomentREGUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay, beta1, beta2, epsilon, *beta1_pow.template data<MT>(),
nullptr, *beta2_pow.template data<MT>(), nullptr,
mom1.template data<MT>(),
*beta2_pow.template data<MT>(), mom1.template data<MT>(),
mom1_out.template mutable_data<MT>(ctx.GetPlace()),
mom2.template data<MT>(),
mom2_out.template mutable_data<MT>(ctx.GetPlace()),
Expand All @@ -523,12 +528,17 @@ class LambOpKernel : public framework::OpKernel<T> {
beta2_pow_out.template mutable_data<MT>(platform::CPUPlace())[0] =
beta2 * beta2_pow.template data<MT>()[0];
} else {
beta1_pow_ptr = beta1_pow.template data<MT>();
beta2_pow_ptr = beta2_pow.template data<MT>();
beta1_pow_out_ptr =
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace());
beta2_pow_out_ptr =
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace());
should_update_beta_pow_later = true;
LambMomentMENUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay, beta1, beta2, epsilon, beta1_pow.template data<MT>(),
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace()),
beta2_pow.template data<MT>(),
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace()),
mom1.template data<MT>(),
weight_decay, beta1, beta2, epsilon,
static_cast<const MT*>(beta1_pow_ptr),
static_cast<const MT*>(beta2_pow_ptr), mom1.template data<MT>(),
mom1_out.template mutable_data<MT>(ctx.GetPlace()),
mom2.template data<MT>(),
mom2_out.template mutable_data<MT>(ctx.GetPlace()),
Expand All @@ -542,7 +552,12 @@ class LambOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(IsMultiPrecision, false,
platform::errors::Unimplemented(
"SelectedRows gradient is not supported when "
"multi_precision=True"));
"multi_precision=True."));
constexpr bool kIsSameType = std::is_same<T, MT>::value;
PADDLE_ENFORCE_EQ(kIsSameType, true,
platform::errors::Unimplemented(
"SelectedRows gradient is not supported when "
"multi_precision=True."));
auto& grad = GET_DATA_SAFELY(ctx.Input<framework::SelectedRows>("Grad"),
"Input", "Grad", "Lamb");
if (grad.rows().size() == 0) {
Expand Down Expand Up @@ -582,8 +597,8 @@ class LambOpKernel : public framework::OpKernel<T> {
SparseLambMomentREGUpdateFunctor<T> moment_update_functor(
static_cast<T>(weight_decay), static_cast<T>(beta1),
static_cast<T>(beta2), static_cast<T>(epsilon),
*beta1_pow.template data<T>(), nullptr,
*beta2_pow.template data<T>(), nullptr, mom1.template data<T>(),
*beta1_pow.template data<T>(), *beta2_pow.template data<T>(),
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), grad_data,
Expand All @@ -595,14 +610,18 @@ class LambOpKernel : public framework::OpKernel<T> {
beta2_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
static_cast<T>(beta2) * beta2_pow.template data<T>()[0];
} else {
beta1_pow_ptr = beta1_pow.template data<MT>();
beta2_pow_ptr = beta2_pow.template data<MT>();
beta1_pow_out_ptr =
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace());
beta2_pow_out_ptr =
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace());
should_update_beta_pow_later = true;
SparseLambMomentMENUpdateFunctor<T> moment_update_functor(
static_cast<T>(weight_decay), static_cast<T>(beta1),
static_cast<T>(beta2), static_cast<T>(epsilon),
beta1_pow.template data<T>(),
beta1_pow_out.template mutable_data<T>(ctx.GetPlace()),
beta2_pow.template data<T>(),
beta2_pow_out.template mutable_data<T>(ctx.GetPlace()),
mom1.template data<T>(),
reinterpret_cast<const T*>(beta1_pow_ptr),
reinterpret_cast<const T*>(beta2_pow_ptr), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), grad_data,
Expand Down Expand Up @@ -639,14 +658,31 @@ class LambOpKernel : public framework::OpKernel<T> {
}
trust_ratio_div_norm.device(*place) = t.square().sum().sqrt();

LambParamUpateFunctor<T, IsMultiPrecision> param_update_functor(
lr.template data<MT>(), static_cast<const T*>(param_ptr),
static_cast<const MT*>(master_param_ptr), p_norm_t.template data<MT>(),
trust_ratio_div.template data<MT>(),
trust_ratio_div_norm_t.template data<MT>(),
static_cast<T*>(param_out_ptr), static_cast<MT*>(master_param_out_ptr),
skip_update_flag);
for_range(param_update_functor);
#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \
do { \
LambParamUpateFunctor<T, MT, IsMultiPrecision, __should_update_beta_pow> \
param_update_functor( \
lr.template data<MT>(), static_cast<const T*>(param_ptr), \
static_cast<const MT*>(master_param_ptr), \
p_norm_t.template data<MT>(), trust_ratio_div.template data<MT>(), \
trust_ratio_div_norm_t.template data<MT>(), \
static_cast<T*>(param_out_ptr), \
static_cast<MT*>(master_param_out_ptr), skip_update_flag); \
if (__should_update_beta_pow) { \
param_update_functor.SetBetaPows(beta1_pow_ptr, beta2_pow_ptr, \
beta1_pow_out_ptr, beta2_pow_out_ptr, \
beta1, beta2); \
} \
for_range(param_update_functor); \
} while (0)

if (should_update_beta_pow_later) {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(true);
} else {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(false);
}

#undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC
}
};

Expand Down

1 comment on commit c7145c2

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.