File tree Expand file tree Collapse file tree 3 files changed +8
-6
lines changed
include/tvm/meta_schedule
python/tvm/meta_schedule/postproc
src/meta_schedule/postproc Expand file tree Collapse file tree 3 files changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -149,10 +149,11 @@ class Postproc : public runtime::ObjectRef {
149149 */
150150 TVM_DLL static Postproc RewriteUnboundBlock (int max_threadblock);
151151 /* !
152- * \brief Create a postprocessor that tensorize Tensor Core related components
152+ * \brief Create a postprocessor that applies tensorization to annotated blocks
153+ * \param Whether or not vectorize the initialization loop produced by DecomposeReduction
153154 * \return The postprocessor created.
154155 */
155- TVM_DLL static Postproc RewriteTensorCore ( );
156+ TVM_DLL static Postproc RewriteTensorize ( bool vectorize_init_loop = false );
156157
157158 /* !
158159 * \brief Creates a postprocessor that verifies if the GPU code is correct
Original file line number Diff line number Diff line change @@ -28,7 +28,7 @@ class RewriteTensorize(Postproc):
2828 Parameters
2929 ----------
3030 vectorize_init_loop : bool
31- Whether or not vectorize the initialization loop produced by decompose_reduction
31+ Whether or not vectorize the initialization loop produced by DecomposeReduction
3232
3333 """
3434
Original file line number Diff line number Diff line change 1616 * specific language governing permissions and limitations
1717 * under the License.
1818 */
19- #include < tvm/runtime/container/base .h>
19+ #include < tvm/meta_schedule/postproc .h>
2020
2121#include < algorithm>
2222
@@ -91,14 +91,15 @@ bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) {
9191 return true ;
9292}
9393
94- Postproc RewriteTensorize (bool vectorize_init_loop) {
94+ Postproc Postproc:: RewriteTensorize (bool vectorize_init_loop) {
9595 ObjectPtr<RewriteTensorizeNode> n = make_object<RewriteTensorizeNode>();
9696 n->vectorize_init_loop = vectorize_init_loop;
9797 return Postproc (n);
9898}
9999
100100TVM_REGISTER_NODE_TYPE (RewriteTensorizeNode);
101- TVM_REGISTER_GLOBAL (" meta_schedule.PostprocRewriteTensorize" ).set_body_typed(RewriteTensorize);
101+ TVM_REGISTER_GLOBAL (" meta_schedule.PostprocRewriteTensorize" )
102+ .set_body_typed(Postproc::RewriteTensorize);
102103
103104} // namespace meta_schedule
104105} // namespace tvm
You can’t perform that action at this time.
0 commit comments