3434namespace tvm {
3535namespace tir {
3636
37- class StackSizeChecker : public StmtExprVisitor {
37+ // Calculate the statistics of packed function.
38+ // These information are needed during codegen.
39+ class BuiltinLower : public StmtExprMutator {
3840 public:
3941 struct StackSizes {
4042 // If a tvm_stack_make_shape call has no arguments, it is still
@@ -46,159 +48,86 @@ class StackSizeChecker : public StmtExprVisitor {
4648 uint64_t arg_stack{0 };
4749 };
4850
49- static StackSizes Check (Stmt stmt) {
50- StackSizeChecker visitor;
51- visitor.VisitStmt (stmt);
52- return visitor.max_stack_ ;
53- }
54-
55- private:
56- void VisitStmt_ (const ForNode* op) final {
57- if (op->kind == ForKind::kParallel ) {
58- // Parallel for loops have their own stack and allocations, so
59- // stop the recursion here.
60- return ;
61- } else {
62- this ->VisitStmt (op->body );
63- }
64- }
65- void VisitExpr_ (const CallNode* op) final {
66- if (op->op .same_as (builtin::tvm_call_packed ())) {
67- return MakeCallPacked (op, /* use_string_lookup */ true );
68- } else if (op->op .same_as (builtin::tvm_call_cpacked ())) {
69- return MakeCallPacked (op, /* use_string_lookup */ false );
70- } else if (op->op .same_as (builtin::tvm_call_trace_packed ())) {
71- return MakeCallTracePacked (op);
72- } else if (op->op .same_as (builtin::tvm_stack_make_shape ())) {
73- return MakeShape (op);
74- } else if (op->op .same_as (builtin::tvm_stack_make_array ())) {
75- return MakeArray (op);
76- } else {
77- return StmtExprVisitor::VisitExpr_ (op);
78- }
79- }
80- // call shape
81- void MakeShape (const CallNode* op) {
82- // if args.size() == 0, it is still valid and represents a scalar
83- // shape (). Therefore, -1 is used to represent "no shape
84- // arguments exist", while 0 represents "shape arguments exist,
85- // all of which are size 0".
86- if (current_stack_.shape_stack == -1 ) {
87- current_stack_.shape_stack = 0 ;
88- }
89- current_stack_.shape_stack += op->args .size ();
90- StmtExprVisitor::VisitExpr_ (op);
91- }
92- // make array
93- void MakeArray (const CallNode* op) {
94- current_stack_.array_stack += 1 ;
95- StmtExprVisitor::VisitExpr_ (op);
96- }
97- // call packed.
98- void MakeCallPacked (const CallNode* op, bool use_string_lookup) {
99- StackSizes restore_stack = current_stack_;
100-
101- size_t arg_count = op->args .size ();
102-
103- // cpacked expects a resource_handle parameter
104- if (!use_string_lookup) {
105- arg_count--;
106- }
107-
108- current_stack_.arg_stack += arg_count;
109- // Specially handle the buffer packed intrinsic
110- StmtExprVisitor::VisitExpr_ (op);
111- // Record the amount of stack space needed, then reset the stack
112- // position to its previous location.
113- UpdateMaxStack ();
114- current_stack_ = restore_stack;
115- }
116-
117- void MakeCallTracePacked (const CallNode* op) {
118- StackSizes restore_stack = current_stack_;
119-
120- size_t args_size = op->args .size ();
121- ICHECK_GT (args_size, 0 );
122- current_stack_.arg_stack += args_size;
123-
124- StmtExprVisitor::VisitExpr_ (op);
125- // Record the amount of stack space needed, then reset the stack
126- // position to its previous location.
127- UpdateMaxStack ();
128- current_stack_ = restore_stack;
129-
130- // However, the arguments to this CallNode remain on top of the
131- // stack, so we can use more than one packed function's arguments
132- // with the one stack.
133- current_stack_.arg_stack = restore_stack.arg_stack + args_size - 1 ;
134- }
135-
136- void UpdateMaxStack () {
137- max_stack_.arg_stack = std::max (current_stack_.arg_stack , max_stack_.arg_stack );
138- max_stack_.shape_stack = std::max (current_stack_.shape_stack , max_stack_.shape_stack );
139- max_stack_.array_stack = std::max (current_stack_.array_stack , max_stack_.array_stack );
140- }
141-
142- StackSizes current_stack_;
143- StackSizes max_stack_;
144- };
145-
146- // Calculate the statistics of packed function.
147- // These information are needed during codegen.
148- class BuiltinLower : public StmtExprMutator {
149- public:
15051 // Record stack frame for existing scope.
15152 struct AllocaScope {
15253 Buffer stack_shape;
15354 Var stack_array = Var(" stack_array" , DataType::Handle());
15455 Var stack_value = Var(" stack_value" , DataType::Handle());
15556 Buffer stack_tcode;
15657
157- int64_t max_shape_stack{-1 };
158- uint64_t max_array_stack{0 };
159- uint64_t max_arg_stack{0 };
58+ StackSizes max_sizes;
59+ StackSizes run_sizes;
16060
161- int64_t run_shape_stack{-1 };
162- uint64_t run_array_stack{0 };
163- uint64_t run_arg_stack{0 };
61+ void UpdateMax () {
62+ max_sizes.shape_stack = std::max (max_sizes.shape_stack , run_sizes.shape_stack );
63+ max_sizes.array_stack = std::max (max_sizes.array_stack , run_sizes.array_stack );
64+ max_sizes.arg_stack = std::max (max_sizes.arg_stack , run_sizes.arg_stack );
65+ }
66+
67+ void AssertMaxIsValid () const {
68+ ICHECK ((max_sizes.shape_stack >= run_sizes.shape_stack ) ||
69+ (max_sizes.array_stack >= run_sizes.array_stack ) ||
70+ (max_sizes.arg_stack >= run_sizes.arg_stack ));
71+ }
16472 };
16573
16674 Stmt Build (Stmt stmt) { return this ->VisitBodyAndRealizeAlloca (stmt); }
16775
76+ StackSizes GetMaxStack (Stmt stmt) {
77+ BuiltinLower precheck;
78+ precheck.is_precheck_ = true ;
79+ precheck.device_id_ = this ->device_id_ ;
80+ precheck.device_type_ = this ->device_type_ ;
81+
82+ precheck.alloca_scope_ .emplace_back ();
83+ auto & scope = precheck.alloca_scope_ .back ();
84+ scope.stack_shape =
85+ decl_buffer ({IntImm (DataType::Int (64 ), 0 )}, DataType::Int (64 ), " stack_shape" );
86+ scope.stack_tcode =
87+ decl_buffer ({IntImm (DataType::UInt (64 ), 0 )}, DataType::Int (32 ), " stack_tcode" );
88+
89+ precheck.VisitStmt (stmt);
90+
91+ ICHECK_EQ (precheck.alloca_scope_ .size (), 1 );
92+ return precheck.alloca_scope_ [0 ].max_sizes ;
93+ }
94+
16895 // Allcoate stack frames, only at parallel-for or root.
16996 Stmt VisitBodyAndRealizeAlloca (Stmt stmt) {
170- // Initial check to identify maximum stack sizes. These are used
171- // to construct Buffer objects to hold the stack, which are then
172- // used when mutating.
173- auto max_sizes = StackSizeChecker::Check (stmt);
97+ // Only perform the precheck up to the point where we would add a
98+ // new scope.
99+ if (is_precheck_) {
100+ return stmt;
101+ }
174102
175103 alloca_scope_.emplace_back ();
176104 auto & scope = alloca_scope_.back ();
177105
178- if (max_sizes.shape_stack != -1 ) {
179- scope.stack_shape = decl_buffer ({IntImm (DataType::Int (64 ), max_sizes.shape_stack )},
106+ // Initial check to identify maximum stack sizes. These are used
107+ // to construct Buffer objects to hold the stack, which are then
108+ // used when mutating.
109+ scope.max_sizes = GetMaxStack (stmt);
110+
111+ if (scope.max_sizes .shape_stack != -1 ) {
112+ scope.stack_shape = decl_buffer ({IntImm (DataType::Int (64 ), scope.max_sizes .shape_stack )},
180113 DataType::Int (64 ), " stack_shape" );
181- stmt = LetStmt (scope.stack_shape ->data , StackAlloca (" shape" , max_sizes.shape_stack ), stmt);
114+ stmt =
115+ LetStmt (scope.stack_shape ->data , StackAlloca (" shape" , scope.max_sizes .shape_stack ), stmt);
182116 }
183117
184- if (max_sizes.array_stack != 0 ) {
185- stmt = LetStmt (scope.stack_array , StackAlloca (" array" , max_sizes.array_stack ), stmt);
118+ if (scope. max_sizes .array_stack != 0 ) {
119+ stmt = LetStmt (scope.stack_array , StackAlloca (" array" , scope. max_sizes .array_stack ), stmt);
186120 }
187121
188- if (max_sizes.arg_stack != 0 ) {
189- scope.stack_tcode = decl_buffer ({IntImm (DataType::UInt (64 ), max_sizes.arg_stack )},
122+ if (scope. max_sizes .arg_stack != 0 ) {
123+ scope.stack_tcode = decl_buffer ({IntImm (DataType::UInt (64 ), scope. max_sizes .arg_stack )},
190124 DataType::Int (32 ), " stack_tcode" );
191- stmt = LetStmt (scope.stack_value , StackAlloca (" arg_value" , max_sizes.arg_stack ), stmt);
125+ stmt = LetStmt (scope.stack_value , StackAlloca (" arg_value" , scope. max_sizes .arg_stack ), stmt);
192126
193- stmt = LetStmt (scope.stack_tcode ->data , StackAlloca (" arg_tcode" , max_sizes.arg_stack ), stmt);
127+ stmt = LetStmt (scope.stack_tcode ->data , StackAlloca (" arg_tcode" , scope.max_sizes .arg_stack ),
128+ stmt);
194129 }
195130
196- // Copy these values from the earlier search, for use in bounds
197- // checks.
198- scope.max_shape_stack = max_sizes.shape_stack ;
199- scope.max_array_stack = max_sizes.array_stack ;
200- scope.max_arg_stack = max_sizes.arg_stack ;
201-
202131 stmt = this ->VisitStmt (stmt);
203132
204133 ICHECK (!alloca_scope_.empty ());
@@ -213,8 +142,8 @@ class BuiltinLower : public StmtExprMutator {
213142
214143 auto stmt = StmtExprMutator::VisitStmt (s);
215144 auto & scope = alloca_scope_.back ();
216- ICHECK_EQ (scope.run_shape_stack , -1 );
217- ICHECK_EQ (scope.run_array_stack , 0 );
145+ ICHECK_EQ (scope.run_sizes . shape_stack , -1 );
146+ ICHECK_EQ (scope.run_sizes . array_stack , 0 );
218147
219148 auto prep_seq = std::move (prep_seq_stack_.back ());
220149 prep_seq_stack_.pop_back ();
@@ -364,11 +293,11 @@ class BuiltinLower : public StmtExprMutator {
364293 ICHECK (!alloca_scope_.empty ());
365294 auto & scope = alloca_scope_.back ();
366295 auto & prep_seq = prep_seq_stack_.back ();
367- if (scope.run_shape_stack == -1 ) {
368- scope.run_shape_stack = 0 ;
296+ if (scope.run_sizes . shape_stack == -1 ) {
297+ scope.run_sizes . shape_stack = 0 ;
369298 }
370- int64_t stack_begin = scope.run_shape_stack ;
371- scope.run_shape_stack += op->args .size ();
299+ int64_t stack_begin = scope.run_sizes . shape_stack ;
300+ scope.run_sizes . shape_stack += op->args .size ();
372301 PrimExpr expr = StmtExprMutator::VisitExpr_ (op);
373302 op = expr.as <CallNode>();
374303 // no need to perform any store for a scalar shape
@@ -384,8 +313,8 @@ class BuiltinLower : public StmtExprMutator {
384313 auto & scope = alloca_scope_.back ();
385314 auto & prep_seq = prep_seq_stack_.back ();
386315
387- size_t idx = scope.run_array_stack ;
388- scope.run_array_stack += 1 ;
316+ size_t idx = scope.run_sizes . array_stack ;
317+ scope.run_sizes . array_stack += 1 ;
389318 PrimExpr expr = StmtExprMutator::VisitExpr_ (op);
390319 op = expr.as <CallNode>();
391320
@@ -426,9 +355,9 @@ class BuiltinLower : public StmtExprMutator {
426355 auto & scope = alloca_scope_.back ();
427356 auto & prep_seq = prep_seq_stack_.back ();
428357
429- int64_t restore_shape_stack = scope.run_shape_stack ;
430- size_t restore_array_stack = scope.run_array_stack ;
431- size_t arg_stack_begin = scope.run_arg_stack ;
358+ int64_t restore_shape_stack = scope.run_sizes . shape_stack ;
359+ size_t restore_array_stack = scope.run_sizes . array_stack ;
360+ size_t arg_stack_begin = scope.run_sizes . arg_stack ;
432361
433362 size_t arg_count = op->args .size ();
434363
@@ -437,7 +366,7 @@ class BuiltinLower : public StmtExprMutator {
437366 arg_count--;
438367 }
439368
440- scope.run_arg_stack += arg_count;
369+ scope.run_sizes . arg_stack += arg_count;
441370 // Specially handle the buffer packed intrinsic
442371 PrimExpr expr = StmtExprMutator::VisitExpr_ (op);
443372 op = expr.as <CallNode>();
@@ -460,12 +389,14 @@ class BuiltinLower : public StmtExprMutator {
460389 prep_seq.emplace_back (BufferStore (scope.stack_tcode , ConstInt32 (arg_tcode), {stack_index}));
461390 }
462391 // Verify stack size matches earlier value.
463- ICHECK_LE (scope.run_arg_stack , scope.max_arg_stack );
464- ICHECK_LE (scope.run_shape_stack , scope.max_shape_stack );
465- ICHECK_LE (scope.run_array_stack , scope.max_array_stack );
466- scope.run_shape_stack = restore_shape_stack;
467- scope.run_array_stack = restore_array_stack;
468- scope.run_arg_stack = arg_stack_begin;
392+ if (is_precheck_) {
393+ scope.UpdateMax ();
394+ } else {
395+ scope.AssertMaxIsValid ();
396+ }
397+ scope.run_sizes .shape_stack = restore_shape_stack;
398+ scope.run_sizes .array_stack = restore_array_stack;
399+ scope.run_sizes .arg_stack = arg_stack_begin;
469400 Array<PrimExpr> packed_args = {op->args [0 ], scope.stack_value , scope.stack_tcode ->data ,
470401 ConstInt32 (arg_stack_begin),
471402 ConstInt32 (arg_stack_begin + op->args .size () - 1 )};
@@ -486,10 +417,10 @@ class BuiltinLower : public StmtExprMutator {
486417 auto & scope = alloca_scope_.back ();
487418 auto & prep_seq = prep_seq_stack_.back ();
488419
489- int64_t restore_shape_stack = scope.run_shape_stack ;
490- size_t restore_array_stack = scope.run_array_stack ;
491- size_t arg_stack_begin = scope.run_arg_stack ;
492- scope.run_arg_stack += op->args .size ();
420+ int64_t restore_shape_stack = scope.run_sizes . shape_stack ;
421+ size_t restore_array_stack = scope.run_sizes . array_stack ;
422+ size_t arg_stack_begin = scope.run_sizes . arg_stack ;
423+ scope.run_sizes . arg_stack += op->args .size ();
493424 size_t args_size = op->args .size ();
494425 ICHECK_GT (args_size, 0 );
495426 PrimExpr expr = StmtExprMutator::VisitExpr_ (op);
@@ -510,14 +441,16 @@ class BuiltinLower : public StmtExprMutator {
510441 prep_seq.emplace_back (BufferStore (scope.stack_tcode , ConstInt32 (arg_tcode), {stack_index}));
511442 }
512443 // Verify stack size matches earlier value.
513- ICHECK_LE (scope.run_arg_stack , scope.max_arg_stack );
514- ICHECK_LE (scope.run_shape_stack , scope.max_shape_stack );
515- ICHECK_LE (scope.run_array_stack , scope.max_array_stack );
516- scope.run_shape_stack = restore_shape_stack;
517- scope.run_array_stack = restore_array_stack;
444+ if (is_precheck_) {
445+ scope.UpdateMax ();
446+ } else {
447+ scope.AssertMaxIsValid ();
448+ }
449+ scope.run_sizes .shape_stack = restore_shape_stack;
450+ scope.run_sizes .array_stack = restore_array_stack;
518451 // Update the top of the stack, so we can use more than one
519452 // packed function's arguments with the one stack.
520- scope.run_arg_stack = arg_stack_begin + args_size - 1 ;
453+ scope.run_sizes . arg_stack = arg_stack_begin + args_size - 1 ;
521454 Array<PrimExpr> packed_args = {op->args [0 ], scope.stack_value , scope.stack_tcode ->data ,
522455 ConstInt32 (arg_stack_begin),
523456 ConstInt32 (arg_stack_begin + op->args .size () - 1 ),
@@ -575,6 +508,8 @@ class BuiltinLower : public StmtExprMutator {
575508 PrimExpr device_type_;
576509 PrimExpr device_id_;
577510
511+ bool is_precheck_{false };
512+
578513 // Record all stack frames.
579514 std::vector<AllocaScope> alloca_scope_;
580515};
0 commit comments