From 6ad7e11f4b8438375f57f5606065bd5ca84761ce Mon Sep 17 00:00:00 2001 From: Bohan Hou Date: Tue, 11 Mar 2025 14:02:02 -0400 Subject: [PATCH 1/2] Update ndarray.py --- python/tvm/runtime/ndarray.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 349b7d2d546f..d001b671fc57 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -233,7 +233,12 @@ def numpy(self): if dtype == "int4": dtype = "int8" if dtype == "bfloat16": - dtype = "uint16" + if ml_dtypes is not None: + dtype = ml_dtypes.bfloat16 + else: + raise RuntimeError( + "ml_dtypes is not installed, cannot convert bfloat16 array to numpy." + ) if dtype == "float8_e4m3fn": if ml_dtypes is not None: dtype = ml_dtypes.float8_e4m3fn From 1223e29cc0fdfa714eac18eb667d4349eaa7f869 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 11 Mar 2025 20:13:58 -0400 Subject: [PATCH 2/2] Fix llvm codegen bfloat16 test --- .../codegen/test_target_codegen_llvm.py | 43 ++++--------------- 1 file changed, 8 insertions(+), 35 deletions(-) diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index e4a28453184d..cea0bc2d9318 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -758,33 +758,6 @@ def check_llvm_ir(): check_llvm_ir() -def np_float2np_bf16(arr): - """Convert a numpy array of float to a numpy array - of bf16 in uint16""" - orig = arr.view("