diff --git a/mlir/include/mlir/Target/LLVM/Offload.h b/mlir/include/mlir/Target/LLVM/Offload.h new file mode 100644 index 0000000000000..7b705667d477d --- /dev/null +++ b/mlir/include/mlir/Target/LLVM/Offload.h @@ -0,0 +1,61 @@ +//===- Offload.h - LLVM Target Offload --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares LLVM target offload utility classes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVM_OFFLOAD_H +#define MLIR_TARGET_LLVM_OFFLOAD_H + +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/StringRef.h" + +namespace llvm { +class Constant; +class GlobalVariable; +class Module; +} // namespace llvm + +namespace mlir { +namespace LLVM { +/// `OffloadHandler` is a utility class for creating LLVM offload entries. LLVM +/// offload entries hold information on offload symbols; for example, for a GPU +/// kernel, this includes its host address to identify the kernel and the kernel +/// identifier in the binary. Arrays of offload entries can be used to register +/// functions within the CUDA/HIP runtime. Libomptarget also uses these entries +/// to register OMP target offload kernels and variables. +class OffloadHandler { +public: + using OffloadEntryArray = + std::pair; + OffloadHandler(llvm::Module &module) : module(module) {} + + /// Returns the begin symbol name used in the entry array. + static std::string getBeginSymbol(StringRef suffix); + + /// Returns the end symbol name used in the entry array. + static std::string getEndSymbol(StringRef suffix); + + /// Returns the entry array if it exists or a pair of null pointers. + OffloadEntryArray getEntryArray(StringRef suffix); + + /// Emits an empty array of offloading entries. + OffloadEntryArray emitEmptyEntryArray(StringRef suffix); + + /// Inserts an offloading entry into an existing entry array. This method + /// returns failure if the entry array hasn't been declared. + LogicalResult insertOffloadEntry(StringRef suffix, llvm::Constant *entry); + +protected: + llvm::Module &module; +}; +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_TARGET_LLVM_OFFLOAD_H diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt index cc2c3a00a02ea..241a6c64dd868 100644 --- a/mlir/lib/Target/LLVM/CMakeLists.txt +++ b/mlir/lib/Target/LLVM/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(MLIRTargetLLVM ModuleToObject.cpp + Offload.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVM @@ -16,6 +17,7 @@ add_mlir_library(MLIRTargetLLVM Passes Support Target + FrontendOffloading LINK_LIBS PUBLIC MLIRExecutionEngineUtils MLIRTargetLLVMIRExport diff --git a/mlir/lib/Target/LLVM/Offload.cpp b/mlir/lib/Target/LLVM/Offload.cpp new file mode 100644 index 0000000000000..81ba12403bfb9 --- /dev/null +++ b/mlir/lib/Target/LLVM/Offload.cpp @@ -0,0 +1,111 @@ +//===- Offload.cpp - LLVM Target Offload ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines LLVM target offload utility classes. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVM/Offload.h" +#include "llvm/Frontend/Offloading/Utility.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Module.h" + +using namespace mlir; +using namespace mlir::LLVM; + +std::string OffloadHandler::getBeginSymbol(StringRef suffix) { + return ("__begin_offload_" + suffix).str(); +} + +std::string OffloadHandler::getEndSymbol(StringRef suffix) { + return ("__end_offload_" + suffix).str(); +} + +namespace { +/// Returns the type of the entry array. +llvm::ArrayType *getEntryArrayType(llvm::Module &module, size_t numElems) { + return llvm::ArrayType::get(llvm::offloading::getEntryTy(module), numElems); +} + +/// Creates the initializer of the entry array. +llvm::Constant *getEntryArrayBegin(llvm::Module &module, + ArrayRef entries) { + // If there are no entries return a constant zero initializer. + llvm::ArrayType *arrayTy = getEntryArrayType(module, entries.size()); + return entries.empty() ? llvm::ConstantAggregateZero::get(arrayTy) + : llvm::ConstantArray::get(arrayTy, entries); +} + +/// Computes the end position of the entry array. +llvm::Constant *getEntryArrayEnd(llvm::Module &module, + llvm::GlobalVariable *begin, size_t numElems) { + llvm::Type *intTy = module.getDataLayout().getIntPtrType(module.getContext()); + return llvm::ConstantExpr::getGetElementPtr( + llvm::offloading::getEntryTy(module), begin, + ArrayRef({llvm::ConstantInt::get(intTy, numElems)}), + true); +} +} // namespace + +OffloadHandler::OffloadEntryArray +OffloadHandler::getEntryArray(StringRef suffix) { + llvm::GlobalVariable *beginGV = + module.getGlobalVariable(getBeginSymbol(suffix), true); + llvm::GlobalVariable *endGV = + module.getGlobalVariable(getEndSymbol(suffix), true); + return {beginGV, endGV}; +} + +OffloadHandler::OffloadEntryArray +OffloadHandler::emitEmptyEntryArray(StringRef suffix) { + llvm::ArrayType *arrayTy = getEntryArrayType(module, 0); + auto *beginGV = new llvm::GlobalVariable( + module, arrayTy, /*isConstant=*/true, llvm::GlobalValue::InternalLinkage, + getEntryArrayBegin(module, {}), getBeginSymbol(suffix)); + auto *endGV = new llvm::GlobalVariable( + module, llvm::PointerType::get(module.getContext(), 0), + /*isConstant=*/true, llvm::GlobalValue::InternalLinkage, + getEntryArrayEnd(module, beginGV, 0), getEndSymbol(suffix)); + return {beginGV, endGV}; +} + +LogicalResult OffloadHandler::insertOffloadEntry(StringRef suffix, + llvm::Constant *entry) { + // Get the begin and end symbols to the entry array. + std::string beginSymId = getBeginSymbol(suffix); + llvm::GlobalVariable *beginGV = module.getGlobalVariable(beginSymId, true); + llvm::GlobalVariable *endGV = + module.getGlobalVariable(getEndSymbol(suffix), true); + // Fail if the symbols are missing. + if (!beginGV || !endGV) + return failure(); + // Create the entry initializer. + assert(beginGV->getInitializer() && "entry array initializer is missing."); + // Add existing entries into the new entry array. + SmallVector entries; + if (auto beginInit = dyn_cast_or_null( + beginGV->getInitializer())) { + for (unsigned i = 0; i < beginInit->getNumOperands(); ++i) + entries.push_back(beginInit->getOperand(i)); + } + // Add the new entry. + entries.push_back(entry); + // Create a global holding the new updated set of entries. + auto *arrayTy = llvm::ArrayType::get(llvm::offloading::getEntryTy(module), + entries.size()); + auto *entryArr = new llvm::GlobalVariable( + module, arrayTy, /*isConstant=*/true, llvm::GlobalValue::InternalLinkage, + getEntryArrayBegin(module, entries), beginSymId, endGV); + // Replace the old entry array variable withe new one. + beginGV->replaceAllUsesWith(entryArr); + beginGV->eraseFromParent(); + entryArr->setName(beginSymId); + // Update the end symbol. + endGV->setInitializer(getEntryArrayEnd(module, entryArr, entries.size())); + return success(); +} diff --git a/mlir/unittests/Target/LLVM/CMakeLists.txt b/mlir/unittests/Target/LLVM/CMakeLists.txt index 6d612548a94c0..d04f38ddddfac 100644 --- a/mlir/unittests/Target/LLVM/CMakeLists.txt +++ b/mlir/unittests/Target/LLVM/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_unittest(MLIRTargetLLVMTests + Offload.cpp SerializeNVVMTarget.cpp SerializeROCDLTarget.cpp SerializeToLLVMBitcode.cpp diff --git a/mlir/unittests/Target/LLVM/Offload.cpp b/mlir/unittests/Target/LLVM/Offload.cpp new file mode 100644 index 0000000000000..375edc2e9614d --- /dev/null +++ b/mlir/unittests/Target/LLVM/Offload.cpp @@ -0,0 +1,49 @@ +//===- Offload.cpp ----------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVM/Offload.h" +#include "llvm/Frontend/Offloading/Utility.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Module.h" + +#include "gmock/gmock.h" + +using namespace llvm; + +TEST(MLIRTarget, OffloadAPI) { + using OffloadEntryArray = mlir::LLVM::OffloadHandler::OffloadEntryArray; + LLVMContext llvmContext; + Module llvmModule("offload", llvmContext); + mlir::LLVM::OffloadHandler handler(llvmModule); + StringRef suffix = ".mlir"; + // Check there's no entry array with `.mlir` suffix. + OffloadEntryArray entryArray = handler.getEntryArray(suffix); + EXPECT_EQ(entryArray, OffloadEntryArray()); + // Emit the entry array. + handler.emitEmptyEntryArray(suffix); + // Check there's an entry array with `.mlir` suffix. + entryArray = handler.getEntryArray(suffix); + ASSERT_NE(entryArray.first, nullptr); + ASSERT_NE(entryArray.second, nullptr); + // Check the array contains no entries. + auto *zeroInitializer = dyn_cast_or_null( + entryArray.first->getInitializer()); + ASSERT_NE(zeroInitializer, nullptr); + // Insert an empty entries. + auto emptyEntry = + ConstantAggregateZero::get(offloading::getEntryTy(llvmModule)); + ASSERT_TRUE(succeeded(handler.insertOffloadEntry(suffix, emptyEntry))); + // Check there's an entry in the entry array with `.mlir` suffix. + entryArray = handler.getEntryArray(suffix); + ASSERT_NE(entryArray.first, nullptr); + Constant *arrayInitializer = entryArray.first->getInitializer(); + ASSERT_NE(arrayInitializer, nullptr); + auto *arrayTy = dyn_cast_or_null(arrayInitializer->getType()); + ASSERT_NE(arrayTy, nullptr); + EXPECT_EQ(arrayTy->getNumElements(), 1u); +}