Skip to content

Commit b3d6c60

Browse files
committed
...
1 parent 3548bd6 commit b3d6c60

File tree

2 files changed

+30
-20
lines changed

2 files changed

+30
-20
lines changed

inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.cpp

+29-20
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,12 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
966966
if (!supportedPrimitiveDescriptors.empty())
967967
return;
968968

969+
const int simd_w = mayiuse(cpu::x64::avx512_common) ? 16 : 8;
970+
if (group != 1 && (((getParentEdgeAt(0)->getShape().getStaticDims()[0] / group) % simd_w != 0)
971+
|| ((getChildEdgeAt(0)->getShape().getStaticDims()[1] / group) % simd_w != 0))) {
972+
enforceRef = true;
973+
}
974+
969975
size_t inputsNumber = getOriginalInputsNumber();
970976
NodeConfig config;
971977
config.dynBatchSupport = false;
@@ -986,19 +992,20 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
986992
config.outConfs[0].inPlace = -1;
987993

988994
impl_desc_type impl_type;
989-
// if (mayiuse(cpu::x64::avx512_common)) {
990-
// impl_type = impl_desc_type::jit_avx512;
991-
// } else if (mayiuse(cpu::x64::avx2)) {
992-
// impl_type = impl_desc_type::jit_avx2;
993-
// } else if (mayiuse(cpu::x64::sse41)) {
994-
// impl_type = impl_desc_type::jit_sse42;
995-
// } else {
996-
// impl_type = impl_desc_type::ref;
997-
// }
998-
impl_type = impl_desc_type::ref;
999-
1000-
if (false && mayiuse(cpu::x64::sse41)) {
1001-
// optimzed implementation
995+
if (enforceRef) {
996+
impl_type = impl_desc_type::ref;
997+
} else if (mayiuse(cpu::x64::avx512_common)) {
998+
impl_type = impl_desc_type::jit_avx512;
999+
} else if (mayiuse(cpu::x64::avx2)) {
1000+
impl_type = impl_desc_type::jit_avx2;
1001+
} else if (mayiuse(cpu::x64::sse41)) {
1002+
impl_type = impl_desc_type::jit_sse42;
1003+
} else {
1004+
impl_type = impl_desc_type::ref;
1005+
}
1006+
1007+
if (!enforceRef && mayiuse(cpu::x64::sse41)) {
1008+
// optimized implementation
10021009
auto dataFormat = memory::format_tag::nhwc;
10031010
auto offFormat = memory::format_tag::nchw;
10041011
auto weiFormat = group > 1 ? mayiuse(avx512_common) ? memory::format_tag::gOIhw16i16o : memory::format_tag::gOIhw8i8o
@@ -1107,13 +1114,15 @@ void MKLDNNDeformableConvolutionNode::createPrimitive() {
11071114

11081115
jcp.nthr = dnnl_get_max_threads();
11091116

1110-
// if (mayiuse(cpu::x64::avx512_common)) {
1111-
// def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx512_common>(jcp));
1112-
// } else if (mayiuse(cpu::x64::avx2)) {
1113-
// def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx2>(jcp));
1114-
// } else if (mayiuse(cpu::x64::sse41)) {
1115-
// def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::sse41>(jcp));
1116-
// }
1117+
if (enforceRef) {
1118+
return;
1119+
} else if (mayiuse(cpu::x64::avx512_common)) {
1120+
def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx512_common>(jcp));
1121+
} else if (mayiuse(cpu::x64::avx2)) {
1122+
def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::avx2>(jcp));
1123+
} else if (mayiuse(cpu::x64::sse41)) {
1124+
def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32<cpu::x64::sse41>(jcp));
1125+
}
11171126

11181127
if (def_conv_kernel)
11191128
def_conv_kernel->create_ker();

inference-engine/src/mkldnn_plugin/nodes/mkldnn_def_conv_node.h

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class MKLDNNDeformableConvolutionNode : public MKLDNNNode {
7979
bool canBeInPlace() const override {
8080
return false;
8181
}
82+
bool enforceRef = false;
8283

8384
InferenceEngine::Precision getRuntimePrecision() const override;
8485

0 commit comments

Comments
 (0)