Skip to content

Commit 6ad7e11

Browse files
Update ndarray.py
1 parent 148737b commit 6ad7e11

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

python/tvm/runtime/ndarray.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,12 @@ def numpy(self):
233233
if dtype == "int4":
234234
dtype = "int8"
235235
if dtype == "bfloat16":
236-
dtype = "uint16"
236+
if ml_dtypes is not None:
237+
dtype = ml_dtypes.bfloat16
238+
else:
239+
raise RuntimeError(
240+
"ml_dtypes is not installed, cannot convert bfloat16 array to numpy."
241+
)
237242
if dtype == "float8_e4m3fn":
238243
if ml_dtypes is not None:
239244
dtype = ml_dtypes.float8_e4m3fn

0 commit comments

Comments
 (0)