diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/remove_converts.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/remove_converts.cpp index 64885731c8ec7e..dca472d2e86e3d 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/remove_converts.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/remove_converts.cpp @@ -10,24 +10,26 @@ #include "snippets/op/convert_saturation.hpp" ov::intel_cpu::pass::RemoveConverts::RemoveConverts() { + using namespace ov::pass::pattern; MATCHER_SCOPE(RemoveConverts); - auto parent_convert_wrap = ov::pass::pattern::wrap_type(); - auto child_convert_wrap = ov::pass::pattern::wrap_type({ parent_convert_wrap }); + auto input_m = any_input(type_matches(ov::element::f32)); + auto parent_convert_m = wrap_type({input_m}, type_matches(ov::element::bf16)); + auto child_convert_wrap = wrap_type({parent_convert_m}, type_matches(ov::element::f32)); auto callback = [=](ov::pass::pattern::Matcher& m) { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::RemoveConverts") const auto& pm = m.get_pattern_value_map(); - const auto parent_convert = pm.at(parent_convert_wrap).get_node_shared_ptr(); + const auto parent_convert = pm.at(parent_convert_m).get_node_shared_ptr(); const auto child_convert = pm.at(child_convert_wrap).get_node_shared_ptr(); - if ( - (parent_convert->get_input_element_type(0) != element::f32) || - (parent_convert->get_output_target_inputs(0).size() != 1ull) || - (parent_convert->get_output_element_type(0) != element::bf16) || - (child_convert->get_output_element_type(0) != element::f32)) { - return false; - } - replace_output_update_name(child_convert->output(0), parent_convert->get_input_source_output(0)); + const auto& parent_convert_consumers = parent_convert->get_output_target_inputs(0); + for (const auto& input : parent_convert_consumers) { + const auto node = input.get_node(); + if (ov::is_type(node) && + node->get_output_element_type(0) == child_convert->get_output_element_type(0)) { + replace_output_update_name(node->output(0), parent_convert->input_value(0)); + } + } return true; };