Skip to content

Commit 86baa31

Browse files
committed
Ported auto-tensorization code
1 parent 534205b commit 86baa31

File tree

8 files changed

+431
-4
lines changed

8 files changed

+431
-4
lines changed

include/tvm/tir/stmt.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,11 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl
15091509
/*! \brief Mark auto-unroll setting on the block. */
15101510
constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
15111511

1512+
/*!
1513+
* \brief Mark that the block should be further rewritten using tensorization.
1514+
*/
1515+
constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1516+
15121517
/*!
15131518
* \brief Check if attr_key is a pragma key extension
15141519
* \param attr_key The attr key to be compared

include/tvm/tir/stmt_functor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ inline T Substitute(T input, const std::unordered_map<const VarNode*, PrimExpr>&
412412
* children of the node
413413
*/
414414
TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr,
415-
const std::function<bool(const ObjectRef&)>& fvisit);
415+
const std::function<bool(const ObjectRef&)>& fvisit, bool visit_init_block=true);
416416
} // namespace tir
417417
} // namespace tvm
418418

python/tvm/meta_schedule/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@
3939
tune_tir,
4040
)
4141
from .tune_context import TuneContext
42+
from . import tensor_intrin
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from . import vnni
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from tvm import tir
18+
from tvm.script import tir as T
19+
from tvm.script.registry import register
20+
21+
22+
@T.prim_func
23+
def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
24+
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
25+
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
26+
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
27+
28+
with T.block("root"):
29+
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
30+
T.writes(C[0:16])
31+
for i in T.serial(0, 16):
32+
with T.init():
33+
C[i] = T.int32(0)
34+
for k in T.serial(0, 4):
35+
with T.block("update"):
36+
vi, vk = T.axis.remap("SR", [i, k])
37+
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
38+
39+
40+
@T.prim_func
41+
def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
42+
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
43+
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
44+
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
45+
46+
with T.block("root"):
47+
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
48+
T.writes(C[0:16])
49+
50+
A_u8x4 = A.vload([0], "uint8x4")
51+
A_i32 = T.reinterpret(A_u8x4, dtype="int32")
52+
53+
B_i8x64 = B.vload([0, 0], dtype="int8x64")
54+
B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
55+
56+
C[
57+
T.ramp(T.int32(0), 1, 16)
58+
] += T.call_llvm_pure_intrin( # Note: this is an update +=
59+
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
60+
T.uint32(0),
61+
T.int32x16(0),
62+
T.broadcast(A_i32, 16),
63+
B_i32x16,
64+
dtype="int32x16",
65+
)
66+
67+
68+
tir.TensorIntrin.register(
69+
"dot_16x1x16_uint8_int8_int32_cascadelake", dot_product_desc, dot_product_intrin
70+
)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace meta_schedule {
23+
24+
using tir::BlockRV;
25+
using tir::LoopRV;
26+
27+
using BlockPosition = std::tuple<String, String, String>;
28+
29+
class RewriteVNNINode : public PostprocNode {
30+
public:
31+
// Inherited from PostprocNode
32+
void InitializeWithTuneContext(const TuneContext& context) final {}
33+
34+
// Inherited from PostprocNode
35+
bool Apply(const tir::Schedule& sch) final;
36+
37+
void VisitAttrs(tvm::AttrVisitor* v) {}
38+
39+
static constexpr const char* _type_key = "meta_schedule.RewriteVNNI";
40+
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteVNNINode, PostprocNode);
41+
};
42+
43+
void CollectTensorized(const tir::Schedule& sch, const String& func_name,
44+
const tir::PrimFuncNode* func, std::vector<BlockPosition>& tasks) {
45+
tir::PreOrderVisit(
46+
func->body,
47+
[&](const ObjectRef& obj) -> bool {
48+
if (const auto* block = obj.as<tir::BlockNode>()) {
49+
tir::StmtSRef block_sref = sch->GetSRef(block);
50+
if (Optional<String> intrin_name =
51+
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
52+
tasks.push_back(std::make_tuple(block_sref->StmtAs<tir::BlockNode>()->name_hint,
53+
func_name, intrin_name.value()));
54+
}
55+
}
56+
return true;
57+
},
58+
/*visit_init_block=*/false);
59+
}
60+
61+
bool RewriteVNNINode::Apply(const tir::Schedule& sch) {
62+
std::vector<BlockPosition> tasks;
63+
for (const auto& kv : sch->mod()->functions) {
64+
GlobalVar g_var = kv.first;
65+
BaseFunc base_func = kv.second;
66+
if (const tir::PrimFuncNode* prim_func = base_func.as<tir::PrimFuncNode>()) {
67+
CollectTensorized(sch, g_var->name_hint, prim_func, tasks);
68+
}
69+
}
70+
for (const BlockPosition& task : tasks) {
71+
// Retrieve the block rv according to the task noted down before
72+
BlockRV block_rv = sch->GetBlock(std::get<0>(task), std::get<1>(task));
73+
String intrin_name = std::get<2>(task);
74+
sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize);
75+
sch->Tensorize(block_rv, intrin_name);
76+
}
77+
return true;
78+
}
79+
80+
Postproc RewriteVNNI() {
81+
ObjectPtr<RewriteVNNINode> n = make_object<RewriteVNNINode>();
82+
return Postproc(n);
83+
}
84+
85+
TVM_REGISTER_NODE_TYPE(RewriteVNNINode);
86+
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteVNNI")
87+
.set_body_typed(RewriteVNNI);
88+
89+
} // namespace meta_schedule
90+
} // namespace tvm

0 commit comments

Comments
 (0)