-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Tosa] fix fp16/bf16 support for Clamp min/max attributes #69192
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: None (fabrizio-indirli) ChangesIn TOSA MLIR dialect, fix the definition of the Clamp op to accept fp16 & bf16 datatypes for the min_fp and max_fp attributes. Full diff: https://github.com/llvm/llvm-project/pull/69192.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index a80111aedfe0b59..59220bffa27ca7e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -380,8 +380,8 @@ def Tosa_ClampOp : Tosa_ElementwiseOp<"clamp"> {
Tosa_Tensor:$input,
I64Attr:$min_int,
I64Attr:$max_int,
- F32Attr:$min_fp,
- F32Attr:$max_fp
+ Tosa_FloatAttr:$min_fp,
+ Tosa_FloatAttr:$max_fp
);
let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index e39f4662e791918..34dd30f37bf4d62 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -197,6 +197,17 @@ def Tosa_IntArrayAttrUpto2 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<2>
def Tosa_IntArrayAttrUpto4 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<4>]>;
def Tosa_IntArrayAttrUpto5 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<5>]>;
+def Tosa_FloatAttr : Attr<And<[CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
+ Or<[CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isBF16()">,
+ CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isF16()">,
+ CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isF32()">,
+ ]>
+ ]>,
+ "BF16/F16/F32 float attribute"> {
+ let storageType = [{ ::mlir::FloatAttr }];
+ let returnType = [{ ::mlir::APFloat }];
+}
+
//===----------------------------------------------------------------------===//
// Iterable attributes.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e62bea515d06baa..e18943a0fd613d5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -124,6 +124,20 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: clamp_f16
+func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> {
+ %0 = tosa.clamp %arg0 {min_fp = 0.0 : f16, max_fp = 1.0: f16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf16>) -> tensor<13x21x3xf16>
+ return %0 : tensor<13x21x3xf16>
+}
+
+// -----
+// CHECK-LABEL: clamp_bf16
+func.func @test_clamp_bf16(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> {
+ %0 = tosa.clamp %arg0 {min_fp = 0.0 : bf16, max_fp = 1.0: bf16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16>
+ return %0 : tensor<13x21x3xbf16>
+}
+
// -----
// CHECK-LABEL: sigmoid
func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
|
@eric-k256 @GeorgeARM |
Looks good. My only thought is if this is too restrictive. I know that tosa defines only |
A verifier would be nice to have to check attribute/in/out alignment. I believe that the issue with adding all float types is that bfloat16 doesn't show as a normal float type. Maybe all float + bf16 would work. |
Thank you both for the reviews!
Any thoughts on which solution would be better? |
15b7398
to
0ef8ec4
Compare
@GeorgeARM @eric-k256
|
In TOSA MLIR dialect, fix the definition of the Clamp op to accept fp16 & bf16 datatype for the min_fp and max_fp attributes. Add ClampOp verifier to check attributes types compatibility. Add related test cases in Tosa/ops.mlir. Signed-off-by: Fabrizio Indirli <[email protected]>
0ef8ec4
to
96f8d12
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
In TOSA MLIR dialect, fix the definition of the Clamp op to
accept fp16 & bf16 datatype for the min_fp and max_fp attributes.
Add ClampOp verifier to check attributes types compatibility.
Add related test cases in Tosa/ops.mlir.