Skip to content

Commit

Permalink
[CPU] MergeTransposeReorder extending
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Jan 16, 2024
1 parent a240ae8 commit f3802b1
Show file tree
Hide file tree
Showing 7 changed files with 635 additions and 206 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,14 @@ TRANSFORMATIONS_API std::vector<Input<Node>> get_node_target_inputs(const std::s
TRANSFORMATIONS_API std::shared_ptr<Node> node_to_get_shape_value_of_indices_from_shape_node(
const std::shared_ptr<Node>& shape_node,
const std::vector<size_t>& indices,
const std::vector<std::shared_ptr<Node>>& copy_rt_info_from = {});
const std::vector<std::shared_ptr<Node>>& copy_rt_info_from = {},
const ov::element::Type& shape_path_precision = ov::element::i64);

TRANSFORMATIONS_API std::shared_ptr<Node> node_to_get_shape_value_of_indices_from_shape_source(
const Output<Node>& shape_source,
const std::vector<size_t>& indices,
const std::vector<std::shared_ptr<Node>>& copy_rt_info_from = {});
const std::vector<std::shared_ptr<Node>>& copy_rt_info_from = {},
const ov::element::Type& shape_path_precision = ov::element::i64);

TRANSFORMATIONS_API bool is_dequantization_subgraph(const Output<Node>& node);

Expand Down
17 changes: 11 additions & 6 deletions src/common/transformations/src/transformations/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,10 @@ std::vector<Input<Node>> get_node_target_inputs(const std::shared_ptr<Node>& nod
std::shared_ptr<ov::Node> node_to_get_shape_value_of_indices_from_shape_node(
const std::shared_ptr<ov::Node>& shape_node,
const std::vector<size_t>& indices,
const std::vector<std::shared_ptr<ov::Node>>& copy_rt_info_from) {
const auto& indices_op = v0::Constant::create(ov::element::i64, {indices.size()}, indices);
const auto& axis_op = v0::Constant::create(ov::element::i64, {}, {0});
const std::vector<std::shared_ptr<ov::Node>>& copy_rt_info_from,
const ov::element::Type& shape_path_precision) {
const auto& indices_op = v0::Constant::create(shape_path_precision, {indices.size()}, indices);
const auto& axis_op = v0::Constant::create(shape_path_precision, {}, {0});
auto op = make_try_fold<v7::Gather>(shape_node, indices_op, axis_op);
if (!copy_rt_info_from.empty())
ov::copy_runtime_info(copy_rt_info_from, {op, indices_op, axis_op});
Expand All @@ -252,11 +253,15 @@ std::shared_ptr<ov::Node> node_to_get_shape_value_of_indices_from_shape_node(
std::shared_ptr<ov::Node> node_to_get_shape_value_of_indices_from_shape_source(
const ov::Output<ov::Node>& shape_source,
const std::vector<size_t>& indices,
const std::vector<std::shared_ptr<ov::Node>>& copy_rt_info_from) {
const auto& shape_node = make_try_fold<v3::ShapeOf>(shape_source);
const std::vector<std::shared_ptr<ov::Node>>& copy_rt_info_from,
const ov::element::Type& shape_path_precision) {
const auto& shape_node = make_try_fold<v3::ShapeOf>(shape_source, shape_path_precision);
if (!copy_rt_info_from.empty())
ov::copy_runtime_info(copy_rt_info_from, shape_node);
return node_to_get_shape_value_of_indices_from_shape_node(shape_node, indices, copy_rt_info_from);
return node_to_get_shape_value_of_indices_from_shape_node(shape_node,
indices,
copy_rt_info_from,
shape_path_precision);
}

bool shapes_equal_except_dynamic_expected_batch(const ov::PartialShape& expected, const ov::PartialShape& actual) {
Expand Down
468 changes: 325 additions & 143 deletions src/plugins/intel_cpu/src/graph_optimizer.cpp

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions src/plugins/intel_cpu/src/graph_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,33 @@ class GraphOptimizer {
void FusePerformedAsScaleShiftAndFakeQuantize(Graph &graph);
void FuseClampAndFakeQuantize(Graph &graph);
void MergeTransposeAndReorder(Graph &graph);
void MergeReorderAndTranspose(Graph &graph);
void reshapeRnnSeq(Graph &graph);
void RemoveSameConvert(Graph &graph);
void RemoveMemoryInputConvert(Graph &graph);
void RemoveConvertMemoryOutput(Graph &graph);
void MatchSdpaKvCache(Graph &graph);

// Method checks that after the sequential execution of Transpose and Reorder nodes,
// the order of the elements in the memory (physical layout) will not change.
bool checkAscendingSummaryOrder(const VectorDims& transposeOrder,
const VectorDims& layoutOrder,
const VectorDims& reorderInOrder,
const VectorDims& reorderOutOrder);
// Method merges Transpose -> Reshape(optional) -> Reorder sequences which do opposite permutation to each other.
// Reverse order Reorder -> Reshape(optional) -> Transpose is supported too.
// Reshape support has the following limitations:
// - direct order: Only reshape which separates one of the dimension on 2 consecutive ones is supported
// - reverse order: Only reshape which fuses 2 consecutive dimensions into one is supported
// Example:
// chain [physical layout: NCHW, logical layout: NCHW] -> Transpose(order=0312) -> [physical layout: NWCH, logical layout: NCHW] ->
// Reorder(nchw->nhwc) -> [physical layout: NCHW, logical layout: NHWC] can be replaced with Reorder(nchw->nhwc; isOptimized=true)
// which will just reinterprets layout without physical change of the memory.
void mergeTransposeReshapeReorder(Graph& graph,
const NodePtr& transposeNode,
const NodePtr& reshapeNode,
const NodePtr& reorderNode,
const bool reverseOrder);
};

} // namespace intel_cpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,15 @@ INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseTransposeAndReorderTest1, fuseTranspos
|Input | |Input |
--------- ---------
| |
| -------------
--------- | ----------- |
|Reorder| | |Transpose| |
--------- | ----------- |
| | | |
--------- | ----------- |
|Transpose| | |Reorder| |
--------- | ----------- |
| |-------------|
|------------ | |-------------|
| ----------- | | ----------- |
| |Reorder| | | |Transpose| |
| ----------- | | ----------- |
| | | | | |
| ----------- | | ----------- |
| |Transpose| | | |Reorder| |
| ----------- | | ----------- |
|------------ | |-------------|
| |
-------- --------
| |
Expand Down Expand Up @@ -223,7 +223,7 @@ void FuseTransposeAndReorderTest2::create_model() {

TEST_P(FuseTransposeAndReorderTest2, CompareWithRefs) {
run();
check_transpose_count(1);
check_transpose_count(0);
}

INSTANTIATE_TEST_SUITE_P(smoke_Basic, FuseTransposeAndReorderTest2, fuseTransposeAndReorderCommonParams, FuseTransposeAndReorderTest::getTestCaseName);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <memory>
#include <string>
#include <tuple>
#include <vector>

#include "common_test_utils/common_utils.hpp"
#include "openvino/opsets/opset10.hpp"
#include "ov_models/builders.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/cpu_test_utils.hpp"
#include "transformations/utils/utils.hpp"

using namespace CPUTestUtils;
using namespace ov::test;

namespace CPUSubgraphTestsDefinitions {
template <typename NodeType, typename... Args>
static std::shared_ptr<ov::Node> make_layer_with_bias(Args&&... args) {
const auto node = std::make_shared<NodeType>(std::forward<Args>(args)...);
const auto& precision = node->get_output_element_type(0);
const auto bias_const = ngraph::builder::makeConstant(precision, ov::Shape{}, std::vector<float>{}, true);
const auto bias = std::make_shared<ov::opset10::Add>(node, bias_const);
return bias;
}

/*
Parameter(4D)
|
Reshape(3D)
|
Transpose(0, 2, 1)
|
MatMul
|
Transpose(0, 2, 1)
|
Reshape(4D)
|
GroupConvolution
|
Reshape(3D)
|
Transpose(0, 2, 1)
|
MatMul
*/
using MergeTransposeReorderTestParams = std::tuple<InputShape, ElementType, size_t>;
class MergeTransposeReorderCPUTest : public testing::WithParamInterface<MergeTransposeReorderTestParams>, virtual public SubgraphBaseTest, public CPUTestsBase {
public:
static std::string getTestCaseName(const testing::TestParamInfo<MergeTransposeReorderTestParams> &obj) {
InputShape input_shape;
ElementType precision;
size_t optimized_reorders_count;
std::tie(input_shape, precision, optimized_reorders_count) = obj.param;

std::ostringstream results;
results << "IS=(" << ov::test::utils::partialShape2str({input_shape.first}) << "_";
results << ")_TS=(";
for (const auto& static_shape : input_shape.second) {
results << ov::test::utils::vec2str(static_shape) << "_";
}
results << ")_precision=" << precision;
results << ")_optimized_reorders_count=" << optimized_reorders_count;
return results.str();
}

protected:
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;
InputShape input_shape;
ElementType precision;
std::tie(input_shape, precision, m_optimized_reorders_count) = this->GetParam();
init_input_shapes({input_shape});

const auto shapeof_subgraph_prc = ov::element::i32;
OPENVINO_ASSERT(inputDynamicShapes[0].rank().is_static() && inputDynamicShapes[0].size() == 4, "initSubgraph: only 4D shapes are supported");
OPENVINO_ASSERT(inputDynamicShapes[0][1].is_static(), "initSubgraph: only static channels dim is supported");

const auto param = std::make_shared<ov::opset10::Parameter>(precision, inputDynamicShapes[0]);
const auto reshape_const_1 = ov::opset10::Constant::create(shapeof_subgraph_prc, {3}, {0, 0, -1});
const auto reshape_1 = std::make_shared<ov::opset10::Reshape>(param, reshape_const_1, true);

const auto transpose_const_1 = ov::opset10::Constant::create(shapeof_subgraph_prc, {3}, {0, 2, 1});
const auto transpose_1 = std::make_shared<ov::opset10::Transpose>(reshape_1, transpose_const_1);

const size_t channels = inputDynamicShapes[0][1].get_length();
const size_t fc_out_channels = 512;
const auto fc_weights_1 = ngraph::builder::makeConstant(precision, ov::Shape{fc_out_channels, channels}, std::vector<float>{}, true);
const auto fc_1 = make_layer_with_bias<ov::opset10::MatMul>(transpose_1, fc_weights_1, false, true);

const auto transpose_const_2 = ov::opset10::Constant::create(shapeof_subgraph_prc, {3}, {0, 2, 1});
const auto transpose_2 = std::make_shared<ov::opset10::Transpose>(fc_1, transpose_const_2);
const auto spatial_dims = ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(param, {2, 3}, {}, shapeof_subgraph_prc);
const auto unchangable_dims = ov::opset10::Constant::create(shapeof_subgraph_prc, {2}, {0, 0});
const auto reshape_const_2 = ov::op::util::make_try_fold<ov::opset10::Concat>(ov::OutputVector{unchangable_dims, spatial_dims}, 0);
const auto reshape_2 = std::make_shared<ov::opset10::Reshape>(transpose_2, reshape_const_2, true);

const auto conv_weights = ngraph::builder::makeConstant(precision, ov::Shape{fc_out_channels, 1, 1, 3, 3}, std::vector<float>{}, true);
const auto conv_with_bias = make_layer_with_bias<ov::opset10::GroupConvolution>(reshape_2,
conv_weights,
ov::Strides{1, 1},
ov::CoordinateDiff{1, 1},
ov::CoordinateDiff{1, 1},
ov::Strides{1, 1});
// It's necessary to force acdb layout to be sure that the reorder, which changes dims order, will be inserted
// (by default acdb layout is chosen only on >= AVX512 platforms)
const auto conv = conv_with_bias->get_input_node_shared_ptr(0);
const auto acdb_format = CPUTestUtils::cpu_memory_format_t::acdb;
conv->get_rt_info() = makeCPUInfo({acdb_format}, {acdb_format}, {});

const auto dim_h = ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(param, {2}, {}, shapeof_subgraph_prc);
const auto dim_w = ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(param, {3}, {}, shapeof_subgraph_prc);
const auto fused_spatial_dims = ov::op::util::make_try_fold<ov::opset10::Multiply>(dim_h, dim_w);
const auto reshape_const_3 = ov::op::util::make_try_fold<ov::opset10::Concat>(ov::OutputVector{unchangable_dims, fused_spatial_dims}, 0);
const auto reshape_3 = std::make_shared<ov::opset10::Reshape>(conv_with_bias, reshape_const_3, true);
const auto transpose_const_3 = ov::opset10::Constant::create(shapeof_subgraph_prc, {3}, {0, 2, 1});
const auto transpose_3 = std::make_shared<ov::opset10::Transpose>(reshape_3, transpose_const_3);

const auto fc_weights_2 = ngraph::builder::makeConstant(precision, ov::Shape{channels, fc_out_channels}, std::vector<float>{}, true);
const auto fc_2 = make_layer_with_bias<ov::opset10::MatMul>(transpose_3, fc_weights_2, false, true);
function = std::make_shared<ov::Model>(fc_2, ov::ParameterVector{param}, "MergeTransposeReorderModel");
}

void validate_exec_graph() {
const size_t original_reshape_transpose_count = 3;
const size_t non_optimized_reshape_transpose_count = original_reshape_transpose_count - m_optimized_reorders_count;
CheckNumberOfNodesWithType(compiledModel, "Transpose", non_optimized_reshape_transpose_count);
CheckNumberOfNodesWithType(compiledModel, "Reshape", non_optimized_reshape_transpose_count);
size_t fake_reorder_count = 0;
for (const auto& node : compiledModel.get_runtime_model()->get_ops()) {
const auto& rtInfo = node->get_rt_info();
auto it = rtInfo.find(ExecGraphInfoSerialization::LAYER_TYPE);
IE_ASSERT(rtInfo.end() != it);
if (it->second.as<std::string>() == "Reorder" && node->get_friendly_name().find("_fake") != std::string::npos) {
fake_reorder_count++;
}
}
ASSERT_EQ(fake_reorder_count, m_optimized_reorders_count);
}

private:
size_t m_optimized_reorders_count = 0;
};

TEST_P(MergeTransposeReorderCPUTest, CompareWithRefs) {
run();
validate_exec_graph();
}

namespace {
std::vector<InputShape> static_shapes = {
InputShape{{}, {{1, 32, 16, 16}}},
};

INSTANTIATE_TEST_SUITE_P(smoke_MergeTransposeReorder_static, MergeTransposeReorderCPUTest,
::testing::Combine(::testing::ValuesIn(static_shapes),
::testing::Values(ElementType::f32),
::testing::Values(2)),
MergeTransposeReorderCPUTest::getTestCaseName);

std::vector<InputShape> dynamic_shapes = {
InputShape{{-1, 32, -1, -1}, {{1, 32, 16, 16}}},
InputShape{{-1, 32, 16, 16}, {{1, 32, 16, 16}}},
};

INSTANTIATE_TEST_SUITE_P(smoke_MergeTransposeReorder_dynamic, MergeTransposeReorderCPUTest,
::testing::Combine(::testing::ValuesIn(dynamic_shapes),
::testing::Values(ElementType::f32),
::testing::Values(0)),
MergeTransposeReorderCPUTest::getTestCaseName);
} // namespace
} // namespace CPUSubgraphTestsDefinitions
Loading

0 comments on commit f3802b1

Please sign in to comment.