@@ -966,6 +966,12 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
966
966
if (!supportedPrimitiveDescriptors.empty ())
967
967
return ;
968
968
969
+ const int simd_w = mayiuse (cpu::x64::avx512_common) ? 16 : 8 ;
970
+ if (group != 1 && (((getParentEdgeAt (0 )->getDims ()[1 ] / group) % simd_w != 0 )
971
+ || ((getChildEdgeAt (0 )->getDims ()[1 ] / group) % simd_w != 0 ))) {
972
+ enforceRef = true ;
973
+ }
974
+
969
975
size_t inputsNumber = getOriginalInputsNumber ();
970
976
InferenceEngine::LayerConfig config;
971
977
config.dynBatchSupport = false ;
@@ -986,19 +992,20 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
986
992
config.outConfs [0 ].inPlace = -1 ;
987
993
988
994
impl_desc_type impl_type;
989
- // if (mayiuse(cpu::x64::avx512_common)) {
990
- // impl_type = impl_desc_type::jit_avx512;
991
- // } else if (mayiuse(cpu::x64::avx2)) {
992
- // impl_type = impl_desc_type::jit_avx2;
993
- // } else if (mayiuse(cpu::x64::sse41)) {
994
- // impl_type = impl_desc_type::jit_sse42;
995
- // } else {
996
- // impl_type = impl_desc_type::ref;
997
- // }
998
- impl_type = impl_desc_type::ref;
999
-
1000
- if (false && mayiuse (cpu::x64::sse41)) {
1001
- // optimzed implementation
995
+ if (enforceRef) {
996
+ impl_type = impl_desc_type::ref;
997
+ } else if (mayiuse (cpu::x64::avx512_common)) {
998
+ impl_type = impl_desc_type::jit_avx512;
999
+ } else if (mayiuse (cpu::x64::avx2)) {
1000
+ impl_type = impl_desc_type::jit_avx2;
1001
+ } else if (mayiuse (cpu::x64::sse41)) {
1002
+ impl_type = impl_desc_type::jit_sse42;
1003
+ } else {
1004
+ impl_type = impl_desc_type::ref;
1005
+ }
1006
+
1007
+ if (!enforceRef && mayiuse (cpu::x64::sse41)) {
1008
+ // optimized implementation
1002
1009
auto dataFormat = memory::format_tag::nhwc;
1003
1010
auto offFormat = memory::format_tag::nchw;
1004
1011
auto weiFormat = group > 1 ? mayiuse (avx512_common) ? memory::format_tag::gOIhw16i16o : memory::format_tag::gOIhw8i8o
@@ -1097,13 +1104,15 @@ void MKLDNNDeformableConvolutionNode::createPrimitive() {
1097
1104
1098
1105
jcp.nthr = dnnl_get_max_threads ();
1099
1106
1100
- // if (mayiuse(cpu::x64::avx512_common)) {
1101
- // def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx512_common>(jcp));
1102
- // } else if (mayiuse(cpu::x64::avx2)) {
1103
- // def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx2>(jcp));
1104
- // } else if (mayiuse(cpu::x64::sse41)) {
1105
- // def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::sse41>(jcp));
1106
- // }
1107
+ if (enforceRef) {
1108
+ return ;
1109
+ } else if (mayiuse (cpu::x64::avx512_common)) {
1110
+ def_conv_kernel.reset (new jit_uni_def_conv_kernel_f32<cpu::x64::avx512_common>(jcp));
1111
+ } else if (mayiuse (cpu::x64::avx2)) {
1112
+ def_conv_kernel.reset (new jit_uni_def_conv_kernel_f32<cpu::x64::avx2>(jcp));
1113
+ } else if (mayiuse (cpu::x64::sse41)) {
1114
+ def_conv_kernel.reset (new jit_uni_def_conv_kernel_f32<cpu::x64::sse41>(jcp));
1115
+ }
1107
1116
1108
1117
if (def_conv_kernel)
1109
1118
def_conv_kernel->create_ker ();
0 commit comments