From 5b99b25e5f6ab3a20c7bcf4821a6af0a1a95f823 Mon Sep 17 00:00:00 2001 From: Dick Carter Date: Tue, 2 Jan 2018 10:47:41 -0800 Subject: [PATCH] 1 d conv with cudnn (#9184) * 1D conv/deconv handling by cudnn, with tests. * Fix python3 test issue. * Fix lint issues. * Fixed CI and doc. --- src/operator/nn/convolution-inl.h | 8 +- src/operator/nn/convolution.cu | 7 - src/operator/nn/cudnn/cudnn_convolution-inl.h | 273 ++++++++---------- .../nn/cudnn/cudnn_deconvolution-inl.h | 261 ++++++++--------- src/operator/nn/deconvolution-inl.h | 161 +++++++---- src/operator/nn/deconvolution.cc | 2 +- src/operator/nn/deconvolution.cu | 8 +- tests/python/gpu/test_operator_gpu.py | 64 ++-- tests/python/unittest/test_operator.py | 221 +++++++++----- 9 files changed, 529 insertions(+), 476 deletions(-) diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h index 38971aefa2d3..1613da6c85d1 100644 --- a/src/operator/nn/convolution-inl.h +++ b/src/operator/nn/convolution-inl.h @@ -67,13 +67,13 @@ struct ConvolutionParam : public dmlc::Parameter { bool cudnn_off; dmlc::optional layout; DMLC_DECLARE_PARAMETER(ConvolutionParam) { - DMLC_DECLARE_FIELD(kernel).describe("Convolution kernel size: (h, w) or (d, h, w)"); + DMLC_DECLARE_FIELD(kernel).describe("Convolution kernel size: (w,), (h, w) or (d, h, w)"); DMLC_DECLARE_FIELD(stride).set_default(TShape()) - .describe("Convolution stride: (h, w) or (d, h, w). Defaults to 1 for each dimension."); + .describe("Convolution stride: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension."); DMLC_DECLARE_FIELD(dilate).set_default(TShape()) - .describe("Convolution dilate: (h, w) or (d, h, w). Defaults to 1 for each dimension."); + .describe("Convolution dilate: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension."); DMLC_DECLARE_FIELD(pad).set_default(TShape()) - .describe("Zero pad for convolution: (h, w) or (d, h, w). Defaults to no padding."); + .describe("Zero pad for convolution: (w,), (h, w) or (d, h, w). Defaults to no padding."); DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000) .describe("Convolution filter(channel) number"); DMLC_DECLARE_FIELD(num_group).set_default(1) diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu index c31d78c226f4..7234daf0d614 100644 --- a/src/operator/nn/convolution.cu +++ b/src/operator/nn/convolution.cu @@ -41,13 +41,6 @@ Operator* CreateOp(ConvolutionParam param, int dtype, std::vector *out_shape, Context ctx) { Operator *op = NULL; - // If 1D convolution, use MXNet implementation - if (param.kernel.ndim() == 1) { - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new ConvolutionOp(param); - }) - return op; - } // depth wise conv if (param.num_filter == param.num_group && diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index f37203998e0a..8ffe97d94310 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -64,12 +64,22 @@ class CuDNNConvolutionOp : public Operator { cudnn_tensor_core_ = DataType::kFlag == kFloat16 && GetEnvAllowTensorCore(); #if CUDNN_MAJOR >= 5 - MSHADOW_LAYOUT_SWITCH(param_.layout.value(), Layout, { + auto effective_layout = param_.layout.value(); + switch (effective_layout) { + // 1D convolutions will be executed as 2D convolutions with a height of 1. + case mshadow::kNCW: effective_layout = mshadow::kNCHW; break; + case mshadow::kNWC: effective_layout = mshadow::kNHWC; break; + case mshadow::kCWN: effective_layout = mshadow::kCHWN; break; + default: break; + } + + MSHADOW_LAYOUT_SWITCH(effective_layout, Layout, { format_ = LayoutType::kCudnnFlag; }); #else - CHECK(param_.layout.value() == kNCHW || param_.layout.value() == kNCDHW) - << "Need CuDNN > 5.0 for layout support"; + CHECK(param_.layout.value() == kNCW || + param_.layout.value() == kNCHW || + param_.layout.value() == kNCDHW) << "Need CuDNN > 5.0 for layout support"; #endif // Double check to make sure this class supports the operation if (!Supports(param, forward_compute_type, backward_compute_type, ctx)) @@ -110,9 +120,6 @@ class CuDNNConvolutionOp : public Operator { const std::vector &aux_args) { using namespace mshadow; size_t expected = param_.no_bias ? 2 : 3; - DType *data_ptr = NULL; - DType *wmat_ptr = NULL; - DType *out_ptr = NULL; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1U); Stream *s = ctx.get_stream(); @@ -120,27 +127,11 @@ class CuDNNConvolutionOp : public Operator { Tensor workspace = AllocateTempWorkspace(ctx, forward_workspace_byte_); size_t workspace_size = TensorSizeBytes(workspace); - if (param_.kernel.ndim() == 2) { - Tensor data = in_data[conv::kData].get(s); - Tensor wmat = in_data[conv::kWeight].get(s); - Tensor out = out_data[conv::kOut].get(s); - CHECK_EQ(data.CheckContiguous(), true); - CHECK_EQ(wmat.CheckContiguous(), true); - CHECK_EQ(out.CheckContiguous(), true); - data_ptr = data.dptr_; - wmat_ptr = wmat.dptr_; - out_ptr = out.dptr_; - } else { - Tensor data = in_data[conv::kData].get(s); - Tensor wmat = in_data[conv::kWeight].get(s); - Tensor out = out_data[conv::kOut].get(s); - CHECK_EQ(data.CheckContiguous(), true); - CHECK_EQ(wmat.CheckContiguous(), true); - CHECK_EQ(out.CheckContiguous(), true); - data_ptr = data.dptr_; - wmat_ptr = wmat.dptr_; - out_ptr = out.dptr_; - } + // I/O's should have 2 more dims than the kernel dim + DType *data_ptr = GetNdPtr(in_data[conv::kData], param_.kernel.ndim() + 2, s); + DType *wmat_ptr = GetNdPtr(in_data[conv::kWeight], param_.kernel.ndim() + 2, s); + DType *out_ptr = GetNdPtr(out_data[conv::kOut], param_.kernel.ndim() + 2, s); + for (uint32_t g = 0; g < param_.num_group; ++g) { typename DataType::ScaleType alpha = 1.0f; typename DataType::ScaleType beta = 0.0f; @@ -193,37 +184,17 @@ class CuDNNConvolutionOp : public Operator { using namespace mshadow; using namespace mshadow::expr; size_t expected = param_.no_bias == 0 ? 3 : 2; - DType *grad_ptr = NULL; - DType *wmat_ptr = NULL; - DType *gwmat_ptr = NULL; - DType *data_ptr = NULL; - DType *gdata_ptr = NULL; CHECK_EQ(out_grad.size(), 1U); CHECK(in_data.size() == expected && in_grad.size() == expected); Stream *s = ctx.get_stream(); - if (param_.kernel.ndim() == 2) { - Tensor grad = out_grad[conv::kOut].get(s); - Tensor wmat = in_data[conv::kWeight].get(s); - Tensor gwmat = in_grad[conv::kWeight].get(s); - Tensor data = in_data[conv::kData].get(s); - Tensor gdata = in_grad[conv::kData].get(s); - grad_ptr = grad.dptr_; - wmat_ptr = wmat.dptr_; - gwmat_ptr = gwmat.dptr_; - data_ptr = data.dptr_; - gdata_ptr = gdata.dptr_; - } else { - Tensor grad = out_grad[conv::kOut].get(s); - Tensor wmat = in_data[conv::kWeight].get(s); - Tensor gwmat = in_grad[conv::kWeight].get(s); - Tensor data = in_data[conv::kData].get(s); - Tensor gdata = in_grad[conv::kData].get(s); - grad_ptr = grad.dptr_; - wmat_ptr = wmat.dptr_; - gwmat_ptr = gwmat.dptr_; - data_ptr = data.dptr_; - gdata_ptr = gdata.dptr_; - } + + // I/O's should have 2 more dims than the kernel dim + DType *grad_ptr = GetNdPtr(out_grad[conv::kOut], param_.kernel.ndim() + 2, s); + DType *wmat_ptr = GetNdPtr(in_data[conv::kWeight], param_.kernel.ndim() + 2, s); + DType *gwmat_ptr = GetNdPtr(in_grad[conv::kWeight], param_.kernel.ndim() + 2, s); + DType *data_ptr = GetNdPtr(in_data[conv::kData], param_.kernel.ndim() + 2, s); + DType *gdata_ptr = GetNdPtr(in_grad[conv::kData], param_.kernel.ndim() + 2, s); + Tensor workspace = AllocateTempWorkspace(ctx, backward_workspace_byte_); size_t workspace_size = TensorSizeBytes(workspace); for (uint32_t g = 0; g < param_.num_group; ++g) { @@ -320,7 +291,8 @@ class CuDNNConvolutionOp : public Operator { auto layout_val = param.layout.value(); auto true_fp16 = DataType::kFlag == kFloat16 && (forward_compute_type == kFloat16 || backward_compute_type == kFloat16); - if (layout_val == kNDHWC || layout_val == kNHWC && true_fp16) + if (layout_val == kNDHWC || layout_val == kNWC || + layout_val == kNHWC && true_fp16) return false; // Permits graceful fallback to pseudo-fp16 on heterogenous systems @@ -374,100 +346,78 @@ class CuDNNConvolutionOp : public Operator { TShape oshape = out_shape[conv::kOut]; TShape dstride, ostride; wshape[0] /= param_.num_group; - if (param_.kernel.ndim() == 2) { - // 2d conv - +#if CUDNN_MAJOR <= 5 // As of cuDNN_v6, the unsuffixed version of cudnnSetConvolution2dDescriptor() - // requires an additional 'computeType' parameter to set the precision of the - // convolution calculation. This facility was available as of v5 in - // cudnnSetConvolution2dDescriptor_v5(), but was never accessed. - #if CUDNN_MAJOR >= 6 + // takes an additional 'computeType' parameter to set the precision of the + // convolution calculation. Supply this method signature for cuDNN versions < 6. +#define cudnnSetConvolution2dDescriptor(cdesc, p0, p1, s0, s1, d0, d1, m, ct) \ + cudnnSetConvolution2dDescriptor(cdesc, p0, p1, s0, s1, d0, d1, m) +#endif + if (param_.kernel.ndim() == 1 || param_.kernel.ndim() == 2) { + // 1d or 2d conv + auto pad = param_.kernel.ndim() == 2 ? param_.pad : TShape({0, param_.pad[0]}); + auto stride = param_.kernel.ndim() == 2 ? param_.stride : TShape({1, param_.stride[0]}); + auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : TShape({1, param_.dilate[0]}); CUDNN_CALL(cudnnSetConvolution2dDescriptor(forward_conv_desc_, - param_.pad[0], - param_.pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], + pad[0], + pad[1], + stride[0], + stride[1], + dilate[0], + dilate[1], CUDNN_CROSS_CORRELATION, cudnn_forward_compute_type)); CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_, - param_.pad[0], - param_.pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], + pad[0], + pad[1], + stride[0], + stride[1], + dilate[0], + dilate[1], CUDNN_CROSS_CORRELATION, cudnn_backward_compute_type)); CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_w_, - param_.pad[0], - param_.pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], + pad[0], + pad[1], + stride[0], + stride[1], + dilate[0], + dilate[1], CUDNN_CROSS_CORRELATION, cudnn_backward_compute_type)); - #else - CUDNN_CALL(cudnnSetConvolution2dDescriptor(forward_conv_desc_, - param_.pad[0], - param_.pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], - CUDNN_CROSS_CORRELATION)); - CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_, - param_.pad[0], - param_.pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], - CUDNN_CROSS_CORRELATION)); - CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_w_, - param_.pad[0], - param_.pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], - CUDNN_CROSS_CORRELATION)); - #endif - - #if CUDNN_MAJOR >= 5 - wshape = ConvertLayout(wshape.get<4>(), param_.layout.value(), kNCHW); - CUDNN_CALL(cudnnSetFilter4dDescriptor(filter_desc_, - dtype_, - format_, - wshape[0], - wshape[1], - wshape[2], - wshape[3])); - #else - CHECK_EQ(param_.layout.value(), kNCHW) << "CuDNN V4 only support NCHW layout"; +#if CUDNN_MAJOR < 5 + // As of cuDNN_v5, cudnnSetFilter4dDescriptor() takes a format parameter. + // Supply this method signature for cuDNN versions < 5. +#define cudnnSetFilter4dDescriptor(fdesc, dt, f, w0, w1, w2, w3) \ + cudnnSetFilter4dDescriptor(fdesc, dt, w0, w1, w2, w3) + CHECK_EQ(format_, CUDNN_TENSOR_NCHW) << "CuDNN V4 and earlier only supports NCHW layout"; +#endif + if (param_.kernel.ndim() == 2) { + wshape = ConvertLayout(wshape.get<4>(), param_.layout.value(), kNCHW); + dstride = ConvertLayout(Strides<4>(dshape), param_.layout.value(), kNCHW); + dshape = ConvertLayout(dshape.get<4>(), param_.layout.value(), kNCHW); + ostride = ConvertLayout(Strides<4>(oshape), param_.layout.value(), kNCHW); + oshape = ConvertLayout(oshape.get<4>(), param_.layout.value(), kNCHW); + } else { + wshape = ConvertLayout(wshape.get<3>(), param_.layout.value(), kNCW); + wshape = TShape({wshape[0], wshape[1], 1, wshape[2]}); + dstride = ConvertLayout(Strides<3>(dshape), param_.layout.value(), kNCW); + dstride = TShape({dstride[0], dstride[1], dstride[1], dstride[2]}); + dshape = ConvertLayout(dshape.get<3>(), param_.layout.value(), kNCW); + dshape = TShape({dshape[0], dshape[1], 1, dshape[2]}); + ostride = ConvertLayout(Strides<3>(oshape), param_.layout.value(), kNCW); + ostride = TShape({ostride[0], ostride[1], ostride[1], ostride[2]}); + oshape = ConvertLayout(oshape.get<3>(), param_.layout.value(), kNCW); + oshape = TShape({oshape[0], oshape[1], 1, oshape[2]}); + } CUDNN_CALL(cudnnSetFilter4dDescriptor(filter_desc_, - dtype_, - wshape[0], - wshape[1], - wshape[2], - wshape[3])); - #endif + dtype_, + format_, + wshape[0], + wshape[1], + wshape[2], + wshape[3])); - dstride = ConvertLayout(Shape4(dshape[1] * dshape[2] * dshape[3], - dshape[2] * dshape[3], - dshape[3], - 1), - param_.layout.value(), kNCHW); - dshape = ConvertLayout(dshape.get<4>(), param_.layout.value(), kNCHW); - - ostride = ConvertLayout(Shape4(oshape[1] * oshape[2] * oshape[3], - oshape[2] * oshape[3], - oshape[3], - 1), - param_.layout.value(), kNCHW); - oshape = ConvertLayout(oshape.get<4>(), param_.layout.value(), kNCHW); } else if (param_.kernel.ndim() == 3) { // 3d conv #if CUDNN_MAJOR >= 5 @@ -505,20 +455,9 @@ class CuDNNConvolutionOp : public Operator { CUDNN_CROSS_CORRELATION, cudnn_backward_compute_type)); - dstride = ConvertLayout(Shape5(dshape[1] * dshape[2] * dshape[3] * dshape[4], - dshape[2] * dshape[3] * dshape[4], - dshape[3] * dshape[4], - dshape[4], - 1), - param_.layout.value(), kNCDHW); + dstride = ConvertLayout(Strides<5>(dshape), param_.layout.value(), kNCDHW); dshape = ConvertLayout(dshape.get<5>(), param_.layout.value(), kNCDHW); - - ostride = ConvertLayout(Shape5(oshape[1] * oshape[2] * oshape[3] * oshape[4], - oshape[2] * oshape[3] * oshape[4], - oshape[3] * oshape[4], - oshape[4], - 1), - param_.layout.value(), kNCDHW); + ostride = ConvertLayout(Strides<5>(oshape), param_.layout.value(), kNCDHW); oshape = ConvertLayout(oshape.get<5>(), param_.layout.value(), kNCDHW); } // Set "allow tensor core" flag in convolution descriptors, if available. @@ -852,6 +791,38 @@ class CuDNNConvolutionOp : public Operator { return buffer->data(); } + // Converts a TBlob to a dptr, checking for the expected dim and that it's contiguous. + DType *GetNdPtr(const TBlob& tb, int dim, Stream *s) { + DType *data_ptr = NULL; + if (dim == 3) { + Tensor data = tb.get(s); + CHECK_EQ(data.CheckContiguous(), true); + data_ptr = data.dptr_; + } else if (dim == 4) { + Tensor data = tb.get(s); + CHECK_EQ(data.CheckContiguous(), true); + data_ptr = data.dptr_; + } else if (dim == 5) { + Tensor data = tb.get(s); + CHECK_EQ(data.CheckContiguous(), true); + data_ptr = data.dptr_; + } else { + LOG(FATAL) << "Unexpected Tensor size " << dim << ", supporting only 3, 4 or 5."; + } + return data_ptr; + } + + // Converts a TShape to a Shape<> of strides. + // e.g. {shape[0], shape[1], shape[2]} -> {shape[1]*shape[2], shape[2], 1} + template + inline Shape Strides(const TShape &s) { + uint32_t ndim = s.ndim(); + TShape strides(ndim); + for (uint32_t i = 0; i != ndim; ++i) + strides[i] = s.ProdShape(i+1, ndim); + return strides.get(); + } + void InitBufferForParam() { CastTShapeToIntPtr(param_.stride, ¶m_stride_); CastTShapeToIntPtr(param_.dilate, ¶m_dilate_); diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index 09e89c27bbaf..bc02d1b73f45 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -61,16 +61,26 @@ class CuDNNDeconvolutionOp : public Operator { cudnn_tensor_core_ = DataType::kFlag == kFloat16 && GetEnvAllowTensorCore(); #if CUDNN_MAJOR >= 5 - MSHADOW_LAYOUT_SWITCH(param_.layout.value(), Layout, { + auto effective_layout = param_.layout.value(); + switch (effective_layout) { + // 1D convolutions will be executed as 2D convolutions with a height of 1. + case mshadow::kNCW: effective_layout = mshadow::kNCHW; break; + case mshadow::kNWC: effective_layout = mshadow::kNHWC; break; + case mshadow::kCWN: effective_layout = mshadow::kCHWN; break; + default: break; + } + + MSHADOW_LAYOUT_SWITCH(effective_layout, Layout, { format_ = LayoutType::kCudnnFlag; }); #else - CHECK(param_.layout.value() == kNCHW || param_.layout.value() == kNCDHW) - << "Need CuDNN > 5.0 for layout support"; + CHECK(param_.layout.value() == kNCW || + param_.layout.value() == kNCHW || + param_.layout.value() == kNCDHW) << "Need CuDNN > 5.0 for layout support"; #endif // Double check to make sure this class supports the operation if (!Supports(param, forward_compute_type, backward_compute_type, ctx)) - LOG(FATAL) << "Need CuDNN >= 6.0 for dilated convolution."; + LOG(FATAL) << "Need CuDNN >= 6.0 for dilated deconvolution."; InitDescriptors(ctx, in_shape, out_shape, cudnn_forward_compute_type, cudnn_backward_compute_type); @@ -107,9 +117,6 @@ class CuDNNDeconvolutionOp : public Operator { const std::vector &aux_args) { using namespace mshadow; size_t expected = param_.no_bias ? 2 : 3; - DType *data_ptr = NULL; - DType *wmat_ptr = NULL; - DType *out_ptr = NULL; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1U); Stream *s = ctx.get_stream(); @@ -117,27 +124,10 @@ class CuDNNDeconvolutionOp : public Operator { Tensor workspace = AllocateTempWorkspace(ctx, forward_workspace_byte_); size_t workspace_size = TensorSizeBytes(workspace); - if (param_.kernel.ndim() == 2) { - Tensor data = in_data[deconv::kData].get(s); - Tensor wmat = in_data[deconv::kWeight].get(s); - Tensor out = out_data[deconv::kOut].get(s); - CHECK_EQ(data.CheckContiguous(), true); - CHECK_EQ(wmat.CheckContiguous(), true); - CHECK_EQ(out.CheckContiguous(), true); - data_ptr = data.dptr_; - wmat_ptr = wmat.dptr_; - out_ptr = out.dptr_; - } else { - Tensor data = in_data[deconv::kData].get(s); - Tensor wmat = in_data[deconv::kWeight].get(s); - Tensor out = out_data[deconv::kOut].get(s); - CHECK_EQ(data.CheckContiguous(), true); - CHECK_EQ(wmat.CheckContiguous(), true); - CHECK_EQ(out.CheckContiguous(), true); - data_ptr = data.dptr_; - wmat_ptr = wmat.dptr_; - out_ptr = out.dptr_; - } + // I/O's should have 2 more dims than the kernel dim + DType *data_ptr = GetNdPtr(in_data[deconv::kData], param_.kernel.ndim() + 2, s); + DType *wmat_ptr = GetNdPtr(in_data[deconv::kWeight], param_.kernel.ndim() + 2, s); + DType *out_ptr = GetNdPtr(out_data[deconv::kOut], param_.kernel.ndim() + 2, s); for (uint32_t g = 0; g < param_.num_group; ++g) { typename DataType::ScaleType alpha = 1.0f; @@ -207,37 +197,17 @@ class CuDNNDeconvolutionOp : public Operator { using namespace mshadow; using namespace mshadow::expr; size_t expected = param_.no_bias == 0 ? 3 : 2; - DType *grad_ptr = NULL; - DType *wmat_ptr = NULL; - DType *gwmat_ptr = NULL; - DType *data_ptr = NULL; - DType *gdata_ptr = NULL; CHECK_EQ(out_grad.size(), 1U); CHECK(in_data.size() == expected && in_grad.size() == expected); Stream *s = ctx.get_stream(); - if (param_.kernel.ndim() == 2) { - Tensor grad = out_grad[deconv::kOut].get(s); - Tensor wmat = in_data[deconv::kWeight].get(s); - Tensor gwmat = in_grad[deconv::kWeight].get(s); - Tensor data = in_data[deconv::kData].get(s); - Tensor gdata = in_grad[deconv::kData].get(s); - grad_ptr = grad.dptr_; - wmat_ptr = wmat.dptr_; - gwmat_ptr = gwmat.dptr_; - data_ptr = data.dptr_; - gdata_ptr = gdata.dptr_; - } else { - Tensor grad = out_grad[deconv::kOut].get(s); - Tensor wmat = in_data[deconv::kWeight].get(s); - Tensor gwmat = in_grad[deconv::kWeight].get(s); - Tensor data = in_data[deconv::kData].get(s); - Tensor gdata = in_grad[deconv::kData].get(s); - grad_ptr = grad.dptr_; - wmat_ptr = wmat.dptr_; - gwmat_ptr = gwmat.dptr_; - data_ptr = data.dptr_; - gdata_ptr = gdata.dptr_; - } + + // I/O's should have 2 more dims than the kernel dim + DType *grad_ptr = GetNdPtr(out_grad[deconv::kOut], param_.kernel.ndim() + 2, s); + DType *wmat_ptr = GetNdPtr(in_data[deconv::kWeight], param_.kernel.ndim() + 2, s); + DType *gwmat_ptr = GetNdPtr(in_grad[deconv::kWeight], param_.kernel.ndim() + 2, s); + DType *data_ptr = GetNdPtr(in_data[deconv::kData], param_.kernel.ndim() + 2, s); + DType *gdata_ptr = GetNdPtr(in_grad[deconv::kData], param_.kernel.ndim() + 2, s); + CHECK_NE(req[deconv::kWeight], kWriteInplace); if (!param_.no_bias) { CHECK_NE(req[deconv::kBias], kWriteInplace); @@ -331,7 +301,8 @@ class CuDNNDeconvolutionOp : public Operator { auto layout_val = param.layout.value(); auto true_fp16 = DataType::kFlag == kFloat16 && (forward_compute_type == kFloat16 || backward_compute_type == kFloat16); - if (layout_val == kNDHWC || layout_val == kNHWC && true_fp16) + if (layout_val == kNDHWC || layout_val == kNWC || + layout_val == kNHWC && true_fp16) return false; // Permits graceful fallback to pseudo-fp16 on heterogenous systems @@ -374,9 +345,6 @@ class CuDNNDeconvolutionOp : public Operator { cudnnDataType_t cudnn_forward_compute_type, cudnnDataType_t cudnn_backward_compute_type) { using namespace mshadow; - #if CUDNN_MAJOR >= 5 - format_ = CUDNN_TENSOR_NCHW; - #endif size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_shape.size(), expected); CHECK_EQ(out_shape.size(), 1U); @@ -393,70 +361,81 @@ class CuDNNDeconvolutionOp : public Operator { TShape oshape = out_shape[deconv::kOut]; TShape dstride, ostride; wshape[0] /= param_.num_group; - - if (param_.kernel.ndim() == 2) { - // 2d conv +#if CUDNN_MAJOR <= 5 + // As of cuDNN_v6, the unsuffixed version of cudnnSetConvolution2dDescriptor() + // takes an additional 'computeType' parameter to set the precision of the + // convolution calculation. Supply this method signature for cuDNN versions < 6. +#define cudnnSetConvolution2dDescriptor(cdesc, p0, p1, s0, s1, d0, d1, m, ct) \ + cudnnSetConvolution2dDescriptor(cdesc, p0, p1, s0, s1, d0, d1, m) +#endif + if (param_.kernel.ndim() == 1 || param_.kernel.ndim() == 2) { + // 1d or 2d conv index_t o_pad[2]; index_t o_adj[2]; - param_.InferPad(dshape, o_pad, o_adj); + if (param_.kernel.ndim() == 2) { + param_.InferPad(dshape, o_pad, o_adj); + } else { + index_t o_pad_1D[1]; + index_t o_adj_1D[1]; + param_.InferPad(dshape, o_pad_1D, o_adj_1D); + o_pad[0] = 0; + o_pad[1] = o_pad_1D[0]; + } + auto stride = param_.kernel.ndim() == 2 ? param_.stride : TShape({1, param_.stride[0]}); + auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : TShape({1, param_.dilate[0]}); - #if CUDNN_MAJOR >= 6 CUDNN_CALL(cudnnSetConvolution2dDescriptor(forward_conv_desc_, o_pad[0], o_pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], + stride[0], + stride[1], + dilate[0], + dilate[1], CUDNN_CROSS_CORRELATION, cudnn_forward_compute_type)); CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_, o_pad[0], o_pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], + stride[0], + stride[1], + dilate[0], + dilate[1], CUDNN_CROSS_CORRELATION, cudnn_backward_compute_type)); CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_w_, o_pad[0], o_pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], + stride[0], + stride[1], + dilate[0], + dilate[1], CUDNN_CROSS_CORRELATION, cudnn_backward_compute_type)); - #else - CUDNN_CALL(cudnnSetConvolution2dDescriptor(forward_conv_desc_, - o_pad[0], - o_pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], - CUDNN_CROSS_CORRELATION)); - CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_, - o_pad[0], - o_pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], - CUDNN_CROSS_CORRELATION)); - CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_w_, - o_pad[0], - o_pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], - CUDNN_CROSS_CORRELATION)); - #endif - - #if CUDNN_MAJOR >= 5 - wshape = ConvertLayout(wshape.get<4>(), param_.layout.value(), kNCHW); +#if CUDNN_MAJOR < 5 + // As of cuDNN_v5, cudnnSetFilter4dDescriptor() takes a format parameter. + // Supply this method signature for cuDNN versions < 5. +#define cudnnSetFilter4dDescriptor(fdesc, dt, f, w0, w1, w2, w3) \ + cudnnSetFilter4dDescriptor(fdesc, dt, w0, w1, w2, w3) + CHECK_EQ(format_, CUDNN_TENSOR_NCHW) << "CuDNN V4 and earlier only supports NCHW layout"; +#endif + if (param_.kernel.ndim() == 2) { + wshape = ConvertLayout(wshape.get<4>(), param_.layout.value(), kNCHW); + dstride = ConvertLayout(Strides<4>(dshape), param_.layout.value(), kNCHW); + dshape = ConvertLayout(dshape.get<4>(), param_.layout.value(), kNCHW); + ostride = ConvertLayout(Strides<4>(oshape), param_.layout.value(), kNCHW); + oshape = ConvertLayout(oshape.get<4>(), param_.layout.value(), kNCHW); + } else { + wshape = ConvertLayout(wshape.get<3>(), param_.layout.value(), kNCW); + wshape = TShape({wshape[0], wshape[1], 1, wshape[2]}); + dstride = ConvertLayout(Strides<3>(dshape), param_.layout.value(), kNCW); + dstride = TShape({dstride[0], dstride[1], dstride[1], dstride[2]}); + dshape = ConvertLayout(dshape.get<3>(), param_.layout.value(), kNCW); + dshape = TShape({dshape[0], dshape[1], 1, dshape[2]}); + ostride = ConvertLayout(Strides<3>(oshape), param_.layout.value(), kNCW); + ostride = TShape({ostride[0], ostride[1], ostride[1], ostride[2]}); + oshape = ConvertLayout(oshape.get<3>(), param_.layout.value(), kNCW); + oshape = TShape({oshape[0], oshape[1], 1, oshape[2]}); + } CUDNN_CALL(cudnnSetFilter4dDescriptor(filter_desc_, dtype_, format_, @@ -464,29 +443,6 @@ class CuDNNDeconvolutionOp : public Operator { wshape[1], wshape[2], wshape[3])); - #else - CHECK_EQ(param_.layout.value(), kNCHW) << "CuDNN V4 only support NCHW layout"; - CUDNN_CALL(cudnnSetFilter4dDescriptor(filter_desc_, - dtype_, - wshape[0], - wshape[1], - wshape[2], - wshape[3])); - #endif - - dstride = ConvertLayout(Shape4(dshape[1] * dshape[2] * dshape[3], - dshape[2] * dshape[3], - dshape[3], - 1), - param_.layout.value(), kNCHW); - dshape = ConvertLayout(dshape.get<4>(), param_.layout.value(), kNCHW); - - ostride = ConvertLayout(Shape4(oshape[1] * oshape[2] * oshape[3], - oshape[2] * oshape[3], - oshape[3], - 1), - param_.layout.value(), kNCHW); - oshape = ConvertLayout(oshape.get<4>(), param_.layout.value(), kNCHW); } else if (param_.kernel.ndim() == 3) { // 3d conv index_t o_pad[3]; @@ -528,20 +484,9 @@ class CuDNNDeconvolutionOp : public Operator { CUDNN_CROSS_CORRELATION, cudnn_backward_compute_type)); - dstride = ConvertLayout(Shape5(dshape[1] * dshape[2] * dshape[3] * dshape[4], - dshape[2] * dshape[3] * dshape[4], - dshape[3] * dshape[4], - dshape[4], - 1), - param_.layout.value(), kNCDHW); + dstride = ConvertLayout(Strides<5>(dshape), param_.layout.value(), kNCDHW); dshape = ConvertLayout(dshape.get<5>(), param_.layout.value(), kNCDHW); - - ostride = ConvertLayout(Shape5(oshape[1] * oshape[2] * oshape[3] * oshape[4], - oshape[2] * oshape[3] * oshape[4], - oshape[3] * oshape[4], - oshape[4], - 1), - param_.layout.value(), kNCDHW); + ostride = ConvertLayout(Strides<5>(oshape), param_.layout.value(), kNCDHW); oshape = ConvertLayout(oshape.get<5>(), param_.layout.value(), kNCDHW); } // Set "allow tensor core" flag in convolution descriptors, if available. @@ -883,6 +828,38 @@ class CuDNNDeconvolutionOp : public Operator { return buffer->data(); } + // Converts a TBlob to a dptr, checking for the expected dim and that it's contiguous. + DType *GetNdPtr(const TBlob& tb, int dim, Stream *s) { + DType *data_ptr = NULL; + if (dim == 3) { + Tensor data = tb.get(s); + CHECK_EQ(data.CheckContiguous(), true); + data_ptr = data.dptr_; + } else if (dim == 4) { + Tensor data = tb.get(s); + CHECK_EQ(data.CheckContiguous(), true); + data_ptr = data.dptr_; + } else if (dim == 5) { + Tensor data = tb.get(s); + CHECK_EQ(data.CheckContiguous(), true); + data_ptr = data.dptr_; + } else { + LOG(FATAL) << "Unexpected Tensor size " << dim << ", supporting only 3, 4 or 5."; + } + return data_ptr; + } + + // Converts a TShape to a Shape<> of strides. + // e.g. {shape[0], shape[1], shape[2]} -> {shape[1]*shape[2], shape[2], 1} + template + inline Shape Strides(const TShape &s) { + uint32_t ndim = s.ndim(); + TShape strides(ndim); + for (uint32_t i = 0; i != ndim; ++i) + strides[i] = s.ProdShape(i+1, ndim); + return strides.get(); + } + void InitBufferForParam() { CastTShapeToIntPtr(param_.stride, ¶m_stride_); CastTShapeToIntPtr(param_.dilate, ¶m_dilate_); diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h index b7d2676fadf3..fbdfaa84faab 100644 --- a/src/operator/nn/deconvolution-inl.h +++ b/src/operator/nn/deconvolution-inl.h @@ -63,28 +63,28 @@ struct DeconvolutionParam : public dmlc::Parameter { bool cudnn_off; dmlc::optional layout; DMLC_DECLARE_PARAMETER(DeconvolutionParam) { - DMLC_DECLARE_FIELD(kernel).describe("Deconvolution kernel size: (h, w) or (d, h, w). " + DMLC_DECLARE_FIELD(kernel).describe("Deconvolution kernel size: (w,), (h, w) or (d, h, w). " "This is same as the kernel size used for the corresponding convolution"); DMLC_DECLARE_FIELD(stride).set_default(TShape()) - .describe("The stride used for the corresponding convolution: (h, w) or (d, h, w). " + .describe("The stride used for the corresponding convolution: (w,), (h, w) or (d, h, w). " "Defaults to 1 for each dimension."); DMLC_DECLARE_FIELD(dilate).set_default(TShape()) - .describe("Dilation factor for each dimension of the input: (h, w) or (d, h, w). " + .describe("Dilation factor for each dimension of the input: (w,), (h, w) or (d, h, w). " "Defaults to 1 for each dimension."); DMLC_DECLARE_FIELD(pad).set_default(TShape()) .describe("The amount of implicit zero padding added during convolution for each " "dimension of the input: " - "(h, w) or (d, h, w). " + "(w,), (h, w) or (d, h, w). " "``(kernel-1)/2`` is usually a good choice. " "If `target_shape` is set, " "`pad` will be ignored and a padding that will generate the target shape " "will be used. Defaults to no padding."); DMLC_DECLARE_FIELD(adj).set_default(TShape()) - .describe("Adjustment for output shape: (h, w) or (d, h, w). " + .describe("Adjustment for output shape: (w,), (h, w) or (d, h, w). " "If `target_shape` is set, " "`adj` will be ignored and computed accordingly."); DMLC_DECLARE_FIELD(target_shape).set_default(TShape()) - .describe("Shape of the output tensor: (h, w) or (d, h, w)."); + .describe("Shape of the output tensor: (w,), (h, w) or (d, h, w)."); DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000) .describe("Number of output filters."); DMLC_DECLARE_FIELD(num_group).set_default(1) @@ -211,8 +211,8 @@ class DeconvolutionOp : public Operator { using namespace mshadow; using namespace mshadow::expr; - if (param_.kernel.ndim() != 2) { - LOG(FATAL) << "If not using CUDNN only 2D-Deconvolution is supported"; + if (param_.kernel.ndim() > 2) { + LOG(FATAL) << "If not using CUDNN, only 1D or 2D Deconvolution is supported"; } CHECK_EQ(req[deconv::kOut], kWriteTo); @@ -220,18 +220,29 @@ class DeconvolutionOp : public Operator { CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1U); Stream *s = ctx.get_stream(); - Tensor data = in_data[deconv::kData].get(s); - Tensor out = out_data[deconv::kOut].get(s); - + auto in_data_shape = in_data[deconv::kData].shape_; + Tensor data = TBlobTo4DTensor(in_data[deconv::kData], s); + Tensor out = TBlobTo4DTensor(out_data[deconv::kOut], s); index_t o_pad[2], o_adj[2]; - TShape dshape = {static_cast(data.size(2)), - static_cast(data.size(3))}; - param_.InferPad(dshape, o_pad, o_adj); + if (param_.kernel.ndim() == 2) { + param_.InferPad(TShape({in_data_shape[2], in_data_shape[3]}), o_pad, o_adj); + } else { + index_t o_pad_1D[1], o_adj_1D[1]; + param_.InferPad({in_data_shape[2]}, o_pad_1D, o_adj_1D); + o_pad[0] = 0; + o_pad[1] = o_pad_1D[0]; + o_adj[0] = 0; + o_adj[1] = o_adj_1D[0]; + } + auto stride = param_.kernel.ndim() == 2 ? param_.stride : TShape({1, param_.stride[0]}); + auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : TShape({1, param_.dilate[0]}); + auto kernel = param_.kernel.ndim() == 2 ? param_.kernel : TShape({1, param_.kernel[0]}); + auto kernel_size = kernel.Size(); Shape<3> wmat_shape = Shape3(param_.num_group, data.shape_[1] / param_.num_group, - param_.num_filter / param_.num_group * param_.kernel[0] * param_.kernel[1]); + param_.num_filter / param_.num_group * kernel_size); Tensor wmat = in_data[deconv::kWeight].get_with_shape(wmat_shape, s); #if defined(__CUDACC__) @@ -256,21 +267,21 @@ class DeconvolutionOp : public Operator { temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_); if (o_pad[0] == 0 && o_pad[1] == 0) { temp_col = unpack_patch2col(out.Slice(i, i + step), - param_.kernel[0], - param_.kernel[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1]); + kernel[0], + kernel[1], + stride[0], + stride[1], + dilate[0], + dilate[1]); } else { temp_col = unpack_patch2col(pad(out.Slice(i, i + step), o_pad[0], o_pad[1]), - param_.kernel[0], - param_.kernel[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1]); + kernel[0], + kernel[1], + stride[0], + stride[1], + dilate[0], + dilate[1]); } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { @@ -283,24 +294,24 @@ class DeconvolutionOp : public Operator { if (o_pad[0] == 0 && o_pad[1] == 0) { out.Slice(i, i + step) = pack_col2patch(temp_col, out.Slice(i, i + step).shape_, - param_.kernel[0], - param_.kernel[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1]); + kernel[0], + kernel[1], + stride[0], + stride[1], + dilate[0], + dilate[1]); } else { Shape<4> pshape = out.Slice(i, i + step).shape_; pshape[2] += 2 * o_pad[0]; pshape[3] += 2 * o_pad[1]; out.Slice(i, i + step) = crop(pack_col2patch(temp_col, pshape, - param_.kernel[0], - param_.kernel[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1]), + kernel[0], + kernel[1], + stride[0], + stride[1], + dilate[0], + dilate[1]), out[i][0].shape_); } } @@ -328,13 +339,31 @@ class DeconvolutionOp : public Operator { CHECK_EQ(in_data[deconv::kWeight].CheckContiguous(), true); // get data Stream *s = ctx.get_stream(); - Tensor data = in_data[deconv::kData].get(s); - Tensor grad = out_grad[deconv::kOut].get(s); - Tensor gdata = in_grad[deconv::kData].get(s); + auto in_data_shape = in_data[deconv::kData].shape_; + Tensor data = TBlobTo4DTensor(in_data[deconv::kData], s); + Tensor grad = TBlobTo4DTensor(out_grad[deconv::kOut], s); + Tensor gdata = TBlobTo4DTensor(in_grad[deconv::kData], s); + + index_t o_pad[2], o_adj[2]; + if (param_.kernel.ndim() == 2) { + param_.InferPad(TShape({in_data_shape[2], in_data_shape[3]}), o_pad, o_adj); + } else { + index_t o_pad_1D[1], o_adj_1D[1]; + param_.InferPad({in_data_shape[2]}, o_pad_1D, o_adj_1D); + o_pad[0] = 0; + o_pad[1] = o_pad_1D[0]; + o_adj[0] = 0; + o_adj[1] = o_adj_1D[0]; + } + auto stride = param_.kernel.ndim() == 2 ? param_.stride : TShape({1, param_.stride[0]}); + auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : TShape({1, param_.dilate[0]}); + auto kernel = param_.kernel.ndim() == 2 ? param_.kernel : TShape({1, param_.kernel[0]}); + auto kernel_size = kernel.Size(); + Shape<3> wmat_shape = Shape3(param_.num_group, data.shape_[1] / param_.num_group, - param_.num_filter / param_.num_group * param_.kernel[0] * param_.kernel[1]); + param_.num_filter / param_.num_group * kernel_size); Tensor wmat = in_data[deconv::kWeight].get_with_shape(wmat_shape, s); Tensor gwmat = @@ -343,10 +372,6 @@ class DeconvolutionOp : public Operator { CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) << "Must init CuBLAS handle in stream"; #endif - index_t o_pad[2], o_adj[2]; - TShape dshape = {static_cast(data.size(2)), - static_cast(data.size(3))}; - param_.InferPad(dshape, o_pad, o_adj); const index_t nbatch = data.size(0); Tensor workspace = @@ -366,20 +391,20 @@ class DeconvolutionOp : public Operator { temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_); if (o_pad[0] == 0 && o_pad[1] == 0) { temp_col = unpack_patch2col(grad.Slice(i, i + step), - param_.kernel[0], - param_.kernel[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1]); + kernel[0], + kernel[1], + stride[0], + stride[1], + dilate[0], + dilate[1]); } else { temp_col = unpack_patch2col(pad(grad.Slice(i, i + step), o_pad[0], o_pad[1]), - param_.kernel[0], - param_.kernel[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1]); + kernel[0], + kernel[1], + stride[0], + stride[1], + dilate[0], + dilate[1]); } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { @@ -422,9 +447,8 @@ class DeconvolutionOp : public Operator { private: inline index_t InitTemp(const mshadow::Shape<4> &ishape, const mshadow::Shape<4> &oshape) { - const int ksize_y = param_.kernel[0]; - const int ksize_x = param_.kernel[1]; - shape_colunit_ = mshadow::Shape2(ishape[1] * ksize_y * ksize_x, + const int ksize = param_.kernel.Size(); + shape_colunit_ = mshadow::Shape2(ishape[1] * ksize, oshape[2] * oshape[3]); shape_dstunit_ = mshadow::Shape3(param_.num_group, oshape[1] / param_.num_group, @@ -449,6 +473,15 @@ class DeconvolutionOp : public Operator { return required_size; } + inline Tensor TBlobTo4DTensor(const TBlob &tb, Stream *s) { + using namespace mshadow; + if (param_.kernel.ndim() == 2) + return tb.get(s); + else + return tb.get_with_shape( + Shape4(tb.shape_[0], tb.shape_[1], 1, tb.shape_[2]), s); + } + DeconvolutionParam param_; mshadow::Shape<2> shape_colunit_; mshadow::Shape<3> shape_dstunit_; @@ -505,8 +538,8 @@ class DeconvolutionProp : public OperatorProperty { std::vector *out_shape, std::vector *aux_shape) const override { #if MXNET_USE_CUDNN == 0 - if (param_.kernel.ndim() != 2) { - LOG(FATAL) << "If not using CUDNN only 2D-Deconvolution is supported"; + if (param_.kernel.ndim() > 2) { + LOG(FATAL) << "If not using CUDNN, only 1D or 2D Deconvolution is supported"; return false; } #endif // CUDNN diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index 45867f78593c..9d3c040c1d63 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -55,7 +55,7 @@ MXNET_REGISTER_OP_PROPERTY(Deconvolution, DeconvolutionProp) .add_argument("bias", "NDArray-or-Symbol", "Bias added to the result after the deconvolution " "operation.") .add_arguments(DeconvolutionParam::__FIELDS__()) -.describe("Computes 2D transposed convolution (aka fractionally strided convolution) of the " +.describe("Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of the " "input tensor. This operation can be seen as the gradient of Convolution operation with " "respect to its input. Convolution usually reduces the size of the input. Transposed " "convolution works the other way, going from a smaller input to a larger output while " diff --git a/src/operator/nn/deconvolution.cu b/src/operator/nn/deconvolution.cu index 6d0787662c64..623770170d50 100644 --- a/src/operator/nn/deconvolution.cu +++ b/src/operator/nn/deconvolution.cu @@ -38,13 +38,7 @@ Operator* CreateOp(DeconvolutionParam param, int dtype, Context ctx) { // Logic here parallels that in Convolution.cu Operator *op = NULL; - // If 1D deconvolution, use MXNet implementation - if (param.kernel.ndim() == 1) { - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new DeconvolutionOp(param); - }) - return op; - } + #if MXNET_USE_CUDNN == 1 // On fp16-I/O instances, use fp32 compute (i.e. pseudo-fp16). int compute_type = (dtype == mshadow::kFloat16) ? mshadow::kFloat32 : dtype; diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 7706bce56e74..31e888b4d10e 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -460,19 +460,19 @@ def test_convolution_options(): {'ctx': mx.cpu(0), 'conv_data': (2, 2, 7), 'type_dict': {'conv_data': np.float64}}, {'ctx': mx.cpu(0), 'conv_data': (2, 2, 7), 'type_dict': {'conv_data': np.float32}}] # Pad > 0 - sym = mx.sym.Convolution(num_filter=3, kernel=(3,), pad=(1,), name='conv') + sym = mx.sym.Convolution(layout='NCW', num_filter=3, kernel=(3,), pad=(1,), name='conv') sym_no_cudnn = mx.sym.Convolution(num_filter=3, kernel=(3,), pad=(1,), cudnn_off=True, name='conv') check_consistency_NxM([sym, sym_no_cudnn], ctx_list) # Stride > 1 - sym = mx.sym.Convolution(num_filter=3, kernel=(3,), stride=(2,), name='conv') + sym = mx.sym.Convolution(layout='NCW', num_filter=3, kernel=(3,), stride=(2,), name='conv') sym_no_cudnn = mx.sym.Convolution(num_filter=3, kernel=(3,), stride=(2,), cudnn_off=True, name='conv') check_consistency_NxM([sym, sym_no_cudnn], ctx_list) # Dilate > 1 - sym = mx.sym.Convolution(num_filter=3, kernel=(3,), dilate=(2,), name='conv') + sym = mx.sym.Convolution(layout='NCW', num_filter=3, kernel=(3,), dilate=(2,), name='conv') sym_no_cudnn = mx.sym.Convolution(num_filter=3, kernel=(3,), dilate=(2,), cudnn_off=True, name='conv') check_consistency_NxM([sym, sym_no_cudnn], ctx_list) # 1x1 convolution - sym = mx.sym.Convolution(num_filter=3, kernel=(1,), pad=(0,), name='conv') + sym = mx.sym.Convolution(layout='NCW', num_filter=3, kernel=(1,), pad=(0,), name='conv') sym_no_cudnn = mx.sym.Convolution(num_filter=3, kernel=(1,), pad=(0,), cudnn_off=True, name='conv') check_consistency_NxM([sym, sym_no_cudnn], ctx_list) @@ -558,6 +558,24 @@ def test_pooling_with_type(): check_consistency(sym, ctx_list) def test_deconvolution_with_type(): + # Test basic deconvolution without exercising stride, pad or dilation. + # 1D deconvolution + sym = mx.sym.Deconvolution(num_filter=3, kernel=(3,), name='deconv') + ctx_list = [{'ctx': mx.gpu(0), 'deconv_data': (2, 2, 7), 'type_dict': {'deconv_data': np.float64}}, + {'ctx': mx.gpu(0), 'deconv_data': (2, 2, 7), 'type_dict': {'deconv_data': np.float32}}, + {'ctx': mx.gpu(0), 'deconv_data': (2, 2, 7), 'type_dict': {'deconv_data': np.float16}}, + {'ctx': mx.cpu(0), 'deconv_data': (2, 2, 7), 'type_dict': {'deconv_data': np.float64}}, + {'ctx': mx.cpu(0), 'deconv_data': (2, 2, 7), 'type_dict': {'deconv_data': np.float32}}] + # wider tolerance needed for true-fp16 test above + tol = {np.dtype(np.float16): 0.3, + np.dtype(np.float32): 1e-3, + np.dtype(np.float64): 1e-5, + np.dtype(np.uint8): 0, + np.dtype(np.int32): 0} + check_consistency(sym, ctx_list, tol=tol) + check_consistency(sym, ctx_list, tol=tol, grad_req="add") + + # 2D deconvolution sym = mx.sym.Deconvolution(num_filter=2, kernel=(3,3), name='deconv') ctx_list = [{'ctx': mx.gpu(0), 'deconv_data': (2, 2, 10, 10), 'type_dict': {'deconv_data': np.float64}}, {'ctx': mx.gpu(0), 'deconv_data': (2, 2, 10, 10), 'type_dict': {'deconv_data': np.float32}}, @@ -575,24 +593,24 @@ def test_deconvolution_with_type(): def test_deconvolution_options(): -# # 1D convolution (not yet enabled) -# ctx_list = [{'ctx': mx.gpu(0), 'conv_data': (2, 2, 7), 'type_dict': {'conv_data': np.float64}}, -# {'ctx': mx.gpu(0), 'conv_data': (2, 2, 7), 'type_dict': {'conv_data': np.float32}}, -# {'ctx': mx.gpu(0), 'conv_data': (2, 2, 7), 'type_dict': {'conv_data': np.float16}}, -# {'ctx': mx.cpu(0), 'conv_data': (2, 2, 7), 'type_dict': {'conv_data': np.float64}}, -# {'ctx': mx.cpu(0), 'conv_data': (2, 2, 7), 'type_dict': {'conv_data': np.float32}}] -# # Pad > 0 -# sym = mx.sym.Convolution(num_filter=3, kernel=(3,), pad=(1,), name='conv') -# sym_no_cudnn = mx.sym.Convolution(num_filter=3, kernel=(3,), pad=(1,), cudnn_off=True, name='conv') -# check_consistency_NxM([sym, sym_no_cudnn], ctx_list) -# # Stride > 1 -# sym = mx.sym.Convolution(num_filter=3, kernel=(3,), stride=(2,), name='conv') -# sym_no_cudnn = mx.sym.Convolution(num_filter=3, kernel=(3,), stride=(2,), cudnn_off=True, name='conv') -# check_consistency_NxM([sym, sym_no_cudnn], ctx_list) -# # Dilate > 1 -# sym = mx.sym.Convolution(num_filter=3, kernel=(3,), dilate=(2,), name='conv') -# sym_no_cudnn = mx.sym.Convolution(num_filter=3, kernel=(3,), dilate=(2,), cudnn_off=True, name='conv') -# check_consistency_NxM([sym, sym_no_cudnn], ctx_list) + # 1D deconvolution + ctx_list = [{'ctx': mx.gpu(0), 'deconv_data': (2, 2, 7), 'type_dict': {'deconv_data': np.float64}}, + {'ctx': mx.gpu(0), 'deconv_data': (2, 2, 7), 'type_dict': {'deconv_data': np.float32}}, + {'ctx': mx.gpu(0), 'deconv_data': (2, 2, 7), 'type_dict': {'deconv_data': np.float16}}, + {'ctx': mx.cpu(0), 'deconv_data': (2, 2, 7), 'type_dict': {'deconv_data': np.float64}}, + {'ctx': mx.cpu(0), 'deconv_data': (2, 2, 7), 'type_dict': {'deconv_data': np.float32}}] + # Pad > 0 + sym = mx.sym.Deconvolution(layout='NCW', num_filter=3, kernel=(3,), pad=(1,), name='deconv') + sym_no_cudnn = mx.sym.Deconvolution(num_filter=3, kernel=(3,), pad=(1,), cudnn_off=True, name='deconv') + check_consistency_NxM([sym, sym_no_cudnn], ctx_list) + # Stride > 1 + sym = mx.sym.Deconvolution(layout='NCW', num_filter=3, kernel=(3,), stride=(2,), name='deconv') + sym_no_cudnn = mx.sym.Deconvolution(num_filter=3, kernel=(3,), stride=(2,), cudnn_off=True, name='deconv') + check_consistency_NxM([sym, sym_no_cudnn], ctx_list) + # Dilate > 1 + sym = mx.sym.Deconvolution(layout='NCW', num_filter=3, kernel=(3,), dilate=(2,), name='deconv') + sym_no_cudnn = mx.sym.Deconvolution(num_filter=3, kernel=(3,), dilate=(2,), cudnn_off=True, name='deconv') + check_consistency_NxM([sym, sym_no_cudnn], ctx_list) # 2D deconvolution ctx_list = [{'ctx': mx.gpu(0), 'deconv_data': (2, 2, 10, 10), 'type_dict': {'deconv_data': np.float64}}, @@ -613,7 +631,7 @@ def test_deconvolution_options(): sym_no_cudnn = mx.sym.Deconvolution(num_filter=2, kernel=(3,3), dilate=(2,2), cudnn_off=True, name='deconv') check_consistency_NxM([sym, sym_no_cudnn], ctx_list) -# # 3D convolution (not yet enabled) +# # 3D deconvolution (not yet enabled) # ctx_list = [{'ctx': mx.cpu(0), 'conv_data': (2, 2, 5, 7, 7), 'type_dict': {'conv_data': np.float64}}, # {'ctx': mx.cpu(0), 'conv_data': (2, 2, 5, 7, 7), 'type_dict': {'conv_data': np.float64}}, # {'ctx': mx.gpu(0), 'conv_data': (2, 2, 5, 7, 7), 'type_dict': {'conv_data': np.float64}}, diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 0230d5f064a9..3fbf98becc8a 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -788,8 +788,9 @@ def check_deconvolution_gradient(input_shape, num_filter, pad): During backward(), if the input of A equals output of B, and the output of A equals input of B, then the grad of weight should be the same; """ - stride = (1, 1) - kernel = (2*pad[0]+1, 2*pad[1]+1) + ndim = len(pad) + stride = (1,) * ndim + kernel = tuple(2 * np.array(pad) + 1) data_conv = mx.sym.Variable(name="data_conv") conv = mx.sym.Convolution( data=data_conv, kernel=kernel, stride=stride, pad=pad, @@ -848,10 +849,14 @@ def check_deconvolution_target_shape(input_shape, kernel, stride, pad, adj, targ data=data, kernel=kernel, stride=stride, pad=pad, adj=adj, num_filter=5) arg_names = deconv.list_arguments() arg_shapes, out_shapes, _ = deconv.infer_shape(data=input_shape) - assert out_shapes[0] == (input_shape[0], 5, 8, 8) + default_target_size = 8 + if target_shape is None: + target_shape = (default_target_size,) * len(kernel) + assert out_shapes[0] == (input_shape[0], 5) + target_shape def test_deconvolution(): + # 2D check_deconvolution_target_shape( input_shape = (2,3,4,4), kernel = (3,3), @@ -898,6 +903,53 @@ def test_deconvolution(): num_filter = 3, pad = (3,3) ) + # 1D + check_deconvolution_target_shape( + input_shape = (2,3,4), + kernel = (3,), + stride = (2,), + target_shape = (8,), + pad = (99,), # will be ignored + adj = (101,), # will be ignored + ) + check_deconvolution_target_shape( + input_shape = (2,3,4), + kernel = (3,), + stride = (2,), + pad = (1,), + adj = (1,), + ) + check_deconvolution_forward_backward( + input_shape = (1,1,5), + num_filter = 1, + kernel = (3,), + stride = (1,), + pad = (1,) + ) + check_deconvolution_forward_backward( + input_shape = (32,3,28), + num_filter = 3, + kernel = (3,), + stride = (1,), + pad = (1,) + ) + check_deconvolution_forward_backward( + input_shape = (10, 3, 403), + num_filter = 3, + kernel = (7,), + stride = (5,), + pad = (2,) + ) + check_deconvolution_gradient( + input_shape = (1,3,5), + num_filter = 3, + pad = (1,) + ) + check_deconvolution_gradient( + input_shape = (5,3,100), + num_filter = 3, + pad = (3,) + ) def check_nearest_upsampling_with_shape(shapes, scale, root_scale): @@ -1022,74 +1074,79 @@ def check_batchnorm_training(stype): def test_convolution_grouping(): - num_filter = 4 - num_group = 2 - kernel = (3, 3) - shape = (1, 4, 9, 9) + for dim in [1, 2, 3]: + num_filter = 4 + num_group = 2 + kernel = (3,) * dim + shape = (1, 4) + (9,) * dim - x = mx.sym.Variable('x') - w = mx.sym.Variable('w') - b = mx.sym.Variable('b') - y1 = mx.sym.Convolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group, kernel=kernel) - xslice = mx.sym.SliceChannel(data=x, num_outputs=num_group, axis=1) - wslice = mx.sym.SliceChannel(data=w, num_outputs=num_group, axis=0) - bslice = mx.sym.SliceChannel(data=b, num_outputs=num_group, axis=0) - y2 = mx.sym.Concat(*[mx.sym.Convolution(data=xslice[i], weight=wslice[i], bias=bslice[i], - num_filter=num_filter//num_group, kernel=kernel) - for i in range(num_group)]) - - exe1 = y1.simple_bind(default_context(), x=shape) - exe2 = y2.simple_bind(default_context(), x=shape, w=(num_filter, shape[1]//num_group, kernel[0], kernel[1]), b=(num_filter,)) - for arr1, arr2 in zip(exe1.arg_arrays, exe2.arg_arrays): - arr1[:] = np.random.normal(size=arr1.shape) - arr2[:] = arr1 - exe1.forward(is_train=True) - exe1.backward(exe1.outputs[0]) - exe2.forward(is_train=True) - exe2.backward(exe2.outputs[0]) - - for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays): - np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-4) + x = mx.sym.Variable('x') + w = mx.sym.Variable('w') + b = mx.sym.Variable('b') + y1 = mx.sym.Convolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group, kernel=kernel) + xslice = mx.sym.SliceChannel(data=x, num_outputs=num_group, axis=1) + wslice = mx.sym.SliceChannel(data=w, num_outputs=num_group, axis=0) + bslice = mx.sym.SliceChannel(data=b, num_outputs=num_group, axis=0) + y2 = mx.sym.Concat(*[mx.sym.Convolution(data=xslice[i], weight=wslice[i], bias=bslice[i], + num_filter=num_filter//num_group, kernel=kernel) + for i in range(num_group)]) + + exe1 = y1.simple_bind(default_context(), x=shape) + exe2 = y2.simple_bind(default_context(), x=shape, w=(num_filter, shape[1]//num_group) + kernel, b=(num_filter,)) + for arr1, arr2 in zip(exe1.arg_arrays, exe2.arg_arrays): + arr1[:] = np.random.normal(size=arr1.shape) + arr2[:] = arr1 + exe1.forward(is_train=True) + exe1.backward(exe1.outputs[0]) + exe2.forward(is_train=True) + exe2.backward(exe2.outputs[0]) + + for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays): + np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-4) @unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/8712") def test_depthwise_convolution(): - for num_base in [1, 4, 16, 32, 64]: - for kernel in [(3,3), (5,5)]: - for stride in [(1,1), (2,2)]: - for pad in [(0,0), (1,1)]: - for in_size in [7, 32]: - num_filter = num_base - num_group = num_base - shape = (2, num_base, in_size, in_size) - - x = mx.sym.Variable('x') - w = mx.sym.Variable('w') - b = mx.sym.Variable('b') - y1 = mx.sym.Convolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group, - kernel=kernel, stride=stride, pad=pad) - xslice = mx.sym.SliceChannel(data=x, num_outputs=num_group, axis=1) - wslice = mx.sym.SliceChannel(data=w, num_outputs=num_group, axis=0) - bslice = mx.sym.SliceChannel(data=b, num_outputs=num_group, axis=0) - y2 = mx.sym.Concat(*[mx.sym.Convolution(data=xslice[i], weight=wslice[i], bias=bslice[i], - num_filter=num_filter//num_group, kernel=kernel, - stride=stride, pad=pad) - for i in range(num_group)]) - - dev = default_context() - exe1 = y1.simple_bind(dev, x=shape) - exe2 = y2.simple_bind(mx.cpu(), x=shape, w=(num_filter, shape[1]//num_group, kernel[0], kernel[1]), - b=(num_filter,)) - for arr1, arr2 in zip(exe1.arg_arrays, exe2.arg_arrays): - arr1[:] = np.random.normal(size=arr1.shape) - arr2[:] = arr1 - exe1.forward(is_train=True) - exe1.backward(exe1.outputs[0]) - exe2.forward(is_train=True) - exe2.backward(exe2.outputs[0]) - - for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays): - np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-4) + for dim in [1,2]: + for num_base in [1, 4, 16, 32, 64]: + for kernel_x in [3, 5]: + for stride_x in [1, 2]: + for pad_x in [0, 1]: + for in_size in [7, 32]: + kernel = (kernel_x,) * dim + stride = (stride_x,) * dim + pad = (pad_x,) * dim + num_filter = num_base + num_group = num_base + shape = (2, num_base) + (in_size,) * dim + + x = mx.sym.Variable('x') + w = mx.sym.Variable('w') + b = mx.sym.Variable('b') + y1 = mx.sym.Convolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group, + kernel=kernel, stride=stride, pad=pad) + xslice = mx.sym.SliceChannel(data=x, num_outputs=num_group, axis=1) + wslice = mx.sym.SliceChannel(data=w, num_outputs=num_group, axis=0) + bslice = mx.sym.SliceChannel(data=b, num_outputs=num_group, axis=0) + y2 = mx.sym.Concat(*[mx.sym.Convolution(data=xslice[i], weight=wslice[i], bias=bslice[i], + num_filter=num_filter//num_group, kernel=kernel, + stride=stride, pad=pad) + for i in range(num_group)]) + + dev = default_context() + exe1 = y1.simple_bind(dev, x=shape) + exe2 = y2.simple_bind(mx.cpu(), x=shape, w=(num_filter, shape[1]//num_group)+kernel, + b=(num_filter,)) + for arr1, arr2 in zip(exe1.arg_arrays, exe2.arg_arrays): + arr1[:] = np.random.normal(size=arr1.shape) + arr2[:] = arr1 + exe1.forward(is_train=True) + exe1.backward(exe1.outputs[0]) + exe2.forward(is_train=True) + exe2.backward(exe2.outputs[0]) + + for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays): + np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-3) def gen_broadcast_data(idx): @@ -1361,9 +1418,14 @@ def test_bmin(a, b): def test_run_convolution_dilated_impulse_response(dil=(1,1), kernel_shape=(3,3), verbose=False): + dim = len(dil) + assert(len(kernel_shape) == dim) # Input for spike response - spike_imgs = np.zeros(shape=(1,1,33,33), dtype=np.float32) - spike_imgs[0,0,16,16] = 1.0 + data_size = 33 + data_shape = (1, 1) + (data_size,) * dim + center = (0,0) + (data_size // 2,) * dim + spike_imgs = np.zeros(shape=data_shape, dtype=np.float32) + spike_imgs[center] = 1.0 spike_img = mx.nd.array(spike_imgs) spike_img2 = mx.nd.array(spike_imgs) @@ -1381,14 +1443,14 @@ def test_run_convolution_dilated_impulse_response(dil=(1,1), kernel_shape=(3,3), ndo = be.outputs[0] out_grads = np.zeros(shape=be.outputs[0].shape, dtype=np.float32) - out_grads[0,0, 16,16] = 1.0 + out_grads[center] = 1.0 out_grad = mx.nd.array(out_grads) be.backward([out_grad]) vgrad = be.grad_arrays[0].asnumpy() - out = out_o.reshape((out_o.shape[2],out_o.shape[3])) - nzx,nzy = np.nonzero(out) - assert(np.sum(out)==np.prod(kernel_shape)) - assert(np.sum(vgrad)==np.prod(kernel_shape)) + out = out_o.reshape(out_o.shape[2:]) + nz_loc = np.nonzero(out) + assert_allclose(np.sum(out),np.prod(kernel_shape),atol=1e-5) + assert_allclose(np.sum(vgrad),np.prod(kernel_shape),atol=1e-5) # Now check whether the input gradient was computed correctly input_grad = mx.nd.array(vgrad) @@ -1396,15 +1458,15 @@ def test_run_convolution_dilated_impulse_response(dil=(1,1), kernel_shape=(3,3), be = net.bind(default_context(), args={ 'input' : input_grad, 'test_convolution_weight' : kernel_weights}) be.forward(True) out_o = be.outputs[0].asnumpy() - assert(out_o[0,0,16,16]==np.prod(kernel_shape)) + assert_allclose(out_o[center],np.prod(kernel_shape),atol=1e-5) rnd_kernel_s = np.random.uniform(low=0.0, high=1.0, size=tuple([1,1]+list(kernel_shape))).astype(np.float32) impulse_error = mx.nd.array(out_o/np.sum(out_o)) # This should be 1.0 at [0,0,16,16] rnd_kernel = mx.nd.array(rnd_kernel_s) rnd_kernel2 = mx.nd.array(rnd_kernel_s) - white_in = mx.nd.ones(shape=(1,1,33,33)) - white_in2 = mx.nd.ones(shape=(1,1,33,33)) + white_in = mx.nd.ones(shape=data_shape) + white_in2 = mx.nd.ones(shape=data_shape) be = net.bind(default_context(), args={ 'input' : white_in, 'test_convolution_weight' : rnd_kernel}, args_grad={'input' : white_in2, 'test_convolution_weight' : rnd_kernel2 } ) @@ -1421,10 +1483,15 @@ def test_run_convolution_dilated_impulse_response(dil=(1,1), kernel_shape=(3,3), be.forward(True) out = be.outputs[0].asnumpy() # Now do a simple check of the kernel gradient - assert(out[0,0,16,16] - np.sum(kernel_gradient) - out_orig[0,0,16,16] < 0.001) + assert(out[center] - np.sum(kernel_gradient) - out_orig[center] < 0.001) def test_convolution_dilated_impulse_response(): + # 1D + for dil in [ (1,), (2,), (3,) ]: + for ks in [ (1,), (2,), (3,), (4,)]: + test_run_convolution_dilated_impulse_response(dil=dil, kernel_shape=ks) + # 2D for dil in [ (1,1), (2,2), (3,3) ]: for ks in [ (3,3), (4,4), (2,3), (3,2), (1,1) ]: test_run_convolution_dilated_impulse_response(dil=dil, kernel_shape=ks)