Skip to content

Commit 8b5bd55

Browse files
authored
[Target][CUDA] Allow non-numeric arch as needed for latest gpu (#16736)
* [Target][CUDA] Allow non-numeric arch as needed for latest gpu * Fix parsing in nvcc * fix
1 parent 95ec38b commit 8b5bd55

File tree

3 files changed

+16
-33
lines changed

3 files changed

+16
-33
lines changed

python/tvm/contrib/nvcc.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,14 @@ def get_target_compute_version(target=None):
291291
# 2. Target.current()
292292
target = target or Target.current()
293293
if target and target.arch:
294-
major, minor = target.arch.split("_")[1]
295-
return major + "." + minor
294+
arch = target.arch.split("_")[1]
295+
if len(arch) == 2:
296+
major, minor = arch
297+
return major + "." + minor
298+
elif len(arch) == 3:
299+
# This is for arch like "sm_90a"
300+
major, minor, suffix = arch
301+
return major + "." + minor + "." + suffix
296302

297303
# 3. GPU compute version
298304
if tvm.cuda(0).exist:

src/target/tag.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768);
155155
TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768);
156156
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536)
157157
.with_config("l2_cache_size_bytes", Integer(41943040));
158-
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90", 49152, 65536)
158+
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90a", 49152, 65536)
159159
.with_config("l2_cache_size_bytes", Integer(52428800));
160160
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536);
161161
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536);

src/target/target_kind.cc

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <algorithm>
3131

3232
#include "../node/attr_registry.h"
33+
#include "../support/utils.h"
3334
#include "./parsers/cpu.h"
3435

3536
namespace tvm {
@@ -81,30 +82,6 @@ Optional<TargetKind> TargetKind::Get(const String& target_kind_name) {
8182

8283
/********** Utility functions **********/
8384

84-
/*!
85-
* \brief Extract a number from the string with the given prefix.
86-
* For example, when `str` is "sm_20" and `prefix` is "sm_".
87-
* This function first checks if `str` starts with `prefix`,
88-
* then return the integer 20 after the `prefix`
89-
* \param str The string to be extracted
90-
* \param prefix The prefix to be checked
91-
* \return An integer, the extracted number. -1 if the check fails
92-
*/
93-
static int ExtractIntWithPrefix(const std::string& str, const std::string& prefix) {
94-
if (str.substr(0, prefix.size()) != prefix) {
95-
return -1;
96-
}
97-
int result = 0;
98-
for (size_t i = prefix.size(); i < str.size(); ++i) {
99-
char c = str[i];
100-
if (!isdigit(c)) {
101-
return -1;
102-
}
103-
result = result * 10 + c - '0';
104-
}
105-
return result;
106-
}
107-
10885
/*!
10986
* \brief Extract a string from the string with the given prefix.
11087
* For example, when `str` is "sm_20" and `prefix` is "sm_".
@@ -168,14 +145,14 @@ void CheckOrSetAttr(Map<String, ObjectRef>* attrs, const String& name, const Str
168145
*/
169146
TargetJSON UpdateCUDAAttrs(TargetJSON target) {
170147
// Update -arch=sm_xx
171-
int archInt;
172148
if (target.count("arch")) {
173149
// If -arch has been specified, validate the correctness
174150
String archStr = Downcast<String>(target.at("arch"));
175-
archInt = ExtractIntWithPrefix(archStr, "sm_");
176-
ICHECK(archInt != -1) << "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr;
151+
ICHECK(support::StartsWith(archStr, "sm_"))
152+
<< "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr;
177153
} else {
178154
// Use the compute version of the first CUDA GPU instead
155+
int archInt;
179156
TVMRetValue version;
180157
if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) {
181158
LOG(WARNING) << "Unable to detect CUDA version, default to \"-arch=sm_50\" instead";
@@ -196,14 +173,14 @@ TargetJSON UpdateCUDAAttrs(TargetJSON target) {
196173
TargetJSON UpdateNVPTXAttrs(TargetJSON target) {
197174
CheckOrSetAttr(&target, "mtriple", "nvptx64-nvidia-cuda");
198175
// Update -mcpu=sm_xx
199-
int arch;
200176
if (target.count("mcpu")) {
201177
// If -mcpu has been specified, validate the correctness
202178
String mcpu = Downcast<String>(target.at("mcpu"));
203-
arch = ExtractIntWithPrefix(mcpu, "sm_");
204-
ICHECK(arch != -1) << "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
179+
ICHECK(support::StartsWith(mcpu, "sm_"))
180+
<< "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
205181
} else {
206182
// Use the compute version of the first CUDA GPU instead
183+
int arch;
207184
TVMRetValue version;
208185
if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) {
209186
LOG(WARNING) << "Unable to detect CUDA version, default to \"-mcpu=sm_50\" instead";

0 commit comments

Comments
 (0)