Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ def amp_guard(
raise ValueError("level should be O0, OD, O1 or O2.")

# check amp_dtype: float16 or bfloat16
if isinstance(dtype, paddle.base.core.DataType):
dtype = dtype.name
dtype = dtype.lower()
if enable:
if dtype not in ['float16', 'bfloat16']:
Expand Down
97 changes: 57 additions & 40 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,51 +1422,64 @@ def convert_np_dtype_to_proto_type(
"""

# Convert the data type string to numpy data type.
if isinstance(np_dtype, str) and np_dtype == "bfloat16":
dtype = np.uint16
elif isinstance(np_dtype, str) and np_dtype == "float8_e4m3fn":
dtype = 'float8_e4m3fn'
elif isinstance(np_dtype, str) and np_dtype == "float8_e5m2":
dtype = 'float8_e5m2'
else:
dtype = np.dtype(np_dtype)

if dtype == np.float32:
return core.VarDesc.VarType.FP32
elif dtype == np.float64:
return core.VarDesc.VarType.FP64
elif dtype == 'float8_e4m3fn':
return core.VarDesc.VarType.FP8_E4M3FN
elif dtype == 'float8_e5m2':
return core.VarDesc.VarType.FP8_E5M2
elif dtype == np.float16:
return core.VarDesc.VarType.FP16
elif dtype == np.int32:
return core.VarDesc.VarType.INT32
elif dtype == np.int16:
return core.VarDesc.VarType.INT16
elif dtype == np.int64:
return core.VarDesc.VarType.INT64
elif dtype == np.bool_:
return core.VarDesc.VarType.BOOL
elif dtype == np.uint16:
# since there is still no support for bfloat16 in NumPy,
# uint16 is used for casting bfloat16
return core.VarDesc.VarType.BF16
elif dtype == np.uint8:
return core.VarDesc.VarType.UINT8
elif dtype == np.int8:
return core.VarDesc.VarType.INT8
elif dtype == np.complex64:
return core.VarDesc.VarType.COMPLEX64
elif dtype == np.complex128:
return core.VarDesc.VarType.COMPLEX128

str_to_var_type = {
'float32': core.VarDesc.VarType.FP32,
'float64': core.VarDesc.VarType.FP64,
'float16': core.VarDesc.VarType.FP16,
'int32': core.VarDesc.VarType.INT32,
'int16': core.VarDesc.VarType.INT16,
'int64': core.VarDesc.VarType.INT64,
'bool': core.VarDesc.VarType.BOOL,
'uint8': core.VarDesc.VarType.UINT8,
'int8': core.VarDesc.VarType.INT8,
'complex64': core.VarDesc.VarType.COMPLEX64,
'complex128': core.VarDesc.VarType.COMPLEX128,
'bfloat16': core.VarDesc.VarType.BF16,
'float8_e4m3fn': core.VarDesc.VarType.FP8_E4M3FN,
'float8_e5m2': core.VarDesc.VarType.FP8_E5M2,
}

np_dtype_to_var_type = {
np.dtype("float32"): core.VarDesc.VarType.FP32,
np.dtype("float64"): core.VarDesc.VarType.FP64,
np.dtype("float16"): core.VarDesc.VarType.FP16,
np.dtype("int32"): core.VarDesc.VarType.INT32,
np.dtype("int16"): core.VarDesc.VarType.INT16,
np.dtype("int64"): core.VarDesc.VarType.INT64,
np.dtype("bool_"): core.VarDesc.VarType.BOOL,
np.dtype("uint16"): core.VarDesc.VarType.BF16,
np.dtype("uint8"): core.VarDesc.VarType.UINT8,
np.dtype("int8"): core.VarDesc.VarType.INT8,
np.dtype("complex64"): core.VarDesc.VarType.COMPLEX64,
np.dtype("complex128"): core.VarDesc.VarType.COMPLEX128,
np.float32: core.VarDesc.VarType.FP32,
np.float64: core.VarDesc.VarType.FP64,
np.float16: core.VarDesc.VarType.FP16,
np.int32: core.VarDesc.VarType.INT32,
np.int16: core.VarDesc.VarType.INT16,
np.int64: core.VarDesc.VarType.INT64,
np.bool_: core.VarDesc.VarType.BOOL,
np.uint8: core.VarDesc.VarType.UINT8,
np.int8: core.VarDesc.VarType.INT8,
np.uint16: core.VarDesc.VarType.BF16,
np.complex64: core.VarDesc.VarType.COMPLEX64,
np.complex128: core.VarDesc.VarType.COMPLEX128,
}

if isinstance(np_dtype, str):
if np_dtype in str_to_var_type:
return str_to_var_type[np_dtype]
dtype = np.dtype(np_dtype)

if dtype in np_dtype_to_var_type:
return np_dtype_to_var_type[dtype]
else:
raise ValueError(f"Not supported numpy dtype {dtype}")


def convert_np_dtype_to_dtype_(
np_dtype: np.dtype | str,
np_dtype: np.dtype | str | core.VarDesc.VarType | core.DataType,
) -> core.VarDesc.VarType | core.DataType:
"""
Convert the data type in numpy to the data type in Paddle.
Expand All @@ -1480,8 +1493,12 @@ def convert_np_dtype_to_dtype_(

"""
if use_pir_api():
if isinstance(np_dtype, core.DataType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的判断是不是就不需要了,一般会在外面判断了

if isinstance(np_dtype, (core.DataType, core.VarDesc.VarType)):
      ...

return np_dtype
return pir.core.convert_np_dtype_to_dtype_(np_dtype)

if isinstance(np_dtype, core.VarDesc.VarType):
return np_dtype
return convert_np_dtype_to_proto_type(np_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert_np_dtype_to_proto_type这个函数的这些内容,可以换成map的实现,提升性能:

infoflow 2025-08-25 15-39-07

同时可以直接 str映射到core.VarDesc.VarType,无需先转np.dtype,再转core.VarDesc.VarType



Expand Down
9 changes: 8 additions & 1 deletion python/paddle/incubate/nn/functional/int_bincount.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import paddle
from paddle import _C_ops
from paddle.base.data_feeder import convert_dtype
from paddle.base.framework import in_dynamic_or_pir_mode
from paddle.base.framework import (
convert_np_dtype_to_dtype_,
core,
in_dynamic_or_pir_mode,
)
from paddle.base.layer_helper import LayerHelper


Expand Down Expand Up @@ -77,6 +81,9 @@ def math_int_bincount(x, low, high, dtype):

def int_bincount(x, low, high, dtype=None, name=None):
if in_dynamic_or_pir_mode():
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
dtype = convert_np_dtype_to_dtype_(dtype)

if paddle.is_compiled_with_xpu():
return math_int_bincount(x, low, high, dtype)
else:
Expand Down
3 changes: 1 addition & 2 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,12 +1587,11 @@ def log_softmax(
[-12.31326640, -1.31326640 , -0.31326640 , -15.31326640],
[-3.44018970 , -2.44018970 , -1.44018970 , -0.44018970 ]]])
"""

if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)):
dtype = convert_np_dtype_to_dtype_(dtype)

if in_dynamic_or_pir_mode():
if dtype is not None:
if dtype is not None and x.dtype != dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里dtype没有None的情况

x = _C_ops.cast(x, dtype)
return _C_ops.log_softmax(x, axis)
else:
Expand Down
36 changes: 25 additions & 11 deletions python/paddle/pir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,26 @@
}


str_to_paddle_type = {
"float32": DataType.FLOAT32,
"float64": DataType.FLOAT64,
"float16": DataType.FLOAT16,
"int32": DataType.INT32,
"int16": DataType.INT16,
"int64": DataType.INT64,
"bool": DataType.BOOL,
"bool_": DataType.BOOL,
"uint16": DataType.BFLOAT16,
"uint8": DataType.UINT8,
"int8": DataType.INT8,
"complex64": DataType.COMPLEX64,
"complex128": DataType.COMPLEX128,
"bfloat16": DataType.BFLOAT16,
"float8_e4m3fn": DataType.FLOAT8_E4M3FN,
"float8_e5m2": DataType.FLOAT8_E5M2,
}


def convert_np_dtype_to_dtype_(np_dtype) -> DataType:
"""
Convert the data type in numpy to the data type in Paddle.
Expand All @@ -113,17 +133,11 @@ def convert_np_dtype_to_dtype_(np_dtype) -> DataType:

"""
# Convert the data type string to numpy data type.
if isinstance(np_dtype, str) and np_dtype == "bfloat16":
# since there is still no support for bfloat16 in NumPy,
# uint16 is used for casting bfloat16
dtype = np.dtype("uint16")
elif isinstance(np_dtype, str) and np_dtype == "float8_e4m3fn":
dtype = 'float8_e4m3fn'
elif isinstance(np_dtype, str) and np_dtype == "float8_e5m2":
dtype = 'float8_e5m2'
else:
dtype = np.dtype(np_dtype)

if isinstance(np_dtype, str):
key = np_dtype.lower().strip()
if key in str_to_paddle_type:
return str_to_paddle_type[key]
dtype = np.dtype(np_dtype)
if dtype in np_type_to_paddle_type:
return np_type_to_paddle_type[dtype]
else:
Expand Down
8 changes: 6 additions & 2 deletions python/paddle/sparse/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,9 +623,13 @@ def cast(
assert in_dynamic_or_pir_mode(), (
"Currently, Sparse API only support dynamic mode or pir mode."
)
if index_dtype and not isinstance(index_dtype, core.VarDesc.VarType):
if index_dtype and not isinstance(
index_dtype, (core.VarDesc.VarType, core.DataType)
):
index_dtype = convert_np_dtype_to_dtype_(index_dtype)
if value_dtype and not isinstance(value_dtype, core.VarDesc.VarType):
if value_dtype and not isinstance(
value_dtype, (core.VarDesc.VarType, core.DataType)
):
value_dtype = convert_np_dtype_to_dtype_(value_dtype)
return _C_ops.sparse_cast(x, index_dtype, value_dtype)

Expand Down
6 changes: 3 additions & 3 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3519,9 +3519,8 @@ def tril_indices(
[[1, 2, 2, 3, 3, 3],
[0, 0, 1, 0, 1, 2]])
"""
if not isinstance(dtype, core.VarDesc.VarType):
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
dtype = convert_np_dtype_to_dtype_(dtype)

if not isinstance(row, int) or row < 0:
raise TypeError("row should be a non-negative int")

Expand Down Expand Up @@ -3600,7 +3599,8 @@ def triu_indices(
[[0 0 0 0 1 1 1 1 2 2 2 3 3]
[0 1 2 3 0 1 2 3 1 2 3 2 3]]
"""
if not isinstance(dtype, core.VarDesc.VarType):

if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
dtype = convert_np_dtype_to_dtype_(dtype)

if not isinstance(row, int) or row < 0:
Expand Down
35 changes: 27 additions & 8 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4283,8 +4283,11 @@ def cumsum_(
flatten = True
else:
flatten = False
if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast_(x, dtype)
if dtype is not None:
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
dtype = convert_np_dtype_to_dtype_(dtype)
if x.dtype != dtype:
x = cast_(x, dtype)

if in_dynamic_mode():
if axis is None:
Expand Down Expand Up @@ -4641,8 +4644,11 @@ def cumprod(
dim = -1
x = x.flatten(0, len(x.shape) - 1)

if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast(x, dtype)
if dtype is not None:
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
dtype = convert_np_dtype_to_dtype_(dtype)
if x.dtype != dtype:
x = cast_(x, dtype)

if in_dynamic_or_pir_mode():
return _C_ops.cumprod(x, dim, False, False)
Expand Down Expand Up @@ -4689,9 +4695,13 @@ def cumprod_(
if dim is None:
dim = -1
x = _C_ops.flatten_(x, 0, len(x.shape) - 1)

if dtype is not None and x.dtype != convert_np_dtype_to_dtype_(dtype):
x = cast_(x, dtype)
if dtype is None:
dtype = x.dtype
else:
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
dtype = convert_np_dtype_to_dtype_(dtype)
if x.dtype != dtype:
x = cast_(x, dtype)

if in_dynamic_mode():
return _C_ops.cumprod_(x, dim, False, False)
Expand Down Expand Up @@ -4782,7 +4792,16 @@ def prod(
check_dtype(
dtype,
'dtype',
['float32', 'float64', 'int32', 'int64', "float16", "uint16"],
[
'float32',
'float64',
'int32',
'int64',
"float16",
"uint16",
"complex64",
"complex128",
],
'prod',
)
if x.dtype != convert_np_dtype_to_dtype_(dtype):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ def amp_guard_white_op(self):
data = paddle.to_tensor(data)
with paddle.amp.amp_guard(True):
out_fp16 = conv2d(data)
with paddle.amp.amp_guard(True, dtype=paddle.float16):
out_fp16_ = conv2d(data)

with paddle.amp.amp_guard(False):
out_fp32 = conv2d(data)

self.assertTrue(data.dtype == paddle.float32)
self.assertTrue(out_fp16.dtype == paddle.float16)
self.assertTrue(out_fp16_.dtype == paddle.float16)
self.assertTrue(out_fp32.dtype == paddle.float32)

def test_amp_guard_white_op(self):
Expand Down
6 changes: 6 additions & 0 deletions test/legacy_test/test_incubate_int_bincount.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def test_basic(self):
expected = np.array([2, 2, 2, 0])
np.testing.assert_array_equal(out.numpy(), expected)

def test_basic_2(self):
x = paddle.to_tensor([1, 2, 3, 1, 2, 3], dtype=paddle.int32)
out = int_bincount(x, low=1, high=4, dtype="int32")
expected = np.array([2, 2, 2, 0])
np.testing.assert_array_equal(out.numpy(), expected)

def test_empty_input(self):
x = paddle.to_tensor([], dtype=paddle.int32)
out = int_bincount(x, low=0, high=10, dtype=paddle.int32)
Expand Down
9 changes: 9 additions & 0 deletions test/legacy_test/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ def test_np_dtype_convert(self):
self.assertEqual(paddle.bool, convert("bool"))
self.assertEqual(paddle.int8, convert("int8"))
self.assertEqual(paddle.uint8, convert("uint8"))
self.assertEqual(paddle.float32, convert(paddle.float32))
self.assertEqual(paddle.float16, convert(paddle.float16))
self.assertEqual(paddle.float64, convert(paddle.float64))
self.assertEqual(paddle.int32, convert(paddle.int32))
self.assertEqual(paddle.int16, convert(paddle.int16))
self.assertEqual(paddle.int64, convert(paddle.int64))
self.assertEqual(paddle.bool, convert(paddle.bool))
self.assertEqual(paddle.int8, convert(paddle.int8))
self.assertEqual(paddle.uint8, convert(paddle.uint8))

def test_var(self):
b = default_main_program().current_block()
Expand Down