diff --git a/python/mxnet/gluon/contrib/cnn/conv_layers.py b/python/mxnet/gluon/contrib/cnn/conv_layers.py index 098463eca968..c4924c130a28 100644 --- a/python/mxnet/gluon/contrib/cnn/conv_layers.py +++ b/python/mxnet/gluon/contrib/cnn/conv_layers.py @@ -313,7 +313,8 @@ def __init__(self, channels, kernel_size=(1, 1), strides=(1, 1), padding=(0, 0), dilation = (dilation,) * len(kernel_size) self._op_name = op_name - offset_channels = 27 + offset_channels = num_deformable_group * 3 * kernel_size[0] * kernel_size[1] + self.offset_split_index = num_deformable_group * 2 * kernel_size[0] * kernel_size[1] self._kwargs_offset = { 'kernel': kernel_size, 'stride': strides, 'dilate': dilation, 'pad': padding, 'num_filter': offset_channels, 'num_group': groups, @@ -377,8 +378,8 @@ def hybrid_forward(self, F, x, offset_weight, deformable_conv_weight, offset_bia else: offset = F.Convolution(x, offset_weight, offset_bias, cudnn_off=True, **self._kwargs_offset) - offset_t = F.slice_axis(offset, axis=1, begin=0, end=18) - mask = F.slice_axis(offset, axis=1, begin=18, end=None) + offset_t = F.slice_axis(offset, axis=1, begin=0, end=self.offset_split_index) + mask = F.slice_axis(offset, axis=1, begin=self.offset_split_index, end=None) mask = F.sigmoid(mask) * 2 if deformable_conv_bias is None: diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index fdba553c8560..0ed0d4e8a545 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -411,6 +411,10 @@ def test_ModulatedDeformableConvolution(): net = nn.HybridSequential() net.add( DeformableConvolution(10, kernel_size=(3, 3), strides=1, padding=0), + DeformableConvolution(10, kernel_size=(1, 1), strides=1, padding=0), + DeformableConvolution(10, kernel_size=(5, 5), strides=1, padding=0), + DeformableConvolution(10, kernel_size=(3, 5), strides=1, padding=0), + DeformableConvolution(10, kernel_size=(5, 1), strides=1, padding=0, num_deformable_group=2), DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu', offset_use_bias=False, use_bias=False), DeformableConvolution(10, kernel_size=(3, 2), strides=1, padding=0, activation='relu',