3030#include < algorithm>
3131
3232#include " ../node/attr_registry.h"
33+ #include " ../support/utils.h"
3334#include " ./parsers/cpu.h"
3435
3536namespace tvm {
@@ -81,30 +82,6 @@ Optional<TargetKind> TargetKind::Get(const String& target_kind_name) {
8182
8283/* ********* Utility functions **********/
8384
84- /* !
85- * \brief Extract a number from the string with the given prefix.
86- * For example, when `str` is "sm_20" and `prefix` is "sm_".
87- * This function first checks if `str` starts with `prefix`,
88- * then return the integer 20 after the `prefix`
89- * \param str The string to be extracted
90- * \param prefix The prefix to be checked
91- * \return An integer, the extracted number. -1 if the check fails
92- */
93- static int ExtractIntWithPrefix (const std::string& str, const std::string& prefix) {
94- if (str.substr (0 , prefix.size ()) != prefix) {
95- return -1 ;
96- }
97- int result = 0 ;
98- for (size_t i = prefix.size (); i < str.size (); ++i) {
99- char c = str[i];
100- if (!isdigit (c)) {
101- return -1 ;
102- }
103- result = result * 10 + c - ' 0' ;
104- }
105- return result;
106- }
107-
10885/* !
10986 * \brief Extract a string from the string with the given prefix.
11087 * For example, when `str` is "sm_20" and `prefix` is "sm_".
@@ -168,14 +145,14 @@ void CheckOrSetAttr(Map<String, ObjectRef>* attrs, const String& name, const Str
168145 */
169146TargetJSON UpdateCUDAAttrs (TargetJSON target) {
170147 // Update -arch=sm_xx
171- int archInt;
172148 if (target.count (" arch" )) {
173149 // If -arch has been specified, validate the correctness
174150 String archStr = Downcast<String>(target.at (" arch" ));
175- archInt = ExtractIntWithPrefix ( archStr, " sm_" );
176- ICHECK (archInt != - 1 ) << " ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr;
151+ ICHECK ( support::StartsWith ( archStr, " sm_" ))
152+ << " ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr;
177153 } else {
178154 // Use the compute version of the first CUDA GPU instead
155+ int archInt;
179156 TVMRetValue version;
180157 if (!DetectDeviceFlag ({kDLCUDA , 0 }, runtime::kComputeVersion , &version)) {
181158 LOG (WARNING) << " Unable to detect CUDA version, default to \" -arch=sm_50\" instead" ;
@@ -196,14 +173,14 @@ TargetJSON UpdateCUDAAttrs(TargetJSON target) {
196173TargetJSON UpdateNVPTXAttrs (TargetJSON target) {
197174 CheckOrSetAttr (&target, " mtriple" , " nvptx64-nvidia-cuda" );
198175 // Update -mcpu=sm_xx
199- int arch;
200176 if (target.count (" mcpu" )) {
201177 // If -mcpu has been specified, validate the correctness
202178 String mcpu = Downcast<String>(target.at (" mcpu" ));
203- arch = ExtractIntWithPrefix ( mcpu, " sm_" );
204- ICHECK (arch != - 1 ) << " ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
179+ ICHECK ( support::StartsWith ( mcpu, " sm_" ))
180+ << " ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
205181 } else {
206182 // Use the compute version of the first CUDA GPU instead
183+ int arch;
207184 TVMRetValue version;
208185 if (!DetectDeviceFlag ({kDLCUDA , 0 }, runtime::kComputeVersion , &version)) {
209186 LOG (WARNING) << " Unable to detect CUDA version, default to \" -mcpu=sm_50\" instead" ;
0 commit comments