diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index fc3efb072d57b..7a75fcfee5f7b 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -18774,7 +18774,7 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N, static SDValue performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *ST) { - if (DCI.isBeforeLegalize()) + if (DCI.isBeforeLegalize() && !DCI.isBeforeLegalizeOps()) return SDValue(); if (SDValue While = optimizeIncrementingWhile(N, DCI.DAG, /*IsSigned=*/false, @@ -18785,21 +18785,27 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming()))) return SDValue(); - unsigned NumUses = N->use_size(); + // Count the number of users which are extract_vectors + // The only other valid users for this combine are ptest_first + // and reinterpret_cast. + unsigned NumExts = count_if(N->users(), [](SDNode *Use) { + return Use->getOpcode() == ISD::EXTRACT_SUBVECTOR; + }); + auto MaskEC = N->getValueType(0).getVectorElementCount(); - if (!MaskEC.isKnownMultipleOf(NumUses)) + if (NumExts == 0 || !MaskEC.isKnownMultipleOf(NumExts)) return SDValue(); - ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses); + ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumExts); if (ExtMinEC.getKnownMinValue() < 2) return SDValue(); - SmallVector Extracts(NumUses, nullptr); + SmallVector Extracts(NumExts, nullptr); for (SDNode *Use : N->users()) { if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR) - return SDValue(); + continue; - // Ensure the extract type is correct (e.g. if NumUses is 4 and + // Ensure the extract type is correct (e.g. if NumExts is 4 and // the mask return type is nxv8i1, each extract should be nxv2i1. if (Use->getValueType(0).getVectorElementCount() != ExtMinEC) return SDValue(); @@ -18832,20 +18838,23 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC}); DCI.CombineTo(Extracts[0], R.getValue(0)); DCI.CombineTo(Extracts[1], R.getValue(1)); + SmallVector Results = {R.getValue(0), R.getValue(1)}; - if (NumUses == 2) - return SDValue(N, 0); + if (NumExts == 2) + return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Results); auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2); - for (unsigned I = 2; I < NumUses; I += 2) { + for (unsigned I = 2; I < NumExts; I += 2) { // After the first whilelo_x2, we need to increment the starting value. Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts); R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC}); DCI.CombineTo(Extracts[I], R.getValue(0)); DCI.CombineTo(Extracts[I + 1], R.getValue(1)); + Results.push_back(R.getValue(0)); + Results.push_back(R.getValue(1)); } - return SDValue(N, 0); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0), Results); } // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce diff --git a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll index 5e01612e3881a..0531fdf5c35cf 100644 --- a/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll +++ b/llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll @@ -294,12 +294,13 @@ define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) # ; CHECK-SVE2p1-SME2-LABEL: test_2x32bit_mask_with_32bit_index_and_trip_count: ; CHECK-SVE2p1-SME2: // %bb.0: ; CHECK-SVE2p1-SME2-NEXT: rdvl x8, #2 -; CHECK-SVE2p1-SME2-NEXT: mov w9, w1 -; CHECK-SVE2p1-SME2-NEXT: mov w10, w0 -; CHECK-SVE2p1-SME2-NEXT: adds w8, w0, w8 -; CHECK-SVE2p1-SME2-NEXT: csinv w8, w8, wzr, lo -; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.b, p1.b }, x10, x9 -; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.b, p3.b }, x8, x9 +; CHECK-SVE2p1-SME2-NEXT: mov w9, w0 +; CHECK-SVE2p1-SME2-NEXT: mov w10, w1 +; CHECK-SVE2p1-SME2-NEXT: mrs x8, NZCV +; CHECK-SVE2p1-SME2-NEXT: adds x8, x9, x8 +; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo +; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.b, p1.b }, x9, x10 +; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.b, p3.b }, x8, x10 ; CHECK-SVE2p1-SME2-NEXT: b use %r = call @llvm.get.active.lane.mask.nxv64i1.i32(i32 %i, i32 %n) %v0 = call @llvm.vector.extract.nxv16i1.nxv64i1.i64( %r, i64 0) @@ -310,6 +311,187 @@ define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) # ret void } +; Extra use of the get_active_lane_mask from an extractelement, which is replaced with ptest_first. + +define void @test_2x8bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) { +; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_ptest: +; CHECK-SVE: // %bb.0: // %entry +; CHECK-SVE-NEXT: whilelo p1.b, x0, x1 +; CHECK-SVE-NEXT: b.pl .LBB11_2 +; CHECK-SVE-NEXT: // %bb.1: // %if.then +; CHECK-SVE-NEXT: punpklo p0.h, p1.b +; CHECK-SVE-NEXT: punpkhi p1.h, p1.b +; CHECK-SVE-NEXT: b use +; CHECK-SVE-NEXT: .LBB11_2: // %if.end +; CHECK-SVE-NEXT: ret +; +; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_ptest: +; CHECK-SVE2p1-SME2: // %bb.0: // %entry +; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1 +; CHECK-SVE2p1-SME2-NEXT: uzp1 p2.b, p0.b, p1.b +; CHECK-SVE2p1-SME2-NEXT: mov z0.b, p2/z, #1 // =0x1 +; CHECK-SVE2p1-SME2-NEXT: fmov w8, s0 +; CHECK-SVE2p1-SME2-NEXT: tbz w8, #0, .LBB11_2 +; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then +; CHECK-SVE2p1-SME2-NEXT: b use +; CHECK-SVE2p1-SME2-NEXT: .LBB11_2: // %if.end +; CHECK-SVE2p1-SME2-NEXT: ret +entry: + %r = call @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n) + %v0 = call @llvm.vector.extract.nxv8i1.nxv16i1.i64( %r, i64 0) + %v1 = call @llvm.vector.extract.nxv8i1.nxv16i1.i64( %r, i64 8) + %elt0 = extractelement %r, i32 0 + br i1 %elt0, label %if.then, label %if.end + +if.then: + tail call void @use( %v0, %v1) + br label %if.end + +if.end: + ret void +} + +; Extra use of the get_active_lane_mask from an extractelement, which is +; replaced with ptest_first and reinterpret_casts because the extract is not nxv16i1. + +define void @test_2x8bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) { +; CHECK-SVE-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts: +; CHECK-SVE: // %bb.0: // %entry +; CHECK-SVE-NEXT: whilelo p1.h, x0, x1 +; CHECK-SVE-NEXT: b.pl .LBB12_2 +; CHECK-SVE-NEXT: // %bb.1: // %if.then +; CHECK-SVE-NEXT: punpklo p0.h, p1.b +; CHECK-SVE-NEXT: punpkhi p1.h, p1.b +; CHECK-SVE-NEXT: b use +; CHECK-SVE-NEXT: .LBB12_2: // %if.end +; CHECK-SVE-NEXT: ret +; +; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_extracts_and_reinterpret_casts: +; CHECK-SVE2p1-SME2: // %bb.0: // %entry +; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1 +; CHECK-SVE2p1-SME2-NEXT: uzp1 p2.h, p0.h, p1.h +; CHECK-SVE2p1-SME2-NEXT: mov z0.h, p2/z, #1 // =0x1 +; CHECK-SVE2p1-SME2-NEXT: fmov w8, s0 +; CHECK-SVE2p1-SME2-NEXT: tbz w8, #0, .LBB12_2 +; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then +; CHECK-SVE2p1-SME2-NEXT: b use +; CHECK-SVE2p1-SME2-NEXT: .LBB12_2: // %if.end +; CHECK-SVE2p1-SME2-NEXT: ret +entry: + %r = call @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n) + %v0 = tail call @llvm.vector.extract.nxv4i1.nxv8i1( %r, i64 0) + %v1 = tail call @llvm.vector.extract.nxv4i1.nxv8i1( %r, i64 4) + %elt0 = extractelement %r, i64 0 + br i1 %elt0, label %if.then, label %if.end + +if.then: + tail call void @use( %v0, %v1) + br label %if.end + +if.end: + ret void +} + +define void @test_4x4bit_mask_with_extracts_and_ptest(i64 %i, i64 %n) { +; CHECK-SVE-LABEL: test_4x4bit_mask_with_extracts_and_ptest: +; CHECK-SVE: // %bb.0: // %entry +; CHECK-SVE-NEXT: whilelo p0.b, x0, x1 +; CHECK-SVE-NEXT: b.pl .LBB13_2 +; CHECK-SVE-NEXT: // %bb.1: // %if.then +; CHECK-SVE-NEXT: punpklo p1.h, p0.b +; CHECK-SVE-NEXT: punpkhi p3.h, p0.b +; CHECK-SVE-NEXT: punpklo p0.h, p1.b +; CHECK-SVE-NEXT: punpkhi p1.h, p1.b +; CHECK-SVE-NEXT: punpklo p2.h, p3.b +; CHECK-SVE-NEXT: punpkhi p3.h, p3.b +; CHECK-SVE-NEXT: b use +; CHECK-SVE-NEXT: .LBB13_2: // %if.end +; CHECK-SVE-NEXT: ret +; +; CHECK-SVE2p1-SME2-LABEL: test_4x4bit_mask_with_extracts_and_ptest: +; CHECK-SVE2p1-SME2: // %bb.0: // %entry +; CHECK-SVE2p1-SME2-NEXT: cnth x8 +; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8 +; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo +; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.s, p1.s }, x0, x1 +; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.s, p3.s }, x8, x1 +; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.h, p0.h, p1.h +; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p2.h, p3.h +; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.b, p5.b, p4.b +; CHECK-SVE2p1-SME2-NEXT: mov z0.b, p4/z, #1 // =0x1 +; CHECK-SVE2p1-SME2-NEXT: fmov w8, s0 +; CHECK-SVE2p1-SME2-NEXT: tbz w8, #0, .LBB13_2 +; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then +; CHECK-SVE2p1-SME2-NEXT: b use +; CHECK-SVE2p1-SME2-NEXT: .LBB13_2: // %if.end +; CHECK-SVE2p1-SME2-NEXT: ret +entry: + %r = call @llvm.get.active.lane.mask.nxv16i1.i32(i64 %i, i64 %n) + %v0 = call @llvm.vector.extract.nxv4i1.nxv16i1.i64( %r, i64 0) + %v1 = call @llvm.vector.extract.nxv4i1.nxv16i1.i64( %r, i64 4) + %v2 = call @llvm.vector.extract.nxv4i1.nxv16i1.i64( %r, i64 8) + %v3 = call @llvm.vector.extract.nxv4i1.nxv16i1.i64( %r, i64 12) + %elt0 = extractelement %r, i32 0 + br i1 %elt0, label %if.then, label %if.end + +if.then: + tail call void @use( %v0, %v1, %v2, %v3) + br label %if.end + +if.end: + ret void +} + +define void @test_4x2bit_mask_with_extracts_and_reinterpret_casts(i64 %i, i64 %n) { +; CHECK-SVE-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts: +; CHECK-SVE: // %bb.0: // %entry +; CHECK-SVE-NEXT: whilelo p0.h, x0, x1 +; CHECK-SVE-NEXT: b.pl .LBB14_2 +; CHECK-SVE-NEXT: // %bb.1: // %if.then +; CHECK-SVE-NEXT: punpklo p1.h, p0.b +; CHECK-SVE-NEXT: punpkhi p3.h, p0.b +; CHECK-SVE-NEXT: punpklo p0.h, p1.b +; CHECK-SVE-NEXT: punpkhi p1.h, p1.b +; CHECK-SVE-NEXT: punpklo p2.h, p3.b +; CHECK-SVE-NEXT: punpkhi p3.h, p3.b +; CHECK-SVE-NEXT: b use +; CHECK-SVE-NEXT: .LBB14_2: // %if.end +; CHECK-SVE-NEXT: ret +; +; CHECK-SVE2p1-SME2-LABEL: test_4x2bit_mask_with_extracts_and_reinterpret_casts: +; CHECK-SVE2p1-SME2: // %bb.0: // %entry +; CHECK-SVE2p1-SME2-NEXT: cntw x8 +; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8 +; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo +; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1 +; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1 +; CHECK-SVE2p1-SME2-NEXT: uzp1 p5.s, p0.s, p1.s +; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.s, p2.s, p3.s +; CHECK-SVE2p1-SME2-NEXT: uzp1 p4.h, p5.h, p4.h +; CHECK-SVE2p1-SME2-NEXT: mov z0.h, p4/z, #1 // =0x1 +; CHECK-SVE2p1-SME2-NEXT: fmov w8, s0 +; CHECK-SVE2p1-SME2-NEXT: tbz w8, #0, .LBB14_2 +; CHECK-SVE2p1-SME2-NEXT: // %bb.1: // %if.then +; CHECK-SVE2p1-SME2-NEXT: b use +; CHECK-SVE2p1-SME2-NEXT: .LBB14_2: // %if.end +; CHECK-SVE2p1-SME2-NEXT: ret +entry: + %r = call @llvm.get.active.lane.mask.nxv8i1.i32(i64 %i, i64 %n) + %v0 = call @llvm.vector.extract.nxv2i1.nxv8i1.i64( %r, i64 0) + %v1 = call @llvm.vector.extract.nxv2i1.nxv8i1.i64( %r, i64 2) + %v2 = call @llvm.vector.extract.nxv2i1.nxv8i1.i64( %r, i64 4) + %v3 = call @llvm.vector.extract.nxv2i1.nxv8i1.i64( %r, i64 6) + %elt0 = extractelement %r, i32 0 + br i1 %elt0, label %if.then, label %if.end + +if.then: + tail call void @use( %v0, %v1, %v2, %v3) + br label %if.end + +if.end: + ret void +} + declare void @use(...) attributes #0 = { nounwind }