@@ -45,6 +45,19 @@ bool AllowConciseScoping(const IRDocsifier& d) {
4545 LOG (FATAL) << " NotImplementedError: fragment printing" ;
4646}
4747
48+ bool IsAncestorOfAllVarUse (const tir::Stmt& node, const ObjectRef& var, const IRDocsifier& d) {
49+ if (!d->common_prefix .count (var.get ())) {
50+ return false ;
51+ }
52+ const std::vector<const Object*>& path = d->common_prefix .at (var.get ());
53+ for (auto it = path.rbegin (); it != path.rend (); ++it) {
54+ if (*it == node.get ()) {
55+ return true ;
56+ }
57+ }
58+ return false ;
59+ }
60+
4861TVM_STATIC_IR_FUNCTOR (IRDocsifier, vtable)
4962 .set_dispatch<tir::Evaluate>(" " , [](tir::Evaluate eval, ObjectPath p, IRDocsifier d) -> Doc {
5063 ExprDoc value = d->AsDoc <ExprDoc>(eval->value , p->Attr (" value" ));
@@ -322,6 +335,39 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional<ExprDo
322335 return TIR (d, " realize" )->Call (args, kwargs_keys, kwargs_values);
323336}
324337
338+ void InsertEnvThread (const tir::IterVar& iter_var, const ObjectPath& iter_var_p,
339+ const IRDocsifier& d) {
340+ Frame f = FindLowestVarDef (iter_var->var , d).value ();
341+ DefineVar (iter_var->var , f, d);
342+ ExprDoc rhs = TIR (d, " env_thread" )
343+ ->Call ({LiteralDoc::Str (iter_var->thread_tag , //
344+ iter_var_p->Attr (" thread_tag" ))});
345+ ExprDoc lhs = d->AsDoc <ExprDoc>(iter_var->var , iter_var_p->Attr (" var" ));
346+ f->stmts .push_back (AssignDoc (lhs, rhs, NullOpt));
347+ }
348+
349+ ExprDoc DocsifyLaunchThread (const tir::AttrStmt& attr_stmt, const ObjectPath& attr_stmt_p,
350+ Optional<tir::Var>* define_var, const IRDocsifier& d) {
351+ tir::IterVar iter_var = Downcast<tir::IterVar>(attr_stmt->node );
352+ ObjectPath iter_var_p = attr_stmt_p->Attr (" node" );
353+
354+ ExprDoc var_doc{nullptr };
355+ if (d->IsVarDefined (iter_var->var )) {
356+ var_doc = d->AsDoc <ExprDoc>(iter_var->var , iter_var_p->Attr (" var" ));
357+ } else if (IsAncestorOfAllVarUse (attr_stmt, iter_var->var , d)) {
358+ var_doc = LiteralDoc::Str (iter_var->thread_tag , iter_var_p->Attr (" thread_tag" ));
359+ *define_var = iter_var->var ;
360+ } else {
361+ InsertEnvThread (iter_var, iter_var_p, d);
362+ var_doc = d->AsDoc <ExprDoc>(iter_var->var , iter_var_p->Attr (" var" ));
363+ }
364+ return TIR (d, " launch_thread" )
365+ ->Call ({
366+ var_doc,
367+ d->AsDoc <ExprDoc>(attr_stmt->value , attr_stmt_p->Attr (" value" )),
368+ });
369+ }
370+
325371TVM_STATIC_IR_FUNCTOR (IRDocsifier, vtable)
326372 .set_dispatch<tir::BufferRealize>( //
327373 " " , [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
@@ -336,7 +382,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
336382 .set_dispatch<tir::AttrStmt>( //
337383 " " , [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
338384 bool concise = AllowConciseScoping (d);
385+ Optional<ExprDoc> lhs = NullOpt;
339386 Optional<ExprDoc> rhs = NullOpt;
387+ Optional<tir::Var> define_var = NullOpt;
340388 tir::Stmt body = stmt->body ;
341389 ObjectPath body_p = stmt_p->Attr (" body" );
342390 if (stmt->attr_key == " realize_scope" ) {
@@ -347,29 +395,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
347395 /* value=*/ d->AsDoc <ExprDoc>(stmt->value , stmt_p->Attr (" value" )),
348396 /* p=*/ stmt_p->Attr (" body" ), d);
349397 body = realize->body ;
350- body_p = body_p ->Attr (" body" );
398+ body_p = stmt_p-> Attr ( " body " ) ->Attr (" body" );
351399 }
352400 }
353401 }
354402 if (stmt->attr_key == " thread_extent" || stmt->attr_key == " virtual_thread" ) {
355- if (const auto * iter_var = stmt->node .as <tir::IterVarNode>()) {
356- if (!d->IsVarDefined (iter_var->var )) {
357- // `DefineVar` is not used here because a more specific name is desirable
358- ObjectPath iter_var_p = stmt_p->Attr (" node" );
359- Frame f = FindLowestVarDef (iter_var->var , d).value ();
360- DefineVar (iter_var->var , f, d);
361- f->stmts .push_back (
362- AssignDoc (d->AsDoc <ExprDoc>(iter_var->var , iter_var_p->Attr (" var" )),
363- TIR (d, " env_thread" )
364- ->Call ({LiteralDoc::Str (iter_var->thread_tag ,
365- iter_var_p->Attr (" thread_tag" ))}), //
366- NullOpt));
367- }
368- rhs = TIR (d, " launch_thread" )
369- ->Call ({
370- d->AsDoc <ExprDoc>(iter_var->var , stmt_p->Attr (" node" )),
371- d->AsDoc <ExprDoc>(stmt->value , stmt_p->Attr (" value" )),
372- });
403+ if (stmt->node ->IsInstance <tir::IterVarNode>()) {
404+ rhs = DocsifyLaunchThread (stmt, stmt_p, &define_var, d);
373405 }
374406 }
375407 if (!rhs.defined ()) {
@@ -380,8 +412,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
380412 });
381413 }
382414 With<TIRFrame> f (d, stmt);
415+ if (define_var.defined ()) {
416+ lhs = DefineVar (define_var.value (), *f, d);
417+ }
383418 AsDocBody (body, body_p, f->get (), d);
384- return DoConciseScoping (NullOpt , rhs.value (), &(*f)->stmts , concise);
419+ return DoConciseScoping (lhs , rhs.value (), &(*f)->stmts , concise);
385420 });
386421
387422TVM_STATIC_IR_FUNCTOR (IRDocsifier, vtable)
0 commit comments