Skip to content

Commit

Permalink
try to fix python builds
Browse files Browse the repository at this point in the history
  • Loading branch information
vwellsTT committed Jan 13, 2025
1 parent 2fd1b2c commit 7097494
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 62 deletions.
25 changes: 25 additions & 0 deletions include/ttmlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,31 @@ def ConvertTTKernelToEmitC : Pass<"convert-ttkernel-to-emitc", "::func::FuncOp">

def ConvertTTIRToLinalg: Pass<"convert-ttir-to-linalg", "::mlir::ModuleOp"> {
let summary = "Convert TTIR dialect to Linalg dialect.";
let description = [{
Conversion pass to convert TTIR ops with defined conversion pattern into linalg ops, with broadcast and collapse tensor ops as needed.
Example:
Input:
func.func @add_with_broadcast(
%arg0: tensor<32x32xf32>,
%arg1: tensor<32x1xf32>,
%arg2: tensor<32x32xf32>
) -> tensor<32x32xf32> {
%1 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x32xf32>, tensor<32x1xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
return %1 : tensor<32x32xf32>
}
Output:
func.func @add_with_broadcast(
%arg0: tensor<32x32xf32>,
%arg1: tensor<32x1xf32>,
%arg2: tensor<32x32xf32>
) -> tensor<32x32xf32> {
%collapsed = tensor.collapse_shape %arg1 [[0, 1]] : tensor<32x1xf32> into tensor<32xf32>
%0 = tensor.empty() : tensor<32x32xf32>
%broadcasted = linalg.broadcast ins(%collapsed : tensor<32xf32>) outs(%0 : tensor<32x32xf32>) dimensions = [1]
%1 = linalg.add ins(%arg0, %broadcasted : tensor<32x32xf32>, tensor<32x32xf32>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32>
return %1 : tensor<32x32xf32>
}
}];
let constructor = "createConvertTTIRToLinalgPass()";
let dependentDialects = ["mlir::tt::ttir::TTIRDialect", "mlir::linalg::LinalgDialect"];
}
Expand Down
76 changes: 14 additions & 62 deletions lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
Expand All @@ -19,54 +20,15 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"

#include <cstdint>

using namespace mlir;
using namespace mlir::tt;

namespace {

using TensorRanks = SmallVector<int64_t, 2>;

static LogicalResult computeBroadcastedShape(SmallVector<Value, 3> inputs,
TensorRanks &broadcastedShape) {
broadcastedShape.clear();

// First find the maximum rank
int64_t maxRank = 0;
for (Value input : inputs) {
auto type = dyn_cast<RankedTensorType>(input.getType());
if (!type) {
return failure();
}
maxRank = std::max(maxRank, type.getRank());
}

// Initialize broadcastedShape to the right size, one-filled.
broadcastedShape = TensorRanks(maxRank, 1);

// From right-to-left, replace target dim with any non-1 values we encounter
// in inputs, returning failure if we find incompatible ranks.
for (Value input : inputs) {
auto type = dyn_cast<RankedTensorType>(input.getType());
const ArrayRef<int64_t> shape = type.getShape();

for (int64_t i = 0; i < maxRank; ++i) {
// Work from right to left
size_t rightIdx = maxRank - 1 - i;
size_t inputRightIdx = shape.size() - 1 - i;

int64_t targetDim = broadcastedShape[rightIdx];
int64_t inputDim =
inputRightIdx < shape.size() ? shape[inputRightIdx] : 1;

if (targetDim != inputDim && targetDim != 1 && inputDim != 1) {
return failure();
}
broadcastedShape[rightIdx] = std::max(targetDim, inputDim);
}
}
return success();
}

// Helper func to check which dims need to be broadcast and which need to be
// collapsed. Assumes that inputShape is broadcast-able to targetShape.
static void getDimsToBroadcastAndCollapse(
Expand All @@ -82,8 +44,6 @@ static void getDimsToBroadcastAndCollapse(

while (targetIdx >= 0) {
if (inputIdx >= 0) {
llvm::outs() << inputShape[inputIdx] << " vs " << targetShape[targetIdx]
<< "\n";
// This should be impossible since we verify input while computing
// targetShape.
assert(
Expand All @@ -101,12 +61,6 @@ static void getDimsToBroadcastAndCollapse(
targetIdx--;
}

llvm::outs() << "Found dims to broadcast: ";
for (const auto dim : broadcastDims) {
llvm::outs() << dim << " ";
}
llvm::outs() << "\n";

// Group non-broadcast dimensions together for collapse.
TensorRanks currentGroup;
size_t nextBroadcastDimIdx = 0;
Expand All @@ -132,9 +86,10 @@ static void getDimsToBroadcastAndCollapse(
}
}

// Conversion pattern of operations which have exactly 2 input and 1 output operands.
template <typename TTIROpTy, typename LinalgOpTy,
typename OpAdaptor = typename TTIROpTy::Adaptor>
class ElementwiseOpConversionPattern : public OpConversionPattern<TTIROpTy> {
class ElementwiseBinaryOpConversionPattern : public OpConversionPattern<TTIROpTy> {
public:
using OpConversionPattern<TTIROpTy>::OpConversionPattern;

Expand All @@ -145,17 +100,14 @@ class ElementwiseOpConversionPattern : public OpConversionPattern<TTIROpTy> {

// First, compute broadcasted shape from operands.
SmallVector<Value, 3> inputs = adaptor.getInputs();
llvm::outs() << "wtf\n";
TensorRanks broadcastedShape;
if (failed(computeBroadcastedShape(inputs, broadcastedShape))) {
return rewriter.notifyMatchFailure(op, "Operands are not broadcastable");
}
assert(inputs.size() == 2 && "binary element-wise operations must have 2 inputs!");
ArrayRef<int64_t> input0Shape = dyn_cast<RankedTensorType>(inputs[0].getType()).getShape();
ArrayRef<int64_t> input1Shape = dyn_cast<RankedTensorType>(inputs[1].getType()).getShape();

llvm::outs() << "target rank = [";
for (const auto rank : broadcastedShape) {
llvm::outs() << rank << " ";
SmallVector<int64_t, 4> broadcastedShape;
if (!OpTrait::util::getBroadcastedShape(input0Shape, input1Shape, broadcastedShape)) {
return rewriter.notifyMatchFailure(op, "Operands are not broadcastable--this should be impossible!");
}
llvm::outs() << "]\n";

// Replace any inputs which aren't in target shape with broadcast results
// which are.
Expand Down Expand Up @@ -214,9 +166,9 @@ namespace mlir::tt {

void populateTTIRToLinalgPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<ElementwiseOpConversionPattern<ttir::AddOp, linalg::AddOp>,
ElementwiseOpConversionPattern<ttir::MultiplyOp, linalg::MulOp>,
ElementwiseOpConversionPattern<ttir::SubtractOp, linalg::SubOp>>(
patterns.add<ElementwiseBinaryOpConversionPattern<ttir::AddOp, linalg::AddOp>,
ElementwiseBinaryOpConversionPattern<ttir::MultiplyOp, linalg::MulOp>,
ElementwiseBinaryOpConversionPattern<ttir::SubtractOp, linalg::SubOp>>(
typeConverter, ctx);
}

Expand Down
6 changes: 6 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ declare_mlir_python_extension(TTMLIRPythonExtensions.Main
${translation_libs}
TTMLIRStatic
${extension_libs}
MLIRLinalgTransforms
MLIRArithTransforms
MLIRSCFTransforms
MLIRFuncTransforms
MLIRTensorTransforms
MLIRVectorTransforms
)

set(TTMLIR_PYTHON_SOURCES
Expand Down

0 comments on commit 7097494

Please sign in to comment.