Skip to content

Commit a47a00d

Browse files
spectrometerHBHMasterJH5574
authored andcommitted
[BF16] Support ndarray.asnumpy() to bfloat16 tensor natively using ml_dtypes (apache#17739)
* Update ndarray.py * Fix llvm codegen bfloat16 test --------- Co-authored-by: Ruihang Lai <[email protected]>
1 parent 323c39f commit a47a00d

File tree

2 files changed

+14
-36
lines changed

2 files changed

+14
-36
lines changed

python/tvm/runtime/ndarray.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,12 @@ def numpy(self):
233233
if dtype == "int4":
234234
dtype = "int8"
235235
if dtype == "bfloat16":
236-
dtype = "uint16"
236+
if ml_dtypes is not None:
237+
dtype = ml_dtypes.bfloat16
238+
else:
239+
raise RuntimeError(
240+
"ml_dtypes is not installed, cannot convert bfloat16 array to numpy."
241+
)
237242
if dtype == "float8_e4m3fn":
238243
if ml_dtypes is not None:
239244
dtype = ml_dtypes.float8_e4m3fn

tests/python/codegen/test_target_codegen_llvm.py

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -758,33 +758,6 @@ def check_llvm_ir():
758758
check_llvm_ir()
759759

760760

761-
def np_float2np_bf16(arr):
762-
"""Convert a numpy array of float to a numpy array
763-
of bf16 in uint16"""
764-
orig = arr.view("<u4")
765-
bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
766-
return np.right_shift(orig + bias, 16).astype("uint16")
767-
768-
769-
def np_float2tvm_bf16(arr):
770-
"""Convert a numpy array of float to a TVM array
771-
of bf16"""
772-
nparr = np_float2np_bf16(arr)
773-
return tvm.nd.empty(nparr.shape, "bfloat16").copyfrom(nparr)
774-
775-
776-
def np_bf162np_float(arr):
777-
"""Convert a numpy array of bf16 (uint16) to a numpy array
778-
of float"""
779-
u32 = np.left_shift(arr.astype("uint32"), 16)
780-
return u32.view("<f4")
781-
782-
783-
def np_bf16_cast_and_cast_back(arr):
784-
"""Convert a numpy array of float to bf16 and cast back"""
785-
return np_bf162np_float(np_float2np_bf16(arr))
786-
787-
788761
@tvm.testing.requires_llvm
789762
def test_llvm_bf16():
790763
def dotest(do_vectorize):
@@ -806,16 +779,16 @@ def dotest(do_vectorize):
806779
sch.vectorize(loop)
807780

808781
module = tvm.compile(sch.mod, target="llvm")
809-
npa = np.random.rand(32).astype("float32")
810-
npb = np.random.rand(32).astype("float32")
811-
va = np_bf16_cast_and_cast_back(npa)
812-
vb = np_bf16_cast_and_cast_back(npb)
813-
res = np_bf16_cast_and_cast_back(va + vb)
814-
a_ = np_float2tvm_bf16(npa)
815-
b_ = np_float2tvm_bf16(npb)
782+
npa = np.random.rand(32).astype("bfloat16")
783+
npb = np.random.rand(32).astype("bfloat16")
784+
res = npa + npb
785+
a_ = tvm.nd.array(npa)
786+
b_ = tvm.nd.array(npb)
816787
c_ = tvm.nd.empty((32,), "bfloat16")
817788
module(a_, b_, c_)
818-
tvm.testing.assert_allclose(np_bf162np_float(c_.numpy()), res)
789+
# Note: directly compare without casting to float32 should work with the
790+
# latest numpy version.
791+
tvm.testing.assert_allclose(c_.numpy().astype("float32"), res.astype("float32"))
819792

820793
dotest(True)
821794
dotest(False)

0 commit comments

Comments
 (0)