Skip to content

Commit 711a007

Browse files
committed
[TIR] Add VNNI dot product intrinsic for TIR
1 parent 534205b commit 711a007

File tree

3 files changed

+92
-5
lines changed

3 files changed

+92
-5
lines changed

python/tvm/script/tir/special_stmt.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@
2525

2626
import tvm.tir
2727
from tvm.runtime import Object, String
28-
from tvm import te
2928
from tvm.target import Target
3029
from tvm.ir import Span
31-
from tvm.tir import IntImm, IterVar
30+
from tvm.tir import IntImm, IterVar, Var
3231

3332
from .node import BufferSlice
3433
from .utils import buffer_slice_to_region
@@ -800,7 +799,7 @@ def var(dtype, span):
800799
self.context.report_error(
801800
f"VarDef expected assign to only one var, but got {names}", span
802801
)
803-
v = te.var(names[0], dtype, span=span)
802+
v = Var(names[0], dtype, span=span)
804803
self.context.update_symbol(v.name, v, self.node)
805804

806805
super().__init__(var, def_symbol=True)
@@ -821,7 +820,7 @@ def buffer_var(dtype, storage_scope, span):
821820
f"VarDef expected assign to only one var, but got {names}", span
822821
)
823822
ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
824-
v = te.var(names[0], ptr_type, span=span)
823+
v = Var(names[0], ptr_type, span=span)
825824
self.context.update_symbol(v.name, v, self.node)
826825

827826
super().__init__(buffer_var, def_symbol=True)
@@ -841,7 +840,7 @@ def env_thread(env_name, span):
841840
self.context.report_error(
842841
f"VarDef expected assign to only one var, but got {names}", span
843842
)
844-
v = te.var(names[0], span=span)
843+
v = Var(names[0], span=span)
845844
self.context.func_var_env_dict[v] = env_name
846845
self.context.update_symbol(v.name, v, self.node)
847846

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=unused-import
18+
"""Intrinsics for tensorization."""
19+
from . import vnni
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from .. import TensorIntrin
18+
from tvm.script import tir as T
19+
20+
21+
@T.prim_func
22+
def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
23+
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
24+
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
25+
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
26+
27+
with T.block("root"):
28+
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
29+
T.writes(C[0:16])
30+
for i in T.serial(0, 16):
31+
with T.init():
32+
C[i] = T.int32(0)
33+
for k in T.serial(0, 4):
34+
with T.block("update"):
35+
vi, vk = T.axis.remap("SR", [i, k])
36+
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
37+
38+
39+
@T.prim_func
40+
def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
41+
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
42+
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
43+
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
44+
45+
with T.block("root"):
46+
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
47+
T.writes(C[0:16])
48+
49+
A_u8x4 = A.vload([0], "uint8x4")
50+
A_i32 = T.reinterpret(A_u8x4, dtype="int32")
51+
52+
B_i8x64 = B.vload([0, 0], dtype="int8x64")
53+
B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
54+
55+
C[
56+
T.ramp(T.int32(0), 1, 16)
57+
] += T.call_llvm_pure_intrin( # Note: this is an update +=
58+
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
59+
T.uint32(0),
60+
T.int32x16(0),
61+
T.broadcast(A_i32, 16),
62+
B_i32x16,
63+
dtype="int32x16",
64+
)
65+
66+
67+
TensorIntrin.register(
68+
"dot_16x1x16_uint8_int8_int32_cascadelake", dot_product_desc, dot_product_intrin
69+
)

0 commit comments

Comments
 (0)