Skip to content

Commit 4c55f8c

Browse files
authored
fix: some bugs..
1 parent 4dfc3db commit 4c55f8c

File tree

1 file changed

+53
-49
lines changed

1 file changed

+53
-49
lines changed

test/legacy_test/test_gather_op.py

Lines changed: 53 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -101,29 +101,6 @@ def config_dtype(self):
101101
self.x_type = "float16"
102102

103103

104-
@unittest.skipIf(
105-
not (core.is_compiled_with_cuda() or is_custom_device()),
106-
"only support compiled with CUDA.",
107-
)
108-
class TestGatherGPUCPUConsistency(unittest.TestCase):
109-
def test_gpu_cpu_consistency(self):
110-
with paddle.base.dygraph.guard():
111-
np.random.seed(42)
112-
x = np.random.rand(1000, 128).astype("float32")
113-
index = np.random.randint(0, 1000, size=(100,))
114-
cpu_out = paddle.gather(
115-
paddle.to_tensor(x, place=paddle.CPUPlace()),
116-
paddle.to_tensor(index),
117-
)
118-
gpu_out = paddle.gather(
119-
paddle.to_tensor(x, place=paddle.CUDAPlace(0)),
120-
paddle.to_tensor(index),
121-
)
122-
np.testing.assert_allclose(
123-
cpu_out.numpy(), gpu_out.numpy(), rtol=1e-6
124-
)
125-
126-
127104
@unittest.skipIf(
128105
not (core.is_compiled_with_cuda() or is_custom_device())
129106
or core.cudnn_version() < 8100
@@ -770,38 +747,65 @@ def test_out2(self):
770747
np.testing.assert_allclose(result, expected_output, rtol=1e-05)
771748

772749

750+
@unittest.skipIf(
751+
not (core.is_compiled_with_cuda() or is_custom_device()),
752+
"only support compiled with CUDA.",
753+
)
754+
class TestGatherGPUCPUConsistency(unittest.TestCase):
755+
def test_gpu_cpu_consistency(self):
756+
paddle.disable_static()
757+
np.random.seed(42)
758+
x = np.random.rand(1000, 128).astype("float32")
759+
index = np.random.randint(0, 1000, size=(100,))
760+
cpu_out = paddle.gather(
761+
paddle.to_tensor(x, place=paddle.CPUPlace()),
762+
paddle.to_tensor(index),
763+
)
764+
gpu_out = paddle.gather(
765+
paddle.to_tensor(x, place=paddle.CUDAPlace(0)),
766+
paddle.to_tensor(index),
767+
)
768+
np.testing.assert_allclose(
769+
cpu_out.numpy(), gpu_out.numpy(), rtol=1e-6
770+
)
771+
paddle.enable_static()
772+
773+
773774
class API_TestDygraphGather(unittest.TestCase):
774775
def test_out1(self):
775-
with paddle.base.dygraph.guard():
776-
input_1 = np.array([[1, 2], [3, 4], [5, 6]])
777-
index_1 = np.array([1, 2])
778-
input = paddle.to_tensor(input_1)
779-
index = paddle.to_tensor(index_1)
780-
output = paddle.gather(input, index)
781-
output_np = output.numpy()
782-
expected_output = np.array([[3, 4], [5, 6]])
783-
np.testing.assert_allclose(output_np, expected_output, rtol=1e-05)
776+
paddle.disable_static()
777+
input_1 = np.array([[1, 2], [3, 4], [5, 6]])
778+
index_1 = np.array([1, 2])
779+
input = paddle.to_tensor(input_1)
780+
index = paddle.to_tensor(index_1)
781+
output = paddle.gather(input, index)
782+
output_np = output.numpy()
783+
expected_output = np.array([[3, 4], [5, 6]])
784+
np.testing.assert_allclose(output_np, expected_output, rtol=1e-05)
785+
paddle.enable_static()
784786

785787
def test_out12(self):
786-
with paddle.base.dygraph.guard():
787-
input_1 = np.array([[1, 2], [3, 4], [5, 6]])
788-
index_1 = np.array([1, 2])
789-
x = paddle.to_tensor(input_1)
790-
index = paddle.to_tensor(index_1)
791-
output = paddle.gather(x, index, axis=0)
792-
output_np = output.numpy()
793-
expected_output = gather_numpy(input_1, index_1, axis=0)
794-
np.testing.assert_allclose(output_np, expected_output, rtol=1e-05)
788+
paddle.disable_static()
789+
input_1 = np.array([[1, 2], [3, 4], [5, 6]])
790+
index_1 = np.array([1, 2])
791+
x = paddle.to_tensor(input_1)
792+
index = paddle.to_tensor(index_1)
793+
output = paddle.gather(x, index, axis=0)
794+
output_np = output.numpy()
795+
expected_output = gather_numpy(input_1, index_1, axis=0)
796+
np.testing.assert_allclose(output_np, expected_output, rtol=1e-05)
797+
paddle.enable_static()
795798

796799
def test_zero_index(self):
797-
with paddle.base.dygraph.guard():
798-
x = paddle.to_tensor([[1, 2], [3, 4]])
799-
index = paddle.to_tensor(np.array([]).astype('int64'))
800-
for axis in range(len(x.shape)):
801-
out = paddle.gather(x, index, axis)
802-
expected_shape = list(x.shape)
803-
expected_shape[axis] = 0
804-
self.assertEqual(list(out.shape), expected_shape)
800+
paddle.disable_static()
801+
x = paddle.to_tensor([[1, 2], [3, 4]])
802+
index = paddle.to_tensor(np.array([]).astype('int64'))
803+
for axis in range(len(x.shape)):
804+
out = paddle.gather(x, index, axis)
805+
expected_shape = list(x.shape)
806+
expected_shape[axis] = 0
807+
self.assertEqual(list(out.shape), expected_shape)
808+
paddle.enable_static()
805809

806810
def test_large_data(self):
807811
if not (paddle.is_compiled_with_cuda() or is_custom_device()):

0 commit comments

Comments
 (0)