-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[AArch64][SVE2p1] Allow more uses of mask in performActiveLaneMaskCombine #159360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: Kerry McLaughlin (kmclaughlin-arm) ChangesThe combine replaces a get_active_lane_mask used by two extract subvectors with This patch changes performActiveLaneMaskCombine to count the number of Full diff: https://github.com/llvm/llvm-project/pull/159360.diff 3 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c9a756da0078d..9c7ecf944e763 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18693,21 +18693,31 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
(!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
return SDValue();
- unsigned NumUses = N->use_size();
+ // Count the number of users which are extract_vectors
+ // The only other valid users for this combine are ptest_first
+ // and reinterpret_cast.
+ unsigned NumExts = count_if(N->users(), [](SDNode *Use) {
+ return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR;
+ });
+
auto MaskEC = N->getValueType(0).getVectorElementCount();
- if (!MaskEC.isKnownMultipleOf(NumUses))
+ if (!MaskEC.isKnownMultipleOf(NumExts))
return SDValue();
- ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
+ ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts);
if (ExtMinEC.getKnownMinValue() < 2)
return SDValue();
- SmallVector<SDNode *> Extracts(NumUses, nullptr);
+ SmallVector<SDNode *> Extracts(NumExts, nullptr);
for (SDNode *Use : N->users()) {
+ if (Use->getOpcode() == AArch64ISD::PTEST_FIRST ||
+ Use->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+ continue;
+
if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
return SDValue();
- // Ensure the extract type is correct (e.g. if NumUses is 4 and
+ // Ensure the extract type is correct (e.g. if NumExts is 4 and
// the mask return type is nxv8i1, each extract should be nxv2i1.
if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
return SDValue();
@@ -18741,11 +18751,13 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
DCI.CombineTo(Extracts[0], R.getValue(0));
DCI.CombineTo(Extracts[1], R.getValue(1));
- if (NumUses == 2)
- return SDValue(N, 0);
+ if (NumExts == 2) {
+ DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
+ return SDValue(SDValue(N, 0));
+ }
auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
- for (unsigned I = 2; I < NumUses; I += 2) {
+ for (unsigned I = 2; I < NumExts; I += 2) {
// After the first whilelo_x2, we need to increment the starting value.
Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
@@ -18753,6 +18765,7 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
DCI.CombineTo(Extracts[I + 1], R.getValue(1));
}
+ DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
return SDValue(N, 0);
}
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index bf3d47ac43607..069d08663fdea 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1495,13 +1495,19 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
return PredOpcode;
- // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
- // redundant since WHILE performs an implicit PTEST with an all active
- // mask.
- if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
- getElementSizeForOpcode(MaskOpcode) ==
- getElementSizeForOpcode(PredOpcode))
- return PredOpcode;
+ if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31) {
+ auto PTestOp = MRI->getUniqueVRegDef(PTest->getOperand(1).getReg());
+ if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && PTestOp->isCopy() &&
+ PTestOp->getOperand(1).getSubReg() == AArch64::psub0)
+ return PredOpcode;
+
+ // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
+ // redundant since WHILE performs an implicit PTEST with an all active
+ // mask.
+ if (getElementSizeForOpcode(MaskOpcode) ==
+ getElementSizeForOpcode(PredOpcode))
+ return PredOpcode;
+ }
return {};
}
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index 5e01612e3881a..3b18008605413 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -310,6 +310,81 @@ define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #
ret void
}
+; Extra use of the get_active_lane_mask from an extractelement, which is replaced with ptest_first.
+
+define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: whilelo p1.b, x0, x1
+; CHECK-SVE-NEXT: b.pl .LBB11_2
+; CHECK-SVE-NEXT: // %bb.1: // %if.then
+; CHECK-SVE-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT: b use
+; CHECK-SVE-NEXT: .LBB11_2: // %if.end
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE2p1-SME2: // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB11_2
+; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT: b use
+; CHECK-SVE2p1-SME2-NEXT: .LBB11_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT: ret
+entry:
+ %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
+ %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+ %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+ %elt0 = extractelement <vscale x 16 x i1> %r, i32 0
+ br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+ tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
+ br label %if.end
+
+if.end:
+ ret void
+}
+
+; Extra use of the get_active_lane_mask from an extractelement, which is
+; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1.
+
+define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: whilelo p1.h, x0, x1
+; CHECK-SVE-NEXT: b.pl .LBB12_2
+; CHECK-SVE-NEXT: // %bb.1: // %if.then
+; CHECK-SVE-NEXT: punpklo p0.h, p1.b
+; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT: b use
+; CHECK-SVE-NEXT: .LBB12_2: // %if.end
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE2p1-SME2: // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT: b.pl .LBB12_2
+; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT: b use
+; CHECK-SVE2p1-SME2-NEXT: .LBB12_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT: ret
+entry:
+ %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+ %v0 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 0)
+ %v1 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 4)
+ %elt0 = extractelement <vscale x 8 x i1> %r, i64 0
+ br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+ tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1)
+ br label %if.end
+
+if.end:
+ ret void
+}
+
declare void @use(...)
attributes #0 = { nounwind }
|
ea8a053
to
02a75b7
Compare
DCI.CombineTo(Extracts[I + 1], R.getValue(1)); | ||
} | ||
|
||
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the above but in this case there's the extra problem that you're replacing N with the first result of the last instance of emitted while_pair, which is even more wrong if the uses happened to be a PTEST_FIRST
.
if (Use->getOpcode() == AArch64ISD::PTEST_FIRST || | ||
Use->getOpcode() == AArch64ISD::REINTERPRET_CAST) | ||
continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's worth trying to special case based on the users beyond verifying the presence of the relevant ISD::EXTRACT_SUBVECTOR
to prove the value of using the while_pair instructions. Even if the original while remains, the resulting code might be better because multiple extracts have been replaced by a single while_pair?
return SDValue(); | ||
|
||
ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses); | ||
ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, the precommit failure might be due to NumExts
being zero, which isKnownMultipleOf
should probably reject.
if (NumUses == 2) | ||
return SDValue(N, 0); | ||
if (NumExts == 2) { | ||
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if I'm misunderstanding something but I don't know how this works because SDValue(N, 0)
and R.getValue(0)
are going to have different result types? so the post combine DAG is likely to be broken.
DCI.CombineTo(Extracts[0], R.getValue(0)); | ||
DCI.CombineTo(Extracts[1], R.getValue(1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the part that needs to change. Rather that thinking along the lines of the current combine, which is replacing the tree of get_active_lane_mask and the extracts from it, you should instead focus on replacing the get_active_lane_mask in isolation.
For the simple case of NumExts == 2
this means replacing:
DCI.CombineTo(Extracts[0], R.getValue(0));
DCI.CombineTo(Extracts[1], R.getValue(1));
if (NumUses == 2)
return SDValue(N, 0);
with something like:
if (NumExts == 2)
return DAG.getNode(CONCAT_VECTORS, R.getValue(0), R.getValue(1));
Then see what the generated code looks like. If the original get_active_lane_mask
remains then I'd hope it's just a case of needing separate combines to look through the concat.
The NumUses > 2 is likely to be more complicated because this is post legalisation so I doubt you'll be able to emit a single concat but will instead need a tree on concats. Whilst more awkward to implement I don't think this will change the follow on work much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed earlier, I've changed this PR to focus only on the changes to performActiveLaneMaskCombine as these are enough to use the paired whilelo in the test cases. The combine is now using ISD::CONCAT_VECTORS
as described above.
Improving codegen of those tests (i.e. removing the ptest) will be addressed separately.
if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31) { | ||
auto PTestOp = MRI->getUniqueVRegDef(PTest->getOperand(1).getReg()); | ||
if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && PTestOp->isCopy() && | ||
PTestOp->getOperand(1).getSubReg() == AArch64::psub0) | ||
return PredOpcode; | ||
|
||
// For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is | ||
// redundant since WHILE performs an implicit PTEST with an all active | ||
// mask. | ||
if (getElementSizeForOpcode(MaskOpcode) == | ||
getElementSizeForOpcode(PredOpcode)) | ||
return PredOpcode; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, I'm kind of hoping this is just fallout from the DAG being broken and will not be necessary once fixed.
; Extra use of the get_active_lane_mask from an extractelement, which is | ||
; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1. | ||
|
||
define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's worth adding a similar test for the NumExts != 2
case, if only to see if that better exposes the issues I believe exist in the PR as it stands today.
…bine The combine replaces a get_active_lane_mask used by two extract subvectors with a single paired whilelo intrinsic. When the instruction is used for control flow in a vector loop, an additional extract of element 0 may introduce other uses of the intrinsic such as ptest and reinterpret cast, which is currently not supported. This patch changes performActiveLaneMaskCombine to count the number of extract subvectors using the mask instead of the total number of uses, and allows other uses by these additional operations.
…performActiveLaneMaskCombine - Add tests for the 4 extracts case which will use ptest & reinterpret_cast - Remove changes to canRemovePTestInstr
02a75b7
to
70adaf7
Compare
performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, | ||
const AArch64Subtarget *ST) { | ||
if (DCI.isBeforeLegalize()) | ||
if (DCI.isBeforeLegalize() && !!DCI.isBeforeLegalizeOps()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: The !!
looks a little odd. Is it possible to just use DCI.isBeforeLegalizeOps()
?
The combine replaces a get_active_lane_mask used by two extract subvectors with
a single paired whilelo intrinsic. When the instruction is used for control
flow in a vector loop, an additional extract of element 0 may introduce
other uses of the intrinsic such as ptest and reinterpret cast, which
is currently not supported.
This patch changes performActiveLaneMaskCombine to count the number of
extract subvectors using the mask instead of the total number of uses,
and returns the concatenated results of get_active_lane_mask.