-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Metaschedule] Auto tensorization for CPU / GPU dot product #11088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
93190dd
264dde7
b1430f9
5d6baa9
cf6c9a7
b0e2b21
d380964
bda570d
b29a303
6742810
9377eaa
07d0457
e7483eb
fcd35e3
b97b56c
9b9855a
b90f8ee
9e10cf9
3d773f9
fda3d83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """A postprocessor that tensorize related components.""" | ||
|
|
||
| from tvm._ffi.registry import register_object | ||
| from .. import _ffi_api | ||
| from .postproc import Postproc | ||
|
|
||
|
|
||
| @register_object("meta_schedule.RewriteTensorize") | ||
| class RewriteTensorize(Postproc): | ||
| """A postprocessor that applies tensorization to annotated blocks. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| vectorize_init_loop : bool | ||
| Whether or not vectorize the initialization loop produced by DecomposeReduction | ||
| """ | ||
|
|
||
| def __init__(self, vectorize_init_loop=False) -> None: | ||
| self.__init_handle_by_constructor__( | ||
| _ffi_api.PostprocRewriteTensorize, # type: ignore # pylint: disable=no-member | ||
| vectorize_init_loop, | ||
| ) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,105 @@ | ||||||||||||||||||||||||||||||||||||||||||||||
| /* | ||||||||||||||||||||||||||||||||||||||||||||||
| * Licensed to the Apache Software Foundation (ASF) under one | ||||||||||||||||||||||||||||||||||||||||||||||
| * or more contributor license agreements. See the NOTICE file | ||||||||||||||||||||||||||||||||||||||||||||||
| * distributed with this work for additional information | ||||||||||||||||||||||||||||||||||||||||||||||
| * regarding copyright ownership. The ASF licenses this file | ||||||||||||||||||||||||||||||||||||||||||||||
| * to you under the Apache License, Version 2.0 (the | ||||||||||||||||||||||||||||||||||||||||||||||
| * "License"); you may not use this file except in compliance | ||||||||||||||||||||||||||||||||||||||||||||||
| * with the License. You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||||||||||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||||||||||
| * Unless required by applicable law or agreed to in writing, | ||||||||||||||||||||||||||||||||||||||||||||||
| * software distributed under the License is distributed on an | ||||||||||||||||||||||||||||||||||||||||||||||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||||||||||||||||||||||||||||||||||||||||||||||
| * KIND, either express or implied. See the License for the | ||||||||||||||||||||||||||||||||||||||||||||||
| * specific language governing permissions and limitations | ||||||||||||||||||||||||||||||||||||||||||||||
| * under the License. | ||||||||||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||||||||||
| #include <tvm/meta_schedule/postproc.h> | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| #include <algorithm> | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| #include "../utils.h" | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| namespace tvm { | ||||||||||||||||||||||||||||||||||||||||||||||
| namespace meta_schedule { | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| using tir::BlockRV; | ||||||||||||||||||||||||||||||||||||||||||||||
| using tir::LoopRV; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| void ApplyTensorization(const tir::Schedule& sch, const String& func_name, | ||||||||||||||||||||||||||||||||||||||||||||||
| const tir::PrimFuncNode* func, bool vectorize_init_loop) { | ||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<std::pair<std::string, std::function<void(tir::BlockRV)>>> jobs; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) { | ||||||||||||||||||||||||||||||||||||||||||||||
| if (const auto* block = obj.as<tir::BlockNode>()) { | ||||||||||||||||||||||||||||||||||||||||||||||
| tir::StmtSRef block_sref = sch->GetSRef(block); | ||||||||||||||||||||||||||||||||||||||||||||||
| if (Optional<String> intrin_name = | ||||||||||||||||||||||||||||||||||||||||||||||
| tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) { | ||||||||||||||||||||||||||||||||||||||||||||||
| std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint; | ||||||||||||||||||||||||||||||||||||||||||||||
| if (block_name.find("init") == std::string::npos) { | ||||||||||||||||||||||||||||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are target-specific handling here, ideally we can make the init block behavior configurable in meta schedule rule, it is fine for now |
||||||||||||||||||||||||||||||||||||||||||||||
| jobs.emplace_back(block_name, [sch, intrin_name](tir::BlockRV block) { | ||||||||||||||||||||||||||||||||||||||||||||||
| try { | ||||||||||||||||||||||||||||||||||||||||||||||
| sch->Tensorize(block, intrin_name.value()); | ||||||||||||||||||||||||||||||||||||||||||||||
| } catch (const std::exception& e) { | ||||||||||||||||||||||||||||||||||||||||||||||
| LOG(WARNING) << "Tensorize failed with error " << e.what(); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||||
| } else if (vectorize_init_loop) { | ||||||||||||||||||||||||||||||||||||||||||||||
| jobs.emplace_back(block_name, [sch](tir::BlockRV block) { | ||||||||||||||||||||||||||||||||||||||||||||||
| Array<BlockRV> child_blocks = sch->GetChildBlocks(block); | ||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(child_blocks.size() == 1); | ||||||||||||||||||||||||||||||||||||||||||||||
| Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]); | ||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(init_loops.size() == 1); | ||||||||||||||||||||||||||||||||||||||||||||||
| sch->Vectorize(init_loops[0]); | ||||||||||||||||||||||||||||||||||||||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Related to above, since
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hope it would, but it doesn't. Also since parallelization etc is supposed to be applied before I'd prefer vectoring in the init loop right after we run
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting! What’s the order of post-processors being applied now? Perhaps we should reflect this order by adding this post-processor to tune.py tvm/python/tvm/meta_schedule/tune.py Lines 159 to 170 in effc23d
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue in question is vectorization for CPU targets. I'm using the default postprocs in tvm/python/tvm/meta_schedule/tune.py Lines 96 to 103 in effc23d
Since loop parallelization or vectorization checks for the "compact dataflow" constraint, tvm/src/tir/schedule/primitive/for_kind.cc Line 160 in 0ddaaa6
DecomposeReduction in RewriteReductionBlock(). So having RewriteParallelVectorizeUnroll before RewriteReductionBlock() in the default postprocs makes sense.
However, this is not sufficient to vectorize the init loop of reduction block, since it is generated during
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case I want to tensorize the reduction block. So before
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. So the block we want to tensorize wasn’t applied by the schedule rule
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah yes (otherwise tensorize pattern matching fails, because an intrin desc is always serial), I'm not exactly sure what prevents
(after tiling the inner loop nests to be tensorized) is helping?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quite interesting.. So here the case is, on one hand we don’t want the block being annotated by rule Since before decomposition the block wasn’t annotated by For upstreaming, it might be okay to do manual vectorization in
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Exactly.
That's a great question! Until recently, vectorization of the init loop after Yeah, the ideally all outer loop parallelizations and inner loop vectorization can be done by one pass of |
||||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| for (auto kv : jobs) { | ||||||||||||||||||||||||||||||||||||||||||||||
| tir::BlockRV block = sch->GetBlock(kv.first, func_name); | ||||||||||||||||||||||||||||||||||||||||||||||
| sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize); | ||||||||||||||||||||||||||||||||||||||||||||||
| kv.second(block); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| class RewriteTensorizeNode : public PostprocNode { | ||||||||||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||||||||||
| void InitializeWithTuneContext(const TuneContext& context) final {} | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| bool Apply(const tir::Schedule& sch) final; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| void VisitAttrs(tvm::AttrVisitor* v) {} | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| bool vectorize_init_loop = false; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| static constexpr const char* _type_key = "meta_schedule.RewriteTensorize"; | ||||||||||||||||||||||||||||||||||||||||||||||
| TVM_DECLARE_FINAL_OBJECT_INFO(RewriteTensorizeNode, PostprocNode); | ||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { | ||||||||||||||||||||||||||||||||||||||||||||||
| for (const auto& kv : sch->mod()->functions) { | ||||||||||||||||||||||||||||||||||||||||||||||
| GlobalVar g_var = kv.first; | ||||||||||||||||||||||||||||||||||||||||||||||
| BaseFunc base_func = kv.second; | ||||||||||||||||||||||||||||||||||||||||||||||
| if (const tir::PrimFuncNode* prim_func = base_func.as<tir::PrimFuncNode>()) { | ||||||||||||||||||||||||||||||||||||||||||||||
| ApplyTensorization(sch, g_var->name_hint, prim_func, vectorize_init_loop); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| return true; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) { | ||||||||||||||||||||||||||||||||||||||||||||||
| ObjectPtr<RewriteTensorizeNode> n = make_object<RewriteTensorizeNode>(); | ||||||||||||||||||||||||||||||||||||||||||||||
| n->vectorize_init_loop = vectorize_init_loop; | ||||||||||||||||||||||||||||||||||||||||||||||
| return Postproc(n); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode); | ||||||||||||||||||||||||||||||||||||||||||||||
| TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize") | ||||||||||||||||||||||||||||||||||||||||||||||
| .set_body_typed(Postproc::RewriteTensorize); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| } // namespace meta_schedule | ||||||||||||||||||||||||||||||||||||||||||||||
| } // namespace tvm | ||||||||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.