Skip to content

Commit

Permalink
fix error
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesLim-sy committed Oct 9, 2021
1 parent 0f924da commit 0deaa40
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 13 deletions.
13 changes: 6 additions & 7 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,12 @@ __global__ void MomentumLarsKernel(
rescale_grad, gridDim.x, &param_norm, &grad_norm);
#else
const MT rescale_grad_pow = rescale_grad * rescale_grad;
MT param_parital_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0;
MT grad_parital_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0;
MT param_part_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0;
MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0;
__syncthreads();
MT param_norm =
Sqrt(math::blockReduceSum<MT>(param_parital_norm, FINAL_MASK));
MT param_norm = Sqrt(math::blockReduceSum<MT>(param_part_norm, FINAL_MASK));
MT grad_norm = Sqrt(rescale_grad_pow *
math::blockReduceSum<MT>(grad_parital_norm, FINAL_MASK));
math::blockReduceSum<MT>(grad_part_norm, FINAL_MASK));
#endif

const MT lr = learning_rate[0];
Expand Down Expand Up @@ -499,9 +498,9 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
MT* master_param_out_data = nullptr;

if (multi_precision) {
auto master_param = ctx.MultiInput<framework::Tensor>("MasterParam");
auto master_param = ctx.MultiInput<framework::LoDTensor>("MasterParam");
auto master_param_out =
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
ctx.MultiOutput<framework::LoDTensor>("MasterParamOut");
master_param_data = master_param[0]->data<MT>();
master_param_out_data =
master_param_out[0]->mutable_data<MT>(ctx.GetPlace());
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/optimizers/lars_momentum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class LarsMomentumOpKernel : public framework::OpKernel<T> {

T mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff");
T lars_weight_decay = (ctx.Attr<std::vector<float>>("lars_weight_decay"))[0];
T lars_weight_decay = ctx.Attr<std::vector<float>>("lars_weight_decay")[0];
T epsilon = ctx.Attr<float>("epsilon");

auto p_out = framework::EigenVector<T>::Flatten(*(param_out[0]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def net(self, main_prog, startup_prog):
strategy.lars = True
strategy.lars_configs = {
"lars_coeff": 0.001,
"lars_weight_decay": [0.0005],
"lars_weight_decay": 0.0005,
"epsilon": 0,
"exclude_from_weight_decay": ["batch_norm", ".b"],
}
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_lars_apply_with_amp(self):
strategy.lars = True
strategy.lars_configs = {
"lars_coeff": 0.001,
"lars_weight_decay": [0.0005],
"lars_weight_decay": 0.0005,
"epsilon": 0,
"exclude_from_weight_decay": ["batch_norm", ".b"],
}
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/fluid/tests/unittests/test_momentum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ def setUp(self):
grads = []
velocitys = []
learning_rates = []
master_params = []
param_outs = []
velocity_outs = []
master_params = []
master_param_outs = []
for i in range(self.params_num):
master_param = np.random.random((123, 321)).astype("float32")
Expand Down Expand Up @@ -376,8 +376,8 @@ def setUp(self):
gnorm = np.sqrt(np.square(grad).sum())
local_lr = learning_rate * lars_coeff * pnorm / (
gnorm + lars_weight_decay[i] * param)
velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay[i]
* param)
velocity_out = mu * velocity + local_lr * (
grad + lars_weight_decay[i] * param)
param_out = param - velocity_out

params.append(("SubParam_" + str(i), param))
Expand Down

0 comments on commit 0deaa40

Please sign in to comment.