Skip to content

Commit 7960404

Browse files
committed
apply review
1 parent a6a2f67 commit 7960404

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

python/paddle/tensor/math.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,14 @@
115115
VarDesc.VarType.FP64,
116116
]
117117

118+
_supported_int_like_types_lt_int64 = {
119+
DataType.BOOL,
120+
DataType.INT8,
121+
DataType.INT16,
122+
DataType.INT32,
123+
DataType.UINT8,
124+
}
125+
118126

119127
def _get_reduce_axis(axis, x):
120128
"""
@@ -4727,16 +4735,8 @@ def cumprod(
47274735
if x.dtype != target_dtype:
47284736
x = cast(x, target_dtype)
47294737
else:
4730-
converted_x_dtype = convert_dtype(x.dtype)
4731-
# use the default platform integer when integer dtype with a precision less than that of the default platform integer
4732-
if converted_x_dtype in {
4733-
"bool",
4734-
"uint16",
4735-
"int8",
4736-
"int16",
4737-
"int32",
4738-
"uint8",
4739-
}:
4738+
if x.dtype in _supported_int_like_types_lt_int64:
4739+
# use the default platform integer when integer dtype with a precision less than that of the default platform integer
47404740
x = cast(x, "int64")
47414741

47424742
if in_dynamic_or_pir_mode():

0 commit comments

Comments
 (0)