Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions offload/liboffload/src/OffloadImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,19 @@ struct ol_program_impl_t {
DeviceImage(DeviceImage) {}
plugin::DeviceImageTy *Image;
std::unique_ptr<llvm::MemoryBuffer> ImageData;
std::vector<std::unique_ptr<ol_symbol_impl_t>> Symbols;
llvm::SmallVector<std::unique_ptr<ol_symbol_impl_t>> Symbols;
std::mutex SymbolListMutex;
__tgt_device_image DeviceImage;
};

struct ol_symbol_impl_t {
ol_symbol_impl_t(GenericKernelTy *Kernel)
: PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL) {}
ol_symbol_impl_t(GlobalTy &&Global)
: PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE) {}
ol_symbol_impl_t(const char *Name, GenericKernelTy *Kernel)
: PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL), Name(Name) {}
ol_symbol_impl_t(const char *Name, GlobalTy &&Global)
: PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE), Name(Name) {}
std::variant<GenericKernelTy *, GlobalTy> PluginImpl;
ol_symbol_kind_t Kind;
llvm::StringRef Name;
};

namespace llvm {
Expand Down Expand Up @@ -714,6 +716,17 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
auto &Device = Program->Image->getDevice();

std::lock_guard<std::mutex> Lock{Program->SymbolListMutex};

// If it already exists, return an existing handle
auto Check = llvm::find_if(Program->Symbols, [&](auto &Sym) {
return Sym->Kind == Kind && Sym->Name == Name;
});
if (Check != Program->Symbols.end()) {
*Symbol = Check->get();
return Error::success();
}

switch (Kind) {
case OL_SYMBOL_KIND_KERNEL: {
auto KernelImpl = Device.constructKernel(Name);
Expand All @@ -723,10 +736,10 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
if (auto Err = KernelImpl->init(Device, *Program->Image))
return Err;

*Symbol =
Program->Symbols
.emplace_back(std::make_unique<ol_symbol_impl_t>(&*KernelImpl))
.get();
*Symbol = Program->Symbols
.emplace_back(std::make_unique<ol_symbol_impl_t>(
KernelImpl->getName(), &*KernelImpl))
.get();
return Error::success();
}
case OL_SYMBOL_KIND_GLOBAL_VARIABLE: {
Expand All @@ -736,8 +749,8 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
return Res;

*Symbol = Program->Symbols
.emplace_back(
std::make_unique<ol_symbol_impl_t>(std::move(GlobalObj)))
.emplace_back(std::make_unique<ol_symbol_impl_t>(
GlobalObj.getName().c_str(), std::move(GlobalObj)))
.get();

return Error::success();
Expand Down
18 changes: 18 additions & 0 deletions offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ TEST_P(olGetSymbolKernelTest, Success) {
ASSERT_NE(Kernel, nullptr);
}

TEST_P(olGetSymbolKernelTest, SuccessSamePtr) {
ol_symbol_handle_t KernelA = nullptr;
ol_symbol_handle_t KernelB = nullptr;
ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &KernelA));
ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &KernelB));
ASSERT_EQ(KernelA, KernelB);
}

TEST_P(olGetSymbolKernelTest, InvalidNullProgram) {
ol_symbol_handle_t Kernel = nullptr;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
Expand Down Expand Up @@ -72,6 +80,16 @@ TEST_P(olGetSymbolGlobalTest, Success) {
ASSERT_NE(Global, nullptr);
}

TEST_P(olGetSymbolGlobalTest, SuccessSamePtr) {
ol_symbol_handle_t GlobalA = nullptr;
ol_symbol_handle_t GlobalB = nullptr;
ASSERT_SUCCESS(
olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &GlobalA));
ASSERT_SUCCESS(
olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &GlobalB));
ASSERT_EQ(GlobalA, GlobalB);
}

TEST_P(olGetSymbolGlobalTest, InvalidNullProgram) {
ol_symbol_handle_t Global = nullptr;
ASSERT_ERROR(
Expand Down
Loading