Skip to content

Commit 7666cd7

Browse files
committed
add DP4A intrin
1 parent 7086bdb commit 7666cd7

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

python/tvm/tir/tensor_intrin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818
"""Intrinsics for tensorization."""
1919
from .x86 import *
2020
from .arm_cpu import *
21+
from .dot_product_common import *
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name,missing-function-docstring
18+
"""Dot product related intrinsics."""
19+
from tvm.script import tir as T
20+
from .. import TensorIntrin
21+
22+
23+
@T.prim_func
24+
def dp4a_desc(
25+
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
26+
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
27+
C: T.Buffer((), "int32", offset_factor=1, align=4, scope="local"),
28+
) -> None:
29+
with T.block("root"):
30+
T.reads(C[()], A[0:4], B[0:4])
31+
T.writes(C[()])
32+
for i in range(0, 4):
33+
with T.block("update"):
34+
vi = T.axis.remap("R", [i])
35+
C[()] = C[()] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")
36+
37+
38+
@T.prim_func
39+
def dp4a_impl(
40+
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
41+
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
42+
C: T.Buffer((), "int32", offset_factor=1, align=4, scope="local"),
43+
) -> None:
44+
with T.block("root"):
45+
T.reads(C[()], A[0:4], B[0:4])
46+
T.writes(C[()])
47+
48+
A_i8x4 = B.vload([0], "int8x4")
49+
B_i8x4 = B.vload([0], "int8x4")
50+
51+
T.evaluate(T.call_pure_extern("__dp4a", A_i8x4, B_i8x4, T.int32(0), dtype="int32"))
52+
53+
54+
DP4A_INTRIN = "dp4a"
55+
56+
TensorIntrin.register(DP4A_INTRIN, dp4a_desc, dp4a_impl)

0 commit comments

Comments
 (0)