Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge all larsMomentum ops into one cuda kernel #35476

Merged
merged 26 commits into from
Oct 13, 2021

Conversation

JamesLim-sy
Copy link
Contributor

@JamesLim-sy JamesLim-sy commented Sep 5, 2021

PR types

Performance optimization

PR changes

OPs

Describe

  • Feature:
    As an extension of #35652. Merge all lars optimizer inside Resnet50 into just one cuda kernel, and kill the most of __global__ kernel lanuch overhead caused by lars optimizer in model.

  • Optimization method
    Package all param, grad and attrs of those lars optimizer in model, and extend the optimization method in #35652 with one for loop in __global__ kernel.
    In this way, the timeline of lars optimizer in resnet50 change from Fig. 1 to Fig. 2, and __global__ kernel lanuch overhead shrinked.

Fig. 1 Separated lars optimizer

Fig. 2 Packaged lars optimizer

  • Performance
    As for resent50 model, after merging all lars optimizer into just two cuda kernels(one for amp lars op, one for the rest), the model trainning performance increase from 3082 ips to 3114 ips, about 1.04%. (one cuda card.)

  • Note:
    Currently, the maximum supported lars op is limited no more than 150. Will be modified in next pr.

@paddle-bot-old
Copy link

paddle-bot-old bot commented Sep 5, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

(eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon);
}
v_out = v * mu + local_lr * (g + lars_weight_decay * p);
p_out = p - v_out;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line46-75 的代码和 line79-107区间的代码重复度较高,下一个commit会合并起来

AddAttr<std::vector<float>>(
"lars_weight_decay",
"(float, default 0.0005) Merged LARS weight decay params")
.SetDefault({0.0005});
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑到lars_op 合并功能启动之后,被合并的lars op 中可能携带各自不同的 lars_weight_decay参数,因此这部分的attr参数,修改成 std::vector形式

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种默认值写法,解释为当lars_weight_decay的的长度为1时,所有参数都共用这个一个值?当lars_weight_decay的长度不为1时,它的长度是不是应该和Param的长度相同?InferShape里面加一下检查?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前在infer_shape里面加入过这个判断:

 auto lars_weight_decays =
        ctx->Attrs().Get<std::vector<float>>("lars_weight_decay");
    PADDLE_ENFORCE_EQ(lars_weight_decays.size(), grad_dim.size(),
                      platform::errors::InvalidArgument(
                          "Lars_weight_decay and Grad input of LarsMomentumOp "
                          "should have the same "
                          "quantity. But number of Lars_weight_decay is [%d] "
                          "and Grad is [%d].",
                          lars_weight_decays.size(), grad_dim.size()));

开发过程中版本管理出了点问题,这里的代码丢失了,下一个commit会补充

*p_n = Sqrt(math::blockReduceSum<MT>(p_partial_sum, FINAL_MASK));
*g_n = Sqrt(rescale_grad_pow *
*p_n = sqrt(math::blockReduceSum<MT>(p_partial_sum, FINAL_MASK));
*g_n = sqrt(rescale_grad_pow *
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里笔误写成了sqrt,应该调用本文件预定义的函数:
__device__ __forceinline__ float Sqrt(float x) { return sqrtf(x); }
下一个commit会修改回来

template <typename T, typename MT>
LARS_FUNCTION_FLAG void L2NormKernel(
__global__ void L2NormKernel(
#endif
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的写法确实不太美观~~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

或许可以将cooperative_groups的一些操作封装一下,如下:

struct CooperativeGroups {
#if CUDA_VERSION >= 11000
 public:
  CooperativeGroups() {
    cg_ = cooperative_groups::this_grid();
  }
  void Sync() {
    cg_.sync();
  }
 private:
  cooperative_groups::grid_group cg_;
#endif
};

那么,这个函数定义可以写成如下,cuda10.x的环境下,CooperativeGroups是一个空的类,也不会有多大开销?

void L2NormKernel(const CooperativeGroups& cg, ...) {
    ...
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

salute,根据建议在下一个commit修改

@@ -38,6 +28,8 @@ limitations under the License. */
#define LARS_BLOCK_SIZE 512
#endif

#define LARS_MAX_MERGED_OPS 200

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最大合并的lars op数量设置为200,主要考虑的是支持resnet模型,其中最大合并的lars op 规模是100+,索性这里就写成了200,但是这里的处理方式较为暴力,限制了模型中最大合并的lars op最大数量阈值。一旦超过这个阈值就会直接报错,后续的PR会修复这里的问题。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否可能支持超过200的情况?比如每隔200个op launch一次MergedLars Kernel?

SunNy820828449
SunNy820828449 previously approved these changes Oct 11, 2021
Copy link
Contributor

@SunNy820828449 SunNy820828449 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

AddAttr<float>("lars_weight_decay",
"(float, default 0.0005) LARS weight decay")
.SetDefault(0.0005);
AddAttr<std::vector<float>>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉会对save/load checkpoint有影响?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据讨论不会有影响

@@ -38,6 +28,8 @@ limitations under the License. */
#define LARS_BLOCK_SIZE 512
#endif

#define LARS_MAX_MERGED_OPS 200

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否可能支持超过200的情况?比如每隔200个op launch一次MergedLars Kernel?


protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInputs("Param"), true,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个也可以用OP_INOUT_CHECK来检查。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议会在下一个commit修改

platform::errors::NotFound(
"Input(LearningRate) of LarsMomentum should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Param").front(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在的Param既然是一个vector,那只检查第一个param是不是LOD_TENSOR,似乎不够。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那这里的判断感觉可以移动到 InferShape里面去

"been initialized. You may need to confirm "
"whether exe.run(startup_program) is put "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(framework::product(lr_dims[i]), 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有这个检查,L57行的检查是不是多余的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是多余的,这部分抄了momentum_op的内容,这里也连带无脑抄过来了

"quantity. But number of Param is [%d] and Velocity is [%d].",
param_dim.size(), velocity_dim.size()));

if (ctx->GetInputsVarType("Grad")[0] ==
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grad有可能不是LOD_TENSOR吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不会不是LOD_TENSOR,这里的判断可以删除


if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
for (size_t i = 0; i < param_dim.size(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for循环放外面,L86的if挪到for循环里面。

int tid = threadIdx.x + blockIdx.x * blockDim.x;
const cooperative_groups::grid_group cg = cooperative_groups::this_grid();
for (int i = 0; i < op_num; ++i) {
int numel = merged_params->numel_arr[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码可以复用下,宁可多封装一个函数,也不要复制、粘贴多份相同的代码。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,根据建议修改

grad[i]->data<T>(), learning_rate[i]->data<MT>(), p_buffer,
g_buffer, mu, lars_coeff, weight_decay_arr[i], epsilon,
rescale_grad, param[i]->numel());
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L461 ~ L482的两个for循环,也可以合并成1个。

        auto master_param = ctx.MultiInput<framework::LoDTensor>("MasterParam");
        auto master_param_out =
            ctx.MultiOutput<framework::LoDTensor>("MasterParamOut");

        for (int i = 0; i < op_num; ++i) {
          const *T master_param_ptr = multi_precision ? master_param[i]->data<MT>() : nullptr;
          T *master_param_out_ptr = multi_precison ? master_param_out[i]->mutable_data<MT>(ctx.GetPlace()) : nullptr;
          SeparatedLarsMomentumOpCUDAKernel<T, MT>(
              cuda_ctx, param[i]->data<T>(),
              param_out[i]->mutable_data<T>(ctx.GetPlace()),
              velocity[i]->data<MT>(),
              velocity_out[i]->mutable_data<MT>(ctx.GetPlace()),
              grad[i]->data<T>(), learning_rate[i]->data<MT>(), p_buffer,
              g_buffer, mu, lars_coeff, weight_decay_arr[i], epsilon,
              rescale_grad, param[i]->numel(), master_param_ptr, master_param_out_ptr);
        }

Copy link
Contributor Author

@JamesLim-sy JamesLim-sy Oct 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我其实有个小问题,在非AMP的模式下,在之前的if (multi_precision) 代码段之外直接写 auto master_param = ctx.MultiInput<framework::LoDTensor>("MasterParam"); 不会报错嘛?

cuda_ctx, param_data, param_out_data, velocity_data,
velocity_out_data, grad_data, lr, p_buffer, g_buffer, mu, lars_coeff,
lars_weight_decay, epsilon, rescale_grad, numel, master_param_data,
master_param_out_data);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数220+行代码了,注意每个函数的代码不要超过100行,否则就影响阅读了。

T epsilon = ctx.Attr<float>("epsilon");

auto p_out = framework::EigenVector<T>::Flatten(*param_out);
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);
if (!merge_operation) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是CPU Kernel?既没有merge的实现,是否有for循环,不应该由这个attr决定,而应该由输入Param的个数决定?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是CPU kernel,根据建议修改

@@ -1960,6 +1960,7 @@ def __init__(self,
name=None,
exclude_from_weight_decay=None,
epsilon=0,
merged_ops=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个参数名,也是略有些奇怪啊

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T_T,我再想一个合适的

Copy link
Collaborator

@sneaxiy sneaxiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Superjomn Superjomn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JamesLim-sy JamesLim-sy merged commit 0c31579 into PaddlePaddle:develop Oct 13, 2021
@JamesLim-sy JamesLim-sy changed the title Merge lars op Merge all larsMomentum ops into one cuda kernel Jan 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants