diff --git a/test/legacy_test/test_normal.py b/test/legacy_test/test_normal.py index 5151a9f9411dc3..0e0a38a2b35fd8 100644 --- a/test/legacy_test/test_normal.py +++ b/test/legacy_test/test_normal.py @@ -64,15 +64,16 @@ def get_dtype(self): return 'float32' def static_api(self): + paddle.enable_static() shape = self.get_shape() ret_all_shape = copy.deepcopy(shape) ret_all_shape.insert(0, self.repeat_num) ret_all = np.zeros(ret_all_shape, self.dtype) main_program = paddle.static.Program() - if isinstance(self.mean, np.ndarray) and isinstance( - self.std, np.ndarray - ): - with paddle.static.program_guard(main_program): + with paddle.static.program_guard(main_program): + if isinstance(self.mean, np.ndarray) and isinstance( + self.std, np.ndarray + ): mean = paddle.static.data( 'Mean', self.mean.shape, self.mean.dtype ) @@ -89,9 +90,7 @@ def static_api(self): fetch_list=[out], ) ret_all[i] = ret[0] - return ret_all - elif isinstance(self.mean, np.ndarray): - with paddle.static.program_guard(main_program): + elif isinstance(self.mean, np.ndarray): mean = paddle.static.data( 'Mean', self.mean.shape, self.mean.dtype ) @@ -101,9 +100,7 @@ def static_api(self): for i in range(self.repeat_num): ret = exe.run(feed={'Mean': self.mean}, fetch_list=[out]) ret_all[i] = ret[0] - return ret_all - elif isinstance(self.std, np.ndarray): - with paddle.static.program_guard(main_program): + elif isinstance(self.std, np.ndarray): std = paddle.static.data('Std', self.std.shape, self.std.dtype) out = paddle.normal(self.mean, std, self.shape) @@ -111,16 +108,15 @@ def static_api(self): for i in range(self.repeat_num): ret = exe.run(feed={'Std': self.std}, fetch_list=[out]) ret_all[i] = ret[0] - return ret_all - else: - with paddle.static.program_guard(main_program): + else: out = paddle.normal(self.mean, self.std, self.shape) exe = paddle.static.Executor(self.place) for i in range(self.repeat_num): ret = exe.run(fetch_list=[out]) ret_all[i] = ret[0] - return ret_all + paddle.disable_static() + return ret_all def dygraph_api(self): paddle.disable_static(self.place) @@ -218,7 +214,6 @@ def test_errors(self): self.assertRaises(TypeError, paddle.normal, mean=1.0, std=std) self.assertRaises(TypeError, paddle.normal, shape=1) - self.assertRaises(TypeError, paddle.normal, shape=[1.0]) shape = paddle.static.data('Shape', [100], 'float32') @@ -261,15 +256,16 @@ def get_dtype(self): return 'complex64' def static_api(self): + paddle.enable_static() shape = self.get_shape() ret_all_shape = copy.deepcopy(shape) ret_all_shape.insert(0, self.repeat_num) ret_all = np.zeros(ret_all_shape, self.dtype) main_program = paddle.static.Program() - if isinstance(self.mean, np.ndarray) and isinstance( - self.std, np.ndarray - ): - with paddle.static.program_guard(main_program): + with paddle.static.program_guard(main_program): + if isinstance(self.mean, np.ndarray) and isinstance( + self.std, np.ndarray + ): mean = paddle.static.data( 'Mean', self.mean.shape, self.mean.dtype ) @@ -286,9 +282,7 @@ def static_api(self): fetch_list=[out], ) ret_all[i] = ret[0] - return ret_all - elif isinstance(self.mean, np.ndarray): - with paddle.static.program_guard(main_program): + elif isinstance(self.mean, np.ndarray): mean = paddle.static.data( 'Mean', self.mean.shape, self.mean.dtype ) @@ -298,9 +292,7 @@ def static_api(self): for i in range(self.repeat_num): ret = exe.run(feed={'Mean': self.mean}, fetch_list=[out]) ret_all[i] = ret[0] - return ret_all - elif isinstance(self.std, np.ndarray): - with paddle.static.program_guard(main_program): + elif isinstance(self.std, np.ndarray): mean = paddle.static.data('Mean', self.std.shape, 'complex128') std = paddle.static.data('Std', self.std.shape, self.std.dtype) out = paddle.normal(mean, std, self.shape) @@ -317,20 +309,18 @@ def static_api(self): fetch_list=[out], ) ret_all[i] = ret[0] - return ret_all - else: - with paddle.static.program_guard(main_program): + else: mean = paddle.static.data('Mean', (), 'complex128') out = paddle.normal(mean, self.std, self.shape) exe = paddle.static.Executor(self.place) for i in range(self.repeat_num): ret = exe.run( - feed={'Mean': np.array(self.mean)}, - fetch_list=[out], + feed={'Mean': np.array(self.mean)}, fetch_list=[out] ) ret_all[i] = ret[0] - return ret_all + paddle.disable_static() + return ret_all def dygraph_api(self): paddle.disable_static(self.place)