diff --git a/llvm/lib/SYCLLowerIR/ESIMD/ESIMDVerifier.cpp b/llvm/lib/SYCLLowerIR/ESIMD/ESIMDVerifier.cpp index bc024cb7ac2a8..8ceb66d9c8d09 100644 --- a/llvm/lib/SYCLLowerIR/ESIMD/ESIMDVerifier.cpp +++ b/llvm/lib/SYCLLowerIR/ESIMD/ESIMDVerifier.cpp @@ -13,6 +13,7 @@ #include "llvm/SYCLLowerIR/ESIMD/ESIMDVerifier.h" #include "llvm/Demangle/Demangle.h" +#include "llvm/Demangle/ItaniumDemangle.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -22,17 +23,64 @@ #include "llvm/Support/Regex.h" using namespace llvm; +namespace id = itanium_demangle; #define DEBUG_TYPE "esimd-verifier" -// A list of unsupported functions in ESIMD context. -static const char *IllegalFunctions[] = { - "^cl::sycl::multi_ptr<.+> cl::sycl::accessor<.+>::get_pointer<.+>\\(\\) " - "const", - " cl::sycl::accessor<.+>::operator\\[\\]<.+>\\(.+\\) const"}; +// A list of SYCL functions (regexps) allowed for use in ESIMD context. +static const char *LegalSYCLFunctions[] = { + "^cl::sycl::accessor<.+>::accessor", + "^cl::sycl::accessor<.+>::~accessor", + "^cl::sycl::accessor<.+>::getNativeImageObj", + "^cl::sycl::accessor<.+>::__init_esimd", + "^cl::sycl::id<.+>::.+", + "^cl::sycl::item<.+>::.+", + "^cl::sycl::nd_item<.+>::.+", + "^cl::sycl::group<.+>::.+", + "^cl::sycl::sub_group<.+>::.+", + "^cl::sycl::range<.+>::.+", + "^cl::sycl::kernel_handler::.+", + "^cl::sycl::cos<.+>", + "^cl::sycl::sin<.+>", + "^cl::sycl::log<.+>", + "^cl::sycl::exp<.+>", + "^cl::sycl::operator.+<.+>", + "^cl::sycl::ext::oneapi::sub_group::.+", + "^cl::sycl::ext::oneapi::experimental::spec_constant<.+>::.+", + "^cl::sycl::ext::oneapi::experimental::this_sub_group"}; namespace { +// Simplest possible implementation of an allocator for the Itanium demangler +class SimpleAllocator { +protected: + SmallVector Ptrs; + +public: + void reset() { + for (void *Ptr : Ptrs) { + // Destructors are not called, but that is OK for the + // itanium_demangle::Node subclasses + std::free(Ptr); + } + Ptrs.resize(0); + } + + template T *makeNode(Args &&...args) { + void *Ptr = std::calloc(1, sizeof(T)); + Ptrs.push_back(Ptr); + return new (Ptr) T(std::forward(args)...); + } + + void *allocateNodeArray(size_t sz) { + void *Ptr = std::calloc(sz, sizeof(id::Node *)); + Ptrs.push_back(Ptr); + return Ptr; + } + + ~SimpleAllocator() { reset(); } +}; + class ESIMDVerifierImpl { const Module &M; @@ -63,22 +111,49 @@ class ESIMDVerifierImpl { if (!Callee) continue; - // Demangle called function name and check if it matches any illegal - // function name. Report an error if there is a match. - std::string DemangledName = demangle(Callee->getName().str()); - for (const char *Name : IllegalFunctions) { - Regex NameRE(Name); - assert(NameRE.isValid() && "invalid function name regex"); - if (NameRE.match(DemangledName)) { - std::string ErrorMsg = std::string("function '") + DemangledName + - "' is not supported in ESIMD context"; - F->getContext().emitError(&I, ErrorMsg); - } - } - // Add callee to the list to be analyzed if it is not a declaration. if (!Callee->isDeclaration()) Add2Worklist(Callee); + + // Demangle called function name and check if it is legal to use this + // function in ESIMD context. + StringRef MangledName = Callee->getName(); + id::ManglingParser Parser(MangledName.begin(), + MangledName.end()); + id::Node *AST = Parser.parse(); + if (!AST || AST->getKind() != id::Node::KFunctionEncoding) + continue; + + auto *FE = static_cast(AST); + const id::Node *NameNode = FE->getName(); + if (!NameNode) // Can it be null? + continue; + + id::OutputBuffer NameBuf; + NameNode->print(NameBuf); + StringRef Name(NameBuf.getBuffer(), NameBuf.getCurrentPosition()); + + // We are interested in functions defined in SYCL namespace, but + // outside of ESIMD namespaces. + if (!Name.startswith("cl::sycl::") || + Name.startswith("cl::sycl::detail::") || + Name.startswith("cl::sycl::ext::intel::esimd::") || + Name.startswith("cl::sycl::ext::intel::experimental::esimd::")) + continue; + + // Check if function name matches any allowed SYCL function name. + if (any_of(LegalSYCLFunctions, [Name](const char *LegalName) { + Regex LegalNameRE(LegalName); + assert(LegalNameRE.isValid() && "invalid function name regex"); + return LegalNameRE.match(Name); + })) + continue; + + // If not, report an error. + std::string ErrorMsg = std::string("function '") + + demangle(MangledName.str()) + + "' is not supported in ESIMD context"; + F->getContext().emitError(&I, ErrorMsg); } } } diff --git a/sycl/test/esimd/esimd_verify.cpp b/sycl/test/esimd/esimd_verify.cpp index df940ad0f22a6..d0cd342d00e5c 100644 --- a/sycl/test/esimd/esimd_verify.cpp +++ b/sycl/test/esimd/esimd_verify.cpp @@ -3,6 +3,7 @@ // RUN: not %clangxx -fsycl -fsycl-device-only -flegacy-pass-manager -O0 -S %s -o /dev/null 2>&1 | FileCheck %s // RUN: not %clangxx -fsycl -fsycl-device-only -fno-legacy-pass-manager -O0 -S %s -o /dev/null 2>&1 | FileCheck %s +#include #include using namespace cl::sycl; @@ -10,6 +11,7 @@ using namespace sycl::ext::intel::esimd; // CHECK-DAG: error: function 'cl::sycl::multi_ptr<{{.+}}> cl::sycl::accessor<{{.+}}>::get_pointer<{{.+}}>() const' is not supported in ESIMD context // CHECK-DAG: error: function '{{.+}} cl::sycl::accessor<{{.+}}>::operator[]<{{.+}}>({{.+}}) const' is not supported in ESIMD context +// CHECK-DAG: error: function 'cl::sycl::ext::oneapi::detail::reducer, void>::combine(int const&)' is not supported in ESIMD context SYCL_EXTERNAL auto test(accessor &acc) @@ -22,3 +24,10 @@ test1(accessor &acc) SYCL_ESIMD_FUNCTION { acc[0] = 0; } + +void test2(sycl::handler &cgh, int *buf) { + auto reduction = sycl::reduction(buf, sycl::plus()); + cgh.parallel_for(sycl::range<1>(1), reduction, + [=](sycl::id<1>, auto &reducer) + SYCL_ESIMD_KERNEL { reducer.combine(15); }); +}