-
Notifications
You must be signed in to change notification settings - Fork 798
[SYCL] Additional support for SYCL_DEVICE_ALLOWLIST #2483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
5c30ab7
dc5c005
aff799f
49e41f7
267ad4f
cfb40d7
55b0857
650d475
5fb7393
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| #include <algorithm> | ||
| #include <cstring> | ||
| #include <regex> | ||
| #include <string> | ||
|
|
||
| __SYCL_INLINE_NAMESPACE(cl) { | ||
| namespace sycl { | ||
|
|
@@ -120,96 +121,147 @@ vector_class<platform> platform_impl::get_platforms() { | |
| return Platforms; | ||
| } | ||
|
|
||
| struct DevDescT { | ||
| const char *devName = nullptr; | ||
| int devNameSize = 0; | ||
| const char *devDriverVer = nullptr; | ||
| int devDriverVerSize = 0; | ||
| std::string getValue(/*const*/ std::string &AllowList, size_t &Pos, | ||
s-kanaev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| unsigned long int Size) { | ||
| size_t Prev = Pos; | ||
| if ((Pos = AllowList.find("{{", Pos)) == std::string::npos) { | ||
| throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST", | ||
| PI_INVALID_VALUE); | ||
| } | ||
| if (Pos > Prev + Size) { | ||
| throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST", | ||
| PI_INVALID_VALUE); | ||
| } | ||
|
|
||
| const char *platformName = nullptr; | ||
| int platformNameSize = 0; | ||
| Pos = Pos + 2; | ||
| size_t Start = Pos; | ||
| if ((Pos = AllowList.find("}}", Pos)) == std::string::npos) { | ||
| throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST", | ||
| PI_INVALID_VALUE); | ||
| } | ||
| std::string Value = AllowList.substr(Start, Pos - Start); | ||
| Pos = Pos + 2; | ||
| return Value; | ||
| } | ||
|
|
||
| const char *platformVer = nullptr; | ||
| int platformVerSize = 0; | ||
| struct DevDescT { | ||
| std::string DevName; | ||
| std::string DevDriverVer; | ||
| std::string PlatName; | ||
| std::string PlatVer; | ||
| }; | ||
|
|
||
| static std::vector<DevDescT> getAllowListDesc() { | ||
| const char *str = SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get(); | ||
| if (!str) | ||
| std::string AllowList(SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get()); | ||
| if (AllowList.empty()) | ||
| return {}; | ||
|
|
||
| std::vector<DevDescT> decDescs; | ||
| const char devNameStr[] = "DeviceName"; | ||
| const char driverVerStr[] = "DriverVersion"; | ||
| const char platformNameStr[] = "PlatformName"; | ||
| const char platformVerStr[] = "PlatformVersion"; | ||
| decDescs.emplace_back(); | ||
| while ('\0' != *str) { | ||
| const char **valuePtr = nullptr; | ||
| int *size = nullptr; | ||
|
|
||
| // -1 to avoid comparing null terminator | ||
| if (0 == strncmp(devNameStr, str, sizeof(devNameStr) - 1)) { | ||
| valuePtr = &decDescs.back().devName; | ||
| size = &decDescs.back().devNameSize; | ||
| str += sizeof(devNameStr) - 1; | ||
| } else if (0 == | ||
| strncmp(platformNameStr, str, sizeof(platformNameStr) - 1)) { | ||
| valuePtr = &decDescs.back().platformName; | ||
| size = &decDescs.back().platformNameSize; | ||
| str += sizeof(platformNameStr) - 1; | ||
| } else if (0 == strncmp(platformVerStr, str, sizeof(platformVerStr) - 1)) { | ||
| valuePtr = &decDescs.back().platformVer; | ||
| size = &decDescs.back().platformVerSize; | ||
| str += sizeof(platformVerStr) - 1; | ||
| } else if (0 == strncmp(driverVerStr, str, sizeof(driverVerStr) - 1)) { | ||
| valuePtr = &decDescs.back().devDriverVer; | ||
| size = &decDescs.back().devDriverVerSize; | ||
| str += sizeof(driverVerStr) - 1; | ||
| } else { | ||
| throw sycl::runtime_error("Unrecognized key in device allowlist", | ||
| PI_INVALID_VALUE); | ||
| std::string DeviceName("DeviceName:"); | ||
| std::string DriverVersion("DriverVersion:"); | ||
| std::string PlatformName("PlatformName:"); | ||
| std::string PlatformVersion("PlatformVersion:"); | ||
| std::vector<DevDescT> DecDescs; | ||
| DecDescs.emplace_back(); | ||
|
|
||
| size_t Pos = 0; | ||
| while (Pos < AllowList.size()) { | ||
| if ((AllowList.compare(Pos, DeviceName.size(), DeviceName)) == 0) { | ||
| DecDescs.back().DevName = getValue(AllowList, Pos, DeviceName.size()); | ||
| if (AllowList[Pos] == ',') { | ||
| Pos++; | ||
| } | ||
| } | ||
|
|
||
| if (':' != *str) | ||
| throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE); | ||
|
|
||
| // Skip ':' | ||
| str += 1; | ||
|
|
||
| if ('{' != *str || '{' != *(str + 1)) | ||
| throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE); | ||
|
|
||
| // Skip opening sequence "{{" | ||
| str += 2; | ||
|
|
||
| *valuePtr = str; | ||
| else if ((AllowList.compare(Pos, DriverVersion.size(), DriverVersion)) == | ||
| 0) { | ||
| DecDescs.back().DevDriverVer = | ||
| getValue(AllowList, Pos, DriverVersion.size()); | ||
| if (AllowList[Pos] == ',') { | ||
| Pos++; | ||
| } | ||
| } | ||
|
|
||
| // Increment until closing sequence is encountered | ||
| while (('\0' != *str) && ('}' != *str || '}' != *(str + 1))) | ||
| ++str; | ||
| else if ((AllowList.compare(Pos, PlatformName.size(), PlatformName)) == 0) { | ||
| DecDescs.back().PlatName = getValue(AllowList, Pos, PlatformName.size()); | ||
| if (AllowList[Pos] == ',') { | ||
| Pos++; | ||
| } | ||
| } | ||
|
|
||
| if ('\0' == *str) | ||
| throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE); | ||
| else if ((AllowList.compare(Pos, PlatformVersion.size(), | ||
| PlatformVersion)) == 0) { | ||
| DecDescs.back().PlatVer = | ||
| getValue(AllowList, Pos, PlatformVersion.size()); | ||
| } else if (AllowList.find('|', Pos) != std::string::npos) { | ||
| Pos = AllowList.find('|') + 1; | ||
| while (AllowList[Pos] == ' ') { | ||
| Pos++; | ||
| } | ||
| DecDescs.emplace_back(); | ||
| } | ||
|
|
||
| *size = str - *valuePtr; | ||
| else { | ||
| throw sycl::runtime_error("Unrecognized key in device allowlist", | ||
| PI_INVALID_VALUE); | ||
| } | ||
| } // while (Pos <= AllowList.size()) | ||
| return DecDescs; | ||
| } | ||
|
|
||
| // Skip closing sequence "}}" | ||
| str += 2; | ||
| std::vector<int> convertVersionString(std::string Version) { | ||
| // version string format is xx.yy.zzzzz.ww WW is optional | ||
| std::vector<int> Values; | ||
| size_t Pos = 0; | ||
| size_t Start = Pos; | ||
| if ((Pos = Version.find(".", Pos)) == std::string::npos) { | ||
| throw sycl::runtime_error("Malformed syntax in version string", | ||
| PI_INVALID_VALUE); | ||
| } | ||
| Values.push_back(std::stoi(Version.substr(Start, Pos - Start))); | ||
| Pos++; | ||
| Start = Pos; | ||
| if ((Pos = Version.find(".", Pos)) == std::string::npos) { | ||
| throw sycl::runtime_error("Malformed syntax in version string", | ||
| PI_INVALID_VALUE); | ||
| } | ||
| Values.push_back(std::stoi(Version.substr(Start, Pos - Start))); | ||
| Pos++; | ||
| size_t Prev = Pos; | ||
| if ((Pos = Version.find(".", Pos)) == std::string::npos) { | ||
| Values.push_back(std::stoi(Version.substr(Prev))); | ||
| } else { | ||
| Values.push_back(std::stoi(Version.substr(Start, Pos - Start))); | ||
| Pos++; | ||
| Values.push_back(std::stoi(Version.substr(Pos))); | ||
| } | ||
| return Values; | ||
| } | ||
|
|
||
| if ('\0' == *str) | ||
| break; | ||
| enum MatchState { UNKNOWN, MATCH, NOMATCH }; | ||
|
|
||
| // '|' means that the is another filter | ||
| if ('|' == *str) | ||
| decDescs.emplace_back(); | ||
| else if (',' != *str) | ||
| throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE); | ||
| MatchState matchVersions(std::string Version1, std::string Version2) { | ||
|
||
| std::vector<int> V1 = convertVersionString(Version1); | ||
| std::vector<int> V2 = convertVersionString(Version2); | ||
|
|
||
| ++str; | ||
| if (V1.size() != V2.size()) { | ||
| return MatchState::NOMATCH; | ||
| } | ||
|
||
|
|
||
| return decDescs; | ||
| if (V1[0] > V2[0]) { | ||
| return MatchState::MATCH; | ||
| } | ||
| if ((V1[0] == V2[0]) && (V1[1] >= V2[1])) { | ||
| return MatchState::MATCH; | ||
| } | ||
| if ((V1[0] == V2[0]) && (V1[1] == V2[1]) && (V1[2] >= V2[2])) { | ||
| return MatchState::MATCH; | ||
| } | ||
| if (V1.size() == 4) { | ||
| if ((V1[0] == V2[0]) && (V1[1] == V2[1]) && (V1[2] == V2[2]) && | ||
| (V1[3] >= V2[3])) { | ||
| return MatchState::MATCH; | ||
| } | ||
| } | ||
| return MatchState::NOMATCH; | ||
| } | ||
|
|
||
| static void filterAllowList(vector_class<RT::PiDevice> &PiDevices, | ||
|
|
@@ -218,6 +270,11 @@ static void filterAllowList(vector_class<RT::PiDevice> &PiDevices, | |
| if (AllowList.empty()) | ||
| return; | ||
|
|
||
| MatchState DevNameState = UNKNOWN; | ||
| MatchState DevVerState = UNKNOWN; | ||
| MatchState PlatNameState = UNKNOWN; | ||
| MatchState PlatVerState = UNKNOWN; | ||
|
|
||
| const string_class PlatformName = | ||
| sycl::detail::get_platform_info<string_class, info::platform::name>::get( | ||
| PiPlatform, Plugin); | ||
|
|
@@ -237,33 +294,57 @@ static void filterAllowList(vector_class<RT::PiDevice> &PiDevices, | |
| string_class, info::device::driver_version>::get(Device, Plugin); | ||
|
|
||
| for (const DevDescT &Desc : AllowList) { | ||
| if (nullptr != Desc.platformName && | ||
| !std::regex_match(PlatformName, | ||
| std::regex(std::string(Desc.platformName, | ||
| Desc.platformNameSize)))) | ||
| continue; | ||
|
|
||
| if (nullptr != Desc.platformVer && | ||
| !std::regex_match( | ||
| PlatformVer, | ||
| std::regex(std::string(Desc.platformVer, Desc.platformVerSize)))) | ||
| continue; | ||
|
|
||
| if (nullptr != Desc.devName && | ||
| !std::regex_match(DeviceName, std::regex(std::string( | ||
| Desc.devName, Desc.devNameSize)))) | ||
| continue; | ||
|
|
||
| if (nullptr != Desc.devDriverVer && | ||
| !std::regex_match(DeviceDriverVer, | ||
| std::regex(std::string(Desc.devDriverVer, | ||
| Desc.devDriverVerSize)))) | ||
| continue; | ||
| if (!Desc.PlatName.empty()) { | ||
| if (!std::regex_match(PlatformName, std::regex(Desc.PlatName))) { | ||
| PlatNameState = MatchState::NOMATCH; | ||
| continue; | ||
| } else { | ||
| PlatNameState = MatchState::MATCH; | ||
| } | ||
| } | ||
|
|
||
| if (!Desc.PlatVer.empty()) { | ||
| if (!std::regex_match(PlatformVer, std::regex(Desc.PlatVer))) { | ||
| PlatVerState = MatchState::NOMATCH; | ||
| continue; | ||
| } else { | ||
| PlatVerState = MatchState::MATCH; | ||
| } | ||
| } | ||
|
|
||
| if (!Desc.DevName.empty()) { | ||
| if (!std::regex_match(DeviceName, std::regex(Desc.DevName))) { | ||
| DevNameState = MatchState::NOMATCH; | ||
| continue; | ||
| } else { | ||
| DevNameState = MatchState::MATCH; | ||
| } | ||
| } | ||
|
|
||
| if (!Desc.DevDriverVer.empty()) { | ||
| if (!std::regex_match(DeviceDriverVer, std::regex(Desc.DevDriverVer))) { | ||
| DevVerState = matchVersions(DeviceDriverVer, Desc.DevDriverVer); | ||
| if (DevVerState == MatchState::NOMATCH) { | ||
| continue; | ||
| } | ||
| } else { | ||
| DevVerState = MatchState::MATCH; | ||
| } | ||
| } | ||
|
|
||
| PiDevices[InsertIDx++] = Device; | ||
| break; | ||
| } | ||
| } | ||
| if (DevNameState == MatchState::MATCH && DevVerState == MatchState::NOMATCH) { | ||
| throw sycl::runtime_error("Requested SYCL device not found", | ||
| PI_DEVICE_NOT_FOUND); | ||
| } | ||
| if (PlatNameState == MatchState::MATCH && | ||
| PlatVerState == MatchState::NOMATCH) { | ||
| throw sycl::runtime_error("Requested SYCL platform not found", | ||
| PI_DEVICE_NOT_FOUND); | ||
| } | ||
| PiDevices.resize(InsertIDx); | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this mean '.' cannot be used in the regexp? Please add.