diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td index 43ca5337067ca..3d73d00ecfdd7 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td +++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td @@ -47,8 +47,9 @@ def GPUTargetAttrInterface : AttrInterface<"TargetAttrInterface"> { meant to be used for passing additional options that are not in the attribute. }], "::mlir::Attribute", "createObject", - (ins "const ::llvm::SmallVector&":$object, - "const ::mlir::gpu::TargetOptions&":$options)> + (ins "::mlir::Operation *":$module, + "const ::llvm::SmallVector &":$object, + "const ::mlir::gpu::TargetOptions &":$options)> ]; } diff --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp index adae3bef763ff..86a3b4780e88c 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp @@ -99,7 +99,8 @@ LogicalResult moduleSerializer(GPUModuleOp op, return failure(); } - Attribute object = target.createObject(*serializedModule, targetOptions); + Attribute object = + target.createObject(op, *serializedModule, targetOptions); if (!object) { op.emitError("An error happened while creating the object."); return failure(); diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp index e608d26e8d2ec..a75b7f92ed8dc 100644 --- a/mlir/lib/Target/LLVM/NVVM/Target.cpp +++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp @@ -49,7 +49,7 @@ class NVVMTargetAttrImpl serializeToObject(Attribute attribute, Operation *module, const gpu::TargetOptions &options) const; - Attribute createObject(Attribute attribute, + Attribute createObject(Attribute attribute, Operation *module, const SmallVector &object, const gpu::TargetOptions &options) const; }; @@ -591,7 +591,7 @@ NVVMTargetAttrImpl::serializeToObject(Attribute attribute, Operation *module, } Attribute -NVVMTargetAttrImpl::createObject(Attribute attribute, +NVVMTargetAttrImpl::createObject(Attribute attribute, Operation *module, const SmallVector &object, const gpu::TargetOptions &options) const { auto target = cast(attribute); diff --git a/mlir/lib/Target/LLVM/ROCDL/Target.cpp b/mlir/lib/Target/LLVM/ROCDL/Target.cpp index 4d23f987eb05e..e32a0c7e14e85 100644 --- a/mlir/lib/Target/LLVM/ROCDL/Target.cpp +++ b/mlir/lib/Target/LLVM/ROCDL/Target.cpp @@ -59,7 +59,7 @@ class ROCDLTargetAttrImpl serializeToObject(Attribute attribute, Operation *module, const gpu::TargetOptions &options) const; - Attribute createObject(Attribute attribute, + Attribute createObject(Attribute attribute, Operation *module, const SmallVector &object, const gpu::TargetOptions &options) const; }; @@ -500,7 +500,7 @@ std::optional> ROCDLTargetAttrImpl::serializeToObject( } Attribute -ROCDLTargetAttrImpl::createObject(Attribute attribute, +ROCDLTargetAttrImpl::createObject(Attribute attribute, Operation *module, const SmallVector &object, const gpu::TargetOptions &options) const { gpu::CompilationTarget format = options.getCompilationTarget(); diff --git a/mlir/lib/Target/SPIRV/Target.cpp b/mlir/lib/Target/SPIRV/Target.cpp index 4c416abe71cac..d48548bf9709c 100644 --- a/mlir/lib/Target/SPIRV/Target.cpp +++ b/mlir/lib/Target/SPIRV/Target.cpp @@ -34,7 +34,7 @@ class SPIRVTargetAttrImpl serializeToObject(Attribute attribute, Operation *module, const gpu::TargetOptions &options) const; - Attribute createObject(Attribute attribute, + Attribute createObject(Attribute attribute, Operation *module, const SmallVector &object, const gpu::TargetOptions &options) const; }; @@ -89,7 +89,7 @@ std::optional> SPIRVTargetAttrImpl::serializeToObject( // Prepare Attribute for gpu.binary with serialized kernel object Attribute -SPIRVTargetAttrImpl::createObject(Attribute attribute, +SPIRVTargetAttrImpl::createObject(Attribute attribute, Operation *module, const SmallVector &object, const gpu::TargetOptions &options) const { gpu::CompilationTarget format = options.getCompilationTarget(); diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp index 96828828fe3f4..37dbfe6203687 100644 --- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp +++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Parser/Parser.h" @@ -30,24 +32,46 @@ using namespace mlir; #define SKIP_WITHOUT_NATIVE(x) x #endif +namespace { +// Dummy interface for testing. +class TargetAttrImpl + : public gpu::TargetAttrInterface::FallbackModel { +public: + std::optional> + serializeToObject(Attribute attribute, Operation *module, + const gpu::TargetOptions &options) const; + + Attribute createObject(Attribute attribute, Operation *module, + const SmallVector &object, + const gpu::TargetOptions &options) const; +}; +} // namespace + class MLIRTargetLLVM : public ::testing::Test { protected: void SetUp() override { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); + registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { + IntegerAttr::attachInterface(*ctx); + }); + registerBuiltinDialectTranslation(registry); + registerLLVMDialectTranslation(registry); + registry.insert(); } -}; -TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(SerializeToLLVMBitcode)) { + // Dialect registry. + DialectRegistry registry; + + // MLIR module used for the tests. std::string moduleStr = R"mlir( llvm.func @foo(%arg0 : i32) { llvm.return } )mlir"; +}; - DialectRegistry registry; - registerBuiltinDialectTranslation(registry); - registerLLVMDialectTranslation(registry); +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(SerializeToLLVMBitcode)) { MLIRContext context(registry); OwningOpRef module = @@ -74,3 +98,52 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(SerializeToLLVMBitcode)) { // Check that it has a function named `foo`. ASSERT_TRUE((*llvmModule)->getFunction("foo") != nullptr); } + +std::optional> +TargetAttrImpl::serializeToObject(Attribute attribute, Operation *module, + const gpu::TargetOptions &options) const { + module->setAttr("serialize_attr", UnitAttr::get(module->getContext())); + std::string targetTriple = llvm::sys::getProcessTriple(); + LLVM::ModuleToObject serializer(*module, targetTriple, "", ""); + return serializer.run(); +} + +Attribute +TargetAttrImpl::createObject(Attribute attribute, Operation *module, + const SmallVector &object, + const gpu::TargetOptions &options) const { + return gpu::ObjectAttr::get( + module->getContext(), attribute, gpu::CompilationTarget::Offload, + StringAttr::get(module->getContext(), + StringRef(object.data(), object.size())), + module->getAttrDictionary()); +} + +TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(TargetAttrAPI)) { + MLIRContext context(registry); + context.loadAllAvailableDialects(); + + OwningOpRef module = + parseSourceString(moduleStr, &context); + ASSERT_TRUE(!!module); + Builder builder(&context); + IntegerAttr target = builder.getI32IntegerAttr(0); + auto targetAttr = dyn_cast(target); + // Check the attribute holds the interface. + ASSERT_TRUE(!!targetAttr); + gpu::TargetOptions opts; + std::optional> serializedBinary = + targetAttr.serializeToObject(*module, opts); + // Check the serialized string. + ASSERT_TRUE(!!serializedBinary); + ASSERT_TRUE(!serializedBinary->empty()); + // Create the object attribute. + auto object = cast( + targetAttr.createObject(*module, *serializedBinary, opts)); + // Check the object has properties. + DictionaryAttr properties = object.getProperties(); + ASSERT_TRUE(!!properties); + // Check that it contains the attribute added to the module in + // `serializeToObject`. + ASSERT_TRUE(properties.contains("serialize_attr")); +}