-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][ROCDL] Remove patterns for ops supported as intrinsics in the AMDGPU backend #102971
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Jan Leyonberg (jsjodin) ChangesThis patch removes patterns for a few operations which allows mathToLLVM conversion to convert the operations into LLVM intrinsics instead since they are supported directly by the AMDGPU backend. Full diff: https://github.com/llvm/llvm-project/pull/102971.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 7de6971ba2ee72..fd4eab0e10d67e 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -48,18 +48,20 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
// Handled by mathToLLVM: math::AbsIOp
+ // Handled by mathToLLVM: math::AbsFIOp
// Handled by mathToLLVM: math::CopySignOp
// Handled by mathToLLVM: math::CountLeadingZerosOp
// Handled by mathToLLVM: math::CountTrailingZerosOp
// Handled by mathToLLVM: math::CgPopOp
+ // Handled by mathToLLVM: math::ExpOp
// Handled by mathToLLVM: math::FmaOp
+ // Handled by mathToLLVM: math::LogOp
// FIXME: math::IPowIOp
// FIXME: math::FPowIOp
// Handled by mathToLLVM: math::RoundEvenOp
// Handled by mathToLLVM: math::RoundOp
+ // Handled by mathToLLVM: math::SqrtOp
// Handled by mathToLLVM: math::TruncOp
- populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
- "__ocml_fabs_f64");
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
"__ocml_acos_f64");
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
@@ -84,16 +86,12 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
"__ocml_cosh_f64");
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
"__ocml_sinh_f64");
- populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
- "__ocml_exp_f64");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
"__ocml_exp2_f64");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
"__ocml_expm1_f64");
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
"__ocml_floor_f64");
- populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
- "__ocml_log_f64");
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
"__ocml_log10_f64");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
@@ -106,8 +104,6 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
"__ocml_rsqrt_f64");
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
"__ocml_sin_f64");
- populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
- "__ocml_sqrt_f64");
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
"__ocml_tanh_f64");
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index bf49a42a115775..4f1f26e8794d9e 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -131,21 +131,6 @@ gpu.module @test_module {
// -----
-gpu.module @test_module {
- // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
- // CHECK-LABEL: func @gpu_fabs
- func.func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %result32 = math.absf %arg_f32 : f32
- // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.absf %arg_f64 : f64
- // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
gpu.module @test_module {
// CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
@@ -206,23 +191,6 @@ gpu.module @test_module {
// -----
-gpu.module @test_module {
- // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
- // CHECK-LABEL: func @gpu_exp
- func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %exp_f32 = math.exp %arg_f32 : f32
- // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
- %result32 = math.exp %exp_f32 : f32
- // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.exp %arg_f64 : f64
- // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
gpu.module @test_module {
// CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
// CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
@@ -239,21 +207,20 @@ gpu.module @test_module {
}
// -----
-
// Test that we handled properly operation with SymbolTable other than module op
gpu.module @test_module {
"test.symbol_scope"() ({
// CHECK: test.symbol_scope
- // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
- // CHECK-LABEL: func @gpu_exp
- func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %exp_f32 = math.exp %arg_f32 : f32
- // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
- %result32 = math.exp %exp_f32 : f32
- // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.exp %arg_f64 : f64
- // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
+ // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
+ // CHECK-LABEL: func @gpu_sin
+ func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %sin_f32 = math.sin %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+ %result32 = math.sin %sin_f32 : f32
+ // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.sin %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
func.return %result32, %result64 : f32, f64
}
"test.finish" () : () -> ()
@@ -279,21 +246,6 @@ gpu.module @test_module {
// -----
-gpu.module @test_module {
- // CHECK: llvm.func @__ocml_log_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
- // CHECK-LABEL: func @gpu_log
- func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %result32 = math.log %arg_f32 : f32
- // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.log %arg_f64 : f64
- // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
gpu.module @test_module {
// CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
@@ -359,26 +311,6 @@ gpu.module @test_module {
// -----
-gpu.module @test_module {
- // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
- // CHECK-LABEL: func @gpu_sqrt
- func.func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
- -> (f16, f32, f64) {
- %result16 = math.sqrt %arg_f16 : f16
- // CHECK: llvm.fpext %{{.*}} : f16 to f32
- // CHECK-NEXT: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
- // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16
- %result32 = math.sqrt %arg_f32 : f32
- // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.sqrt %arg_f64 : f64
- // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
- }
-}
-
-// -----
-
gpu.module @test_module {
// CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
@@ -472,15 +404,15 @@ gpu.module @test_module {
gpu.module @test_module {
// CHECK-LABEL: func @gpu_unroll
func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
- %result = math.exp %arg0 : vector<4xf32>
+ %result = math.sin %arg0 : vector<4xf32>
// CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<4xf32>
- // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+ // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[V1:.+]] = llvm.insertelement %[[CL]], %[[V0]]
- // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+ // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[V2:.+]] = llvm.insertelement %[[CL]], %[[V1]]
- // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+ // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[V3:.+]] = llvm.insertelement %[[CL]], %[[V2]]
- // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+ // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[V4:.+]] = llvm.insertelement %[[CL]], %[[V3]]
// CHECK: return %[[V4]]
func.return %result : vector<4xf32>
@@ -526,9 +458,9 @@ gpu.module @test_module {
gpu.module @module {
// CHECK-LABEL: @spirv_exp
-// CHECK: llvm.call @__ocml_exp_f32
+// CHECK: llvm.call @__ocml_sin_f32
spirv.func @spirv_exp(%arg0: vector<4xf32>) -> vector<4xf32> "None" {
- %0 = math.exp %arg0 : vector<4xf32>
+ %0 = math.sin %arg0 : vector<4xf32>
spirv.ReturnValue %0 : vector<4xf32>
}
}
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index a406ec45a7f109..9a05a94f9f1ac7 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -15,21 +15,6 @@ module @test_module {
// -----
-module @test_module {
- // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
- // CHECK-LABEL: func @math_absf
- func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %result32 = math.absf %arg_f32 : f32
- // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.absf %arg_f64 : f64
- // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
module @test_module {
// CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
// CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
@@ -210,21 +195,6 @@ module @test_module {
// -----
-module @test_module {
- // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
- // CHECK-LABEL: func @math_exp
- func.func @math_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %result32 = math.exp %arg_f32 : f32
- // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.exp %arg_f64 : f64
- // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
module @test_module {
// CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
// CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
@@ -270,21 +240,6 @@ module @test_module {
// -----
-module @test_module {
- // CHECK: llvm.func @__ocml_log_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
- // CHECK-LABEL: func @math_log
- func.func @math_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %result32 = math.log %arg_f32 : f32
- // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.log %arg_f64 : f64
- // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
module @test_module {
// CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
@@ -360,21 +315,6 @@ module @test_module {
// -----
-module @test_module {
- // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
- // CHECK-LABEL: func @math_sqrt
- func.func @math_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %result32 = math.sqrt %arg_f32 : f32
- // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.sqrt %arg_f64 : f64
- // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
module @test_module {
// CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine, but I want to double-check that nothing'll go wrong with double-precision exp
and log
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32", | ||
"__ocml_exp2_f64"); | ||
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32", | ||
"__ocml_expm1_f64"); | ||
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32", | ||
"__ocml_floor_f64"); | ||
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same note re f64 log
@@ -84,16 +86,12 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter, | |||
"__ocml_cosh_f64"); | |||
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32", | |||
"__ocml_sinh_f64"); | |||
populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having disassembled OCML, the double-precision exp isn't actually a direct wrapper around the relevant intrinsic, but I figure that's probably fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. We only directly handle the f32 (and I think f16) versions. The f64 versions of the hard operations do not work. We do directly handle llvm.sqrt.f64 as an exception.
Also, none of the trig functions are directly handled (correctly). We do codegen the f32 versions but probably shouldn't
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what do we do in the case of f32 being handled but f64 not, should we still just call ocml for both or modify the lowering to handle just one of them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modify the lowering to handle just one. The operation + type should be treated like different operations, so emit the working f32 intrinsics and the calls for the nonworking f64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I put the lowering for f64 back for log and exp and added the tests back for them as well.
…AMDGPU backend This patch removes pattens for a few operations which allows mathToLLVM conversion to convert the operations into LLVM intrinsics instead since they are supported directly by the AMDGPU backend.
8fe4b0e
to
58f0fc6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved
(If it's possible, could you get the half-precision functions hooked up in a later PR?)
Sure, that should be pretty easy to do. |
LLVM::FAbsOp and LLVM::SqrtOp are legal after #102971
LLVM::FAbsOp and LLVM::SqrtOp are legal after llvm#102971
LLVM::FAbsOp and LLVM::SqrtOp are legal after llvm#102971
This patch removes patterns for a few operations which allows mathToLLVM conversion to convert the operations into LLVM intrinsics instead since they are supported directly by the AMDGPU backend.