Skip to content

Commit

Permalink
[LPT][TESTS] ReshapeTransofrmation fix #2 and functional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Aug 25, 2021
1 parent 0a6d11b commit 881455c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,10 @@ void reshapeDequantizationConstant(const std::shared_ptr<opset1::Reshape>& resha
return constant;
}

std::shared_ptr<Node> updatedConstant = constant;
Shape newOperationConstantBroadcastedShape = constant->get_shape();
// add dimensions to broadcast values
if (newOperationConstantBroadcastedShape.size() == 2ul) {
newOperationConstantBroadcastedShape.push_back(dimensionsToBroadcast);

const auto unsqueezeConst = opset1::Constant::create(element::i32, Shape{}, { 2 });
updatedConstant = fold<opset1::Unsqueeze>(constant, unsqueezeConst);
newOperationConstantBroadcastedShape[0] = dimensionsToBroadcast;
} else {
newOperationConstantBroadcastedShape[2] = dimensionsToBroadcast;
}
Expand All @@ -97,7 +93,7 @@ void reshapeDequantizationConstant(const std::shared_ptr<opset1::Reshape>& resha
Shape{ newOperationConstantBroadcastedShape.size() },
newOperationConstantBroadcastedShape);

return fold<opset1::Broadcast>(updatedConstant, targetShapeConstant);
return fold<opset1::Broadcast>(constant, targetShapeConstant);
};

const std::shared_ptr<Node> broadcastedConstant = getBCastedConst(originalConstant, dimensionsToBroadcast);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,30 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
{{0.1f, 0.01f, 0.1f}, ngraph::element::f32, {1, 3}}
}
}
},
// U8: without subtract 2D -> 2D
{
{ Dimension::dynamic(), 2 },
{ -1, 6 },
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{
{ngraph::element::f32},
{},
{{0.1f, 0.02f}, ngraph::element::f32, {1, 2}}
}
},
{
ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{
{ngraph::element::f32},
{},
{{0.1f, 0.02f, 0.1f, 0.02f, 0.1f, 0.02f}, ngraph::element::f32, {1, 6}}
}
}
}
};

Expand Down

0 comments on commit 881455c

Please sign in to comment.