Skip to content

Commit 4c77f0f

Browse files
vincentcccLufang CHEN 陈橹方
andauthored
[TIR] Extend DP4A tensor intrin (#16293)
* update dp4a tensor intrin * update dp4a tensor intrin * lint --------- Co-authored-by: Lufang CHEN 陈橹方 <[email protected]>
1 parent 8e54a9e commit 4c77f0f

File tree

6 files changed

+154
-56
lines changed

6 files changed

+154
-56
lines changed

python/tvm/tir/tensor_intrin/arm_cpu.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
# pylint: disable=invalid-name,missing-function-docstring
17+
# pylint: disable=invalid-name,missing-function-docstring,unused-import
1818
"""Intrinsics for ARM tensorization."""
1919
from tvm.script import tir as T
2020
from .. import TensorIntrin
21-
from .dot_product_common import DP4A_INTRIN # pylint: disable=unused-import
21+
from .dot_product_common import (
22+
DP4A_S8S8S32_INTRIN,
23+
DP4A_S8U8S32_INTRIN,
24+
DP4A_U8S8S32_INTRIN,
25+
DP4A_U8U8U32_INTRIN,
26+
)
2227

2328

2429
# TODO(masahi): Parametrize the TVMScript description of dot product by

python/tvm/tir/tensor_intrin/dot_product_common.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,52 @@
2020
from .. import TensorIntrin
2121

2222

23-
@T.prim_func
24-
def dp4a_desc(
25-
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
26-
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
27-
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
28-
) -> None:
29-
with T.block("root"):
30-
T.reads(C[0], A[0:4], B[0:4])
31-
T.writes(C[0])
32-
for i in range(0, 4):
33-
with T.block("update"):
34-
vi = T.axis.remap("R", [i])
35-
C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")
36-
37-
38-
@T.prim_func
39-
def dp4a_impl(
40-
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
41-
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
42-
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
43-
) -> None:
44-
with T.block("root"):
45-
T.reads(C[0], A[0:4], B[0:4])
46-
T.writes(C[0])
47-
48-
C[0] += T.call_pure_extern(
49-
"__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"), T.int32(0), dtype="int32"
50-
)
51-
52-
53-
DP4A_INTRIN = "dp4a"
54-
55-
TensorIntrin.register(DP4A_INTRIN, dp4a_desc, dp4a_impl)
23+
def get_dp4a_intrin(dtype_a, dtype_b, dtype_c):
24+
if dtype_c == "uint32":
25+
assert dtype_a == dtype_b == "uint8"
26+
vec_type_a = "int8x4" if dtype_a == "int8" else "uint8x4"
27+
vec_type_b = "int8x4" if dtype_b == "int8" else "uint8x4"
28+
29+
@T.prim_func
30+
def dp4a_desc(
31+
A: T.Buffer((4,), dtype_a, offset_factor=1, align=4, scope="shared"),
32+
B: T.Buffer((4,), dtype_b, offset_factor=1, align=4, scope="shared"),
33+
C: T.Buffer((1,), dtype_c, offset_factor=1, align=4, scope="local"),
34+
) -> None:
35+
with T.block("root"):
36+
T.reads(C[0], A[0:4], B[0:4])
37+
T.writes(C[0])
38+
for i in range(0, 4):
39+
with T.block("update"):
40+
vi = T.axis.remap("R", [i])
41+
C[0] = C[0] + T.cast(A[vi], dtype_c) * T.cast(B[vi], dtype_c)
42+
43+
@T.prim_func
44+
def dp4a_impl(
45+
A: T.Buffer((4,), dtype_a, offset_factor=1, align=4, scope="shared"),
46+
B: T.Buffer((4,), dtype_b, offset_factor=1, align=4, scope="shared"),
47+
C: T.Buffer((1,), dtype_c, offset_factor=1, align=4, scope="local"),
48+
) -> None:
49+
with T.block("root"):
50+
T.reads(C[0], A[0:4], B[0:4])
51+
T.writes(C[0])
52+
53+
C[0] += T.call_pure_extern(
54+
"__dp4a",
55+
A.vload([0], vec_type_a),
56+
B.vload([0], vec_type_b),
57+
T.uint32(0) if dtype_c == "uint32" else T.int32(0),
58+
dtype=dtype_c,
59+
)
60+
61+
return dp4a_desc, dp4a_impl
62+
63+
64+
DP4A_S8S8S32_INTRIN = "dp4a_s8s8s32"
65+
TensorIntrin.register(DP4A_S8S8S32_INTRIN, *get_dp4a_intrin("int8", "int8", "int32"))
66+
DP4A_U8S8S32_INTRIN = "dp4a_u8s8s32"
67+
TensorIntrin.register(DP4A_U8S8S32_INTRIN, *get_dp4a_intrin("uint8", "int8", "int32"))
68+
DP4A_S8U8S32_INTRIN = "dp4a_s8u8s32"
69+
TensorIntrin.register(DP4A_S8U8S32_INTRIN, *get_dp4a_intrin("int8", "uint8", "int32"))
70+
DP4A_U8U8U32_INTRIN = "dp4a_u8u8u32"
71+
TensorIntrin.register(DP4A_U8U8U32_INTRIN, *get_dp4a_intrin("uint8", "uint8", "uint32"))

python/tvm/tir/tensor_intrin/rocm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from tvm.runtime import convert
2222
from tvm.tir.expr import Cast, IntImm
23-
from .dot_product_common import dp4a_desc
23+
from .dot_product_common import get_dp4a_intrin
2424
from .. import TensorIntrin
2525

2626

@@ -50,6 +50,7 @@ def sdot4(
5050

5151
AMDGPU_SDOT4_INTRIN = "sdot4"
5252

53+
dp4a_desc, _ = get_dp4a_intrin("int8", "int8", "int32")
5354
TensorIntrin.register(AMDGPU_SDOT4_INTRIN, dp4a_desc, sdot4)
5455

5556
WARP_SIZE = 64

src/target/source/codegen_cuda.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
#include "../../tir/transforms/ir_utils.h"
3737
#include "literal/cuda_half_t.h"
38+
#include "literal/cuda_int8_t.h"
3839
#include "ptx.h"
3940

4041
namespace tvm {
@@ -130,6 +131,7 @@ std::string CodeGenCUDA::Finish() {
130131
if (enable_int8_) {
131132
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n";
132133
decl_stream << "#include <sm_61_intrinsics.h>\n";
134+
decl_stream << _cuda_int8_t_def;
133135
decl_stream << "#endif\n";
134136
}
135137

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file cuda_int8_t.h
22+
* \brief Extra int8 intrisic for cuda codegen.
23+
*/
24+
#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_INT8_T_H_
25+
#define TVM_TARGET_SOURCE_LITERAL_CUDA_INT8_T_H_
26+
27+
static constexpr const char* _cuda_int8_t_def = R"(
28+
29+
#if defined(__CUDACC_RTC__)
30+
#define __SM_61_INTRINSICS_DECL__ __device__
31+
#else /* !__CUDACC_RTC__ */
32+
#define __SM_61_INTRINSICS_DECL__ static __device__ __inline__
33+
#endif /* __CUDACC_RTC__ */
34+
35+
#ifndef __CUDA_ARCH__
36+
#define __DEF_IF_HOST { }
37+
#else /* !__CUDA_ARCH__ */
38+
#define __DEF_IF_HOST ;
39+
#endif /* __CUDA_ARCH__ */
40+
41+
__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c) __DEF_IF_HOST
42+
__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c) __DEF_IF_HOST
43+
44+
#undef __DEF_IF_HOST
45+
46+
#if !defined(__CUDACC_RTC__) && defined(__CUDA_ARCH__)
47+
__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c) {
48+
int ret;
49+
asm volatile ("dp4a.u32.s32 %0, %1, %2, %3;" : "=r"(ret) : "r"(srcA), "r"(srcB), "r"(c));
50+
return ret;
51+
}
52+
53+
__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c) {
54+
int ret;
55+
asm volatile ("dp4a.s32.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(srcA), "r"(srcB), "r"(c));
56+
return ret;
57+
}
58+
#endif /* !__CUDACC_RTC__ && defined(__CUDA_ARCH__) */
59+
60+
#undef __SM_61_INTRINSICS_DECL__
61+
62+
)";
63+
64+
#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_INT8_T_H_

tests/python/tir-schedule/test_tir_schedule_tensorize.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
verify_trace_roundtrip,
2727
)
2828
from tvm.tir.tensor_intrin.arm_cpu import (
29-
DP4A_INTRIN,
29+
DP4A_S8S8S32_INTRIN,
30+
DP4A_U8U8U32_INTRIN,
31+
DP4A_U8S8S32_INTRIN,
32+
DP4A_S8U8S32_INTRIN,
3033
ARM_DOT_4x4_i8_NEON_INTRIN,
3134
ARM_DOT_4x4_i8_SDOT_INTRIN,
3235
)
@@ -687,26 +690,25 @@ def test_tensorize_vdmpy():
687690
verify_trace_roundtrip(sch=sch, mod=func)
688691

689692

690-
def test_tensorize_dpa4():
691-
m, n, k = 128, 128, 128
692-
693-
X = te.placeholder((m, k), name="X", dtype="int8")
694-
W = te.placeholder((n, k), name="W", dtype="int8")
695-
ak = te.reduce_axis((0, k), name="k")
696-
697-
matmul = te.compute(
698-
(m, n),
699-
lambda i, j: te.sum(
700-
X[i, ak].astype("int32")
701-
* W[j, ak].astype("int32"),
702-
axis=ak,
703-
),
704-
name="compute",
705-
)
693+
def test_tensorize_dp4a():
694+
# pylint: disable=too-many-locals
695+
def _test_intrin(dtype_a, dtype_b, dtype_c, intrin):
696+
m, n, k = 128, 128, 128
697+
X = te.placeholder((m, k), name="X", dtype=dtype_a)
698+
W = te.placeholder((n, k), name="W", dtype=dtype_b)
699+
ak = te.reduce_axis((0, k), name="k")
700+
701+
matmul = te.compute(
702+
(m, n),
703+
lambda i, j: te.sum(
704+
X[i, ak].astype(dtype_c) * W[j, ak].astype(dtype_c),
705+
axis=ak,
706+
),
707+
name="compute",
708+
)
706709

707-
func = te.create_prim_func([X, W, matmul])
710+
func = te.create_prim_func([X, W, matmul])
708711

709-
for intrin in [AMDGPU_SDOT4_INTRIN, DP4A_INTRIN]:
710712
sch = tir.Schedule(func, debug_mask="all")
711713
block = sch.get_block("compute")
712714
i, j, k = sch.get_loops(block)
@@ -717,7 +719,6 @@ def test_tensorize_dpa4():
717719
ko, kt = sch.split(ko, factors=sch.sample_perfect_tile(ko, n=2))
718720

719721
sch.reorder(by, bx, ty, tx, yi, xi)
720-
721722
CC = sch.cache_write(block, 0, "local")
722723
sch.reverse_compute_at(CC, tx)
723724

@@ -734,6 +735,15 @@ def fetch_to_shared(block, idx):
734735

735736
verify_trace_roundtrip(sch=sch, mod=func)
736737

738+
for args in [
739+
("int8", "int8", "int32", AMDGPU_SDOT4_INTRIN),
740+
("int8", "int8", "int32", DP4A_S8S8S32_INTRIN),
741+
("int8", "uint8", "int32", DP4A_S8U8S32_INTRIN),
742+
("uint8", "int8", "int32", DP4A_U8S8S32_INTRIN),
743+
("uint8", "uint8", "uint32", DP4A_U8U8U32_INTRIN),
744+
]:
745+
_test_intrin(*args)
746+
737747

738748
def test_tensor_intrin_look_up():
739749
intrin_name = 'non_existent_intrin'

0 commit comments

Comments
 (0)