Skip to content

Commit

Permalink
Performance improvement in Normalize GPU Kernel (apache#14139)
Browse files Browse the repository at this point in the history
* New CPU kernel for normalize

* New GPU kernel for Normalize

* Add launch bounds and increase threads to 32*32

* do not hardcode number of threads

* Try fix windows build failure

* make channels as int to fix windows build issues with omp

* Simplify cuda kernels with 1 D thread block

* Minor refactoring

* Revert thread dim for ToTensor operator
  • Loading branch information
sandeep-krishnamurthy authored and stephenrawls committed Feb 16, 2019
1 parent dc3d336 commit b7a12f4
Show file tree
Hide file tree
Showing 2 changed files with 336 additions and 141 deletions.
307 changes: 187 additions & 120 deletions src/operator/image/image_random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,35 @@ void ToTensorImplCUDA(mshadow::Stream<gpu> *s,
const T2 output,
const int req,
const float normalize_factor);

template<typename DType>
void NormalizeImplCUDA(mshadow::Stream<gpu> *s,
const DType *input,
DType *output,
const int req,
const int N,
const int C,
const int H,
const int W,
const float mean_d0,
const float mean_d1,
const float mean_d2,
const float std_d0,
const float std_d1,
const float std_d2);

template<typename DType>
void NormalizeBackwardImplCUDA(mshadow::Stream<gpu> *s,
const DType *out_grad,
DType *in_grad,
const int req,
const int N,
const int C,
const int H,
const int W,
const float std_d0,
const float std_d1,
const float std_d2);
#endif // MXNET_USE_CUDA

// Shape and Type inference for image to tensor operator
Expand Down Expand Up @@ -254,156 +283,165 @@ inline bool NormalizeOpType(const nnvm::NodeAttrs& attrs,
return out_attrs->at(0) != -1;
}

template<int req>
struct normalize_forward {
template<typename DType>
MSHADOW_XINLINE static void Map(uint32_t c, DType* out_data, const DType* in_data,
const float mean_d0, const float mean_d1, const float mean_d2,
const float std_d0, const float std_d1, const float std_d2,
const int length, const int step) {
float mean, std;
switch (c) {
case 0 : mean = mean_d0;
std = std_d0;
break;
case 1 : mean = mean_d1;
std = std_d1;
break;
case 2 : mean = mean_d2;
std = std_d2;
break;
}
#pragma omp parallel for
for (int i = 0; i < length; ++i) {
KERNEL_ASSIGN(out_data[step + c*length + i], req,
(in_data[step + c*length + i] - mean) / std);
}
template<typename DType, int req>
inline void Normalize(DType* out_data,
const DType* in_data,
const int length,
const int channels,
const int step,
const std::vector<float> mean,
const std::vector<float> std) {
// Microsoft Visual C++ compiler does not support omp collapse
#ifdef _MSC_VER
#pragma omp parallel for
#else
#pragma omp parallel for collapse(2)
#endif // _MSC_VER
for (int c = 0; c < channels; ++c) {
for (int i = 0; i < length; ++i) {
KERNEL_ASSIGN(out_data[step + c*length + i], req,
(in_data[step + c*length + i] - mean[c]) / std[c]);
}
};

template<typename xpu>
void NormalizeImpl(const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const float mean_d0, const float mean_d1,
const float mean_d2, const float std_d0,
const float std_d1, const float std_d2,
const int length,
const uint32_t channel,
const int step = 0) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
}
}

MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
DType* input = inputs[0].dptr<DType>();
DType* output = outputs[0].dptr<DType>();
mxnet_op::Kernel<normalize_forward<req_type>, xpu>::Launch(
s, channel, output, input, mean_d0, mean_d1, mean_d2,
std_d0, std_d1, std_d2, length, step);
});
inline void NormalizeImpl(const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const int length,
const int channels,
const int step,
const std::vector<float> mean,
const std::vector<float> std) {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
DType* input = inputs[0].dptr<DType>();
DType* output = outputs[0].dptr<DType>();
Normalize<DType, req_type>(output, input, length, channels, step,
mean, std);
});
});
}

template<typename xpu>
void NormalizeOpForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);

const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);

// Note: We need mean and std_dev in the kernel.
// It is costly (device copy) to pass it as vector, for gpu kernel.
// Hence, passing it as below for performance.
float mean_d0, mean_d1, mean_d2;
float std_d0, std_d1, std_d2;

// Mean and Std can be 1 or 3 D only.
// Mean and Std can be 1 or 3D only.
std::vector<float> mean(3);
std::vector<float> std(3);
if (param.mean.ndim() == 1) {
mean_d0 = mean_d1 = mean_d2 = param.mean[0];
mean[0] = mean[1] = mean[3] = param.mean[0];
} else {
mean_d0 = param.mean[0];
mean_d1 = param.mean[1];
mean_d2 = param.mean[2];
mean[0] = param.mean[0];
mean[1] = param.mean[1];
mean[2] = param.mean[2];
}

if (param.std.ndim() == 1) {
std_d0 = std_d1 = std_d2 = param.std[0];
std[0] = std[1] = std[2] = param.std[0];
} else {
std_d0 = param.std[0];
std_d1 = param.std[1];
std_d2 = param.std[2];
std[0] = param.std[0];
std[1] = param.std[1];
std[2] = param.std[2];
}

// 3D input (c, h, w)
if (inputs[0].ndim() == 3) {
if (std::is_same<xpu, gpu>::value) {
#if MXNET_USE_CUDA
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
int N, C, H, W;
DType *input = nullptr;
DType *output = nullptr;
if (inputs[0].ndim() == 3) {
N = 1;
C = static_cast<int>(inputs[0].shape_[0]);
H = static_cast<int>(inputs[0].shape_[1]);
W = static_cast<int>(inputs[0].shape_[2]);
input = (inputs[0].get<gpu, 3, DType>(s)).dptr_;
output = (outputs[0].get<gpu, 3, DType>(s)).dptr_;
} else {
N = static_cast<int>(inputs[0].shape_[0]);
C = static_cast<int>(inputs[0].shape_[1]);
H = static_cast<int>(inputs[0].shape_[2]);
W = static_cast<int>(inputs[0].shape_[3]);
input = (inputs[0].get<gpu, 4, DType>(s)).dptr_;
output = (outputs[0].get<gpu, 4, DType>(s)).dptr_;
}
NormalizeImplCUDA<DType>(s, input, output, req_type,
N, C, H, W,
mean[0], mean[1], mean[2],
std[0], std[1], std[2]);
});
});
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to use Normalize operator on GPU.";
#endif // MXNET_USE_CUDA
} else if (inputs[0].ndim() == 3) {
// 3D input (c, h, w)
const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
const uint32_t channel = inputs[0].shape_[0];
NormalizeImpl<xpu>(ctx, inputs, outputs, req, mean_d0, mean_d1, mean_d2,
std_d0, std_d1, std_d2, length, channel);
const int channel = static_cast<int>(inputs[0].shape_[0]);
const int step = 0;
NormalizeImpl(inputs, outputs, req, length, channel, step, mean, std);
} else if (inputs[0].ndim() == 4) {
// 4D input (n, c, h, w)
const int batch_size = inputs[0].shape_[0];
const int length = inputs[0].shape_[2] * inputs[0].shape_[3];
const uint32_t channel = inputs[0].shape_[1];
const int channel = static_cast<int>(inputs[0].shape_[1]);
const int step = channel * length;

#pragma omp parallel for
for (auto n = 0; n < batch_size; ++n) {
NormalizeImpl<xpu>(ctx, inputs, outputs, req, mean_d0, mean_d1, mean_d2,
std_d0, std_d1, std_d2, length, channel, n*step);
NormalizeImpl(inputs, outputs, req, length, channel, n*step, mean, std);
}
}
}

// Backward function
template<int req>
struct normalize_backward {
template<typename DType>
MSHADOW_XINLINE static void Map(uint32_t c, DType* in_grad, const DType* out_grad,
const float std_d0, const float std_d1, const float std_d2,
const int length, const int step) {
// d/dx{(x - mean) / std_dev} => (1 / std_dev)
float std_dev;
switch (c) {
case 0 : std_dev = std_d0;
break;
case 1 : std_dev = std_d1;
break;
case 2 : std_dev = std_d2;
break;
}

template<typename DType, int req>
inline void NormalizeBackward(const DType* out_grad,
DType* in_grad,
const int length,
const int channels,
const int step,
const std::vector<float> std) {
// Microsoft Visual C++ compiler does not support omp collapse
#ifdef _MSC_VER
#pragma omp parallel for
#else
#pragma omp parallel for collapse(2)
#endif // _MSC_VER
for (int c = 0; c < channels; ++c) {
for (int i = 0; i < length; ++i) {
KERNEL_ASSIGN(in_grad[step + c*length + i], req,
out_grad[step + c*length + i] * (1.0 / std_dev));
out_grad[step + c*length + i] * (1.0 / std[c]));
}
}
};

template<typename xpu>
void NormalizeBackwardImpl(const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const float std_d0, const float std_d1, const float std_d2,
const int length,
const uint32_t channel,
const int step = 0) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
}

inline void NormalizeBackwardImpl(const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const int length,
const int channels,
const int step,
const std::vector<float> std
) {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
DType* out_grad = inputs[0].dptr<DType>();
DType* in_grad = outputs[0].dptr<DType>();
mxnet_op::Kernel<normalize_backward<req_type>, xpu>::Launch(
s, channel, in_grad, out_grad, std_d0, std_d1, std_d2, length, step);
NormalizeBackward<DType, req_type>(out_grad, in_grad, length,
channels, step, std);
});
});
}
Expand All @@ -419,37 +457,66 @@ void NormalizeOpBackward(const nnvm::NodeAttrs &attrs,
CHECK_EQ(req.size(), 1U);

const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
float std_d0, std_d1, std_d2;

// Std can be 1 or 3 D only
// Std can be 1 or 3D only.
std::vector<float> std(3);
if (param.std.ndim() == 1) {
std_d0 = std_d1 = std_d2 = param.std[0];
std[0] = std[1] = std[2] = param.std[0];
} else {
std_d0 = param.std[0];
std_d1 = param.std[1];
std_d2 = param.std[2];
std[0] = param.std[0];
std[1] = param.std[1];
std[2] = param.std[2];
}

// Note: inputs[0] is out_grad
const TBlob& in_data = inputs[1];

// 3D input (c, h, w)
if (in_data.ndim() == 3) {
if (std::is_same<xpu, gpu>::value) {
#if MXNET_USE_CUDA
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
int N, C, H, W;
DType *in_grad = nullptr;
DType *out_grad = nullptr;
if (in_data.ndim() == 3) {
N = 1;
C = static_cast<int>(in_data.shape_[0]);
H = static_cast<int>(in_data.shape_[1]);
W = static_cast<int>(in_data.shape_[2]);
out_grad = (inputs[0].get<gpu, 3, DType>(s)).dptr_;
in_grad = (outputs[0].get<gpu, 3, DType>(s)).dptr_;
} else {
N = static_cast<int>(in_data.shape_[0]);
C = static_cast<int>(in_data.shape_[1]);
H = static_cast<int>(in_data.shape_[2]);
W = static_cast<int>(in_data.shape_[3]);
out_grad = (inputs[0].get<gpu, 4, DType>(s)).dptr_;
in_grad = (outputs[0].get<gpu, 4, DType>(s)).dptr_;
}
NormalizeBackwardImplCUDA<DType>(s, out_grad, in_grad, req_type,
N, C, H, W,
std[0], std[1], std[2]);
});
});
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to use Normalize backward operator on GPU.";
#endif // MXNET_USE_CUDA
} else if (in_data.ndim() == 3) {
// 3D input (c, h, w)
const int length = in_data.shape_[1] * in_data.shape_[2];
const uint32_t channel = in_data.shape_[0];
NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req, std_d0, std_d1, std_d2, length, channel);
const int channel = static_cast<int>(in_data.shape_[0]);
const int step = 0;
NormalizeBackwardImpl(inputs, outputs, req, length, channel, step, std);
} else if (in_data.ndim() == 4) {
// 4D input (n, c, h, w)
const int batch_size = in_data.shape_[0];
const int length = in_data.shape_[2] * in_data.shape_[3];
const uint32_t channel = in_data.shape_[1];
const int channel = static_cast<int>(in_data.shape_[1]);
const int step = channel * length;

#pragma omp parallel for
for (auto n = 0; n < batch_size; ++n) {
NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req,
std_d0, std_d1, std_d2, length,
channel, n*step);
NormalizeBackwardImpl(inputs, outputs, req, length, channel, n*step, std);
}
}
}
Expand Down
Loading

0 comments on commit b7a12f4

Please sign in to comment.