From a770ce7c26ae94f4193638627cf369d2dcdb53de Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Fri, 1 Feb 2019 18:44:24 -0800 Subject: [PATCH] parallelize on channel in kernel launch --- src/operator/image/image_random-inl.h | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index 5cd4a82df21d..c9dd85af616f 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -82,11 +82,14 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs, template struct totensor_forward { template - MSHADOW_XINLINE static void Map(int l, float* out_data, const DType* in_data, - const int c, const int length, const int channel, - const int step, const float normalize_factor = 255.0f) { - KERNEL_ASSIGN(out_data[step + c*length + l], req, - (in_data[step + l*channel + c]) / normalize_factor); + MSHADOW_XINLINE static void Map(uint32_t c, float* out_data, const DType* in_data, + const int length, const int channel, const int step, + const float normalize_factor = 255.0f) { + #pragma omp parallel for + for (int i = 0; i < length; ++i) { + KERNEL_ASSIGN(out_data[step + c*length + i], req, + (in_data[step + i*channel + c]) / normalize_factor); + } } }; @@ -96,7 +99,7 @@ void ToTensorImpl(const OpContext &ctx, const std::vector &outputs, const std::vector &req, const int length, - const int channel, + const uint32_t channel, const int step = 0) { mshadow::Stream *s = ctx.get_stream(); @@ -104,11 +107,8 @@ void ToTensorImpl(const OpContext &ctx, MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { float* output = outputs[0].dptr(); DType* input = inputs[0].dptr(); - - for (int c = 0; c < channel; ++c) { - mxnet_op::Kernel, xpu>::Launch( - s, length, output, input, c, length, channel, step); - } + mxnet_op::Kernel, xpu>::Launch( + s, channel, output, input, length, channel, step); }); }); } @@ -129,13 +129,13 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs, // 3D Input - (h, w, c) if (inputs[0].ndim() == 3) { const int length = inputs[0].shape_[0] * inputs[0].shape_[1]; - const int channel = inputs[0].shape_[2]; + const uint32_t channel = inputs[0].shape_[2]; ToTensorImpl(ctx, inputs, outputs, req, length, channel); } else if (inputs[0].ndim() == 4) { // 4D input (n, h, w, c) const int batch_size = inputs[0].shape_[0]; const int length = inputs[0].shape_[1] * inputs[0].shape_[2]; - const int channel = inputs[0].shape_[3]; + const uint32_t channel = inputs[0].shape_[3]; const int step = channel * length; #pragma omp parallel for