From 67837ee38a07017aead193a7695bbe051ba52990 Mon Sep 17 00:00:00 2001 From: CtfGo Date: Wed, 16 Jun 2021 06:41:32 +0000 Subject: [PATCH] fix ci ut converage --- .../tests/unittests/test_share_data_op.py | 36 ++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_share_data_op.py b/python/paddle/fluid/tests/unittests/test_share_data_op.py index 295c45578949a..bf4be9c975849 100644 --- a/python/paddle/fluid/tests/unittests/test_share_data_op.py +++ b/python/paddle/fluid/tests/unittests/test_share_data_op.py @@ -37,22 +37,50 @@ def get_places(self): places.append(core.CUDAPlace(0)) return places - def check_with_place(self, place): + def check_with_tensor(self, place): scope = core.Scope() np_array = np.random.rand(2, 3, 5).astype("float32") # initialize input and output variable - x = scope.var('Input').get_tensor() + x = scope.var('X').get_tensor() x.set(np_array, place) out = scope.var("Out").get_tensor() - op = Operator("share_data", Input="Input", Out="Out") + op = Operator("share_data", Input="X", Out="Out") op.run(scope, place) self.assertTrue(np.allclose(np_array, out)) + def check_with_selected_rows(self, place): + scope = core.Scope() + x_rows = [0, 1, 5, 4, 19] + x_height = 20 + row_numel = 2 + np_array = np.ones((len(x_rows), row_numel)).astype("float32") + + # initialize input variable + x = scope.var('X').get_selected_rows() + x.set_rows(x_rows) + x.set_height(x_height) + x_tensor = x.get_tensor() + x_tensor.set(np_array, place) + + # initialize the Out variable + out = scope.var("Out").get_selected_rows() + out_tensor = out.get_tensor() + + op = Operator("share_data", Input="X", Out="Out") + op.run(scope, place) + + out_height = out.height() + out_rows = out.rows() + self.assertTrue(np.allclose(np_array, out_tensor)) + self.assertEqual(x_height, out_height) + self.assertEqual(x_rows, out_rows) + def test_check_output(self): for place in self.get_places(): - self.check_with_place(place) + self.check_with_selected_rows(place) + self.check_with_tensor(place) if __name__ == '__main__':