Skip to content

Commit b724c87

Browse files
authored
[MetaSchedule][ARM] Enable ARM CPU intrinsic for MetaSchedule (#14209)
1 parent 7831a79 commit b724c87

File tree

6 files changed

+291
-35
lines changed

6 files changed

+291
-35
lines changed

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ class ScheduleRule : public runtime::ObjectRef {
300300
TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();
301301
/*! \brief Create default schedule rules for Micro */
302302
TVM_DLL static Array<ScheduleRule, void> DefaultMicro();
303+
/*! \brief Create default schedule rules for ARM CPU (NEON and DOTPROD) */
304+
TVM_DLL static Array<ScheduleRule, void> DefaultARM(const String& type);
303305

304306
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
305307
};

include/tvm/runtime/container/array.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,36 @@ class Array : public ObjectRef {
580580
}
581581
}
582582

583+
template <typename... Args>
584+
static size_t CalcCapacityImpl() {
585+
return 0;
586+
}
587+
588+
template <typename... Args>
589+
static size_t CalcCapacityImpl(Array<T> value, Args... args) {
590+
return value.size() + CalcCapacityImpl(args...);
591+
}
592+
593+
template <typename... Args>
594+
static size_t CalcCapacityImpl(T value, Args... args) {
595+
return 1 + CalcCapacityImpl(args...);
596+
}
597+
598+
template <typename... Args>
599+
static void AgregateImpl(Array<T>& dest) {} // NOLINT(*)
600+
601+
template <typename... Args>
602+
static void AgregateImpl(Array<T>& dest, Array<T> value, Args... args) { // NOLINT(*)
603+
dest.insert(dest.end(), value.begin(), value.end());
604+
AgregateImpl(dest, args...);
605+
}
606+
607+
template <typename... Args>
608+
static void AgregateImpl(Array<T>& dest, T value, Args... args) { // NOLINT(*)
609+
dest.push_back(value);
610+
AgregateImpl(dest, args...);
611+
}
612+
583613
public:
584614
// Array's own methods
585615

@@ -680,6 +710,19 @@ class Array : public ObjectRef {
680710
/*! \brief specify container node */
681711
using ContainerType = ArrayNode;
682712

713+
/*!
714+
* \brief Agregate arguments into a single Array<T>
715+
* \param args sequence of T or Array<T> elements
716+
* \return Agregated Array<T>
717+
*/
718+
template <typename... Args>
719+
static Array<T> Agregate(Args... args) {
720+
Array<T> result;
721+
result.reserve(CalcCapacityImpl(args...));
722+
AgregateImpl(result, args...);
723+
return result;
724+
}
725+
683726
private:
684727
/*!
685728
* \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements.

python/tvm/tir/tensor_intrin/arm_cpu.py

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
@T.prim_func
29-
def dot_product_4x4_i8i8i32_desc(
29+
def neon_4x4_i8i8i32_desc(
3030
A: T.Buffer((4,), "int8", offset_factor=1),
3131
B: T.Buffer((4, 4), "int8", offset_factor=1),
3232
C: T.Buffer((4,), "int32", offset_factor=1),
@@ -42,7 +42,7 @@ def dot_product_4x4_i8i8i32_desc(
4242

4343

4444
@T.prim_func
45-
def dot_product_4x4_i8i8i32_neon(
45+
def neon_4x4_i8i8i32_impl(
4646
A: T.Buffer((4,), "int8", offset_factor=1),
4747
B: T.Buffer((4, 4), "int8", offset_factor=1),
4848
C: T.Buffer((4,), "int32", offset_factor=1),
@@ -102,42 +102,71 @@ def dot_product_4x4_i8i8i32_neon(
102102
)
103103

104104

105-
@T.prim_func
106-
def dot_product_4x4_i8i8i32_sdot(
107-
A: T.Buffer((4,), "int8", offset_factor=1),
108-
B: T.Buffer((4, 4), "int8", offset_factor=1),
109-
C: T.Buffer((4,), "int32", offset_factor=1),
110-
) -> None:
111-
with T.block("root"):
112-
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
113-
T.writes(C[0:4])
114-
115-
A_i8x4 = A.vload([0], "int8x4")
116-
A_i32 = T.reinterpret(A_i8x4, dtype="int32")
117-
vec_ai32 = T.broadcast(A_i32, 4)
118-
vec_a = T.reinterpret(vec_ai32, dtype="int8x16")
119-
120-
vec_b = B.vload([0, 0], dtype="int8x16")
121-
122-
vec_c = C.vload([0], dtype="int32x4")
123-
124-
C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
125-
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"),
126-
T.uint32(3),
127-
vec_c,
128-
vec_a,
129-
vec_b,
130-
dtype="int32x4",
131-
)
105+
def get_dotprod_intrin(in_dtype, out_dtype):
106+
if in_dtype == "uint8":
107+
instr = "udot.v4u32.v16u8"
108+
else: # if in_dtype == "int8"
109+
instr = "sdot.v4i32.v16i8"
110+
111+
in_dtype_x4 = "{TYPE}x4".format(TYPE=in_dtype)
112+
out_dtype_x4 = "{TYPE}x4".format(TYPE=out_dtype)
113+
in_dtype_x16 = "{TYPE}x16".format(TYPE=in_dtype)
114+
115+
@T.prim_func
116+
def dot_prod_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
117+
A = T.match_buffer(a, (4,), dtype=in_dtype, offset_factor=1)
118+
B = T.match_buffer(b, (4, 4), dtype=in_dtype, offset_factor=1)
119+
C = T.match_buffer(c, (4,), dtype=out_dtype, offset_factor=1)
120+
with T.block("root"):
121+
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
122+
T.writes(C[0:4])
123+
for i in T.serial(0, 4):
124+
for k in T.serial(0, 4):
125+
with T.block("update"):
126+
vi, vk = T.axis.remap("SR", [i, k])
127+
C[vi] = C[vi] + T.cast(A[vk], dtype=out_dtype) * T.cast(
128+
B[vi, vk], dtype=out_dtype
129+
)
130+
131+
@T.prim_func
132+
def dot_prod_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
133+
A = T.match_buffer(a, (4,), dtype=in_dtype, offset_factor=1)
134+
B = T.match_buffer(b, (4, 4), dtype=in_dtype, offset_factor=1)
135+
C = T.match_buffer(c, (4,), dtype=out_dtype, offset_factor=1)
136+
with T.block("root"):
137+
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
138+
T.writes(C[0:4])
139+
140+
A_i8x4 = A.vload([0], in_dtype_x4)
141+
A_i32 = T.reinterpret(A_i8x4, dtype=out_dtype)
142+
vec_ai32 = T.broadcast(A_i32, 4)
143+
vec_a = T.reinterpret(vec_ai32, dtype=in_dtype_x16)
144+
145+
vec_b = B.vload([0, 0], dtype=in_dtype_x16)
146+
147+
vec_c = C.vload([0], dtype=out_dtype_x4)
148+
149+
C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
150+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.{INSTR}".format(INSTR=instr)),
151+
T.uint32(3),
152+
vec_c,
153+
vec_a,
154+
vec_b,
155+
dtype=out_dtype_x4,
156+
)
157+
158+
return dot_prod_desc, dot_prod_impl
132159

133160

134161
ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon"
135162
ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot"
163+
ARM_DOT_4x4_u8_UDOT_INTRIN = "dot_4x4_u8u8u32_udot"
164+
ARM_DOT_4x4_u8_HDOT_INTRIN = "dot_4x4_u8u8i32_hdot"
165+
166+
TensorIntrin.register(ARM_DOT_4x4_i8_NEON_INTRIN, neon_4x4_i8i8i32_desc, neon_4x4_i8i8i32_impl)
167+
168+
TensorIntrin.register(ARM_DOT_4x4_i8_SDOT_INTRIN, *get_dotprod_intrin("int8", "int32"))
136169

137-
TensorIntrin.register(
138-
ARM_DOT_4x4_i8_NEON_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_neon
139-
)
170+
TensorIntrin.register(ARM_DOT_4x4_u8_UDOT_INTRIN, *get_dotprod_intrin("uint8", "uint32"))
140171

141-
TensorIntrin.register(
142-
ARM_DOT_4x4_i8_SDOT_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_sdot
143-
)
172+
TensorIntrin.register(ARM_DOT_4x4_u8_HDOT_INTRIN, *get_dotprod_intrin("uint8", "int32"))

src/meta_schedule/schedule_rule/schedule_rule.cc

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,94 @@ Array<ScheduleRule> ScheduleRule::DefaultMicro() {
295295
};
296296
}
297297

298+
Array<ScheduleRule> GetNeonSpecificRules() {
299+
return {
300+
ScheduleRule::MultiLevelTilingWithIntrin(
301+
/*intrin_name=*/String("dot_4x4_i8i8s32_neon"),
302+
/*structure=*/"SSRSRS",
303+
/*tile_binds=*/NullOpt,
304+
/*max_innermost_factor=*/Integer(32),
305+
/*vector_load_lens=*/NullOpt,
306+
/*reuse_read=*/NullOpt,
307+
/*reuse_write=*/
308+
Map<String, ObjectRef>{{"req", String("may")},
309+
{"levels", Array<Integer>{1, 2}},
310+
{"scope", String("global")}}),
311+
};
312+
}
313+
314+
Array<ScheduleRule> GetDotprodSpecificRules() {
315+
return {
316+
ScheduleRule::MultiLevelTilingWithIntrin(
317+
/*intrin_name=*/String("dot_4x4_i8i8s32_sdot"),
318+
/*structure=*/"SSRSRS",
319+
/*tile_binds=*/NullOpt,
320+
/*max_innermost_factor=*/Integer(32),
321+
/*vector_load_lens=*/NullOpt,
322+
/*reuse_read=*/NullOpt,
323+
/*reuse_write=*/
324+
Map<String, ObjectRef>{{"req", String("may")},
325+
{"levels", Array<Integer>{1, 2}},
326+
{"scope", String("global")}}),
327+
ScheduleRule::MultiLevelTilingWithIntrin(
328+
/*intrin_name=*/String("dot_4x4_u8u8u32_udot"),
329+
/*structure=*/"SSRSRS",
330+
/*tile_binds=*/NullOpt,
331+
/*max_innermost_factor=*/Integer(32),
332+
/*vector_load_lens=*/NullOpt,
333+
/*reuse_read=*/NullOpt,
334+
/*reuse_write=*/
335+
Map<String, ObjectRef>{{"req", String("may")},
336+
{"levels", Array<Integer>{1, 2}},
337+
{"scope", String("global")}}),
338+
ScheduleRule::MultiLevelTilingWithIntrin(
339+
/*intrin_name=*/String("dot_4x4_u8u8i32_hdot"),
340+
/*structure=*/"SSRSRS",
341+
/*tile_binds=*/NullOpt,
342+
/*max_innermost_factor=*/Integer(32),
343+
/*vector_load_lens=*/NullOpt,
344+
/*reuse_read=*/NullOpt,
345+
/*reuse_write=*/
346+
Map<String, ObjectRef>{{"req", String("may")},
347+
{"levels", Array<Integer>{1, 2}},
348+
{"scope", String("global")}}),
349+
};
350+
}
351+
352+
Array<ScheduleRule> ScheduleRule::DefaultARM(const String& type) {
353+
return Array<ScheduleRule>::Agregate(
354+
ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(),
355+
ScheduleRule::AutoInline(
356+
/*into_producer=*/false,
357+
/*into_consumer=*/true,
358+
/*inline_const_tensor=*/true,
359+
/*disallow_if_then_else=*/true,
360+
/*require_injective=*/true,
361+
/*require_ordered=*/true,
362+
/*disallow_op=*/Array<String>{"tir.exp"}),
363+
ScheduleRule::AddRFactor(
364+
/*max_jobs_per_core=*/8,
365+
/*max_innermost_factor=*/Integer(32)),
366+
"neon" == type ? GetNeonSpecificRules() : Array<ScheduleRule>{},
367+
"dotprod" == type ? GetDotprodSpecificRules() : Array<ScheduleRule>{},
368+
ScheduleRule::MultiLevelTiling(
369+
/*structure=*/"SSRSRS",
370+
/*tile_binds=*/NullOpt,
371+
/*max_innermost_factor=*/Integer(32),
372+
/*vector_load_lens=*/NullOpt,
373+
/*reuse_read=*/NullOpt,
374+
/*reuse_write=*/
375+
Map<String, ObjectRef>{{"req", String("may")},
376+
{"levels", Array<Integer>{1, 2}},
377+
{"scope", String("global")}}),
378+
ScheduleRule::ParallelizeVectorizeUnroll(
379+
/*max_jobs_per_core=*/8,
380+
/*max_vectorize_extent=*/32,
381+
/*unroll_max_steps=*/Array<Integer>{0, 8, 32, 256},
382+
/*unroll_explicit=*/true),
383+
ScheduleRule::RandomComputeLocation());
384+
}
385+
298386
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
299387
.set_dispatch<PyScheduleRuleNode>([](const ObjectRef& n, ReprPrinter* p) {
300388
const auto* self = n.as<PyScheduleRuleNode>();
@@ -325,6 +413,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultHexagon")
325413
.set_body_typed(ScheduleRule::DefaultHexagon);
326414
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultMicro")
327415
.set_body_typed(ScheduleRule::DefaultMicro);
416+
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultARM")
417+
.set_body_typed(ScheduleRule::DefaultARM);
328418

329419
} // namespace meta_schedule
330420
} // namespace tvm

src/meta_schedule/space_generator/space_generator.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19+
#include "../../target/parsers/aprofile.h"
1920
#include "../utils.h"
2021

2122
namespace tvm {
@@ -38,6 +39,16 @@ String GetRuleKindFromTarget(const Target& target) {
3839
return "avx512";
3940
}
4041
}
42+
43+
TargetJSON target_json = target::parsers::aprofile::ParseTarget(target->Export());
44+
TargetFeatures afeatures = Downcast<TargetFeatures>(target_json.at("features"));
45+
46+
if (Downcast<Bool>(afeatures.at("has_dotprod"))) {
47+
return "dotprod";
48+
}
49+
if (Downcast<Bool>(afeatures.at("has_asimd"))) {
50+
return "asimd";
51+
}
4152
return "llvm";
4253
}
4354
if (target->kind->name == "hexagon") {
@@ -110,6 +121,14 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) {
110121
default_sch_rules = ScheduleRule::DefaultMicro();
111122
default_postprocs = Postproc::DefaultMicro();
112123
default_mutator_probs = Mutator::DefaultMicro();
124+
} else if (kind == "asimd") {
125+
default_sch_rules = ScheduleRule::DefaultARM("neon");
126+
default_postprocs = Postproc::DefaultCPUTensorization();
127+
default_mutator_probs = Mutator::DefaultLLVM();
128+
} else if (kind == "dotprod") {
129+
default_sch_rules = ScheduleRule::DefaultARM("dotprod");
130+
default_postprocs = Postproc::DefaultCPUTensorization();
131+
default_mutator_probs = Mutator::DefaultLLVM();
113132
} else {
114133
LOG(FATAL) << "Unsupported kind: " << kind;
115134
throw;

0 commit comments

Comments
 (0)