Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::MulOp>(op)) {
auto shiftVal = cast<tosa::MulOp>(op).getShift();
DenseElementsAttr shiftElem;
if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
return nullptr;
}

int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
bool shiftIsConstant = true;
int32_t shift = 0;
if (matchPattern(shiftVal, m_Constant(&shiftElem)))
shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
else
shiftIsConstant = false;

if (isa<FloatType>(elementTy)) {
if (shift != 0) {
Expand All @@ -147,23 +147,24 @@ static Value createLinalgBodyCalculationForElementwiseOp(
Value a = args[0];
Value b = args[1];

if (shift > 0) {
auto shiftConst =
arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8);
if (shift > 0 || !shiftIsConstant) {
Value shiftConst;
if (shiftIsConstant)
shiftConst =
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);

if (!a.getType().isInteger(32))
a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);

if (!b.getType().isInteger(32))
b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);

auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
auto result = tosa::ApplyScaleOp::create(
rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
rewriter, loc, rewriter.getI32Type(), a, b, shiftAmount,
rewriter.getStringAttr("SINGLE_ROUND"));

if (elementTy.isInteger(32))
return result;

return arith::TruncIOp::create(rewriter, loc, elementTy, result);
return result;
}

int aWidth = a.getType().getIntOrFloatBitWidth();
Expand Down Expand Up @@ -918,6 +919,18 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
if (operands.size() == 1)
return operands;

// No need to broadcast for static shape
bool hasDynamic = false;
for (auto op : operands) {
const auto tType = dyn_cast<RankedTensorType>(op.getType());
if (tType && !tType.hasStaticShape()) {
hasDynamic = true;
break;
}
}
if (!hasDynamic)
return operands;

// Broadcast dynamic dimensions operand by operand
return llvm::map_to_vector(operands, [&](Value operand) {
return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
Expand Down Expand Up @@ -990,8 +1003,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
static ValueRange getBroadcastableOperands(Operation *operation,
ValueRange operands) {
// Shift cannot broadcast
if (isa<tosa::MulOp>(operation))
return operands.take_front(2);
if (isa<tosa::MulOp>(operation)) {
DenseElementsAttr shiftElems;
// Shift cannot broadcast when it is constant
if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
return operands.take_front(2);
else
return operands.take_front(3);
}
// Input1_zp and output_zp cannot broadcast
if (isa<tosa::NegateOp>(operation))
return operands.take_front(1);
Expand Down
8 changes: 0 additions & 8 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,3 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

// -----

func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
// expected-error@+1 {{failed to legalize operation 'tosa.mul'}}
%0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
}
11 changes: 11 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2471,3 +2471,14 @@ func.func @test_0d_input(%arg0: tensor<i32>) -> () {

return
}

// -----

// CHECK-LABEL: @mul_no_const_shift
func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
// CHECK: linalg.generic
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i8, %[[OUT:.*]]: i32):
// CHECK: tosa.apply_scale %[[ARG0]], %[[ARG1]], %[[ARG2]]
%0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
}