Skip to content

Commit 86bbd49

Browse files
committed
Add ARM intrin
1 parent 120fd96 commit 86bbd49

File tree

3 files changed

+161
-12
lines changed

3 files changed

+161
-12
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from .. import TensorIntrin
18+
from tvm.script import tir as T
19+
20+
21+
@T.prim_func
22+
def dot_product_4x4_i8i8i32_desc(
23+
A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"]
24+
) -> None:
25+
with T.block("root"):
26+
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
27+
T.writes(C[0:4])
28+
for i in T.serial(0, 4):
29+
with T.init():
30+
C[i] = T.int32(0)
31+
for k in T.serial(0, 4):
32+
with T.block("update"):
33+
vi, vk = T.axis.remap("SR", [i, k])
34+
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
35+
36+
37+
@T.prim_func
38+
def dot_product_4x4_i8i8i32_neon(
39+
A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"]
40+
) -> None:
41+
with T.block("root"):
42+
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
43+
T.writes(C[0:4])
44+
45+
A_int8 = A.vload([0], "int8x4")
46+
re_int32 = T.reinterpret(A_int8, dtype="int32")
47+
vec_ai32 = T.broadcast(re_int32, 2)
48+
vec_a = T.reinterpret(vec_ai32, dtype="int8x8")
49+
50+
vec_b = B.vload([0, 0], dtype="int8x8")
51+
52+
multiply = T.call_llvm_pure_intrin(
53+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
54+
T.uint32(2),
55+
vec_a,
56+
vec_b,
57+
dtype="int16x8",
58+
)
59+
60+
pair1 = T.call_llvm_pure_intrin(
61+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
62+
T.uint32(1),
63+
multiply,
64+
dtype="int32x4",
65+
)
66+
67+
vec_b_2 = B.vload([2, 0], dtype="int8x8")
68+
69+
multiply_2 = T.call_llvm_pure_intrin(
70+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
71+
T.uint32(2),
72+
vec_a,
73+
vec_b_2,
74+
dtype="int16x8",
75+
)
76+
77+
pair2 = T.call_llvm_pure_intrin(
78+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
79+
T.uint32(1),
80+
multiply_2,
81+
dtype="int32x4",
82+
)
83+
84+
C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
85+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"),
86+
T.uint32(2),
87+
pair1,
88+
pair2,
89+
dtype="int32x4",
90+
)
91+
92+
93+
@T.prim_func
94+
def dot_product_4x4_i8i8i32_sdot(
95+
A: T.Buffer[(4,), "int8"], B: T.Buffer[(4, 4), "int8"], C: T.Buffer[(4,), "int32"]
96+
) -> None:
97+
with T.block("root"):
98+
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
99+
T.writes(C[0:4])
100+
101+
A_i8x4 = A.vload([0], "int8x4")
102+
A_i32 = T.reinterpret(A_i8x4, dtype="int32")
103+
vec_ai32 = T.broadcast(A_i32, 4)
104+
vec_a = T.reinterpret(vec_ai32, dtype="int8x16")
105+
106+
vec_b = B.vload([0, 0], dtype="int8x16")
107+
108+
C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
109+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"),
110+
T.uint32(3),
111+
T.int32x4(0),
112+
vec_a,
113+
vec_b,
114+
dtype="int32x4",
115+
)
116+
117+
118+
ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon"
119+
ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot"
120+
121+
TensorIntrin.register(
122+
ARM_DOT_4x4_i8_NEON_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_neon
123+
)
124+
125+
TensorIntrin.register(
126+
ARM_DOT_4x4_i8_SDOT_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_sdot
127+
)

python/tvm/tir/tensor_intrin/x86.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,9 @@ def dot_product_16x4_u8i8i32_desc(
3939

4040

4141
@T.prim_func
42-
def dot_product_16x4_u8i8i32_vnni_impl(
42+
def dot_product_16x4_u8i8i32_vnni(
4343
A: T.Buffer[(4,), "uint8"], B: T.Buffer[(16, 4), "int8"], C: T.Buffer[(16,), "int32"]
4444
) -> None:
45-
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
46-
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
47-
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
48-
4945
with T.block("root"):
5046
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
5147
T.writes(C[0:16])
@@ -69,5 +65,5 @@ def dot_product_16x4_u8i8i32_vnni_impl(
6965
VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni"
7066

7167
TensorIntrin.register(
72-
VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni_impl
68+
VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni
7369
)

tests/python/unittest/test_tir_schedule_tensorize.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tvm.script import tir as T
2424
from tvm.tir.schedule.testing import verify_trace_roundtrip
2525
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
26+
from tvm.tir.tensor_intrin.arm_cpu import ARM_DOT_4x4_i8_NEON_INTRIN, ARM_DOT_4x4_i8_SDOT_INTRIN
2627

2728
# fmt: off
2829
# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
@@ -532,10 +533,9 @@ def test_tensorize_with_annotation():
532533
verify_trace_roundtrip(sch=s, mod=func)
533534

534535

535-
def test_tensorize_vnni():
536-
n, m, k = 128, 128, 128
537-
X = te.placeholder((m, k), name="X", dtype="uint8")
538-
packed_W = te.placeholder((n // 16, k // 4, 16, 4), name="packedW", dtype="int8")
536+
def get_matmul_packed(m, n, k, lhs_type, int32_lanes):
537+
X = te.placeholder((m, k), name="X", dtype=lhs_type)
538+
packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), name="packedW", dtype="int8")
539539

540540
ak = te.reduce_axis((0, k), name="k")
541541
matmul = te.compute(
@@ -550,7 +550,13 @@ def test_tensorize_vnni():
550550
name="compute",
551551
)
552552

553-
func = te.create_prim_func([X, packed_W, matmul])
553+
return te.create_prim_func([X, packed_W, matmul])
554+
555+
556+
def test_tensorize_vnni():
557+
m, n, k = 128, 128, 128
558+
559+
func = get_matmul_packed(m, n, k, "uint8", 16)
554560

555561
sch = tir.Schedule(func, debug_mask="all")
556562
block = sch.get_block("compute")
@@ -566,6 +572,26 @@ def test_tensorize_vnni():
566572
verify_trace_roundtrip(sch=sch, mod=func)
567573

568574

575+
def test_tensorize_arm_dot():
576+
m, n, k = 128, 128, 128
577+
578+
func = get_matmul_packed(m, n, k, "int8", 4)
579+
580+
for intrin in [ARM_DOT_4x4_i8_SDOT_INTRIN, ARM_DOT_4x4_i8_NEON_INTRIN]:
581+
sch = tir.Schedule(func, debug_mask="all")
582+
block = sch.get_block("compute")
583+
_, j, k = sch.get_loops(block)
584+
585+
_, ji = sch.split(j, factors=[None, 4])
586+
ko, ki = sch.split(k, factors=[None, 4])
587+
sch.reorder(ko, ji, ki)
588+
589+
sch.decompose_reduction(block, ko)
590+
sch.tensorize(ji, intrin)
591+
592+
verify_trace_roundtrip(sch=sch, mod=func)
593+
594+
569595
if __name__ == "__main__":
570596
# sys.exit(pytest.main([__file__] + sys.argv[1:]))
571-
test_tensorize_vnni()
597+
test_tensorize_arm_dot()

0 commit comments

Comments
 (0)