diff --git a/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h b/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h index f922c2501e..83a28b5157 100644 --- a/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h +++ b/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h @@ -9,6 +9,7 @@ namespace mlir::tt::ttir { +#ifdef TTMLIR_ENABLE_STABLEHLO // Options for the TTIR to TTNN backend pipeline. // struct StableHLOToTTIRPipelineOptions @@ -31,12 +32,28 @@ struct StableHLOToTTIRPipelineOptions // that the TTIR inliner pass may inline the ops. llvm::cl::init(true)}; }; +#endif +struct LinalgToLLVMPipelineOptions + : public PassPipelineOptions { + // TODO: we might want some more options to say lower through affine loops + // instead of scf directly, etc. which could be new options + Option cleanupOutputEnabled{ + *this, "enable-remove-dead-values", + llvm::cl::desc( + "Enable final cleanup passes (canonicalize, SCC, CSE, SymbolDCE)"), + llvm::cl::init(true)}; +}; + +#ifdef TTMLIR_ENABLE_STABLEHLO void createStableHLOToTTIRPipeline( OpPassManager &pm, const StableHLOToTTIRPipelineOptions &options); +#endif + +void createLinalgToLLVMPipeline(OpPassManager &pm, + const LinalgToLLVMPipelineOptions &options); -/// Registers all pipelines for the TTIR dialect. Currently, -/// this includes only the "stablehlo-to-ttir-pipeline". +/// Registers all pipelines for the TTIR dialect. void registerTTIRPipelines(); } // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTIR/Pipelines/CMakeLists.txt b/lib/Dialect/TTIR/Pipelines/CMakeLists.txt index 3296c3f11c..9c7f7bb777 100644 --- a/lib/Dialect/TTIR/Pipelines/CMakeLists.txt +++ b/lib/Dialect/TTIR/Pipelines/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTTIRPipelines ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM LINK_LIBS PUBLIC MLIRTTIRDialect diff --git a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp index a092a36e1c..baae6c28fc 100644 --- a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp +++ b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp @@ -4,7 +4,19 @@ #include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" +#include "mlir/Dialect/Bufferization/Pipelines/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Transforms/Passes.h" #include "ttmlir/Conversion/Passes.h" @@ -34,6 +46,43 @@ void createStableHLOToTTIRPipeline( } #endif +void createLinalgToLLVMPipeline(OpPassManager &manager, + const LinalgToLLVMPipelineOptions &options) { + manager.addPass(mlir::createCanonicalizerPass()); + + manager.addPass(mlir::createConvertElementwiseToLinalgPass()); + manager.addPass(mlir::createConvertTensorToLinalgPass()); + + // One-shot bufferize, from + // https://mlir.llvm.org/docs/Bufferization/#ownership-based-buffer-deallocation + mlir::bufferization::OneShotBufferizationOptions bufferizationOptions; + bufferizationOptions.bufferizeFunctionBoundaries = true; + manager.addPass( + mlir::bufferization::createOneShotBufferizePass(bufferizationOptions)); + mlir::bufferization::BufferDeallocationPipelineOptions deallocationOptions; + mlir::bufferization::buildBufferDeallocationPipeline(manager, + deallocationOptions); + + manager.addPass(mlir::createConvertLinalgToLoopsPass()); + + // Needed to lower memref.subview + manager.addPass(mlir::memref::createExpandStridedMetadataPass()); + + manager.addPass(mlir::createConvertSCFToCFPass()); + manager.addPass(mlir::createConvertControlFlowToLLVMPass()); + manager.addPass(mlir::createArithToLLVMConversionPass()); + manager.addPass(mlir::createConvertFuncToLLVMPass()); + manager.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); + manager.addPass(mlir::createReconcileUnrealizedCastsPass()); + + // Cleanup + if (options.cleanupOutputEnabled) { + manager.addPass(mlir::createCanonicalizerPass()); + manager.addPass(mlir::createSCCPPass()); + manager.addPass(mlir::createCSEPass()); + manager.addPass(mlir::createSymbolDCEPass()); + } +} //===----------------------------------------------------------------------===// // Pipeline registration. //===----------------------------------------------------------------------===// @@ -45,5 +94,8 @@ void registerTTIRPipelines() { "Pipeline lowering stablehlo to ttir dialect.", mlir::tt::ttir::createStableHLOToTTIRPipeline); #endif + mlir::PassPipelineRegistration( + "linalg-to-llvm-pipeline", "Pipeline lowering linalg to llvm dialect.", + mlir::tt::ttir::createLinalgToLLVMPipeline); } } // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTNN/IR/CMakeLists.txt b/lib/Dialect/TTNN/IR/CMakeLists.txt index 2fb004e0f3..90c6811af0 100644 --- a/lib/Dialect/TTNN/IR/CMakeLists.txt +++ b/lib/Dialect/TTNN/IR/CMakeLists.txt @@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRTTNNDialect TTMLIRTTNNUtils MLIRSCFToEmitC MLIRLinalgDialect + MLIRBufferizationDialect MLIRMLProgramDialect TTNNOpModelLib ) diff --git a/lib/RegisterAll.cpp b/lib/RegisterAll.cpp index 9ee4c30e16..e170112fa5 100644 --- a/lib/RegisterAll.cpp +++ b/lib/RegisterAll.cpp @@ -19,23 +19,44 @@ #include "ttmlir/Dialect/TTNN/IR/TTNN.h" #include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h" #include "ttmlir/Dialect/TTNN/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" + #ifdef TTMLIR_ENABLE_STABLEHLO #include "stablehlo/dialect/Register.h" #endif void mlir::tt::registerAllDialects(mlir::DialectRegistry ®istry) { - registry - .insert(); + registry.insert< + mlir::tt::TTDialect, mlir::tt::ttir::TTIRDialect, + mlir::tt::ttnn::TTNNDialect, mlir::tt::ttmetal::TTMetalDialect, + mlir::tt::ttkernel::TTKernelDialect, mlir::func::FuncDialect, + mlir::arith::ArithDialect, mlir::ml_program::MLProgramDialect, + mlir::tensor::TensorDialect, mlir::linalg::LinalgDialect, + mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect, + mlir::tosa::TosaDialect, mlir::vector::VectorDialect, + mlir::emitc::EmitCDialect, mlir::bufferization::BufferizationDialect>(); #if TTMLIR_ENABLE_STABLEHLO mlir::stablehlo::registerAllDialects(registry); #endif + arith::registerBufferizableOpInterfaceExternalModels(registry); + linalg::registerBufferizableOpInterfaceExternalModels(registry); + scf::registerBufferizableOpInterfaceExternalModels(registry); + bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( + registry); + tensor::registerBufferizableOpInterfaceExternalModels(registry); + vector::registerBufferizableOpInterfaceExternalModels(registry); } void mlir::tt::registerAllExtensions(mlir::DialectRegistry ®istry) {