Skip to content

Commit

Permalink
Merge pull request #12 from nosovmik/test_partShapefix
Browse files Browse the repository at this point in the history
Correction of partial_shape tests
  • Loading branch information
nosovmik authored Apr 16, 2021
2 parents f1d5e82 + da73518 commit 5310f92
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
25 changes: 23 additions & 2 deletions ngraph/test/frontend/paddlepaddle/partial_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,34 @@ static PartShape getTestShape_2in_2out() {
res.m_modelName = "2in_2out/2in_2out.pdmodel";
res.m_tensorName = "inputX1";
res.m_oldPartialShape = {1, 1, 3, 3};
res.m_newPartialShape = {2, 2, 4, 4};
res.m_newPartialShape = {2, 1, 3, 3};
return res;
}

static PartShape getTestShape_conv2d() {
PartShape res;
res.m_modelName = "conv2d_s/conv2d.pdmodel";
res.m_tensorName = "x";
res.m_oldPartialShape = {1, 3, 4, 4};
res.m_newPartialShape = {1, 3, 8, 8};
return res;
}

static PartShape getTestShape_conv2d_relu() {
PartShape res;
res.m_modelName = "conv2d_relu/conv2d_relu.pdmodel";
res.m_tensorName = "xxx";
res.m_oldPartialShape = {1, 3, 4, 4};
res.m_newPartialShape = {5, 3, 5, 5};
return res;
}

INSTANTIATE_TEST_CASE_P(PDPDPartialShapeTest, FrontEndPartialShapeTest,
::testing::Combine(
::testing::Values(BaseFEParam { PDPD, PATH_TO_MODELS }),
::testing::ValuesIn({ getTestShape_2in_2out() })
::testing::ValuesIn({ getTestShape_2in_2out(),
getTestShape_conv2d_relu(),
getTestShape_conv2d()
})
),
FrontEndPartialShapeTest::getTestCaseName);
12 changes: 9 additions & 3 deletions ngraph/test/frontend/shared/src/basic_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ TEST_P(FrontEndBasicTest, testInputModel_getInputsOutputs)
EXPECT_EQ(placesSet.size(), places.size());
std::for_each(places.begin(), places.end(), [&](Place::Ptr place) {
ASSERT_NE(place, nullptr);
EXPECT_GT(place->getNames().size(), 0);
std::vector<std::string> names;
ASSERT_NO_THROW(names = place->getNames());
EXPECT_GT(names.size(), 0);
cb(place);
});
};
Expand All @@ -82,7 +84,9 @@ TEST_P(FrontEndBasicTest, testInputModel_getPlaceByTensorName)
EXPECT_GT(places.size(), 0);
for (auto place : places) {
ASSERT_NE(place, nullptr);
for (auto name : place->getNames()) {
std::vector<std::string> names;
ASSERT_NO_THROW(names = place->getNames());
for (auto name : names) {
EXPECT_NE(name, std::string());
Place::Ptr placeByName;
ASSERT_NO_THROW(placeByName = m_inputModel->getPlaceByTensorName(name));
Expand Down Expand Up @@ -144,7 +148,9 @@ TEST_P(FrontEndBasicTest, testInputModel_overrideAll_empty)
ASSERT_NO_THROW(newPlaces = getCB());
ASSERT_EQ(newPlaces.size(), 0);
std::for_each(places.begin(), places.end(), [&](Place::Ptr place) {
for (auto name : place->getNames()) {
std::vector<std::string> names;
ASSERT_NO_THROW(names = place->getNames());
for (auto name : names) {
customCB(name);
}
});
Expand Down
8 changes: 5 additions & 3 deletions ngraph/test/frontend/shared/src/cut_specific_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ TEST_P(FrontEndCutModelTest, testOverrideInputs)
ASSERT_NO_THROW(m_inputModel->getInputs());
EXPECT_EQ(m_param.m_newInputs.size(), m_inputModel->getInputs().size());
for (auto newInput : m_inputModel->getInputs()) {
auto names = newInput->getNames();
std::vector<std::string> names;
ASSERT_NO_THROW(names = newInput->getNames());
bool found = false;
for (const auto& name: m_param.m_newInputs) {
if (std::find(names.begin(), names.begin(), name) != names.end()) {
Expand All @@ -93,7 +94,8 @@ TEST_P(FrontEndCutModelTest, testOverrideOutputs)
ASSERT_NO_THROW(m_inputModel->getOutputs());
EXPECT_EQ(m_param.m_newOutputs.size(), m_inputModel->getOutputs().size());
for (auto newOutput : m_inputModel->getOutputs()) {
auto names = newOutput->getNames();
std::vector<std::string> names;
ASSERT_NO_THROW(names = newOutput->getNames());
bool found = false;
for (const auto& name: m_param.m_newOutputs) {
if (std::find(names.begin(), names.begin(), name) != names.end()) {
Expand Down Expand Up @@ -165,7 +167,7 @@ TEST_P(FrontEndCutModelTest, testNewOutputs_func) {
ASSERT_NO_THROW(doLoadFromFile());
std::vector<Place::Ptr> newPlaces;
ASSERT_NO_THROW(newPlaces = constructNewOutputs());
ASSERT_NO_THROW(m_inputModel->overrideAllInputs(newPlaces));
ASSERT_NO_THROW(m_inputModel->overrideAllOutputs(newPlaces));

std::shared_ptr<ngraph::Function> function;
ASSERT_NO_THROW(function = m_frontEnd->convert(m_inputModel));
Expand Down

0 comments on commit 5310f92

Please sign in to comment.