Skip to content

Commit ea1b57b

Browse files
committed
[LLVM][Codegen] Enable SVE for RISCV targets
1 parent 36f2502 commit ea1b57b

File tree

11 files changed

+270
-184
lines changed

11 files changed

+270
-184
lines changed

src/arith/analyzer.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,13 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
236236
Target curr_target = Target::Current();
237237
if (ContainsVscaleCall(simplified)) {
238238
if (TargetHasSVE(curr_target)) {
239-
return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues);
239+
auto kVScaleValues = GetVScaleValues(curr_target);
240+
return CanProveVscaleExpressionFromKnownValues(this, simplified, kVScaleValues);
240241
}
241242
LOG(WARNING)
242243
<< "The expression contains scalable values. An attempt to prove by substituting "
243244
"with known values of vscale was not performed. This proof currently only supports "
244-
"AArch64 SVE targets, but the target was "
245+
"SVE targets, but the target was "
245246
<< curr_target;
246247
}
247248
return false;

src/arith/const_int_bound.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,15 +364,16 @@ class ConstIntBoundAnalyzer::Impl
364364
// only special handle >> and & which can be
365365
// used for index calculation.
366366

367+
auto curr_target = Target::Current();
367368
if (op->op.same_as(tir::builtin::shift_right())) {
368369
return VisitRightShift(op);
369370
} else if (op->op.same_as(tir::builtin::shift_left())) {
370371
return VisitLeftShift(op);
371372
} else if (op->op.same_as(tir::builtin::bitwise_and())) {
372373
return VisitBitwiseAnd(op);
373-
} else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE(Target::Current())) {
374-
unsigned int max_val =
375-
*std::max_element(kAArch64VScaleValues.begin(), kAArch64VScaleValues.end());
374+
} else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE(curr_target)) {
375+
auto kVScaleValues = GetVScaleValues(curr_target);
376+
unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end());
376377
return MakeBound(1, max_val);
377378
} else {
378379
return Everything(op->dtype);

src/arith/scalable_expression.cc

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,41 @@ bool TargetHasSVE(Optional<Target> target) {
9090
if (!target.defined()) {
9191
target = Target::Current();
9292
}
93+
bool has_sve{false};
9394
if (target.defined()) {
94-
return Downcast<Target>(target)->GetFeature<Bool>("has_sve").value_or(Bool(false));
95+
// aarch64
96+
has_sve = Downcast<Target>(target)->GetFeature<Bool>("has_sve").value_or(Bool(false));
97+
// riscv{32,64}
98+
static const PackedFunc* target_has_feature_fn_ptr =
99+
runtime::Registry::Get("target.target_has_feature");
100+
ICHECK(target_has_feature_fn_ptr != nullptr)
101+
<< "The `target.target_has_feature` func is not in tvm registry.";
102+
has_sve |= static_cast<bool>((*target_has_feature_fn_ptr)("v", target));
95103
}
96-
return false;
104+
return has_sve;
105+
}
106+
107+
const std::vector<unsigned int> GetVScaleValues(Optional<Target> target) {
108+
unsigned int vector_width = 0;
109+
std::vector<unsigned int> kVScaleValues;
110+
if (!target.defined()) {
111+
target = Target::Current();
112+
}
113+
if (target.defined()) {
114+
static const PackedFunc* llvm_get_vector_width_fn_ptr =
115+
runtime::Registry::Get("target.llvm_get_vector_width");
116+
ICHECK(llvm_get_vector_width_fn_ptr != nullptr)
117+
<< "The `target.llvm_get_vector_width` func is not in tvm registry.";
118+
vector_width = static_cast<int>((*llvm_get_vector_width_fn_ptr)(target));
119+
}
120+
// scale list with powers of two
121+
for (unsigned int i = 0;; ++i) {
122+
auto power = static_cast<unsigned int>(std::pow(2, i));
123+
if (power > (vector_width / 8)) break;
124+
kVScaleValues.push_back(power);
125+
}
126+
127+
return kVScaleValues;
97128
}
98129

99130
} // namespace arith

src/arith/scalable_expression.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@
3535
namespace tvm {
3636
namespace arith {
3737

38-
/*! \brief A list of known vscale values to try for an AArch64 SVE target. */
39-
static const std::vector<unsigned int> kAArch64VScaleValues = {1, 2, 4, 8, 16};
40-
4138
/*!
4239
* \brief Check if an expr is a call to the vscale intrinsic.
4340
* \param expr The expr to check
@@ -85,6 +82,13 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr
8582
*/
8683
bool TargetHasSVE(Optional<Target> target = NullOpt);
8784

85+
/*!
86+
* \brief Get a list of known vscale values to try for an SVE target.
87+
* \param target The target to check.
88+
* \return A list of vscale values as std::vector<usigned int>
89+
*/
90+
const std::vector<unsigned int> GetVScaleValues(Optional<Target> target = NullOpt);
91+
8892
} // namespace arith
8993
} // namespace tvm
9094

src/target/llvm/codegen_aarch64.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ void CodeGenAArch64::SetTargetAttributes(llvm::Function* func) {
5757
#if TVM_LLVM_VERSION >= 130
5858
// Add vscale_range() function attribute when appropriate.
5959
if (llvm_target_->TargetHasCPUFeature("sve") || llvm_target_->TargetHasCPUFeature("sme")) {
60-
unsigned int max_val =
61-
*std::max_element(arith::kAArch64VScaleValues.begin(), arith::kAArch64VScaleValues.end());
60+
auto kVScaleValues = arith::GetVScaleValues(Target::Current());
61+
unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end());
6262
func->addFnAttr(
6363
llvm::Attribute::getWithVScaleRangeArgs(*llvm_target_->GetContext(), 1, max_val));
6464
}

src/target/llvm/llvm_instance.cc

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,6 @@
5353
#include <llvm/Support/raw_ostream.h>
5454
#include <llvm/Target/TargetMachine.h>
5555
#include <llvm/Target/TargetOptions.h>
56-
#if TVM_LLVM_VERSION >= 190
57-
#include <llvm/TargetParser/RISCVISAInfo.h>
58-
#else
59-
#if TVM_LLVM_VERSION >= 140
60-
#include <llvm/Support/RISCVISAInfo.h>
61-
#endif
62-
#endif
63-
#if TVM_LLVM_VERSION >= 160
64-
#include <llvm/TargetParser/RISCVTargetParser.h>
65-
#else
66-
#include <llvm/Support/TargetParser.h>
67-
#endif
6856
#include <tvm/runtime/container/array.h>
6957
#include <tvm/runtime/container/map.h>
7058
#include <tvm/runtime/container/optional.h>
@@ -299,34 +287,25 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target)
299287
// code model
300288
code_model_ = llvm::CodeModel::Medium;
301289
#if TVM_LLVM_VERSION >= 140
302-
// VLEN inference
303-
const auto cpu_name = GetOrCreateTargetMachine(false)->getMCSubtargetInfo()->getCPU();
304-
const auto canon_arch = llvm::RISCV::getMArchFromMcpu(cpu_name);
305-
auto ISAInfo =
306-
llvm::RISCVISAInfo::parseArchString(canon_arch, /*EnableExperimentalExtensions=*/true);
307-
// infer VLEN from LLVM RISCVInfo parser
308-
if (!llvm::errorToBool(ISAInfo.takeError()) && (vector_width_ == 0)) {
309-
vector_width_ = (*ISAInfo)->getMinVLen();
310-
}
311-
// infer VLEN from LLVM options (zvlXXXb override)
312-
for (const auto& attr : attrs_) {
313-
if (attr.find("zvl") != std::string::npos) {
314-
std::string vec;
315-
for (char c : attr) {
316-
if (std::isdigit(c)) vec += c;
290+
// get VLEN from the LLVM backend (zvlXXXb)
291+
Map<String, String> features = GetAllLLVMCpuFeatures();
292+
// check vector ISA
293+
if (features.count("v") > 0) {
294+
vector_width_ = 0;
295+
int zvlbits = 0;
296+
for (const auto& [attr, val] : features) {
297+
if (std::string(attr).find("zvl") != std::string::npos) {
298+
std::string vec;
299+
for (char c : std::string(attr)) {
300+
if (std::isdigit(c)) vec += c;
301+
}
302+
zvlbits = std::stoi(vec);
303+
// max of the multiple zvlXXXb
304+
if (vector_width_ < zvlbits) vector_width_ = zvlbits;
317305
}
318-
vector_width_ = std::stoi(vec);
319306
}
320307
}
321308
#endif
322-
if (vector_width_ > 0) {
323-
// push cl-opt to LLVM
324-
llvm_options_.push_back(
325-
ParseOptionString("-riscv-v-vector-bits-min:int=" + std::to_string(vector_width_)));
326-
} else {
327-
// fallback default (codegen will warn)
328-
llvm_options_.push_back(ParseOptionString("-riscv-v-vector-bits-min:int=256"));
329-
}
330309
}
331310

332311
// Target options
@@ -943,9 +922,7 @@ const int LLVMTargetInfo::GetVectorWidth() {
943922
} else if (arch == llvm::Triple::arm || arch == llvm::Triple::aarch64) {
944923
vector_width_ = 128;
945924
} else if (arch == llvm::Triple::riscv32 || arch == llvm::Triple::riscv64) {
946-
vector_width_ = 256;
947-
LOG(WARNING) << "LLVM RVV VLEN inference failed, "
948-
<< "using 256 bits, set -vector-width=XXX to override";
925+
vector_width_ = 128;
949926
} else {
950927
// fallback default
951928
vector_width_ = 128;

src/tir/transforms/vectorize_loop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ bool EnableBufferLevelPredication(Target target) {
8080
return enable_buffer_predication.value();
8181
}
8282

83-
// Use buffer-level predication by default for AArch64 SVE targets
83+
// Use buffer-level predication by default for SVE targets
8484
return arith::TargetHasSVE(target);
8585
}
8686

0 commit comments

Comments
 (0)