Skip to content

Commit eb2a405

Browse files
DrRyanHuangLuckycheng222
authored andcommitted
[BUG Fix] Fix cumsum dtype bug (PaddlePaddle#74830)
1 parent 9e4df03 commit eb2a405

File tree

2 files changed

+29
-26
lines changed

2 files changed

+29
-26
lines changed

python/paddle/tensor/math.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4230,15 +4230,19 @@ def cumsum(
42304230
flatten = True
42314231
else:
42324232
flatten = False
4233-
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
4234-
x = cast(x, dtype)
4235-
elif isinstance(x, paddle.Tensor) and x.dtype in [
4236-
paddle.uint8,
4237-
paddle.int8,
4238-
paddle.int16,
4239-
paddle.int32,
4240-
]:
4241-
x = cast(x, "int64")
4233+
4234+
if dtype is None:
4235+
if x.dtype in [
4236+
paddle.uint8,
4237+
paddle.int8,
4238+
paddle.int16,
4239+
paddle.int32,
4240+
]:
4241+
x = cast(x, "int64")
4242+
else:
4243+
dtype = convert_np_dtype_to_dtype_(dtype)
4244+
if x.dtype != dtype:
4245+
x = cast(x, dtype)
42424246

42434247
if in_dynamic_or_pir_mode():
42444248
if axis is None:

test/legacy_test/test_cumsum_op.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import paddle.inference as paddle_infer
2929
from paddle import base
3030
from paddle.base import core
31+
from paddle.framework import convert_np_dtype_to_dtype_
3132

3233

3334
class TestCumsumOp(unittest.TestCase):
@@ -273,6 +274,13 @@ def run_cases(self):
273274
z = np.cumsum(data_np, axis=-2)
274275
np.testing.assert_array_equal(z, y.numpy())
275276

277+
# test data type
278+
data_np = np.arange(12).reshape(3, 4).astype(np.int16)
279+
data = paddle.to_tensor(data_np)
280+
y = paddle.cumsum(data, axis=0, dtype='int32')
281+
z = np.cumsum(data_np, axis=0, dtype="int32")
282+
np.testing.assert_equal(convert_np_dtype_to_dtype_(z.dtype), y.dtype)
283+
276284
def run_static_uint8(self, use_gpu=False):
277285
with paddle.static.program_guard(paddle.static.Program()):
278286
data_np = np.random.random((100, 100)).astype(np.uint8)
@@ -281,6 +289,7 @@ def run_static_uint8(self, use_gpu=False):
281289
y2 = paddle.cumsum(x, axis=0)
282290
y3 = paddle.cumsum(x, axis=-1)
283291
y4 = paddle.cumsum(x, axis=-2)
292+
y5 = paddle.cumsum(x, axis=-1, dtype='int32')
284293
place = base.CUDAPlace(0) if use_gpu else base.CPUPlace()
285294
exe = base.Executor(place)
286295
exe.run(paddle.static.default_startup_program())
@@ -291,6 +300,7 @@ def run_static_uint8(self, use_gpu=False):
291300
y2,
292301
y3,
293302
y4,
303+
y5,
294304
],
295305
)
296306
z = np.cumsum(data_np)
@@ -301,6 +311,8 @@ def run_static_uint8(self, use_gpu=False):
301311
np.testing.assert_allclose(z, out[2], rtol=1e-05)
302312
z = np.cumsum(data_np, axis=-2)
303313
np.testing.assert_allclose(z, out[3], rtol=1e-05)
314+
z = np.cumsum(data_np, axis=-1, dtype="int32")
315+
np.testing.assert_equal(z.dtype, out[4].dtype)
304316

305317
def run_static_int8(self, use_gpu=False):
306318
with paddle.static.program_guard(paddle.static.Program()):
@@ -310,7 +322,7 @@ def run_static_int8(self, use_gpu=False):
310322
y2 = paddle.cumsum(x, axis=0)
311323
y3 = paddle.cumsum(x, axis=-1)
312324
y4 = paddle.cumsum(x, axis=-2)
313-
325+
y5 = paddle.cumsum(x, axis=-1, dtype='int16')
314326
place = base.CUDAPlace(0) if use_gpu else base.CPUPlace()
315327
exe = base.Executor(place)
316328
exe.run(paddle.static.default_startup_program())
@@ -321,6 +333,7 @@ def run_static_int8(self, use_gpu=False):
321333
y2,
322334
y3,
323335
y4,
336+
y5,
324337
],
325338
)
326339
z = np.cumsum(data_np)
@@ -331,6 +344,8 @@ def run_static_int8(self, use_gpu=False):
331344
np.testing.assert_allclose(z, out[2], rtol=1e-05)
332345
z = np.cumsum(data_np, axis=-2)
333346
np.testing.assert_allclose(z, out[3], rtol=1e-05)
347+
z = np.cumsum(data_np, axis=-1, dtype="int16")
348+
np.testing.assert_equal(z.dtype, out[4].dtype)
334349

335350
def run_static_int16(self, use_gpu=False):
336351
with paddle.static.program_guard(paddle.static.Program()):
@@ -883,22 +898,6 @@ def test_check_grad(self):
883898
create_test_bf16_class(TestSumOpReverseExclusive)
884899

885900

886-
class BadInputTest(unittest.TestCase):
887-
def test_error(self):
888-
paddle.enable_static()
889-
with paddle.static.program_guard(
890-
paddle.static.Program(), paddle.static.Program()
891-
):
892-
893-
def test_bad_x():
894-
data = [1, 2, 4]
895-
result = paddle.cumsum(data, axis=0)
896-
897-
with self.assertRaises(TypeError):
898-
test_bad_x()
899-
paddle.disable_static()
900-
901-
902901
class TestTensorAxis(unittest.TestCase):
903902
def setUp(self):
904903
paddle.seed(2022)

0 commit comments

Comments
 (0)