2323 */
2424
2525#include < tvm/arith/analyzer.h>
26- #include < tvm/tir/data_type_rewriter.h>
2726#include < tvm/tir/function.h>
2827#include < tvm/tir/op.h>
2928#include < tvm/tir/stmt_functor.h>
3029#include < tvm/tir/transform.h>
3130
32- #include " ../../printer/text_printer.h"
3331#include " ../ir/functor_common.h"
3432#include " ir_utils.h"
3533
3634namespace tvm {
3735namespace tir {
38- class MatchBufferLower : public DataTypeLegalizer {
36+ class MatchBufferLower : public StmtExprMutator {
3937 public:
4038 explicit MatchBufferLower (const PrimFunc& func) {
4139 for (const Var& param : func->params ) {
@@ -190,14 +188,14 @@ class MatchBufferLower : public DataTypeLegalizer {
190188 Array<PrimExpr> buffer_start_indices = source_buffer->ElemOffset (indices);
191189 if (buffer_start_indices.size () == 1 ) {
192190 Bind (buffer->elem_offset , buffer_start_indices[0 ], buffer->name + " .elem_offset" );
193- CHECK (analyzer_.CanProve (truncmod (buffer_start_indices[ 0 ] , buffer->offset_factor ) == 0 ))
191+ CHECK (analyzer_.CanProve (truncmod (buffer-> elem_offset , buffer->offset_factor ) == 0 ))
194192 << " The source elem_offset " << buffer_start_indices[0 ]
195193 << " does not satisfy the offset_factor " << buffer->offset_factor << " ." ;
196194 } else {
197195 // Non-zero elem_offset is ill-defined for non-flat memory.
198196 // If needed in the future, will require `Array<PrimExpr>
199197 // elem_offsets`, with one offset for each flattened index.
200- Bind (buffer->elem_offset , make_zero (buffer-> elem_offset . dtype ()) );
198+ Bind (buffer->elem_offset , 0 );
201199 }
202200 }
203201
@@ -231,7 +229,7 @@ class MatchBufferLower : public DataTypeLegalizer {
231229 }
232230
233231 void Bind (const PrimExpr& arg, PrimExpr value, const std::string& arg_name = " argument" ) {
234- CHECK_EQ (arg.dtype (). code () , value.dtype (). code ())
232+ CHECK_EQ (arg.dtype (), value.dtype ())
235233 << " The data type mismatched: " << arg->dtype << " vs. " << value->dtype ;
236234 // Handle recursive case
237235 value = Substitute (std::move (value), var_map_);
@@ -240,7 +238,7 @@ class MatchBufferLower : public DataTypeLegalizer {
240238 auto it = var_map_.find (v);
241239 if (it == var_map_.end ()) {
242240 var_map_.Set (v, value);
243- // analyzer_.Bind(v, value);
241+ analyzer_.Bind (v, value);
244242 } else {
245243 AssertBinding ((*it).second , value, arg_name);
246244 }
@@ -249,21 +247,10 @@ class MatchBufferLower : public DataTypeLegalizer {
249247 }
250248 }
251249
252- PrimExpr LookUpArgBind (const PrimExpr& arg) {
253- if (arg->IsInstance <VarNode>()) {
254- Var v = Downcast<Var>(arg);
255- if (auto it = var_map_.find (v); it != var_map_.end ()) {
256- return (*it).second ;
257- }
258- }
259- return arg;
260- }
261-
262250 void AssertBinding (const PrimExpr& lhs, const PrimExpr& rhs,
263251 const std::string& arg_name = " argument" ) {
264- CHECK (analyzer_.CanProve (LookUpArgBind (lhs) == rhs))
265- << " The buffer match constraint for " << arg_name << " unmet: " << lhs << " ==" << rhs
266- << " ." ;
252+ CHECK (analyzer_.CanProve (lhs == rhs)) << " The buffer match constraint for " << arg_name
253+ << " unmet: " << lhs << " ==" << rhs << " ." ;
267254 }
268255
269256 private:
@@ -277,9 +264,7 @@ class MatchBufferLower : public DataTypeLegalizer {
277264
278265PrimFunc LowerMatchBuffer (PrimFunc func) {
279266 auto fptr = func.CopyOnWrite ();
280- // LOG(INFO) << "BeforeLMB:\n" << tir::AsTVMScript(func);
281267 fptr->body = MatchBufferLower (func)(std::move (fptr->body ));
282- // LOG(INFO) << "AfterLMB:\n" << tir::AsTVMScript(func);
283268 return func;
284269}
285270
@@ -297,4 +282,4 @@ TVM_REGISTER_GLOBAL("tir.transform.LowerMatchBuffer").set_body_typed(LowerMatchB
297282} // namespace transform
298283
299284} // namespace tir
300- } // namespace tvm
285+ } // namespace tvm
0 commit comments