Skip to content

Commit

Permalink
[mlir][tosa] Implement dynamic shape support for tosa.max_pool2d lowe…
Browse files Browse the repository at this point in the history
…ring (#87538)

The existing lowering for tosa.max_pool2d only supports dynamic
dimensions when the dynamic dimension is the batch dimension. This
change updates the lowering to support arbitrary dynamic dimensions on
the inputs and outputs of the tosa.max_pool2d operation.

This change also fixes a bug in the implementation of implicit
broadcasting in the tosa-to-linalg pass, which was introducing uses of
constant ops that violated dominance requirements.
  • Loading branch information
sabauma authored Apr 16, 2024
1 parent ac1f2de commit 1c076b4
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 37 deletions.
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
// to not include any remaining unranked tensors.
def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;

def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>]>;
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>]>;
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>]>;
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>]>;
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>]>;
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;

// Ranked tensors up to given rank.
def Tosa_Tensor1Dto4D : AnyTypeOf<[
Expand Down
12 changes: 8 additions & 4 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,11 +766,15 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,

// Emit 'then' region of 'scf.if'
auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
// It is not safe to cache constants across regions.
// New constants could potentially violate dominance requirements.
IndexPool localPool;

// Emit 'tensor.empty' op
SmallVector<OpFoldResult> outputTensorShape;
for (auto index : llvm::seq<int64_t>(0, rank)) {
auto size = index == dim ? targetSize
: getOrFoldTensorDim(rewriter, loc, indexPool,
: getOrFoldTensorDim(rewriter, loc, localPool,
operand, index);
outputTensorShape.push_back(size);
}
Expand Down Expand Up @@ -812,9 +816,9 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value operand,
ArrayRef<OpFoldResult> targetShape,
ArrayRef<Value> masterOperands) {
size_t rank = operand.getType().cast<RankedTensorType>().getRank();
assert(targetShape.size() == rank);
assert(masterOperands.size() == rank);
int64_t rank = operand.getType().cast<RankedTensorType>().getRank();
assert((int64_t)targetShape.size() == rank);
assert((int64_t)masterOperands.size() == rank);
for (auto index : llvm::seq<int64_t>(0, rank))
operand =
broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
Expand Down
88 changes: 64 additions & 24 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "mlir/Interfaces/InferTypeOpInterface.h"

#include <numeric>
#include <type_traits>

Expand All @@ -34,7 +36,7 @@ using namespace mlir::tosa;

static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
TypedAttr padAttr, OpBuilder &rewriter) {
// Input should be padded if necessary.
// Input should be padded only if necessary.
if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
return input;

Expand All @@ -47,7 +49,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
SmallVector<int64_t, 4> paddedShape;
SmallVector<OpFoldResult, 8> lowIndices;
SmallVector<OpFoldResult, 8> highIndices;
for (int i = 0, s = inputShape.size(); i < s; i++) {
for (size_t i : llvm::seq(inputShape.size())) {
auto lowPad = pad[i * 2];
auto highPad = pad[i * 2 + 1];
if (ShapedType::isDynamic(inputShape[i]))
Expand Down Expand Up @@ -131,20 +133,19 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,

static mlir::Value reifyConstantDim(int64_t attr,
ImplicitLocOpBuilder &builder) {
return builder.createOrFold<arith::IndexCastOp>(
builder.getIndexType(),
builder.create<arith::ConstantOp>(builder.getI64IntegerAttr(attr)));
return builder.create<arith::ConstantIndexOp>(attr);
}

// Calculating the output width/height using the formula:
// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1

static mlir::Value getConvOutputDim(Location loc, Value inputDim,
int64_t padBeforeAttr, int64_t padAfterAttr,
Value kernelDim, int64_t strideAttr,
int64_t dilationAttr, Type inputETy,
OpBuilder &rewriter) {
static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
int64_t padBeforeAttr,
int64_t padAfterAttr, Value kernelDim,
int64_t strideAttr,
int64_t dilationAttr,
OpBuilder &rewriter) {
ImplicitLocOpBuilder builder(loc, rewriter);
auto one = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(inputDim.getType(), 1));
Expand All @@ -171,7 +172,6 @@ static SmallVector<Value> inferDynamicDimsForConv(
ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
ShapedType inputTy = cast<ShapedType>(input.getType());
Type inputETy = inputTy.getElementType();
int64_t inputRank = inputTy.getRank();

SmallVector<Value> dynDims;
Expand All @@ -190,8 +190,8 @@ static SmallVector<Value> inferDynamicDimsForConv(
rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
// H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
dynDims[inputDim] =
getConvOutputDim(loc, initDynDim, padTop, padBottom, kernelDynDim,
stride, dilation, inputETy, rewriter);
getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
kernelDynDim, stride, dilation, rewriter);
}
}

Expand Down Expand Up @@ -685,20 +685,61 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
public:
using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;

// Compute the dynamic output sizes of the maxpool operation.
static SmallVector<Value>
computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
TensorType resultTy = op.getType();
Location loc = op.getLoc();

TypedValue<TensorType> input = op.getInput();
ArrayRef<int64_t> kernel = op.getKernel();
ArrayRef<int64_t> pad = op.getPad();
ArrayRef<int64_t> stride = op.getStride();

SmallVector<Value> dynamicDims;

// Batch dimension
if (resultTy.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));

// Height/width dimensions
for (int64_t dim : {1, 2}) {
if (!resultTy.isDynamicDim(dim))
continue;

// Index into the attribute arrays
int64_t index = dim - 1;

// Input height/width
Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim);

// Kernel height/width
Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]);

// Output height/width
Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
pad[index * 2 + 1], khw, stride[index],
/*dilationAttr=*/1, rewriter);
dynamicDims.push_back(ohw);
}

// Channel dimension
if (resultTy.isDynamicDim(3))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3));

return dynamicDims;
}

LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value input = op.getInput();
ShapedType inputTy = cast<ShapedType>(input.getType());
TypedValue<TensorType> input = op.getInput();
ShapedType inputTy = input.getType();

ShapedType resultTy = cast<ShapedType>(op.getType());
ShapedType resultTy = op.getType();
Type resultETy = inputTy.getElementType();

auto dynamicDimsOr =
checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
if (!dynamicDimsOr.has_value())
return failure();
SmallVector<Value> dynamicDims = *dynamicDimsOr;
SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter);

// Determine what the initial value needs to be for the max pool op.
TypedAttr initialAttr;
Expand All @@ -721,6 +762,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
pad.resize(2, 0);
llvm::append_range(pad, op.getPad());
pad.resize(pad.size() + 2, 0);

Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);

Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
Expand All @@ -736,9 +778,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);

Value filledEmptyTensor =
rewriter
.create<linalg::FillOp>(loc, ValueRange{initialValue},
ValueRange{emptyTensor})
rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor)
.result();

Value fakeWindowDims =
Expand Down
54 changes: 54 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s

// CHECK-LABEL: @matmul
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
Expand Down Expand Up @@ -215,6 +216,59 @@ func.func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
return
}

// CHECK-CSE-LABEL: @max_pool_all_dynamic
func.func @max_pool_all_dynamic(%arg0: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
// Batch size
// CHECK-CSE: %[[C0:.+]] = arith.constant 0 : index
// CHECK-CSE: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] : tensor<?x?x?x?xf32>

// Compute output height
// CHECK-CSE: %[[C1:.+]] = arith.constant 1 : index
// CHECK-CSE: %[[IH:.+]] = tensor.dim %arg0, %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-CSE: %[[C2:.+]] = arith.constant 2 : index
// CHECK-CSE: %[[PADDED_BEFORE:.+]] = arith.addi %[[IH]], %[[C0]] : index
// CHECK-CSE: %[[PADDED_AFTER:.+]] = arith.addi %[[PADDED_BEFORE]], %[[C0]] : index
// CHECK-CSE: %[[SUB_ONE:.+]] = arith.subi %[[C2]], %[[C1]] : index
// CHECK-CSE: %[[DILATED:.+]] = arith.muli %[[C1]], %[[SUB_ONE]] : index
// CHECK-CSE: %[[ADD_ONE:.+]] = arith.addi %[[DILATED]], %[[C1]] : index
// CHECK-CSE: %[[SUBTRACT:.+]] = arith.subi %[[PADDED_AFTER]], %[[ADD_ONE]] : index
// CHECK-CSE: %[[DIVIDE:.+]] = arith.divui %[[SUBTRACT]], %[[C1]] : index
// CHECK-CSE: %[[HEIGHT:.+]] = arith.addi %[[DIVIDE]], %[[C1]] : index

// Compute output width
// CHECK-CSE: %[[IW:.+]] = tensor.dim %arg0, %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-CSE: %[[C5:.+]] = arith.constant 5 : index
// CHECK-CSE: %[[PADDED_BEFORE:.+]] = arith.addi %[[IW]], %[[C2]] : index
// CHECK-CSE: %[[PADDED_AFTER:.+]] = arith.addi %[[PADDED_BEFORE]], %[[C2]] : index
// CHECK-CSE: %[[SUB_ONE:.+]] = arith.subi %[[C5]], %[[C1]] : index
// CHECK-CSE: %[[DILATED:.+]] = arith.muli %[[C1]], %[[SUB_ONE]] : index
// CHECK-CSE: %[[ADD_ONE:.+]] = arith.addi %[[DILATED]], %[[C1]] : index
// CHECK-CSE: %[[SUBTRACT:.+]] = arith.subi %[[PADDED_AFTER]], %[[ADD_ONE]] : index
// CHECK-CSE: %[[DIVIDE:.+]] = arith.divui %[[SUBTRACT]], %[[C1]] : index
// CHECK-CSE: %[[WIDTH:.+]] = arith.addi %14, %[[C1]] : index

// Channel size
// CHECK-CSE: %[[C3:.+]] = arith.constant 3 : index
// CHECK-CSE: %[[CHANNEL:.+]] = tensor.dim %arg0, %[[C3]] : tensor<?x?x?x?xf32>

// Pad the input
// CHECK-CSE: %[[FLOAT_MIN:.+]] = arith.constant -3.40282347E+38 : f32
// CHECK-CSE: %[[PADDED:.+]] = tensor.pad %arg0 low[0, 0, 2, 0] high[0, 0, 2, 0] {
// CHECK-CSE: tensor.yield %[[FLOAT_MIN]] : f32

// Allocate the output and fill with minimum value
// CHECK-CSE: %[[INIT:.+]] = tensor.empty(%[[BATCH]], %[[HEIGHT]], %[[WIDTH]], %[[CHANNEL]]) : tensor<?x?x?x?xf32>
// CHECK-CSE: %[[FILL:.+]] = linalg.fill ins(%[[FLOAT_MIN]] : f32) outs(%[[INIT]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK-CSE: %[[FAKE_WINDOW:.+]] = tensor.empty() : tensor<2x5xf32>

// Compute max pool
// CHECK-CSE: %[[OUT:.+]] = linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PADDED]], %[[FAKE_WINDOW]] : tensor<?x?x?x?xf32>, tensor<2x5xf32>) outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK-CSE: return %[[OUT]]

%0 = tosa.max_pool2d %arg0 {kernel = array<i64: 2, 5>, pad = array<i64: 0, 0, 2, 2>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}

// -----

// CHECK-LABEL: @avg_pool_f32
Expand Down
12 changes: 8 additions & 4 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
// CHECK: %[[VAL_0:.*]] = tensor.dim %[[ARG0]], %[[CONST0]] : tensor<?x?xf32>
// CHECK: %[[VAL_1:.*]] = arith.cmpi eq, %[[VAL_0]], %[[CONST1]] : index
// CHECK: %[[ARG0_DIM0_BROADCAST:.*]] = scf.if %[[VAL_1]] -> (tensor<?x?xf32>) {
// CHECK: %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[CONST1]] : tensor<?x?xf32>
// CHECK: %[[LOCAL_CONST1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[LOCAL_CONST1]] : tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_2]]) : tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x?xf32>) outs(%[[VAL_3]] : tensor<?x?xf32>) {
// CHECK: ^bb0(%[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: f32):
Expand All @@ -284,7 +285,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
// CHECK: %[[VAL_7:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
// CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[CONST1]] : index
// CHECK: %[[ARG0_DIM1_BROADCAST:.*]] = scf.if %[[VAL_8]] -> (tensor<?x?xf32>) {
// CHECK: %[[VAL_9:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[CONST0]] : tensor<?x?xf32>
// CHECK: %[[LOCAL_CONST0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_9:.*]] = tensor.dim %[[ARG0_DIM0_BROADCAST]], %[[LOCAL_CONST0]] : tensor<?x?xf32>
// CHECK: %[[VAL_10:.*]] = tensor.empty(%[[VAL_9]], %[[MAX_DIM1]]) : tensor<?x?xf32>
// CHECK: %[[VAL_11:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0_DIM0_BROADCAST]] : tensor<?x?xf32>) outs(%[[VAL_10]] : tensor<?x?xf32>) {
// CHECK: ^bb0(%[[VAL_12:.*]]: f32, %[[VAL_13:.*]]: f32):
Expand All @@ -298,7 +300,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
// CHECK: %[[VAL_14:.*]] = tensor.dim %[[ARG1]], %[[CONST0]] : tensor<?x?xf32>
// CHECK: %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_14]], %[[CONST1]] : index
// CHECK: %[[ARG1_DIM0_BROADCAST:.*]] = scf.if %[[VAL_15]] -> (tensor<?x?xf32>) {
// CHECK: %[[VAL_16:.*]] = tensor.dim %[[ARG1]], %[[CONST1]] : tensor<?x?xf32>
// CHECK: %[[LOCAL_CONST1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_16:.*]] = tensor.dim %[[ARG1]], %[[LOCAL_CONST1]] : tensor<?x?xf32>
// CHECK: %[[VAL_17:.*]] = tensor.empty(%[[MAX_DIM0]], %[[VAL_16]]) : tensor<?x?xf32>
// CHECK: %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[VAL_17]] : tensor<?x?xf32>) {
// CHECK: ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32):
Expand All @@ -312,7 +315,8 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
// CHECK: %[[VAL_21:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST1]] : tensor<?x?xf32>
// CHECK: %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[CONST1]] : index
// CHECK: %[[ARG1_DIM1_BROADCAST:.*]] = scf.if %[[VAL_22]] -> (tensor<?x?xf32>) {
// CHECK: %[[VAL_23:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[CONST0]] : tensor<?x?xf32>
// CHECK: %[[LOCAL_CONST0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_23:.*]] = tensor.dim %[[ARG1_DIM0_BROADCAST]], %[[LOCAL_CONST0]] : tensor<?x?xf32>
// CHECK: %[[VAL_24:.*]] = tensor.empty(%[[VAL_23]], %[[MAX_DIM1]]) : tensor<?x?xf32>
// CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG1_DIM0_BROADCAST]] : tensor<?x?xf32>) outs(%[[VAL_24]] : tensor<?x?xf32>) {
// CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32):
Expand Down
Loading

0 comments on commit 1c076b4

Please sign in to comment.