Skip to content

Commit dba2b31

Browse files
committed
update postproc.h
1 parent ba1f6b8 commit dba2b31

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

include/tvm/meta_schedule/postproc.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,11 @@ class Postproc : public runtime::ObjectRef {
149149
*/
150150
TVM_DLL static Postproc RewriteUnboundBlock(int max_threadblock);
151151
/*!
152-
* \brief Create a postprocessor that tensorize Tensor Core related components
152+
* \brief Create a postprocessor that applies tensorization to annotated blocks
153+
* \param Whether or not vectorize the initialization loop produced by DecomposeReduction
153154
* \return The postprocessor created.
154155
*/
155-
TVM_DLL static Postproc RewriteTensorCore();
156+
TVM_DLL static Postproc RewriteTensorize(bool vectorize_init_loop = false);
156157

157158
/*!
158159
* \brief Creates a postprocessor that verifies if the GPU code is correct

python/tvm/meta_schedule/postproc/rewrite_tensorize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class RewriteTensorize(Postproc):
2828
Parameters
2929
----------
3030
vectorize_init_loop : bool
31-
Whether or not vectorize the initialization loop produced by decompose_reduction
31+
Whether or not vectorize the initialization loop produced by DecomposeReduction
3232
3333
"""
3434

src/meta_schedule/postproc/rewrite_tensorize.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19-
#include <tvm/runtime/container/base.h>
19+
#include <tvm/meta_schedule/postproc.h>
2020

2121
#include <algorithm>
2222

@@ -91,14 +91,15 @@ bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) {
9191
return true;
9292
}
9393

94-
Postproc RewriteTensorize(bool vectorize_init_loop) {
94+
Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) {
9595
ObjectPtr<RewriteTensorizeNode> n = make_object<RewriteTensorizeNode>();
9696
n->vectorize_init_loop = vectorize_init_loop;
9797
return Postproc(n);
9898
}
9999

100100
TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode);
101-
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize").set_body_typed(RewriteTensorize);
101+
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize")
102+
.set_body_typed(Postproc::RewriteTensorize);
102103

103104
} // namespace meta_schedule
104105
} // namespace tvm

0 commit comments

Comments
 (0)