diff --git a/src/common/transformations/include/transformations/common_optimizations/reshape_sequence_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/reshape_sequence_fusion.hpp index 273e134c86ae6f..4d54950fed49b2 100644 --- a/src/common/transformations/include/transformations/common_optimizations/reshape_sequence_fusion.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/reshape_sequence_fusion.hpp @@ -18,11 +18,11 @@ class TRANSFORMATIONS_API ReshapeSequenceFusion; /** * @ingroup ie_transformation_common_api - * @brief ReshpaeSequenceFusion fuses sequence of Reshape operation into single Reshape + * @brief ReshapeSequenceFusion fuses sequence of Reshape operation into single Reshape or eliminates full redundant sequence */ class ngraph::pass::ReshapeSequenceFusion: public ngraph::pass::MatcherPass { public: NGRAPH_RTTI_DECLARATION; - ReshapeSequenceFusion(); + ReshapeSequenceFusion(bool use_shape_for_elimination = true); }; diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp index d4a760f845069c..0b8258c3255a46 100644 --- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp @@ -153,7 +153,7 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptradd_matcher(); common_fusions->add_matcher(); common_fusions->add_matcher(); - common_fusions->add_matcher(); + common_fusions->add_matcher(m_use_shapes); common_fusions->set_name("ngraph::pass::CommonFusions"); manager.register_pass(); diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index 3d6d8c52f03385..22fb21b074a8e0 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -102,16 +102,25 @@ static bool eliminate_reshape_v1(const std::shared_ptr& node) { if (ov::as_type_ptr(input_node) || ov::as_type_ptr(input_node) || ov::as_type_ptr(input_node)) { + if (input_node->get_output_target_inputs(0).size() != 1) + return false; + auto shape = node->get_output_shape(0); - std::vector vi; - vi.assign(shape.begin(), shape.end()); - auto pat = opset3::Constant::create(element::i64, Shape{vi.size()}, vi); - auto new_reshape = - make_shared(input.get_node()->input_value(0), pat, false); - new_reshape->set_friendly_name(node->get_friendly_name()); - copy_runtime_info({input_node, node}, new_reshape); - replace_node(node, new_reshape); - return true; + + // remove interchangeable nodes + if (input_node->get_input_partial_shape(0).is_static() && input_node->get_input_shape(0) == shape) { + return replace_output_update_name(node->output(0), input_node->input_value(0)); + } else { + std::vector vi; + vi.assign(shape.begin(), shape.end()); + auto pat = opset3::Constant::create(element::i64, Shape{vi.size()}, vi); + auto new_reshape = + make_shared(input.get_node()->input_value(0), pat, false); + new_reshape->set_friendly_name(node->get_friendly_name()); + copy_runtime_info({input_node, node}, new_reshape); + replace_node(node, new_reshape); + return true; + } } return false; diff --git a/src/common/transformations/src/transformations/common_optimizations/reshape_sequence_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/reshape_sequence_fusion.cpp index 6516e14eca6016..f95adb43f1011d 100644 --- a/src/common/transformations/src/transformations/common_optimizations/reshape_sequence_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/reshape_sequence_fusion.cpp @@ -55,7 +55,7 @@ bool has_valid_pattern(const ov::Output& node_out) { } } // namespace -ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion() { +ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion(bool use_shape_for_elimination) { MATCHER_SCOPE(ReshapeSequenceFusion); auto reshape_input = pattern::any_input(); auto reshape_a_pattern = pattern::wrap_type(); @@ -87,9 +87,21 @@ ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion() { input = node->input_value(0); } - reshape->input(0).replace_source_output(input); - copy_runtime_info(nodes, reshape); - return false; + // remove redundant reshapes + bool replaced = false; + if (use_shape_for_elimination && input.get_partial_shape().is_static() && reshape->get_output_partial_shape(0).is_static() && + input.get_shape() == reshape->get_output_shape(0)) { + // in case if elimination is not allowed we still can eliminate all transposes except last one + replaced = replace_output_update_name(reshape->output(0), input); + } + + if (!replaced) { + reshape->input(0).replace_source_output(input); + copy_runtime_info(nodes, reshape); + return false; // because root node wasn't replaced + } + + return true; }; auto m = std::make_shared(reshape_b, matcher_name); diff --git a/src/plugins/intel_cpu/src/ngraph_transformations/convert_to_cpu_specific_opset.hpp b/src/plugins/intel_cpu/src/ngraph_transformations/convert_to_cpu_specific_opset.hpp index e3da16039f88aa..75a7809321b824 100644 --- a/src/plugins/intel_cpu/src/ngraph_transformations/convert_to_cpu_specific_opset.hpp +++ b/src/plugins/intel_cpu/src/ngraph_transformations/convert_to_cpu_specific_opset.hpp @@ -18,6 +18,7 @@ #include "transformations/convert_precision.hpp" #include "transformations/utils/utils.hpp" #include "rnn_sequences_optimization.hpp" +#include "transformations/common_optimizations/reshape_sequence_fusion.hpp" namespace MKLDNNPlugin { @@ -34,6 +35,8 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr &nGraphF if (!ngraph::op::util::has_op_with_type(nGraphFunc)) { manager.register_pass(); } + // after transformation "MoveEltwiseUpThroughDataMov" there can be Reshape sequences that should be eliminated or fused + manager.register_pass(); manager.register_pass(); manager.register_pass(precisions_array {{ ngraph::element::i64, ngraph::element::i32 }}); diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp index d8244a4239b8e7..b38ab845172c8d 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp @@ -140,17 +140,48 @@ TEST(nop_elimination, squeeze_reshape_elimination_check_info) { pass_manager.register_pass(); pass_manager.run_passes(f); - bool reshape_is_missing = true; + bool movement_are_missing = true; for (auto node : f->get_ops()) { - if (node->get_friendly_name() == "reshape") { - reshape_is_missing = false; - ASSERT_TRUE(std::dynamic_pointer_cast(node)); - auto original_names = ngraph::getFusedNamesVector(node); - sort(original_names.begin(), original_names.end()); - ASSERT_EQ(original_names, std::vector({"reshape", "squeeze"})); + if (node->get_friendly_name() == "reshape" || node->get_friendly_name() == "squeeze") { + movement_are_missing = false; } } - ASSERT_FALSE(reshape_is_missing); + ASSERT_TRUE(movement_are_missing); +} + +TEST(nop_elimination, squeeze_unsqueeze_elimination) { + std::shared_ptr f; + { + auto arg = std::make_shared(element::f32, PartialShape{8, 16, 1, 3}); + + auto relu = std::make_shared(arg); + relu->set_friendly_name("relu"); + + auto squeeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2}); + auto squeeze = std::make_shared(relu, squeeze_axes); + squeeze->set_friendly_name("squeeze"); + + auto unsqueeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2}); + auto unsqueeze = std::make_shared(squeeze, unsqueeze_axes); + unsqueeze->set_friendly_name("unsqueeze"); + + auto abs = std::make_shared(unsqueeze); + + f = std::make_shared(NodeVector{abs}, ParameterVector{arg}); + } + + pass::Manager pass_manager; + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.run_passes(f); + + bool movement_are_missing = true; + for (auto node : f->get_ops()) { + if (node->get_friendly_name() == "squeeze" || node->get_friendly_name() == "unsqueeze") { + movement_are_missing = false; + } + } + ASSERT_TRUE(movement_are_missing); } TEST(nop_elimination, reshape_elimination_v1_dynamic) { @@ -165,6 +196,33 @@ TEST(nop_elimination, reshape_elimination_v1_dynamic) { ASSERT_TRUE(count_ops_of_type(f) == 1); } +TEST(nop_elimination, reshape_elimination_v1_check_consumer_count) { + std::shared_ptr f; + { + auto arg = std::make_shared(element::f32, PartialShape{8, 16, 1, 3}); + + auto reshape_1_shape = opset4::Constant::create(element::i64, Shape{2}, {128, 3}); + auto reshape_1 = std::make_shared(arg, reshape_1_shape, false); + reshape_1->set_friendly_name("reshape_1"); + + auto reshape_2_shape = opset4::Constant::create(element::i64, Shape{4}, {8, 16, 1, 3}); + auto reshape_2 = std::make_shared(reshape_1, reshape_2_shape, false); + reshape_2->set_friendly_name("reshape_2"); + + auto relu = std::make_shared(reshape_1); + relu->set_friendly_name("relu"); + + f = std::make_shared(NodeVector{reshape_2, relu}, ParameterVector{arg}); + } + + pass::Manager pass_manager; + pass_manager.register_pass(); + pass_manager.register_pass(); + pass_manager.run_passes(f); + + ASSERT_TRUE(count_ops_of_type(f) == 2); +} + TEST(nop_elimination, concat_elimination_single_node) { int64_t a = 0; auto A = make_shared(element::f32, Shape{2, 3}); diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/reshape_sequence_fusion.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/reshape_sequence_fusion.cpp index 01ced7b3ac47ba..9c2c271b9bf03d 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/reshape_sequence_fusion.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/reshape_sequence_fusion.cpp @@ -305,3 +305,21 @@ TEST_F(TransformationTestsF, ReshapeSequenceFusionNeg5_special_zero_false) { manager.register_pass(); } } + +TEST_F(TransformationTestsF, ReshapeSequenceFusionEliminate) { + { + auto data = std::make_shared(element::f32, Shape{1, 2, 3}); + auto relu = std::make_shared(data); + auto a = reshape(relu, {2, 3}); + auto b = reshape(a, {1, 2, 3}); + function = std::make_shared(OutputVector{b}, ParameterVector{data}); + + manager.register_pass(); + } + + { + auto data = std::make_shared(element::f32, Shape{1, 2, 3}); + auto relu = std::make_shared(data); + function_ref = std::make_shared(OutputVector{relu}, ParameterVector{data}); + } +}