Skip to content

Commit a2fdb8f

Browse files
committed
TIR debug info
1 parent a61c1ad commit a2fdb8f

19 files changed

+764
-93
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,6 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py
268268

269269
# Used in CI to communicate between Python and Jenkins
270270
.docker-image-names/
271+
272+
# Printed TIR code on disk
273+
*.tir

include/tvm/ir/expr.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ class IntImm : public PrimExpr {
526526
TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span());
527527

528528
TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
529+
TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode);
529530
};
530531

531532
/*!
@@ -572,6 +573,7 @@ class FloatImm : public PrimExpr {
572573
TVM_DLL FloatImm(DataType dtype, double value, Span span = Span());
573574

574575
TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
576+
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode);
575577
};
576578

577579
/*!

include/tvm/tir/expr.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class StringImm : public PrimExpr {
7979
public:
8080
TVM_DLL StringImm(String value, Span span = Span());
8181
TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode);
82+
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode);
8283
};
8384

8485
/*!
@@ -117,6 +118,7 @@ class Cast : public PrimExpr {
117118
public:
118119
TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span());
119120
TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode);
121+
TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode);
120122
};
121123

122124
/*!
@@ -165,6 +167,7 @@ class Add : public PrimExpr {
165167
public:
166168
TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span());
167169
TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode);
170+
TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode);
168171
};
169172

170173
/*! \brief a - b */
@@ -181,6 +184,7 @@ class Sub : public PrimExpr {
181184
public:
182185
TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span());
183186
TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode);
187+
TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode);
184188
};
185189

186190
/*! \brief a * b */
@@ -197,6 +201,7 @@ class Mul : public PrimExpr {
197201
public:
198202
TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span());
199203
TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode);
204+
TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode);
200205
};
201206

202207
/*!
@@ -216,6 +221,7 @@ class Div : public PrimExpr {
216221
public:
217222
TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span());
218223
TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode);
224+
TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode);
219225
};
220226

221227
/*!
@@ -235,6 +241,7 @@ class Mod : public PrimExpr {
235241
public:
236242
TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span());
237243
TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode);
244+
TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode);
238245
};
239246

240247
/*! \brief Floor division, floor(a/b) */
@@ -251,6 +258,7 @@ class FloorDiv : public PrimExpr {
251258
public:
252259
TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span());
253260
TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode);
261+
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode);
254262
};
255263

256264
/*! \brief The remainder of the floordiv */
@@ -267,6 +275,7 @@ class FloorMod : public PrimExpr {
267275
public:
268276
TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span());
269277
TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode);
278+
TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode);
270279
};
271280

272281
/*! \brief min(a, b) */
@@ -283,6 +292,7 @@ class Min : public PrimExpr {
283292
public:
284293
TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span());
285294
TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode);
295+
TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode);
286296
};
287297

288298
/*! \brief max(a, b) */
@@ -299,6 +309,7 @@ class Max : public PrimExpr {
299309
public:
300310
TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span());
301311
TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode);
312+
TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode);
302313
};
303314

304315
/*!
@@ -347,6 +358,7 @@ class EQ : public PrimExpr {
347358
public:
348359
TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span());
349360
TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode);
361+
TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode);
350362
};
351363

352364
/*! \brief a != b */
@@ -363,6 +375,7 @@ class NE : public PrimExpr {
363375
public:
364376
TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span());
365377
TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode);
378+
TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode);
366379
};
367380

368381
/*! \brief a < b */
@@ -379,6 +392,7 @@ class LT : public PrimExpr {
379392
public:
380393
TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span());
381394
TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode);
395+
TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode);
382396
};
383397

384398
/*! \brief a <= b */
@@ -395,6 +409,7 @@ class LE : public PrimExpr {
395409
public:
396410
TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span());
397411
TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode);
412+
TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode);
398413
};
399414

400415
/*! \brief a > b */
@@ -411,6 +426,7 @@ class GT : public PrimExpr {
411426
public:
412427
TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span());
413428
TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode);
429+
TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode);
414430
};
415431

416432
/*! \brief a >= b */
@@ -427,6 +443,7 @@ class GE : public PrimExpr {
427443
public:
428444
TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span());
429445
TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode);
446+
TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode);
430447
};
431448

432449
/*! \brief a && b */
@@ -466,6 +483,7 @@ class And : public PrimExpr {
466483
public:
467484
TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span());
468485
TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode);
486+
TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode);
469487
};
470488

471489
/*! \brief a || b */
@@ -505,6 +523,7 @@ class Or : public PrimExpr {
505523
public:
506524
TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span());
507525
TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode);
526+
TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode);
508527
};
509528

510529
/*! \brief !a */
@@ -540,6 +559,7 @@ class Not : public PrimExpr {
540559
public:
541560
TVM_DLL Not(PrimExpr a, Span span = Span());
542561
TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode);
562+
TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode);
543563
};
544564

545565
/*!
@@ -591,6 +611,7 @@ class Select : public PrimExpr {
591611
TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span());
592612

593613
TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode);
614+
TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode);
594615
};
595616

596617
/*!
@@ -706,6 +727,7 @@ class ProducerLoad : public PrimExpr {
706727
TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span = Span());
707728

708729
TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode);
730+
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode);
709731
};
710732

711733
/*!
@@ -765,6 +787,7 @@ class Load : public PrimExpr {
765787
TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate,
766788
Span span = Span());
767789
TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode);
790+
TVM_DEFINE_OBJECT_REF_COW_METHOD(LoadNode);
768791
};
769792

770793
/*!
@@ -817,6 +840,7 @@ class Ramp : public PrimExpr {
817840
public:
818841
TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span());
819842
TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode);
843+
TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode);
820844
};
821845

822846
/*! \brief Create a vector where all the elements are value. */
@@ -856,6 +880,7 @@ class Broadcast : public PrimExpr {
856880
public:
857881
TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span());
858882
TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode);
883+
TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode);
859884
};
860885

861886
/*!
@@ -902,6 +927,7 @@ class Let : public PrimExpr {
902927
public:
903928
TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span());
904929
TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode);
930+
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode);
905931
};
906932

907933
/*!
@@ -948,6 +974,7 @@ class Call : public PrimExpr {
948974
public:
949975
TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, Span span = Span());
950976
TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
977+
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
951978
};
952979

953980
/*!
@@ -995,6 +1022,7 @@ class Shuffle : public PrimExpr {
9951022
TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span());
9961023

9971024
TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode);
1025+
TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode);
9981026
};
9991027

10001028
// Reduce operator
@@ -1124,6 +1152,7 @@ class Reduce : public PrimExpr {
11241152
int value_index, Array<PrimExpr> init, Span span = Span());
11251153

11261154
TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode);
1155+
TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode);
11271156
};
11281157

11291158
/*! \brief Any shape. */
@@ -1159,6 +1188,7 @@ class Any : public PrimExpr {
11591188
TVM_DLL Any(Span span = Span());
11601189

11611190
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode);
1191+
TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode);
11621192
};
11631193

11641194
/*

include/tvm/tir/stmt.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class LetStmt : public Stmt {
102102
TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());
103103

104104
TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode);
105+
TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode);
105106
};
106107

107108
/*!
@@ -158,6 +159,7 @@ class AttrStmt : public Stmt {
158159
TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span());
159160

160161
TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode);
162+
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode);
161163
};
162164

163165
/*!
@@ -206,6 +208,7 @@ class AssertStmt : public Stmt {
206208
TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span());
207209

208210
TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode);
211+
TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode);
209212
};
210213

211214
/*!
@@ -271,6 +274,7 @@ class Store : public Stmt {
271274
Span span = Span());
272275

273276
TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode);
277+
TVM_DEFINE_OBJECT_REF_COW_METHOD(StoreNode);
274278
};
275279

276280
/*!
@@ -442,6 +446,7 @@ class ProducerStore : public Stmt {
442446
Span span = Span());
443447

444448
TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode);
449+
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode);
445450
};
446451

447452
/*!
@@ -505,6 +510,7 @@ class ProducerRealize : public Stmt {
505510
String storage_scope = "", Span span = Span());
506511

507512
TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
513+
TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode);
508514
};
509515

510516
/*!
@@ -679,6 +685,7 @@ class AllocateConst : public Stmt {
679685
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
680686
Span span = Span());
681687
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
688+
TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode);
682689
};
683690

684691
/*! \brief Declare a buffer that can be used in the body */
@@ -812,6 +819,7 @@ class SeqStmt : public Stmt {
812819
};
813820

814821
TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode);
822+
TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode);
815823
};
816824

817825
/*!
@@ -858,6 +866,7 @@ class IfThenElse : public Stmt {
858866
Span span = Span());
859867

860868
TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode);
869+
TVM_DEFINE_OBJECT_REF_COW_METHOD(IfThenElseNode);
861870
};
862871

863872
/*!
@@ -897,6 +906,7 @@ class Evaluate : public Stmt {
897906
explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {}
898907

899908
TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
909+
TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode);
900910
};
901911

902912
/*!
@@ -1054,6 +1064,7 @@ class While : public Stmt {
10541064
TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
10551065

10561066
TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode);
1067+
TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode);
10571068
};
10581069

10591070
/*!
@@ -1098,6 +1109,7 @@ class Prefetch : public Stmt {
10981109
TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span());
10991110

11001111
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
1112+
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode);
11011113
};
11021114

11031115
/*!
@@ -1202,6 +1214,7 @@ class MatchBufferRegion : public ObjectRef {
12021214
TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
12031215

12041216
TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode);
1217+
TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode);
12051218
};
12061219

12071220
/*!

include/tvm/tir/transform.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,13 @@ TVM_DLL Pass LowerAsyncDMA();
499499
*/
500500
TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false);
501501

502+
/*!
503+
* \brief Add TIR-printer output as debug information to all ops in the module
504+
* \return The pass.
505+
*/
506+
507+
TVM_DLL Pass InstallDebugSpans();
508+
502509
/*!
503510
* \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and
504511
* "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g.,

0 commit comments

Comments
 (0)