Skip to content

Commit

Permalink
RemoveConverts: supported a case with several consumers
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Feb 26, 2024
1 parent d68f463 commit af70da2
Showing 1 changed file with 13 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,26 @@
#include "snippets/op/convert_saturation.hpp"

ov::intel_cpu::pass::RemoveConverts::RemoveConverts() {
using namespace ov::pass::pattern;
MATCHER_SCOPE(RemoveConverts);
auto parent_convert_wrap = ov::pass::pattern::wrap_type<snippets::op::ConvertSaturation>();
auto child_convert_wrap = ov::pass::pattern::wrap_type<snippets::op::ConvertSaturation>({ parent_convert_wrap });
auto input_m = any_input(type_matches(ov::element::f32));
auto parent_convert_m = wrap_type<snippets::op::ConvertSaturation>({input_m}, type_matches(ov::element::bf16));
auto child_convert_wrap = wrap_type<snippets::op::ConvertSaturation>({parent_convert_m}, type_matches(ov::element::f32));

auto callback = [=](ov::pass::pattern::Matcher& m) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::RemoveConverts")
const auto& pm = m.get_pattern_value_map();
const auto parent_convert = pm.at(parent_convert_wrap).get_node_shared_ptr();
const auto parent_convert = pm.at(parent_convert_m).get_node_shared_ptr();
const auto child_convert = pm.at(child_convert_wrap).get_node_shared_ptr();
if (
(parent_convert->get_input_element_type(0) != element::f32) ||
(parent_convert->get_output_target_inputs(0).size() != 1ull) ||
(parent_convert->get_output_element_type(0) != element::bf16) ||
(child_convert->get_output_element_type(0) != element::f32)) {
return false;
}

replace_output_update_name(child_convert->output(0), parent_convert->get_input_source_output(0));
const auto& parent_convert_consumers = parent_convert->get_output_target_inputs(0);
for (const auto& input : parent_convert_consumers) {
const auto node = input.get_node();
if (ov::is_type<snippets::op::ConvertSaturation>(node) &&
node->get_output_element_type(0) == child_convert->get_output_element_type(0)) {
replace_output_update_name(node->output(0), parent_convert->input_value(0));
}
}
return true;
};

Expand Down

0 comments on commit af70da2

Please sign in to comment.