Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
#include <vector>

#include "../../arith/pattern_match.h"
#include "../../arith/scalable_expression.h"
#include "../build_common.h"
#include "../func_registry_generator.h"
#include "codegen_params.h"
Expand Down Expand Up @@ -1127,6 +1128,13 @@ void CodeGenLLVM::SetTargetAttributes(llvm::Function* func) {
if (!features.empty()) {
func->addFnAttr("target-features", features);
}
#if TVM_LLVM_VERSION >= 130
// Add vscale_range() function attribute when appropriate.
if (llvm_target_->TargetHasCPUFeature("sve") || llvm_target_->TargetHasCPUFeature("sme")) {
func->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
*llvm_target_->GetContext(), 1, tvm::arith::kAArch64VScaleValues.size()));
}
#endif
}

void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
llvm::ArrayRef<llvm::Type*> arg_types);
/*!
* \brief Set target-related attributes on the LLVM function \p func. This
* includes "target-cpu" and "target-features" if present.
* includes "target-cpu", "target-features" and "vscale_range()" if present.
*
* \param func The function to set attributes on.
*/
Expand Down
38 changes: 38 additions & 0 deletions tests/python/codegen/test_target_codegen_aarch64.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,44 @@ def my_func(a: T.handle):
assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable store in generated LLVM."


@pytest.mark.skipif(
llvm_version_major() < 13,
reason="Function attribute vscale_range() is not supported in earlier versions of LLVM",
)
@pytest.mark.parametrize(
"mattr,expect_attr",
[
("+neon", False),
("+sve", True),
("+v9a", True),
("+sme", True),
],
)
def test_vscale_range_function_attribute(mattr, expect_attr):
target = f"llvm -mtriple=aarch64-linux-gnu -mattr={mattr}"

m = te.var("m")
A = te.placeholder(m, dtype="float32", name="A")
C = te.compute((m), lambda i: A[i] + 1, name="C")
s = te.create_schedule([C.op])

with tvm.target.Target(target) as target:
f = tvm.build(s, [A, C], target)

# Check if the vscale_range() attribute exists
ll = f.get_source("ll")
attr = re.findall(rf".*vscale_range\(\d+,\d+\)*.", ll)

if expect_attr:
assert (
len(attr) > 0
), f"Function attribute vscale_range() was not found in generated LLVM IR"
else:
assert (
len(attr) == 0
), f"Unexpected function attribute vscale_range() was found in generated LLVM IR"


@pytest.mark.skipif(
llvm_version_major() < 16, reason="Test requires an LLVM version of at least 16 to target SME"
)
Expand Down