Skip to content

Commit 92d6f9b

Browse files
committed
[LLVM][METASCHEDULE] Add RISCV V-extension v1.0 kernels to metaschedule
1 parent 585d6d2 commit 92d6f9b

File tree

8 files changed

+334
-1
lines changed

8 files changed

+334
-1
lines changed

include/tvm/meta_schedule/postproc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ class Postproc : public runtime::ObjectRef {
166166
TVM_DLL static Array<Postproc, void> DefaultLLVM();
167167
/*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */
168168
TVM_DLL static Array<Postproc, void> DefaultCPUTensorization();
169+
/*! \brief Create default postprocessors for RISCV */
170+
TVM_DLL static Array<Postproc, void> DefaultRISCV();
169171
/*! \brief Create default postprocessors for CUDA */
170172
TVM_DLL static Array<Postproc, void> DefaultCUDA();
171173
/*! \brief Create default postprocessors for CUDA with TensorCore */

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ class ScheduleRule : public runtime::ObjectRef {
301301
TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();
302302
/*! \brief Create default schedule rules for ARM CPU (NEON and DOTPROD) */
303303
TVM_DLL static Array<ScheduleRule, void> DefaultARM(const String& type);
304+
/*! \brief Create default schedule rules for RISCV CPU (RVV) */
305+
TVM_DLL static Array<ScheduleRule, void> DefaultRISCV(int vlen);
304306

305307
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
306308
};

python/tvm/target/target.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,14 @@ def riscv_cpu(model="sifive-u54", options=None):
637637
"-mabi=lp64d",
638638
# cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d -mcpu=sifive-u74
639639
],
640+
"licheepi3a": [
641+
"-num-cores=8",
642+
"-mtriple=riscv64-unknown-linux-gnu",
643+
"-mcpu=spacemit-x60",
644+
"-mfloat-abi=hard",
645+
"-mabi=lp64d",
646+
# cc: riscv64-unknown-linux-gnu-g++ -march=rv64gcv -mabi=lp64d -mcpu=spacemit-x60
647+
],
640648
}
641649
pre_defined_opt = trans_table.get(model, ["-model=%s" % model])
642650

python/tvm/tir/tensor_intrin/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020
from . import cuda
2121

2222
if enabled("llvm"):
23-
from . import arm_cpu, x86, rocm, hexagon
23+
from . import arm_cpu, x86, rocm, hexagon, riscv_cpu
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name,line-too-long
18+
"""Intrinsics for RISCV tensorization"""
19+
20+
import logging
21+
from tvm.ffi import register_func
22+
from tvm.runtime import DataType
23+
from tvm.script import tir as T
24+
from tvm.target.codegen import Target
25+
from tvm.target.codegen import llvm_get_vector_width, target_has_features
26+
from .. import TensorIntrin
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
def get_max_elems(vlen: int, lmul: int, sew: int) -> int:
32+
"""Returns number of elements of a given data type (SEW)
33+
that fits multiple (LMUL) of the vector registers (VLEN).
34+
35+
Args:
36+
vlen (int): VLEN vector length in bits
37+
lmul (int): LMUL vector lenght multiplier
38+
sew (int): SEW standard (single) element width
39+
40+
Returns:
41+
int: Number of elements
42+
"""
43+
return (vlen // sew) * lmul
44+
45+
46+
def rvv_vec_dot_product_kernels(
47+
n_elems: int,
48+
n_lanes: int,
49+
data_dtype: str,
50+
weight_dtype: str,
51+
out_dtype: str,
52+
lmul: int,
53+
):
54+
"""
55+
Dot product of vector and matrix rows using RISC-V vector instructions.
56+
57+
These kernel takes two arrays A[ELEMS] and B[ELEMS][MACS] and computes
58+
dot product of A[ELEMS] with each row of B[LANES], accumulating results
59+
in C[LANES].
60+
61+
The pseudo code is as follows:
62+
.. code-block:: c
63+
void vec_dot_prod(A[ELEMS], B[LANES][ELEMS], C[LANES]){
64+
for (j = 0; j < LANES; j++) {
65+
for (k = 0; k < ELEMS; k++) {
66+
C[j] += A[k] * B[j][k]
67+
}
68+
}
69+
}
70+
"""
71+
72+
@T.prim_func
73+
def rvv_vec_dot_prod_desc(
74+
A: T.Buffer((n_elems,), data_dtype, offset_factor=1),
75+
B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1),
76+
C: T.Buffer((n_lanes,), out_dtype, offset_factor=1),
77+
) -> None:
78+
with T.block("root"):
79+
T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems])
80+
T.writes(C[0:n_lanes])
81+
for j in T.serial(0, n_lanes):
82+
for k in T.serial(0, n_elems):
83+
with T.block("update"):
84+
vj, vk = T.axis.remap("SR", [j, k])
85+
C[vj] = C[vj] + T.cast(A[vk], out_dtype) * T.cast(B[vj, vk], out_dtype)
86+
87+
# LLVM only supports ELEN=32 or ELEN=64
88+
# https://llvm.org/docs//RISCV/RISCVVectorExtension.html
89+
d_dtype_lanes = (64 // DataType(data_dtype).bits) * lmul
90+
w_dtype_lanes = (64 // DataType(weight_dtype).bits) * lmul
91+
# reduction lanes narrow
92+
o_dtype_lanes = (64 // DataType(out_dtype).bits) * lmul // n_lanes
93+
# data type widening case
94+
o_dtype_lanes = max(o_dtype_lanes, 2)
95+
96+
mask_args = () if data_dtype[0] in ("i", "u") else (T.uint64(7),)
97+
98+
wide_dtype = out_dtype
99+
if DataType(out_dtype).bits > DataType(data_dtype).bits:
100+
wide_dtype = "".join(c for c in data_dtype if not c.isdigit())
101+
wide_dtype += str(DataType(data_dtype).bits * 2)
102+
103+
# fmt: off
104+
@T.prim_func
105+
def rvv_vec_dot_prod_impl(
106+
A: T.Buffer((n_elems,), data_dtype, offset_factor=1),
107+
B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1),
108+
C: T.Buffer((n_lanes,), out_dtype, offset_factor=1),
109+
) -> None:
110+
with T.block("root"):
111+
T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems])
112+
T.writes(C[0:n_lanes])
113+
114+
vec_A = T.call_llvm_intrin(
115+
f"{data_dtype}xvscalex{d_dtype_lanes}",
116+
"llvm.riscv.vle",
117+
T.broadcast(T.Cast(data_dtype, 0), T.vscale() * d_dtype_lanes),
118+
T.tvm_access_ptr(T.type_annotation(data_dtype), A.data, 0, n_elems, 1),
119+
T.int64(n_elems))
120+
121+
for i in range(n_lanes):
122+
with T.block("update"):
123+
T.reads(B[i, 0:n_elems])
124+
T.writes(C[i])
125+
126+
vec_B_row = T.call_llvm_intrin(
127+
f"{weight_dtype}xvscalex{w_dtype_lanes}",
128+
"llvm.riscv.vle",
129+
T.broadcast(T.Cast(data_dtype, 0), T.vscale() * w_dtype_lanes),
130+
T.tvm_access_ptr(T.type_annotation(weight_dtype), B.data, i * n_elems, n_elems, 1),
131+
T.int64(n_elems))
132+
133+
product = T.call_llvm_intrin(
134+
f"{wide_dtype}xvscalex{w_dtype_lanes}",
135+
"llvm.riscv.vfmul" if out_dtype[0] == "f" else \
136+
"llvm.riscv.vwmulsu" if (data_dtype[0] != weight_dtype[0]) else \
137+
"llvm.riscv.vwmul",
138+
T.broadcast(T.Cast(wide_dtype, 0), T.vscale() * w_dtype_lanes),
139+
vec_B_row,
140+
vec_A,
141+
*mask_args,
142+
T.uint64(n_elems))
143+
144+
ini_acc = T.call_llvm_intrin(
145+
f"{out_dtype}xvscalex{o_dtype_lanes}",
146+
"llvm.riscv.vle",
147+
T.broadcast(T.Cast(out_dtype, 0), T.vscale() * o_dtype_lanes),
148+
T.tvm_access_ptr(T.type_annotation(out_dtype), C.data, i, 1, 1),
149+
T.int64(1))
150+
151+
red_sum = T.call_llvm_intrin(
152+
f"{out_dtype}xvscalex{o_dtype_lanes}",
153+
"llvm.riscv.vfredusum" if out_dtype[0] == "f" else \
154+
"llvm.riscv.vwredsum",
155+
T.broadcast(T.Cast(out_dtype, 0), T.vscale() * o_dtype_lanes),
156+
product,
157+
ini_acc,
158+
*mask_args,
159+
T.uint64(n_elems))
160+
161+
C[i] = T.call_llvm_intrin(
162+
out_dtype,
163+
"llvm.riscv.vfmv.f.s" if out_dtype[0] == "f" else \
164+
"llvm.riscv.vmv.x.s",
165+
red_sum)
166+
# fmt: on
167+
168+
return rvv_vec_dot_prod_desc, rvv_vec_dot_prod_impl
169+
170+
171+
@register_func("tir.tensor_intrin.register_rvv_isa_intrisics")
172+
def register_rvv_isa_intrisics(target: Target, only_inventory=False) -> dict():
173+
"""Register RISCV V (vector) intrinsics
174+
[x] Implementation follows version 1.0 vector specifications:
175+
https://github.com/riscvarchive/riscv-v-spec/releases/tag/v1.0
176+
177+
Args:
178+
target (Target): TVM target
179+
only_inventory (bool): No registration inventory only
180+
181+
Returns:
182+
dict(): A catalog with registered kernel names and properties
183+
"""
184+
if not target_has_features("v", target):
185+
raise RuntimeError("Current target does not support `v` extension.")
186+
187+
vlen = llvm_get_vector_width(target)
188+
# get maximum reduction lanes (without grouping)
189+
n_lanes = get_max_elems(vlen, lmul=1, sew=32)
190+
191+
data_dtype = ["uint8", "int8", "float16", "float32"]
192+
weight_dtype = ["int8", "int8", "float16", "float32"]
193+
output_dtype = ["int32", "int32", "float16", "float32"]
194+
195+
kernel_inventory = {}
196+
197+
for d_dtype, w_dtype, o_dtype in zip(data_dtype, weight_dtype, output_dtype):
198+
# max elements to grouped registers
199+
max_elems = get_max_elems(vlen, lmul=8, sew=DataType(d_dtype).bits)
200+
# data widening halves available vector registers
201+
if DataType(o_dtype).bits > DataType(d_dtype).bits:
202+
max_elems //= 2
203+
# compute optimal LMUL for full load
204+
lmul = max_elems // (vlen // DataType(d_dtype).bits)
205+
206+
n_elems = max_elems
207+
while n_elems >= 4:
208+
209+
dt = DataType(d_dtype)
210+
wt = DataType(w_dtype)
211+
ot = DataType(o_dtype)
212+
213+
kernel_name = "rvv_dot"
214+
kernel_name += f"_{n_elems}{dt[0]}{dt.bits}"
215+
kernel_name += f"_{n_lanes}x{n_elems}{wt[0]}{wt.bits}"
216+
kernel_name += f"_{n_lanes}{ot[0]}{ot.bits}"
217+
kernel_inventory[kernel_name] = n_elems
218+
219+
if not only_inventory:
220+
logger.debug(f"Registering kernel {kernel_name}")
221+
desc, impl = rvv_vec_dot_product_kernels(
222+
n_elems, n_lanes, d_dtype, w_dtype, o_dtype, lmul
223+
)
224+
TensorIntrin.register(kernel_name, desc, impl, override=True)
225+
226+
n_elems //= 2
227+
228+
return kernel_inventory
229+
230+
231+
def register_riscv_intrinsics(target: Target):
232+
"""Register RISCV intrinsics
233+
234+
Args:
235+
target (Target): TVM target
236+
"""
237+
238+
# RISC-V `v` extension ISA
239+
_ = register_rvv_isa_intrisics(target)
240+
logger.debug("Finished registering riscv intrinsics.")

src/meta_schedule/postproc/postproc.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ Array<Postproc> Postproc::DefaultCPUTensorization() {
6969
};
7070
}
7171

72+
Array<Postproc> Postproc::DefaultRISCV() {
73+
return Array<Postproc>{
74+
Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(),
75+
Postproc::RewriteReductionBlock(), Postproc::RewriteTensorize(/*vectorize_init_loop=*/false),
76+
Postproc::RewriteLayout(),
77+
};
78+
}
79+
7280
Array<Postproc> Postproc::DefaultCUDA() {
7381
return Array<Postproc>{
7482
Postproc::DisallowDynamicLoop(),

src/meta_schedule/schedule_rule/schedule_rule.cc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* under the License.
1818
*/
1919
#include <tvm/ffi/reflection/registry.h>
20+
#include <tvm/runtime/data_type.h>
2021

2122
#include "../utils.h"
2223

@@ -304,6 +305,62 @@ Array<ScheduleRule> ScheduleRule::DefaultHexagon() {
304305
};
305306
}
306307

308+
Array<ScheduleRule> ScheduleRule::DefaultRISCV(const int vlen) {
309+
Array<ScheduleRule> rules;
310+
rules.push_back(ScheduleRule::ApplyCustomRule());
311+
rules.push_back(ScheduleRule::InlineConstantScalars());
312+
rules.push_back(ScheduleRule::AutoInline(
313+
/*into_producer=*/false,
314+
/*into_consumer=*/true,
315+
/*inline_const_tensor=*/true,
316+
/*disallow_if_then_else=*/true,
317+
/*require_injective=*/true,
318+
/*require_ordered=*/true,
319+
/*disallow_op=*/Array<String>{"tir.exp"}));
320+
rules.push_back(ScheduleRule::AddRFactor(
321+
/*max_jobs_per_core=*/16,
322+
/*max_innermost_factor=*/Integer(64)));
323+
auto current_target = tvm::Target::Current();
324+
const auto reg_rvv_intrinsics =
325+
tvm::ffi::Function::GetGlobalRequired("tir.tensor_intrin.register_rvv_isa_intrisics");
326+
const auto rvv_kernels_inventory =
327+
reg_rvv_intrinsics(current_target, /* only_inventory */ true).cast<Map<String, int>>();
328+
for (const auto& intrin : rvv_kernels_inventory) {
329+
if (!tir::TensorIntrin::Get(intrin.first, /*allow_missing*/ true)) {
330+
// on demand intrinsic register
331+
reg_rvv_intrinsics(current_target, /* only_inventory */ false);
332+
}
333+
rules.push_back(ScheduleRule::MultiLevelTilingWithIntrin(
334+
/*intrin_name=*/intrin.first,
335+
/*structure=*/"SSRSRS",
336+
/*tile_binds=*/std::nullopt,
337+
/*max_innermost_factor=*/Integer(intrin.second),
338+
/*vector_load_lens=*/std::nullopt,
339+
/*reuse_read=*/std::nullopt,
340+
/*reuse_write=*/
341+
Map<String, ffi::Any>{{"req", String("may")},
342+
{"levels", Array<Integer>{1, 2}},
343+
{"scope", String("global")}}));
344+
}
345+
rules.push_back(ScheduleRule::MultiLevelTiling(
346+
/*structure=*/"SSRSRS",
347+
/*tile_binds=*/std::nullopt,
348+
/*max_innermost_factor=*/Integer(64),
349+
/*vector_load_lens=*/std::nullopt,
350+
/*reuse_read=*/std::nullopt,
351+
/*reuse_write=*/
352+
Map<String, ffi::Any>{
353+
{"req", String("may")}, {"levels", Array<Integer>{1, 2}}, {"scope", String("global")}}));
354+
rules.push_back(ScheduleRule::ParallelizeVectorizeUnroll(
355+
/*max_jobs_per_core=*/16,
356+
/*max_vectorize_extent=*/64,
357+
/*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512},
358+
/*unroll_explicit=*/true));
359+
rules.push_back(ScheduleRule::RandomComputeLocation());
360+
361+
return rules;
362+
}
363+
307364
Array<ScheduleRule> GetARMNeonSpecificRules() {
308365
return {
309366
ScheduleRule::MultiLevelTilingWithIntrin(

0 commit comments

Comments
 (0)