diff --git a/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/include/mlir/Dialect/LLVMIR/CMakeLists.txt index 1d7d06bc25c1..3e5a0346ed60 100644 --- a/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -8,9 +8,16 @@ set(LLVM_TARGET_DEFINITIONS NVVMOps.td) mlir_tablegen(NVVMOps.h.inc -gen-op-decls) mlir_tablegen(NVVMOps.cpp.inc -gen-op-defs) add_public_tablegen_target(MLIRNVVMOpsIncGen) +set(LLVM_TARGET_DEFINITIONS ROCDLOps.td) +mlir_tablegen(ROCDLOps.h.inc -gen-op-decls) +mlir_tablegen(ROCDLOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRROCDLOpsIncGen) set(LLVM_TARGET_DEFINITIONS LLVMOps.td) mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRLLVMConversionsIncGen) set(LLVM_TARGET_DEFINITIONS NVVMOps.td) mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRNVVMConversionsIncGen) +set(LLVM_TARGET_DEFINITIONS ROCDLOps.td) +mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRROCDLConversionsIncGen) diff --git a/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/include/mlir/Dialect/LLVMIR/ROCDLDialect.h new file mode 100644 index 000000000000..a34c11223f3d --- /dev/null +++ b/include/mlir/Dialect/LLVMIR/ROCDLDialect.h @@ -0,0 +1,54 @@ +//===- ROCDLDialect.h - MLIR ROCDL IR dialect -------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines the ROCDL dialect in MLIR, containing ROCDL operations +// and ROCDL specific extensions to the LLVM type system. +// +// Unfortunately there does not exists a formal definition of ROCDL IR that be +// pointed to here. However the following links contain more information about +// ROCDL (ROCm-Device-Library) +// +// https://github.com/RadeonOpenCompute/ROCm-Device-Libs/blob/master/doc/OCML.md +// https://github.com/RadeonOpenCompute/ROCm-Device-Libs/blob/master/doc/OCKL.md +// https://llvm.org/docs/AMDGPUUsage.html +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace ROCDL { + +///// Ops ///// +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/ROCDLOps.h.inc" + +class ROCDLDialect : public Dialect { +public: + explicit ROCDLDialect(MLIRContext *context); + + static StringRef getDialectNamespace() { return "rocdl"; } +}; + +} // namespace ROCDL +} // namespace mlir + +#endif /* MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ */ diff --git a/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/include/mlir/Dialect/LLVMIR/ROCDLOps.td new file mode 100644 index 000000000000..fb078a70bffd --- /dev/null +++ b/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -0,0 +1,102 @@ +//===-- ROCDLOps.td - ROCDL IR dialect op definition file --*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is the ROCDL IR operation definition file. +// +//===----------------------------------------------------------------------===// + +#ifdef ROCDLIR_OPS +#else +#define ROCDLIR_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// ROCDL dialect definitions +//===----------------------------------------------------------------------===// + +def ROCDL_Dialect : Dialect { + let name = "rocdl"; + let cppNamespace = "ROCDL"; +} + +//===----------------------------------------------------------------------===// +// ROCDL op definitions +//===----------------------------------------------------------------------===// + +class ROCDL_Op traits = []> : + LLVM_OpBase { +} + +//===----------------------------------------------------------------------===// +// ROCDL special register op definitions +//===----------------------------------------------------------------------===// + +class ROCDL_SpecialRegisterOp traits = []> : + ROCDL_Op, + Results<(outs LLVM_Type:$res)>, Arguments<(ins)> { + string llvmBuilder = "$res = createIntrinsicCall(builder," + # "llvm::Intrinsic::amdgcn_" # !subst(".","_", mnemonic) # ");"; + let parser = [{ return parseROCDLOp(parser, result); }]; + let printer = [{ printROCDLOp(p, this->getOperation()); }]; +} + +class ROCDL_DeviceFunctionOp traits = []> : + ROCDL_Op, + Results<(outs LLVM_Type:$res)>, Arguments<(ins)> { + string llvmBuilder = "$res = createDeviceFunctionCall(builder, \"" + # device_function # "\", " # parameter # ");"; + let parser = [{ return parseROCDLOp(parser, result); }]; + let printer = [{ printROCDLOp(p, this->getOperation()); }]; +} + +//===----------------------------------------------------------------------===// +// Thread index and Block index + +def ROCDL_ThreadIdXOp : ROCDL_SpecialRegisterOp<"workitem.id.x">; +def ROCDL_ThreadIdYOp : ROCDL_SpecialRegisterOp<"workitem.id.y">; +def ROCDL_ThreadIdZOp : ROCDL_SpecialRegisterOp<"workitem.id.z">; + +def ROCDL_BlockIdXOp : ROCDL_SpecialRegisterOp<"workgroup.id.x">; +def ROCDL_BlockIdYOp : ROCDL_SpecialRegisterOp<"workgroup.id.y">; +def ROCDL_BlockIdZOp : ROCDL_SpecialRegisterOp<"workgroup.id.z">; + +//===----------------------------------------------------------------------===// +// Thread range and Block range + +def ROCDL_BlockDimXOp : ROCDL_DeviceFunctionOp<"workgroup.dim.x", + "__ockl_get_local_size", 0>; + +def ROCDL_BlockDimYOp : ROCDL_DeviceFunctionOp<"workgroup.dim.y", + "__ockl_get_local_size", 1>; + +def ROCDL_BlockDimZOp : ROCDL_DeviceFunctionOp<"workgroup.dim.z", + "__ockl_get_local_size", 2>; + +def ROCDL_GridDimXOp : ROCDL_DeviceFunctionOp<"grid.dim.x", + "__ockl_get_global_size", 0>; + +def ROCDL_GridDimYOp : ROCDL_DeviceFunctionOp<"grid.dim.y", + "__ockl_get_global_size", 1>; + +def ROCDL_GridDimZOp : ROCDL_DeviceFunctionOp<"grid.dim.z", + "__ockl_get_global_size", 2>; + + +#endif // ROCDLIR_OPS diff --git a/include/mlir/Target/ROCDLIR.h b/include/mlir/Target/ROCDLIR.h new file mode 100644 index 000000000000..9fe151e89d64 --- /dev/null +++ b/include/mlir/Target/ROCDLIR.h @@ -0,0 +1,45 @@ +//===- ROCDLIR.h - MLIR to LLVM + ROCDL IR conversion -----------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file declares the entry point for the MLIR to LLVM + ROCDL IR +// conversion. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_ROCDLIR_H +#define MLIR_TARGET_ROCDLIR_H + +#include + +// Forward-declare LLVM classses. +namespace llvm { +class Module; +} // namespace llvm + +namespace mlir { +class ModuleOp; + +/// Convert the given MLIR module into ROCDL IR. This conversion requires the +/// registration of the LLVM IR dialect and will extract the LLVM context +/// from the registered LLVM IR dialect. In case of error, report it +/// to the error handler registered with the MLIR context, if any (obtained from +/// the MLIR module), and return `nullptr`. +std::unique_ptr translateModuleToROCDLIR(ModuleOp m); + +} // namespace mlir + +#endif // MLIR_TARGET_ROCDLIR_H diff --git a/lib/Dialect/LLVMIR/CMakeLists.txt b/lib/Dialect/LLVMIR/CMakeLists.txt index 4469e7606d39..40bcb572e56d 100644 --- a/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/lib/Dialect/LLVMIR/CMakeLists.txt @@ -15,3 +15,12 @@ add_llvm_library(MLIRNVVMIR ) add_dependencies(MLIRNVVMIR MLIRNVVMOpsIncGen MLIRNVVMConversionsIncGen LLVMAsmParser LLVMCore LLVMSupport) target_link_libraries(MLIRNVVMIR LLVMAsmParser LLVMCore LLVMSupport) + +add_llvm_library(MLIRROCDLIR + IR/ROCDLDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR + ) +add_dependencies(MLIRROCDLIR MLIRROCDLOpsIncGen MLIRROCDLConversionsIncGen LLVMAsmParser LLVMCore LLVMSupport) +target_link_libraries(MLIRROCDLIR LLVMAsmParser LLVMCore LLVMSupport) diff --git a/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp new file mode 100644 index 000000000000..075e01d1dfcb --- /dev/null +++ b/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -0,0 +1,82 @@ +//===- ROCDLDialect.cpp - ROCDL IR Ops and Dialect registration -----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines the types and operation details for the ROCDL IR dialect in +// MLIR, and the LLVM IR dialect. It also registers the dialect. +// +// The ROCDL dialect only contains GPU specific additions on top of the general +// LLVM dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/SourceMgr.h" + +namespace mlir { +namespace ROCDL { + +//===----------------------------------------------------------------------===// +// Printing/parsing for ROCDL ops +//===----------------------------------------------------------------------===// + +static void printROCDLOp(OpAsmPrinter &p, Operation *op) { + p << op->getName() << " "; + p.printOperands(op->getOperands()); + if (op->getNumResults() > 0) + interleaveComma(op->getResultTypes(), p << " : "); +} + +// ::= `rocdl.XYZ` : type +static ParseResult parseROCDLOp(OpAsmParser &parser, OperationState &result) { + Type type; + return failure(parser.parseOptionalAttributeDict(result.attributes) || + parser.parseColonType(type) || + parser.addTypeToList(type, result.types)); +} + +//===----------------------------------------------------------------------===// +// ROCDLDialect initialization, type parsing, and registration. +//===----------------------------------------------------------------------===// + +// TODO(herhut): This should be the llvm.rocdl dialect once this is supported. +ROCDLDialect::ROCDLDialect(MLIRContext *context) : Dialect("rocdl", context) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc" + >(); + + // Support unknown operations because not all ROCDL operations are registered. + allowUnknownOperations(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc" + +static DialectRegistration rocdlDialect; + +} // namespace ROCDL +} // namespace mlir diff --git a/lib/Target/CMakeLists.txt b/lib/Target/CMakeLists.txt index 9f49b813336b..111e2f673137 100644 --- a/lib/Target/CMakeLists.txt +++ b/lib/Target/CMakeLists.txt @@ -28,3 +28,17 @@ target_link_libraries(MLIRTargetNVVMIR MLIRNVVMIR MLIRTargetLLVMIRModuleTranslation ) +add_llvm_library(MLIRTargetROCDLIR + LLVMIR/ConvertToROCDLIR.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR + DEPENDS + intrinsics_gen + ) +target_link_libraries(MLIRTargetROCDLIR + MLIRGPU + MLIRIR + MLIRROCDLIR + MLIRTargetLLVMIRModuleTranslation + ) diff --git a/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/lib/Target/LLVMIR/ConvertToROCDLIR.cpp new file mode 100644 index 000000000000..2c2d1169a3dd --- /dev/null +++ b/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -0,0 +1,119 @@ +//===- ConvertToROCDLIR.cpp - MLIR to LLVM IR conversion ------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a translation between the MLIR LLVM + ROCDL dialects and +// LLVM IR with ROCDL intrinsics and metadata. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/ROCDLIR.h" + +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Translation.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/ToolOutputFile.h" + +#include + +using namespace mlir; + +namespace { +// Create a call to llvm intrisic +static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder, + llvm::Intrinsic::ID intrinsic, + ArrayRef args = {}) { + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic); + return builder.CreateCall(fn, args); +} + +// Create a call to ROCm-Device-Library function +// Currently this routine will work only for calling ROCDL functions that +// take a single int32 argument. It is likely that the interface of this +// function will change to make it more generic. +static llvm::Value *createDeviceFunctionCall(llvm::IRBuilder<> &builder, + StringRef fn_name, int parameter) { + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::FunctionType *function_type = llvm::FunctionType::get( + llvm::Type::getInt32Ty(module->getContext()), // return type. + llvm::Type::getInt32Ty(module->getContext()), // parameter type. + false); // no variadic arguments. + llvm::Function *fn = llvm::dyn_cast( + module->getOrInsertFunction(fn_name, function_type).getCallee()); + llvm::Value *fn_op0 = llvm::ConstantInt::get( + llvm::Type::getInt32Ty(module->getContext()), parameter); + return builder.CreateCall(fn, llvm::ArrayRef(fn_op0)); +} + +class ModuleTranslation : public LLVM::ModuleTranslation { + +public: + explicit ModuleTranslation(ModuleOp module) + : LLVM::ModuleTranslation(module) {} + ~ModuleTranslation() override {} + +protected: + LogicalResult convertOperation(Operation &opInst, + llvm::IRBuilder<> &builder) override { + +#include "mlir/Dialect/LLVMIR/ROCDLConversions.inc" + + return LLVM::ModuleTranslation::convertOperation(opInst, builder); + } +}; +} // namespace + +std::unique_ptr mlir::translateModuleToROCDLIR(ModuleOp m) { + ModuleTranslation translation(m); + + // lower MLIR (with RODL Dialect) to LLVM IR (with ROCDL intrinsics) + auto llvmModule = + LLVM::ModuleTranslation::translateModule(m); + + // foreach GPU kernel + // 1. Insert AMDGPU_KERNEL calling convention. + // 2. Insert amdgpu-flat-workgroup-size(1, 1024) attribute. + for (FuncOp func : m.getOps()) { + if (!func.getAttrOfType(gpu::GPUDialect::getKernelFuncAttrName())) + continue; + + auto *llvmFunc = llvmModule->getFunction(func.getName()); + + llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); + + llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1, 1024"); + } + + return llvmModule; +} + +static TranslateFromMLIRRegistration + registration("mlir-to-rocdlir", + [](ModuleOp module, llvm::raw_ostream &output) { + auto llvmModule = mlir::translateModuleToROCDLIR(module); + if (!llvmModule) + return failure(); + + llvmModule->print(output, nullptr); + return success(); + }); diff --git a/test/Dialect/LLVMIR/rocdl.mlir b/test/Dialect/LLVMIR/rocdl.mlir new file mode 100644 index 000000000000..5a7178030843 --- /dev/null +++ b/test/Dialect/LLVMIR/rocdl.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt %s | FileCheck %s + +func @rocdl_special_regs() -> !llvm.i32 { + // CHECK-LABEL: rocdl_special_regs + // CHECK: rocdl.workitem.id.x : !llvm.i32 + %0 = rocdl.workitem.id.x : !llvm.i32 + // CHECK: rocdl.workitem.id.y : !llvm.i32 + %1 = rocdl.workitem.id.y : !llvm.i32 + // CHECK: rocdl.workitem.id.z : !llvm.i32 + %2 = rocdl.workitem.id.z : !llvm.i32 + // CHECK: rocdl.workgroup.id.x : !llvm.i32 + %3 = rocdl.workgroup.id.x : !llvm.i32 + // CHECK: rocdl.workgroup.id.y : !llvm.i32 + %4 = rocdl.workgroup.id.y : !llvm.i32 + // CHECK: rocdl.workgroup.id.z : !llvm.i32 + %5 = rocdl.workgroup.id.z : !llvm.i32 + // CHECK: rocdl.workgroup.dim.x : !llvm.i32 + %6 = rocdl.workgroup.dim.x : !llvm.i32 + // CHECK: rocdl.workgroup.dim.y : !llvm.i32 + %7 = rocdl.workgroup.dim.y : !llvm.i32 + // CHECK: rocdl.workgroup.dim.z : !llvm.i32 + %8 = rocdl.workgroup.dim.z : !llvm.i32 + // CHECK: rocdl.grid.dim.x : !llvm.i32 + %9 = rocdl.grid.dim.x : !llvm.i32 + // CHECK: rocdl.grid.dim.y : !llvm.i32 + %10 = rocdl.grid.dim.y : !llvm.i32 + // CHECK: rocdl.grid.dim.z : !llvm.i32 + %11 = rocdl.grid.dim.z : !llvm.i32 + llvm.return %0 : !llvm.i32 +} diff --git a/test/Target/rocdl.mlir b/test/Target/rocdl.mlir new file mode 100644 index 000000000000..5665b7156e85 --- /dev/null +++ b/test/Target/rocdl.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-translate -mlir-to-rocdlir %s | FileCheck %s + +func @rocdl_special_regs() -> !llvm.i32 { + // CHECK-LABEL: rocdl_special_regs + // CHECK: call i32 @llvm.amdgcn.workitem.id.x() + %1 = rocdl.workitem.id.x : !llvm.i32 + // CHECK: call i32 @llvm.amdgcn.workitem.id.y() + %2 = rocdl.workitem.id.y : !llvm.i32 + // CHECK: call i32 @llvm.amdgcn.workitem.id.z() + %3 = rocdl.workitem.id.z : !llvm.i32 + // CHECK: call i32 @llvm.amdgcn.workgroup.id.x() + %4 = rocdl.workgroup.id.x : !llvm.i32 + // CHECK: call i32 @llvm.amdgcn.workgroup.id.y() + %5 = rocdl.workgroup.id.y : !llvm.i32 + // CHECK: call i32 @llvm.amdgcn.workgroup.id.z() + %6 = rocdl.workgroup.id.z : !llvm.i32 + // CHECK: call i32 @__ockl_get_local_size(i32 0) + %7 = rocdl.workgroup.dim.x : !llvm.i32 + // CHECK: call i32 @__ockl_get_local_size(i32 1) + %8 = rocdl.workgroup.dim.y : !llvm.i32 + // CHECK: call i32 @__ockl_get_local_size(i32 2) + %9 = rocdl.workgroup.dim.z : !llvm.i32 + // CHECK: call i32 @__ockl_get_global_size(i32 0) + %10 = rocdl.grid.dim.x : !llvm.i32 + // CHECK: call i32 @__ockl_get_global_size(i32 1) + %11 = rocdl.grid.dim.y : !llvm.i32 + // CHECK: call i32 @__ockl_get_global_size(i32 2) + %12 = rocdl.grid.dim.z : !llvm.i32 + llvm.return %1 : !llvm.i32 +} + +func @kernel_func() attributes {gpu.kernel} { + // CHECK-LABEL: amdgpu_kernel void @kernel_func + llvm.return +} diff --git a/tools/mlir-opt/CMakeLists.txt b/tools/mlir-opt/CMakeLists.txt index 2f57eb02f18d..196edd83b050 100644 --- a/tools/mlir-opt/CMakeLists.txt +++ b/tools/mlir-opt/CMakeLists.txt @@ -34,6 +34,7 @@ set(LIBS MLIRPass MLIRQuantizerTransforms MLIRQuantOps + MLIRROCDLIR MLIRSPIRV MLIRSPIRVConversion MLIRStandardOps diff --git a/tools/mlir-translate/CMakeLists.txt b/tools/mlir-translate/CMakeLists.txt index 50df9de8cae7..8f03de449579 100644 --- a/tools/mlir-translate/CMakeLists.txt +++ b/tools/mlir-translate/CMakeLists.txt @@ -9,6 +9,7 @@ set(LIBS MLIRStandardOps MLIRTargetLLVMIR MLIRTargetNVVMIR + MLIRTargetROCDLIR MLIRTransforms MLIRTranslation MLIRSupport