@@ -758,33 +758,6 @@ def check_llvm_ir():
758758 check_llvm_ir ()
759759
760760
761- def np_float2np_bf16 (arr ):
762- """Convert a numpy array of float to a numpy array
763- of bf16 in uint16"""
764- orig = arr .view ("<u4" )
765- bias = np .bitwise_and (np .right_shift (orig , 16 ), 1 ) + 0x7FFF
766- return np .right_shift (orig + bias , 16 ).astype ("uint16" )
767-
768-
769- def np_float2tvm_bf16 (arr ):
770- """Convert a numpy array of float to a TVM array
771- of bf16"""
772- nparr = np_float2np_bf16 (arr )
773- return tvm .nd .empty (nparr .shape , "bfloat16" ).copyfrom (nparr )
774-
775-
776- def np_bf162np_float (arr ):
777- """Convert a numpy array of bf16 (uint16) to a numpy array
778- of float"""
779- u32 = np .left_shift (arr .astype ("uint32" ), 16 )
780- return u32 .view ("<f4" )
781-
782-
783- def np_bf16_cast_and_cast_back (arr ):
784- """Convert a numpy array of float to bf16 and cast back"""
785- return np_bf162np_float (np_float2np_bf16 (arr ))
786-
787-
788761@tvm .testing .requires_llvm
789762def test_llvm_bf16 ():
790763 def dotest (do_vectorize ):
@@ -806,16 +779,14 @@ def dotest(do_vectorize):
806779 sch .vectorize (loop )
807780
808781 module = tvm .compile (sch .mod , target = "llvm" )
809- npa = np .random .rand (32 ).astype ("float32" )
810- npb = np .random .rand (32 ).astype ("float32" )
811- va = np_bf16_cast_and_cast_back (npa )
812- vb = np_bf16_cast_and_cast_back (npb )
813- res = np_bf16_cast_and_cast_back (va + vb )
814- a_ = np_float2tvm_bf16 (npa )
815- b_ = np_float2tvm_bf16 (npb )
782+ npa = np .random .rand (32 ).astype ("bfloat16" )
783+ npb = np .random .rand (32 ).astype ("bfloat16" )
784+ res = npa + npb
785+ a_ = tvm .nd .array (npa )
786+ b_ = tvm .nd .array (npb )
816787 c_ = tvm .nd .empty ((32 ,), "bfloat16" )
817788 module (a_ , b_ , c_ )
818- tvm .testing .assert_allclose (np_bf162np_float ( c_ .numpy () ), res )
789+ tvm .testing .assert_allclose (c_ .numpy (), res )
819790
820791 dotest (True )
821792 dotest (False )
0 commit comments