Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
parallelize on channel in kernel launch
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeep-krishnamurthy committed Feb 4, 2019
1 parent d2988fa commit a770ce7
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/operator/image/image_random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,14 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
template<int req>
struct totensor_forward {
template<typename DType>
MSHADOW_XINLINE static void Map(int l, float* out_data, const DType* in_data,
const int c, const int length, const int channel,
const int step, const float normalize_factor = 255.0f) {
KERNEL_ASSIGN(out_data[step + c*length + l], req,
(in_data[step + l*channel + c]) / normalize_factor);
MSHADOW_XINLINE static void Map(uint32_t c, float* out_data, const DType* in_data,
const int length, const int channel, const int step,
const float normalize_factor = 255.0f) {
#pragma omp parallel for
for (int i = 0; i < length; ++i) {
KERNEL_ASSIGN(out_data[step + c*length + i], req,
(in_data[step + i*channel + c]) / normalize_factor);
}
}
};

Expand All @@ -96,19 +99,16 @@ void ToTensorImpl(const OpContext &ctx,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const int length,
const int channel,
const uint32_t channel,
const int step = 0) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();

MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
float* output = outputs[0].dptr<float>();
DType* input = inputs[0].dptr<DType>();

for (int c = 0; c < channel; ++c) {
mxnet_op::Kernel<totensor_forward<req_type>, xpu>::Launch(
s, length, output, input, c, length, channel, step);
}
mxnet_op::Kernel<totensor_forward<req_type>, xpu>::Launch(
s, channel, output, input, length, channel, step);
});
});
}
Expand All @@ -129,13 +129,13 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs,
// 3D Input - (h, w, c)
if (inputs[0].ndim() == 3) {
const int length = inputs[0].shape_[0] * inputs[0].shape_[1];
const int channel = inputs[0].shape_[2];
const uint32_t channel = inputs[0].shape_[2];
ToTensorImpl<xpu>(ctx, inputs, outputs, req, length, channel);
} else if (inputs[0].ndim() == 4) {
// 4D input (n, h, w, c)
const int batch_size = inputs[0].shape_[0];
const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
const int channel = inputs[0].shape_[3];
const uint32_t channel = inputs[0].shape_[3];
const int step = channel * length;

#pragma omp parallel for
Expand Down

0 comments on commit a770ce7

Please sign in to comment.