Skip to content

Commit

Permalink
graph: backend: compiler: fusion: update infer slice logic for matmul…
Browse files Browse the repository at this point in the history
… with explicit broadcast semantic
  • Loading branch information
Yun-Fly authored and vpirogov committed Nov 2, 2023
1 parent a2ec0a0 commit 5476ef7
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 6 deletions.
38 changes: 32 additions & 6 deletions src/graph/backend/graph_compiler/core/src/ops/matmul_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,7 @@ void matmul_core_op_t::infer_slice_ranges(
}

if (!known_ranges_map[0].empty() && known_ranges_map[1].empty()) {
// implicit broadcast semantic
if (inp_plain_size < wei_plain_size) {
stat_map.append_ops_by_status(this, infer_status_code::RETRY);
return;
Expand All @@ -1151,16 +1152,29 @@ void matmul_core_op_t::infer_slice_ranges(
if (std::find(
blocking_axis.B_bs.begin(), blocking_axis.B_bs.end(), i)
!= blocking_axis.B_bs.end()) {
wei_slice[i]
= inp_slice[blocking_axis.A_bs[blocking_axis.A_bs.size()
- 1 - bs_cnt]];
int bs_idx_inp
= blocking_axis
.A_bs[blocking_axis.A_bs.size() - 1 - bs_cnt];
// explicit broadcast semantic
if (inp_dims[bs_idx_inp] < wei_dims[i]) {
stat_map.append_ops_by_status(
this, infer_status_code::RETRY);
return;
} else if (inp_dims[bs_idx_inp] == wei_dims[i]) {
wei_slice[i] = inp_slice[bs_idx_inp];
} else {
COMPILE_ASSERT(
wei_dims[i] == 1, "broadcast weight is expected")
wei_slice[i] = std::make_pair(expr(0), expr(1));
}
bs_cnt++;
} else {
wei_slice[i] = std::make_pair(expr(0), wei_dims_expr[i]);
}
}
}
if (known_ranges_map[0].empty() && !known_ranges_map[1].empty()) {
// implicit broadcast semantic
if (inp_plain_size > wei_plain_size) {
stat_map.append_ops_by_status(this, infer_status_code::RETRY);
return;
Expand All @@ -1172,9 +1186,21 @@ void matmul_core_op_t::infer_slice_ranges(
if (std::find(
blocking_axis.A_bs.begin(), blocking_axis.A_bs.end(), i)
!= blocking_axis.A_bs.end()) {
inp_slice[i]
= wei_slice[blocking_axis.B_bs[blocking_axis.B_bs.size()
- 1 - bs_cnt]];
int bs_idx_wei
= blocking_axis
.B_bs[blocking_axis.B_bs.size() - 1 - bs_cnt];
// explicit broadcast semantic
if (wei_dims[bs_idx_wei] < inp_dims[i]) {
stat_map.append_ops_by_status(
this, infer_status_code::RETRY);
return;
} else if (wei_dims[bs_idx_wei] == inp_dims[i]) {
inp_slice[i] = wei_slice[bs_idx_wei];
} else {
COMPILE_ASSERT(
inp_dims[i] == 1, "broadcast input is expected")
inp_slice[i] = std::make_pair(expr(0), expr(1));
}
bs_cnt++;
} else {
inp_slice[i] = std::make_pair(expr(0), inp_dims_expr[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2502,3 +2502,39 @@ TEST(GCCore_CPU_graph_mixed_partition_cpp,
)";
EXPECT_EQ(ss.str(), expected_str_spr);
}

TEST(GCCore_CPU_graph_mixed_partition_cpp, InferSliceForBMMWithBroadcast) {
SET_THREADS_OR_SKIP(8);

sc_graph_t graph;
auto input0 = graph.make_input({graph_tensor::make(
{8, 16, 64, 128}, sc_data_format_t(), sc_data_type_t::u8())});
// explicit broadcast semantic on weight side
auto weight0 = graph.make_input({graph_tensor::make(
{8, 1, 128, 64}, sc_data_format_t(), sc_data_type_t::s8())});

any_map_t attrs({{"transpose_a", false}, {"transpose_b", false},
{"output2d", false}, {"use_mmm", false}});
// bmm
auto bmm = graph.make("matmul_core",
{input0->get_outputs()[0], weight0->get_outputs()[0]}, {}, attrs);
ops::matmul_core_config_t cfg = {32, 32, 32};
bmm->dyn_cast<ops::matmul_core_op_t>()->set_config(
reflection::general_object_t::make(cfg));
graph.make_output({bmm->get_outputs()[0]});
auto ctx = std::make_shared<context_t>(*get_test_ctx());
ctx->flags_.use_cost_model_ = true;

graph_driver_before_fusion(graph, ctx);
mixed_partition(graph, ctx);
std::stringstream ss;
print_graph(graph, ss, true);
// The matmul op could not be fused into reorder op along weight side
std::string expected_str
= R"(graph(v0: u8[8, 16, 64, 128], v1: s8[8, 1, 128, 64]) -> [v2: s32[8, 16, 64, 64]] {
[v3: s8[8, 1, 2, 4, 8, 32, 4]] = reorder(v1)
[v2: s32[8, 16, 64, 64]] = matmul_core(v0, v3)
}
)";
EXPECT_EQ(ss.str(), expected_str);
}

0 comments on commit 5476ef7

Please sign in to comment.