diff --git a/src/operator/nn/layer_norm.cu b/src/operator/nn/layer_norm.cu index 2d6d54c5b56a..386b8133a2ee 100644 --- a/src/operator/nn/layer_norm.cu +++ b/src/operator/nn/layer_norm.cu @@ -328,7 +328,6 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, template __global__ void LayerNormFusedBackwardKernel_PartGammaBeta(const int nbatch, const int nchannel, - const int row_repeat, const DType* __restrict__ in_data, const DType* __restrict__ out_grad, const DType* __restrict__ mean_data, @@ -351,18 +350,15 @@ __global__ void LayerNormFusedBackwardKernel_PartGammaBeta(const int nbatch, DType local_beta_grad = 0; if (c < nchannel) { - for (int r_b = r_begin; r_b < r_end; r_b += blockDim.y * row_repeat) { - for (int i = 0; i < row_repeat; ++i) { - int r_offset = i * blockDim.y + threadIdx.y; - int r = r_b + r_offset; - if (r < r_end) { - DType local_mean = mean_data[r]; - DType local_std = std_data[r]; - int read_idx = r * nchannel + c; - local_gamma_grad += (in_data[read_idx] - local_mean) / local_std - * out_grad[read_idx]; - local_beta_grad += out_grad[read_idx]; - } + for (int r_b = r_begin; r_b < r_end; r_b += blockDim.y) { + int r = r_b + threadIdx.y; + if (r < r_end) { + DType local_mean = mean_data[r]; + DType local_std = std_data[r]; + int read_idx = r * nchannel + c; + local_gamma_grad += (in_data[read_idx] - local_mean) / local_std + * out_grad[read_idx]; + local_beta_grad += out_grad[read_idx]; } } } @@ -536,11 +532,10 @@ __global__ void LayerNormFusedBackwardKernel_Data(const int nbatch, void GetGammaBetaGradKernelParams(const int nbatch, const int nchannel, dim3* part_grad_block_dim, dim3* part_grad_grid_dim, dim3* gb_block_dim, dim3* gb_grid_dim, - int* row_repeat, int* npart) { + int* npart) { *npart = 16; *part_grad_block_dim = dim3(32, 16); *part_grad_grid_dim = dim3((nchannel + 32 - 1) / 32, *npart); - *row_repeat = 4; *gb_block_dim = dim3(32, *npart); *gb_grid_dim = dim3((nchannel + 32 - 1) / 32); CheckLaunchParam(*part_grad_grid_dim, *part_grad_block_dim); @@ -583,9 +578,9 @@ void LayerNormGradGPUContig(const LayerNormParam param, CHECK_EQ(gamma_grad.CheckContiguous(), true); CHECK_EQ(beta_grad.CheckContiguous(), true); dim3 part_grad_block_dim, part_grad_grid_dim, gb_block_dim, gb_grid_dim; - int row_repeat, npart; + int npart; GetGammaBetaGradKernelParams(nbatch, nchannel, &part_grad_block_dim, &part_grad_grid_dim, - &gb_block_dim, &gb_grid_dim, &row_repeat, &npart); + &gb_block_dim, &gb_grid_dim, &npart); if (gamma_grad_req != kNullOp || beta_grad_req != kNullOp) { MSHADOW_REAL_TYPE_SWITCH(in_data.type_flag_, DType, { Tensor workspace = @@ -599,7 +594,7 @@ void LayerNormGradGPUContig(const LayerNormParam param, DType* beta_grad_ptr = (beta_grad_req != kNullOp) ? beta_grad.dptr() : nullptr; LayerNormFusedBackwardKernel_PartGammaBeta <<>> - (nbatch, nchannel, row_repeat, in_data.dptr(), out_grad.dptr(), + (nbatch, nchannel, in_data.dptr(), out_grad.dptr(), mean_data.dptr(), std_data.dptr(), part_gamma_grad_ptr, part_beta_grad_ptr); MSHADOW_CUDA_POST_KERNEL_CHECK(LayerNormFusedBackwardKernel_PartGammaBeta); if (gamma_grad_req == kAddTo && beta_grad_req != kAddTo) {