Skip to content

Commit 7ca1009

Browse files
committed
[TVMScript] Sugar T.env_thread + T.launch_thread
This PR introduces a syntactic sugar that combines T.env_thread and T.launch_thread. Previously, an AttrStmt that specifies thread extent or virtual thread is required to be written in two steps: ```python bx = T.env_thread("blockIdx.x") // creates an IterVar with T.launch_thread(bx, 128): // specify the iter domain ... ``` With this PR, now this behavior can be merged in a single line: ```python with T.launch_thread("blockIdx.x", 128) as bx: ... ```
1 parent 012d6a7 commit 7ca1009

File tree

6 files changed

+101
-27
lines changed

6 files changed

+101
-27
lines changed

include/tvm/script/ir_builder/tir/ir.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,14 @@ DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_
390390
*/
391391
LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);
392392

393+
/*!
394+
* \brief Launch a new thread.
395+
* \param thread_tag The thread type tag.
396+
* \param extent The extent of environment thread.
397+
* \return The result LaunchThreadFrame.
398+
*/
399+
LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent);
400+
393401
/*!
394402
* \brief Bind a var to thread env.
395403
* \param thread_tag The thread type tag.

python/tvm/script/ir_builder/tir/frame.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,6 @@ def __enter__(self) -> Buffer:
115115

116116
@_register_object("script.ir_builder.tir.LaunchThreadFrame")
117117
class LaunchThreadFrame(TIRFrame):
118-
...
118+
def __enter__(self) -> Var:
119+
super().__enter__()
120+
return self.iter_var.var

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from tvm import tir
3232
from tvm.ir import Range, Type
3333
from tvm.ir.base import deprecated
34-
from tvm.runtime import convert, ndarray
34+
from tvm.runtime import String, convert, ndarray
3535
from tvm.target import Target
3636

3737
# pylint: disable=unused-import
@@ -1185,14 +1185,14 @@ def decl_buffer(
11851185

11861186

11871187
def launch_thread(
1188-
iter_var: IterVar, # pylint: disable=redefined-outer-name
1188+
thread: Union[IterVar, str], # pylint: disable=redefined-outer-name
11891189
extent: PrimExpr,
11901190
) -> frame.LaunchThreadFrame:
11911191
"""Launch a thread.
11921192
11931193
Parameters
11941194
----------
1195-
iter_var : IterVar
1195+
thread : Union[IterVar, str]
11961196
The iteration variable.
11971197
11981198
extent : PrimExpr
@@ -1213,11 +1213,14 @@ def launch_thread(
12131213
T.launch_thread(brow, 1)
12141214
12151215
"""
1216-
return _ffi_api.LaunchThread(iter_var, extent) # type: ignore[attr-defined] # pylint: disable=no-member
1216+
1217+
if isinstance(thread, str):
1218+
thread = String(thread)
1219+
return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member
12171220

12181221

12191222
def env_thread(thread_tag: str) -> IterVar:
1220-
"""Bind a var to thread env"
1223+
"""Bind a var to thread env
12211224
12221225
Parameters
12231226
----------

src/script/ir_builder/tir/ir.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,10 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
427427
return LaunchThreadFrame(n);
428428
}
429429

430+
LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) {
431+
return LaunchThread(EnvThread(thread_tag), extent);
432+
}
433+
430434
RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
431435
PrimExpr condition) {
432436
ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>();
@@ -658,7 +662,18 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If);
658662
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then);
659663
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else);
660664
TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
661-
TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
665+
TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread")
666+
.set_body_typed([](ObjectRef thread_tag_or_var, PrimExpr extent) {
667+
if (const auto* var = thread_tag_or_var.as<tvm::tir::VarNode>()) {
668+
return LaunchThread(GetRef<tvm::tir::Var>(var), extent);
669+
} else if (const auto* str = thread_tag_or_var.as<StringObj>()) {
670+
return LaunchThread(GetRef<String>(str), extent);
671+
} else {
672+
LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: "
673+
<< thread_tag_or_var->GetTypeKey();
674+
throw;
675+
}
676+
});
662677
TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread);
663678

664679
TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore);

src/script/printer/tir/stmt.cc

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,19 @@ bool AllowConciseScoping(const IRDocsifier& d) {
4545
LOG(FATAL) << "NotImplementedError: fragment printing";
4646
}
4747

48+
bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IRDocsifier& d) {
49+
if (!d->common_prefix.count(var.get())) {
50+
return false;
51+
}
52+
const std::vector<const Object*>& path = d->common_prefix.at(var.get());
53+
for (auto it = path.rbegin(); it != path.rend(); ++it) {
54+
if (*it == node.get()) {
55+
return true;
56+
}
57+
}
58+
return false;
59+
}
60+
4861
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
4962
.set_dispatch<tir::Evaluate>("", [](tir::Evaluate eval, ObjectPath p, IRDocsifier d) -> Doc {
5063
ExprDoc value = d->AsDoc<ExprDoc>(eval->value, p->Attr("value"));
@@ -322,6 +335,39 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional<ExprDo
322335
return TIR(d, "realize")->Call(args, kwargs_keys, kwargs_values);
323336
}
324337

338+
void InsertEnvThread(const tir::IterVar& iter_var, const ObjectPath& iter_var_p,
339+
const IRDocsifier& d) {
340+
Frame f = FindLowestVarDef(iter_var->var, d).value();
341+
DefineVar(iter_var->var, f, d);
342+
ExprDoc rhs = TIR(d, "env_thread")
343+
->Call({LiteralDoc::Str(iter_var->thread_tag, //
344+
iter_var_p->Attr("thread_tag"))});
345+
ExprDoc lhs = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
346+
f->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
347+
}
348+
349+
ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const ObjectPath& attr_stmt_p,
350+
Optional<tir::Var>* define_var, const IRDocsifier& d) {
351+
tir::IterVar iter_var = Downcast<tir::IterVar>(attr_stmt->node);
352+
ObjectPath iter_var_p = attr_stmt_p->Attr("node");
353+
354+
ExprDoc var_doc{nullptr};
355+
if (d->IsVarDefined(iter_var->var)) {
356+
var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
357+
} else if (IsAncestorOfAllVarUse(attr_stmt, iter_var->var, d)) {
358+
var_doc = LiteralDoc::Str(iter_var->thread_tag, iter_var_p->Attr("thread_tag"));
359+
*define_var = iter_var->var;
360+
} else {
361+
InsertEnvThread(iter_var, iter_var_p, d);
362+
var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
363+
}
364+
return TIR(d, "launch_thread")
365+
->Call({
366+
var_doc,
367+
d->AsDoc<ExprDoc>(attr_stmt->value, attr_stmt_p->Attr("value")),
368+
});
369+
}
370+
325371
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
326372
.set_dispatch<tir::BufferRealize>( //
327373
"", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
@@ -336,7 +382,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
336382
.set_dispatch<tir::AttrStmt>( //
337383
"", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
338384
bool concise = AllowConciseScoping(d);
385+
Optional<ExprDoc> lhs = NullOpt;
339386
Optional<ExprDoc> rhs = NullOpt;
387+
Optional<tir::Var> define_var = NullOpt;
340388
tir::Stmt body = stmt->body;
341389
ObjectPath body_p = stmt_p->Attr("body");
342390
if (stmt->attr_key == "realize_scope") {
@@ -347,29 +395,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
347395
/*value=*/d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value")),
348396
/*p=*/stmt_p->Attr("body"), d);
349397
body = realize->body;
350-
body_p = body_p->Attr("body");
398+
body_p = stmt_p->Attr("body")->Attr("body");
351399
}
352400
}
353401
}
354402
if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") {
355-
if (const auto* iter_var = stmt->node.as<tir::IterVarNode>()) {
356-
if (!d->IsVarDefined(iter_var->var)) {
357-
// `DefineVar` is not used here because a more specific name is desirable
358-
ObjectPath iter_var_p = stmt_p->Attr("node");
359-
Frame f = FindLowestVarDef(iter_var->var, d).value();
360-
DefineVar(iter_var->var, f, d);
361-
f->stmts.push_back(
362-
AssignDoc(d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var")),
363-
TIR(d, "env_thread")
364-
->Call({LiteralDoc::Str(iter_var->thread_tag,
365-
iter_var_p->Attr("thread_tag"))}), //
366-
NullOpt));
367-
}
368-
rhs = TIR(d, "launch_thread")
369-
->Call({
370-
d->AsDoc<ExprDoc>(iter_var->var, stmt_p->Attr("node")),
371-
d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value")),
372-
});
403+
if (stmt->node->IsInstance<tir::IterVarNode>()) {
404+
rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d);
373405
}
374406
}
375407
if (!rhs.defined()) {
@@ -380,8 +412,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
380412
});
381413
}
382414
With<TIRFrame> f(d, stmt);
415+
if (define_var.defined()) {
416+
lhs = DefineVar(define_var.value(), *f, d);
417+
}
383418
AsDocBody(body, body_p, f->get(), d);
384-
return DoConciseScoping(NullOpt, rhs.value(), &(*f)->stmts, concise);
419+
return DoConciseScoping(lhs, rhs.value(), &(*f)->stmts, concise);
385420
});
386421

387422
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)

tests/python/unittest/test_tvmscript_roundtrip.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,16 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None:
171171
return Module
172172

173173

174+
def launch_env_thread():
175+
@T.prim_func
176+
def main(inputs: T.Buffer((64, 2, 4), "float32")) -> None:
177+
bx = T.launch_thread("blockIdx.x", 64)
178+
for i, j in T.grid(2, 4):
179+
T.evaluate(inputs[bx, i, j])
180+
181+
return main
182+
183+
174184
def opt_gemm_mod_host():
175185
@tvm.script.ir_module
176186
class Module:
@@ -3563,6 +3573,7 @@ def func():
35633573

35643574

35653575
ir_generator = tvm.testing.parameter(
3576+
launch_env_thread,
35663577
opt_gemm_normalize,
35673578
opt_gemm_lower,
35683579
opt_gemm_mod_host,

0 commit comments

Comments
 (0)