Skip to content

Commit 02c4c55

Browse files
authored
[SVE] Add codegen support for vscale_range() function attribute (#16962)
This commit adds support for the `vscale_range()` LLVM function attribute to be generated for SVE and SME targets. Some LLVM optimisation passes make use of the `vscale_range()` function attribute when scalable vectors are present (e.g. BasicAA llvm/llvm-project/pull/80445), so we include it alongside the "target_cpu" and "target-features" attributes.
1 parent 819b002 commit 02c4c55

File tree

3 files changed

+52
-1
lines changed

3 files changed

+52
-1
lines changed

src/target/llvm/codegen_aarch64.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <llvm/Target/TargetMachine.h>
2828
#include <tvm/runtime/registry.h>
2929

30+
#include "../../arith/scalable_expression.h"
3031
#include "codegen_cpu.h"
3132
#include "llvm_instance.h"
3233

@@ -40,6 +41,7 @@ class CodeGenAArch64 final : public CodeGenCPU {
4041

4142
void VisitStmt_(const AttrStmtNode* op);
4243
void AddFunction(const GlobalVar& gvar, const PrimFunc& f);
44+
void SetTargetAttributes(llvm::Function* func);
4345

4446
bool func_has_pstate_sm = false;
4547
bool func_has_pstate_za = false;
@@ -51,6 +53,17 @@ void CodeGenAArch64::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
5153
CodeGenCPU::AddFunction(gvar, f);
5254
}
5355

56+
void CodeGenAArch64::SetTargetAttributes(llvm::Function* func) {
57+
#if TVM_LLVM_VERSION >= 130
58+
// Add vscale_range() function attribute when appropriate.
59+
if (llvm_target_->TargetHasCPUFeature("sve") || llvm_target_->TargetHasCPUFeature("sme")) {
60+
func->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
61+
*llvm_target_->GetContext(), 1, tvm::arith::kAArch64VScaleValues.size()));
62+
}
63+
#endif
64+
CodeGenCPU::SetTargetAttributes(func);
65+
}
66+
5467
/*!
5568
* \brief Visit and handle AArch64 specific pragmas. To be AArch64 specific,
5669
* the expectation is that they are prepended with "pragma_aarch64".

src/target/llvm/codegen_llvm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
431431
*
432432
* \param func The function to set attributes on.
433433
*/
434-
void SetTargetAttributes(llvm::Function* func);
434+
virtual void SetTargetAttributes(llvm::Function* func);
435435
/*!
436436
* \brief Emit LLVM IR for conversion functions __extendhfsf2 and __truncsfhf2
437437
* into the current llvm::Module.

tests/python/codegen/test_target_codegen_aarch64.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,44 @@ def my_func(a: T.handle):
537537
assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable store in generated LLVM."
538538

539539

540+
@pytest.mark.skipif(
541+
llvm_version_major() < 13,
542+
reason="Function attribute vscale_range() is not supported in earlier versions of LLVM",
543+
)
544+
@pytest.mark.parametrize(
545+
"mattr,expect_attr",
546+
[
547+
("+neon", False),
548+
("+sve", True),
549+
("+v9a", True),
550+
("+sme", True),
551+
],
552+
)
553+
def test_vscale_range_function_attribute(mattr, expect_attr):
554+
target = f"llvm -mtriple=aarch64-linux-gnu -mattr={mattr}"
555+
556+
m = te.var("m")
557+
A = te.placeholder(m, dtype="float32", name="A")
558+
C = te.compute((m), lambda i: A[i] + 1, name="C")
559+
s = te.create_schedule([C.op])
560+
561+
with tvm.target.Target(target) as target:
562+
f = tvm.build(s, [A, C], target)
563+
564+
# Check if the vscale_range() attribute exists
565+
ll = f.get_source("ll")
566+
attr = re.findall(rf".*vscale_range\(\d+,\d+\)*.", ll)
567+
568+
if expect_attr:
569+
assert (
570+
len(attr) > 0
571+
), f"Function attribute vscale_range() was not found in generated LLVM IR"
572+
else:
573+
assert (
574+
len(attr) == 0
575+
), f"Unexpected function attribute vscale_range() was found in generated LLVM IR"
576+
577+
540578
@pytest.mark.skipif(
541579
llvm_version_major() < 16, reason="Test requires an LLVM version of at least 16 to target SME"
542580
)

0 commit comments

Comments
 (0)