diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index e36a0f008821..a13337b122c3 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -329,6 +329,7 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) { case mkldnn_nchw: case mkldnn_nhwc: case mkldnn_chwn: + case mkldnn_nChw4c: case mkldnn_nChw8c: case mkldnn_nChw16c: return mkldnn_nchw; @@ -338,6 +339,7 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) { case mkldnn_iohw: case mkldnn_oIhw8i: case mkldnn_oIhw16i: + case mkldnn_OIhw4i4o: case mkldnn_OIhw8i8o: case mkldnn_hwio_s8s8: case mkldnn_OIhw16i16o: @@ -376,6 +378,7 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) { case mkldnn_giohw: case mkldnn_hwigo: case mkldnn_hwigo_s8s8: + case mkldnn_gOIhw4i4o: case mkldnn_gOIhw8i8o: case mkldnn_gOIhw16i16o: case mkldnn_gOIhw4i16o4i: @@ -383,6 +386,7 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) { case mkldnn_gOIhw8i16o2i: case mkldnn_gOIhw8o16i2o: case mkldnn_gOIhw8o8i: + case mkldnn_gOIhw4o4i: case mkldnn_gOIhw16o16i: case mkldnn_gIOhw16o16i: case mkldnn_gOihw8o: