Skip to content

Commit 7320f14

Browse files
fix to_tensor bug (#76000) (#76067)
Co-authored-by: wanghuancoder <[email protected]>
1 parent e3bde83 commit 7320f14

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

python/paddle/tensor/creation.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,24 @@ def _handle_tensor_dtype(
713713
if np.isscalar(data) and not isinstance(data, str):
714714
data = np.array(data)
715715
elif isinstance(data, (list, tuple)):
716-
data = np.array(data)
716+
has_tensor = False
717+
for d in data:
718+
if isinstance(d, paddle.Tensor):
719+
has_tensor = True
720+
break
721+
if has_tensor:
722+
if (
723+
len(data) == 1
724+
and isinstance(data[0], paddle.Tensor)
725+
and data[0].dtype == paddle.bfloat16
726+
):
727+
data = np.array([data[0].numpy()])
728+
else:
729+
data = np.array(data)
730+
if not dtype:
731+
dtype = data.dtype
732+
else:
733+
data = np.array(data)
717734
if data.dtype == np.object_:
718735
raise ValueError(
719736
"\n\tFailed to convert input data to a regular ndarray :\n\t - Usually "

test/legacy_test/test_eager_tensor.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2129,5 +2129,39 @@ def test_set_dynamic_attribute_to_eager_tensor_instance_create_via_to_pyobject(
21292129
self.assertEqual(tensor_instance.__dict__["_custom_flag"], True)
21302130

21312131

2132+
class TestListToTensor(unittest.TestCase):
2133+
def test_list_to_tensor_bfloat16(self):
2134+
a = [paddle.to_tensor(2, dtype=paddle.bfloat16)]
2135+
b = paddle.to_tensor(a)
2136+
self.assertEqual(b.dtype, paddle.bfloat16)
2137+
self.assertEqual(b[0], 2.0)
2138+
2139+
def test_list_to_tensor_float16(self):
2140+
a = [paddle.to_tensor(2, dtype=paddle.float16)]
2141+
b = paddle.to_tensor(a)
2142+
self.assertEqual(b.dtype, paddle.float16)
2143+
self.assertEqual(b[0], 2.0)
2144+
2145+
def test_list_to_tensor_bfloat16_float32(self):
2146+
a = [
2147+
paddle.to_tensor(2, dtype=paddle.bfloat16),
2148+
paddle.to_tensor(2, dtype=paddle.float32),
2149+
]
2150+
b = paddle.to_tensor(a)
2151+
self.assertEqual(b.dtype, paddle.float32)
2152+
self.assertEqual(b[0], 2.0)
2153+
self.assertEqual(b[1], 2.0)
2154+
2155+
def test_list_to_tensor_float16_float32(self):
2156+
a = [
2157+
paddle.to_tensor(2, dtype=paddle.float16),
2158+
paddle.to_tensor(2, dtype=paddle.float32),
2159+
]
2160+
b = paddle.to_tensor(a)
2161+
self.assertEqual(b.dtype, paddle.float32)
2162+
self.assertEqual(b[0], 2.0)
2163+
self.assertEqual(b[1], 2.0)
2164+
2165+
21322166
if __name__ == "__main__":
21332167
unittest.main()

0 commit comments

Comments
 (0)