File tree Expand file tree Collapse file tree 1 file changed +10
-10
lines changed Expand file tree Collapse file tree 1 file changed +10
-10
lines changed Original file line number Diff line number Diff line change 115115 VarDesc .VarType .FP64 ,
116116]
117117
118+ _supported_int_like_types_lt_int64 = {
119+ DataType .BOOL ,
120+ DataType .INT8 ,
121+ DataType .INT16 ,
122+ DataType .INT32 ,
123+ DataType .UINT8 ,
124+ }
125+
118126
119127def _get_reduce_axis (axis , x ):
120128 """
@@ -4727,16 +4735,8 @@ def cumprod(
47274735 if x .dtype != target_dtype :
47284736 x = cast (x , target_dtype )
47294737 else :
4730- converted_x_dtype = convert_dtype (x .dtype )
4731- # use the default platform integer when integer dtype with a precision less than that of the default platform integer
4732- if converted_x_dtype in {
4733- "bool" ,
4734- "uint16" ,
4735- "int8" ,
4736- "int16" ,
4737- "int32" ,
4738- "uint8" ,
4739- }:
4738+ if x .dtype in _supported_int_like_types_lt_int64 :
4739+ # use the default platform integer when integer dtype with a precision less than that of the default platform integer
47404740 x = cast (x , "int64" )
47414741
47424742 if in_dynamic_or_pir_mode ():
You can’t perform that action at this time.
0 commit comments