Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][ROCDL] Remove patterns for ops supported as intrinsics in the AMDGPU backend #102971

Merged
merged 5 commits into from
Sep 4, 2024

Conversation

jsjodin
Copy link
Contributor

@jsjodin jsjodin commented Aug 12, 2024

This patch removes patterns for a few operations which allows mathToLLVM conversion to convert the operations into LLVM intrinsics instead since they are supported directly by the AMDGPU backend.

@llvmbot
Copy link
Collaborator

llvmbot commented Aug 12, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Jan Leyonberg (jsjodin)

Changes

This patch removes patterns for a few operations which allows mathToLLVM conversion to convert the operations into LLVM intrinsics instead since they are supported directly by the AMDGPU backend.


Full diff: https://github.com/llvm/llvm-project/pull/102971.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+4-8)
  • (modified) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir (+17-85)
  • (modified) mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir (-60)
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 7de6971ba2ee72..fd4eab0e10d67e 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -48,18 +48,20 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
 void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
   // Handled by mathToLLVM: math::AbsIOp
+  // Handled by mathToLLVM: math::AbsFIOp
   // Handled by mathToLLVM: math::CopySignOp
   // Handled by mathToLLVM: math::CountLeadingZerosOp
   // Handled by mathToLLVM: math::CountTrailingZerosOp
   // Handled by mathToLLVM: math::CgPopOp
+  // Handled by mathToLLVM: math::ExpOp
   // Handled by mathToLLVM: math::FmaOp
+  // Handled by mathToLLVM: math::LogOp
   // FIXME: math::IPowIOp
   // FIXME: math::FPowIOp
   // Handled by mathToLLVM: math::RoundEvenOp
   // Handled by mathToLLVM: math::RoundOp
+  // Handled by mathToLLVM: math::SqrtOp
   // Handled by mathToLLVM: math::TruncOp
-  populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
-                                   "__ocml_fabs_f64");
   populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
                                    "__ocml_acos_f64");
   populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
@@ -84,16 +86,12 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                    "__ocml_cosh_f64");
   populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
                                    "__ocml_sinh_f64");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
-                                  "__ocml_exp_f64");
   populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
                                    "__ocml_exp2_f64");
   populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
                                     "__ocml_expm1_f64");
   populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
                                     "__ocml_floor_f64");
-  populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
-                                  "__ocml_log_f64");
   populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
                                     "__ocml_log10_f64");
   populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
@@ -106,8 +104,6 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                     "__ocml_rsqrt_f64");
   populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
                                   "__ocml_sin_f64");
-  populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
-                                   "__ocml_sqrt_f64");
   populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
                                    "__ocml_tanh_f64");
   populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index bf49a42a115775..4f1f26e8794d9e 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -131,21 +131,6 @@ gpu.module @test_module {
 
 // -----
 
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_fabs
-  func.func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.absf %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.absf %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 gpu.module @test_module {
   // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
@@ -206,23 +191,6 @@ gpu.module @test_module {
 
 // -----
 
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_exp
-  func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %exp_f32 = math.exp %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-    %result32 = math.exp %exp_f32 : f32
-    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.exp %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 gpu.module @test_module {
   // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
@@ -239,21 +207,20 @@ gpu.module @test_module {
 }
 
 // -----
-
 // Test that we handled properly operation with SymbolTable other than module op
 gpu.module @test_module {
   "test.symbol_scope"() ({
     // CHECK: test.symbol_scope
-    // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
-    // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
-    // CHECK-LABEL: func @gpu_exp
-    func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-      %exp_f32 = math.exp %arg_f32 : f32
-      // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-      %result32 = math.exp %exp_f32 : f32
-      // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-      %result64 = math.exp %arg_f64 : f64
-      // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
+    // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
+    // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
+    // CHECK-LABEL: func @gpu_sin
+    func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+      %sin_f32 = math.sin %arg_f32 : f32
+      // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+      %result32 = math.sin %sin_f32 : f32
+      // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+      %result64 = math.sin %arg_f64 : f64
+      // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
       func.return %result32, %result64 : f32, f64
     }
     "test.finish" () : () -> ()
@@ -279,21 +246,6 @@ gpu.module @test_module {
 
 // -----
 
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_log_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_log
-  func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.log %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.log %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 gpu.module @test_module {
   // CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
@@ -359,26 +311,6 @@ gpu.module @test_module {
 
 // -----
 
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_sqrt
-  func.func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
-      -> (f16, f32, f64) {
-    %result16 = math.sqrt %arg_f16 : f16
-    // CHECK: llvm.fpext %{{.*}} : f16 to f32
-    // CHECK-NEXT: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
-    // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16
-    %result32 = math.sqrt %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.sqrt %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
-    func.return %result16, %result32, %result64 : f16, f32, f64
-  }
-}
-
-// -----
-
 gpu.module @test_module {
   // CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
@@ -472,15 +404,15 @@ gpu.module @test_module {
 gpu.module @test_module {
   // CHECK-LABEL: func @gpu_unroll
   func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
-    %result = math.exp %arg0 : vector<4xf32>
+    %result = math.sin %arg0 : vector<4xf32>
     // CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<4xf32>
-    // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
     // CHECK: %[[V1:.+]] = llvm.insertelement %[[CL]], %[[V0]]
-    // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
     // CHECK: %[[V2:.+]] = llvm.insertelement %[[CL]], %[[V1]]
-    // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
     // CHECK: %[[V3:.+]] = llvm.insertelement %[[CL]], %[[V2]]
-    // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
     // CHECK: %[[V4:.+]] = llvm.insertelement %[[CL]], %[[V3]]
     // CHECK: return %[[V4]]
     func.return %result : vector<4xf32>
@@ -526,9 +458,9 @@ gpu.module @test_module {
 
 gpu.module @module {
 // CHECK-LABEL: @spirv_exp
-// CHECK: llvm.call @__ocml_exp_f32
+// CHECK: llvm.call @__ocml_sin_f32
   spirv.func @spirv_exp(%arg0: vector<4xf32>) -> vector<4xf32> "None" {
-    %0 = math.exp %arg0 : vector<4xf32>
+    %0 = math.sin %arg0 : vector<4xf32>
     spirv.ReturnValue %0 : vector<4xf32>
   }
 }
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index a406ec45a7f109..9a05a94f9f1ac7 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -15,21 +15,6 @@ module @test_module {
 
 // -----
 
-module @test_module {
-  // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
-  // CHECK-LABEL: func @math_absf
-  func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.absf %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.absf %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 module @test_module {
   // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
@@ -210,21 +195,6 @@ module @test_module {
 
 // -----
 
-module @test_module {
-  // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
-  // CHECK-LABEL: func @math_exp
-  func.func @math_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.exp %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.exp %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 module @test_module {
   // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
@@ -270,21 +240,6 @@ module @test_module {
 
 // -----
 
-module @test_module {
-  // CHECK: llvm.func @__ocml_log_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
-  // CHECK-LABEL: func @math_log
-  func.func @math_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.log %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.log %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 module @test_module {
   // CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
@@ -360,21 +315,6 @@ module @test_module {
 
 // -----
 
-module @test_module {
-  // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
-  // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
-  // CHECK-LABEL: func @math_sqrt
-  func.func @math_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = math.sqrt %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
-    %result64 = math.sqrt %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
 module @test_module {
   // CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine, but I want to double-check that nothing'll go wrong with double-precision exp and log

populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
"__ocml_exp2_f64");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
"__ocml_expm1_f64");
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
"__ocml_floor_f64");
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note re f64 log

@@ -84,16 +86,12 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
"__ocml_cosh_f64");
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
"__ocml_sinh_f64");
populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having disassembled OCML, the double-precision exp isn't actually a direct wrapper around the relevant intrinsic, but I figure that's probably fine

Copy link
Contributor

@arsenm arsenm Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. We only directly handle the f32 (and I think f16) versions. The f64 versions of the hard operations do not work. We do directly handle llvm.sqrt.f64 as an exception.

Also, none of the trig functions are directly handled (correctly). We do codegen the f32 versions but probably shouldn't

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what do we do in the case of f32 being handled but f64 not, should we still just call ocml for both or modify the lowering to handle just one of them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modify the lowering to handle just one. The operation + type should be treated like different operations, so emit the working f32 intrinsics and the calls for the nonworking f64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I put the lowering for f64 back for log and exp and added the tests back for them as well.

…AMDGPU backend

This patch removes pattens for a few operations which allows mathToLLVM
conversion to convert the operations into LLVM intrinsics instead since they
are supported directly by the AMDGPU backend.
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved

(If it's possible, could you get the half-precision functions hooked up in a later PR?)

@jsjodin
Copy link
Contributor Author

jsjodin commented Sep 4, 2024

Approved

(If it's possible, could you get the half-precision functions hooked up in a later PR?)

Sure, that should be pretty easy to do.

@jsjodin jsjodin merged commit 3ebd797 into llvm:main Sep 4, 2024
8 checks passed
nirvedhmeshram added a commit that referenced this pull request Sep 11, 2024
LLVM::FAbsOp and LLVM::SqrtOp are legal after
#102971
nirvedhmeshram added a commit to iree-org/llvm-project that referenced this pull request Sep 11, 2024
LLVM::FAbsOp and LLVM::SqrtOp are legal after
llvm#102971
VitaNuo pushed a commit to VitaNuo/llvm-project that referenced this pull request Sep 12, 2024
nirvedhmeshram added a commit that referenced this pull request Sep 12, 2024
…108302)

Similar to #108266
After #102971
It is legal to generate `LLVM::ExpOp` and `LLVM::LogOp` if the type is
is a float16 or float32
@jsjodin jsjodin deleted the jleyonberg/math-rocdl-update branch October 1, 2024 14:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants