diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h index 096510a09e324..3942113880ce5 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h @@ -243,6 +243,10 @@ bool getConstShapeValues(Operation *op, // returns a small vector of int64_t values that attr contains SmallVector convertFromIntAttr(const DenseElementsAttr &attr, const int rank); + +// Returns the attribute that stores the constant value of a ConstantLike +// operation. Prerequisite is `op` to be a `ConstantLike` operation. +Attribute getConstantAttribute(Operation *op); } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 86f5e9baf4a94..e84c8bcfc11ba 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -23,10 +23,11 @@ #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/ADT/TypeSwitch.h" #include #include @@ -118,6 +119,71 @@ static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source, /*symbolCount=*/0, sourceDims, rewriter.getContext()); } +static mlir::Value createScalarConstantFromTensor(PatternRewriter &rewriter, + Operation *source, + Value result) { + // Get the constant as the attribute from the constant operation + Attribute value = tosa::getConstantAttribute(source); + auto attr = dyn_cast(value); + + // Ensure the constant is splat so we can convert to a scalar + if (!attr) { + return Value(); + } + + // Filter for constants based on Ranked Tensors + auto resultTy = dyn_cast(result.getType()); + if (!resultTy) { + return Value(); + } + + // Create a scalar constant with the same type as the result tensor. + // We assume the ResultType follows the TOSA spec, in that it can be an + // accumulator type that is same as or larger in bitwidth than the splat + // constant. + Value scalarValue = + llvm::TypeSwitch(attr.getSplatValue()) + .Case([&](FloatAttr attr) { + return rewriter + // Create a float constant with the same type as the result + // tensor and use the host systems double type as APFloat + // checks bitwidths so in the case of different input -> output + // types the conversion will fail. + .create( + source->getLoc(), + FloatAttr::get(resultTy.getElementType(), + attr.getValue().convertToDouble())) + .getResult(); + }) + .Case([&](IntegerAttr attr) { + // At the moment all profiles are signed, so for the unsigned case + // if it does happen bail out. + if (resultTy.getElementType().isUnsignedInteger()) { + return Value(); + } + // Create a scalar that follows the result type. In the case of i8, + // the result can be i32. So we perform the conversion at + // compile-time. + return rewriter + .create( + source->getLoc(), + IntegerAttr::get(resultTy.getElementType(), + attr.getValue().getSExtValue())) + .getResult(); + }) + .Default([](Attribute) { return Value(); }); + + // Could not create a scalar constant due to an unsupported type + if (!scalarValue) { + return Value(); + } + + return rewriter + .create(source->getLoc(), ValueRange{scalarValue}, + ValueRange{result}) + .getResult(0); +} + // Broadcast the source value to all the outer dimensions of the result value. // If required, the element type is expanded using an arith.extsi or arith.extf // operation as appropriate. @@ -126,6 +192,17 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, Value result) { ShapedType resultTy = cast(result.getType()); const int64_t resultRank = resultTy.getRank(); + + // Attempt to create a FillOp in linalg if the constant is a splat value. + if (source.getDefiningOp() && + matchPattern(source.getDefiningOp(), m_Constant())) { + auto scalar = createScalarConstantFromTensor( + rewriter, source.getDefiningOp(), result); + if (scalar) { + return scalar; + } + } + // Creating maps for the input and output of the broacast-like generic op. SmallVector indexingMaps; indexingMaps.push_back(getBroadcastingMap(rewriter, source, result)); diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index e1b3be74b50fd..b93793cfeb036 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -213,3 +213,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) { } return {}; } + +Attribute mlir::tosa::getConstantAttribute(Operation *op) { + + if (!op || !op->hasTrait()) + return Attribute(); + + if (auto constOp = dyn_cast(op)) { + return constOp.getValues(); + } + + // TOSA names constants in the operation as "value" while linalg names them + // with "values". Here we search for both and find the first. + const SmallVector possibleAttributes = {"value", "values"}; + for (llvm::StringRef name : possibleAttributes) { + if (op->hasAttr(name)) { + return op->getAttr(name); + } + } + return Attribute(); +} diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index a737a8a05bae6..29116dce80868 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -672,6 +672,63 @@ func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<2 // ----- +// CHECK-LABEL: @conv2d_bias_broadcast_f32 +func.func @conv2d_bias_broadcast_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>) -> () { + %bias = "tosa.const"() <{values = dense<4.20> : tensor<28xf32>}> : () -> tensor<28xf32> + // CHECK-DAG: %[[CST:.+]] = arith.constant 4.200000e+00 : f32 + // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x45x40x28xf32> + // CHECK: %[[BIAS:.+]] = linalg.fill + // CHECK-SAME: ins(%[[CST]] + // CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<1x45x40x28xf32> + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc + // CHECK-SAME: outs(%[[BIAS]] + %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32> + return +} + +// ----- + +// CHECK-LABEL: @conv2d_dynamic_batch_bias_broadcast_f32 +// CHECK-SAME: (%[[INPUT:.+]]: tensor +func.func @conv2d_dynamic_batch_bias_broadcast_f32(%input: tensor, %weights: tensor<28x3x3x27xf32>) -> () { + %bias = "tosa.const"() <{values = dense<4.20> : tensor<28xf32>}> : () -> tensor<28xf32> + // CHECK: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %[[DIM:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor + // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor + // CHECK: %[[CST:.+]] = arith.constant 4.200000e+00 : f32 + // CHECK: %[[BIAS:.+]] = linalg.fill + // CHECK-SAME: ins(%[[CST]] + // CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc + // CHECK-SAME: outs(%[[BIAS]] + %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor, tensor<28x3x3x27xf32>, tensor<28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @conv2d_bias_broadcast_i8_acc_i32 +func.func @conv2d_bias_broadcast_i8_acc_i32(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x3x3x27xi8>) -> () { + %bias = "tosa.const"() <{values = dense<42> : tensor<28xi8>}> : () -> tensor<28xi8> + // CHECK-DAG: %[[CST:.+]] = arith.constant 42 : i32 + // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x45x40x28xi32> + // CHECK: %[[BIAS:.+]] = linalg.fill + // CHECK-SAME: ins(%[[CST]] + // CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<1x45x40x28xi32> + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc + // CHECK-SAME: outs(%[[BIAS]] + %input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array} : (tensor<1x49x42x27xi8>, tensor<28x3x3x27xi8>, tensor<28xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x45x40x28xi32> + return +} + +// ----- + // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>