diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 5e799785d204db..e483e5b197b18f 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -527,6 +527,8 @@ def amp_guard( raise ValueError("level should be O0, OD, O1 or O2.") # check amp_dtype: float16 or bfloat16 + if isinstance(dtype, paddle.base.core.DataType): + dtype = dtype.name dtype = dtype.lower() if enable: if dtype not in ['float16', 'bfloat16']: diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index fd8d986fb27e9a..c942de04c87632 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -1422,51 +1422,64 @@ def convert_np_dtype_to_proto_type( """ # Convert the data type string to numpy data type. - if isinstance(np_dtype, str) and np_dtype == "bfloat16": - dtype = np.uint16 - elif isinstance(np_dtype, str) and np_dtype == "float8_e4m3fn": - dtype = 'float8_e4m3fn' - elif isinstance(np_dtype, str) and np_dtype == "float8_e5m2": - dtype = 'float8_e5m2' - else: - dtype = np.dtype(np_dtype) - - if dtype == np.float32: - return core.VarDesc.VarType.FP32 - elif dtype == np.float64: - return core.VarDesc.VarType.FP64 - elif dtype == 'float8_e4m3fn': - return core.VarDesc.VarType.FP8_E4M3FN - elif dtype == 'float8_e5m2': - return core.VarDesc.VarType.FP8_E5M2 - elif dtype == np.float16: - return core.VarDesc.VarType.FP16 - elif dtype == np.int32: - return core.VarDesc.VarType.INT32 - elif dtype == np.int16: - return core.VarDesc.VarType.INT16 - elif dtype == np.int64: - return core.VarDesc.VarType.INT64 - elif dtype == np.bool_: - return core.VarDesc.VarType.BOOL - elif dtype == np.uint16: - # since there is still no support for bfloat16 in NumPy, - # uint16 is used for casting bfloat16 - return core.VarDesc.VarType.BF16 - elif dtype == np.uint8: - return core.VarDesc.VarType.UINT8 - elif dtype == np.int8: - return core.VarDesc.VarType.INT8 - elif dtype == np.complex64: - return core.VarDesc.VarType.COMPLEX64 - elif dtype == np.complex128: - return core.VarDesc.VarType.COMPLEX128 + + str_to_var_type = { + 'float32': core.VarDesc.VarType.FP32, + 'float64': core.VarDesc.VarType.FP64, + 'float16': core.VarDesc.VarType.FP16, + 'int32': core.VarDesc.VarType.INT32, + 'int16': core.VarDesc.VarType.INT16, + 'int64': core.VarDesc.VarType.INT64, + 'bool': core.VarDesc.VarType.BOOL, + 'uint8': core.VarDesc.VarType.UINT8, + 'int8': core.VarDesc.VarType.INT8, + 'complex64': core.VarDesc.VarType.COMPLEX64, + 'complex128': core.VarDesc.VarType.COMPLEX128, + 'bfloat16': core.VarDesc.VarType.BF16, + 'float8_e4m3fn': core.VarDesc.VarType.FP8_E4M3FN, + 'float8_e5m2': core.VarDesc.VarType.FP8_E5M2, + } + + np_dtype_to_var_type = { + np.dtype("float32"): core.VarDesc.VarType.FP32, + np.dtype("float64"): core.VarDesc.VarType.FP64, + np.dtype("float16"): core.VarDesc.VarType.FP16, + np.dtype("int32"): core.VarDesc.VarType.INT32, + np.dtype("int16"): core.VarDesc.VarType.INT16, + np.dtype("int64"): core.VarDesc.VarType.INT64, + np.dtype("bool_"): core.VarDesc.VarType.BOOL, + np.dtype("uint16"): core.VarDesc.VarType.BF16, + np.dtype("uint8"): core.VarDesc.VarType.UINT8, + np.dtype("int8"): core.VarDesc.VarType.INT8, + np.dtype("complex64"): core.VarDesc.VarType.COMPLEX64, + np.dtype("complex128"): core.VarDesc.VarType.COMPLEX128, + np.float32: core.VarDesc.VarType.FP32, + np.float64: core.VarDesc.VarType.FP64, + np.float16: core.VarDesc.VarType.FP16, + np.int32: core.VarDesc.VarType.INT32, + np.int16: core.VarDesc.VarType.INT16, + np.int64: core.VarDesc.VarType.INT64, + np.bool_: core.VarDesc.VarType.BOOL, + np.uint8: core.VarDesc.VarType.UINT8, + np.int8: core.VarDesc.VarType.INT8, + np.uint16: core.VarDesc.VarType.BF16, + np.complex64: core.VarDesc.VarType.COMPLEX64, + np.complex128: core.VarDesc.VarType.COMPLEX128, + } + + if isinstance(np_dtype, str): + if np_dtype in str_to_var_type: + return str_to_var_type[np_dtype] + dtype = np.dtype(np_dtype) + + if dtype in np_dtype_to_var_type: + return np_dtype_to_var_type[dtype] else: raise ValueError(f"Not supported numpy dtype {dtype}") def convert_np_dtype_to_dtype_( - np_dtype: np.dtype | str, + np_dtype: np.dtype | str | core.VarDesc.VarType | core.DataType, ) -> core.VarDesc.VarType | core.DataType: """ Convert the data type in numpy to the data type in Paddle. @@ -1480,8 +1493,12 @@ def convert_np_dtype_to_dtype_( """ if use_pir_api(): + if isinstance(np_dtype, core.DataType): + return np_dtype return pir.core.convert_np_dtype_to_dtype_(np_dtype) + if isinstance(np_dtype, core.VarDesc.VarType): + return np_dtype return convert_np_dtype_to_proto_type(np_dtype) diff --git a/python/paddle/incubate/nn/functional/int_bincount.py b/python/paddle/incubate/nn/functional/int_bincount.py index 9497658786a14c..eae65b25f301d7 100644 --- a/python/paddle/incubate/nn/functional/int_bincount.py +++ b/python/paddle/incubate/nn/functional/int_bincount.py @@ -15,7 +15,11 @@ import paddle from paddle import _C_ops from paddle.base.data_feeder import convert_dtype -from paddle.base.framework import in_dynamic_or_pir_mode +from paddle.base.framework import ( + convert_np_dtype_to_dtype_, + core, + in_dynamic_or_pir_mode, +) from paddle.base.layer_helper import LayerHelper @@ -77,6 +81,9 @@ def math_int_bincount(x, low, high, dtype): def int_bincount(x, low, high, dtype=None, name=None): if in_dynamic_or_pir_mode(): + if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): + dtype = convert_np_dtype_to_dtype_(dtype) + if paddle.is_compiled_with_xpu(): return math_int_bincount(x, low, high, dtype) else: diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 23a4539183ae85..61d68bb4c9f048 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1587,12 +1587,11 @@ def log_softmax( [-12.31326640, -1.31326640 , -0.31326640 , -15.31326640], [-3.44018970 , -2.44018970 , -1.44018970 , -0.44018970 ]]]) """ - if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)): dtype = convert_np_dtype_to_dtype_(dtype) if in_dynamic_or_pir_mode(): - if dtype is not None: + if dtype is not None and x.dtype != dtype: x = _C_ops.cast(x, dtype) return _C_ops.log_softmax(x, axis) else: diff --git a/python/paddle/pir/core.py b/python/paddle/pir/core.py index 5debf18d990726..01bfcb983c3750 100644 --- a/python/paddle/pir/core.py +++ b/python/paddle/pir/core.py @@ -100,6 +100,26 @@ } +str_to_paddle_type = { + "float32": DataType.FLOAT32, + "float64": DataType.FLOAT64, + "float16": DataType.FLOAT16, + "int32": DataType.INT32, + "int16": DataType.INT16, + "int64": DataType.INT64, + "bool": DataType.BOOL, + "bool_": DataType.BOOL, + "uint16": DataType.BFLOAT16, + "uint8": DataType.UINT8, + "int8": DataType.INT8, + "complex64": DataType.COMPLEX64, + "complex128": DataType.COMPLEX128, + "bfloat16": DataType.BFLOAT16, + "float8_e4m3fn": DataType.FLOAT8_E4M3FN, + "float8_e5m2": DataType.FLOAT8_E5M2, +} + + def convert_np_dtype_to_dtype_(np_dtype) -> DataType: """ Convert the data type in numpy to the data type in Paddle. @@ -113,17 +133,11 @@ def convert_np_dtype_to_dtype_(np_dtype) -> DataType: """ # Convert the data type string to numpy data type. - if isinstance(np_dtype, str) and np_dtype == "bfloat16": - # since there is still no support for bfloat16 in NumPy, - # uint16 is used for casting bfloat16 - dtype = np.dtype("uint16") - elif isinstance(np_dtype, str) and np_dtype == "float8_e4m3fn": - dtype = 'float8_e4m3fn' - elif isinstance(np_dtype, str) and np_dtype == "float8_e5m2": - dtype = 'float8_e5m2' - else: - dtype = np.dtype(np_dtype) - + if isinstance(np_dtype, str): + key = np_dtype.lower().strip() + if key in str_to_paddle_type: + return str_to_paddle_type[key] + dtype = np.dtype(np_dtype) if dtype in np_type_to_paddle_type: return np_type_to_paddle_type[dtype] else: diff --git a/python/paddle/sparse/unary.py b/python/paddle/sparse/unary.py index 2e1ff02ef0aea0..7d4eb96bda9c73 100644 --- a/python/paddle/sparse/unary.py +++ b/python/paddle/sparse/unary.py @@ -623,9 +623,13 @@ def cast( assert in_dynamic_or_pir_mode(), ( "Currently, Sparse API only support dynamic mode or pir mode." ) - if index_dtype and not isinstance(index_dtype, core.VarDesc.VarType): + if index_dtype and not isinstance( + index_dtype, (core.VarDesc.VarType, core.DataType) + ): index_dtype = convert_np_dtype_to_dtype_(index_dtype) - if value_dtype and not isinstance(value_dtype, core.VarDesc.VarType): + if value_dtype and not isinstance( + value_dtype, (core.VarDesc.VarType, core.DataType) + ): value_dtype = convert_np_dtype_to_dtype_(value_dtype) return _C_ops.sparse_cast(x, index_dtype, value_dtype) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index ddd11fede69694..1186a05d9e772f 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -3519,9 +3519,8 @@ def tril_indices( [[1, 2, 2, 3, 3, 3], [0, 0, 1, 0, 1, 2]]) """ - if not isinstance(dtype, core.VarDesc.VarType): + if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) - if not isinstance(row, int) or row < 0: raise TypeError("row should be a non-negative int") @@ -3600,7 +3599,8 @@ def triu_indices( [[0 0 0 0 1 1 1 1 2 2 2 3 3] [0 1 2 3 0 1 2 3 1 2 3 2 3]] """ - if not isinstance(dtype, core.VarDesc.VarType): + + if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) if not isinstance(row, int) or row < 0: diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 909cba7ae2bca6..355d40d8911113 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4283,8 +4283,11 @@ def cumsum_( flatten = True else: flatten = False - if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): - x = cast_(x, dtype) + if dtype is not None: + if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): + dtype = convert_np_dtype_to_dtype_(dtype) + if x.dtype != dtype: + x = cast_(x, dtype) if in_dynamic_mode(): if axis is None: @@ -4641,8 +4644,11 @@ def cumprod( dim = -1 x = x.flatten(0, len(x.shape) - 1) - if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): - x = cast(x, dtype) + if dtype is not None: + if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): + dtype = convert_np_dtype_to_dtype_(dtype) + if x.dtype != dtype: + x = cast_(x, dtype) if in_dynamic_or_pir_mode(): return _C_ops.cumprod(x, dim, False, False) @@ -4689,9 +4695,13 @@ def cumprod_( if dim is None: dim = -1 x = _C_ops.flatten_(x, 0, len(x.shape) - 1) - - if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype): - x = cast_(x, dtype) + if dtype is None: + dtype = x.dtype + else: + if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): + dtype = convert_np_dtype_to_dtype_(dtype) + if x.dtype != dtype: + x = cast_(x, dtype) if in_dynamic_mode(): return _C_ops.cumprod_(x, dim, False, False) @@ -4782,7 +4792,16 @@ def prod( check_dtype( dtype, 'dtype', - ['float32', 'float64', 'int32', 'int64', "float16", "uint16"], + [ + 'float32', + 'float64', + 'int32', + 'int64', + "float16", + "uint16", + "complex64", + "complex128", + ], 'prod', ) if x.dtype != convert_np_dtype_to_dtype_(dtype): diff --git a/test/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py b/test/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py index f785a5878a3215..43bb4ef8aa1d24 100644 --- a/test/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py +++ b/test/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py @@ -61,12 +61,15 @@ def amp_guard_white_op(self): data = paddle.to_tensor(data) with paddle.amp.amp_guard(True): out_fp16 = conv2d(data) + with paddle.amp.amp_guard(True, dtype=paddle.float16): + out_fp16_ = conv2d(data) with paddle.amp.amp_guard(False): out_fp32 = conv2d(data) self.assertTrue(data.dtype == paddle.float32) self.assertTrue(out_fp16.dtype == paddle.float16) + self.assertTrue(out_fp16_.dtype == paddle.float16) self.assertTrue(out_fp32.dtype == paddle.float32) def test_amp_guard_white_op(self): diff --git a/test/legacy_test/test_incubate_int_bincount.py b/test/legacy_test/test_incubate_int_bincount.py index 46f43cf791c35b..1d3cf9f69f3ba3 100644 --- a/test/legacy_test/test_incubate_int_bincount.py +++ b/test/legacy_test/test_incubate_int_bincount.py @@ -30,6 +30,12 @@ def test_basic(self): expected = np.array([2, 2, 2, 0]) np.testing.assert_array_equal(out.numpy(), expected) + def test_basic_2(self): + x = paddle.to_tensor([1, 2, 3, 1, 2, 3], dtype=paddle.int32) + out = int_bincount(x, low=1, high=4, dtype="int32") + expected = np.array([2, 2, 2, 0]) + np.testing.assert_array_equal(out.numpy(), expected) + def test_empty_input(self): x = paddle.to_tensor([], dtype=paddle.int32) out = int_bincount(x, low=0, high=10, dtype=paddle.int32) diff --git a/test/legacy_test/test_variable.py b/test/legacy_test/test_variable.py index e93e1ebdc823d4..aca3dc0b72cfe0 100644 --- a/test/legacy_test/test_variable.py +++ b/test/legacy_test/test_variable.py @@ -45,6 +45,15 @@ def test_np_dtype_convert(self): self.assertEqual(paddle.bool, convert("bool")) self.assertEqual(paddle.int8, convert("int8")) self.assertEqual(paddle.uint8, convert("uint8")) + self.assertEqual(paddle.float32, convert(paddle.float32)) + self.assertEqual(paddle.float16, convert(paddle.float16)) + self.assertEqual(paddle.float64, convert(paddle.float64)) + self.assertEqual(paddle.int32, convert(paddle.int32)) + self.assertEqual(paddle.int16, convert(paddle.int16)) + self.assertEqual(paddle.int64, convert(paddle.int64)) + self.assertEqual(paddle.bool, convert(paddle.bool)) + self.assertEqual(paddle.int8, convert(paddle.int8)) + self.assertEqual(paddle.uint8, convert(paddle.uint8)) def test_var(self): b = default_main_program().current_block()