Skip to content

Commit 0674ff8

Browse files
spectrometerHBHjunrushao
authored andcommitted
[TIR] Allow sync threads inside condition (apache#16345)
Originally, it is not allowed to sync threads inside a condition `while, if`. This PR introduces `tvm_thread_invariant` op to annotate the condition to be thread id invariant and get around the check.
1 parent 0499bd8 commit 0674ff8

File tree

7 files changed

+105
-4
lines changed

7 files changed

+105
-4
lines changed

include/tvm/tir/builtin.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,12 @@ TVM_DLL const Op& tvm_check_return();
411411
*/
412412
TVM_DLL const Op& tvm_thread_context();
413413

414+
/*!
415+
* \brief Mark a condition to be thread invariant.
416+
* This means the condition must be the same for all threads.
417+
*/
418+
TVM_DLL const Op& tvm_thread_invariant();
419+
414420
/*!
415421
* \brief Lowered version of call packed, the space of value and
416422
* type codes are explicitly allocated.

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,6 +1832,7 @@ def wrapped(*args, **kwargs):
18321832
tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
18331833
tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
18341834
tvm_struct_get = _tir_op.tvm_struct_get
1835+
tvm_thread_invariant = _op_wrapper(_tir_op.tvm_thread_invariant)
18351836
tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce)
18361837
tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync)
18371838
tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
@@ -2108,6 +2109,7 @@ def wrapped(*args, **kwargs):
21082109
"tvm_tuple",
21092110
"tvm_struct_set",
21102111
"tvm_struct_get",
2112+
"tvm_thread_invariant",
21112113
"tvm_thread_allreduce",
21122114
"tvm_load_matrix_sync",
21132115
"tvm_mma_sync",

python/tvm/tir/op.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,23 @@ def tvm_thread_allreduce(*freduce_args):
602602
return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
603603

604604

605+
def tvm_thread_invariant(cond):
606+
"""Mark condition as thread invariant.
607+
608+
Parameters
609+
----------
610+
cond : Expr
611+
The condition.
612+
613+
Returns
614+
-------
615+
call : PrimExpr
616+
The call expression.
617+
"""
618+
assert isinstance(cond, PrimExpr)
619+
return call_intrin(cond.dtype, "tir.tvm_thread_invariant", cond)
620+
621+
605622
def tvm_storage_sync(storage_scope):
606623
"""Perform synchronization in specified scope.
607624

src/target/source/codegen_c.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,10 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
669669
const StringImmNode* str = op->args[0].as<StringImmNode>();
670670
ICHECK(str != nullptr);
671671
os << "__tvm_param__" << str->value;
672+
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
673+
os << "(";
674+
this->PrintExpr(op->args[0], os);
675+
os << ")";
672676
} else {
673677
LOG(FATAL) << "Unresolved call " << op->op;
674678
}

src/tir/op/builtin.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,10 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context)
211211
.set_num_inputs(1)
212212
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
213213

214+
TIR_DEFINE_BUILTIN_FUNC(tvm_thread_invariant)
215+
.set_num_inputs(1)
216+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
217+
214218
TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered)
215219
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
216220
.set_attr<TScriptPrinterName>("TScriptPrinterName", String("call_packed_lowered"),

src/tir/transforms/storage_access.cc

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,23 @@ void StorageAccessVisitor::VisitStmt_(const ForNode* op) {
170170
}
171171
}
172172

173+
bool IsThreadInvariant(const PrimExpr& cond) {
174+
if (auto call = cond.as<CallNode>()) {
175+
if (auto opt_call_op = call->op.as<Op>()) {
176+
auto call_op = opt_call_op.value();
177+
if (call_op.same_as(builtin::tvm_thread_invariant())) {
178+
return true;
179+
}
180+
}
181+
}
182+
return false;
183+
}
184+
173185
void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) {
174-
++condition_counter_;
186+
bool is_thread_invariant = IsThreadInvariant(op->condition);
187+
if (!is_thread_invariant) {
188+
++condition_counter_;
189+
}
175190
this->VisitExpr(op->condition);
176191
scope_.push_back(std::vector<StmtEntry>());
177192
this->VisitStmt(op->then_case);
@@ -187,11 +202,16 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) {
187202
s.access.insert(s.access.end(), v.begin(), v.end());
188203
}
189204
scope_.back().emplace_back(std::move(s));
190-
--condition_counter_;
205+
if (!is_thread_invariant) {
206+
--condition_counter_;
207+
}
191208
}
192209

193210
void StorageAccessVisitor::VisitStmt_(const WhileNode* op) {
194-
++condition_counter_;
211+
bool is_thread_invariant = IsThreadInvariant(op->condition);
212+
if (!is_thread_invariant) {
213+
++condition_counter_;
214+
}
195215
this->VisitExpr(op->condition);
196216
scope_.push_back(std::vector<StmtEntry>());
197217
this->VisitStmt(op->body);
@@ -200,7 +220,9 @@ void StorageAccessVisitor::VisitStmt_(const WhileNode* op) {
200220
s.access = Summarize(std::move(scope_.back()), nullptr);
201221
scope_.pop_back();
202222
scope_.back().emplace_back(std::move(s));
203-
--condition_counter_;
223+
if (!is_thread_invariant) {
224+
--condition_counter_;
225+
}
204226
}
205227

206228
void StorageAccessVisitor::VisitExpr_(const CallNode* op) {

tests/python/codegen/test_target_codegen_cuda.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tvm import topi
2525
from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
2626
from tvm.contrib import utils
27+
from tvm.script import tir as T
2728
import tvm.testing
2829
import pytest
2930

@@ -1068,5 +1069,50 @@ def check_cuda(n, lanes):
10681069
check_cuda(64, 2)
10691070

10701071

1072+
def test_cuda_thread_sync_inside_condition():
1073+
@T.prim_func
1074+
def func1(A: T.Buffer((4, 4), "float32")) -> None:
1075+
A_shared = T.alloc_buffer((4, 4), "float32", scope="shared")
1076+
for bx in T.thread_binding(1, "blockIdx.x"):
1077+
for tx in T.thread_binding(32, "threadIdx.x"):
1078+
if A[0, 0] > 1.0:
1079+
for i, j in T.grid(4, 4):
1080+
A_shared[i, j] = A[i, j]
1081+
for i, j in T.grid(4, 4):
1082+
A[i, j] = A_shared[i, j] + 1.0
1083+
1084+
@T.prim_func
1085+
def func2(A: T.Buffer((4, 4), "float32")) -> None:
1086+
A_shared = T.alloc_buffer((4, 4), "float32", scope="shared")
1087+
for bx in T.thread_binding(1, "blockIdx.x"):
1088+
for tx in T.thread_binding(32, "threadIdx.x"):
1089+
if T.tvm_thread_invariant(A[0, 0] > 1.0):
1090+
for i, j in T.grid(4, 4):
1091+
A_shared[i, j] = A[i, j]
1092+
for i, j in T.grid(4, 4):
1093+
A[i, j] = A_shared[i, j] + 1.0
1094+
1095+
@T.prim_func
1096+
def func3(A: T.Buffer((4, 4), "float32")) -> None:
1097+
A_shared = T.alloc_buffer((4, 4), "float32", scope="shared")
1098+
for bx in T.thread_binding(1, "blockIdx.x"):
1099+
for tx in T.thread_binding(32, "threadIdx.x"):
1100+
while T.tvm_thread_invariant(A[0, 0] > 1.0):
1101+
for i, j in T.grid(4, 4):
1102+
A_shared[i, j] = A[i, j]
1103+
for i, j in T.grid(4, 4):
1104+
A[i, j] = A_shared[i, j] + 1.0
1105+
1106+
mod = tvm.IRModule({"main": func1})
1107+
with pytest.raises(tvm.error.InternalError):
1108+
tvm.build(mod, target="cuda")
1109+
1110+
mod = tvm.IRModule({"main": func2})
1111+
tvm.build(mod, target="cuda")
1112+
1113+
mod = tvm.IRModule({"main": func3})
1114+
tvm.build(mod, target="cuda")
1115+
1116+
10711117
if __name__ == "__main__":
10721118
tvm.testing.main()

0 commit comments

Comments
 (0)