Skip to content

Commit

Permalink
Review comments applied
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Jan 20, 2025
1 parent e5906de commit 03ebc69
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions src/common/low_precision_transformations/src/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,34 +87,37 @@ bool ConcatTransformation::transform(TransformationContext& context, ov::pass::p
const bool scalar_equal_constants_requested = concat_out_shape[axis].is_dynamic();

auto adaptConstForConcatenation = [scalar_equal_constants_requested](
std::shared_ptr<opset1::Constant> constant,
const Shape targetShape) -> std::shared_ptr<ov::Node> {
const std::shared_ptr<opset1::Constant>& constant,
const Shape& targetShape) {
if (scalar_equal_constants_requested) {
OPENVINO_ASSERT(targetShape.empty(), "scalar_equal_constants_requested implies targetShape is empty");
return std::make_shared<opset1::Constant>(*constant, ov::Shape{});
} else {
auto targetShapeConst = std::make_shared<opset1::Constant>(element::i64, Shape{ targetShape.size() }, targetShape);
auto broadcast = fold<ov::opset1::Broadcast>(constant, targetShapeConst);
return broadcast;
auto bcastedConst = ov::as_type_ptr<opset1::Constant>(fold<ov::opset1::Broadcast>(constant, targetShapeConst));
OPENVINO_ASSERT(bcastedConst, "adaptConstForConcatenation must return constant");
return bcastedConst;
}
};

bool someDqInLowPrecision = std::any_of(
const bool someDqInLowPrecision = std::any_of(
layerDequantizations.begin(),
layerDequantizations.end(),
[](const FakeQuantizeDequantization& value) { return value.isLowPrecision(); });

bool someDqInFpPrecision = std::any_of(
const bool someDqInFpPrecision = std::any_of(
layerDequantizations.begin(),
layerDequantizations.end(),
[](const FakeQuantizeDequantization& value) { return !value.isLowPrecision(); });

bool DqWithDifferentPrecision = someDqInLowPrecision && someDqInFpPrecision;
const bool DqWithDifferentPrecision = someDqInLowPrecision && someDqInFpPrecision;

OutputVector dataNodes;
NodeVector convertNodes;
NodeVector subConstants;
NodeVector mulConstants;

using ConstVector = std::vector<std::shared_ptr<opset1::Constant>>;
ConstVector subConstants;
ConstVector mulConstants;
std::shared_ptr<opset1::Convert> subtractConvert = nullptr;
for (size_t i = 0; i < layerDequantizations.size(); ++i) {
const auto& dequantization = layerDequantizations[i];
Expand Down Expand Up @@ -151,7 +154,9 @@ bool ConcatTransformation::transform(TransformationContext& context, ov::pass::p
subtractConvert = dequantization.subtractConvert;
}
} else if (dequantization.subtractConvert != nullptr) {
subtractInput = foldConvert(subtractInput, dequantization.subtractConvert->get_convert_element_type());
const auto& dstType = dequantization.subtractConvert->get_convert_element_type();
subtractInput = ov::as_type_ptr<opset1::Constant>(foldConvert(subtractInput, dstType));
OPENVINO_ASSERT(subtractInput, "foldConvert must finish successfully for the concatenated subtract constant");
NetworkHelper::copyInfo(dequantization.subtractConvert, subtractInput);
}
subConstants.push_back(subtractInput);
Expand All @@ -175,23 +180,19 @@ bool ConcatTransformation::transform(TransformationContext& context, ov::pass::p
lastDequantization = convert;
}

auto concat_constants_if_needed = [&](const NodeVector& constants) {
auto concat_constants_if_needed = [&](const ConstVector& constants) -> std::shared_ptr<ov::Node> {
OPENVINO_ASSERT(!constants.empty(), "concat_constants_if_needed expects non empty constants vec");
if (constants.size() == 1ul) {
return constants[0];
}
if (scalar_equal_constants_requested) {
if (ov::shape_size(constants[0]->get_output_shape(0)) == 1) {
const auto ref_value = ov::as_type_ptr<ov::op::v0::Constant>(constants[0])->cast_vector<float>();
bool all_constants_are_equal = true;
for (size_t i = 1ul; i < constants.size(); i++) {
const auto cur_value = ov::as_type_ptr<ov::op::v0::Constant>(constants[i])->cast_vector<float>();
if (ref_value != cur_value) {
all_constants_are_equal = false;
}
}
if (all_constants_are_equal)
if (std::all_of(constants.cbegin() + 1, constants.cend(), [&ref_value](const auto& constant) {
return constant->cast_vector<float>() == ref_value;
})) {
return constants[0];
}
}
OPENVINO_THROW("in case of dynamic concatenation dim all constants must be scalar and equal");
}
Expand Down

0 comments on commit 03ebc69

Please sign in to comment.