Skip to content

Conversation

kmclaughlin-arm
Copy link
Contributor

@kmclaughlin-arm kmclaughlin-arm commented Sep 17, 2025

The combine replaces a get_active_lane_mask used by two extract subvectors with
a single paired whilelo intrinsic. When the instruction is used for control
flow in a vector loop, an additional extract of element 0 may introduce
other uses of the intrinsic such as ptest and reinterpret cast, which
is currently not supported.

This patch changes performActiveLaneMaskCombine to count the number of
extract subvectors using the mask instead of the total number of uses,
and returns the concatenated results of get_active_lane_mask.

@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Kerry McLaughlin (kmclaughlin-arm)

Changes

The combine replaces a get_active_lane_mask used by two extract subvectors with
a single paired whilelo intrinsic. When the instruction is used for control
flow in a vector loop, an additional extract of element 0 may introduce
other uses of the intrinsic such as ptest and reinterpret cast, which
is currently not supported.

This patch changes performActiveLaneMaskCombine to count the number of
extract subvectors using the mask instead of the total number of uses,
and allows other uses by these additional operations.


Full diff: https://github.com/llvm/llvm-project/pull/159360.diff

3 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+21-8)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+13-7)
  • (modified) llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll (+75)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c9a756da0078d..9c7ecf944e763 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18693,21 +18693,31 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
       (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
     return SDValue();
 
-  unsigned NumUses = N->use_size();
+  // Count the number of users which are extract_vectors
+  // The only other valid users for this combine are ptest_first
+  // and reinterpret_cast.
+  unsigned NumExts = count_if(N->users(), [](SDNode *Use) {
+    return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR;
+  });
+
   auto MaskEC = N->getValueType(0).getVectorElementCount();
-  if (!MaskEC.isKnownMultipleOf(NumUses))
+  if (!MaskEC.isKnownMultipleOf(NumExts))
     return SDValue();
 
-  ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
+  ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts);
   if (ExtMinEC.getKnownMinValue() < 2)
     return SDValue();
 
-  SmallVector<SDNode *> Extracts(NumUses, nullptr);
+  SmallVector<SDNode *> Extracts(NumExts, nullptr);
   for (SDNode *Use : N->users()) {
+    if (Use->getOpcode() == AArch64ISD::PTEST_FIRST ||
+        Use->getOpcode() == AArch64ISD::REINTERPRET_CAST)
+      continue;
+
     if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
       return SDValue();
 
-    // Ensure the extract type is correct (e.g. if NumUses is 4 and
+    // Ensure the extract type is correct (e.g. if NumExts is 4 and
     // the mask return type is nxv8i1, each extract should be nxv2i1.
     if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
       return SDValue();
@@ -18741,11 +18751,13 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
   DCI.CombineTo(Extracts[0], R.getValue(0));
   DCI.CombineTo(Extracts[1], R.getValue(1));
 
-  if (NumUses == 2)
-    return SDValue(N, 0);
+  if (NumExts == 2) {
+    DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
+    return SDValue(SDValue(N, 0));
+  }
 
   auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
-  for (unsigned I = 2; I < NumUses; I += 2) {
+  for (unsigned I = 2; I < NumExts; I += 2) {
     // After the first whilelo_x2, we need to increment the starting value.
     Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
     R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
@@ -18753,6 +18765,7 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
     DCI.CombineTo(Extracts[I + 1], R.getValue(1));
   }
 
+  DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
   return SDValue(N, 0);
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index bf3d47ac43607..069d08663fdea 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1495,13 +1495,19 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
     if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
       return PredOpcode;
 
-    // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
-    // redundant since WHILE performs an implicit PTEST with an all active
-    // mask.
-    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
-        getElementSizeForOpcode(MaskOpcode) ==
-            getElementSizeForOpcode(PredOpcode))
-      return PredOpcode;
+    if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31) {
+      auto PTestOp = MRI->getUniqueVRegDef(PTest->getOperand(1).getReg());
+      if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && PTestOp->isCopy() &&
+          PTestOp->getOperand(1).getSubReg() == AArch64::psub0)
+        return PredOpcode;
+
+      // For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
+      // redundant since WHILE performs an implicit PTEST with an all active
+      // mask.
+      if (getElementSizeForOpcode(MaskOpcode) ==
+          getElementSizeForOpcode(PredOpcode))
+        return PredOpcode;
+    }
 
     return {};
   }
diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
index 5e01612e3881a..3b18008605413 100644
--- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
+++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
@@ -310,6 +310,81 @@ define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #
   ret void
 }
 
+; Extra use of the get_active_lane_mask from an extractelement, which is replaced with ptest_first.
+
+define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    whilelo p1.b, x0, x1
+; CHECK-SVE-NEXT:    b.pl .LBB11_2
+; CHECK-SVE-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    b use
+; CHECK-SVE-NEXT:  .LBB11_2: // %if.end
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest:
+; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.h, p1.h }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB11_2
+; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    b use
+; CHECK-SVE2p1-SME2-NEXT:  .LBB11_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT:    ret
+entry:
+    %r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n)
+    %v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
+    %v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
+    %elt0 = extractelement <vscale x 16 x i1> %r, i32 0
+    br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+    tail call void @use(<vscale x 8 x i1> %v0, <vscale x 8 x i1> %v1)
+    br label %if.end
+
+if.end:
+    ret void
+}
+
+; Extra use of the get_active_lane_mask from an extractelement, which is
+; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1.
+
+define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
+; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE:       // %bb.0: // %entry
+; CHECK-SVE-NEXT:    whilelo p1.h, x0, x1
+; CHECK-SVE-NEXT:    b.pl .LBB12_2
+; CHECK-SVE-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE-NEXT:    punpklo p0.h, p1.b
+; CHECK-SVE-NEXT:    punpkhi p1.h, p1.b
+; CHECK-SVE-NEXT:    b use
+; CHECK-SVE-NEXT:  .LBB12_2: // %if.end
+; CHECK-SVE-NEXT:    ret
+;
+; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts:
+; CHECK-SVE2p1-SME2:       // %bb.0: // %entry
+; CHECK-SVE2p1-SME2-NEXT:    whilelo { p0.s, p1.s }, x0, x1
+; CHECK-SVE2p1-SME2-NEXT:    b.pl .LBB12_2
+; CHECK-SVE2p1-SME2-NEXT:  // %bb.1: // %if.then
+; CHECK-SVE2p1-SME2-NEXT:    b use
+; CHECK-SVE2p1-SME2-NEXT:  .LBB12_2: // %if.end
+; CHECK-SVE2p1-SME2-NEXT:    ret
+entry:
+    %r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
+    %v0 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 0)
+    %v1 = tail call <vscale x 4 x i1> @llvm.vector.extract.nxv4i1.nxv8i1(<vscale x 8 x i1> %r, i64 4)
+    %elt0 = extractelement <vscale x 8 x i1> %r, i64 0
+    br i1 %elt0, label %if.then, label %if.end
+
+if.then:
+    tail call void @use(<vscale x 4 x i1> %v0, <vscale x 4 x i1> %v1)
+    br label %if.end
+
+if.end:
+    ret void
+}
+
 declare void @use(...)
 
 attributes #0 = { nounwind }

DCI.CombineTo(Extracts[I + 1], R.getValue(1));
}

DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the above but in this case there's the extra problem that you're replacing N with the first result of the last instance of emitted while_pair, which is even more wrong if the uses happened to be a PTEST_FIRST.

Comment on lines 18805 to 18807
if (Use->getOpcode() == AArch64ISD::PTEST_FIRST ||
Use->getOpcode() == AArch64ISD::REINTERPRET_CAST)
continue;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's worth trying to special case based on the users beyond verifying the presence of the relevant ISD::EXTRACT_SUBVECTOR to prove the value of using the while_pair instructions. Even if the original while remains, the resulting code might be better because multiple extracts have been replaced by a single while_pair?

return SDValue();

ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, the precommit failure might be due to NumExts being zero, which isKnownMultipleOf should probably reject.

if (NumUses == 2)
return SDValue(N, 0);
if (NumExts == 2) {
DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), R.getValue(0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I'm misunderstanding something but I don't know how this works because SDValue(N, 0) and R.getValue(0) are going to have different result types? so the post combine DAG is likely to be broken.

Comment on lines 18843 to 18840
DCI.CombineTo(Extracts[0], R.getValue(0));
DCI.CombineTo(Extracts[1], R.getValue(1));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the part that needs to change. Rather that thinking along the lines of the current combine, which is replacing the tree of get_active_lane_mask and the extracts from it, you should instead focus on replacing the get_active_lane_mask in isolation.

For the simple case of NumExts == 2 this means replacing:

DCI.CombineTo(Extracts[0], R.getValue(0));
DCI.CombineTo(Extracts[1], R.getValue(1));

if (NumUses == 2)
  return SDValue(N, 0);

with something like:

if (NumExts == 2)
  return DAG.getNode(CONCAT_VECTORS, R.getValue(0), R.getValue(1));

Then see what the generated code looks like. If the original get_active_lane_mask remains then I'd hope it's just a case of needing separate combines to look through the concat.

The NumUses > 2 is likely to be more complicated because this is post legalisation so I doubt you'll be able to emit a single concat but will instead need a tree on concats. Whilst more awkward to implement I don't think this will change the follow on work much.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed earlier, I've changed this PR to focus only on the changes to performActiveLaneMaskCombine as these are enough to use the paired whilelo in the test cases. The combine is now using ISD::CONCAT_VECTORS as described above.

Improving codegen of those tests (i.e. removing the ptest) will be addressed separately.

Comment on lines 1498 to 1510
if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31) {
auto PTestOp = MRI->getUniqueVRegDef(PTest->getOperand(1).getReg());
if (PTest->getOpcode() == AArch64::PTEST_PP_FIRST && PTestOp->isCopy() &&
PTestOp->getOperand(1).getSubReg() == AArch64::psub0)
return PredOpcode;

// For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
// redundant since WHILE performs an implicit PTEST with an all active
// mask.
if (getElementSizeForOpcode(MaskOpcode) ==
getElementSizeForOpcode(PredOpcode))
return PredOpcode;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, I'm kind of hoping this is just fallout from the DAG being broken and will not be necessary once fixed.

; Extra use of the get_active_lane_mask from an extractelement, which is
; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1.

define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth adding a similar test for the NumExts != 2 case, if only to see if that better exposes the issues I believe exist in the PR as it stands today.

…bine

The combine replaces a get_active_lane_mask used by two extract subvectors with
a single paired whilelo intrinsic. When the instruction is used for control
flow in a vector loop, an additional extract of element 0 may introduce
other uses of the intrinsic such as ptest and reinterpret cast, which
is currently not supported.

This patch changes performActiveLaneMaskCombine to count the number of
extract subvectors using the mask instead of the total number of uses,
and allows other uses by these additional operations.
…performActiveLaneMaskCombine

- Add tests for the 4 extracts case which will use ptest & reinterpret_cast
- Remove changes to canRemovePTestInstr
performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *ST) {
if (DCI.isBeforeLegalize())
if (DCI.isBeforeLegalize() && !!DCI.isBeforeLegalizeOps())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The !! looks a little odd. Is it possible to just use DCI.isBeforeLegalizeOps()?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants