Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
1 d conv with cudnn (#9184)
Browse files Browse the repository at this point in the history
* 1D conv/deconv handling by cudnn, with tests.

* Fix python3 test issue.

* Fix lint issues.

* Fixed CI and doc.
  • Loading branch information
DickJC123 authored and piiswrong committed Jan 2, 2018
1 parent 4aff838 commit 5b99b25
Show file tree
Hide file tree
Showing 9 changed files with 529 additions and 476 deletions.
8 changes: 4 additions & 4 deletions src/operator/nn/convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ struct ConvolutionParam : public dmlc::Parameter<ConvolutionParam> {
bool cudnn_off;
dmlc::optional<int> layout;
DMLC_DECLARE_PARAMETER(ConvolutionParam) {
DMLC_DECLARE_FIELD(kernel).describe("Convolution kernel size: (h, w) or (d, h, w)");
DMLC_DECLARE_FIELD(kernel).describe("Convolution kernel size: (w,), (h, w) or (d, h, w)");
DMLC_DECLARE_FIELD(stride).set_default(TShape())
.describe("Convolution stride: (h, w) or (d, h, w). Defaults to 1 for each dimension.");
.describe("Convolution stride: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.");
DMLC_DECLARE_FIELD(dilate).set_default(TShape())
.describe("Convolution dilate: (h, w) or (d, h, w). Defaults to 1 for each dimension.");
.describe("Convolution dilate: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.");
DMLC_DECLARE_FIELD(pad).set_default(TShape())
.describe("Zero pad for convolution: (h, w) or (d, h, w). Defaults to no padding.");
.describe("Zero pad for convolution: (w,), (h, w) or (d, h, w). Defaults to no padding.");
DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000)
.describe("Convolution filter(channel) number");
DMLC_DECLARE_FIELD(num_group).set_default(1)
Expand Down
7 changes: 0 additions & 7 deletions src/operator/nn/convolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@ Operator* CreateOp<gpu>(ConvolutionParam param, int dtype,
std::vector<TShape> *out_shape,
Context ctx) {
Operator *op = NULL;
// If 1D convolution, use MXNet implementation
if (param.kernel.ndim() == 1) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new ConvolutionOp<gpu, DType>(param);
})
return op;
}

// depth wise conv
if (param.num_filter == param.num_group &&
Expand Down
273 changes: 122 additions & 151 deletions src/operator/nn/cudnn/cudnn_convolution-inl.h

Large diffs are not rendered by default.

261 changes: 119 additions & 142 deletions src/operator/nn/cudnn/cudnn_deconvolution-inl.h

Large diffs are not rendered by default.

161 changes: 97 additions & 64 deletions src/operator/nn/deconvolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,28 @@ struct DeconvolutionParam : public dmlc::Parameter<DeconvolutionParam> {
bool cudnn_off;
dmlc::optional<int> layout;
DMLC_DECLARE_PARAMETER(DeconvolutionParam) {
DMLC_DECLARE_FIELD(kernel).describe("Deconvolution kernel size: (h, w) or (d, h, w). "
DMLC_DECLARE_FIELD(kernel).describe("Deconvolution kernel size: (w,), (h, w) or (d, h, w). "
"This is same as the kernel size used for the corresponding convolution");
DMLC_DECLARE_FIELD(stride).set_default(TShape())
.describe("The stride used for the corresponding convolution: (h, w) or (d, h, w). "
.describe("The stride used for the corresponding convolution: (w,), (h, w) or (d, h, w). "
"Defaults to 1 for each dimension.");
DMLC_DECLARE_FIELD(dilate).set_default(TShape())
.describe("Dilation factor for each dimension of the input: (h, w) or (d, h, w). "
.describe("Dilation factor for each dimension of the input: (w,), (h, w) or (d, h, w). "
"Defaults to 1 for each dimension.");
DMLC_DECLARE_FIELD(pad).set_default(TShape())
.describe("The amount of implicit zero padding added during convolution for each "
"dimension of the input: "
"(h, w) or (d, h, w). "
"(w,), (h, w) or (d, h, w). "
"``(kernel-1)/2`` is usually a good choice. "
"If `target_shape` is set, "
"`pad` will be ignored and a padding that will generate the target shape "
"will be used. Defaults to no padding.");
DMLC_DECLARE_FIELD(adj).set_default(TShape())
.describe("Adjustment for output shape: (h, w) or (d, h, w). "
.describe("Adjustment for output shape: (w,), (h, w) or (d, h, w). "
"If `target_shape` is set, "
"`adj` will be ignored and computed accordingly.");
DMLC_DECLARE_FIELD(target_shape).set_default(TShape())
.describe("Shape of the output tensor: (h, w) or (d, h, w).");
.describe("Shape of the output tensor: (w,), (h, w) or (d, h, w).");
DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000)
.describe("Number of output filters.");
DMLC_DECLARE_FIELD(num_group).set_default(1)
Expand Down Expand Up @@ -211,27 +211,38 @@ class DeconvolutionOp : public Operator {
using namespace mshadow;
using namespace mshadow::expr;

if (param_.kernel.ndim() != 2) {
LOG(FATAL) << "If not using CUDNN only 2D-Deconvolution is supported";
if (param_.kernel.ndim() > 2) {
LOG(FATAL) << "If not using CUDNN, only 1D or 2D Deconvolution is supported";
}

CHECK_EQ(req[deconv::kOut], kWriteTo);
size_t expected = param_.no_bias ? 2 : 3;
CHECK_EQ(in_data.size(), expected);
CHECK_EQ(out_data.size(), 1U);
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4, DType> data = in_data[deconv::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> out = out_data[deconv::kOut].get<xpu, 4, DType>(s);

auto in_data_shape = in_data[deconv::kData].shape_;
Tensor<xpu, 4, DType> data = TBlobTo4DTensor(in_data[deconv::kData], s);
Tensor<xpu, 4, DType> out = TBlobTo4DTensor(out_data[deconv::kOut], s);
index_t o_pad[2], o_adj[2];
TShape dshape = {static_cast<nnvm::dim_t>(data.size(2)),
static_cast<nnvm::dim_t>(data.size(3))};
param_.InferPad(dshape, o_pad, o_adj);
if (param_.kernel.ndim() == 2) {
param_.InferPad(TShape({in_data_shape[2], in_data_shape[3]}), o_pad, o_adj);
} else {
index_t o_pad_1D[1], o_adj_1D[1];
param_.InferPad({in_data_shape[2]}, o_pad_1D, o_adj_1D);
o_pad[0] = 0;
o_pad[1] = o_pad_1D[0];
o_adj[0] = 0;
o_adj[1] = o_adj_1D[0];
}
auto stride = param_.kernel.ndim() == 2 ? param_.stride : TShape({1, param_.stride[0]});
auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : TShape({1, param_.dilate[0]});
auto kernel = param_.kernel.ndim() == 2 ? param_.kernel : TShape({1, param_.kernel[0]});
auto kernel_size = kernel.Size();

Shape<3> wmat_shape =
Shape3(param_.num_group,
data.shape_[1] / param_.num_group,
param_.num_filter / param_.num_group * param_.kernel[0] * param_.kernel[1]);
param_.num_filter / param_.num_group * kernel_size);
Tensor<xpu, 3, DType> wmat =
in_data[deconv::kWeight].get_with_shape<xpu, 3, DType>(wmat_shape, s);
#if defined(__CUDACC__)
Expand All @@ -256,21 +267,21 @@ class DeconvolutionOp : public Operator {
temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_);
if (o_pad[0] == 0 && o_pad[1] == 0) {
temp_col = unpack_patch2col(out.Slice(i, i + step),
param_.kernel[0],
param_.kernel[1],
param_.stride[0],
param_.stride[1],
param_.dilate[0],
param_.dilate[1]);
kernel[0],
kernel[1],
stride[0],
stride[1],
dilate[0],
dilate[1]);
} else {
temp_col = unpack_patch2col(pad(out.Slice(i, i + step),
o_pad[0], o_pad[1]),
param_.kernel[0],
param_.kernel[1],
param_.stride[0],
param_.stride[1],
param_.dilate[0],
param_.dilate[1]);
kernel[0],
kernel[1],
stride[0],
stride[1],
dilate[0],
dilate[1]);
}
const index_t gstride = temp_col.size(0) / param_.num_group;
for (uint32_t gid = 0; gid < param_.num_group; ++gid) {
Expand All @@ -283,24 +294,24 @@ class DeconvolutionOp : public Operator {
if (o_pad[0] == 0 && o_pad[1] == 0) {
out.Slice(i, i + step) = pack_col2patch(temp_col,
out.Slice(i, i + step).shape_,
param_.kernel[0],
param_.kernel[1],
param_.stride[0],
param_.stride[1],
param_.dilate[0],
param_.dilate[1]);
kernel[0],
kernel[1],
stride[0],
stride[1],
dilate[0],
dilate[1]);
} else {
Shape<4> pshape = out.Slice(i, i + step).shape_;
pshape[2] += 2 * o_pad[0];
pshape[3] += 2 * o_pad[1];
out.Slice(i, i + step) = crop(pack_col2patch(temp_col,
pshape,
param_.kernel[0],
param_.kernel[1],
param_.stride[0],
param_.stride[1],
param_.dilate[0],
param_.dilate[1]),
kernel[0],
kernel[1],
stride[0],
stride[1],
dilate[0],
dilate[1]),
out[i][0].shape_);
}
}
Expand Down Expand Up @@ -328,13 +339,31 @@ class DeconvolutionOp : public Operator {
CHECK_EQ(in_data[deconv::kWeight].CheckContiguous(), true);
// get data
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4, DType> data = in_data[deconv::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> grad = out_grad[deconv::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> gdata = in_grad[deconv::kData].get<xpu, 4, DType>(s);
auto in_data_shape = in_data[deconv::kData].shape_;
Tensor<xpu, 4, DType> data = TBlobTo4DTensor(in_data[deconv::kData], s);
Tensor<xpu, 4, DType> grad = TBlobTo4DTensor(out_grad[deconv::kOut], s);
Tensor<xpu, 4, DType> gdata = TBlobTo4DTensor(in_grad[deconv::kData], s);

index_t o_pad[2], o_adj[2];
if (param_.kernel.ndim() == 2) {
param_.InferPad(TShape({in_data_shape[2], in_data_shape[3]}), o_pad, o_adj);
} else {
index_t o_pad_1D[1], o_adj_1D[1];
param_.InferPad({in_data_shape[2]}, o_pad_1D, o_adj_1D);
o_pad[0] = 0;
o_pad[1] = o_pad_1D[0];
o_adj[0] = 0;
o_adj[1] = o_adj_1D[0];
}
auto stride = param_.kernel.ndim() == 2 ? param_.stride : TShape({1, param_.stride[0]});
auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : TShape({1, param_.dilate[0]});
auto kernel = param_.kernel.ndim() == 2 ? param_.kernel : TShape({1, param_.kernel[0]});
auto kernel_size = kernel.Size();

Shape<3> wmat_shape =
Shape3(param_.num_group,
data.shape_[1] / param_.num_group,
param_.num_filter / param_.num_group * param_.kernel[0] * param_.kernel[1]);
param_.num_filter / param_.num_group * kernel_size);
Tensor<xpu, 3, DType> wmat =
in_data[deconv::kWeight].get_with_shape<xpu, 3, DType>(wmat_shape, s);
Tensor<xpu, 3, DType> gwmat =
Expand All @@ -343,10 +372,6 @@ class DeconvolutionOp : public Operator {
CHECK_EQ(s->blas_handle_ownership_, Stream<xpu>::OwnHandle)
<< "Must init CuBLAS handle in stream";
#endif
index_t o_pad[2], o_adj[2];
TShape dshape = {static_cast<nnvm::dim_t>(data.size(2)),
static_cast<nnvm::dim_t>(data.size(3))};
param_.InferPad(dshape, o_pad, o_adj);

const index_t nbatch = data.size(0);
Tensor<xpu, 1, DType> workspace =
Expand All @@ -366,20 +391,20 @@ class DeconvolutionOp : public Operator {
temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_);
if (o_pad[0] == 0 && o_pad[1] == 0) {
temp_col = unpack_patch2col(grad.Slice(i, i + step),
param_.kernel[0],
param_.kernel[1],
param_.stride[0],
param_.stride[1],
param_.dilate[0],
param_.dilate[1]);
kernel[0],
kernel[1],
stride[0],
stride[1],
dilate[0],
dilate[1]);
} else {
temp_col = unpack_patch2col(pad(grad.Slice(i, i + step), o_pad[0], o_pad[1]),
param_.kernel[0],
param_.kernel[1],
param_.stride[0],
param_.stride[1],
param_.dilate[0],
param_.dilate[1]);
kernel[0],
kernel[1],
stride[0],
stride[1],
dilate[0],
dilate[1]);
}
const index_t gstride = temp_col.size(0) / param_.num_group;
for (uint32_t gid = 0; gid < param_.num_group; ++gid) {
Expand Down Expand Up @@ -422,9 +447,8 @@ class DeconvolutionOp : public Operator {
private:
inline index_t InitTemp(const mshadow::Shape<4> &ishape,
const mshadow::Shape<4> &oshape) {
const int ksize_y = param_.kernel[0];
const int ksize_x = param_.kernel[1];
shape_colunit_ = mshadow::Shape2(ishape[1] * ksize_y * ksize_x,
const int ksize = param_.kernel.Size();
shape_colunit_ = mshadow::Shape2(ishape[1] * ksize,
oshape[2] * oshape[3]);
shape_dstunit_ = mshadow::Shape3(param_.num_group,
oshape[1] / param_.num_group,
Expand All @@ -449,6 +473,15 @@ class DeconvolutionOp : public Operator {
return required_size;
}

inline Tensor<xpu, 4, DType> TBlobTo4DTensor(const TBlob &tb, Stream<xpu> *s) {
using namespace mshadow;
if (param_.kernel.ndim() == 2)
return tb.get<xpu, 4, DType>(s);
else
return tb.get_with_shape<xpu, 4, DType>(
Shape4(tb.shape_[0], tb.shape_[1], 1, tb.shape_[2]), s);
}

DeconvolutionParam param_;
mshadow::Shape<2> shape_colunit_;
mshadow::Shape<3> shape_dstunit_;
Expand Down Expand Up @@ -505,8 +538,8 @@ class DeconvolutionProp : public OperatorProperty {
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
#if MXNET_USE_CUDNN == 0
if (param_.kernel.ndim() != 2) {
LOG(FATAL) << "If not using CUDNN only 2D-Deconvolution is supported";
if (param_.kernel.ndim() > 2) {
LOG(FATAL) << "If not using CUDNN, only 1D or 2D Deconvolution is supported";
return false;
}
#endif // CUDNN
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ MXNET_REGISTER_OP_PROPERTY(Deconvolution, DeconvolutionProp)
.add_argument("bias", "NDArray-or-Symbol", "Bias added to the result after the deconvolution "
"operation.")
.add_arguments(DeconvolutionParam::__FIELDS__())
.describe("Computes 2D transposed convolution (aka fractionally strided convolution) of the "
.describe("Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of the "
"input tensor. This operation can be seen as the gradient of Convolution operation with "
"respect to its input. Convolution usually reduces the size of the input. Transposed "
"convolution works the other way, going from a smaller input to a larger output while "
Expand Down
8 changes: 1 addition & 7 deletions src/operator/nn/deconvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,7 @@ Operator* CreateOp<gpu>(DeconvolutionParam param, int dtype,
Context ctx) {
// Logic here parallels that in Convolution.cu
Operator *op = NULL;
// If 1D deconvolution, use MXNet implementation
if (param.kernel.ndim() == 1) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new DeconvolutionOp<gpu, DType>(param);
})
return op;
}

#if MXNET_USE_CUDNN == 1
// On fp16-I/O instances, use fp32 compute (i.e. pseudo-fp16).
int compute_type = (dtype == mshadow::kFloat16) ? mshadow::kFloat32 : dtype;
Expand Down
Loading

0 comments on commit 5b99b25

Please sign in to comment.