Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/target/tag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768);
TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768);
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536)
.with_config("l2_cache_size_bytes", Integer(41943040));
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90", 49152, 65536)
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90a", 49152, 65536)
.with_config("l2_cache_size_bytes", Integer(52428800));
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536);
TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536);
Expand Down
37 changes: 7 additions & 30 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <algorithm>

#include "../node/attr_registry.h"
#include "../support/utils.h"
#include "./parsers/cpu.h"

namespace tvm {
Expand Down Expand Up @@ -81,30 +82,6 @@ Optional<TargetKind> TargetKind::Get(const String& target_kind_name) {

/********** Utility functions **********/

/*!
* \brief Extract a number from the string with the given prefix.
* For example, when `str` is "sm_20" and `prefix` is "sm_".
* This function first checks if `str` starts with `prefix`,
* then return the integer 20 after the `prefix`
* \param str The string to be extracted
* \param prefix The prefix to be checked
* \return An integer, the extracted number. -1 if the check fails
*/
static int ExtractIntWithPrefix(const std::string& str, const std::string& prefix) {
if (str.substr(0, prefix.size()) != prefix) {
return -1;
}
int result = 0;
for (size_t i = prefix.size(); i < str.size(); ++i) {
char c = str[i];
if (!isdigit(c)) {
return -1;
}
result = result * 10 + c - '0';
}
return result;
}

/*!
* \brief Extract a string from the string with the given prefix.
* For example, when `str` is "sm_20" and `prefix` is "sm_".
Expand Down Expand Up @@ -168,14 +145,14 @@ void CheckOrSetAttr(Map<String, ObjectRef>* attrs, const String& name, const Str
*/
TargetJSON UpdateCUDAAttrs(TargetJSON target) {
// Update -arch=sm_xx
int archInt;
if (target.count("arch")) {
// If -arch has been specified, validate the correctness
String archStr = Downcast<String>(target.at("arch"));
archInt = ExtractIntWithPrefix(archStr, "sm_");
ICHECK(archInt != -1) << "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr;
ICHECK(support::StartsWith(archStr, "sm_"))
<< "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr;
} else {
// Use the compute version of the first CUDA GPU instead
int archInt;
TVMRetValue version;
if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) {
LOG(WARNING) << "Unable to detect CUDA version, default to \"-arch=sm_50\" instead";
Expand All @@ -196,14 +173,14 @@ TargetJSON UpdateCUDAAttrs(TargetJSON target) {
TargetJSON UpdateNVPTXAttrs(TargetJSON target) {
CheckOrSetAttr(&target, "mtriple", "nvptx64-nvidia-cuda");
// Update -mcpu=sm_xx
int arch;
if (target.count("mcpu")) {
// If -mcpu has been specified, validate the correctness
String mcpu = Downcast<String>(target.at("mcpu"));
arch = ExtractIntWithPrefix(mcpu, "sm_");
ICHECK(arch != -1) << "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
ICHECK(support::StartsWith(mcpu, "sm_"))
<< "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
} else {
// Use the compute version of the first CUDA GPU instead
int arch;
TVMRetValue version;
if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) {
LOG(WARNING) << "Unable to detect CUDA version, default to \"-mcpu=sm_50\" instead";
Expand Down