Skip to content

Commit 998403f

Browse files
committed
fix outer accuracy
1 parent 4aee08b commit 998403f

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

python/paddle/tensor/math.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2953,6 +2953,16 @@ def outer(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
29532953
else:
29542954
ny = y.reshape((1, -1))
29552955

2956+
if x.dtype in [
2957+
paddle.int32,
2958+
paddle.int64,
2959+
DataType.INT32,
2960+
DataType.INT64,
2961+
VarDesc.VarType.INT32,
2962+
VarDesc.VarType.INT64,
2963+
]:
2964+
return nx * ny
2965+
29562966
if in_dynamic_mode():
29572967
return _C_ops.matmul(nx, ny, False, False)
29582968

test/legacy_test/test_outer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@ def test_multiply_static(self):
8787
res = self._run_static_graph_case(x_data, y_data)
8888
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
8989

90+
# test static computation graph: 3-d int32 big array
91+
x_data = np.random.randint(-80000, 80000, [5, 10, 10]).astype(np.int32)
92+
y_data = np.random.randint(-80000, 80000, [2, 10]).astype(np.int32)
93+
res = self._run_static_graph_case(x_data, y_data)
94+
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
95+
96+
# test static computation graph: 3-d int64 big array
97+
x_data = np.random.randint(-80000, 80000, [5, 10, 10]).astype(np.int64)
98+
y_data = np.random.randint(-80000, 80000, [2, 10]).astype(np.int64)
99+
res = self._run_static_graph_case(x_data, y_data)
100+
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
101+
90102
def test_multiply_dynamic(self):
91103
# test dynamic computation graph: 3-d array
92104
x_data = np.random.rand(5, 10, 10).astype(np.float64)
@@ -138,6 +150,18 @@ def test_multiply_dynamic(self):
138150
res = self._run_dynamic_graph_case(x_data, y_data)
139151
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
140152

153+
# test dynamic computation graph: 3-d int32 big array
154+
x_data = np.random.randint(-80000, 80000, [5, 10, 10]).astype(np.int32)
155+
y_data = np.random.randint(-80000, 80000, [2, 10]).astype(np.int32)
156+
res = self._run_dynamic_graph_case(x_data, y_data)
157+
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
158+
159+
# test dynamic computation graph: 3-d int64 big array
160+
x_data = np.random.randint(-80000, 80000, [5, 10, 10]).astype(np.int64)
161+
y_data = np.random.randint(-80000, 80000, [2, 10]).astype(np.int64)
162+
res = self._run_dynamic_graph_case(x_data, y_data)
163+
np.testing.assert_allclose(res, np.outer(x_data, y_data), rtol=1e-05)
164+
141165

142166
class TestMultiplyError(unittest.TestCase):
143167

0 commit comments

Comments
 (0)