diff --git a/python/paddle/fluid/tests/unittests/test_imperative_numpy_bridge.py b/python/paddle/fluid/tests/unittests/test_imperative_numpy_bridge.py index 772dd913e4d20..4f3089baffdd3 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_numpy_bridge.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_numpy_bridge.py @@ -16,10 +16,11 @@ import numpy as np import paddle.fluid as fluid import warnings +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode class TestImperativeNumpyBridge(unittest.TestCase): - def test_tensor_from_numpy(self): + def func_tensor_from_numpy(self): data_np = np.array([[2, 3, 1]]).astype('float32') with fluid.dygraph.guard(fluid.CPUPlace()): with warnings.catch_warnings(record=True) as w: @@ -39,9 +40,18 @@ def test_tensor_from_numpy(self): self.assertTrue(np.array_equal(var2.numpy(), data_np)) data_np[0][0] = -1 self.assertEqual(data_np[0][0], -1) - self.assertNotEqual(var2[0][0].numpy()[0], -1) + if _in_eager_mode(): + # eager_mode, var2 is EagerTensor, is not subscriptable + self.assertNotEqual(var2.numpy()[0][0], -1) + else: + self.assertNotEqual(var2[0][0].numpy()[0], -1) self.assertFalse(np.array_equal(var2.numpy(), data_np)) + def test_func_tensor_from_numpy(self): + with _test_eager_guard(): + self.func_tensor_from_numpy() + self.func_tensor_from_numpy() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py index d81849725d75a..f54e50953f131 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py @@ -18,6 +18,7 @@ import paddle.nn as nn import numpy as np import threading +from paddle.fluid.framework import _test_eager_guard, _in_eager_mode class SimpleNet(nn.Layer): @@ -44,7 +45,7 @@ def thread_2_main(self): x = net(x) self.assertFalse(x.stop_gradient) - def test_main(self): + def func_main(self): threads = [] for _ in range(10): threads.append(threading.Thread(target=self.thread_1_main)) @@ -54,6 +55,11 @@ def test_main(self): for t in threads: t.join() + def test_main(self): + with _test_eager_guard(): + self.func_main() + self.func_main() + if __name__ == "__main__": unittest.main()