Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4722,8 +4722,22 @@ def cumprod(

"""

if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast(x, dtype)
if dtype is not None:
target_dtype = convert_np_dtype_to_dtype_(dtype)
if x.dtype != target_dtype:
x = cast(x, target_dtype)
else:
converted_x_dtype = convert_dtype(x.dtype)
# use the default platform integer when integer dtype with a precision less than that of the default platform integer
if converted_x_dtype in {
"bool",
"uint16",
"int8",
"int16",
"int32",
"uint8",
}:
Copy link
Contributor

@lshpku lshpku May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码涉及几次dict查找,比较耗时,过不了API性能监控(这个CI查得很严,python性能也要管),建议你可以在文件的适当地方定义一个纯DataType的set,像这样:

_supported_int_like_types = {
    DataType.BOOL,
    DataType.INT8,
    DataType.INT16,
    DataType.INT32,
    DataType.INT64,
    DataType.UINT8
}

然后在这里直接判断:

if x.dtype in _supported_int_like_types:
    x = cast(x, "int64")

当然能不能过CI还要你自己试一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!我试一下~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样改了发现还是不能过耶,其实我们也可以通过手动设置 dtype 来达到 paddle 和 torch api 之间的切换,这样看起来也不算是精度问题吧,属于 api 设计的考量?目前可能去 paddleAPITest 仓库进行修改更加合适~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

提交了 PaddleAPITest 的修改 pr : PFCCLab/PaddleAPITest#212

x = cast(x, "int64")

if in_dynamic_or_pir_mode():
return _C_ops.cumprod(x, dim, False, False)
Expand Down
34 changes: 30 additions & 4 deletions test/legacy_test/test_cumprod_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def init_dtype(self):

def setUp(self):
paddle.enable_static()
self.target_dtype = None
self.init_dtype()
self.x = (np.random.rand(2, 3, 10, 10) + 0.5).astype(self.dtype)
self.place = []
Expand All @@ -281,10 +282,10 @@ def test_static_api(self):
def run(place):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('X', self.shape, dtype=self.dtype)
out = paddle.cumprod(x, -2)
out = paddle.cumprod(x, -2, self.target_dtype)
exe = paddle.static.Executor(place)
res = exe.run(feed={'X': self.x}, fetch_list=[out])
out_ref = np.cumprod(self.x, -2)
out_ref = np.cumprod(self.x, -2, self.target_dtype)

for r in res:
np.testing.assert_allclose(out_ref, r, rtol=1e-05)
Expand All @@ -297,15 +298,40 @@ def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
x = paddle.to_tensor(self.x)
out = paddle.cumprod(x, 1)
out_ref = np.cumprod(self.x, 1)
out = paddle.cumprod(x, 1, self.target_dtype)
out_ref = np.cumprod(self.x, 1, self.target_dtype)
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05)
paddle.enable_static()

for place in self.place:
run(place)


class TestCumprodAPICase1(TestCumprodAPI):
def init_dtype(self):
self.dtype = 'int32'
self.shape = [2, 3, 10, 10]


class TestCumprodAPICase2(TestCumprodAPI):
def init_dtype(self):
self.dtype = 'bool'
self.shape = [2, 3, 10, 10]


class TestCumprodAPICase3(TestCumprodAPI):
def init_dtype(self):
self.dtype = 'int64'
self.shape = [2, 3, 10, 10]


class TestCumprodAPICase4(TestCumprodAPI):
def init_dtype(self):
self.dtype = 'float32'
self.shape = [2, 3, 10, 10]
self.target_dtype = 'float64'


# test function.
class TestCumprodReverse(TestCumprod):
def init_dtype(self):
Expand Down
Loading