Skip to content

Commit 7c4aa0d

Browse files
committed
Fix llvm codegen bfloat16 test
1 parent 6ad7e11 commit 7c4aa0d

File tree

1 file changed

+6
-35
lines changed

1 file changed

+6
-35
lines changed

tests/python/codegen/test_target_codegen_llvm.py

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -758,33 +758,6 @@ def check_llvm_ir():
758758
check_llvm_ir()
759759

760760

761-
def np_float2np_bf16(arr):
762-
"""Convert a numpy array of float to a numpy array
763-
of bf16 in uint16"""
764-
orig = arr.view("<u4")
765-
bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
766-
return np.right_shift(orig + bias, 16).astype("uint16")
767-
768-
769-
def np_float2tvm_bf16(arr):
770-
"""Convert a numpy array of float to a TVM array
771-
of bf16"""
772-
nparr = np_float2np_bf16(arr)
773-
return tvm.nd.empty(nparr.shape, "bfloat16").copyfrom(nparr)
774-
775-
776-
def np_bf162np_float(arr):
777-
"""Convert a numpy array of bf16 (uint16) to a numpy array
778-
of float"""
779-
u32 = np.left_shift(arr.astype("uint32"), 16)
780-
return u32.view("<f4")
781-
782-
783-
def np_bf16_cast_and_cast_back(arr):
784-
"""Convert a numpy array of float to bf16 and cast back"""
785-
return np_bf162np_float(np_float2np_bf16(arr))
786-
787-
788761
@tvm.testing.requires_llvm
789762
def test_llvm_bf16():
790763
def dotest(do_vectorize):
@@ -806,16 +779,14 @@ def dotest(do_vectorize):
806779
sch.vectorize(loop)
807780

808781
module = tvm.compile(sch.mod, target="llvm")
809-
npa = np.random.rand(32).astype("float32")
810-
npb = np.random.rand(32).astype("float32")
811-
va = np_bf16_cast_and_cast_back(npa)
812-
vb = np_bf16_cast_and_cast_back(npb)
813-
res = np_bf16_cast_and_cast_back(va + vb)
814-
a_ = np_float2tvm_bf16(npa)
815-
b_ = np_float2tvm_bf16(npb)
782+
npa = np.random.rand(32).astype("bfloat16")
783+
npb = np.random.rand(32).astype("bfloat16")
784+
res = npa + npb
785+
a_ = tvm.nd.array(npa)
786+
b_ = tvm.nd.array(npb)
816787
c_ = tvm.nd.empty((32,), "bfloat16")
817788
module(a_, b_, c_)
818-
tvm.testing.assert_allclose(np_bf162np_float(c_.numpy()), res)
789+
tvm.testing.assert_allclose(c_.numpy(), res)
819790

820791
dotest(True)
821792
dotest(False)

0 commit comments

Comments
 (0)