Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add CUDA kernel for 4D inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeep-krishnamurthy committed Feb 9, 2019
1 parent 2f4fad5 commit 00d95bb
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 32 deletions.
30 changes: 21 additions & 9 deletions src/operator/image/image_random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ template<typename DType, typename T1, typename T2>
void ToTensorImplCUDA(mshadow::Stream<gpu> *s,
const T1 input,
const T2 output,
const int req);
const int req,
const float normalize_factor);
#endif // MXNET_USE_CUDA

// Shape and Type inference for image to tensor operator
Expand Down Expand Up @@ -89,10 +90,10 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
// Operator Implementation
template<typename DType, int req>
inline void ToTensor(float* out_data, const DType* in_data,
const int length,
const int channels,
const int step = 0,
const float normalize_factor = 255.0f) {
const int length,
const int channels,
const float normalize_factor,
const int step = 0) {
#pragma omp parallel for collapse(2)
for (int c = 0; c < channels; ++c) {
for (int i = 0; i < length; ++i) {
Expand All @@ -107,6 +108,7 @@ inline void ToTensorImpl(const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const int length,
const int channel,
const float normalize_factor,
const int step = 0) {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Expand All @@ -132,6 +134,8 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs,
CHECK_EQ(req[0], kWriteTo)
<< "`to_tensor` does not support inplace updates";

const float normalize_factor = 255.0f;

if (std::is_same<xpu, gpu>::value) {
#if MXNET_USE_CUDA
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
Expand All @@ -140,16 +144,23 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs,
if (inputs[0].ndim() == 3) {
Tensor<gpu, 3, DType> input = inputs[0].get<gpu, 3, DType>(s);
Tensor<gpu, 3, float> output = outputs[0].get<gpu, 3, float>(s);
ToTensorImplCUDA<DType, Tensor<gpu, 3, DType>, Tensor<gpu, 3, float>>(s, input, output, req_type);
ToTensorImplCUDA<DType, Tensor<gpu, 3, DType>, Tensor<gpu, 3, float>>
(s, input, output, req_type, normalize_factor);
} else {
Tensor<gpu, 4, DType> input = inputs[0].get<gpu, 4, DType>(s);
Tensor<gpu, 4, float> output = outputs[0].get<gpu, 4, float>(s);
ToTensorImplCUDA<DType, Tensor<gpu, 4, DType>, Tensor<gpu, 4, float>>
(s, input, output, req_type, normalize_factor);
}
});
});
#endif // MXNET_USE_CUDA
#endif // MXNET_USE_CUDA
} else if (inputs[0].ndim() == 3) {
// 3D Input - (h, w, c)
const int length = inputs[0].shape_[0] * inputs[0].shape_[1];
const int channel = static_cast<int>(inputs[0].shape_[2]);
ToTensorImpl(inputs, outputs, req, length, channel);
ToTensorImpl(inputs, outputs, req, length,
channel, normalize_factor);
} else if (inputs[0].ndim() == 4) {
// 4D input (n, h, w, c)
const int batch_size = inputs[0].shape_[0];
Expand All @@ -159,7 +170,8 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs,

#pragma omp parallel for
for (auto n = 0; n < batch_size; ++n) {
ToTensorImpl(inputs, outputs, req, length, channel, n*step);
ToTensorImpl(inputs, outputs, req, length, channel,
normalize_factor, n*step);
}
}
}
Expand Down
64 changes: 41 additions & 23 deletions src/operator/image/image_random.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,19 @@ namespace image {

using namespace mshadow;

// ToTensor Kernel for 3D input
template<typename xpu, typename Dtype>
__global__ void ToTensorCudaKernel(const Tensor<xpu, 3, Dtype> input,
const Tensor<xpu, 3, float> output,
const int req,
int N, int H, int W, int C,
const float normalize_factor = 255.0f) {
const int N,
const int H,
const int W,
const int C,
const float normalize_factor) {
// We process one image per thread block.
// In 3D case, we have only 1 block i.e., blockIdx.x
// We do not use it.
/*
const int n = blockIdx.x;
const int stride = H*W*C;
// Get pointer to my blocks image
int step = 0;
if (N > 0) {
step = n * stride;
}
*/
for (int c = 0; c < C; ++c) {
for (int h = threadIdx.y; h < H; h += blockDim.y) {
for (int w = threadIdx.x; w < W; w += blockDim.x) {
Expand All @@ -60,12 +54,35 @@ __global__ void ToTensorCudaKernel(const Tensor<xpu, 3, Dtype> input,
}
}

// ToTensor Kernel for 4D input
template<typename xpu, typename Dtype>
__global__ void ToTensorCudaKernel(const Tensor<xpu, 4, Dtype> input,
const Tensor<xpu, 4, float> output,
const int req,
const int N,
const int H,
const int W,
const int C,
const float normalize_factor) {
// We process one image per thread block.
const int n = blockIdx.x;

for (int c = 0; c < C; ++c) {
for (int h = threadIdx.y; h < H; h += blockDim.y) {
for (int w = threadIdx.x; w < W; w += blockDim.x) {
KERNEL_ASSIGN(output[n][c][h][w], req,
input[n][h][w][c] / normalize_factor);
}
}
}
}

template<typename DType, typename T1, typename T2>
void ToTensorImplCUDA(mshadow::Stream<gpu> *s,
const T1 input,
const T2 output,
const int req,
const float normalize_factor = 255.0f) {
const float normalize_factor) {
int blocks, H, W, C, N;
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
if (std::is_same<T1, Tensor<gpu, 3, DType>>::value) {
Expand All @@ -75,22 +92,23 @@ void ToTensorImplCUDA(mshadow::Stream<gpu> *s,
W = input.size(1);
C = input.size(2);
blocks = 1;
} /*else {
} else {
// 4D Input - (N, H, W, C)
N = input.size()[0];
H = input.size()[1];
W = input.size()[2];
C = input.size()[3];
// blocks = N > 0 ? N : 1;
N = input.size(0);
H = input.size(1);
W = input.size(2);
C = input.size(3);
blocks = N > 0 ? N : 1;
blocks = N;
}*/
}
// One block per image.
// Number of threads = (32, 32) is optimal, because,
// computation is minimal and overhead of CUDA preparing
// all threads is minimal.
ToTensorCudaKernel<<<blocks, dim3(32, 32), 0, stream>>>(input,
output, req, N, H, W, C, normalize_factor);
MSHADOW_CUDA_POST_KERNEL_CHECK(ToTensorCudaKernel);
ToTensorCudaKernel<gpu, DType>
<<<blocks, dim3(32, 32), 0, stream>>>(input, output,
req, N, H, W, C, normalize_factor);
MSHADOW_CUDA_POST_KERNEL_CHECK(ToTensorCudaKernel);
}

NNVM_REGISTER_OP(_image_to_tensor)
Expand Down

0 comments on commit 00d95bb

Please sign in to comment.