Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Fold Ext(i1) Pred shr(A, BW - 1) => i1 Pred A s< 0 #68244

Merged
merged 2 commits into from
Oct 13, 2023

Conversation

XChy
Copy link
Member

@XChy XChy commented Oct 4, 2023

Resolves #67916 .
This patch folds Ext(icmp (A, xxx)) Pred shr(A, BW - 1) into i1 Pred A s< 0.
Alive2.

@llvmbot
Copy link

llvmbot commented Oct 4, 2023

@llvm/pr-subscribers-llvm-transforms

Changes

Resolves #67916 .
This patch extends foldICmpEquality to fold zext(icmp (A, xxx)) == shr(A, BW - 1) into not(trunc(xor(zext(icmp), shl))).
Here I think xor would be better for i1 type than eq.
Alive2.


Full diff: https://github.com/llvm/llvm-project/pull/68244.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+29-14)
  • (modified) llvm/test/Transforms/InstCombine/icmp-shr.ll (+3-4)
  • (modified) llvm/test/Transforms/InstCombine/icmp-xor-signbit.ll (+80)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 9f034aba874a8c4..a0a45c73695f5c9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5311,11 +5311,7 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
       return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType()));
   }
 
-  // Test if 2 values have different or same signbits:
-  // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0
-  // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1
-  // (X s>> BitWidth - 1) == sext (Y s> -1) --> (X ^ Y) < 0
-  // (X s>> BitWidth - 1) != sext (Y s> -1) --> (X ^ Y) > -1
+  // Signbit test
   Instruction *ExtI;
   if (match(Op1, m_CombineAnd(m_Instruction(ExtI), m_ZExtOrSExt(m_Value(A)))) &&
       (Op0->hasOneUse() || Op1->hasOneUse())) {
@@ -5325,17 +5321,36 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
     ICmpInst::Predicate Pred2;
     if (match(Op0, m_CombineAnd(m_Instruction(ShiftI),
                                 m_Shr(m_Value(X),
-                                      m_SpecificIntAllowUndef(OpWidth - 1)))) &&
-        match(A, m_ICmp(Pred2, m_Value(Y), m_AllOnes())) &&
-        Pred2 == ICmpInst::ICMP_SGT && X->getType() == Y->getType()) {
+                                      m_SpecificIntAllowUndef(OpWidth - 1))))) {
+      // Test if 2 values have different or same signbits:
+      // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0
+      // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1
+      // (X s>> BitWidth - 1) == sext (Y s> -1) --> (X ^ Y) < 0
+      // (X s>> BitWidth - 1) != sext (Y s> -1) --> (X ^ Y) > -1
       unsigned ExtOpc = ExtI->getOpcode();
       unsigned ShiftOpc = ShiftI->getOpcode();
-      if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) ||
-          (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) {
-        Value *Xor = Builder.CreateXor(X, Y, "xor.signbits");
-        Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor)
-                                               : Builder.CreateIsNotNeg(Xor);
-        return replaceInstUsesWith(I, R);
+
+      if (match(A, m_ICmp(Pred2, m_Value(Y), m_AllOnes())) &&
+          Pred2 == ICmpInst::ICMP_SGT && X->getType() == Y->getType()) {
+        if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) ||
+            (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) {
+          Value *Xor = Builder.CreateXor(X, Y, "xor.signbits");
+          Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor)
+                                                 : Builder.CreateIsNotNeg(Xor);
+          return replaceInstUsesWith(I, R);
+        }
+      }
+
+      // Transform (X < 0 ==/!= icmp(X)) into (not) xor(X < 0, icmp(X))
+      if (match(A, m_c_ICmp(Pred2, m_Value(X), m_Value())) &&
+          ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) {
+
+        Value *Xor = Builder.CreateXor(Op0, Op1, "xor.ne");
+        Value *Trunc = Builder.CreateSExtOrTrunc(Xor, A->getType(), "eq.trunc");
+        Value *Not = (Pred == ICmpInst::ICMP_EQ)
+                         ? Builder.CreateNot(Trunc, "eq.not")
+                         : Trunc;
+        return replaceInstUsesWith(I, Not);
       }
     }
   }
diff --git a/llvm/test/Transforms/InstCombine/icmp-shr.ll b/llvm/test/Transforms/InstCombine/icmp-shr.ll
index f4dfa2edfa17710..b0ecd5ad6a01b2f 100644
--- a/llvm/test/Transforms/InstCombine/icmp-shr.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-shr.ll
@@ -1397,11 +1397,10 @@ define <2 x i1> @same_signbit_poison_elts(<2 x i8> %x, <2 x i8> %y) {
 
 define i1 @same_signbit_wrong_type(i8 %x, i32 %y) {
 ; CHECK-LABEL: @same_signbit_wrong_type(
-; CHECK-NEXT:    [[XSIGN:%.*]] = lshr i8 [[X:%.*]], 7
 ; CHECK-NEXT:    [[YPOS:%.*]] = icmp sgt i32 [[Y:%.*]], -1
-; CHECK-NEXT:    [[YPOSZ:%.*]] = zext i1 [[YPOS]] to i8
-; CHECK-NEXT:    [[R:%.*]] = icmp ne i8 [[XSIGN]], [[YPOSZ]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i8 [[X:%.*]], 0
+; CHECK-NEXT:    [[TMP2:%.*]] = xor i1 [[TMP1]], [[YPOS]]
+; CHECK-NEXT:    ret i1 [[TMP2]]
 ;
   %xsign = lshr i8 %x, 7
   %ypos = icmp sgt i32 %y, -1
diff --git a/llvm/test/Transforms/InstCombine/icmp-xor-signbit.ll b/llvm/test/Transforms/InstCombine/icmp-xor-signbit.ll
index 29a18ebbdd94e16..f4286023779a5a7 100644
--- a/llvm/test/Transforms/InstCombine/icmp-xor-signbit.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-xor-signbit.ll
@@ -217,3 +217,83 @@ define <2 x i1> @negative_simplify_splat(<4 x i8> %x) {
   ret <2 x i1> %c
 }
 
+
+define i1 @slt_zero_eq_ne_0(i32 %a) {
+; CHECK-LABEL: @slt_zero_eq_ne_0(
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i32 [[A:%.*]], 1
+; CHECK-NEXT:    ret i1 [[TMP1]]
+;
+  %cmp = icmp ne i32 %a, 0
+  %conv = zext i1 %cmp to i32
+  %cmp1 = lshr i32 %a, 31
+  %cmp2 = icmp eq i32 %conv, %cmp1
+  ret i1 %cmp2
+}
+
+define i1 @slt_zero_ne_ne_0(i32 %a) {
+; CHECK-LABEL: @slt_zero_ne_ne_0(
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp sgt i32 [[A:%.*]], 0
+; CHECK-NEXT:    ret i1 [[TMP1]]
+;
+  %cmp = icmp ne i32 %a, 0
+  %conv = zext i1 %cmp to i32
+  %cmp1 = lshr i32 %a, 31
+  %cmp2 = icmp ne i32 %conv, %cmp1
+  ret i1 %cmp2
+}
+
+define <4 x i1> @slt_zero_eq_ne_0_vec(<4 x i32> %a) {
+; CHECK-LABEL: @slt_zero_eq_ne_0_vec(
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt <4 x i32> [[A:%.*]], <i32 1, i32 1, i32 1, i32 1>
+; CHECK-NEXT:    ret <4 x i1> [[TMP1]]
+;
+  %cmp = icmp ne <4 x i32> %a, zeroinitializer
+  %conv = zext <4 x i1> %cmp to <4 x i32>
+  %cmp1 = lshr <4 x i32> %a, <i32 31, i32 31, i32 31, i32 31>
+  %cmp2 = icmp eq <4 x i32> %conv, %cmp1
+  ret <4 x i1> %cmp2
+}
+
+define i1 @slt_zero_ne_ne_b(i32 %a, i32 %b) {
+; CHECK-LABEL: @slt_zero_ne_ne_b(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i32 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i32 [[A]], 0
+; CHECK-NEXT:    [[TMP2:%.*]] = xor i1 [[TMP1]], [[CMP]]
+; CHECK-NEXT:    ret i1 [[TMP2]]
+;
+  %cmp = icmp ne i32 %a, %b
+  %conv = zext i1 %cmp to i32
+  %cmp1 = lshr i32 %a, 31
+  %cmp2 = icmp ne i32 %conv, %cmp1
+  ret i1 %cmp2
+}
+
+define i1 @slt_zero_eq_ne_0_fail1(i32 %a) {
+; CHECK-LABEL: @slt_zero_eq_ne_0_fail1(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i32 [[A:%.*]], 0
+; CHECK-NEXT:    [[CONV:%.*]] = zext i1 [[CMP]] to i32
+; CHECK-NEXT:    [[CMP1:%.*]] = ashr i32 [[A]], 31
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i32 [[CMP1]], [[CONV]]
+; CHECK-NEXT:    ret i1 [[CMP2]]
+;
+  %cmp = icmp ne i32 %a, 0
+  %conv = zext i1 %cmp to i32
+  %cmp1 = ashr i32 %a, 31
+  %cmp2 = icmp eq i32 %conv, %cmp1
+  ret i1 %cmp2
+}
+
+define i1 @slt_zero_eq_ne_0_fail2(i32 %a) {
+; CHECK-LABEL: @slt_zero_eq_ne_0_fail2(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i32 [[A:%.*]], 0
+; CHECK-NEXT:    [[CONV:%.*]] = zext i1 [[CMP]] to i32
+; CHECK-NEXT:    [[CMP1:%.*]] = lshr i32 [[A]], 30
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i32 [[CMP1]], [[CONV]]
+; CHECK-NEXT:    ret i1 [[CMP2]]
+;
+  %cmp = icmp ne i32 %a, 0
+  %conv = zext i1 %cmp to i32
+  %cmp1 = lshr i32 %a, 30
+  %cmp2 = icmp eq i32 %conv, %cmp1
+  ret i1 %cmp2
+}

@XChy XChy force-pushed the instcompare branch 2 times, most recently from c381062 to f4c691e Compare October 6, 2023 12:41
@goldsteinn
Copy link
Contributor

Second commit message needs to be updated.

@XChy XChy changed the title [InstCombine] Fold zext(icmp (A, xxx)) == shr(A, BW - 1) => not(trunc(xor(zext(icmp), shl))) [InstCombine] Fold zext(i1) == lshr(A, BW - 1) => i1 == A s< 0 Oct 7, 2023
@sftlbcn
Copy link

sftlbcn commented Oct 12, 2023

worth handling other predicates?
https://alive2.llvm.org/ce/z/EJAH6Z

define i1 @src(i1 %x, i32 %y) {
  %zx = zext i1 %x to i32
  %cmp1 = lshr i32 %y, 31
  %cmp2 = icmp ule i32 %cmp1, %zx
  ret i1 %cmp2
}
define i1 @tgt(i1 %x, i32 %y) {
  %cmp1 = icmp slt i32 %y, 0
  %cmp2 = icmp ule i1 %cmp1, %x
  ret i1 %cmp2
}

@XChy
Copy link
Member Author

XChy commented Oct 12, 2023

worth handling other predicates? https://alive2.llvm.org/ce/z/EJAH6Z

Thanks for reminder, I think that's worthwhile and convenient to implement here.
zext(i1) pred lshr(A, BW - 1) -> i1 pred (A s< 0)
The pattern will be added later.

@goldsteinn
Copy link
Contributor

LGTM.

@nikic
Copy link
Contributor

nikic commented Oct 12, 2023

Can you please update the patch description and alive proof? It doesn't seem to match what is actually implemented.

@nikic
Copy link
Contributor

nikic commented Oct 12, 2023

And I think the tests need an update too? It doesn't look like the most basic case is tested at all (where the zext i1 and lshr are not correlated).

@XChy
Copy link
Member Author

XChy commented Oct 13, 2023

New proof: alive2

@XChy XChy changed the title [InstCombine] Fold zext(i1) == lshr(A, BW - 1) => i1 == A s< 0 [InstCombine] Fold Zext(i1) Pred lshr(A, BW - 1) => i1 Pred A s< 0 Oct 13, 2023
@XChy XChy changed the title [InstCombine] Fold Zext(i1) Pred lshr(A, BW - 1) => i1 Pred A s< 0 [InstCombine] Fold ZExt(i1) Pred lshr(A, BW - 1) => i1 Pred A s< 0 Oct 13, 2023
@XChy XChy changed the title [InstCombine] Fold ZExt(i1) Pred lshr(A, BW - 1) => i1 Pred A s< 0 [InstCombine] Fold Ext(i1) Pred shr(A, BW - 1) => i1 Pred A s< 0 Oct 13, 2023
@XChy
Copy link
Member Author

XChy commented Oct 13, 2023

Tests about ashr-sext pair will be added later.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[InstCombine] !a == (a < 0) is not optimized
5 participants