diff --git a/python/mxnet/gluon/contrib/cnn/conv_layers.py b/python/mxnet/gluon/contrib/cnn/conv_layers.py index 9dd208702932..098463eca968 100644 --- a/python/mxnet/gluon/contrib/cnn/conv_layers.py +++ b/python/mxnet/gluon/contrib/cnn/conv_layers.py @@ -19,7 +19,7 @@ # pylint: disable= arguments-differ """Custom convolutional neural network layers in model_zoo.""" -__all__ = ['DeformableConvolution'] +__all__ = ['DeformableConvolution', 'ModulatedDeformableConvolution'] from .... import symbol from ...block import HybridBlock @@ -219,3 +219,181 @@ def __repr__(self): return s.format(name=self.__class__.__name__, mapping='{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]), **self._kwargs_deformable_conv) + + +class ModulatedDeformableConvolution(HybridBlock): + """2-D Deformable Convolution v2 (Dai, 2018). + + The modulated deformable convolution operation is described in https://arxiv.org/abs/1811.11168 + + Parameters + ---------- + channels : int, + The dimensionality of the output space + i.e. the number of output channels in the convolution. + kernel_size : int or tuple/list of 2 ints, (Default value = (1,1)) + Specifies the dimensions of the convolution window. + strides : int or tuple/list of 2 ints, (Default value = (1,1)) + Specifies the strides of the convolution. + padding : int or tuple/list of 2 ints, (Default value = (0,0)) + If padding is non-zero, then the input is implicitly zero-padded + on both sides for padding number of points. + dilation : int or tuple/list of 2 ints, (Default value = (1,1)) + Specifies the dilation rate to use for dilated convolution. + groups : int, (Default value = 1) + Controls the connections between inputs and outputs. + At groups=1, all inputs are convolved to all outputs. + At groups=2, the operation becomes equivalent to having two convolution + layers side by side, each seeing half the input channels, and producing + half the output channels, and both subsequently concatenated. + num_deformable_group : int, (Default value = 1) + Number of deformable group partitions. + layout : str, (Default value = NCHW) + Dimension ordering of data and weight. Can be 'NCW', 'NWC', 'NCHW', + 'NHWC', 'NCDHW', 'NDHWC', etc. 'N', 'C', 'H', 'W', 'D' stands for + batch, channel, height, width and depth dimensions respectively. + Convolution is performed over 'D', 'H', and 'W' dimensions. + use_bias : bool, (Default value = True) + Whether the layer for generating the output features uses a bias vector. + in_channels : int, (Default value = 0) + The number of input channels to this layer. If not specified, + initialization will be deferred to the first time `forward` is called + and input channels will be inferred from the shape of input data. + activation : str, (Default value = None) + Activation function to use. See :func:`~mxnet.ndarray.Activation`. + If you don't specify anything, no activation is applied + (ie. "linear" activation: `a(x) = x`). + weight_initializer : str or `Initializer`, (Default value = None) + Initializer for the `weight` weights matrix for the convolution layer + for generating the output features. + bias_initializer : str or `Initializer`, (Default value = zeros) + Initializer for the bias vector for the convolution layer + for generating the output features. + offset_weight_initializer : str or `Initializer`, (Default value = zeros) + Initializer for the `weight` weights matrix for the convolution layer + for generating the offset. + offset_bias_initializer : str or `Initializer`, (Default value = zeros), + Initializer for the bias vector for the convolution layer + for generating the offset. + offset_use_bias: bool, (Default value = True) + Whether the layer for generating the offset uses a bias vector. + + Inputs: + - **data**: 4D input tensor with shape + `(batch_size, in_channels, height, width)` when `layout` is `NCHW`. + For other layouts shape is permuted accordingly. + + Outputs: + - **out**: 4D output tensor with shape + `(batch_size, channels, out_height, out_width)` when `layout` is `NCHW`. + out_height and out_width are calculated as:: + + out_height = floor((height+2*padding[0]-dilation[0]*(kernel_size[0]-1)-1)/stride[0])+1 + out_width = floor((width+2*padding[1]-dilation[1]*(kernel_size[1]-1)-1)/stride[1])+1 + """ + + def __init__(self, channels, kernel_size=(1, 1), strides=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1, + num_deformable_group=1, layout='NCHW', use_bias=True, in_channels=0, activation=None, + weight_initializer=None, bias_initializer='zeros', + offset_weight_initializer='zeros', offset_bias_initializer='zeros', offset_use_bias=True, + op_name='ModulatedDeformableConvolution', adj=None, prefix=None, params=None): + super(ModulatedDeformableConvolution, self).__init__(prefix=prefix, params=params) + with self.name_scope(): + self._channels = channels + self._in_channels = in_channels + + assert layout in ('NCHW', 'NHWC'), "Only supports 'NCHW' and 'NHWC' layout for now" + if isinstance(kernel_size, numeric_types): + kernel_size = (kernel_size,) * 2 + if isinstance(strides, numeric_types): + strides = (strides,) * len(kernel_size) + if isinstance(padding, numeric_types): + padding = (padding,) * len(kernel_size) + if isinstance(dilation, numeric_types): + dilation = (dilation,) * len(kernel_size) + self._op_name = op_name + + offset_channels = 27 + self._kwargs_offset = { + 'kernel': kernel_size, 'stride': strides, 'dilate': dilation, + 'pad': padding, 'num_filter': offset_channels, 'num_group': groups, + 'no_bias': not offset_use_bias, 'layout': layout} + + self._kwargs_deformable_conv = { + 'kernel': kernel_size, 'stride': strides, 'dilate': dilation, + 'pad': padding, 'num_filter': channels, 'num_group': groups, + 'num_deformable_group': num_deformable_group, + 'no_bias': not use_bias, 'layout': layout} + + if adj: + self._kwargs_offset['adj'] = adj + self._kwargs_deformable_conv['adj'] = adj + + deformable_conv_weight_shape = [0] * (len(kernel_size) + 2) + deformable_conv_weight_shape[0] = channels + deformable_conv_weight_shape[2] = kernel_size[0] + deformable_conv_weight_shape[3] = kernel_size[1] + + self.deformable_conv_weight = self.params.get('deformable_conv_weight', + shape=deformable_conv_weight_shape, + init=weight_initializer, + allow_deferred_init=True) + + if use_bias: + self.deformable_conv_bias = self.params.get('deformable_conv_bias', shape=(channels,), + init=bias_initializer, + allow_deferred_init=True) + else: + self.deformable_conv_bias = None + + dshape = [0] * (len(kernel_size) + 2) + dshape[layout.find('N')] = 1 + dshape[layout.find('C')] = in_channels + + op = getattr(symbol, 'Convolution') + offset = op(symbol.var('data', shape=dshape), **self._kwargs_offset) + + offsetshapes = offset.infer_shape_partial()[0] + + self.offset_weight = self.params.get('offset_weight', shape=offsetshapes[1], + init=offset_weight_initializer, + allow_deferred_init=True) + + if offset_use_bias: + self.offset_bias = self.params.get('offset_bias', shape=offsetshapes[2], + init=offset_bias_initializer, + allow_deferred_init=True) + else: + self.offset_bias = None + + if activation: + self.act = Activation(activation, prefix=activation + '_') + else: + self.act = None + + def hybrid_forward(self, F, x, offset_weight, deformable_conv_weight, offset_bias=None, deformable_conv_bias=None): + if offset_bias is None: + offset = F.Convolution(x, offset_weight, cudnn_off=True, **self._kwargs_offset) + else: + offset = F.Convolution(x, offset_weight, offset_bias, cudnn_off=True, **self._kwargs_offset) + + offset_t = F.slice_axis(offset, axis=1, begin=0, end=18) + mask = F.slice_axis(offset, axis=1, begin=18, end=None) + mask = F.sigmoid(mask) * 2 + + if deformable_conv_bias is None: + act = F.contrib.ModulatedDeformableConvolution(data=x, offset=offset_t, mask=mask, + weight=deformable_conv_weight, + name='fwd', **self._kwargs_deformable_conv) + else: + act = F.contrib.ModulatedDeformableConvolution(data=x, offset=offset_t, mask=mask, + weight=deformable_conv_weight, + bias=deformable_conv_bias, name='fwd', + **self._kwargs_deformable_conv) + + if self.act: + act = self.act(act) + return act + + def _alias(self): + return 'modulated_deformable_conv' diff --git a/src/operator/contrib/modulated_deformable_convolution-inl.h b/src/operator/contrib/modulated_deformable_convolution-inl.h new file mode 100644 index 000000000000..07e8e29fe443 --- /dev/null +++ b/src/operator/contrib/modulated_deformable_convolution-inl.h @@ -0,0 +1,576 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_convolution-inl.h + * \brief + * \ref: https://github.com/Yangqing/caffe/wiki/Convolution-in-Caffe:-a-memo + * \ref: https://arxiv.org/abs/1811.11168 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu +*/ +#ifndef MXNET_OPERATOR_CONTRIB_MODULATED_DEFORMABLE_CONVOLUTION_INL_H_ +#define MXNET_OPERATOR_CONTRIB_MODULATED_DEFORMABLE_CONVOLUTION_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../operator_common.h" +#include "../nn/im2col.h" +#include "./nn/modulated_deformable_im2col.h" +#include "../linalg.h" + + +namespace mxnet { +namespace op { + +namespace dmconv { + enum ModulatedDeformableConvolutionOpInputs { kData, kOffset, kMask, kWeight, kBias }; + enum ModulatedDeformableConvolutionOpOutputs { kOut }; + enum ModulatedDeformableConvolutionOpResource { kTempSpace }; +} + +struct ModulatedDeformableConvolutionParam + : public dmlc::Parameter { + mxnet::TShape kernel; + mxnet::TShape stride; + mxnet::TShape dilate; + mxnet::TShape pad; + uint32_t num_filter; + uint32_t num_group; + uint32_t num_deformable_group; + uint64_t workspace; + bool no_bias; + uint32_t im2col_step; + dmlc::optional layout; + DMLC_DECLARE_PARAMETER(ModulatedDeformableConvolutionParam) { + DMLC_DECLARE_FIELD(kernel).describe("Convolution kernel size: (h, w) or (d, h, w)"); + DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0, -1)) + .describe("Convolution stride: (h, w) or (d, h, w). Defaults to 1 for each dimension."); + DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape(0, -1)) + .describe("Convolution dilate: (h, w) or (d, h, w). Defaults to 1 for each dimension."); + DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0, -1)) + .describe("Zero pad for convolution: (h, w) or (d, h, w). Defaults to no padding."); + DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000) + .describe("Convolution filter(channel) number"); + DMLC_DECLARE_FIELD(num_group).set_default(1) + .describe("Number of group partitions."); + DMLC_DECLARE_FIELD(num_deformable_group).set_default(1) + .describe("Number of deformable group partitions."); + DMLC_DECLARE_FIELD(workspace).set_default(1024).set_range(0, 8192) + .describe("Maximum temperal workspace allowed for convolution (MB)."); + DMLC_DECLARE_FIELD(no_bias).set_default(false) + .describe("Whether to disable bias parameter."); + DMLC_DECLARE_FIELD(im2col_step).set_default(64) + .describe("Maximum number of images per im2col computation; " + "The total batch size should be divisable by this value or " + "smaller than this value; if you face out of memory problem, " + "you can try to use a smaller value here."); + DMLC_DECLARE_FIELD(layout) + .add_enum("NCW", mshadow::kNCW) + .add_enum("NCHW", mshadow::kNCHW) + .add_enum("NCDHW", mshadow::kNCDHW) + .set_default(dmlc::optional()) + .describe("Set layout for input, output and weight. Empty for\n " + "default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d."); + } +}; + +template +class ModulatedDeformableConvolutionOp : public Operator { + public: + explicit ModulatedDeformableConvolutionOp(ModulatedDeformableConvolutionParam p) { + this->param_ = p; + // convert MBytes first to Bytes and then to elements. + param_.workspace = (param_.workspace << 20) / sizeof(DType); + CHECK(param_.layout.value() == mshadow::kNCW || + param_.layout.value() == mshadow::kNCHW || + param_.layout.value() == mshadow::kNCDHW) + << "Only support NCW, NCHW and NCDHW layout"; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(req[dmconv::kOut], kWriteTo); + size_t expected = param_.no_bias ? 4 : 5; + CHECK_EQ(in_data.size(), expected); + CHECK_EQ(out_data.size(), 1U); + LayerSetUp(in_data[dmconv::kData].shape_, + in_data[dmconv::kOffset].shape_, + in_data[dmconv::kMask].shape_, + out_data[dmconv::kOut].shape_); + Stream* s = ctx.get_stream(); + // allocate workspace for col_buffer + Tensor workspace = ctx.requested[dmconv::kTempSpace] + .get_space_typed(Shape1(col_buffer_size_ + num_*output_dim_), s); + // calculate the shape of col_buffer + mxnet::TShape col_buffer_shape(num_spatial_axes_ + 2, -1); + col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size(); + // for (index_t i = 1; i < col_buffer_shape.ndim(); ++i) { + // col_buffer_shape[i] = out_data[0].shape_[i + 1]; + col_buffer_shape[1] = im2col_step_; + for (index_t i = 2; i < col_buffer_shape.ndim(); ++i) { + col_buffer_shape[i] = out_data[0].shape_[i]; + } + // create a column buffer using workspace and col_buffer_shape + TBlob col_buffer(workspace.dptr_, col_buffer_shape, xpu::kDevMask, DataType::kFlag); + mxnet::TShape output_buffer_shape(1, -1); + output_buffer_shape[0] = num_*output_dim_; + TBlob output_buffer(workspace.dptr_ + col_buffer_size_, output_buffer_shape, + xpu::kDevMask, DataType::kFlag); + + // initialize weight and col_buffer 3D tensors for using gemm + index_t M = conv_out_channels_ / group_; + index_t N = im2col_step_ * conv_out_spatial_dim_; + index_t K = kernel_dim_; + Tensor weight_3d = in_data[dmconv::kWeight].get_with_shape( + Shape3(group_, M, K), s); + Tensor col_buffer_3d = col_buffer.get_with_shape( + Shape3(group_, K, N), s); + Tensor output_4d = output_buffer.get_with_shape( + Shape4(num_ / im2col_step_, group_, M, N), s); + for (index_t n = 0; n < num_ / im2col_step_; ++n) { + // transform image to col_buffer in order to use gemm + modulated_deformable_im2col(s, + in_data[dmconv::kData].dptr() + n*im2col_step_*input_dim_, + in_data[dmconv::kOffset].dptr() + n*im2col_step_*input_offset_dim_, + in_data[dmconv::kMask].dptr() + n*im2col_step_ * input_mask_dim_, + in_data[dmconv::kData].shape_, + col_buffer.shape_, param_.kernel, param_.pad, param_.stride, param_.dilate, + param_.num_deformable_group, col_buffer.dptr()); + Tensor output_3d = output_4d[n]; + for (index_t g = 0; g < group_; ++g) { + // Legacy approach shown here for comparison: + // Assign(output_3d[g], req[dmconv::kOut], dot(weight_3d[g], col_buffer_3d[g])); + linalg_gemm(weight_3d[g], col_buffer_3d[g], output_3d[g], false, false, s, kWriteTo); + } + } + Tensor trans_output_4d = output_buffer.get_with_shape( + Shape4(num_ / im2col_step_, conv_out_channels_, im2col_step_, conv_out_spatial_dim_), s); + Tensor original_output_4d = out_data[dmconv::kOut].get_with_shape( + Shape4(num_ / im2col_step_, im2col_step_, conv_out_channels_, conv_out_spatial_dim_), s); + original_output_4d = swapaxis<2, 1>(trans_output_4d); + + if (bias_term_) { + Tensor bias = in_data[dmconv::kBias].get(s); + Tensor output_3d = out_data[dmconv::kOut].get_with_shape( + Shape3(num_, conv_out_channels_, conv_out_spatial_dim_), s); + // has bias term, broadcast it to the same shape of output_3d in channel dim + output_3d += mshadow::expr::broadcast<1>(bias, output_3d.shape_); + } + } + + virtual void Backward(const OpContext &ctx, + const std::vector& out_grad, + const std::vector& in_data, + const std::vector& out_data, + const std::vector& req, + const std::vector& in_grad, + const std::vector& aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1U); + size_t expected = param_.no_bias == 0 ? 5 : 4; + CHECK(in_data.size() == expected && in_grad.size() == expected); + CHECK_EQ(req.size(), expected); + CHECK_EQ(in_data[dmconv::kWeight].CheckContiguous(), true); + LayerSetUp(in_grad[dmconv::kData].shape_, + in_grad[dmconv::kOffset].shape_, + in_grad[dmconv::kMask].shape_, + out_grad[dmconv::kOut].shape_); + Stream *s = ctx.get_stream(); + // allocate workspace for col_buffer + Tensor workspace = ctx.requested[dmconv::kTempSpace] + .get_space_typed(Shape1(col_buffer_size_ + num_*output_dim_), s); + // calculate the shape of col_buffer + mxnet::TShape col_buffer_shape(num_spatial_axes_ + 2, -1); + col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size(); + col_buffer_shape[1] = im2col_step_; + for (index_t i = 2; i < col_buffer_shape.ndim(); ++i) { + col_buffer_shape[i] = out_grad[dmconv::kData].shape_[i]; + } + // create a column buffer using workspace and col_buffer_shape + TBlob col_buffer(workspace.dptr_, col_buffer_shape, xpu::kDevMask, DataType::kFlag); + mxnet::TShape output_buffer_shape(1, -1); + output_buffer_shape[0] = num_*output_dim_; + TBlob output_buffer(workspace.dptr_ + col_buffer_size_, + output_buffer_shape, xpu::kDevMask, DataType::kFlag); + + Tensor trans_output_4d = output_buffer.get_with_shape( + Shape4(num_ / im2col_step_, conv_out_channels_, im2col_step_, conv_out_spatial_dim_), s); + Tensor original_output_4d = out_grad[dmconv::kOut].get_with_shape( + Shape4(num_ / im2col_step_, im2col_step_, conv_out_channels_, conv_out_spatial_dim_), s); + trans_output_4d = swapaxis<2, 1>(original_output_4d); + + // initialize weight and col_buffer 3D tensors for using gemm + // For computing dLoss/d(in_data[kData]) + index_t M = kernel_dim_; + index_t N = im2col_step_ * conv_out_spatial_dim_; + index_t K = conv_out_channels_ / group_; + Tensor weight_3d = in_data[dmconv::kWeight].get_with_shape( + Shape3(group_, K, M), s); + Tensor out_grad_4d = output_buffer.get_with_shape( + Shape4(num_ / im2col_step_, group_, K, N), s); + Tensor col_buffer_3d = col_buffer.get_with_shape( + Shape3(group_, M, N), s); + // For computing dLoss/dWeight + Tensor dweight_3d = in_grad[dmconv::kWeight].get_with_shape( + Shape3(group_, K, M), s); + + Tensor data_grad = in_grad[dmconv::kData].FlatTo1D(s); + if (req[dmconv::kData] == kWriteTo) + data_grad = 0; + + + for (index_t n = 0; n < num_ / im2col_step_; ++n) { + Tensor out_grad_3d = out_grad_4d[n]; + for (index_t g = 0; g < group_; ++g) { + // Legacy approach shown here for comparison: + // col_buffer_3d[g] = dot(weight_3d[g].T(), out_grad_3d[g]); + linalg_gemm(weight_3d[g], out_grad_3d[g], col_buffer_3d[g], true, false, s); + } + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord(s, col_buffer.dptr(), + in_data[dmconv::kData].dptr() + n*im2col_step_*input_dim_, + in_data[dmconv::kOffset].dptr() + n*im2col_step_*input_offset_dim_, + in_data[dmconv::kMask].dptr() + n*im2col_step_*input_mask_dim_, + in_grad[dmconv::kData].shape_, col_buffer.shape_, + param_.kernel, param_.pad, param_.stride, param_.dilate, param_.num_deformable_group, + in_grad[dmconv::kOffset].dptr() + n*im2col_step_*input_offset_dim_, + in_grad[dmconv::kMask].dptr() + n*im2col_step_*input_mask_dim_, + req[dmconv::kOffset], req[dmconv::kMask]); + + // gradient w.r.t. input data + modulated_deformable_col2im(s, col_buffer.dptr(), + in_data[dmconv::kOffset].dptr() + n*im2col_step_*input_offset_dim_, + in_data[dmconv::kMask].dptr() + n*im2col_step_*input_mask_dim_, + in_grad[dmconv::kData].shape_, col_buffer.shape_, + param_.kernel, param_.pad, param_.stride, param_.dilate, param_.num_deformable_group, + in_grad[dmconv::kData].dptr() + n*im2col_step_*input_dim_, + req[dmconv::kData]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and group + modulated_deformable_im2col(s, + in_data[dmconv::kData].dptr() + n*im2col_step_*input_dim_, + in_data[dmconv::kOffset].dptr() + n*im2col_step_*input_offset_dim_, + in_data[dmconv::kMask].dptr() + n*im2col_step_*input_mask_dim_, + in_data[dmconv::kData].shape_, + col_buffer.shape_, param_.kernel, param_.pad, param_.stride, param_.dilate, + param_.num_deformable_group, col_buffer.dptr()); + + for (index_t g = 0; g < group_; ++g) { + auto request = (n == 0) ? req[dmconv::kWeight] : kAddTo; + // Legacy approach shown here for comparison: + // Assign(dweight_3d[g], request, dot(out_grad_3d[g], col_buffer_3d[g].T())); + linalg_gemm(out_grad_3d[g], col_buffer_3d[g], dweight_3d[g], false, true, s, request); + } + } + + // gradient w.r.t bias + if (bias_term_) { + Tensor dbias = in_grad[dmconv::kBias].get(s); + Tensor dout = out_grad[dmconv::kOut].get_with_shape( + Shape3(num_, conv_out_channels_, conv_out_spatial_dim_), s); + ASSIGN_DISPATCH(dbias, req[dmconv::kBias], sumall_except_dim<1>(dout)); + } + } + + private: + void LayerSetUp(const mxnet::TShape& ishape, const mxnet::TShape& offset_shape, + const mxnet::TShape& mask_shape, const mxnet::TShape& oshape) { + channel_axis_ = 1; // hard code channel axis + const index_t first_spatial_axis = channel_axis_ + 1; + const index_t num_axes = param_.kernel.ndim() + 2; + num_spatial_axes_ = num_axes - first_spatial_axis; + is_1x1_ = true; + for (index_t i = 0; i < param_.kernel.ndim(); ++i) { + is_1x1_ &= param_.kernel[i] == 1 && param_.stride[i] == 1 && param_.pad[i] == 0; + if (!is_1x1_) break; + } + + // batch size + num_ = ishape[0]; + // number of input channels + channels_ = ishape[1]; + group_ = param_.num_group; + conv_out_channels_ = param_.num_filter; + conv_in_channels_ = channels_; + bias_term_ = !param_.no_bias; + kernel_dim_ = conv_in_channels_ / group_ * param_.kernel.Size(); + weight_offset_ = conv_out_channels_ * kernel_dim_ / group_; + conv_out_spatial_dim_ = oshape.ProdShape(2, oshape.ndim()); + col_offset_ = kernel_dim_ * conv_out_spatial_dim_; + output_offset_ = conv_out_channels_ * conv_out_spatial_dim_ / group_; + // size of the column buffer used for storing im2col-ed pixels + im2col_step_ = std::min(param_.im2col_step, static_cast(num_)); + col_buffer_size_ = kernel_dim_ * group_ * im2col_step_ * conv_out_spatial_dim_; + // input/output image size (#channels * height * width) + input_dim_ = ishape.ProdShape(1, ishape.ndim()); + input_offset_dim_ = offset_shape.ProdShape(1, offset_shape.ndim()); + input_mask_dim_ = mask_shape.ProdShape(1, mask_shape.ndim()); + output_dim_ = oshape.ProdShape(1, oshape.ndim()); + num_kernels_im2col_ = conv_in_channels_ * conv_out_spatial_dim_; + num_kernels_col2im_ = input_dim_; + } + + private: + ModulatedDeformableConvolutionParam param_; + index_t channel_axis_; // channel axis of the input + index_t channels_; // number of channels of input image + index_t num_spatial_axes_; // number of spatial axes + index_t num_; // batch size + index_t group_; // number of groups + index_t conv_out_channels_; // number of output channels (num_filter) + index_t conv_out_spatial_dim_; // number of pixels of output images per channel + index_t conv_in_channels_; // number of input channels + index_t kernel_dim_; // number of input channels per group * kernel size + index_t weight_offset_; // number of output channels per group * kernel_dim_ + index_t col_offset_; + index_t output_offset_; + index_t col_buffer_size_; + index_t input_dim_; + index_t input_offset_dim_; + index_t input_mask_dim_; + index_t output_dim_; + index_t num_kernels_im2col_; + index_t num_kernels_col2im_; + index_t im2col_step_; + bool bias_term_; // has bias term? + bool is_1x1_; +}; // class ConvolutionOp + +template +Operator* CreateOp(ModulatedDeformableConvolutionParam param, int dtype, + std::vector *in_shape, + std::vector *out_shape, + Context ctx); + +#if DMLC_USE_CXX11 +class ModulatedDeformableConvolutionProp : public OperatorProperty { + public: + std::vector ListArguments() const override { + if (!param_.no_bias) { + return{ "data", "offset", "mask", "weight", "bias" }; + } else { + return{ "data", "offset", "mask", "weight" }; + } + } + + void Init(const std::vector >& kwargs) override { + using namespace mshadow; + param_.Init(kwargs); + if (param_.kernel.ndim() == 2) { + param_.layout = param_.layout ? param_.layout.value() : mshadow::kNCHW; + if (param_.stride.ndim() == 0) param_.stride = Shape2(1, 1); + if (param_.dilate.ndim() == 0) param_.dilate = Shape2(1, 1); + if (param_.pad.ndim() == 0) param_.pad = Shape2(0, 0); + } else { + LOG(FATAL) << "not implemented"; + } + } + + std::map GetParams() const override { + return param_.__DICT__(); + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + if (!param_.no_bias) { + CHECK_EQ(in_shape->size(), 5U) << "Input:[data, offset, mask, weight, bias]"; + } else { + CHECK_EQ(in_shape->size(), 4U) << "Input:[data, offset, mask, weight]"; + } + out_shape->resize(1, mxnet::TShape()); + const mxnet::TShape &dshp = (*in_shape)[dmconv::kData]; + const mxnet::TShape &oshp = (*in_shape)[dmconv::kOffset]; + const mxnet::TShape &mshp = (*in_shape)[dmconv::kMask]; + if (dshp.ndim() == 0) return false; + if (param_.kernel.ndim() == 2) { + // 2d dmconv + CHECK_EQ(dshp.ndim(), 4U) \ + << "Input data should be 4D in batch-num_filter-y-x"; + CHECK_EQ(oshp.ndim(), 4U) \ + << "Input offset should be 4D in batch-num_filter-y-x"; + CHECK_EQ(mshp.ndim(), 4U) \ + << "Input offset should be 4D in batch-num_filter-y-x"; + Shape<4> dshape = ConvertLayout(dshp.get<4>(), param_.layout.value(), kNCHW); + Shape<4> offsetshape = ConvertLayout(oshp.get<4>(), param_.layout.value(), kNCHW); + Shape<4> maskshape = ConvertLayout(mshp.get<4>(), param_.layout.value(), kNCHW); + Shape<4> wshape = Shape4(param_.num_filter / param_.num_group, dshape[1] / param_.num_group, + param_.kernel[0], param_.kernel[1]); + wshape = ConvertLayout(wshape, kNCHW, param_.layout.value()); + wshape[0] *= param_.num_group; + SHAPE_ASSIGN_CHECK(*in_shape, dmconv::kWeight, wshape); + if (!param_.no_bias) { + SHAPE_ASSIGN_CHECK(*in_shape, dmconv::kBias, Shape1(param_.num_filter)); + } + + const index_t ksize_y = static_cast(param_.kernel[0]); + const index_t ksize_x = static_cast(param_.kernel[1]); + if (dshape[0] > static_cast(param_.im2col_step)) { + CHECK_EQ(dshape[0] % param_.im2col_step, 0U) \ + << "input batchsize must be smaller than or divide im2col_step"; + } + CHECK_EQ(dshape[1] % param_.num_group, 0U) \ + << "input num_filter must divide group size"; + CHECK_EQ(dshape[1] % param_.num_deformable_group, 0U) \ + << "input num_filter must divide deformable group size"; + CHECK_EQ(param_.num_filter % param_.num_group, 0U) \ + << "output num_filter must divide group size"; + CHECK_GT(param_.kernel.Size(), 0U) \ + << "incorrect kernel size: " << param_.kernel; + CHECK_GT(param_.stride.Size(), 0U) \ + << "incorrect stride size: " << param_.stride; + CHECK_GT(param_.dilate.Size(), 0U) \ + << "incorrect dilate size: " << param_.dilate; + Shape<4> oshape; + oshape[0] = dshape[0]; + oshape[1] = param_.num_filter; + oshape[2] = (dshape[2] + 2 * param_.pad[0] - + (param_.dilate[0] * (ksize_y - 1) + 1)) / param_.stride[0] + 1; + oshape[3] = (dshape[3] + 2 * param_.pad[1] - + (param_.dilate[1] * (ksize_x - 1) + 1)) / param_.stride[1] + 1; + SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCHW, param_.layout.value())); + CHECK_EQ(oshape[1] % param_.num_deformable_group, 0U) \ + << "output num_filter must divide deformable group size"; + CHECK_EQ(oshape[2], offsetshape[2]) \ + << "output height must equal to offset map height"; + CHECK_EQ(oshape[3], offsetshape[3]) \ + << "output width must equal to offset map width"; + CHECK_EQ(offsetshape[1] % (param_.kernel[0] * param_.kernel[1]), 0U) \ + << "offset filter must divide deformable group size"; + CHECK_EQ(offsetshape[1] / (2 * param_.kernel[0] * param_.kernel[1]), \ + param_.num_deformable_group) \ + << "offset filter must divide deformable group size"; + CHECK_EQ(oshape[2], maskshape[2]) \ + << "output height must equal to mask map height"; + CHECK_EQ(oshape[3], maskshape[3]) \ + << "output width must equal to mask map width"; + CHECK_EQ(maskshape[1] % (param_.kernel[0] * param_.kernel[1]), 0U) \ + << "offset filter must divide deformable group size"; + CHECK_EQ(maskshape[1] / (param_.kernel[0] * param_.kernel[1]), \ + param_.num_deformable_group) \ + << "offset filter must divide deformable group size"; + // Perform incomplete shape inference. Fill in the missing values in data shape. + // 1) We can always fill in the batch_size. + // 2) We can back-calculate the input height/width if the corresponding stride is 1. + oshape = ConvertLayout((*out_shape)[0].get<4>(), param_.layout.value(), kNCHW); + dshape[0] = oshape[0]; + if (param_.stride[0] == 1) { + dshape[2] = oshape[2] + param_.dilate[0] * (ksize_y - 1) - 2 * param_.pad[0]; + } + if (param_.stride[1] == 1) { + dshape[3] = oshape[3] + param_.dilate[1] * (ksize_x - 1) - 2 * param_.pad[1]; + } + SHAPE_ASSIGN_CHECK(*in_shape, dmconv::kData, + ConvertLayout(dshape, kNCHW, param_.layout.value())); + // Check whether the kernel sizes are valid + if (dshape[2] != 0) { + CHECK_LE(ksize_y, dshape[2] + 2 * param_.pad[0]) << "kernel size exceed input"; + } + if (dshape[3] != 0) { + CHECK_LE(ksize_x, dshape[3] + 2 * param_.pad[1]) << "kernel size exceed input"; + } + return true; + } else { + LOG(FATAL) << "not implemented"; + return false; + } + } + + bool InferType(std::vector *in_type, + std::vector *out_type, + std::vector *aux_type) const override { + CHECK_GE(in_type->size(), 1U); + int dtype = (*in_type)[0]; + CHECK_NE(dtype, -1) << "First input must have specified type"; + for (std::size_t i = 0; i < in_type->size(); ++i) { + if ((*in_type)[i] == -1) { + (*in_type)[i] = dtype; + } else { + UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); + } + } + out_type->clear(); + out_type->push_back(dtype); + return true; + } + + OperatorProperty* Copy() const override { + auto ptr = new ModulatedDeformableConvolutionProp(); + ptr->param_ = param_; + return ptr; + } + + std::string TypeString() const override { + return "_contrib_ModulatedDeformableConvolution"; + } + + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return{ out_grad[dmconv::kOut], in_data[dmconv::kData], + in_data[dmconv::kOffset], in_data[dmconv::kMask], + in_data[dmconv::kWeight] }; + } + + std::vector ForwardResource( + const std::vector &in_shape) const override { + return{ ResourceRequest::kTempSpace }; + } + + std::vector BackwardResource( + const std::vector &in_shape) const override { + return{ ResourceRequest::kTempSpace }; + } + + Operator* CreateOperator(Context ctx) const override { + LOG(FATAL) << "Not Implemented."; + return NULL; + } + + Operator* CreateOperatorEx(Context ctx, std::vector *in_shape, + std::vector *in_type) const override; + + private: + ModulatedDeformableConvolutionParam param_; +}; // class ConvolutionProp +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_CONTRIB_MODULATED_DEFORMABLE_CONVOLUTION_INL_H_ diff --git a/src/operator/contrib/modulated_deformable_convolution.cc b/src/operator/contrib/modulated_deformable_convolution.cc new file mode 100644 index 000000000000..5fa25f797a7a --- /dev/null +++ b/src/operator/contrib/modulated_deformable_convolution.cc @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_convolution.cc + * \brief + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu +*/ + +#include "./modulated_deformable_convolution-inl.h" + +namespace mxnet { +namespace op { +DMLC_REGISTER_PARAMETER(ModulatedDeformableConvolutionParam); + +template<> +Operator* CreateOp(ModulatedDeformableConvolutionParam param, int dtype, + std::vector *in_shape, + std::vector *out_shape, + Context ctx) { + Operator *op = nullptr; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + op = new ModulatedDeformableConvolutionOp(param); + }) + return op; +} + +// DO_BIND_DISPATCH comes from operator_common.h +Operator *ModulatedDeformableConvolutionProp::CreateOperatorEx(Context ctx, + std::vector *in_shape, + std::vector *in_type) const { + std::vector out_shape, aux_shape; + std::vector out_type, aux_type; + CHECK(InferType(in_type, &out_type, &aux_type)); + CHECK(InferShape(in_shape, &out_shape, &aux_shape)); + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], in_shape, &out_shape, ctx); +} + +MXNET_REGISTER_OP_PROPERTY( + _contrib_ModulatedDeformableConvolution, ModulatedDeformableConvolutionProp) +.describe(R"code(Compute 2-D modulated deformable convolution on 4-D input. + +The modulated deformable convolution operation is described in https://arxiv.org/abs/1811.11168 + +For 2-D modulated deformable convolution, the shapes are + +- **data**: *(batch_size, channel, height, width)* +- **offset**: *(batch_size, num_deformable_group * kernel[0] * kernel[1] * 2, height, width)* +- **mask**: *(batch_size, num_deformable_group * kernel[0] * kernel[1], height, width)* +- **weight**: *(num_filter, channel, kernel[0], kernel[1])* +- **bias**: *(num_filter,)* +- **out**: *(batch_size, num_filter, out_height, out_width)*. + +Define:: + + f(x,k,p,s,d) = floor((x+2*p-d*(k-1)-1)/s)+1 + +then we have:: + + out_height=f(height, kernel[0], pad[0], stride[0], dilate[0]) + out_width=f(width, kernel[1], pad[1], stride[1], dilate[1]) + +If ``no_bias`` is set to be true, then the ``bias`` term is ignored. + +The default data ``layout`` is *NCHW*, namely *(batch_size, channle, height, +width)*. + +If ``num_group`` is larger than 1, denoted by *g*, then split the input ``data`` +evenly into *g* parts along the channel axis, and also evenly split ``weight`` +along the first dimension. Next compute the convolution on the *i*-th part of +the data with the *i*-th weight part. The output is obtained by concating all +the *g* results. + +If ``num_deformable_group`` is larger than 1, denoted by *dg*, then split the +input ``offset`` evenly into *dg* parts along the channel axis, and also evenly +split ``out`` evenly into *dg* parts along the channel axis. Next compute the +deformable convolution, apply the *i*-th part of the offset part on the *i*-th +out. + + +Both ``weight`` and ``bias`` are learnable parameters. + + +)code" ADD_FILELINE) +.add_argument("data", "NDArray-or-Symbol", "Input data to the ModulatedDeformableConvolutionOp.") +.add_argument("offset", "NDArray-or-Symbol", "Input offset to ModulatedDeformableConvolutionOp.") +.add_argument("mask", "NDArray-or-Symbol", "Input mask to the ModulatedDeformableConvolutionOp.") +.add_argument("weight", "NDArray-or-Symbol", "Weight matrix.") +.add_argument("bias", "NDArray-or-Symbol", "Bias parameter.") +.add_arguments(ModulatedDeformableConvolutionParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/modulated_deformable_convolution.cu b/src/operator/contrib/modulated_deformable_convolution.cu new file mode 100644 index 000000000000..fce73dd49b1f --- /dev/null +++ b/src/operator/contrib/modulated_deformable_convolution.cu @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_convolution.cu + * \brief + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu +*/ + +#include "./modulated_deformable_convolution-inl.h" +#include + +namespace mxnet { +namespace op { + + template<> + Operator* CreateOp(ModulatedDeformableConvolutionParam param, int dtype, + std::vector *in_shape, + std::vector *out_shape, + Context ctx) { + Operator *op = NULL; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + op = new ModulatedDeformableConvolutionOp(param); + }) + return op; + } + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/nn/modulated_deformable_im2col.cuh b/src/operator/contrib/nn/modulated_deformable_im2col.cuh new file mode 100644 index 000000000000..16d9cef46d4e --- /dev/null +++ b/src/operator/contrib/nn/modulated_deformable_im2col.cuh @@ -0,0 +1,541 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in modulated deformable convolution operators. + * \ref: https://arxiv.org/abs/1811.11168 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu + */ + +#ifndef MXNET_OPERATOR_CONTRIB_NN_MODULATED_DEFORMABLE_IM2COL_CUH_ +#define MXNET_OPERATOR_CONTRIB_NN_MODULATED_DEFORMABLE_IM2COL_CUH_ + +#include +#include +#include +#include +#include +#include "../../mxnet_op.h" +#include "../../../common/cuda_utils.h" + + + +namespace mxnet { +namespace op { + +template +__device__ DType dmcn_im2col_bilinear(const DType* bottom_data, const int data_width, + const int height, const int width, DType h, DType w) { + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + DType lh = h - h_low; + DType lw = w - w_low; + DType hh = 1 - lh, hw = 1 - lw; + + DType v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + DType v2 = 0; + if (h_low >=0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + DType v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + DType v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + DType w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + DType val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ DType dmcn_get_gradient_weight(DType argmax_h, DType argmax_w, + const int h, const int w, const int height, const int width) { + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + DType weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + + +template +__device__ DType dmcn_get_coordinate_weight(DType argmax_h, DType argmax_w, + const int height, const int width, const DType* im_data, + const int data_width, const int bp_dir) { + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + DType weight = 0; + + if (bp_dir == 0) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } else if (bp_dir == 1) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + + +/*! + * \brief deformable_im2col gpu kernel. + * DO NOT call this directly. Use wrapper function im2col() instead; + */ +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const DType* data_im, const DType* data_offset, const DType* data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + DType* data_col) { + CUDA_KERNEL_LOOP(index, n) { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + DType* data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const DType* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const DType* data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const DType* data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const DType* data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const DType offset_h = data_offset_ptr[data_offset_h_ptr]; + const DType offset_w = data_offset_ptr[data_offset_w_ptr]; + const DType mask = data_mask_ptr[data_mask_hw_ptr]; + DType val = static_cast(0); + const DType h_im = h_in + i * dilation_h + offset_h; + const DType w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { + //const DType map_h = i * dilation_h + offset_h; + //const DType map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + + + +/*!\brief + * cpu function of deformable_im2col algorithm + * \param s device stream + * \param data_im pointer of an image (N, C, H, W, ...) in the image batch + * \param data_offset pointer of offset (N, deformable_group*kernel_h*kernel_w*2, H, W, ...) in the offset batch + * \param im_shape input image shape in dimensions (N, C, H, W,) + * \param col_shape column buffer shape (#channels, N, output_im_height, output_im_width, ...) + * \param kernel_shape kernel filter shape + * \param pad pad shape + * \param stride stride shape + * \param dilation dilation shape + * \param deformable_group #offset group that deformable convolution use + * \param data_col column buffer pointer + */ +template +inline void modulated_deformable_im2col(mshadow::Stream* s, + const DType* data_im, const DType* data_offset, const DType* data_mask, + const TShape& im_shape, const TShape& col_shape, const TShape& kernel_shape, + const TShape& pad, const TShape& stride, const TShape& dilation, + const uint32_t deformable_group, DType* data_col) { + // num_axes should be smaller than block size + index_t num_spatial_axes = kernel_shape.ndim(); + CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum); + index_t channel_per_deformable_group = im_shape[1] / deformable_group; + index_t num_kernels = im_shape[1] * col_shape.ProdShape(1, col_shape.ndim()); + using namespace mxnet_op; + switch (num_spatial_axes) { + case 2: + modulated_deformable_im2col_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) + <<::GetStream(s)>>>( + num_kernels, data_im, data_offset, data_mask, im_shape[2], im_shape[3], kernel_shape[0], kernel_shape[1], + pad[0], pad[1], stride[0], stride[1], dilation[0], dilation[1], channel_per_deformable_group, + col_shape[1], im_shape[1], deformable_group, col_shape[2], col_shape[3], data_col); + MSHADOW_CUDA_POST_KERNEL_CHECK(modulated_deformable_im2col_gpu_kernel); + break; + default: + LOG(FATAL) << "im2col_nd_gpu does not support computation with " + << num_spatial_axes << " spatial axes"; + } +} + + +/*! +* \brief deformable_col2im gpu kernel. +* \brief DO NOT call this directly. Use wrapper function deformable_col2im() instead; +*/ +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const DType* data_col, const DType* data_offset, const DType* data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + DType* grad_im, OpReqType req) { + CUDA_KERNEL_LOOP(index, n) { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const DType* data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const DType* data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const DType offset_h = data_offset_ptr[data_offset_h_ptr]; + const DType offset_w = data_offset_ptr[data_offset_w_ptr]; + const DType mask = data_mask_ptr[data_mask_hw_ptr]; + const DType cur_inv_h_data = h_in + i * dilation_h + offset_h; + const DType cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const DType cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1 + ) { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + DType weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + + +/*!\brief + * gpu function of deformable_col2im algorithm + * \param s device stream + * \param data_col start pointer of the column buffer to be filled + * \param data_offset pointer of offset (N, deformable_group*kernel_h*kernel_w*2, H, W, ...) in the offset batch + * \param im_shape input image shape in dimensions (N, C, H, W,) + * \param col_shape column buffer shape + * \param kernel_shape kernel filter shape + * \param pad pad shape + * \param stride stride shape + * \param dilation dilation shape + * \param deformable_group #offset group that deformable convolution use + * \param grad_im pointer of images (N, C, H, W,...) in the image batch + */ +template +inline void modulated_deformable_col2im(mshadow::Stream* s, + const DType* data_col, const DType* data_offset, const DType* data_mask, + const TShape& im_shape, const TShape& col_shape, const TShape& kernel_shape, + const TShape& pad, const TShape& stride, + const TShape& dilation, const uint32_t deformable_group, + DType* grad_im, OpReqType req) { + index_t num_spatial_axes = kernel_shape.ndim(); + index_t im_size = im_shape.ProdShape(1, im_shape.ndim()); + index_t channel_per_deformable_group = im_shape[1] / deformable_group; + index_t num_kernels = col_shape.ProdShape(0, col_shape.ndim()); + // num_axes should be smaller than block size + CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum); + using namespace mxnet_op; + switch (num_spatial_axes) { + case 2: + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // NOLINT_NEXT_LINE(whitespace/operators) + modulated_deformable_col2im_gpu_kernel<<::GetStream(s)>>>( + num_kernels, data_col, data_offset, data_mask, im_shape[1], im_shape[2], im_shape[3], + kernel_shape[0], kernel_shape[1], pad[0], pad[1], stride[0], stride[1], + dilation[0], dilation[1], channel_per_deformable_group, + col_shape[1], deformable_group, col_shape[2], col_shape[3], grad_im, req); + MSHADOW_CUDA_POST_KERNEL_CHECK(modulated_deformable_col2im_gpu_kernel); + break; + default: + LOG(FATAL) << "col2im_nd_gpu does not support computation with " + << num_spatial_axes << " spatial axes"; + } +} + + +/*! + * \brief deformable_col2im_coord gpu kernel. + * \brief DO NOT call this directly. Use wrapper function deformable_col2im_coord() instead; + */ +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const DType* data_col, const DType* data_im, + const DType* data_offset, const DType* data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + DType* grad_offset, DType* grad_mask, OpReqType offset_req, OpReqType mask_req) { + CUDA_KERNEL_LOOP(index, n) { + DType val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const DType* data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const DType* data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const DType* data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const DType* data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const DType offset_h = data_offset_ptr[data_offset_h_ptr]; + const DType offset_w = data_offset_ptr[data_offset_w_ptr]; + const DType mask = data_mask_ptr[data_mask_hw_ptr]; + DType inv_h = h_in + i * dilation_h + offset_h; + DType inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { + inv_h = inv_w = -2; + } else { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const DType weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + + //grad_offset[index] = val; + KERNEL_ASSIGN(grad_offset[index], offset_req, val); + if (offset_c % 2 == 0) + KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + } +} + +/*!\brief + * gpu function of deformable_col2im_coord algorithm + * \param s device stream + * \param data_col start pointer of the column buffer to be filled + * \param data_im pointer of an image (N, C, H, W, ...) in the image batch + * \param data_offset pointer of offset (N, deformable_group*kernel_h*kernel_w*2, H, W, ...) in the offset batch + * \param im_shape input image shape in dimensions (N, C, H, W,) + * \param col_shape column buffer shape + * \param kernel_shape kernel filter shape + * \param pad pad shape + * \param stride stride shape + * \param dilation dilation shape + * \param deformable_group #offset group that deformable convolution use + * \param grad_offset pointer of the offset (N, deformable_group*kernel_h*kernel_w*2, H, W,...) in the offset batch + */ +template +inline void modulated_deformable_col2im_coord(mshadow::Stream* s, + const DType* data_col, const DType* data_im, const DType* data_offset, const DType* data_mask, + const TShape& im_shape, const TShape& col_shape, const TShape& kernel_shape, + const TShape& pad, const TShape& stride, + const TShape& dilation, const uint32_t deformable_group, + DType* grad_offset, DType* grad_mask, OpReqType offset_req, OpReqType mask_req) { + index_t num_spatial_axes = kernel_shape.ndim(); + index_t num_kernels = col_shape[1] * col_shape[2] * col_shape[3] * 2 * kernel_shape[0] * kernel_shape[1] * deformable_group; + index_t channel_per_deformable_group = col_shape[0] / deformable_group; + // num_axes should be smaller than block size + CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum); + using namespace mxnet_op; + switch (num_spatial_axes) { + case 2: + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + // NOLINT_NEXT_LINE(whitespace/operators) + + modulated_deformable_col2im_coord_gpu_kernel << ::GetStream(s) >> >( + num_kernels, data_col, data_im, data_offset, data_mask, im_shape[1], im_shape[2], im_shape[3], + kernel_shape[0], kernel_shape[1], pad[0], pad[1], stride[0], stride[1], + dilation[0], dilation[1], channel_per_deformable_group, + col_shape[1], 2 * kernel_shape[0] * kernel_shape[1] * deformable_group, deformable_group, col_shape[2], col_shape[3], + grad_offset, grad_mask, offset_req, mask_req); + MSHADOW_CUDA_POST_KERNEL_CHECK(modulated_deformable_col2im_coord_gpu_kernel); + break; + default: + LOG(FATAL) << "col2im_nd_gpu does not support computation with " + << num_spatial_axes << " spatial axes"; + } +} + + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_NN_DEFORMABLE_MASKED_IM2COL_CUH_ diff --git a/src/operator/contrib/nn/modulated_deformable_im2col.h b/src/operator/contrib/nn/modulated_deformable_im2col.h new file mode 100644 index 000000000000..b50eb13a8122 --- /dev/null +++ b/src/operator/contrib/nn/modulated_deformable_im2col.h @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.h + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1811.11168 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu + */ + +#ifndef MXNET_OPERATOR_CONTRIB_NN_MODULATED_DEFORMABLE_IM2COL_H_ +#define MXNET_OPERATOR_CONTRIB_NN_MODULATED_DEFORMABLE_IM2COL_H_ + +#include +#include +#include +#include +#include +#include "../../mxnet_op.h" + +namespace mxnet { +namespace op { + +template +inline DType dmcn_im2col_bilinear_cpu(const DType* bottom_data, const int data_width, + const int height, const int width, DType h, DType w) { + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + DType lh = h - h_low; + DType lw = w - w_low; + DType hh = 1 - lh, hw = 1 - lw; + + DType v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + DType v2 = 0; + if (h_low >=0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + DType v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + DType v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + DType w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + DType val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +/*! +* \brief deformable_col2im gpu kernel. +* \brief DO NOT call this directly. Use wrapper function deformable_col2im() instead; +*/ +struct modulated_deformable_col2im_cpu_kernel { + template + MSHADOW_XINLINE static void Map(const int index, + const DType* data_im, const DType* data_offset, const DType* data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + DType* data_col) { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + DType* data_col_ptr = data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + // const DType* data_im_ptr = data_im + + // ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const DType* data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const DType* data_offset_ptr = data_offset + + (b_col * deformable_group + deformable_group_index) * 2 + * kernel_h * kernel_w * height_col * width_col; + + const DType* data_mask_ptr = data_mask + + (b_col * deformable_group + deformable_group_index) * kernel_h + * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) + * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) + * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const DType offset_h = data_offset_ptr[data_offset_h_ptr]; + const DType offset_w = data_offset_ptr[data_offset_w_ptr]; + const DType mask = data_mask_ptr[data_mask_hw_ptr]; + DType val = static_cast(0); + const DType h_im = h_in + i * dilation_h + offset_h; + const DType w_im = w_in + j * dilation_w + offset_w; + // if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { + // const DType map_h = i * dilation_h + offset_h; + // const DType map_w = j * dilation_w + offset_w; + // const int cur_height = height - h_in; + // const int cur_width = width - w_in; + // val = dmcn_im2col_bilinear_cpu( + // data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + // data_col_ptr += height_col * width_col; + } + } + } +}; + +/*!\brief + * cpu function of deformable_im2col algorithm + * \param s device stream + * \param data_im pointer of an image (C, H, W, ...) in the image batch + * \param data_offset pointer of offset (C, H, W, ...) in the offset batch + * \param im_shape input image shape in dimensions (N, C, H, W,) + * \param col_shape column buffer shape (#channels, output_im_height, output_im_width, ...) + * \param kernel_shape kernel filter shape + * \param pad pad shape + * \param stride stride shape + * \param dilation dilation shape + * \param deformable_group #offset group that deformable convolution use + * \param data_col column buffer pointer + */ +template +inline void modulated_deformable_im2col(mshadow::Stream* s, + const DType* data_im, const DType* data_offset, const DType* data_mask, + const TShape& im_shape, const TShape& col_shape, const TShape& kernel_shape, + const TShape& pad, const TShape& stride, const TShape& dilation, + const uint32_t deformable_group, DType* data_col) { + // num_axes should be smaller than block size + index_t num_spatial_axes = kernel_shape.ndim(); + index_t channel_per_deformable_group = im_shape[1] / deformable_group; + index_t num_kernels = im_shape[1] * col_shape.ProdShape(1, col_shape.ndim()); + using namespace mxnet_op; + if (2 == num_spatial_axes) { + Kernel::Launch( + s, num_kernels, data_im, data_offset, data_mask, + im_shape[2], im_shape[3], kernel_shape[0], kernel_shape[1], + pad[0], pad[1], stride[0], stride[1], dilation[0], dilation[1], + channel_per_deformable_group, col_shape[1], im_shape[1], deformable_group, + col_shape[2], col_shape[3], data_col); + } else { + LOG(FATAL) << "not implemented"; + } +} + + +/*!\brief + * cpu function of deformable_col2im algorithm + * \param s device stream + * \param data_col start pointer of the column buffer to be filled + * \param data_offset pointer of offset (C, H, W, ...) in the offset batch + * \param im_shape input image shape in dimensions (N, C, H, W,) + * \param col_shape column buffer shape + * \param kernel_shape kernel filter shape + * \param pad pad shape + * \param stride stride shape + * \param dilation dilation shape + * \param deformable_group #offset group that deformable convolution use + * \param grad_im pointer of a image (C, H, W,...) in the image batch + */ +template +inline void modulated_deformable_col2im(mshadow::Stream* s, + const DType* data_col, const DType* data_offset, const DType* data_mask, + const TShape& im_shape, const TShape& col_shape, const TShape& kernel_shape, + const TShape& pad, const TShape& stride, + const TShape& dilation, const uint32_t deformable_group, + DType* grad_im, OpReqType req) { + LOG(FATAL) << "only implemented in GPU"; +} + + +/*!\brief + * cpu function of deformable_col2im_coord algorithm + * \param s device stream + * \param data_col start pointer of the column buffer to be filled + * \param data_im pointer of an image (C, H, W, ...) in the image batch + * \param data_offset pointer of offset (C, H, W, ...) in the offset batch + * \param im_shape input image shape in dimensions (N, C, H, W,) + * \param col_shape column buffer shape + * \param kernel_shape kernel filter shape + * \param pad pad shape + * \param stride stride shape + * \param dilation dilation shape + * \param deformable_group #offset group that deformable convolution use + * \param grad_offset pointer of the offset (C, H, W,...) in the offset batch + */ + +template +inline void modulated_deformable_col2im_coord(mshadow::Stream* s, + const DType* data_col, const DType* data_im, const DType* data_offset, const DType* data_mask, + const TShape& im_shape, const TShape& col_shape, const TShape& kernel_shape, + const TShape& pad, const TShape& stride, + const TShape& dilation, const uint32_t deformable_group, + DType* grad_offset, DType* grad_mask, OpReqType offset_req, OpReqType mask_req) { + LOG(FATAL) << "only implemented in GPU"; +} + +} // namespace op +} // namespace mxnet +#ifdef __CUDACC__ +#include "./modulated_deformable_im2col.cuh" +#endif +#endif // MXNET_OPERATOR_CONTRIB_NN_MODULATED_DEFORMABLE_IM2COL_H_ diff --git a/tests/python/gpu/test_gluon_contrib_gpu.py b/tests/python/gpu/test_gluon_contrib_gpu.py index 1d19d850dd8e..348e9f77acc8 100644 --- a/tests/python/gpu/test_gluon_contrib_gpu.py +++ b/tests/python/gpu/test_gluon_contrib_gpu.py @@ -57,6 +57,33 @@ def test_DeformableConvolution(): y = net(x) y.backward() +def test_ModulatedDeformableConvolution(): + """test of the deformable convolution layer with possible combinations of arguments, + currently this layer only supports gpu + """ + net = nn.HybridSequential() + net.add( + DeformableConvolution(10, kernel_size=(3, 3), strides=1, padding=0), + DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu', + offset_use_bias=False, use_bias=False), + DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu', + offset_use_bias=False), + DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu', + use_bias=False), + DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, offset_use_bias=False, use_bias=False), + DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, offset_use_bias=False), + DeformableConvolution(12, kernel_size=(3, 2), strides=1, padding=0, use_bias=False), + DeformableConvolution(12, kernel_size=(3, 2), strides=1, padding=0, use_bias=False, num_deformable_group=4), + ) + + ctx = mx.gpu() + net.initialize(force_reinit=True, ctx=ctx) + net.hybridize() + + x = mx.nd.random.uniform(shape=(8, 5, 30, 31), ctx=ctx) + with mx.autograd.record(): + y = net(x) + y.backward() if __name__ == '__main__': import nose diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index e78884f74c71..476dfac24a61 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -23,7 +23,7 @@ import itertools from numpy.testing import assert_allclose, assert_array_equal from mxnet.test_utils import * -from common import with_seed +from common import with_seed, assert_raises_cudnn_not_satisfied import unittest def test_box_nms_op(): @@ -409,6 +409,42 @@ def test_op_mrcnn_mask_target(): assert_almost_equal(mask_targets.asnumpy(), gt_mask_targets.asnumpy()) assert_almost_equal(mask_cls.asnumpy(), gt_mask_cls.asnumpy()) +@with_seed() +def test_modulated_deformable_convolution(): + for num_batch in [1, 2]: + for num_channel_data, num_deformable_group in itertools.product([4, 8], [1, 2]): + for input_height, input_width in itertools.product([5, 6], [5, 6]): + for dilate in [(1, 1), (2, 2)]: + for grad_nodes in [['im_data'], ['offset_data'], ['weight']]: + output_height = input_height + output_width = input_width + im_data = np.random.rand(num_batch, num_channel_data, input_height, input_width) + offset_data = \ + np.random.rand(num_batch, num_deformable_group * 3 * 3 * 2, output_height, output_width)\ + * 0.8 + 0.1 + mask_data = np.random.rand(num_batch, num_deformable_group * 3 * 3, output_height, output_width) + mask_data = 0.5 * (1 + np.tanh(0.5 * mask_data)) # sigmoid + weight = np.random.normal(0, 0.001, (num_channel_data, num_channel_data, 3, 3)) + bias = np.zeros(num_channel_data) + + im_data_var = mx.symbol.Variable(name="im_data") + offset_data_var = mx.symbol.Variable(name="offset_data") + mask_data_var = mx.symbol.Variable(name="mask_data") + weight_var = mx.symbol.Variable(name="weight") + bias_var = mx.symbol.Variable(name="bias") + op = mx.sym.contrib.ModulatedDeformableConvolution(name='test_op', data=im_data_var, + offset=offset_data_var, mask=mask_data_var, + weight=weight_var, bias=bias_var, + num_filter=num_channel_data, pad=dilate, + kernel=(3, 3), stride=(1, 1), dilate=dilate, + num_deformable_group=num_deformable_group) + if grad_nodes[0] == 'offset_data': + # wider tolerance needed for coordinate differential + rtol, atol = 1.0, 1e-2 + else: + rtol, atol = 0.05, 1e-3 + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index 123db085e817..fdba553c8560 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -402,6 +402,36 @@ def test_contrib_unroll(): check_unroll(cell_type, num_states, 'TNC') check_unroll(cell_type, num_states, 'NTC') +@with_seed() +def test_ModulatedDeformableConvolution(): + """test of the deformable convolution layer with possible combinations of arguments, + currently this layer only supports gpu + """ + from mxnet.gluon.contrib.cnn import DeformableConvolution + net = nn.HybridSequential() + net.add( + DeformableConvolution(10, kernel_size=(3, 3), strides=1, padding=0), + DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu', + offset_use_bias=False, use_bias=False), + DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu', + offset_use_bias=False), + DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu', + use_bias=False), + DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, offset_use_bias=False, use_bias=False), + DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, offset_use_bias=False), + DeformableConvolution(12, kernel_size=(3, 2), strides=1, padding=0, use_bias=False), + DeformableConvolution(12, kernel_size=(3, 2), strides=1, padding=0, use_bias=False, num_deformable_group=4), + ) + + ctx = mx.cpu() + + net.initialize(force_reinit=True, ctx=ctx) + net.hybridize() + + x = mx.nd.random.uniform(shape=(8, 5, 30, 31), ctx=ctx) + with mx.autograd.record(): + y = net(x) + if __name__ == '__main__': import nose