From efc5dfb868d9598e20137e96061bad6f5c3f9c02 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Mon, 29 Jan 2024 16:07:20 +0000 Subject: [PATCH 1/2] [SVE] Add vscale builtin Add a vscale builtin and lowering to `llvm.vscale`. This will be used in subsequent patches for expressing scalable vectors in TIR. Co-authored-by: Luke Hutton Co-authored-by: Neil Hickey --- include/tvm/tir/builtin.h | 5 +++++ python/tvm/script/ir_builder/tir/ir.py | 2 ++ python/tvm/tir/__init__.py | 1 + python/tvm/tir/op.py | 10 +++++++++ src/target/llvm/codegen_llvm.cc | 6 +++++ src/tir/op/builtin.cc | 3 +++ .../codegen/test_target_codegen_aarch64.py | 22 ++++++++++++++++--- 7 files changed, 46 insertions(+), 3 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index e6116605f8a2..2df6d890e5d4 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -909,6 +909,11 @@ TVM_DLL const Op& anylist_setitem_call_packed(); */ TVM_DLL const Op& anylist_setitem_call_cpacked(); +/*! + * \brief Get the target's vscale value + */ +TVM_DLL const Op& vscale(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 74b0bd2ba4e1..8a93537f7707 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1862,6 +1862,7 @@ def wrapped(*args, **kwargs): anylist_resetitem = _op_wrapper(_tir_op.anylist_resetitem) anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed) anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked) +vscale = _op_wrapper(_tir_op.vscale) def _dtype_forward(func): @@ -2199,4 +2200,5 @@ def wrapped(*args, **kwargs): "IterVar", "CommReducer", "Range", + "vscale", ] diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index f0500290b888..1723804388b9 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -88,6 +88,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .op import start_profile_intrinsic, end_profile_intrinsic +from .op import vscale from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 4735a2644e83..d880644c2d7f 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3338,6 +3338,16 @@ def anylist_setitem_call_cpacked(list_handle, index, func_name, *args): ) +def vscale(): + """Get the target's vscale value + Returns + ------- + call : PrimExpr + Call to the vscale intrinsic + """ + return call_intrin("int32", "tir.vscale") + + # pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum") min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 9701a299f1d1..a689183d8f48 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1478,6 +1478,12 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return builder_->CreateAssumption(cond); } else if (op->op.same_as(builtin::tvm_thread_invariant())) { return MakeValue(op->args[0]); +#if TVM_LLVM_VERSION >= 110 + } else if (op->op.same_as(builtin::vscale())) { + llvm::Intrinsic::ID id = llvm::Intrinsic::vscale; + llvm::Function* f = GetIntrinsicDecl(id, builder_->getInt32Ty(), {}); + return builder_->CreateCall(f); +#endif } else { LOG(FATAL) << "unknown intrinsic " << op->op; } diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index fb92463c3c32..0756561febe6 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -394,6 +394,9 @@ TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_packed) TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_cpacked) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(vscale).set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); } // namespace builtin } // namespace tir } // namespace tvm diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index e873bce52bdf..921c6d63485c 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -16,10 +16,8 @@ # under the License. import tvm from tvm import te -from tvm.script import tir as TIR +from tvm.script import tir as T import re -import os -import ctypes import pytest from tvm.target.codegen import llvm_version_major @@ -476,5 +474,23 @@ def check_correct_assembly(type): check_correct_assembly(type=dtype) +@pytest.mark.skipif( + llvm_version_major() < 10, reason="Vscale is not supported in earlier versions of LLVM" +) +def test_codegen_vscale(): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve" + vscale = tvm.tir.vscale() + + @T.prim_func + def main(A: T.Buffer((5,), "int32")): + for i in range(5): + A[i] = 2 * vscale + + build_mod = tvm.build(main, target=target) + llvm = build_mod.get_source() + + assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM." + + if __name__ == "__main__": tvm.testing.main() From f6a16609e92de964731daf541bd5265ef4863da3 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Tue, 30 Jan 2024 11:23:15 +0000 Subject: [PATCH 2/2] Improve documentation and fix LLVM versioning in a test Change-Id: I5b364a8b50c8622d21d3d6d30b9d44a56e0418db --- include/tvm/tir/builtin.h | 3 ++- python/tvm/tir/op.py | 3 ++- src/tir/op/builtin.cc | 2 +- tests/python/codegen/test_target_codegen_aarch64.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 2df6d890e5d4..10e5b462d1d1 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -910,7 +910,8 @@ TVM_DLL const Op& anylist_setitem_call_packed(); TVM_DLL const Op& anylist_setitem_call_cpacked(); /*! - * \brief Get the target's vscale value + * \brief Get the target's vscale value. It will be lowered to llvm.vscale intrinsic + * (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic) */ TVM_DLL const Op& vscale(); diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index d880644c2d7f..8816880e7b52 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -3339,7 +3339,8 @@ def anylist_setitem_call_cpacked(list_handle, index, func_name, *args): def vscale(): - """Get the target's vscale value + """Get the target's vscale value. It will be lowered to llvm.vscale intrinsic + (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic) Returns ------- call : PrimExpr diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 0756561febe6..fbe31c890dad 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -396,7 +396,7 @@ TIR_DEFINE_BUILTIN_FUNC(anylist_setitem_call_cpacked) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); TIR_DEFINE_BUILTIN_FUNC(vscale).set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); + Integer(CallEffectKind::kPure)); } // namespace builtin } // namespace tir } // namespace tvm diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 921c6d63485c..4e75f916d9b2 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -475,7 +475,7 @@ def check_correct_assembly(type): @pytest.mark.skipif( - llvm_version_major() < 10, reason="Vscale is not supported in earlier versions of LLVM" + llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM" ) def test_codegen_vscale(): target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"