Skip to content

Commit 9f15746

Browse files
committed
[LLVM][Codegen] Enable SVE/VLA for RISCV targets
1 parent beca091 commit 9f15746

File tree

11 files changed

+255
-141
lines changed

11 files changed

+255
-141
lines changed

src/arith/analyzer.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,17 +231,18 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
231231
// Current analysis may not be powerful enough to prove expressions containing
232232
// the same symbolic value multiple times. However, when the symbolic values are
233233
// "T.vscale" and the compile target uses a scalable architecture extension like
234-
// SVE, we can make some assumptions about the value of vscale and iterate over a
234+
// VLA, we can make some assumptions about the value of vscale and iterate over a
235235
// space of pre-defined values to attempt to prove the expression.
236236
Target curr_target = Target::Current();
237237
if (ContainsVscaleCall(simplified)) {
238-
if (TargetHasSVE(curr_target)) {
239-
return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues);
238+
if (TargetHasVLA(curr_target)) {
239+
auto kVScaleValues = GetVScaleValues(curr_target);
240+
return CanProveVscaleExpressionFromKnownValues(this, simplified, kVScaleValues);
240241
}
241242
LOG(WARNING)
242243
<< "The expression contains scalable values. An attempt to prove by substituting "
243244
"with known values of vscale was not performed. This proof currently only supports "
244-
"AArch64 SVE targets, but the target was "
245+
"VLA targets, but the target was "
245246
<< curr_target;
246247
}
247248
return false;

src/arith/const_int_bound.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,15 +364,16 @@ class ConstIntBoundAnalyzer::Impl
364364
// only special handle >> and & which can be
365365
// used for index calculation.
366366

367+
auto curr_target = Target::Current();
367368
if (op->op.same_as(tir::builtin::shift_right())) {
368369
return VisitRightShift(op);
369370
} else if (op->op.same_as(tir::builtin::shift_left())) {
370371
return VisitLeftShift(op);
371372
} else if (op->op.same_as(tir::builtin::bitwise_and())) {
372373
return VisitBitwiseAnd(op);
373-
} else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE(Target::Current())) {
374-
unsigned int max_val =
375-
*std::max_element(kAArch64VScaleValues.begin(), kAArch64VScaleValues.end());
374+
} else if (op->op.same_as(tir::builtin::vscale()) && TargetHasVLA(curr_target)) {
375+
auto kVScaleValues = GetVScaleValues(curr_target);
376+
unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end());
376377
return MakeBound(1, max_val);
377378
} else {
378379
return Everything(op->dtype);

src/arith/scalable_expression.cc

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,45 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr
8686
return can_prove_expr;
8787
}
8888

89-
bool TargetHasSVE(Optional<Target> target) {
89+
bool TargetHasVLA(Optional<Target> target) {
9090
if (!target.defined()) {
9191
target = Target::Current();
9292
}
93+
bool has_vla{false};
9394
if (target.defined()) {
94-
return Downcast<Target>(target)->GetFeature<Bool>("has_sve").value_or(Bool(false));
95+
// aarch64
96+
has_vla = Downcast<Target>(target)->GetFeature<Bool>("has_sve").value_or(Bool(false));
97+
// riscv{32,64}
98+
static const PackedFunc* target_has_feature_fn_ptr =
99+
runtime::Registry::Get("target.target_has_feature");
100+
ICHECK(target_has_feature_fn_ptr != nullptr)
101+
<< "The `target.target_has_feature` func is not in tvm registry.";
102+
has_vla |= static_cast<bool>((*target_has_feature_fn_ptr)("v", target));
95103
}
96-
return false;
104+
return has_vla;
105+
}
106+
107+
const std::vector<unsigned int> GetVScaleValues(Optional<Target> target) {
108+
unsigned int vector_width = 0;
109+
std::vector<unsigned int> kVScaleValues;
110+
if (!target.defined()) {
111+
target = Target::Current();
112+
}
113+
if (target.defined()) {
114+
static const PackedFunc* llvm_get_vector_width_fn_ptr =
115+
runtime::Registry::Get("target.llvm_get_vector_width");
116+
ICHECK(llvm_get_vector_width_fn_ptr != nullptr)
117+
<< "The `target.llvm_get_vector_width` func is not in tvm registry.";
118+
vector_width = static_cast<int>((*llvm_get_vector_width_fn_ptr)(target));
119+
}
120+
// scale list with powers of two
121+
for (unsigned int i = 0;; ++i) {
122+
auto power = static_cast<unsigned int>(std::pow(2, i));
123+
if (power > (vector_width / 8)) break;
124+
kVScaleValues.push_back(power);
125+
}
126+
127+
return kVScaleValues;
97128
}
98129

99130
} // namespace arith

src/arith/scalable_expression.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@
3535
namespace tvm {
3636
namespace arith {
3737

38-
/*! \brief A list of known vscale values to try for an AArch64 SVE target. */
39-
static const std::vector<unsigned int> kAArch64VScaleValues = {1, 2, 4, 8, 16};
40-
4138
/*!
4239
* \brief Check if an expr is a call to the vscale intrinsic.
4340
* \param expr The expr to check
@@ -80,10 +77,18 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr
8077

8178
/*!
8279
* \brief Check whether the compilation target supports SVE
80+
* \brief Check whether the compilation target supports VLA
81+
* \param target The target to check.
82+
* \return Whether VLA is supported
83+
*/
84+
bool TargetHasVLA(Optional<Target> target = NullOpt);
85+
86+
/*!
87+
* \brief Get a list of known vscale values to try for an VLA target.
8388
* \param target The target to check.
84-
* \return Whether SVE is supported
89+
* \return A list of vscale values as std::vector<usigned int>
8590
*/
86-
bool TargetHasSVE(Optional<Target> target = std::nullopt);
91+
const std::vector<unsigned int> GetVScaleValues(Optional<Target> target = std::nullopt);
8792

8893
} // namespace arith
8994
} // namespace tvm

src/target/llvm/codegen_aarch64.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ void CodeGenAArch64::SetTargetAttributes(llvm::Function* func) {
5757
#if TVM_LLVM_VERSION >= 130
5858
// Add vscale_range() function attribute when appropriate.
5959
if (llvm_target_->TargetHasCPUFeature("sve") || llvm_target_->TargetHasCPUFeature("sme")) {
60-
unsigned int max_val =
61-
*std::max_element(arith::kAArch64VScaleValues.begin(), arith::kAArch64VScaleValues.end());
60+
auto kVScaleValues = arith::GetVScaleValues(Target::Current());
61+
unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end());
6262
func->addFnAttr(
6363
llvm::Attribute::getWithVScaleRangeArgs(*llvm_target_->GetContext(), 1, max_val));
6464
}

src/tir/transforms/vectorize_loop.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ bool EnableBufferLevelPredication(Target target) {
8080
return enable_buffer_predication.value();
8181
}
8282

83-
// Use buffer-level predication by default for AArch64 SVE targets
84-
return arith::TargetHasSVE(target);
83+
// Use buffer-level predication by default for VLA targets
84+
return arith::TargetHasVLA(target);
8585
}
8686

8787
/*!
@@ -972,7 +972,7 @@ class LoopVectorizer : public StmtMutator {
972972

973973
if (!extent_as_int || extent_as_int->value < 1) {
974974
bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall);
975-
ICHECK(is_scalable_expr && arith::TargetHasSVE(target_))
975+
ICHECK(is_scalable_expr && arith::TargetHasVLA(target_))
976976
<< "Failed to vectorize loop with extent " << op->extent << " for target " << target_;
977977
}
978978
ICHECK(is_zero(op->min));

tests/python/arith/test_arith_simplify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_simplify_vscale_comparison_without_sve_target(capfd):
113113
warning_msg = (
114114
"Warning: The expression contains scalable values. An attempt to prove by substituting "
115115
"with known values of vscale was not performed. This proof currently only supports "
116-
"AArch64 SVE targets, but the target was llvm -keys=arm_cpu,cpu -mtriple=aarch64-linux-gnu"
116+
"VLA targets, but the target was llvm -keys=arm_cpu,cpu -mtriple=aarch64-linux-gnu"
117117
)
118118
capture = capfd.readouterr().err
119119
assert warning_msg in capture

0 commit comments

Comments
 (0)