Skip to content

Commit

Permalink
fix conflicts in optimizer.py and finally can pass coverage ci. yes!
Browse files Browse the repository at this point in the history
Shrinking codes less than 1000.
  • Loading branch information
JamesLim-sy committed Oct 12, 2021
2 parents b37c2d8 + 033a73c commit e1840a7
Show file tree
Hide file tree
Showing 15 changed files with 382 additions and 274 deletions.
Binary file removed log
Binary file not shown.
6 changes: 6 additions & 0 deletions paddle/fluid/operators/concat_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,14 @@ namespace ops = paddle::operators;

REGISTER_OP_NPU_KERNEL(concat, ops::ConcatNPUKernel<float>,
ops::ConcatNPUKernel<paddle::platform::float16>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::ConcatNPUKernel<int64_t>,
#endif
ops::ConcatNPUKernel<int>);

REGISTER_OP_NPU_KERNEL(concat_grad, ops::ConcatGradNPUKernel<float>,
ops::ConcatGradNPUKernel<paddle::platform::float16>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::ConcatGradNPUKernel<int64_t>,
#endif
ops::ConcatGradNPUKernel<int>);
64 changes: 17 additions & 47 deletions paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -536,32 +536,20 @@ class CudnnBNAddReluTester {
bn_bias->Resize({1, 1, 1, channels_});

// input
float *sum_ptr = sum->data<float>();
float *sum_of_square_ptr = sum_of_square->data<float>();
float *bn_scale_ptr = bn_scale->data<float>();
float *bn_bias_ptr = bn_bias->data<float>();

mean->Resize({1, 1, 1, channels_});
var->Resize({1, 1, 1, channels_});

// output
float *mean_ptr = mean->data<float>();
float *var_ptr = var->data<float>();
float *saved_mean_ptr =
saved_mean->mutable_data<float>({1, 1, 1, channels_}, place);
float *saved_var_ptr =
saved_var->mutable_data<float>({1, 1, 1, channels_}, place);
T *equiv_scale_ptr =
equiv_scale->mutable_data<T>({1, 1, 1, channels_}, place);
T *equiv_bias_ptr =
equiv_bias->mutable_data<T>({1, 1, 1, channels_}, place);
equiv_scale->Resize({1, 1, 1, channels_});
equiv_bias->Resize({1, 1, 1, channels_});
saved_mean->Resize({1, 1, 1, channels_});
saved_var->Resize({1, 1, 1, channels_});

auto param_shape = framework::vectorize<int>(bn_scale->dims());
op::CudnnBNStatsFinalize<T> bn_op(ctx, param_shape);
bn_op.Forward(ctx, sum_ptr, sum_of_square_ptr, bn_scale_ptr, bn_bias_ptr,
saved_mean_ptr, saved_var_ptr, mean_ptr, var_ptr,
equiv_scale_ptr, equiv_bias_ptr, eps_, momentum_, ele_count_,
true);
bn_op.Forward(ctx, *sum, *sum_of_square, *bn_scale, *bn_bias, saved_mean,
saved_var, mean, var, equiv_scale, equiv_bias, eps_,
momentum_, ele_count_, true);
}

// Get forward results of CudnnBNStatsFinalize + CudnnScaleBiasAddRelu
Expand Down Expand Up @@ -627,21 +615,13 @@ class CudnnBNAddReluTester {
&saved_var_z, &equiv_scale_z, &equiv_bias_z);
}

T *x_ptr = x.data<T>();
T *z_ptr = (fuse_add_ || has_shortcut_) ? z.data<T>() : nullptr;
T *equiv_scale_x_ptr = equiv_scale_x.data<T>();
T *equiv_bias_x_ptr = equiv_bias_x.data<T>();
T *equiv_scale_z_ptr = has_shortcut_ ? equiv_scale_z.data<T>() : nullptr;
T *equiv_bias_z_ptr = has_shortcut_ ? equiv_bias_z.data<T>() : nullptr;
T *y_ptr =
y.mutable_data<T>({batch_size_, height_, width_, channels_}, place);
y.Resize(framework::make_ddim({batch_size_, height_, width_, channels_}));

int c = channels_;
int64_t nhw = ele_count_;
int32_t c_int32_elems = ((c + 63) & ~63) / 32;
int32_t nhw_int32_elems = (nhw + 31) & ~31;
int32_t *bitmask_ptr = bitmask.mutable_data<int32_t>(
{nhw_int32_elems, c_int32_elems, 1}, place);
bitmask.Resize(framework::make_ddim({nhw_int32_elems, c_int32_elems, 1}));

auto data_shape = framework::vectorize<int>(x.dims());
auto param_shape = framework::vectorize<int>(bn_scale_x.dims());
Expand All @@ -651,8 +631,8 @@ class CudnnBNAddReluTester {
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type_, fuse_add_,
has_shortcut_, data_shape, param_shape,
bitmask_shape);
sbar_op.Forward(ctx, x_ptr, equiv_scale_x_ptr, equiv_bias_x_ptr, y_ptr,
bitmask_ptr, z_ptr, equiv_scale_z_ptr, equiv_bias_z_ptr);
sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, z, equiv_scale_z,
equiv_bias_z, &y, &bitmask);

TensorCopySync(mean_x, platform::CPUPlace(), cpu_mean_x);
TensorCopySync(var_x, platform::CPUPlace(), cpu_var_x);
Expand Down Expand Up @@ -697,19 +677,10 @@ class CudnnBNAddReluTester {
saved_mean.Resize({1, 1, 1, channels_});
saved_var.Resize({1, 1, 1, channels_});

T *dy_ptr = dy.data<T>();
T *x_ptr = x.data<T>();
float *bn_scale_ptr = bn_scale.data<float>();
float *bn_bias_ptr = bn_bias.data<float>();
float *saved_mean_ptr = saved_mean.data<float>();
float *saved_var_ptr = saved_var.data<float>();
int32_t *bitmask_ptr = bitmask.data<int32_t>();
T *dx_ptr =
dx.mutable_data<T>({batch_size_, height_, width_, channels_}, place);
T *dz_ptr =
dz.mutable_data<T>({batch_size_, height_, width_, channels_}, place);
float *dscale_ptr = dscale.mutable_data<float>({1, 1, 1, channels_}, place);
float *dbias_ptr = dbias.mutable_data<float>({1, 1, 1, channels_}, place);
dx.Resize(framework::make_ddim({batch_size_, height_, width_, channels_}));
dz.Resize(framework::make_ddim({batch_size_, height_, width_, channels_}));
dscale.Resize(framework::make_ddim({1, 1, 1, channels_}));
dbias.Resize(framework::make_ddim({1, 1, 1, channels_}));

auto data_shape = framework::vectorize<int>(x.dims());
auto param_shape = framework::vectorize<int>(bn_scale.dims());
Expand All @@ -718,9 +689,8 @@ class CudnnBNAddReluTester {
std::string act_type = "relu";
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type, true, false, data_shape,
param_shape, bitmask_shape);
sbar_op.Backward(ctx, dy_ptr, x_ptr, bn_scale_ptr, bn_bias_ptr,
saved_mean_ptr, saved_var_ptr, bitmask_ptr, dx_ptr, dz_ptr,
dscale_ptr, dbias_ptr, eps_);
sbar_op.Backward(ctx, dy, x, bn_scale, bn_bias, saved_mean, saved_var,
bitmask, &dx, &dz, &dscale, &dbias, eps_);

TensorCopySync(dx, platform::CPUPlace(), cpu_dx);
TensorCopySync(dz, platform::CPUPlace(), cpu_dz);
Expand Down
24 changes: 18 additions & 6 deletions paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,13 @@ class CudnnBNStatsFinalize {
}
~CudnnBNStatsFinalize() {}

void Forward(const platform::CUDADeviceContext &ctx, float *sum_ptr,
float *sum_of_squares_ptr, float *scale_ptr, float *bias_ptr,
float *saved_mean_ptr, float *saved_invstd_ptr,
float *running_mean_ptr, float *running_var_ptr,
T *equiv_scale_ptr, T *equiv_bias_ptr, double eps,
float momentum, int64_t ele_count, bool is_train) {
void Forward(const platform::CUDADeviceContext &ctx, const Tensor &sum,
const Tensor &sum_of_squares, const Tensor &scale,
const Tensor &bias, Tensor *saved_mean, Tensor *saved_invstd,
Tensor *running_mean, Tensor *running_var, Tensor *equiv_scale,
Tensor *equiv_bias, double eps, float momentum,
int64_t ele_count, bool is_train) {
auto place = ctx.GetPlace();
if (is_train) {
TrainInit(ctx);
} else {
Expand All @@ -82,6 +83,17 @@ class CudnnBNStatsFinalize {
auto &op = is_train ? train_op_ : inference_op_;

// Set variant_param for both inference_op_ and train_op_
float *sum_ptr = const_cast<float *>(sum.data<float>());
float *sum_of_squares_ptr =
const_cast<float *>(sum_of_squares.data<float>());
float *scale_ptr = const_cast<float *>(scale.data<float>());
float *bias_ptr = const_cast<float *>(bias.data<float>());
float *saved_mean_ptr = saved_mean->mutable_data<float>(place);
float *saved_invstd_ptr = saved_invstd->mutable_data<float>(place);
float *running_mean_ptr = running_mean->mutable_data<float>(place);
float *running_var_ptr = running_var->mutable_data<float>(place);
T *equiv_scale_ptr = equiv_scale->mutable_data<T>(place);
T *equiv_bias_ptr = equiv_bias->mutable_data<T>(place);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_SCALE, scale_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_BIAS, bias_ptr);
op.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_RUNNING_MEAN, running_mean_ptr);
Expand Down
94 changes: 74 additions & 20 deletions paddle/fluid/operators/fused/cudnn_norm_conv.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ struct NormConvolutionArgs {
compute_type = platform::CudnnDataType<float>::type;
}

void Set(const std::vector<int> &input_shape,
void Set(const platform::CUDADeviceContext &ctx,
const std::vector<int> &input_shape,
const std::vector<int> &filter_shape,
const std::vector<int> &output_shape, int padding, int stride,
int dilation, int group) {
Expand All @@ -61,12 +62,33 @@ struct NormConvolutionArgs {
"The filter_shape is expected to store as nhwc, and "
"h = w = 1 or 3. But recieved filter_shape is [%s].",
framework::make_ddim(filter_shape)));
PADDLE_ENFORCE_EQ((filter_shape[0] % 32 == 0 && filter_shape[3] % 8 == 0),
true,
platform::errors::InvalidArgument(
"The input channel is expected to be multiple of 8, "
"and the output channel is expected to be multiple "
"of 32. But recieved input channel is %d, output "
"channel is %d.",
filter_shape[3], filter_shape[0]));
PADDLE_ENFORCE_EQ(
output_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of output_shape is expected to 4. But recieved "
"filter_shape's size is %d, filter_shape is [%s].",
output_shape.size(), framework::make_ddim(output_shape)));
is_support = IsSupport(ctx, filter_shape, stride, dilation, group);
PADDLE_ENFORCE_EQ(
is_support, true,
platform::errors::InvalidArgument(
"Current test is only supported in the platforms with "
"compatiblity greater than or equal to 70 and the kernel size "
"must be equal to 1 or 3. When the kernel size is 1, "
"the stride must be 1 if the compatiblity is equal to 70. "
"Besides, the dilation and group must be equal to 1. But recieved "
"compatiblity is %d, kernel size is %d, stride is %d, "
"dilation is %d, group is %d",
ctx.GetComputeCapability(), filter_shape[1], stride, dilation,
group));

for (size_t i = 0; i < input_shape.size(); ++i) {
in_dims.push_back(input_shape[i]);
Expand All @@ -89,6 +111,25 @@ struct NormConvolutionArgs {
conv_desc.set(dtype, paddings, strides, dilations, false, group);
}

bool IsSupport(const platform::CUDADeviceContext &ctx,
const std::vector<int> &filter_shape, int stride, int dilation,
int group) {
int kernel_size = filter_shape[1];
if (dilation != 1 || group != 1) {
return false;
}
if (ctx.GetComputeCapability() == 70) {
if ((kernel_size == 3) || ((kernel_size == 1) && (stride == 1))) {
return true;
}
} else if (ctx.GetComputeCapability() > 70) {
if ((kernel_size == 3) || (kernel_size == 1)) {
return true;
}
}
return false;
}

cudnnDataType_t dtype;
cudnnTensorFormat_t format;
cudnnDataType_t compute_type;
Expand All @@ -104,6 +145,8 @@ struct NormConvolutionArgs {
platform::TensorDescriptor out_desc;
platform::TensorDescriptor out_stats_desc;
platform::ConvolutionDescriptor conv_desc;

bool is_support;
};

template <typename T>
Expand All @@ -115,15 +158,16 @@ class CudnnNormConvolution {
const std::vector<int> &output_shape, const int &padding,
const int &stride, const int &dilation,
const int &group) {
args_.Set(input_shape, filter_shape, output_shape, padding, stride,
args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride,
dilation, group);
}
~CudnnNormConvolution() {}

void Forward(const platform::CUDADeviceContext &ctx, T *input_ptr,
T *filter_ptr, T *output_ptr, float *sum_ptr,
float *sum_of_squares_ptr) {
void Forward(const platform::CUDADeviceContext &ctx, const Tensor &input,
const Tensor &filter, Tensor *output, Tensor *sum,
Tensor *sum_of_squares) {
auto cudnn_handle = ctx.cudnn_handle();
auto place = ctx.GetPlace();

CudnnFusionOp *fwd_op = GetForwardOp(ctx);
size_t workspace_size = RoundUp(
Expand All @@ -132,12 +176,17 @@ class CudnnNormConvolution {

// Set variant_param
// input ptr
T *input_ptr = const_cast<T *>(input.data<T>());
T *filter_ptr = const_cast<T *>(filter.data<T>());
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WDATA, filter_ptr);
fwd_op->SetOpVariantParamAttrPtr(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);

// output ptr
T *output_ptr = output->mutable_data<T>(place);
float *sum_ptr = sum->mutable_data<float>(place);
float *sum_of_squares_ptr = sum_of_squares->mutable_data<float>(place);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr);
fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr);
Expand Down Expand Up @@ -209,28 +258,34 @@ class CudnnNormConvolutionGrad {
const std::vector<int> &output_shape,
const int &padding, const int &stride,
const int &dilation, const int &group) {
args_.Set(input_shape, filter_shape, output_shape, padding, stride,
args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride,
dilation, group);
dgrad_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
~CudnnNormConvolutionGrad() {}

void Backward(const platform::CUDADeviceContext &ctx, T *input_ptr,
T *output_grad_ptr, T *filter_ptr, T *input_grad_ptr,
T *filter_grad_ptr, bool use_addto = false) {
if (filter_grad_ptr) {
BackwardFilter(ctx, input_ptr, output_grad_ptr, filter_ptr,
filter_grad_ptr);
void Backward(const platform::CUDADeviceContext &ctx, const Tensor &input,
const Tensor &filter, const Tensor &output_grad,
Tensor *input_grad, Tensor *filter_grad,
bool use_addto = false) {
auto place = ctx.GetPlace();
T *input_ptr = const_cast<T *>(input.data<T>());
T *filter_ptr = const_cast<T *>(filter.data<T>());
T *output_grad_ptr = const_cast<T *>(output_grad.data<T>());

if (filter_grad) {
T *filter_grad_ptr = filter_grad->mutable_data<T>(place);
BackwardFilter(ctx, output_grad_ptr, input_ptr, filter_grad_ptr);
}
if (input_grad_ptr) {
BackwardData(ctx, input_ptr, output_grad_ptr, filter_ptr, input_grad_ptr,
use_addto);
if (input_grad) {
T *input_grad_ptr = input_grad->mutable_data<T>(place);
BackwardData(ctx, output_grad_ptr, filter_ptr, input_grad_ptr, use_addto);
}
}

private:
void BackwardFilter(const platform::CUDADeviceContext &ctx, T *input_ptr,
T *output_grad_ptr, T *filter_ptr, T *filter_grad_ptr) {
void BackwardFilter(const platform::CUDADeviceContext &ctx,
T *output_grad_ptr, T *input_ptr, T *filter_grad_ptr) {
auto cudnn_handle = ctx.cudnn_handle();

CudnnFusionOp *wgrad_op = GetBackwardFilterOp(ctx);
Expand All @@ -255,9 +310,8 @@ class CudnnNormConvolutionGrad {
workspace_size);
}

void BackwardData(const platform::CUDADeviceContext &ctx, T *input_ptr,
T *output_grad_ptr, T *filter_ptr, T *input_grad_ptr,
bool use_addto = false) {
void BackwardData(const platform::CUDADeviceContext &ctx, T *output_grad_ptr,
T *filter_ptr, T *input_grad_ptr, bool use_addto = false) {
auto cudnn_handle = ctx.cudnn_handle();
size_t workspace_size = GetWorkspaceSizeBwdData(ctx);

Expand Down
Loading

0 comments on commit e1840a7

Please sign in to comment.