Skip to content

Commit

Permalink
cpu: x64: matmul: enable binary po bcast per_mb_w
Browse files Browse the repository at this point in the history
  • Loading branch information
tczeszun authored and tprimak committed Dec 9, 2021
1 parent c4dc38a commit be261ab
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 9 deletions.
1 change: 1 addition & 0 deletions src/cpu/x64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
{broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::scalar,
broadcasting_strategy_t::per_mb_spatial,
broadcasting_strategy_t::per_mb_w,
broadcasting_strategy_t::no_broadcast})))
return status::unimplemented;

Expand Down
9 changes: 7 additions & 2 deletions src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator {
broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::per_mb_spatial,
broadcasting_strategy_t::per_mb_w,
broadcasting_strategy_t::no_broadcast};
const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<size_t>(Xbyak::Zmm(1).getIdx()), this->r14,
Expand All @@ -70,15 +71,18 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator {

using namespace dnnl::impl::cpu::binary_injector_utils;
std::tie(with_binary_per_oc_bcast_, with_binary_per_oc_sp_bcast_,
with_binary_channel_bcast_, with_binary_no_bcast_)
with_binary_channel_bcast_, with_binary_per_mb_w_bcast_,
with_binary_no_bcast_)
= bcast_strategies_present_tup(brg.attr->post_ops_.entry_,
dst_md_wrapper, broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::per_mb_spatial,
broadcasting_strategy_t::per_mb_w,
broadcasting_strategy_t::no_broadcast);
handle_binary_po_offset_ = with_binary_per_oc_bcast_
|| with_binary_per_oc_sp_bcast_
|| with_binary_channel_bcast_ || with_binary_no_bcast_;
|| with_binary_channel_bcast_ || with_binary_per_mb_w_bcast_
|| with_binary_no_bcast_;
}
use_ils = brg.brgattr.use_interleave_stores;
}
Expand Down Expand Up @@ -149,6 +153,7 @@ struct jit_brgemm_amx_uker_base_t : public jit_generator {
bool with_binary_per_oc_bcast_ = false;
bool with_binary_per_oc_sp_bcast_ = false;
bool with_binary_channel_bcast_ = false;
bool with_binary_per_mb_w_bcast_ = false;
bool with_binary_no_bcast_ = false;

size_t reg_b_offset_ = 0;
Expand Down
9 changes: 7 additions & 2 deletions src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::per_mb_spatial,
broadcasting_strategy_t::per_mb_w,
broadcasting_strategy_t::no_broadcast};
const binary_injector::rhs_arg_static_params_t rhs_sp {
static_cast<size_t>(Xbyak::Zmm(1).getIdx()), this->r14,
Expand All @@ -75,15 +76,18 @@ struct jit_brgemm_kernel_t : public jit_generator {

using namespace dnnl::impl::cpu::binary_injector_utils;
std::tie(with_binary_per_oc_bcast_, with_binary_per_oc_sp_bcast_,
with_binary_channel_bcast_, with_binary_no_bcast_)
with_binary_channel_bcast_, with_binary_per_mb_w_bcast_,
with_binary_no_bcast_)
= bcast_strategies_present_tup(brg.attr->post_ops_.entry_,
dst_md_wrapper, broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::per_mb_spatial,
broadcasting_strategy_t::per_mb_w,
broadcasting_strategy_t::no_broadcast);
handle_binary_po_offset_ = with_binary_per_oc_bcast_
|| with_binary_per_oc_sp_bcast_
|| with_binary_channel_bcast_ || with_binary_no_bcast_;
|| with_binary_channel_bcast_ || with_binary_per_mb_w_bcast_
|| with_binary_no_bcast_;
}
}

Expand Down Expand Up @@ -195,6 +199,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
bool with_binary_per_oc_bcast_ = false;
bool with_binary_per_oc_sp_bcast_ = false;
bool with_binary_channel_bcast_ = false;
bool with_binary_per_mb_w_bcast_ = false;
bool with_binary_no_bcast_ = false;

Xbyak::Opmask ld_full_mask = Xbyak::Opmask(2);
Expand Down
1 change: 1 addition & 0 deletions src/cpu/x64/jit_gemm_inner_product_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ jit_pp_kernel_t<isa>::jit_pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride,
broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::per_mb_spatial,
broadcasting_strategy_t::per_mb_w,
broadcasting_strategy_t::no_broadcast};
const binary_injector::static_params_t binary_static_params {
reg_binary_inj_param_, enabled_bcast_strategy,
Expand Down
15 changes: 10 additions & 5 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,20 @@ bool post_ops_ok(brgemm_matmul_conf_t &bgmmc, const primitive_attr_t &attr,
const auto &post_ops = attr.post_ops_;
const auto ndims = dst_d.ndims();

bool is_binary_po_per_oc_sp_bcast = false;
bool is_binary_po_channel_bcast = false;
std::tie(is_binary_po_per_oc_sp_bcast, is_binary_po_channel_bcast)
bool is_binary_po_per_oc_sp_bcast {};
bool is_binary_po_channel_bcast {};
bool is_binary_po_per_mb_w_bcast {};
std::tie(is_binary_po_per_oc_sp_bcast, is_binary_po_channel_bcast,
is_binary_po_per_mb_w_bcast)
= binary_injector_utils::bcast_strategies_present_tup(
post_ops.entry_, dst_d,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::per_mb_spatial);
broadcasting_strategy_t::per_mb_spatial,
broadcasting_strategy_t::per_mb_w);
const bool supported_binary_bcast
= IMPLICATION(is_binary_po_per_oc_sp_bcast, ndims < 4)
&& IMPLICATION(is_binary_po_channel_bcast, ndims == 4);
&& IMPLICATION(is_binary_po_channel_bcast, ndims == 4)
&& IMPLICATION(is_binary_po_per_mb_w_bcast, ndims == 4);
return supported_binary_bcast
&& injector::post_ops_ok(post_ops_ok_args_t(get_max_cpu_isa(),
{sum, eltwise, binary}, post_ops, &dst_d,
Expand All @@ -66,6 +70,7 @@ bool post_ops_ok(brgemm_matmul_conf_t &bgmmc, const primitive_attr_t &attr,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::scalar,
broadcasting_strategy_t::per_mb_spatial,
broadcasting_strategy_t::per_mb_w,
broadcasting_strategy_t::no_broadcast}));
}

Expand Down

0 comments on commit be261ab

Please sign in to comment.