Skip to content

Commit

Permalink
node_to_get_shape_value_of_indices_from_shape_source: add shape_path_…
Browse files Browse the repository at this point in the history
…precision parameter
  • Loading branch information
v-Golubev committed Nov 7, 2023
1 parent c240f88 commit b6e600f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,14 @@ TRANSFORMATIONS_API std::vector<Input<Node>> get_node_target_inputs(const std::s
TRANSFORMATIONS_API std::shared_ptr<Node> node_to_get_shape_value_of_indices_from_shape_node(
const std::shared_ptr<Node>& shape_node,
const std::vector<size_t>& indices,
const std::vector<std::shared_ptr<Node>>& copy_rt_info_from = {});
const std::vector<std::shared_ptr<Node>>& copy_rt_info_from = {},
const ov::element::Type& shape_path_precision = ov::element::i64);

TRANSFORMATIONS_API std::shared_ptr<Node> node_to_get_shape_value_of_indices_from_shape_source(
const Output<Node>& shape_source,
const std::vector<size_t>& indices,
const std::vector<std::shared_ptr<Node>>& copy_rt_info_from = {});
const std::vector<std::shared_ptr<Node>>& copy_rt_info_from = {},
const ov::element::Type& shape_path_precision = ov::element::i64);

TRANSFORMATIONS_API bool is_dequantization_subgraph(const Output<Node>& node);

Expand Down
17 changes: 11 additions & 6 deletions src/common/transformations/src/transformations/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,10 @@ std::vector<Input<Node>> get_node_target_inputs(const std::shared_ptr<Node>& nod
std::shared_ptr<ov::Node> node_to_get_shape_value_of_indices_from_shape_node(
const std::shared_ptr<ov::Node>& shape_node,
const std::vector<size_t>& indices,
const std::vector<std::shared_ptr<ov::Node>>& copy_rt_info_from) {
const auto& indices_op = v0::Constant::create(ov::element::i64, {indices.size()}, indices);
const auto& axis_op = v0::Constant::create(ov::element::i64, {}, {0});
const std::vector<std::shared_ptr<ov::Node>>& copy_rt_info_from,
const ov::element::Type& shape_path_precision) {
const auto& indices_op = v0::Constant::create(shape_path_precision, {indices.size()}, indices);
const auto& axis_op = v0::Constant::create(shape_path_precision, {}, {0});
auto op = make_try_fold<v7::Gather>(shape_node, indices_op, axis_op);
if (!copy_rt_info_from.empty())
ov::copy_runtime_info(copy_rt_info_from, {op, indices_op, axis_op});
Expand All @@ -224,11 +225,15 @@ std::shared_ptr<ov::Node> node_to_get_shape_value_of_indices_from_shape_node(
std::shared_ptr<ov::Node> node_to_get_shape_value_of_indices_from_shape_source(
const ov::Output<ov::Node>& shape_source,
const std::vector<size_t>& indices,
const std::vector<std::shared_ptr<ov::Node>>& copy_rt_info_from) {
const auto& shape_node = make_try_fold<v3::ShapeOf>(shape_source);
const std::vector<std::shared_ptr<ov::Node>>& copy_rt_info_from,
const ov::element::Type& shape_path_precision) {
const auto& shape_node = make_try_fold<v3::ShapeOf>(shape_source, shape_path_precision);
if (!copy_rt_info_from.empty())
ov::copy_runtime_info(copy_rt_info_from, shape_node);
return node_to_get_shape_value_of_indices_from_shape_node(shape_node, indices, copy_rt_info_from);
return node_to_get_shape_value_of_indices_from_shape_node(shape_node,
indices,
copy_rt_info_from,
shape_path_precision);
}

bool shapes_equal_except_dynamic_expected_batch(const ov::PartialShape& expected, const ov::PartialShape& actual) {
Expand Down

0 comments on commit b6e600f

Please sign in to comment.