Skip to content

Commit 0d0f7cc

Browse files
author
LeiWang199
committed
merge pr from apache#16560
1 parent 043f8a2 commit 0d0f7cc

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

src/tir/schedule/primitive/blockize_tensorize.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <tvm/tir/data_type_rewriter.h>
2020

2121
#include <functional>
22-
22+
#include "../../transforms/simplify.h"
2323
#include "../ir_comparator.h"
2424
#include "../utils.h"
2525

@@ -530,7 +530,8 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int
530530
<< GetRef<Stmt>(sref->stmt);
531531
throw;
532532
}
533-
PrimFunc intrin_desc = intrin->desc;
533+
arith::Analyzer analyzer;
534+
PrimFunc intrin_desc = Simplify(intrin->desc, &analyzer);
534535
PrimFunc intrin_impl = DeepCopy(intrin->impl);
535536

536537
int index_dtype_bits = -1;

src/tir/transforms/simplify.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
#include <tvm/tir/transform.h>
3131

3232
#include <optional>
33-
33+
#include "../../tir/transforms/simplify.h"
3434
#include "../../arith/ir_mutator_with_analyzer.h"
3535
#include "../../tir/analysis/control_flow_graph.h"
3636

@@ -270,6 +270,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
270270

271271
namespace tir {
272272

273+
274+
PrimFunc Simplify(PrimFunc func, arith::Analyzer* analyzer) {
275+
auto* n = func.CopyOnWrite();
276+
n->body = arith::StmtSimplifier::Apply(std::move(n->body), analyzer);
277+
return func;
278+
}
279+
273280
Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer) {
274281
return arith::StmtSimplifier::Apply(stmt, analyzer);
275282
}

src/tir/transforms/simplify.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
namespace tvm {
3131
namespace tir {
3232

33+
/* \brief Simplifies the prim func
34+
*
35+
* Applies the same behavior as the tir.transform.Simplify pass.
36+
*/
37+
PrimFunc Simplify(PrimFunc stmt, arith::Analyzer* analyzer);
38+
3339
/* \brief Simplifies the statement
3440
*
3541
* Applies the same behavior as the tir.transform.Simplify pass, but

0 commit comments

Comments
 (0)