Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,9 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
op->op.same_as(builtin::end_profile_intrinsic())) {
LOG(INFO) << "Ignoring profile_intrinsic ... " << op->op;
return nullptr;
} else if (op->op.same_as(builtin::assume())) {
llvm::Value* cond = MakeValue(op->args[0]);
return builder_->CreateAssumption(cond);
} else {
LOG(FATAL) << "unknown intrinsic " << op->op;
}
Expand Down
24 changes: 24 additions & 0 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,5 +978,29 @@ def test_llvm_target_attributes():
assert n in functions_with_target


@tvm.testing.requires_llvm
def test_llvm_assume():
"""
Check that LLVM does not error out when generating code with tir.assume.
Verifying for llvm.assume being generated is not easy as the intrinsic and its
related instructions get removed during optimizations
"""

@T.prim_func
def tir_assume_func(A: T.Buffer((4, 4), "int32"), B: T.Buffer((14,), "int32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A_1 = T.Buffer((16,), "int32", data=A.data)
for axis0, axis1 in T.grid(4, 4):
T.assume(axis0 < 3 or axis1 < 2 or A_1[axis0 * 4 + axis1] == 0)
for i in range(14):
B_1 = T.Buffer((14,), "int32", data=B.data)
B_1[i] = A_1[i] * 2

mod = tvm.IRModule.from_expr(tir_assume_func)
inp = te.placeholder((4, 4), name="A", dtype="int32")
out = te.placeholder((14,), name="B", dtype="int32")
m = tvm.build(mod, [inp, out], target="llvm")


if __name__ == "__main__":
tvm.testing.main()