Skip to content

Commit

Permalink
Support test_numpy_bridge and thread_local_has_grad (#38835)
Browse files Browse the repository at this point in the history
  • Loading branch information
veyron95 authored Jan 11, 2022
1 parent d3ba189 commit 29c211e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import numpy as np
import paddle.fluid as fluid
import warnings
from paddle.fluid.framework import _test_eager_guard, _in_eager_mode


class TestImperativeNumpyBridge(unittest.TestCase):
def test_tensor_from_numpy(self):
def func_tensor_from_numpy(self):
data_np = np.array([[2, 3, 1]]).astype('float32')
with fluid.dygraph.guard(fluid.CPUPlace()):
with warnings.catch_warnings(record=True) as w:
Expand All @@ -39,9 +40,18 @@ def test_tensor_from_numpy(self):
self.assertTrue(np.array_equal(var2.numpy(), data_np))
data_np[0][0] = -1
self.assertEqual(data_np[0][0], -1)
self.assertNotEqual(var2[0][0].numpy()[0], -1)
if _in_eager_mode():
# eager_mode, var2 is EagerTensor, is not subscriptable
self.assertNotEqual(var2.numpy()[0][0], -1)
else:
self.assertNotEqual(var2[0][0].numpy()[0], -1)
self.assertFalse(np.array_equal(var2.numpy(), data_np))

def test_func_tensor_from_numpy(self):
with _test_eager_guard():
self.func_tensor_from_numpy()
self.func_tensor_from_numpy()


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import paddle.nn as nn
import numpy as np
import threading
from paddle.fluid.framework import _test_eager_guard, _in_eager_mode


class SimpleNet(nn.Layer):
Expand All @@ -44,7 +45,7 @@ def thread_2_main(self):
x = net(x)
self.assertFalse(x.stop_gradient)

def test_main(self):
def func_main(self):
threads = []
for _ in range(10):
threads.append(threading.Thread(target=self.thread_1_main))
Expand All @@ -54,6 +55,11 @@ def test_main(self):
for t in threads:
t.join()

def test_main(self):
with _test_eager_guard():
self.func_main()
self.func_main()


if __name__ == "__main__":
unittest.main()

0 comments on commit 29c211e

Please sign in to comment.