Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ TRANSFORMATIONS_API bool is_constant_and_all_values_equal_int(const Output<Node>

TRANSFORMATIONS_API bool is_on_constant_path(const ov::Output<ov::Node>& output);

TRANSFORMATIONS_API bool is_on_constant_or_param_path(const ov::Output<ov::Node>& output);

TRANSFORMATIONS_API bool process_subgraph(ov::pass::ModelPass& model_pass, const std::shared_ptr<Node>& node);

TRANSFORMATIONS_API std::tuple<std::shared_ptr<ov::Node>, // result
Expand Down
38 changes: 38 additions & 0 deletions src/common/transformations/src/transformations/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,44 @@ bool is_on_constant_path(const ov::Output<ov::Node>& output) {
return status;
}

bool is_on_constant_or_param_path(const ov::Output<ov::Node>& output) {
auto status = true;

auto root_node = output.get_node();
if (!root_node || root_node->get_output_size() == 0) {
return false;
}
std::deque<ov::Node*> nodes_to_calculate = {root_node};

std::unordered_set<ov::Node*> visited;
while (status && !nodes_to_calculate.empty()) {
auto current_node = nodes_to_calculate.front();
nodes_to_calculate.pop_front();
if (visited.count(current_node)) {
continue;
}
visited.insert(current_node);
// RandomUniform output changes during runtime, so we should not consider it as a constant
if (current_node->get_type_info() == ov::op::v8::RandomUniform::get_type_info_static()) {
return false;
}

if (current_node->get_input_size() == 0 &&
!(ov::is_type<ov::op::v0::Constant>(current_node) || ov::is_type<ov::op::v0::Parameter>(current_node))) {
status = false;
} else {
// not a leaf - continue to search
for (const auto& input_value : current_node->input_values()) {
const auto& input_node = input_value.get_node();
if (!visited.count(input_node)) {
nodes_to_calculate.push_front(input_node);
}
}
}
}
return status;
}

bool process_subgraph(ov::pass::ModelPass& model_pass, const std::shared_ptr<Node>& node) {
bool changed = false;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ using namespace ov::pass::pattern;
((in_ps.size() == 3 && out_ps.size() == 2) || (in_ps.size() == 4 && out_ps.size() == 3));\
};\
\
auto compressed_weights_m = wrap_type<ov::op::v0::Constant>(compressed_constant);\
auto weights_const_m = wrap_type<ov::op::v0::Constant>(compressed_constant);\
auto weights_param_m = wrap_type<ov::op::v0::Parameter>(compressed_constant);\
auto weights_param_reshape_m = wrap_type<ov::op::v1::Reshape>({weights_param_m, any_input()});\
auto compressed_weights_m = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{weights_const_m, weights_param_m, weights_param_reshape_m});\
auto convert_m = wrap_type<ov::op::v0::Convert>({compressed_weights_m});\
\
auto sub_const_m = wrap_type<ov::op::v0::Constant>();\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,15 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon

auto weight_ptr = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(compressed_weights_m).get_node_shared_ptr());
bool weight_u8 = false;
if (weight_ptr->get_element_type() == ov::element::u8 || weight_ptr->get_element_type() == ov::element::i8)
weight_u8 = true;
if (pattern_map.count(weights_const_m)) {
auto weight_ptr = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(weights_const_m).get_node_shared_ptr());
if (weight_ptr->get_element_type() == ov::element::u8 || weight_ptr->get_element_type() == ov::element::i8)
weight_u8 = true;
} else {
auto weight_ptr = ov::as_type_ptr<ov::op::v0::Parameter>(pattern_map.at(weights_param_m).get_node_shared_ptr());
if (weight_ptr->get_element_type() == ov::element::u8 || weight_ptr->get_element_type() == ov::element::i8)
weight_u8 = true;
}

auto reshape_const = [has_transpose, grouped, is_weight_3d](std::shared_ptr<ov::Node> node) {
auto constant = ov::as_type_ptr<ov::op::v0::Constant>(node);
Expand All @@ -73,7 +80,7 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
return constant;
else
new_shape = (has_transpose || !grouped) ? ov::Shape{current_shape[0] * current_shape[1], current_shape[2]}
: ov::Shape{current_shape[0], current_shape[1] * current_shape[2]};
: ov::Shape{current_shape[0], current_shape[1] * current_shape[2]};
} else {
OPENVINO_ASSERT(current_shape.size() == 4 && is_weight_3d);
new_shape = (has_transpose || !grouped) ? ov::Shape{current_shape[0], current_shape[1] * current_shape[2], current_shape[3]}
Expand Down Expand Up @@ -102,7 +109,6 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
return result;
};


const ov::Output<Node>& fc_input_a = fc->input(0).get_source_output();
const auto& scale = reshape_const(pattern_map.at(mul_const_m).get_node_shared_ptr());
std::shared_ptr<ov::Node> optional_zero_point = nullptr;
Expand All @@ -112,61 +118,104 @@ ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyCon
optional_zero_point = convert_const_to_u8(reshape_const(pattern_map.at(sub_const_m).get_node_shared_ptr()));
}

std::shared_ptr<ov::Node> fc_input_b = reshape_const(pattern_map.at(compressed_weights_m).get_node_shared_ptr());
std::shared_ptr<ov::Node> fc_input_scale = scale;
std::shared_ptr<ov::Node> fc_input_zp = optional_zero_point;
std::shared_ptr<ov::Node> fc_input_bias = pattern_map.at(bias_m).get_node_shared_ptr();
std::vector<std::shared_ptr<ov::Node>> result_nodes = {};

if (has_transpose) {
const auto& transpose = pattern_map.at(transpose_m).get_node_shared_ptr();
std::shared_ptr<ov::Node> transpose_const = pattern_map.at(transpose_const_m).get_node_shared_ptr();
if (ov::shape_size(transpose_const->get_shape()) != fc_input_b->get_output_partial_shape(0).size()) {
std::vector<int32_t> new_order(fc_input_b->get_output_partial_shape(0).size());
std::iota(new_order.begin(), new_order.end(), 0);
std::swap(new_order[new_order.size() - 1], new_order[new_order.size() - 2]);
transpose_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{new_order.size()}, new_order);
if (pattern_map.count(weights_const_m)) {
std::shared_ptr<ov::Node> fc_input_b = reshape_const(pattern_map.at(weights_const_m).get_node_shared_ptr());
std::shared_ptr<ov::Node> fc_input_scale = scale;
std::shared_ptr<ov::Node> fc_input_zp = optional_zero_point;
std::shared_ptr<ov::Node> fc_input_bias = pattern_map.at(bias_m).get_node_shared_ptr();
std::vector<std::shared_ptr<ov::Node>> result_nodes = {};

if (has_transpose) {
const auto& transpose = pattern_map.at(transpose_m).get_node_shared_ptr();
std::shared_ptr<ov::Node> transpose_const = pattern_map.at(transpose_const_m).get_node_shared_ptr();
if (ov::shape_size(transpose_const->get_shape()) != fc_input_b->get_output_partial_shape(0).size()) {
std::vector<int32_t> new_order(fc_input_b->get_output_partial_shape(0).size());
std::iota(new_order.begin(), new_order.end(), 0);
std::swap(new_order[new_order.size() - 1], new_order[new_order.size() - 2]);
transpose_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{new_order.size()}, new_order);
}

fc_input_b = transpose->clone_with_new_inputs({fc_input_b->output(0), transpose_const});
result_nodes.push_back(fc_input_b);

if (ov::shape_size(scale->output(0).get_shape()) > 1) {
fc_input_scale = transpose->clone_with_new_inputs({scale->output(0), transpose_const});
result_nodes.push_back(fc_input_scale);
}

if (with_zero_point && ov::shape_size(optional_zero_point->output(0).get_shape()) > 1) {
fc_input_zp = transpose->clone_with_new_inputs({optional_zero_point->output(0), transpose_const});
result_nodes.push_back(fc_input_zp);
}
}

fc_input_b = transpose->clone_with_new_inputs({ fc_input_b->output(0), transpose_const });
result_nodes.push_back(fc_input_b);
if (pattern_map.count(mul2_m)) {
auto mul2_op_const = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(mul2_const_m).get_node_shared_ptr());
fc_input_scale = ov::op::util::make_try_fold<ov::op::v1::Multiply>(fc_input_scale, mul2_op_const);
}

if (ov::shape_size(scale->output(0).get_shape()) > 1) {
fc_input_scale = transpose->clone_with_new_inputs({ scale->output(0), transpose_const });
result_nodes.push_back(fc_input_scale);
std::shared_ptr<ov::Node> new_fc = nullptr;
if (with_zero_point) {
new_fc =
std::make_shared<op::FullyConnectedCompressed>(fc_input_a, fc_input_b, fc_input_bias, fc_input_scale, fc_input_zp, fc->get_output_type());
} else {
new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a, fc_input_b, fc_input_bias, fc_input_scale, fc->get_output_type());
}

if (with_zero_point && ov::shape_size(optional_zero_point->output(0).get_shape()) > 1) {
fc_input_zp = transpose->clone_with_new_inputs({ optional_zero_point->output(0), transpose_const });
result_nodes.push_back(fc_input_zp);
result_nodes.push_back(new_fc);
new_fc->set_friendly_name(fc->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), result_nodes);
ov::replace_node(fc, new_fc);
} else {
std::shared_ptr<ov::Node> fc_input_b = pattern_map.count(weights_param_reshape_m) ? pattern_map.at(weights_param_reshape_m).get_node_shared_ptr()
: pattern_map.at(weights_param_m).get_node_shared_ptr();
std::shared_ptr<ov::Node> fc_input_scale = scale;
std::shared_ptr<ov::Node> fc_input_zp = optional_zero_point;
std::shared_ptr<ov::Node> fc_input_bias = pattern_map.at(bias_m).get_node_shared_ptr();
std::vector<std::shared_ptr<ov::Node>> result_nodes = {};

if (has_transpose) {
const auto& transpose = pattern_map.at(transpose_m).get_node_shared_ptr();
std::shared_ptr<ov::Node> transpose_const = pattern_map.at(transpose_const_m).get_node_shared_ptr();
if (ov::shape_size(transpose_const->get_shape()) != fc_input_b->get_output_partial_shape(0).size()) {
std::vector<int32_t> new_order(fc_input_b->get_output_partial_shape(0).size());
std::iota(new_order.begin(), new_order.end(), 0);
std::swap(new_order[new_order.size() - 1], new_order[new_order.size() - 2]);
transpose_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{new_order.size()}, new_order);
}

fc_input_b = transpose->clone_with_new_inputs({fc_input_b->output(0), transpose_const});
result_nodes.push_back(fc_input_b);

if (ov::shape_size(scale->output(0).get_shape()) > 1) {
fc_input_scale = transpose->clone_with_new_inputs({scale->output(0), transpose_const});
result_nodes.push_back(fc_input_scale);
}

if (with_zero_point && ov::shape_size(optional_zero_point->output(0).get_shape()) > 1) {
fc_input_zp = transpose->clone_with_new_inputs({optional_zero_point->output(0), transpose_const});
result_nodes.push_back(fc_input_zp);
}
}
}

if (pattern_map.count(mul2_m)) {
auto mul2_op_const = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(mul2_const_m).get_node_shared_ptr());
fc_input_scale = ov::op::util::make_try_fold<ov::op::v1::Multiply>(fc_input_scale, mul2_op_const);
}
if (pattern_map.count(mul2_m)) {
auto mul2_op_const = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(mul2_const_m).get_node_shared_ptr());
fc_input_scale = ov::op::util::make_try_fold<ov::op::v1::Multiply>(fc_input_scale, mul2_op_const);
}

std::shared_ptr<ov::Node> new_fc = nullptr;
if (with_zero_point) {
new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a,
fc_input_b,
fc_input_bias,
fc_input_scale,
fc_input_zp,
fc->get_output_type());
} else {
new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a,
fc_input_b,
fc_input_bias,
fc_input_scale,
fc->get_output_type());
}
std::shared_ptr<ov::Node> new_fc = nullptr;
if (with_zero_point) {
new_fc =
std::make_shared<op::FullyConnectedCompressed>(fc_input_a, fc_input_b, fc_input_bias, fc_input_scale, fc_input_zp, fc->get_output_type());
} else {
new_fc = std::make_shared<op::FullyConnectedCompressed>(fc_input_a, fc_input_b, fc_input_bias, fc_input_scale, fc->get_output_type());
}

result_nodes.push_back(new_fc);
new_fc->set_friendly_name(fc->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), result_nodes);
ov::replace_node(fc, new_fc);
result_nodes.push_back(new_fc);
new_fc->set_friendly_name(fc->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), result_nodes);
ov::replace_node(fc, new_fc);
}

return true;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ConvertMatMulToFullyConnected::ConvertMatMulToFullyConnected(bool supports_immad
};
auto weights_path = [&static_rank_gt_1](const ov::Output<ov::Node>& output) {
const auto& pshape = output.get_partial_shape();
return ov::op::util::is_on_constant_path(output) &&
return ov::op::util::is_on_constant_or_param_path(output) &&
static_rank_gt_1(output) &&
pshape.is_static();
};
Expand Down