Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Tosa] fix fp16/bf16 support for Clamp min/max attributes #69192

Merged
merged 1 commit into from
Oct 20, 2023

Conversation

fabrizio-indirli
Copy link
Contributor

@fabrizio-indirli fabrizio-indirli commented Oct 16, 2023

In TOSA MLIR dialect, fix the definition of the Clamp op to
accept fp16 & bf16 datatype for the min_fp and max_fp attributes.
Add ClampOp verifier to check attributes types compatibility.
Add related test cases in Tosa/ops.mlir.

@llvmbot
Copy link

llvmbot commented Oct 16, 2023

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: None (fabrizio-indirli)

Changes

In TOSA MLIR dialect, fix the definition of the Clamp op to accept fp16 & bf16 datatypes for the min_fp and max_fp attributes.
Add related test cases in Tosa/ops.mlir.


Full diff: https://github.com/llvm/llvm-project/pull/69192.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+11)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index a80111aedfe0b59..59220bffa27ca7e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -380,8 +380,8 @@ def Tosa_ClampOp : Tosa_ElementwiseOp<"clamp"> {
     Tosa_Tensor:$input,
     I64Attr:$min_int,
     I64Attr:$max_int,
-    F32Attr:$min_fp,
-    F32Attr:$max_fp
+    Tosa_FloatAttr:$min_fp,
+    Tosa_FloatAttr:$max_fp
   );
 
   let results = (outs
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index e39f4662e791918..34dd30f37bf4d62 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -197,6 +197,17 @@ def Tosa_IntArrayAttrUpto2 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<2>
 def Tosa_IntArrayAttrUpto4 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<4>]>;
 def Tosa_IntArrayAttrUpto5 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<5>]>;
 
+def Tosa_FloatAttr : Attr<And<[CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
+                                Or<[CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isBF16()">,
+                                    CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isF16()">,
+                                    CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isF32()">,
+                                    ]>
+                                  ]>,
+                     "BF16/F16/F32 float attribute"> {
+  let storageType = [{ ::mlir::FloatAttr }];
+  let returnType = [{ ::mlir::APFloat }];
+}
+
 //===----------------------------------------------------------------------===//
 // Iterable attributes.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e62bea515d06baa..e18943a0fd613d5 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -124,6 +124,20 @@ func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
   return %0 : tensor<13x21x3xf32>
 }
 
+// -----
+// CHECK-LABEL: clamp_f16
+func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> {
+  %0 = tosa.clamp %arg0 {min_fp = 0.0 : f16, max_fp = 1.0: f16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf16>) -> tensor<13x21x3xf16>
+  return %0 : tensor<13x21x3xf16>
+}
+
+// -----
+// CHECK-LABEL: clamp_bf16
+func.func @test_clamp_bf16(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> {
+  %0 = tosa.clamp %arg0 {min_fp = 0.0 : bf16, max_fp = 1.0: bf16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16>
+  return %0 : tensor<13x21x3xbf16>
+}
+
 // -----
 // CHECK-LABEL: sigmoid
 func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {

@fabrizio-indirli
Copy link
Contributor Author

@eric-k256 @GeorgeARM
Hey, would be glad to get your feedback on this PR when you have time :)
(not sure if I should add someone else as well ...)

@GeorgeARM
Copy link
Contributor

Looks good. My only thought is if this is too restrictive. I know that tosa defines only bf16, fp16, fp32 at the moment but I think that is probably a good idea to allow any float type and just check that the input type, the output type and the attributes type if float match. @eric-k256 what is your take on this?

@eric-k256
Copy link
Contributor

A verifier would be nice to have to check attribute/in/out alignment. I believe that the issue with adding all float types is that bfloat16 doesn't show as a normal float type. Maybe all float + bf16 would work.

@fabrizio-indirli
Copy link
Contributor Author

Thank you both for the reviews!
My understanding is that the storage class mlir::FloatAttr accepts also Bf16, but in MLIR the restricted attribute FloatAttrBase pre-defined in CommonAttrConstraints.td (and used, for example, to define F32Attr etc.) does not support Bf16.
Thus we need to define our own attribute type.
We can create a generic Float attribute if we want, though for now Tosa accepts only Fp16/Bf16/Fp32. We could do this:

  • In the TOSA dialect (e.g. calling it Tosa_FloatAttr), as done here
  • In mlir/IR/CommonAttrConstrainst.td, so that potentially other dialects could use this. However this may require other approvals as we would be modifying the generic IR? Moreover something similar already seems to exist with Builtin_FloatAttr, although I was not able to use it here.

Any thoughts on which solution would be better?

@fabrizio-indirli
Copy link
Contributor Author

fabrizio-indirli commented Oct 19, 2023

@GeorgeARM @eric-k256
I updated the patch so that min_fp & max_fp are generic Float attributes (defined as Tosa_FloatAttr in TosaTypesBase.td ) and added a verifier to check that the types are aligned.
In particular, it checks that:

  • Input and output datatype are the same
  • If datatype is float, checks that the min_fp type == max_fp type, and that this attributes' type is the same of the input/output datatype OR has an higher bitwidth.
    Supporting this last case (min/max_fp with an higher precision than the input/output float type) should allow to retain the interface unchanged and to pass existing tests that had "f32" attributes and "f16"/"bf16" input/output datatype

In TOSA MLIR dialect, fix the definition of the Clamp op to
accept fp16 & bf16 datatype for the min_fp and max_fp attributes.
Add ClampOp verifier to check attributes types compatibility.
Add related test cases in Tosa/ops.mlir.

Signed-off-by: Fabrizio Indirli <[email protected]>
Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants