From fa4364057891fdec528d9442c88d0715306bff2d Mon Sep 17 00:00:00 2001 From: Andrey Kalinin Date: Wed, 25 Oct 2023 17:36:39 -0700 Subject: [PATCH] x64: brgemm unrolled kernel: update output prefetching --- src/cpu/x64/brgemm/brgemm.cpp | 3 +++ src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp | 1 - src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp | 2 +- src/cpu/x64/jit_brgemm_conv_utils.cpp | 8 ++++---- src/cpu/x64/jit_brgemm_inner_product_utils.cpp | 2 +- src/cpu/x64/matmul/brgemm_matmul.cpp | 2 +- 6 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/cpu/x64/brgemm/brgemm.cpp b/src/cpu/x64/brgemm/brgemm.cpp index ae1680c7146..dccc657032f 100644 --- a/src/cpu/x64/brgemm/brgemm.cpp +++ b/src/cpu/x64/brgemm/brgemm.cpp @@ -451,6 +451,9 @@ status_t brgemm_desc_set_attr(brgemm_t *brg, const brgemm_attr_t &brgattr) { if (brgattr.hint_innermost_loop != brgemm_innermost_undef) brg->innermost_loop = brgattr.hint_innermost_loop; + if (brgattr.hint_prefetching == brgemm_kernel_prefetching_t::brgemm_prf0 + && brg->prfC.dist0 < 0) + brg->prfC.dist0 = 0; if (brgattr.hint_prefetching == brgemm_kernel_prefetching_t::brgemm_prf1 && brg->prfC.dist1 < 0) brg->prfC.dist1 = 0; diff --git a/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp b/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp index 382ea4029d8..64fcfb7c00b 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp @@ -1014,7 +1014,6 @@ void jit_brgemm_amx_uker_base_t::uni_prefetch( if (for_write) { switch (pft) { case brgemm_prf0: prefetchw(addr); break; - case brgemm_prf1: prefetchwt1(addr); break; default: break; } } else { diff --git a/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp b/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp index 3fae80b7322..ba65c960b27 100644 --- a/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp @@ -1892,7 +1892,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, if (is_amx(isa) && (/* heuristic */ jcp.kw_sets == 1 && jcp.iw < 256)) { jcp.use_M_mask = 0; - jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; + jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0; // assuming 2x2 decomposition in amx brgemm kernel // and overlap of input by kw diff --git a/src/cpu/x64/jit_brgemm_conv_utils.cpp b/src/cpu/x64/jit_brgemm_conv_utils.cpp index a735a8fe134..971aa73b78b 100644 --- a/src/cpu/x64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_utils.cpp @@ -2032,7 +2032,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, bool use_inversion, jcp.use_M_mask = jcp.is_os_blocking ? 2 : 0; jcp.use_uker = true; jcp.use_interleave_stores = true; - jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; + jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0; // assuming 2x2 decomposition in amx brgemm kernel // and overlap of input by kw const auto bd_blocking = 2 * jcp.amx_h; @@ -2067,7 +2067,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, bool use_inversion, if (is_amx(isa) && jcp.ow < (8 * 1024)) { jcp.use_uker = true; jcp.use_interleave_stores = true; - jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; + jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0; } try_exec_type_res = try_exec_type(); @@ -2339,7 +2339,7 @@ status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, return status::unimplemented; if (jcp.use_uker) - jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; + jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0; if (!jcp.wei_plain) CHECK(pick_tags(jcp, src_md, weights_md, dst_md, bias_md)); CHECK(attr.set_default_formats(&dst_md)); @@ -3073,7 +3073,7 @@ status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp, jcp.od_block = utils::saturate(1, jcp.od, od_block_limit); jcp.use_interleave_stores = false; - jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; + jcp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0; jcp.amx_tile_load_xx = false; if (one_of(jcp.harness, harness_2d_reduction, harness_3d_reduction)) { diff --git a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp index f9c21197f4b..8958d64b33d 100644 --- a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp @@ -1417,7 +1417,7 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa, jbgp.use_uker = true; jbgp.use_interleave_stores = jbgp.use_uker; if (jbgp.use_uker) - jbgp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; + jbgp.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0; CHECK(set_or_check_tags()); CHECK(attr.set_default_formats(&dst_md)); diff --git a/src/cpu/x64/matmul/brgemm_matmul.cpp b/src/cpu/x64/matmul/brgemm_matmul.cpp index 82ede24f36b..d65614b8154 100644 --- a/src/cpu/x64/matmul/brgemm_matmul.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul.cpp @@ -174,7 +174,7 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { brgattr.hint_expected_B_size = vN * vK * bs; brgattr.hint_expected_C_size = vM * vN * bs; brgattr.hint_innermost_loop = brgemm_innermost_undef; - brgattr.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; + brgattr.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf0; } CHECK(brgemm_desc_set_attr(&brg, brgattr));