Skip to content

Commit 726a141

Browse files
authored
[Target] Use LLVM target parser for determining Arm(R) A-Profile Architecture features (#16425)
Currently, target features are determined by a set of fixed checks on the target string. This works well for checking support of a small number of simple features, but it doesn't scale. Some problems include: - There are many non-trivial conditions for which a feature may(not) be available. It is easy to miss these with the current implementation. - The inclusion of some features in a target string can imply other features. For example, "+sve" implies "+neon". This currently isn't taken into account. - The tests in tests/cpp/target/parsers/aprofile_test.c suggest that targets such as "llvm -mcpu=cortex-a+neon" and "llvm -mattr=+noneon" are supported target strings. The features will be correctly parsed in TVM, however, they are not valid in LLVM. Therefore, it's possible that TVM and LLVM have different understanding of the features available. This commit uses the more robust LLVM target parser to determine support for the features in TVM. It leverages previous infrastructure added to TVM for obtaining a list of all supported features given an input target, and uses this to check the existance of certain features we're interested in. It should be trivial to grow this list over time. As a result of this change, the problems mentioned above are solved. In the current form, this commit drops support for target strings such as "llvm -mcpu=cortex-a+neon" and "llvm -mattr=+noneon". A scan of the codebase suggests this functionality is not in use (only in test cases). Should we feel the need to support them, or have a smoother migration for downstream users of TVM we can add a translator to the parser to convert these into LLVM compatible targets.
1 parent 1891b4d commit 726a141

File tree

8 files changed

+282
-223
lines changed

8 files changed

+282
-223
lines changed

python/tvm/target/codegen.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ def llvm_get_cpu_features(target=None):
183183
List of available CPU features.
184184
"""
185185
assert isinstance(target, Target) or target is None
186-
return _ffi_api.llvm_get_cpu_features(target)
186+
feature_map = _ffi_api.llvm_get_cpu_features(target)
187+
return set(feature_map.keys())
187188

188189

189190
def llvm_cpu_has_features(cpu_features, target=None):

src/target/llvm/llvm_instance.cc

Lines changed: 43 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -199,32 +199,37 @@ std::ostream& operator<<(std::ostream& os, const LLVMTargetInfo::Option& opt) {
199199
return os;
200200
}
201201

202-
LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
203-
triple_ = target->GetAttr<String>("mtriple").value_or("default");
202+
LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target)
203+
: LLVMTargetInfo(instance, target->Export()) {}
204204

205+
LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) {
206+
triple_ = Downcast<String>(target.Get("mtriple").value_or(String("default")));
205207
if (triple_.empty() || triple_ == "default") {
206208
triple_ = llvm::sys::getDefaultTargetTriple();
207209
}
208-
cpu_ = target->GetAttr<String>("mcpu").value_or(defaults::cpu);
210+
cpu_ = Downcast<String>(target.Get("mcpu").value_or(String(defaults::cpu)));
209211

210-
if (const Optional<Array<String>>& v = target->GetAttr<Array<String>>("mattr")) {
212+
if (const auto& v = Downcast<Optional<Array<String>>>(target.Get("mattr"))) {
211213
for (const String& s : v.value()) {
212214
attrs_.push_back(s);
213215
}
214216
}
215217
// llvm module target
216-
if (target->kind->name == "llvm") {
218+
if (Downcast<String>(target.Get("kind")) == "llvm") {
217219
// legalize -mcpu with the target -mtriple
218220
auto arches = GetAllLLVMTargetArches();
219221
bool has_arch =
220222
std::any_of(arches.begin(), arches.end(), [&](const auto& var) { return var == cpu_; });
221223
if (!has_arch) {
222-
LOG(FATAL) << "LLVM cpu architecture `-mcpu=" << cpu_
223-
<< "` is not valid in `-mtriple=" << triple_ << "`";
224+
// Flag an error, but don't abort. This mimicks the behaviour of 'llc' to
225+
// give the code a chance to run with a less-specific target.
226+
LOG(ERROR) << "LLVM cpu architecture `-mcpu=" << cpu_
227+
<< "` is not valid in `-mtriple=" << triple_ << "`"
228+
<< ", using default `-mcpu=" << String(defaults::cpu) << "`";
224229
}
225230
}
226231

227-
if (const Optional<Array<String>>& v = target->GetAttr<Array<String>>("cl-opt")) {
232+
if (const auto& v = Downcast<Optional<Array<String>>>(target.Get("cl-opt"))) {
228233
llvm::StringMap<llvm::cl::Option*>& options = llvm::cl::getRegisteredOptions();
229234
bool parse_error = false;
230235
for (const String& s : v.value()) {
@@ -245,7 +250,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
245250
}
246251

247252
llvm::FloatABI::ABIType float_abi = llvm::FloatABI::Default;
248-
if (const Optional<String>& v = target->GetAttr<String>("mfloat-abi")) {
253+
if (const auto& v = Downcast<Optional<String>>(target.Get("mfloat-abi"))) {
249254
String value = v.value();
250255
if (value == "hard") {
251256
float_abi = llvm::FloatABI::Hard;
@@ -257,7 +262,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
257262
}
258263

259264
// LLVM JIT engine options
260-
if (const Optional<String>& v = target->GetAttr<String>("jit")) {
265+
if (const auto& v = Downcast<Optional<String>>(target.Get("jit"))) {
261266
String value = v.value();
262267
if ((value == "mcjit") || (value == "orcjit")) {
263268
jit_engine_ = value;
@@ -283,14 +288,14 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
283288
target_options_.NoInfsFPMath = false;
284289
target_options_.NoNaNsFPMath = true;
285290
target_options_.FloatABIType = float_abi;
286-
if (const Optional<String>& v = target->GetAttr<String>("mabi")) {
287-
target_options_.MCOptions.ABIName = v.value();
291+
if (target.find("mabi") != target.end()) {
292+
target_options_.MCOptions.ABIName = Downcast<String>(target.Get("mabi"));
288293
}
289294

290-
auto maybe_level = target->GetAttr<Integer>("opt-level");
295+
auto maybe_level = Downcast<Integer>(target.Get("opt-level"));
291296
#if TVM_LLVM_VERSION <= 170
292297
if (maybe_level.defined()) {
293-
int level = maybe_level.value()->value;
298+
int level = maybe_level->value;
294299
if (level <= 0) {
295300
opt_level_ = llvm::CodeGenOpt::None;
296301
} else if (level == 1) {
@@ -327,7 +332,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) {
327332
// Fast math options
328333

329334
auto GetBoolFlag = [&target](llvm::StringRef flag) -> bool {
330-
return target->GetAttr<Bool>(flag.str()).value_or(Bool(false));
335+
return Downcast<Bool>(target.Get(flag.str()).value_or(Bool(false)));
331336
};
332337
if (GetBoolFlag("fast-math")) {
333338
#if TVM_LLVM_VERSION >= 60
@@ -381,52 +386,31 @@ static const llvm::Target* CreateLLVMTargetInstance(const std::string triple,
381386
return llvm_instance;
382387
}
383388

384-
static llvm::TargetMachine* CreateLLVMTargetMachine(
389+
static std::unique_ptr<llvm::TargetMachine> CreateLLVMTargetMachine(
385390
const llvm::Target* llvm_instance, const std::string& triple, const std::string& cpu,
386-
const std::string& features, const llvm::TargetOptions& target_options,
387-
const llvm::Reloc::Model& reloc_model, const llvm::CodeModel::Model& code_model,
391+
const std::string& features, const llvm::TargetOptions& target_options = {},
392+
const llvm::Reloc::Model& reloc_model = llvm::Reloc::Static,
393+
const llvm::CodeModel::Model& code_model = llvm::CodeModel::Small,
388394
#if TVM_LLVM_VERSION <= 170
389-
const llvm::CodeGenOpt::Level& opt_level) {
395+
const llvm::CodeGenOpt::Level& opt_level = llvm::CodeGenOpt::Level(0)) {
390396
#else
391-
const llvm::CodeGenOptLevel& opt_level) {
397+
const llvm::CodeGenOptLevel& opt_level = llvm::CodeGenOptLevel(0)) {
392398
#endif
393399
llvm::TargetMachine* tm = llvm_instance->createTargetMachine(
394400
triple, cpu, features, target_options, reloc_model, code_model, opt_level);
395401
ICHECK(tm != nullptr);
396402

397-
return tm;
398-
}
399-
400-
static const llvm::MCSubtargetInfo* GetLLVMSubtargetInfo(const std::string& triple,
401-
const std::string& cpu_name,
402-
const std::string& feats) {
403-
// create a LLVM instance
404-
auto llvm_instance = CreateLLVMTargetInstance(triple, true);
405-
// create a target machine
406-
// required minimum: llvm::InitializeAllTargetMCs()
407-
llvm::TargetOptions target_options;
408-
auto tm = CreateLLVMTargetMachine(llvm_instance, triple, cpu_name, feats, target_options,
409-
llvm::Reloc::Static, llvm::CodeModel::Small,
410-
#if TVM_LLVM_VERSION <= 170
411-
llvm::CodeGenOpt::Level(0));
412-
#else
413-
llvm::CodeGenOptLevel(0));
414-
#endif
415-
// create subtarget info module
416-
const llvm::MCSubtargetInfo* MCInfo = tm->getMCSubtargetInfo();
417-
418-
return MCInfo;
403+
return std::unique_ptr<llvm::TargetMachine>(tm);
419404
}
420405

421406
llvm::TargetMachine* LLVMTargetInfo::GetOrCreateTargetMachine(bool allow_missing) {
422407
if (target_machine_) return target_machine_.get();
423408

424409
std::string error;
425410
if (const llvm::Target* llvm_instance = CreateLLVMTargetInstance(triple_, allow_missing)) {
426-
llvm::TargetMachine* tm =
411+
target_machine_ =
427412
CreateLLVMTargetMachine(llvm_instance, triple_, cpu_, GetTargetFeatureString(),
428413
target_options_, reloc_model_, code_model_, opt_level_);
429-
target_machine_ = std::unique_ptr<llvm::TargetMachine>(tm);
430414
}
431415
ICHECK(target_machine_ != nullptr);
432416
return target_machine_.get();
@@ -832,7 +816,11 @@ const Array<String> LLVMTargetInfo::GetAllLLVMTargets() const {
832816
const Array<String> LLVMTargetInfo::GetAllLLVMTargetArches() const {
833817
Array<String> cpu_arches;
834818
// get the subtarget info module
835-
const auto MCInfo = GetLLVMSubtargetInfo(triple_, "", "");
819+
auto llvm_instance = CreateLLVMTargetInstance(triple_, true);
820+
std::unique_ptr<llvm::TargetMachine> target_machine =
821+
CreateLLVMTargetMachine(llvm_instance, triple_, "", "");
822+
const auto MCInfo = target_machine->getMCSubtargetInfo();
823+
836824
if (!MCInfo) {
837825
return cpu_arches;
838826
}
@@ -850,24 +838,29 @@ const Array<String> LLVMTargetInfo::GetAllLLVMTargetArches() const {
850838
return cpu_arches;
851839
}
852840

853-
const Array<String> LLVMTargetInfo::GetAllLLVMCpuFeatures() const {
841+
const Map<String, String> LLVMTargetInfo::GetAllLLVMCpuFeatures() const {
854842
std::string feats = "";
855843
for (const auto& attr : attrs_) {
856844
feats += feats.empty() ? attr : ("," + attr);
857845
}
858846
// get the subtarget info module
859-
const auto MCInfo = GetLLVMSubtargetInfo(triple_, cpu_.c_str(), feats);
847+
auto llvm_instance = CreateLLVMTargetInstance(triple_, true);
848+
std::unique_ptr<llvm::TargetMachine> target_machine =
849+
CreateLLVMTargetMachine(llvm_instance, triple_, cpu_.c_str(), feats);
850+
const auto MCInfo = target_machine->getMCSubtargetInfo();
851+
860852
// get all features for CPU
861853
llvm::ArrayRef<llvm::SubtargetFeatureKV> llvm_features =
862854
#if TVM_LLVM_VERSION < 180
863855
llvm::featViewer(*(const llvm::MCSubtargetInfo*)MCInfo);
864856
#else
865857
MCInfo->getAllProcessorFeatures();
866858
#endif
867-
Array<String> cpu_features;
859+
// TVM doesn't have an FFI friendly Set, so use a Map instead for now
860+
Map<String, String> cpu_features;
868861
for (const auto& feat : llvm_features) {
869862
if (MCInfo->checkFeatures("+" + std::string(feat.Key))) {
870-
cpu_features.push_back(feat.Key);
863+
cpu_features.Set(feat.Key, "");
871864
}
872865
}
873866

@@ -877,9 +870,7 @@ const Array<String> LLVMTargetInfo::GetAllLLVMCpuFeatures() const {
877870
const bool LLVMTargetInfo::TargetHasCPUFeature(const std::string& feature) const {
878871
// lookup features for `-mcpu`
879872
auto feats = GetAllLLVMCpuFeatures();
880-
bool has_feature =
881-
std::any_of(feats.begin(), feats.end(), [&](const auto& var) { return var == feature; });
882-
873+
bool has_feature = feats.find(feature) != feats.end();
883874
return has_feature;
884875
}
885876

src/target/llvm/llvm_instance.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,14 @@ class LLVMTargetInfo {
156156
*/
157157
// NOLINTNEXTLINE(runtime/references)
158158
LLVMTargetInfo(LLVMInstance& scope, const std::string& target_str);
159+
/*!
160+
* \brief Constructs LLVMTargetInfo from `Target`
161+
* \param scope LLVMInstance object
162+
* \param target TVM JSON Target object for target "llvm"
163+
*/
164+
// NOLINTNEXTLINE(runtime/references)
165+
LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target);
166+
159167
/*!
160168
* \brief Destroys LLVMTargetInfo object
161169
*/
@@ -290,11 +298,12 @@ class LLVMTargetInfo {
290298

291299
/*!
292300
* \brief Get all CPU features from target
293-
* \return list with all valid cpu features
301+
* \return Map with all valid cpu features as keys and empty string as value. The Map
302+
* is intended to be used as a Set, which TVM does not currently support.
294303
* \note The features are fetched from the LLVM backend using the target `-mtriple`
295304
* and the `-mcpu` architecture, but also consider the `-mattr` attributes.
296305
*/
297-
const Array<String> GetAllLLVMCpuFeatures() const;
306+
const Map<String, String> GetAllLLVMCpuFeatures() const;
298307

299308
/*!
300309
* \brief Check the target if has a specific cpu feature

src/target/llvm/llvm_module.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -697,12 +697,12 @@ TVM_REGISTER_GLOBAL("target.llvm_get_cpu_archlist")
697697
});
698698

699699
TVM_REGISTER_GLOBAL("target.llvm_get_cpu_features")
700-
.set_body_typed([](const Target& target) -> Array<String> {
700+
.set_body_typed([](const Target& target) -> Map<String, String> {
701701
auto use_target = target.defined() ? target : Target::Current(false);
702702
// ignore non "llvm" target
703703
if (target.defined()) {
704704
if (target->kind->name != "llvm") {
705-
return Array<String>{};
705+
return {};
706706
}
707707
}
708708
auto llvm_instance = std::make_unique<LLVMInstance>();
@@ -722,8 +722,7 @@ TVM_REGISTER_GLOBAL("target.llvm_cpu_has_feature")
722722
auto llvm_instance = std::make_unique<LLVMInstance>();
723723
LLVMTargetInfo llvm_backend(*llvm_instance, use_target);
724724
auto cpu_features = llvm_backend.GetAllLLVMCpuFeatures();
725-
bool has_feature = std::any_of(cpu_features.begin(), cpu_features.end(),
726-
[&](auto& var) { return var == feature; });
725+
bool has_feature = cpu_features.find(feature) != cpu_features.end();
727726
return has_feature;
728727
});
729728

src/target/parsers/aprofile.cc

Lines changed: 35 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424

2525
#include "aprofile.h"
2626

27+
#include <memory>
2728
#include <string>
2829

2930
#include "../../support/utils.h"
31+
#include "../llvm/llvm_instance.h"
3032

3133
namespace tvm {
3234
namespace target {
@@ -52,33 +54,6 @@ double GetArchVersion(Optional<Array<String>> attr) {
5254
return GetArchVersion(attr.value());
5355
}
5456

55-
static inline bool HasFlag(String attr, std::string flag) {
56-
std::string attr_str = attr;
57-
return attr_str.find(flag) != std::string::npos;
58-
}
59-
60-
static inline bool HasFlag(Optional<String> attr, std::string flag) {
61-
if (!attr) {
62-
return false;
63-
}
64-
return HasFlag(attr.value(), flag);
65-
}
66-
67-
static inline bool HasFlag(Optional<Array<String>> attr, std::string flag) {
68-
if (!attr) {
69-
return false;
70-
}
71-
Array<String> attr_array = attr.value();
72-
73-
auto matching_attr = std::find_if(attr_array.begin(), attr_array.end(),
74-
[flag](String attr_str) { return HasFlag(attr_str, flag); });
75-
return matching_attr != attr_array.end();
76-
}
77-
78-
static bool HasFlag(Optional<String> mcpu, Optional<Array<String>> mattr, std::string flag) {
79-
return HasFlag(mcpu, flag) || HasFlag(mattr, flag);
80-
}
81-
8257
bool IsAArch32(Optional<String> mtriple, Optional<String> mcpu) {
8358
if (mtriple) {
8459
bool is_mprofile = mcpu && support::StartsWith(mcpu.value(), "cortex-m");
@@ -101,39 +76,46 @@ bool IsArch(TargetJSON attrs) {
10176
return IsAArch32(mtriple, mcpu) || IsAArch64(mtriple);
10277
}
10378

104-
static TargetFeatures GetFeatures(TargetJSON target) {
105-
Optional<String> mcpu = Downcast<Optional<String>>(target.Get("mcpu"));
106-
Optional<String> mtriple = Downcast<Optional<String>>(target.Get("mtriple"));
107-
Optional<Array<String>> mattr = Downcast<Optional<Array<String>>>(target.Get("mattr"));
79+
bool CheckContains(Array<String> array, String predicate) {
80+
return std::any_of(array.begin(), array.end(), [&](String var) { return var == predicate; });
81+
}
10882

109-
const double arch_version = GetArchVersion(mattr);
83+
static TargetFeatures GetFeatures(TargetJSON target) {
84+
#ifdef TVM_LLVM_VERSION
85+
String kind = Downcast<String>(target.Get("kind"));
86+
ICHECK_EQ(kind, "llvm") << "Expected target kind 'llvm', but got '" << kind << "'";
11087

111-
const bool is_aarch64 = IsAArch64(mtriple);
88+
Optional<String> mtriple = Downcast<Optional<String>>(target.Get("mtriple"));
89+
Optional<String> mcpu = Downcast<Optional<String>>(target.Get("mcpu"));
11290

113-
const bool simd_flag = HasFlag(mcpu, mattr, "+neon") || HasFlag(mcpu, mattr, "+simd");
114-
const bool has_asimd = is_aarch64 || simd_flag;
115-
const bool has_sve = HasFlag(mcpu, mattr, "+sve");
91+
// Check that LLVM has been compiled with the correct target support
92+
auto llvm_instance = std::make_unique<codegen::LLVMInstance>();
93+
codegen::LLVMTargetInfo llvm_backend(*llvm_instance, {{"kind", String("llvm")}});
94+
Array<String> targets = llvm_backend.GetAllLLVMTargets();
95+
if ((IsAArch64(mtriple) && !CheckContains(targets, "aarch64")) ||
96+
(IsAArch32(mtriple, mcpu) && !CheckContains(targets, "arm"))) {
97+
LOG(WARNING) << "Cannot parse target features. LLVM was not compiled with support for "
98+
"Arm(R)-based targets.";
99+
return {};
100+
}
116101

117-
const bool i8mm_flag = HasFlag(mcpu, mattr, "+i8mm");
118-
const bool i8mm_disable = HasFlag(mcpu, mattr, "+noi8mm");
119-
const bool i8mm_default = arch_version >= 8.6;
120-
const bool i8mm_support = arch_version >= 8.2 && arch_version <= 8.5;
121-
const bool has_i8mm = (i8mm_default && !i8mm_disable) || (i8mm_support && i8mm_flag);
102+
codegen::LLVMTargetInfo llvm_target(*llvm_instance, target);
103+
Map<String, String> features = llvm_target.GetAllLLVMCpuFeatures();
122104

123-
const bool dotprod_flag = HasFlag(mcpu, mattr, "+dotprod");
124-
const bool dotprod_disable = HasFlag(mcpu, mattr, "+nodotprod");
125-
const bool dotprod_default = arch_version >= 8.4;
126-
const bool dotprod_support = arch_version >= 8.2 && arch_version <= 8.3;
127-
const bool has_dotprod =
128-
(dotprod_default && !dotprod_disable) || (dotprod_support && dotprod_flag);
105+
auto has_feature = [features](const String& feature) {
106+
return features.find(feature) != features.end();
107+
};
129108

130-
const bool fp16_flag = HasFlag(mcpu, mattr, "+fullfp16");
131-
const bool fp16_support = arch_version >= 8.2;
132-
const bool has_fp16_simd = fp16_support && (fp16_flag || has_sve);
109+
return {{"is_aarch64", Bool(IsAArch64(mtriple))},
110+
{"has_asimd", Bool(has_feature("neon"))},
111+
{"has_sve", Bool(has_feature("sve"))},
112+
{"has_dotprod", Bool(has_feature("dotprod"))},
113+
{"has_matmul_i8", Bool(has_feature("i8mm"))},
114+
{"has_fp16_simd", Bool(has_feature("fullfp16"))}};
115+
#endif
133116

134-
return {{"is_aarch64", Bool(is_aarch64)}, {"has_asimd", Bool(has_asimd)},
135-
{"has_sve", Bool(has_sve)}, {"has_dotprod", Bool(has_dotprod)},
136-
{"has_matmul_i8", Bool(has_i8mm)}, {"has_fp16_simd", Bool(has_fp16_simd)}};
117+
LOG(WARNING) << "Cannot parse Arm(R)-based target features without LLVM support.";
118+
return {};
137119
}
138120

139121
static Array<String> MergeKeys(Optional<Array<String>> existing_keys) {

0 commit comments

Comments
 (0)