From 5272da502de45a70d0998aef107d2d803a06f948 Mon Sep 17 00:00:00 2001 From: Vincent Wells Date: Wed, 20 Nov 2024 10:27:16 -0600 Subject: [PATCH] Add basic conversion between ttir and linalg --- .../Conversion/TTIRToLinAlg/TTIRToLinAlg.h | 20 +++ lib/Conversion/TTIRToLinAlg/TTIRToLinAlg.cpp | 165 ++++++++++++++++++ .../TTIRToLinAlg/TTIRToLinAlgPass.cpp | 63 +++++++ 3 files changed, 248 insertions(+) create mode 100644 include/ttmlir/Conversion/TTIRToLinAlg/TTIRToLinAlg.h create mode 100644 lib/Conversion/TTIRToLinAlg/TTIRToLinAlg.cpp create mode 100644 lib/Conversion/TTIRToLinAlg/TTIRToLinAlgPass.cpp diff --git a/include/ttmlir/Conversion/TTIRToLinAlg/TTIRToLinAlg.h b/include/ttmlir/Conversion/TTIRToLinAlg/TTIRToLinAlg.h new file mode 100644 index 0000000000..8c4f9b6d34 --- /dev/null +++ b/include/ttmlir/Conversion/TTIRToLinAlg/TTIRToLinAlg.h @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_CONVERSION_TTIRTOLINALG_TTIRTOLINALG_H +#define TTMLIR_CONVERSION_TTIRTOLINALG_TTIRTOLINALG_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::tt { + +void populateTTIRToLinAlgPatterns(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter); + +std::unique_ptr> createConvertTTIRToLinAlgPass(); + +} // namespace mlir::tt + +#endif // TTMLIR_CONVERSION_TTIRTOLINALG_TTIRTOLINALG_H diff --git a/lib/Conversion/TTIRToLinAlg/TTIRToLinAlg.cpp b/lib/Conversion/TTIRToLinAlg/TTIRToLinAlg.cpp new file mode 100644 index 0000000000..e490ec9468 --- /dev/null +++ b/lib/Conversion/TTIRToLinAlg/TTIRToLinAlg.cpp @@ -0,0 +1,165 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Conversion/TTIRToLinAlg/TTIRToLinAlg.h" + +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +using namespace mlir::tt; + +namespace { +template +class ElementwiseOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TTIROpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector resultTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + resultTypes))) { + return failure(); + } + + rewriter.replaceOpWithNewOp( + op, resultTypes, adaptor.getInputs(), adaptor.getOutputs()); + return success(); + } +}; + +class SubtractOpConversionPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(ttir::SubtractOp srcOp, ttir::SubtractOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType lhsType = + mlir::cast(adaptor.getInputs().front().getType()); + RankedTensorType rhsType = + mlir::cast(adaptor.getInputs().back().getType()); + + if (lhsType.getShape() == rhsType.getShape()) { + rewriter.replaceOpWithNewOp( + srcOp, adaptor.getInputs().front(), adaptor.getInputs().back(), + adaptor.getOutputs().front()); + + // Broadcast for rhs operand require the operation to be commutative to + // allow switching the order of operands. To allow this conversion, the + // following conversion is applied to SubtractOp: subtractOp(lhs,rhs) -> + // addOp(lhs, negOp(rhs)) + + } else { + Value device = getOrInsertDevice(rewriter, srcOp); + tensor::EmptyOp negEmptyOp = rewriter.create( + srcOp.getLoc(), this->getTypeConverter()->convertType(rhsType), + device); + linalg::NegOp negOp = rewriter.create( + srcOp.getLoc(), adaptor.getInputs().back(), negEmptyOp); + + rewriter.replaceOpWithNewOp( + srcOp, adaptor.getInputs().front(), negOp.getResults().front(), + adaptor.getOutputs().front()); + } + + return success(); + } +}; + +} // namespace + +namespace mlir::tt { + +void populateTTIRToLinAlgPatterns(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + // clang-format off + // ANCHOR: op_rewriter_pattern_set + patterns + .add< + // TensorEmptyConversionPattern, + // ToLayoutOpConversionPattern, + // ElementwiseOpConversionPattern, + ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseOpConversionPattern, + // ElementwiseUnaryWithFloatParameterOpConversionPattern, + // ReductionOpConversionPattern, + // ReductionOpConversionPattern, + // ReductionOpConversionPattern, + // BroadcastOpConversionPattern, + // EmbeddingOpConversionPattern, + // SoftmaxOpConversionPattern, + // TransposeOpConversionPattern, + // TypecastOpConversionPattern, + // ClampOpConversionPattern, + // ConcatOpConversionPattern, + // ReshapeOpConversionPattern, + // SliceOpConversionPattern, + // SqueezeOpConversionPattern, + // UnsqueezeOpConversionPattern, + // ConstantOpConversionPattern, + // MatmulOpConversionPattern, + // Conv2dOpConversionPattern, + // MaxPool2dOpConversionPattern, + SubtractOpConversionPattern + // AllGatherOpConversionPattern + >(typeConverter, ctx); + // ANCHOR_END: op_rewriter_pattern_set + // clang-format on +} + +} // namespace mlir::tt diff --git a/lib/Conversion/TTIRToLinAlg/TTIRToLinAlgPass.cpp b/lib/Conversion/TTIRToLinAlg/TTIRToLinAlgPass.cpp new file mode 100644 index 0000000000..ddb898ca99 --- /dev/null +++ b/lib/Conversion/TTIRToLinAlg/TTIRToLinAlgPass.cpp @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Conversion/TTIRToLinAlg/TTIRToLinAlg.h" + +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "ttmlir/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.h" +#include "ttmlir/Dialect/TTIR/IR/TTIR.h" +#include + +using namespace mlir; +using namespace mlir::tt; + +namespace mlir::tt::ttir { + +#define GEN_PASS_DEF_CONVERTTTIRTOLINALG +#include "ttmlir/Conversion/Passes.h.inc" + +} // namespace mlir::tt::ttir + +namespace { + +struct ConvertTTIRToLinAlgPass + : public ttir::impl::ConvertTTIRToLinAlgBase { + void runOnOperation() final { + mlir::ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalDialect(); + + TypeConverter typeConverter; + // All types map 1:1. + typeConverter.addConversion([](Type type) { return type; }); + + RewritePatternSet patterns(&getContext()); + populateTTIRToLinAlgPatterns(&getContext(), patterns, typeConverter); + + // Apply full conversion + // + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) { + signalPassFailure(); + return; + } + } +}; + +} // namespace + +namespace mlir::tt { + +std::unique_ptr> createConvertTTIRToLinAlgPass() { + return std::make_unique(); +} + +} // namespace mlir::tt