diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index e6116605f8a2..10e5b462d1d1 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -909,6 +909,12 @@ TVM_DLL const Op& anylist_setitem_call_packed(); */ TVM_DLL const Op& anylist_setitem_call_cpacked(); +/*! + * \brief Get the target's vscale value. It will be lowered to llvm.vscale intrinsic + * (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic) + */ +TVM_DLL const Op& vscale(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 74b0bd2ba4e1..8a93537f7707 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1862,6 +1862,7 @@ def wrapped(*args, **kwargs): anylist_resetitem = _op_wrapper(_tir_op.anylist_resetitem) anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed) anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked) +vscale = _op_wrapper(_tir_op.vscale) def _dtype_forward(func): @@ -2199,4 +2200,5 @@ def wrapped(*args, **kwargs): "IterVar", "CommReducer", "Range", + "vscale", ] diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index f0500290b888..1723804388b9 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -88,6 +88,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .op import start_profile_intrinsic, end_profile_intrinsic +from .op import vscale from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 4735a2644e83..8816880e7b52 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3338,6 +3338,17 @@ def anylist_setitem_call_cpacked(list_handle, index, func_name, *args): ) +def vscale(): + """Get the target's vscale value. It will be lowered to llvm.vscale intrinsic + (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic) + Returns + ------- + call : PrimExpr + Call to the vscale intrinsic + """ + return call_intrin("int32", "tir.vscale") + + # pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 9701a299f1d1..a689183d8f48 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1478,6 +1478,12 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return builder_->CreateAssumption(cond); } else if (op->op.same_as(builtin::tvm_thread_invariant())) { return MakeValue(op->args[0]); +#if TVM_LLVM_VERSION >= 110 + } else if (op->op.same_as(builtin::vscale())) { + llvm::Intrinsic::ID id = llvm::Intrinsic::vscale; + llvm::Function* f = GetIntrinsicDecl(id, builder_->getInt32Ty(), {}); + return builder_->CreateCall(f); +#endif } else { LOG(FATAL) << "unknown intrinsic " << op->op; } diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index fb92463c3c32..fbe31c890dad 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -394,6 +394,9 @@ TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_packed) TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_cpacked) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(vscale).set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); } // namespace builtin } // namespace tir } // namespace tvm diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index e873bce52bdf..4e75f916d9b2 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -16,10 +16,8 @@ # under the License. import tvm from tvm import te -from tvm.script import tir as TIR +from tvm.script import tir as T import re -import os -import ctypes import pytest from tvm.target.codegen import llvm_version_major @@ -476,5 +474,23 @@ def check_correct_assembly(type): check_correct_assembly(type=dtype) +@pytest.mark.skipif( + llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" +) +def test_codegen_vscale(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + vscale = tvm.tir.vscale() + + @T.prim_func + def main(A: T.Buffer((5,), "int32")): + for i in range(5): + A[i] = 2 * vscale + + build_mod = tvm.build(main, target=target) + llvm = build_mod.get_source() + + assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM." + + if __name__ == "__main__": tvm.testing.main()