diff --git a/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp b/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp index f3e9291e0ee..fd4c28ed9b3 100644 --- a/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp @@ -1566,14 +1566,19 @@ status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, && everyone_is(128, jcp.oh, jcp.ow) && everyone_is(3, jcp.kh, jcp.kw) && everyone_is(2, jcp.stride_h, jcp.stride_w)) - || (jcp.ic == 256 && jcp.oc == 512 - && everyone_is(129, jcp.ih, jcp.iw) - && everyone_is(64, jcp.oh, jcp.ow) - && everyone_is(3, jcp.kh, jcp.kw) - && everyone_is(2, jcp.stride_h, jcp.stride_w)) || (jcp.ic == 256 && jcp.oc == 512 && jcp.ih == 49 && jcp.iw == 41 && jcp.oh == 23 && jcp.ow == 19 && everyone_is(5, jcp.kh, jcp.kw) + && everyone_is(2, jcp.stride_h, jcp.stride_w)) + || (jcp.ic == 64 && jcp.oc == 128 + && everyone_is(14, jcp.ih, jcp.iw) + && everyone_is(7, jcp.oh, jcp.ow) + && everyone_is(4, jcp.kh, jcp.kw) + && everyone_is(2, jcp.stride_h, jcp.stride_w)) + || (jcp.ic == 1 && jcp.oc == 64 + && everyone_is(28, jcp.ih, jcp.iw) + && everyone_is(14, jcp.oh, jcp.ow) + && everyone_is(4, jcp.kh, jcp.kw) && everyone_is(2, jcp.stride_h, jcp.stride_w))); VDISPATCH_CONV_IC(!(is_f32 && is_regression_shape), "implementation skipped due to low performance");