Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,12 @@ TVM_DLL const Op& tvm_check_return();
*/
TVM_DLL const Op& tvm_thread_context();

/*!
* \brief Mark a condition to be thread invariant.
* This means the condition must be the same for all threads.
*/
TVM_DLL const Op& tvm_thread_invariant();

/*!
* \brief Lowered version of call packed, the space of value and
* type codes are explicitly allocated.
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,7 @@ def wrapped(*args, **kwargs):
tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
tvm_struct_get = _tir_op.tvm_struct_get
tvm_thread_invariant = _op_wrapper(_tir_op.tvm_thread_invariant)
tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce)
tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync)
tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
Expand Down Expand Up @@ -2104,6 +2105,7 @@ def wrapped(*args, **kwargs):
"tvm_tuple",
"tvm_struct_set",
"tvm_struct_get",
"tvm_thread_invariant",
"tvm_thread_allreduce",
"tvm_load_matrix_sync",
"tvm_mma_sync",
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,23 @@ def tvm_thread_allreduce(*freduce_args):
return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)


def tvm_thread_invariant(cond):
"""Mark condition as thread invariant.

Parameters
----------
cond : Expr
The condition.

Returns
-------
call : PrimExpr
The call expression.
"""
assert isinstance(cond, PrimExpr)
return call_intrin(cond.dtype, "tir.tvm_thread_invariant", cond)


def tvm_storage_sync(storage_scope):
"""Perform synchronization in specified scope.

Expand Down
4 changes: 4 additions & 0 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,10 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
const StringImmNode* str = op->args[0].as<StringImmNode>();
ICHECK(str != nullptr);
os << "__tvm_param__" << str->value;
} else if (op->op.same_as(builtin::tvm_thread_invariant())) {
os << "(";
this->PrintExpr(op->args[0], os);
os << ")";
} else {
LOG(FATAL) << "Unresolved call " << op->op;
}
Expand Down
4 changes: 4 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(tvm_thread_invariant)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));

TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
.set_attr<TScriptPrinterName>("TScriptPrinterName", String("call_packed_lowered"),
Expand Down
30 changes: 26 additions & 4 deletions src/tir/transforms/storage_access.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,23 @@ void StorageAccessVisitor::VisitStmt_(const ForNode* op) {
}
}

bool IsThreadInvariant(const PrimExpr& cond) {
if (auto call = cond.as<CallNode>()) {
if (auto opt_call_op = call->op.as<Op>()) {
auto call_op = opt_call_op.value();
if (call_op.same_as(builtin::tvm_thread_invariant())) {
return true;
}
}
}
return false;
}

void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) {
++condition_counter_;
bool is_thread_invariant = IsThreadInvariant(op->condition);
if (!is_thread_invariant) {
++condition_counter_;
}
this->VisitExpr(op->condition);
scope_.push_back(std::vector<StmtEntry>());
this->VisitStmt(op->then_case);
Expand All @@ -187,11 +202,16 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) {
s.access.insert(s.access.end(), v.begin(), v.end());
}
scope_.back().emplace_back(std::move(s));
--condition_counter_;
if (!is_thread_invariant) {
--condition_counter_;
}
}

void StorageAccessVisitor::VisitStmt_(const WhileNode* op) {
++condition_counter_;
bool is_thread_invariant = IsThreadInvariant(op->condition);
if (!is_thread_invariant) {
++condition_counter_;
}
this->VisitExpr(op->condition);
scope_.push_back(std::vector<StmtEntry>());
this->VisitStmt(op->body);
Expand All @@ -200,7 +220,9 @@ void StorageAccessVisitor::VisitStmt_(const WhileNode* op) {
s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
scope_.back().emplace_back(std::move(s));
--condition_counter_;
if (!is_thread_invariant) {
--condition_counter_;
}
}

void StorageAccessVisitor::VisitExpr_(const CallNode* op) {
Expand Down
46 changes: 46 additions & 0 deletions tests/python/codegen/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm import topi
from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
from tvm.contrib import utils
from tvm.script import tir as T
import tvm.testing
import pytest

Expand Down Expand Up @@ -1068,5 +1069,50 @@ def check_cuda(n, lanes):
check_cuda(64, 2)


def test_cuda_thread_sync_inside_condition():
@T.prim_func
def func1(A: T.Buffer((4, 4), "float32")) -> None:
A_shared = T.alloc_buffer((4, 4), "float32", scope="shared")
for bx in T.thread_binding(1, "blockIdx.x"):
for tx in T.thread_binding(32, "threadIdx.x"):
if A[0, 0] > 1.0:
for i, j in T.grid(4, 4):
A_shared[i, j] = A[i, j]
for i, j in T.grid(4, 4):
A[i, j] = A_shared[i, j] + 1.0

@T.prim_func
def func2(A: T.Buffer((4, 4), "float32")) -> None:
A_shared = T.alloc_buffer((4, 4), "float32", scope="shared")
for bx in T.thread_binding(1, "blockIdx.x"):
for tx in T.thread_binding(32, "threadIdx.x"):
if T.tvm_thread_invariant(A[0, 0] > 1.0):
for i, j in T.grid(4, 4):
A_shared[i, j] = A[i, j]
for i, j in T.grid(4, 4):
A[i, j] = A_shared[i, j] + 1.0

@T.prim_func
def func3(A: T.Buffer((4, 4), "float32")) -> None:
A_shared = T.alloc_buffer((4, 4), "float32", scope="shared")
for bx in T.thread_binding(1, "blockIdx.x"):
for tx in T.thread_binding(32, "threadIdx.x"):
while T.tvm_thread_invariant(A[0, 0] > 1.0):
for i, j in T.grid(4, 4):
A_shared[i, j] = A[i, j]
for i, j in T.grid(4, 4):
A[i, j] = A_shared[i, j] + 1.0

mod = tvm.IRModule({"main": func1})
with pytest.raises(tvm.error.InternalError):
tvm.build(mod, target="cuda")

mod = tvm.IRModule({"main": func2})
tvm.build(mod, target="cuda")

mod = tvm.IRModule({"main": func3})
tvm.build(mod, target="cuda")


if __name__ == "__main__":
tvm.testing.main()