Skip to content

Commit dfcdb3e

Browse files
committed
[Target] Use LLVM target parser for determining Arm(R) A-Profile Architecture features
Currently, target features are determined by a set of fixed checks on the target string. This works well for checking support of a small number of simple features, but it doesn't scale. Some problems include: - There are many non-trivial conditions for which a feature may(not) be available. It is easy to miss these with the current implementation. - The inclusion of some features in a target string can imply other features. For example, "+sve" implies "+neon". This currently isn't taken into account. - The tests in tests/cpp/target/parsers/aprofile_test.c suggest that targets such as "llvm -mcpu=cortex-a+neon" and "llvm -mattr=+noneon" are supported target strings. The features will be correctly parsed in TVM, however, they are not valid in LLVM. Therefore, it's possible that TVM and LLVM have different understanding of the features available. This commit uses the more robust LLVM target parser to determine support for the features in TVM. It leverages previous infrastructure added to TVM for obtaining a list of all supported features given an input target, and uses this to check the existance of certain features we're interested in. It should be trivial to grow this list over time. As a result of this change, the problems mentioned above are solved. In the current form, this commit drops support for target strings such as "llvm -mcpu=cortex-a+neon" and "llvm -mattr=+noneon". A scan of the codebase suggests this functionality is not in use (only in test cases). Should we feel the need to support them, or have a smoother migration for downstream users of TVM we can add a translator to the parser to convert these into LLVM compatible targets. Change-Id: Ic2bf3b68c8af74025ec388d304bd014624c0c585
1 parent fe9814c commit dfcdb3e

File tree

6 files changed

+148
-149
lines changed

6 files changed

+148
-149
lines changed

src/target/llvm/llvm_instance.cc

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -199,21 +199,23 @@ std::ostream& operator<<(std::ostream& os, const LLVMTargetInfo::Option& opt) {
199199
return os;
200200
}
201201

202-
LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
203-
triple_ = target->GetAttr<String>("mtriple").value_or("default");
202+
LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target)
203+
: LLVMTargetInfo(instance, target->Export()) {}
204204

205+
LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) {
206+
triple_ = Downcast<String>(target.Get("mtriple").value_or(String("default")));
205207
if (triple_.empty() || triple_ == "default") {
206208
triple_ = llvm::sys::getDefaultTargetTriple();
207209
}
208-
cpu_ = target->GetAttr<String>("mcpu").value_or(defaults::cpu);
210+
cpu_ = Downcast<String>(target.Get("mcpu").value_or(String(defaults::cpu)));
209211

210-
if (const Optional<Array<String>>& v = target->GetAttr<Array<String>>("mattr")) {
212+
if (const auto& v = Downcast<Optional<Array<String>>>(target.Get("mattr"))) {
211213
for (const String& s : v.value()) {
212214
attrs_.push_back(s);
213215
}
214216
}
215217
// llvm module target
216-
if (target->kind->name == "llvm") {
218+
if (Downcast<String>(target.Get("kind")) == "llvm") {
217219
// legalize -mcpu with the target -mtriple
218220
auto arches = GetAllLLVMTargetArches();
219221
bool has_arch =
@@ -224,7 +226,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
224226
}
225227
}
226228

227-
if (const Optional<Array<String>>& v = target->GetAttr<Array<String>>("cl-opt")) {
229+
if (const auto& v = Downcast<Optional<Array<String>>>(target.Get("cl-opt"))) {
228230
llvm::StringMap<llvm::cl::Option*>& options = llvm::cl::getRegisteredOptions();
229231
bool parse_error = false;
230232
for (const String& s : v.value()) {
@@ -245,7 +247,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
245247
}
246248

247249
llvm::FloatABI::ABIType float_abi = llvm::FloatABI::Default;
248-
if (const Optional<String>& v = target->GetAttr<String>("mfloat-abi")) {
250+
if (const auto& v = Downcast<Optional<String>>(target.Get("mfloat-abi"))) {
249251
String value = v.value();
250252
if (value == "hard") {
251253
float_abi = llvm::FloatABI::Hard;
@@ -268,14 +270,14 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
268270
target_options_.NoInfsFPMath = false;
269271
target_options_.NoNaNsFPMath = true;
270272
target_options_.FloatABIType = float_abi;
271-
if (const Optional<String>& v = target->GetAttr<String>("mabi")) {
272-
target_options_.MCOptions.ABIName = v.value();
273+
if (target.find("mabi") != target.end()) {
274+
target_options_.MCOptions.ABIName = Downcast<String>(target.Get("mabi"));
273275
}
274276

275-
auto maybe_level = target->GetAttr<Integer>("opt-level");
277+
auto maybe_level = Downcast<Integer>(target.Get("opt-level"));
276278
#if TVM_LLVM_VERSION <= 170
277279
if (maybe_level.defined()) {
278-
int level = maybe_level.value()->value;
280+
int level = maybe_level->value;
279281
if (level <= 0) {
280282
opt_level_ = llvm::CodeGenOpt::None;
281283
} else if (level == 1) {
@@ -312,7 +314,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
312314
// Fast math options
313315

314316
auto GetBoolFlag = [&target](llvm::StringRef flag) -> bool {
315-
return target->GetAttr<Bool>(flag.str()).value_or(Bool(false));
317+
return Downcast<Bool>(target.Get(flag.str()).value_or(Bool(false)));
316318
};
317319
if (GetBoolFlag("fast-math")) {
318320
#if TVM_LLVM_VERSION >= 60
@@ -831,7 +833,7 @@ const Array<String> LLVMTargetInfo::GetAllLLVMTargetArches() const {
831833
return cpu_arches;
832834
}
833835

834-
const Array<String> LLVMTargetInfo::GetAllLLVMCpuFeatures() const {
836+
const Map<String, String> LLVMTargetInfo::GetAllLLVMCpuFeatures() const {
835837
std::string feats = "";
836838
for (const auto& attr : attrs_) {
837839
feats += feats.empty() ? attr : ("," + attr);
@@ -845,10 +847,11 @@ const Array<String> LLVMTargetInfo::GetAllLLVMCpuFeatures() const {
845847
#else
846848
MCInfo->getAllProcessorFeatures();
847849
#endif
848-
Array<String> cpu_features;
850+
// TVM doesn't have an FFI friendly Set, so use a Map instead for now
851+
Map<String, String> cpu_features;
849852
for (const auto& feat : llvm_features) {
850853
if (MCInfo->checkFeatures("+" + std::string(feat.Key))) {
851-
cpu_features.push_back(feat.Key);
854+
cpu_features.Set(feat.Key, "");
852855
}
853856
}
854857

@@ -858,9 +861,7 @@ const Array<String> LLVMTargetInfo::GetAllLLVMCpuFeatures() const {
858861
const bool LLVMTargetInfo::TargetHasCPUFeature(const std::string& feature) const {
859862
// lookup features for `-mcpu`
860863
auto feats = GetAllLLVMCpuFeatures();
861-
bool has_feature =
862-
std::any_of(feats.begin(), feats.end(), [&](const auto& var) { return var == feature; });
863-
864+
bool has_feature = feats.find(feature) != feats.end();
864865
return has_feature;
865866
}
866867

src/target/llvm/llvm_instance.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,14 @@ class LLVMTargetInfo {
156156
*/
157157
// NOLINTNEXTLINE(runtime/references)
158158
LLVMTargetInfo(LLVMInstance& scope, const std::string& target_str);
159+
/*!
160+
* \brief Constructs LLVMTargetInfo from `Target`
161+
* \param scope LLVMInstance object
162+
* \param target TVM JSON Target object for target "llvm"
163+
*/
164+
// NOLINTNEXTLINE(runtime/references)
165+
LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target);
166+
159167
/*!
160168
* \brief Destroys LLVMTargetInfo object
161169
*/
@@ -285,11 +293,12 @@ class LLVMTargetInfo {
285293

286294
/*!
287295
* \brief Get all CPU features from target
288-
* \return list with all valid cpu features
296+
* \return Map with all valid cpu features as keys and empty string as value. The Map
297+
* is intended to be used as a Set, which TVM does not currently support.
289298
* \note The features are fetched from the LLVM backend using the target `-mtriple`
290299
* and the `-mcpu` architecture, but also consider the `-mattr` attributes.
291300
*/
292-
const Array<String> GetAllLLVMCpuFeatures() const;
301+
const Map<String, String> GetAllLLVMCpuFeatures() const;
293302

294303
/*!
295304
* \brief Check the target if has a specific cpu feature

src/target/llvm/llvm_module.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,12 @@ TVM_REGISTER_GLOBAL("target.llvm_get_cpu_archlist")
545545
});
546546

547547
TVM_REGISTER_GLOBAL("target.llvm_get_cpu_features")
548-
.set_body_typed([](const Target& target) -> Array<String> {
548+
.set_body_typed([](const Target& target) -> Map<String, String> {
549549
auto use_target = target.defined() ? target : Target::Current(false);
550550
// ignore non "llvm" target
551551
if (target.defined()) {
552552
if (target->kind->name != "llvm") {
553-
return Array<String>{};
553+
return {};
554554
}
555555
}
556556
auto llvm_instance = std::make_unique<LLVMInstance>();
@@ -570,8 +570,7 @@ TVM_REGISTER_GLOBAL("target.llvm_cpu_has_feature")
570570
auto llvm_instance = std::make_unique<LLVMInstance>();
571571
LLVMTargetInfo llvm_backend(*llvm_instance, use_target);
572572
auto cpu_features = llvm_backend.GetAllLLVMCpuFeatures();
573-
bool has_feature = std::any_of(cpu_features.begin(), cpu_features.end(),
574-
[&](auto& var) { return var == feature; });
573+
bool has_feature = cpu_features.find(feature) != cpu_features.end();
575574
return has_feature;
576575
});
577576

src/target/parsers/aprofile.cc

Lines changed: 17 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424

2525
#include "aprofile.h"
2626

27+
#include <memory>
2728
#include <string>
2829

2930
#include "../../support/utils.h"
31+
#include "../llvm/llvm_instance.h"
3032

3133
namespace tvm {
3234
namespace target {
@@ -52,33 +54,6 @@ double GetArchVersion(Optional<Array<String>> attr) {
5254
return GetArchVersion(attr.value());
5355
}
5456

55-
static inline bool HasFlag(String attr, std::string flag) {
56-
std::string attr_str = attr;
57-
return attr_str.find(flag) != std::string::npos;
58-
}
59-
60-
static inline bool HasFlag(Optional<String> attr, std::string flag) {
61-
if (!attr) {
62-
return false;
63-
}
64-
return HasFlag(attr.value(), flag);
65-
}
66-
67-
static inline bool HasFlag(Optional<Array<String>> attr, std::string flag) {
68-
if (!attr) {
69-
return false;
70-
}
71-
Array<String> attr_array = attr.value();
72-
73-
auto matching_attr = std::find_if(attr_array.begin(), attr_array.end(),
74-
[flag](String attr_str) { return HasFlag(attr_str, flag); });
75-
return matching_attr != attr_array.end();
76-
}
77-
78-
static bool HasFlag(Optional<String> mcpu, Optional<Array<String>> mattr, std::string flag) {
79-
return HasFlag(mcpu, flag) || HasFlag(mattr, flag);
80-
}
81-
8257
bool IsAArch32(Optional<String> mtriple, Optional<String> mcpu) {
8358
if (mtriple) {
8459
bool is_mprofile = mcpu && support::StartsWith(mcpu.value(), "cortex-m");
@@ -102,38 +77,25 @@ bool IsArch(TargetJSON attrs) {
10277
}
10378

10479
static TargetFeatures GetFeatures(TargetJSON target) {
105-
Optional<String> mcpu = Downcast<Optional<String>>(target.Get("mcpu"));
106-
Optional<String> mtriple = Downcast<Optional<String>>(target.Get("mtriple"));
107-
Optional<Array<String>> mattr = Downcast<Optional<Array<String>>>(target.Get("mattr"));
108-
109-
const double arch_version = GetArchVersion(mattr);
110-
111-
const bool is_aarch64 = IsAArch64(mtriple);
80+
String kind = Downcast<String>(target.Get("kind"));
81+
ICHECK_EQ(kind, "llvm") << "Expected target kind 'llvm', but got '" << kind << "'";
11282

113-
const bool simd_flag = HasFlag(mcpu, mattr, "+neon") || HasFlag(mcpu, mattr, "+simd");
114-
const bool has_asimd = is_aarch64 || simd_flag;
115-
const bool has_sve = HasFlag(mcpu, mattr, "+sve");
116-
117-
const bool i8mm_flag = HasFlag(mcpu, mattr, "+i8mm");
118-
const bool i8mm_disable = HasFlag(mcpu, mattr, "+noi8mm");
119-
const bool i8mm_default = arch_version >= 8.6;
120-
const bool i8mm_support = arch_version >= 8.2 && arch_version <= 8.5;
121-
const bool has_i8mm = (i8mm_default && !i8mm_disable) || (i8mm_support && i8mm_flag);
83+
Optional<String> mtriple = Downcast<Optional<String>>(target.Get("mtriple"));
12284

123-
const bool dotprod_flag = HasFlag(mcpu, mattr, "+dotprod");
124-
const bool dotprod_disable = HasFlag(mcpu, mattr, "+nodotprod");
125-
const bool dotprod_default = arch_version >= 8.4;
126-
const bool dotprod_support = arch_version >= 8.2 && arch_version <= 8.3;
127-
const bool has_dotprod =
128-
(dotprod_default && !dotprod_disable) || (dotprod_support && dotprod_flag);
85+
auto llvm_instance = std::make_unique<codegen::LLVMInstance>();
86+
codegen::LLVMTargetInfo llvm_target(*llvm_instance, target);
87+
Map<String, String> features = llvm_target.GetAllLLVMCpuFeatures();
12988

130-
const bool fp16_flag = HasFlag(mcpu, mattr, "+fullfp16");
131-
const bool fp16_support = arch_version >= 8.2;
132-
const bool has_fp16_simd = fp16_support && (fp16_flag || has_sve);
89+
auto has_feature = [features](const String& feature) {
90+
return features.find(feature) != features.end();
91+
};
13392

134-
return {{"is_aarch64", Bool(is_aarch64)}, {"has_asimd", Bool(has_asimd)},
135-
{"has_sve", Bool(has_sve)}, {"has_dotprod", Bool(has_dotprod)},
136-
{"has_matmul_i8", Bool(has_i8mm)}, {"has_fp16_simd", Bool(has_fp16_simd)}};
93+
return {{"is_aarch64", Bool(IsAArch64(mtriple))},
94+
{"has_asimd", Bool(has_feature("neon"))},
95+
{"has_sve", Bool(has_feature("sve"))},
96+
{"has_dotprod", Bool(has_feature("dotprod"))},
97+
{"has_matmul_i8", Bool(has_feature("i8mm"))},
98+
{"has_fp16_simd", Bool(has_feature("fullfp16"))}};
13799
}
138100

139101
static Array<String> MergeKeys(Optional<Array<String>> existing_keys) {

0 commit comments

Comments
 (0)