[AArch64] Use SVE fdot for partial.reduce.fadd for NEON types.#167856
[AArch64] Use SVE fdot for partial.reduce.fadd for NEON types.#167856sdesmalen-arm merged 1 commit intomainfrom
Conversation
We only seem to use the SVE fdot for fixed-length vector types when they are larger than 128bits, whereas we can also use them for 128bits vectors if SVE2p1/SME2 is available.
There was a problem hiding this comment.
Pull Request Overview
This PR enables the use of SVE2p1/SME2 fdot instructions for 128-bit NEON vector types during partial reduction operations. Previously, these optimizations were only applied to scalable vector types larger than 128 bits.
Key Changes:
- Added custom legalization to allow SVE2p1
fdotinstruction usage for fixed-length v4f32/v8f16 vector operations - Added test coverage for the new 128-bit vector optimization path
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | Added custom action for PARTIAL_REDUCE_FMLA with v4f32/v8f16 types when SVE2p1/SME2 is available |
| llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll | Added fdot_v4f32 test and updated fixed_fdot_wide test to verify SVE2p1 code generation for 128-bit vectors |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@llvm/pr-subscribers-backend-aarch64 Author: Sander de Smalen (sdesmalen-arm) ChangesWe only seem to use the SVE fdot for fixed-length vector types when they are larger than 128bits, whereas we can also use them for 128bits vectors if SVE2p1/SME2 is available. Full diff: https://github.com/llvm/llvm-project/pull/167856.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8457f6178fdc2..e36396c7bdf2b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1921,6 +1921,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::nxv4f32,
MVT::nxv8f16, Legal);
+ // We can use SVE2p1 fdot to emulate the fixed-length variant.
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::v4f32,
+ MVT::v8f16, Custom);
}
}
diff --git a/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll b/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll
index 89216ce2cb72b..864c66caf5f6c 100644
--- a/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll
+++ b/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll
@@ -4,6 +4,43 @@
target triple = "aarch64-linux-gnu"
+define void @fdot_v4f32(ptr %accptr, ptr %aptr, ptr %bptr) {
+; SVE2-LABEL: fdot_v4f32:
+; SVE2: // %bb.0: // %entry
+; SVE2-NEXT: ldr q0, [x1]
+; SVE2-NEXT: ldr q1, [x2]
+; SVE2-NEXT: fcvtl v2.4s, v0.4h
+; SVE2-NEXT: fcvtl v3.4s, v1.4h
+; SVE2-NEXT: fcvtl2 v0.4s, v0.8h
+; SVE2-NEXT: fcvtl2 v1.4s, v1.8h
+; SVE2-NEXT: fmul v2.4s, v2.4s, v3.4s
+; SVE2-NEXT: ldr q3, [x0]
+; SVE2-NEXT: fmul v0.4s, v0.4s, v1.4s
+; SVE2-NEXT: fadd v1.4s, v3.4s, v2.4s
+; SVE2-NEXT: fadd v0.4s, v1.4s, v0.4s
+; SVE2-NEXT: str q0, [x0]
+; SVE2-NEXT: ret
+;
+; SVE2P1-LABEL: fdot_v4f32:
+; SVE2P1: // %bb.0: // %entry
+; SVE2P1-NEXT: ldr q0, [x0]
+; SVE2P1-NEXT: ldr q1, [x1]
+; SVE2P1-NEXT: ldr q2, [x2]
+; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT: str q0, [x0]
+; SVE2P1-NEXT: ret
+entry:
+ %acc = load <4 x float>, ptr %accptr
+ %a = load <8 x half>, ptr %aptr
+ %b = load <8 x half>, ptr %bptr
+ %a.wide = fpext <8 x half> %a to <8 x float>
+ %b.wide = fpext <8 x half> %b to <8 x float>
+ %mult = fmul <8 x float> %a.wide, %b.wide
+ %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %mult)
+ store <4 x float> %partial.reduce, ptr %accptr
+ ret void
+}
+
define void @fdot_wide_v8f32(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,0) {
; SVE2-LABEL: fdot_wide_v8f32:
; SVE2: // %bb.0: // %entry
@@ -177,17 +214,26 @@ entry:
}
define <4 x float> @fixed_fdot_wide(<4 x float> %acc, <8 x half> %a, <8 x half> %b) {
-; CHECK-LABEL: fixed_fdot_wide:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: fcvtl v3.4s, v1.4h
-; CHECK-NEXT: fcvtl v4.4s, v2.4h
-; CHECK-NEXT: fcvtl2 v1.4s, v1.8h
-; CHECK-NEXT: fcvtl2 v2.4s, v2.8h
-; CHECK-NEXT: fmul v3.4s, v3.4s, v4.4s
-; CHECK-NEXT: fmul v1.4s, v1.4s, v2.4s
-; CHECK-NEXT: fadd v0.4s, v0.4s, v3.4s
-; CHECK-NEXT: fadd v0.4s, v0.4s, v1.4s
-; CHECK-NEXT: ret
+; SVE2-LABEL: fixed_fdot_wide:
+; SVE2: // %bb.0: // %entry
+; SVE2-NEXT: fcvtl v3.4s, v1.4h
+; SVE2-NEXT: fcvtl v4.4s, v2.4h
+; SVE2-NEXT: fcvtl2 v1.4s, v1.8h
+; SVE2-NEXT: fcvtl2 v2.4s, v2.8h
+; SVE2-NEXT: fmul v3.4s, v3.4s, v4.4s
+; SVE2-NEXT: fmul v1.4s, v1.4s, v2.4s
+; SVE2-NEXT: fadd v0.4s, v0.4s, v3.4s
+; SVE2-NEXT: fadd v0.4s, v0.4s, v1.4s
+; SVE2-NEXT: ret
+;
+; SVE2P1-LABEL: fixed_fdot_wide:
+; SVE2P1: // %bb.0: // %entry
+; SVE2P1-NEXT: // kill: def $q0 killed $q0 def $z0
+; SVE2P1-NEXT: // kill: def $q2 killed $q2 def $z2
+; SVE2P1-NEXT: // kill: def $q1 killed $q1 def $z1
+; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT: // kill: def $q0 killed $q0 killed $z0
+; SVE2P1-NEXT: ret
entry:
%a.wide = fpext <8 x half> %a to <8 x float>
%b.wide = fpext <8 x half> %b to <8 x float>
|
|
LGTM |
We only seem to use the SVE fdot for fixed-length vector types when they are larger than 128bits, whereas we can also use them for 128bits vectors if SVE2p1/SME2 is available.