2323
2424
2525@T .prim_func
26- def dot_product_16x4_desc (a : T .handle , b : T .handle , c : T .handle ) -> None :
27- A = T .match_buffer (a , (4 ,), "uint8" , offset_factor = 1 )
28- B = T .match_buffer (b , (16 , 4 ), "int8" , offset_factor = 1 )
29- C = T .match_buffer (c , (16 ,), "int32" , offset_factor = 1 )
30-
26+ def dot_product_16x4_u8i8i32_desc (
27+ A : T .Buffer [(4 ,), "uint8" ], B : T .Buffer [(16 , 4 ), "int8" ], C : T .Buffer [(16 ,), "int32" ]
28+ ) -> None :
3129 with T .block ("root" ):
3230 T .reads (C [0 :16 ], A [0 :4 ], B [0 :16 , 0 :4 ])
3331 T .writes (C [0 :16 ])
@@ -41,7 +39,9 @@ def dot_product_16x4_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
4139
4240
4341@T .prim_func
44- def dot_product_16x4_vnni_impl (a : T .handle , b : T .handle , c : T .handle ) -> None :
42+ def dot_product_16x4_u8i8i32_vnni_impl (
43+ A : T .Buffer [(4 ,), "uint8" ], B : T .Buffer [(16 , 4 ), "int8" ], C : T .Buffer [(16 ,), "int32" ]
44+ ) -> None :
4545 A = T .match_buffer (a , (4 ,), "uint8" , offset_factor = 1 )
4646 B = T .match_buffer (b , (16 , 4 ), "int8" , offset_factor = 1 )
4747 C = T .match_buffer (c , (16 ,), "int32" , offset_factor = 1 )
@@ -66,6 +66,8 @@ def dot_product_16x4_vnni_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
6666 )
6767
6868
69- VNNI_INTRIN = "dot_16x4_vnni"
69+ VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni"
7070
71- TensorIntrin .register (VNNI_INTRIN , dot_product_16x4_desc , dot_product_16x4_vnni_impl )
71+ TensorIntrin .register (
72+ VNNI_DOT_16x4_INTRIN , dot_product_16x4_u8i8i32_desc , dot_product_16x4_u8i8i32_vnni_impl
73+ )
0 commit comments