From 128d2c5d9e4c13d0c4b656012434dff82b0270fd Mon Sep 17 00:00:00 2001 From: huangjiyi <947613776@qq.com> Date: Fri, 31 Oct 2025 17:15:24 +0800 Subject: [PATCH 1/2] [API Compatibility] Support tensor dtype compare using `is` --- paddle/fluid/pybind/eager_utils.cc | 11 ++++++++--- test/legacy_test/test_eager_tensor.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index e679052bab5415..60113686f24403 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1087,9 +1087,14 @@ PyObject* ToPyObject(const phi::DenseTensor* value) { } PyObject* ToPyObject(const phi::DataType& dtype) { - auto obj = ::pybind11::cast(dtype); - obj.inc_ref(); - return obj.ptr(); + static const std::vector dtype_names = { + "UNDEFINED", "BOOL", "UINT8", "INT8", "UINT16", + "INT16", "UINT32", "INT32", "UINT64", "INT64", + "FLOAT32", "FLOAT64", "COMPLEX64", "COMPLEX128", "FLOAT16", + "BFLOAT16", "FLOAT8_E4M3FN", "FLOAT8_E5M2", "PSTRING", + }; + return PyObject_GetAttrString(reinterpret_cast(g_data_type_pytype), + dtype_names[static_cast(dtype)].c_str()); } PyObject* ToPyObject(const std::vector& dtypes) { diff --git a/test/legacy_test/test_eager_tensor.py b/test/legacy_test/test_eager_tensor.py index 8768de64169d98..4869466608d977 100644 --- a/test/legacy_test/test_eager_tensor.py +++ b/test/legacy_test/test_eager_tensor.py @@ -1283,6 +1283,20 @@ def test_print_tensor_dtype(self): self.assertEqual(a_str, expected) + def test_tensor_dtype_compare(self): + a = paddle.randn([2], dtype="float32") + b = paddle.randn([2], dtype="float32") + c = paddle.randn([2], dtype="float64") + + self.assertTrue(a.dtype == paddle.float32) + self.assertTrue(a.dtype == b.dtype) + self.assertTrue(a.dtype != paddle.float64) + self.assertTrue(a.dtype != c.dtype) + self.assertTrue(a.dtype is paddle.float32) + self.assertTrue(a.dtype is b.dtype) + self.assertTrue(a.dtype is not paddle.float64) + self.assertTrue(a.dtype is not c.dtype) + def test___cuda_array_interface__(self): """test Tensor.__cuda_array_interface__""" with dygraph_guard(): From 9af150a54ec663b30f722151d5054698f0ee6b72 Mon Sep 17 00:00:00 2001 From: huangjiyi <947613776@qq.com> Date: Mon, 3 Nov 2025 21:17:06 +0800 Subject: [PATCH 2/2] fix --- paddle/fluid/pybind/eager_utils.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 60113686f24403..5911940e1d0169 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1088,10 +1088,10 @@ PyObject* ToPyObject(const phi::DenseTensor* value) { PyObject* ToPyObject(const phi::DataType& dtype) { static const std::vector dtype_names = { - "UNDEFINED", "BOOL", "UINT8", "INT8", "UINT16", - "INT16", "UINT32", "INT32", "UINT64", "INT64", - "FLOAT32", "FLOAT64", "COMPLEX64", "COMPLEX128", "FLOAT16", - "BFLOAT16", "FLOAT8_E4M3FN", "FLOAT8_E5M2", "PSTRING", + "UNDEFINED", "BOOL", "UINT8", "INT8", "UINT16", + "INT16", "UINT32", "INT32", "UINT64", "INT64", + "FLOAT32", "FLOAT64", "COMPLEX64", "COMPLEX128", "PSTRING", + "FLOAT16", "BFLOAT16", "FLOAT8_E4M3FN", "FLOAT8_E5M2", }; return PyObject_GetAttrString(reinterpret_cast(g_data_type_pytype), dtype_names[static_cast(dtype)].c_str());