Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,12 @@ def numpy(self):
if dtype == "int4":
dtype = "int8"
if dtype == "bfloat16":
dtype = "uint16"
if ml_dtypes is not None:
dtype = ml_dtypes.bfloat16
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert bfloat16 array to numpy."
)
if dtype == "float8_e4m3fn":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e4m3fn
Expand Down
Loading