Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tools] Introduce tblgen-to-irdl tool #66865

Merged
merged 3 commits into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ set(MLIR_TEST_DEPENDS
mlir-tblgen
mlir-translate
tblgen-lsp-server
tblgen-to-irdl
)

# The native target may not be enabled, in this case we won't
Expand Down
56 changes: 56 additions & 0 deletions mlir/test/tblgen-to-irdl/CMathDialect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: tblgen-to-irdl %s -I=%S/../../include --gen-dialect-irdl-defs --dialect=cmath | FileCheck %s

include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"

// CHECK-LABEL: irdl.dialect @cmath {
def CMath_Dialect : Dialect {
let name = "cmath";
}

class CMath_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<CMath_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}

class CMath_Op<string mnemonic, list<Trait> traits = []>
: Op<CMath_Dialect, mnemonic, traits>;

def f32Orf64Type : Or<[CPred<"::llvm::isa<::mlir::F32>">,
CPred<"::llvm::isa<::mlir::F64>">]>;

def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
let parameters = (ins f32Orf64Type:$elementType);
}

// CHECK: irdl.operation @identity {
// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: irdl.operands()
// CHECK-NEXT: irdl.results(%0)
// CHECK-NEXT: }
def CMath_IdentityOp : CMath_Op<"identity"> {
let results = (outs CMath_ComplexType:$out);
}

// CHECK: irdl.operation @mul {
// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %2 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: irdl.operands(%0, %1)
// CHECK-NEXT: irdl.results(%2)
// CHECK-NEXT: }
def CMath_MulOp : CMath_Op<"mul"> {
let arguments = (ins CMath_ComplexType:$in1, CMath_ComplexType:$in2);
let results = (outs CMath_ComplexType:$out);
}

// CHECK: irdl.operation @norm {
// CHECK-NEXT: %0 = irdl.c_pred "(true)"
// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: irdl.operands(%0)
// CHECK-NEXT: irdl.results(%1)
// CHECK-NEXT: }
def CMath_NormOp : CMath_Op<"norm"> {
let arguments = (ins AnyType:$in);
let results = (outs CMath_ComplexType:$out);
}
1 change: 1 addition & 0 deletions mlir/tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_subdirectory(mlir-spirv-cpu-runner)
add_subdirectory(mlir-translate)
add_subdirectory(mlir-vulkan-runner)
add_subdirectory(tblgen-lsp-server)
add_subdirectory(tblgen-to-irdl)

# mlir-cpu-runner requires ExecutionEngine.
if(MLIR_ENABLE_EXECUTION_ENGINE)
Expand Down
20 changes: 20 additions & 0 deletions mlir/tools/tblgen-to-irdl/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
set(LLVM_LINK_COMPONENTS
TableGen
)

add_tablegen(tblgen-to-irdl MLIR
DESTINATION "${MLIR_TOOLS_INSTALL_DIR}"
EXPORT MLIR
tblgen-to-irdl.cpp
OpDefinitionsGen.cpp
)

target_link_libraries(tblgen-to-irdl
PRIVATE
MLIRIR
MLIRIRDL
MLIRTblgenLib
MLIRSupport
)

mlir_check_all_link_libraries(tblgen-to-irdl)
155 changes: 155 additions & 0 deletions mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
//===- OpDefinitionsGen.cpp - IRDL op definitions generator ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// OpDefinitionsGen uses the description of operations to generate IRDL
// definitions for ops.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/GenNameParser.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"

using namespace llvm;
using namespace mlir;
using tblgen::NamedTypeConstraint;

static llvm::cl::OptionCategory dialectGenCat("Options for -gen-irdl-dialect");
llvm::cl::opt<std::string>
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::Required);

irdl::CPredOp createConstraint(OpBuilder &builder,
NamedTypeConstraint namedConstraint) {
MLIRContext *ctx = builder.getContext();
// Build the constraint as a string.
std::string constraint =
namedConstraint.constraint.getPredicate().getCondition();
// Build a CPredOp to match the C constraint built.
irdl::CPredOp op = builder.create<irdl::CPredOp>(
UnknownLoc::get(ctx), StringAttr::get(ctx, constraint));
return op;
}

/// Returns the name of the operation without the dialect prefix.
static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
StringRef opName = tblgenOp.getDef().getValueAsString("opName");
return opName;
}

/// Extract an operation to IRDL.
irdl::OperationOp createIRDLOperation(OpBuilder &builder,
tblgen::Operator &tblgenOp) {
MLIRContext *ctx = builder.getContext();
StringRef opName = getOperatorName(tblgenOp);

irdl::OperationOp op = builder.create<irdl::OperationOp>(
UnknownLoc::get(ctx), StringAttr::get(ctx, opName));

// Add the block in the region.
Block &opBlock = op.getBody().emplaceBlock();
OpBuilder consBuilder = OpBuilder::atBlockBegin(&opBlock);

auto getValues = [&](tblgen::Operator::const_value_range namedCons) {
SmallVector<Value> operands;
SmallVector<irdl::VariadicityAttr> variadicity;
for (const NamedTypeConstraint &namedCons : namedCons) {
auto operand = createConstraint(consBuilder, namedCons);
operands.push_back(operand);

irdl::VariadicityAttr var;
if (namedCons.isOptional())
var = consBuilder.getAttr<irdl::VariadicityAttr>(
irdl::Variadicity::optional);
else if (namedCons.isVariadic())
var = consBuilder.getAttr<irdl::VariadicityAttr>(
irdl::Variadicity::variadic);
else
var = consBuilder.getAttr<irdl::VariadicityAttr>(
irdl::Variadicity::single);

variadicity.push_back(var);
}
return std::make_tuple(operands, variadicity);
};

auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands());
auto [results, resultVariadicity] = getValues(tblgenOp.getResults());

// Create the operands and results operations.
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
operandVariadicity);
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
resultVariadicity);

return op;
}

static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
StringAttr::get(ctx, selectedDialect));
}

static std::vector<llvm::Record *>
getOpDefinitions(const RecordKeeper &recordKeeper) {
if (!recordKeeper.getClass("Op"))
return {};
return recordKeeper.getAllDerivedDefinitions("Op");
}

static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
raw_ostream &os) {
// Initialize.
MLIRContext ctx;
ctx.getOrLoadDialect<irdl::IRDLDialect>();
OpBuilder builder(&ctx);

// Create a module op and set it as the insertion point.
ModuleOp module = builder.create<ModuleOp>(UnknownLoc::get(&ctx));
builder = builder.atBlockBegin(module.getBody());
// Create the dialect and insert it.
irdl::DialectOp dialect = createIRDLDialect(builder);
// Set insertion point to start of DialectOp.
builder = builder.atBlockBegin(&dialect.getBody().emplaceBlock());

std::vector<Record *> defs = getOpDefinitions(recordKeeper);
for (auto *def : defs) {
tblgen::Operator tblgenOp(def);
if (tblgenOp.getDialectName() != selectedDialect)
continue;

createIRDLOperation(builder, tblgenOp);
}

// Print the module.
module.print(os);

return false;
}

static mlir::GenRegistration
genOpDefs("gen-dialect-irdl-defs", "Generate IRDL dialect definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitDialectIRDLDefs(records, os);
});
27 changes: 27 additions & 0 deletions mlir/tools/tblgen-to-irdl/tblgen-to-irdl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- mlir-tblgen.cpp - Top-Level TableGen implementation for MLIR -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the main function for MLIR's TableGen IRDL backend.
//
//===----------------------------------------------------------------------===//

#include "mlir/TableGen/GenInfo.h"
#include "mlir/Tools/mlir-tblgen/MlirTblgenMain.h"
#include "llvm/TableGen/Record.h"

using namespace llvm;
using namespace mlir;

// Generator that prints records.
GenRegistration printRecords("print-records", "Print all records to stdout",
[](const RecordKeeper &records, raw_ostream &os) {
os << records;
return false;
});

int main(int argc, char **argv) { return MlirTblgenMain(argc, argv); }
Loading