From 6373936bf89fc8b31525fd4677f09e7e67471f4a Mon Sep 17 00:00:00 2001 From: reminisce Date: Tue, 4 Feb 2020 18:33:54 -0800 Subject: [PATCH] Remove dilation restriction for conv3d (#17491) * Remove conv3d dilation restriction * Remove comment --- src/operator/nn/convolution.cc | 4 ---- tests/python/unittest/test_operator.py | 4 ++++ 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 6d9f84ffc510..6c8ab3a8f7ec 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -223,8 +223,6 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*in_shape, conv::kBias, Shape1(param_.num_filter)); } - // Note: 3D dilation currently not supported. - // Calculations below done to preserve symmetry with 1D/2D code. const index_t dilated_ksize_d = param_.DilatedKernelSize(0); const index_t dilated_ksize_y = param_.DilatedKernelSize(1); const index_t dilated_ksize_x = param_.DilatedKernelSize(2); @@ -239,8 +237,6 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs, << "incorrect stride size: " << param_.stride; CHECK_GT(param_.dilate.Size(), 0U) \ << "incorrect dilate size: " << param_.dilate; - CHECK_EQ(param_.dilate.Size(), 1U) - << "Dilate is not supported in 3d convolution"; Shape<5> oshape; oshape[0] = dshape[0]; oshape[1] = param_.num_filter; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 9ae35f15748a..37f737616efb 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2600,6 +2600,10 @@ def test_convolution_dilated_impulse_response(): for dil in [ (1,1), (2,2), (3,3) ]: for ks in [ (3,3), (4,4), (2,3), (3,2), (1,1) ]: test_run_convolution_dilated_impulse_response(dil=dil, kernel_shape=ks) + # 3D + for dil in [ (1,1,1), (2,2,2), (3,3,3) ]: + for ks in [ (3,3,3), (4,4,4), (2,3,4), (3,2,4), (1,1,1) ]: + test_run_convolution_dilated_impulse_response(dil=dil, kernel_shape=ks) @with_seed()