Skip to content

Commit

Permalink
[CodeGen][Common] Switch to new pass generation tablegen definitions. (
Browse files Browse the repository at this point in the history
…#18166)

This is not a NFC change because there are non-trivial changes in three
passes. The rest are just switching to the new tablegen definition
style.

The revision applies few cleanups during refactoring, which includes:

- Make all the passes have `iree-codegen-*` prefix.
- Make filename match the pass name.
- Switch namespaces to the single-line syntax.

The non-trivial changes happen in `IREEComprehensiveBufferizePass`,
`TileAndDistributeToWorkgroupsPass`, and
`TransformDialectInterpreterPass`.

In the `IREEComprehensiveBufferizePass`, the default allocationFn and
memcpyFn values are no longer `std::nullopt`. They are the default
functions that implemented in the file. It matches the behavior. The
only difference is that it does not rely on users to pass `std::nullopt`
to trigger the "default" behavior. Instead, it initialize the default
value of the class members be them. This saves one level of nesting
logic.

In the `TileAndDistributeToWorkgroupsPass`, it switches to use
`clEnumValN`. Without the change, the values in `.td` are magic numbers.
Now we explicitly tie them to linalg::DistributionMethod.

In `TransformDialectInterpreterPass`, it no longer uses
`--iree-codegen-transform-dialect-library` to get the library path.
Instead, users should use the predefined option (i.e.,
`library-file-name`) to specify the path. The old behavior was a hack
because the pass itself should not access external variable if possible.
The `--iree-codegen-transform-dialect-library` is still valid, but the
lit test does not use it anymore. Because it is very tricky to teach
tablegen to generate such behavior.

---------

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW authored Aug 12, 2024
1 parent 49198a9 commit 3483893
Show file tree
Hide file tree
Showing 72 changed files with 535 additions and 959 deletions.
20 changes: 9 additions & 11 deletions compiler/src/iree/compiler/Codegen/Common/AddFastMathFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Matchers.h"
Expand All @@ -13,8 +12,10 @@

#define DEBUG_TYPE "iree-codegen-add-fast-math-flags"

using namespace mlir;
using namespace mlir::iree_compiler;
namespace mlir::iree_compiler {

#define GEN_PASS_DEF_ADDFASTMATHFLAGSPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

/// Add `contract` FMF to operations that support it.
static void addContractFMF(Operation *op) {
Expand All @@ -29,19 +30,16 @@ namespace {
/// Add the corresponding fast-math flags to operations given a floating-point
/// optimization mode.
// TODO: For now we only allow default flags, such as arithmetic reassociation.
struct AddFastMathFlagsPass
: public AddFastMathFlagsBase<AddFastMathFlagsPass> {
struct AddFastMathFlagsPass final
: impl::AddFastMathFlagsPassBase<AddFastMathFlagsPass> {
public:
using AddFastMathFlagsBase::AddFastMathFlagsBase;
using impl::AddFastMathFlagsPassBase<
AddFastMathFlagsPass>::AddFastMathFlagsPassBase;

void runOnOperation() override {
getOperation()->walk([](Operation *op) { addContractFMF(op); });
}
};

} // namespace

std::unique_ptr<OperationPass<LLVM::LLVMFuncOp>>
mlir::iree_compiler::createAddFastMathFlagsPass() {
return std::make_unique<AddFastMathFlagsPass>();
}
} // namespace mlir::iree_compiler
5 changes: 2 additions & 3 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ iree_gentbl_cc_library(
iree_compiler_cc_library(
name = "PassHeaders",
hdrs = [
"PassDetail.h",
"Passes.h",
"Passes.h.inc",
],
Expand Down Expand Up @@ -133,9 +132,10 @@ iree_compiler_cc_library(
"PropagateReshapesByExpansion.cpp",
"ReconcileTranslationInfo.cpp",
"RematerializeParallelOps.cpp",
"RemoveTrivialLoops.cpp",
"RemoveSingleIterationLoop.cpp",
"ReplaceSlowMinMaxOps.cpp",
"SplitFullPartialTransferPass.cpp",
"TensorToVectorVectorizePad.cpp",
"TestExecutablePreprocessing.cpp",
"TestPartitionableLoopsInterface.cpp",
"TileAndDistributeToWorkgroupsPass.cpp",
Expand All @@ -144,7 +144,6 @@ iree_compiler_cc_library(
"TypePropagationPass.cpp",
"UserConfig.cpp",
"VectorizeMemrefCopy.cpp",
"VectorizePad.cpp",
],
hdrs = [
"BufferizationAnalysis.h",
Expand Down
13 changes: 5 additions & 8 deletions compiler/src/iree/compiler/Codegen/Common/BubbleUpOrdinalOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
//
//===---------------------------------------------------------------------===//

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_BUBBLEUPORDINALOPSPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

/// Replace the following sequence
Expand Down Expand Up @@ -62,8 +64,8 @@ struct BubbleUpAcrossCastOp
}
};

struct BubbleUpOrdinalOpsPass
: public BubbleUpOrdinalOpsBase<BubbleUpOrdinalOpsPass> {
struct BubbleUpOrdinalOpsPass final
: impl::BubbleUpOrdinalOpsPassBase<BubbleUpOrdinalOpsPass> {
void runOnOperation() override;
};
} // namespace
Expand All @@ -77,9 +79,4 @@ void BubbleUpOrdinalOpsPass::runOnOperation() {
return signalPassFailure();
}
}

std::unique_ptr<Pass> createBubbleUpOrdinalOpsPass() {
return std::make_unique<BubbleUpOrdinalOpsPass>();
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
//
//===----------------------------------------------------------------------===//

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/PassUtils.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
Expand All @@ -30,16 +29,18 @@

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_BUFFERIZECOPYONLYDISPATCHESPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

/// Pass to bufferize early copy-only dispatches. This allows backends
/// to use the `linalg.generic` operation generated for lowering the dispatch.
struct BufferizeCopyOnlyDispatchesPass
: public BufferizeCopyOnlyDispatchesBase<BufferizeCopyOnlyDispatchesPass> {
BufferizeCopyOnlyDispatchesPass() = default;
BufferizeCopyOnlyDispatchesPass(const BufferizeCopyOnlyDispatchesPass &pass) {
}

struct BufferizeCopyOnlyDispatchesPass final
: impl::BufferizeCopyOnlyDispatchesPassBase<
BufferizeCopyOnlyDispatchesPass> {
using impl::BufferizeCopyOnlyDispatchesPassBase<
BufferizeCopyOnlyDispatchesPass>::BufferizeCopyOnlyDispatchesPassBase;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect, bufferization::BufferizationDialect,
IREE::Flow::FlowDialect, linalg::LinalgDialect,
Expand Down Expand Up @@ -109,9 +110,4 @@ void BufferizeCopyOnlyDispatchesPass::runOnOperation() {
}
}

std::unique_ptr<InterfacePass<FunctionOpInterface>>
createBufferizeCopyOnlyDispatchesPass() {
return std::make_unique<BufferizeCopyOnlyDispatchesPass>();
}

} // namespace mlir::iree_compiler
5 changes: 2 additions & 3 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ iree_cc_library(
NAME
PassHeaders
HDRS
"PassDetail.h"
"Passes.h"
"Passes.h.inc"
DEPS
Expand Down Expand Up @@ -124,9 +123,10 @@ iree_cc_library(
"PropagateReshapesByExpansion.cpp"
"ReconcileTranslationInfo.cpp"
"RematerializeParallelOps.cpp"
"RemoveTrivialLoops.cpp"
"RemoveSingleIterationLoop.cpp"
"ReplaceSlowMinMaxOps.cpp"
"SplitFullPartialTransferPass.cpp"
"TensorToVectorVectorizePad.cpp"
"TestExecutablePreprocessing.cpp"
"TestPartitionableLoopsInterface.cpp"
"TileAndDistributeToWorkgroupsPass.cpp"
Expand All @@ -135,7 +135,6 @@ iree_cc_library(
"TypePropagationPass.cpp"
"UserConfig.cpp"
"VectorizeMemrefCopy.cpp"
"VectorizePad.cpp"
DEPS
::PassHeaders
::PassesIncGen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
//
//===----------------------------------------------------------------------===//

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
Expand All @@ -26,11 +25,14 @@

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_CLEANUPBUFFERALLOCVIEWPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

/// Runs canonicalization patterns on interface load/store ops.
struct CleanupBufferAllocViewPass
: public CleanupBufferAllocViewBase<CleanupBufferAllocViewPass> {
struct CleanupBufferAllocViewPass final
: impl::CleanupBufferAllocViewPassBase<CleanupBufferAllocViewPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateReshapeToInterfaceTensorPatterns(patterns);
Expand All @@ -43,10 +45,4 @@ struct CleanupBufferAllocViewPass
};

} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createCleanupBufferAllocViewPass() {
return std::make_unique<CleanupBufferAllocViewPass>();
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "llvm/Support/Debug.h"
Expand All @@ -24,6 +23,9 @@

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_CONCRETIZEPADRESULTSHAPEPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

/// Gets the given `attrOrValue` as an index value by creating constant ops
/// for attributes.
static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder,
Expand Down Expand Up @@ -126,12 +128,9 @@ struct ConcretizePadResultShape final : public OpRewritePattern<tensor::PadOp> {
};

class ConcretizePadResultShapePass final
: public ConcretizePadResultShapeBase<ConcretizePadResultShapePass> {
: public impl::ConcretizePadResultShapePassBase<
ConcretizePadResultShapePass> {
public:
ConcretizePadResultShapePass() = default;
ConcretizePadResultShapePass(const ConcretizePadResultShapePass &pass) =
default;

void runOnOperation() override {
MLIRContext *context = &getContext();
auto funcOp = getOperation();
Expand All @@ -155,11 +154,6 @@ class ConcretizePadResultShapePass final

} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConcretizePadResultShapePass() {
return std::make_unique<ConcretizePadResultShapePass>();
}

void populateConcretizePadResultShapePatterns(RewritePatternSet &patterns,
ArrayRef<int64_t> numWorkgroups) {
MLIRContext *context = patterns.getContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <memory>
#include <utility>

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
Expand All @@ -39,6 +38,9 @@

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_CONVERTBF16ARITHTOF32PASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

Value convertRankedFloat(OpBuilder &builder, Type type, ValueRange inputs,
Expand Down Expand Up @@ -235,9 +237,10 @@ struct PromoteBF16ToF32Converter
}
};

struct ConvertBf16ArithToF32Pass
: public ConvertBf16ArithToF32Base<ConvertBf16ArithToF32Pass> {
using ConvertBf16ArithToF32Base::ConvertBf16ArithToF32Base;
struct ConvertBf16ArithToF32Pass final
: impl::ConvertBf16ArithToF32PassBase<ConvertBf16ArithToF32Pass> {
using impl::ConvertBf16ArithToF32PassBase<
ConvertBf16ArithToF32Pass>::ConvertBf16ArithToF32PassBase;
void runOnOperation() override {
MLIRContext *context = &this->getContext();
RewritePatternSet patterns(context);
Expand Down Expand Up @@ -312,9 +315,4 @@ struct ConvertBf16ArithToF32Pass
};

} // namespace

std::unique_ptr<OperationPass<>> createConvertBf16ArithToF32Pass() {
return std::make_unique<ConvertBf16ArithToF32Pass>();
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
//
//===----------------------------------------------------------------------===//

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
Expand All @@ -24,6 +23,7 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
Expand All @@ -34,10 +34,13 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-spirv-emulate-bf16"
#define DEBUG_TYPE "iree-codegen-convert-bf16-to-uint16-buffers"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_CONVERTBF16TOUINT16BUFFERSPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

class Bf16EmulationConverter : public TypeConverter {
Expand Down Expand Up @@ -249,7 +252,7 @@ static void populateIreeBf16EmulationPatterns(RewritePatternSet &patterns,
//===----------------------------------------------------------------------===//

struct ConvertBf16ToUInt16BuffersPass final
: public ConvertBf16ToUInt16BuffersBase<ConvertBf16ToUInt16BuffersPass> {
: impl::ConvertBf16ToUInt16BuffersPassBase<ConvertBf16ToUInt16BuffersPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect>();
}
Expand Down Expand Up @@ -312,13 +315,4 @@ struct ConvertBf16ToUInt16BuffersPass final
};

} // namespace

//===----------------------------------------------------------------------===//
// Public interface
//===----------------------------------------------------------------------===//

std::unique_ptr<OperationPass<>> createConvertBf16ToUInt16BuffersPass() {
return std::make_unique<ConvertBf16ToUInt16BuffersPass>();
}

} // namespace mlir::iree_compiler
Loading

0 comments on commit 3483893

Please sign in to comment.