Skip to content

Commit

Permalink
linalg to llvm conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
vwellsTT committed Dec 10, 2024
1 parent cf44b15 commit 8550c24
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 11 deletions.
21 changes: 19 additions & 2 deletions include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

namespace mlir::tt::ttir {

#ifdef TTMLIR_ENABLE_STABLEHLO
// Options for the TTIR to TTNN backend pipeline.
//
struct StableHLOToTTIRPipelineOptions
Expand All @@ -31,12 +32,28 @@ struct StableHLOToTTIRPipelineOptions
// that the TTIR inliner pass may inline the ops.
llvm::cl::init(true)};
};
#endif

struct LinalgToLLVMPipelineOptions
: public PassPipelineOptions<LinalgToLLVMPipelineOptions> {
// TODO: we might want some more options to say lower through affine loops
// instead of scf directly, etc. which could be new options
Option<bool> cleanupOutputEnabled{
*this, "enable-remove-dead-values",
llvm::cl::desc(
"Enable final cleanup passes (canonicalize, SCC, CSE, SymbolDCE)"),
llvm::cl::init(true)};
};

#ifdef TTMLIR_ENABLE_STABLEHLO
void createStableHLOToTTIRPipeline(
OpPassManager &pm, const StableHLOToTTIRPipelineOptions &options);
#endif

void createLinalgToLLVMPipeline(OpPassManager &pm,
const LinalgToLLVMPipelineOptions &options);

/// Registers all pipelines for the TTIR dialect. Currently,
/// this includes only the "stablehlo-to-ttir-pipeline".
/// Registers all pipelines for the TTIR dialect.
void registerTTIRPipelines();
} // namespace mlir::tt::ttir

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTIR/Pipelines/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTTIRPipelines

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/ttmlir
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM

LINK_LIBS PUBLIC
MLIRTTIRDialect
Expand Down
52 changes: 52 additions & 0 deletions lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@

#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h"

#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"
#include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"

#include "ttmlir/Conversion/Passes.h"
Expand Down Expand Up @@ -34,6 +46,43 @@ void createStableHLOToTTIRPipeline(
}
#endif

void createLinalgToLLVMPipeline(OpPassManager &manager,
const LinalgToLLVMPipelineOptions &options) {
manager.addPass(mlir::createCanonicalizerPass());

manager.addPass(mlir::createConvertElementwiseToLinalgPass());
manager.addPass(mlir::createConvertTensorToLinalgPass());

// One-shot bufferize, from
// https://mlir.llvm.org/docs/Bufferization/#ownership-based-buffer-deallocation
mlir::bufferization::OneShotBufferizationOptions bufferizationOptions;
bufferizationOptions.bufferizeFunctionBoundaries = true;
manager.addPass(
mlir::bufferization::createOneShotBufferizePass(bufferizationOptions));
mlir::bufferization::BufferDeallocationPipelineOptions deallocationOptions;
mlir::bufferization::buildBufferDeallocationPipeline(manager,
deallocationOptions);

manager.addPass(mlir::createConvertLinalgToLoopsPass());

// Needed to lower memref.subview
manager.addPass(mlir::memref::createExpandStridedMetadataPass());

manager.addPass(mlir::createConvertSCFToCFPass());
manager.addPass(mlir::createConvertControlFlowToLLVMPass());
manager.addPass(mlir::createArithToLLVMConversionPass());
manager.addPass(mlir::createConvertFuncToLLVMPass());
manager.addPass(mlir::createFinalizeMemRefToLLVMConversionPass());
manager.addPass(mlir::createReconcileUnrealizedCastsPass());

// Cleanup
if (options.cleanupOutputEnabled) {
manager.addPass(mlir::createCanonicalizerPass());
manager.addPass(mlir::createSCCPPass());
manager.addPass(mlir::createCSEPass());
manager.addPass(mlir::createSymbolDCEPass());
}
}
//===----------------------------------------------------------------------===//
// Pipeline registration.
//===----------------------------------------------------------------------===//
Expand All @@ -45,5 +94,8 @@ void registerTTIRPipelines() {
"Pipeline lowering stablehlo to ttir dialect.",
mlir::tt::ttir::createStableHLOToTTIRPipeline);
#endif
mlir::PassPipelineRegistration<LinalgToLLVMPipelineOptions>(
"linalg-to-llvm-pipeline", "Pipeline lowering linalg to llvm dialect.",
mlir::tt::ttir::createLinalgToLLVMPipeline);
}
} // namespace mlir::tt::ttir
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRTTNNDialect
TTMLIRTTNNUtils
MLIRSCFToEmitC
MLIRLinalgDialect
MLIRBufferizationDialect
MLIRMLProgramDialect
TTNNOpModelLib
)
39 changes: 30 additions & 9 deletions lib/RegisterAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,44 @@
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h"
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h"

#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"

#ifdef TTMLIR_ENABLE_STABLEHLO
#include "stablehlo/dialect/Register.h"
#endif

void mlir::tt::registerAllDialects(mlir::DialectRegistry &registry) {
registry
.insert<mlir::tt::TTDialect, mlir::tt::ttir::TTIRDialect,
mlir::tt::ttnn::TTNNDialect, mlir::tt::ttmetal::TTMetalDialect,
mlir::tt::ttkernel::TTKernelDialect, mlir::func::FuncDialect,
mlir::arith::ArithDialect, mlir::ml_program::MLProgramDialect,
mlir::tensor::TensorDialect, mlir::linalg::LinalgDialect,
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
mlir::tosa::TosaDialect, mlir::vector::VectorDialect,
mlir::emitc::EmitCDialect>();
registry.insert<
mlir::tt::TTDialect, mlir::tt::ttir::TTIRDialect,
mlir::tt::ttnn::TTNNDialect, mlir::tt::ttmetal::TTMetalDialect,
mlir::tt::ttkernel::TTKernelDialect, mlir::func::FuncDialect,
mlir::arith::ArithDialect, mlir::ml_program::MLProgramDialect,
mlir::tensor::TensorDialect, mlir::linalg::LinalgDialect,
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
mlir::tosa::TosaDialect, mlir::vector::VectorDialect,
mlir::emitc::EmitCDialect, mlir::bufferization::BufferizationDialect>();
#if TTMLIR_ENABLE_STABLEHLO
mlir::stablehlo::registerAllDialects(registry);
#endif
arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
}

void mlir::tt::registerAllExtensions(mlir::DialectRegistry &registry) {
Expand Down

0 comments on commit 8550c24

Please sign in to comment.