Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Split the FMul with reassoc into a helper function, NFC #71493

Merged
merged 1 commit into from
Nov 7, 2023

Conversation

vfdff
Copy link
Contributor

@vfdff vfdff commented Nov 7, 2023

The reassoc check is really hard to find because the handle branch it too large, so spilt it into a helper function.

The reassoc check is really hard to find because the handle branch
it too large, so spilt it into a helper function.
@vfdff vfdff requested a review from arsenm November 7, 2023 07:10
@vfdff vfdff requested a review from nikic as a code owner November 7, 2023 07:10
@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2023

@llvm/pr-subscribers-llvm-transforms

Author: Allen (vfdff)

Changes

The reassoc check is really hard to find because the handle branch it too large, so spilt it into a helper function.


Full diff: https://github.com/llvm/llvm-project/pull/71493.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+1)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+177-170)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 34b10220ec88aba..68a8fb676d8d909 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -98,6 +98,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *visitSub(BinaryOperator &I);
   Instruction *visitFSub(BinaryOperator &I);
   Instruction *visitMul(BinaryOperator &I);
+  Instruction *foldFMulReassoc(BinaryOperator &I);
   Instruction *visitFMul(BinaryOperator &I);
   Instruction *visitURem(BinaryOperator &I);
   Instruction *visitSRem(BinaryOperator &I);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index bc784390c23be49..db0804380855e3a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -560,6 +560,180 @@ Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) {
   return nullptr;
 }
 
+Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
+  Value *Op0 = I.getOperand(0);
+  Value *Op1 = I.getOperand(1);
+  Value *X, *Y;
+  Constant *C;
+
+  // Reassociate constant RHS with another constant to form constant
+  // expression.
+  if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) {
+    Constant *C1;
+    if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) {
+      // (C1 / X) * C --> (C * C1) / X
+      Constant *CC1 =
+          ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL);
+      if (CC1 && CC1->isNormalFP())
+        return BinaryOperator::CreateFDivFMF(CC1, X, &I);
+    }
+    if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) {
+      // (X / C1) * C --> X * (C / C1)
+      Constant *CDivC1 =
+          ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL);
+      if (CDivC1 && CDivC1->isNormalFP())
+        return BinaryOperator::CreateFMulFMF(X, CDivC1, &I);
+
+      // If the constant was a denormal, try reassociating differently.
+      // (X / C1) * C --> X / (C1 / C)
+      Constant *C1DivC =
+          ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL);
+      if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP())
+        return BinaryOperator::CreateFDivFMF(X, C1DivC, &I);
+    }
+
+    // We do not need to match 'fadd C, X' and 'fsub X, C' because they are
+    // canonicalized to 'fadd X, C'. Distributing the multiply may allow
+    // further folds and (X * C) + C2 is 'fma'.
+    if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) {
+      // (X + C1) * C --> (X * C) + (C * C1)
+      if (Constant *CC1 =
+              ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) {
+        Value *XC = Builder.CreateFMulFMF(X, C, &I);
+        return BinaryOperator::CreateFAddFMF(XC, CC1, &I);
+      }
+    }
+    if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) {
+      // (C1 - X) * C --> (C * C1) - (X * C)
+      if (Constant *CC1 =
+              ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) {
+        Value *XC = Builder.CreateFMulFMF(X, C, &I);
+        return BinaryOperator::CreateFSubFMF(CC1, XC, &I);
+      }
+    }
+  }
+
+  Value *Z;
+  if (match(&I,
+            m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))), m_Value(Z)))) {
+    // Sink division: (X / Y) * Z --> (X * Z) / Y
+    Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I);
+    return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I);
+  }
+
+  // sqrt(X) * sqrt(Y) -> sqrt(X * Y)
+  // nnan disallows the possibility of returning a number if both operands are
+  // negative (in that case, we should return NaN).
+  if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) &&
+      match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) {
+    Value *XY = Builder.CreateFMulFMF(X, Y, &I);
+    Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I);
+    return replaceInstUsesWith(I, Sqrt);
+  }
+
+  // The following transforms are done irrespective of the number of uses
+  // for the expression "1.0/sqrt(X)".
+  //  1) 1.0/sqrt(X) * X -> X/sqrt(X)
+  //  2) X * 1.0/sqrt(X) -> X/sqrt(X)
+  // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it
+  // has the necessary (reassoc) fast-math-flags.
+  if (I.hasNoSignedZeros() &&
+      match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
+      match(Y, m_Sqrt(m_Value(X))) && Op1 == X)
+    return BinaryOperator::CreateFDivFMF(X, Y, &I);
+  if (I.hasNoSignedZeros() &&
+      match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
+      match(Y, m_Sqrt(m_Value(X))) && Op0 == X)
+    return BinaryOperator::CreateFDivFMF(X, Y, &I);
+
+  // Like the similar transform in instsimplify, this requires 'nsz' because
+  // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
+  if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && Op0->hasNUses(2)) {
+    // Peek through fdiv to find squaring of square root:
+    // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
+    if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) {
+      Value *XX = Builder.CreateFMulFMF(X, X, &I);
+      return BinaryOperator::CreateFDivFMF(XX, Y, &I);
+    }
+    // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X)
+    if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) {
+      Value *XX = Builder.CreateFMulFMF(X, X, &I);
+      return BinaryOperator::CreateFDivFMF(Y, XX, &I);
+    }
+  }
+
+  // pow(X, Y) * X --> pow(X, Y+1)
+  // X * pow(X, Y) --> pow(X, Y+1)
+  if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X),
+                                                              m_Value(Y))),
+                         m_Deferred(X)))) {
+    Value *Y1 = Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I);
+    Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I);
+    return replaceInstUsesWith(I, Pow);
+  }
+
+  if (I.isOnlyUserOfAnyOperand()) {
+    // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z)
+    if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
+        match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) {
+      auto *YZ = Builder.CreateFAddFMF(Y, Z, &I);
+      auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I);
+      return replaceInstUsesWith(I, NewPow);
+    }
+    // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y)
+    if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
+        match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) {
+      auto *XZ = Builder.CreateFMulFMF(X, Z, &I);
+      auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I);
+      return replaceInstUsesWith(I, NewPow);
+    }
+
+    // powi(x, y) * powi(x, z) -> powi(x, y + z)
+    if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) &&
+        match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) &&
+        Y->getType() == Z->getType()) {
+      auto *YZ = Builder.CreateAdd(Y, Z);
+      auto *NewPow = Builder.CreateIntrinsic(
+          Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I);
+      return replaceInstUsesWith(I, NewPow);
+    }
+
+    // exp(X) * exp(Y) -> exp(X + Y)
+    if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) &&
+        match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) {
+      Value *XY = Builder.CreateFAddFMF(X, Y, &I);
+      Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I);
+      return replaceInstUsesWith(I, Exp);
+    }
+
+    // exp2(X) * exp2(Y) -> exp2(X + Y)
+    if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) &&
+        match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y)))) {
+      Value *XY = Builder.CreateFAddFMF(X, Y, &I);
+      Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I);
+      return replaceInstUsesWith(I, Exp2);
+    }
+  }
+
+  // (X*Y) * X => (X*X) * Y where Y != X
+  //  The purpose is two-fold:
+  //   1) to form a power expression (of X).
+  //   2) potentially shorten the critical path: After transformation, the
+  //  latency of the instruction Y is amortized by the expression of X*X,
+  //  and therefore Y is in a "less critical" position compared to what it
+  //  was before the transformation.
+  if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) && Op1 != Y) {
+    Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I);
+    return BinaryOperator::CreateFMulFMF(XX, Y, &I);
+  }
+  if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) && Op0 != Y) {
+    Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I);
+    return BinaryOperator::CreateFMulFMF(XX, Y, &I);
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
   if (Value *V = simplifyFMulInst(I.getOperand(0), I.getOperand(1),
                                   I.getFastMathFlags(),
@@ -607,176 +781,9 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
   if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
     return replaceInstUsesWith(I, V);
 
-  if (I.hasAllowReassoc()) {
-    // Reassociate constant RHS with another constant to form constant
-    // expression.
-    if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) {
-      Constant *C1;
-      if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) {
-        // (C1 / X) * C --> (C * C1) / X
-        Constant *CC1 =
-            ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL);
-        if (CC1 && CC1->isNormalFP())
-          return BinaryOperator::CreateFDivFMF(CC1, X, &I);
-      }
-      if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) {
-        // (X / C1) * C --> X * (C / C1)
-        Constant *CDivC1 =
-            ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL);
-        if (CDivC1 && CDivC1->isNormalFP())
-          return BinaryOperator::CreateFMulFMF(X, CDivC1, &I);
-
-        // If the constant was a denormal, try reassociating differently.
-        // (X / C1) * C --> X / (C1 / C)
-        Constant *C1DivC =
-            ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL);
-        if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP())
-          return BinaryOperator::CreateFDivFMF(X, C1DivC, &I);
-      }
-
-      // We do not need to match 'fadd C, X' and 'fsub X, C' because they are
-      // canonicalized to 'fadd X, C'. Distributing the multiply may allow
-      // further folds and (X * C) + C2 is 'fma'.
-      if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) {
-        // (X + C1) * C --> (X * C) + (C * C1)
-        if (Constant *CC1 = ConstantFoldBinaryOpOperands(
-                Instruction::FMul, C, C1, DL)) {
-          Value *XC = Builder.CreateFMulFMF(X, C, &I);
-          return BinaryOperator::CreateFAddFMF(XC, CC1, &I);
-        }
-      }
-      if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) {
-        // (C1 - X) * C --> (C * C1) - (X * C)
-        if (Constant *CC1 = ConstantFoldBinaryOpOperands(
-                Instruction::FMul, C, C1, DL)) {
-          Value *XC = Builder.CreateFMulFMF(X, C, &I);
-          return BinaryOperator::CreateFSubFMF(CC1, XC, &I);
-        }
-      }
-    }
-
-    Value *Z;
-    if (match(&I, m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))),
-                           m_Value(Z)))) {
-      // Sink division: (X / Y) * Z --> (X * Z) / Y
-      Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I);
-      return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I);
-    }
-
-    // sqrt(X) * sqrt(Y) -> sqrt(X * Y)
-    // nnan disallows the possibility of returning a number if both operands are
-    // negative (in that case, we should return NaN).
-    if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) &&
-        match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) {
-      Value *XY = Builder.CreateFMulFMF(X, Y, &I);
-      Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I);
-      return replaceInstUsesWith(I, Sqrt);
-    }
-
-    // The following transforms are done irrespective of the number of uses
-    // for the expression "1.0/sqrt(X)".
-    //  1) 1.0/sqrt(X) * X -> X/sqrt(X)
-    //  2) X * 1.0/sqrt(X) -> X/sqrt(X)
-    // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it
-    // has the necessary (reassoc) fast-math-flags.
-    if (I.hasNoSignedZeros() &&
-        match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
-        match(Y, m_Sqrt(m_Value(X))) && Op1 == X)
-      return BinaryOperator::CreateFDivFMF(X, Y, &I);
-    if (I.hasNoSignedZeros() &&
-        match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
-        match(Y, m_Sqrt(m_Value(X))) && Op0 == X)
-      return BinaryOperator::CreateFDivFMF(X, Y, &I);
-
-    // Like the similar transform in instsimplify, this requires 'nsz' because
-    // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
-    if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 &&
-        Op0->hasNUses(2)) {
-      // Peek through fdiv to find squaring of square root:
-      // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
-      if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) {
-        Value *XX = Builder.CreateFMulFMF(X, X, &I);
-        return BinaryOperator::CreateFDivFMF(XX, Y, &I);
-      }
-      // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X)
-      if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) {
-        Value *XX = Builder.CreateFMulFMF(X, X, &I);
-        return BinaryOperator::CreateFDivFMF(Y, XX, &I);
-      }
-    }
-
-    // pow(X, Y) * X --> pow(X, Y+1)
-    // X * pow(X, Y) --> pow(X, Y+1)
-    if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X),
-                                                                m_Value(Y))),
-                           m_Deferred(X)))) {
-      Value *Y1 =
-          Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I);
-      Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I);
-      return replaceInstUsesWith(I, Pow);
-    }
-
-    if (I.isOnlyUserOfAnyOperand()) {
-      // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z)
-      if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
-          match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) {
-        auto *YZ = Builder.CreateFAddFMF(Y, Z, &I);
-        auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I);
-        return replaceInstUsesWith(I, NewPow);
-      }
-      // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y)
-      if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
-          match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) {
-        auto *XZ = Builder.CreateFMulFMF(X, Z, &I);
-        auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I);
-        return replaceInstUsesWith(I, NewPow);
-      }
-
-      // powi(x, y) * powi(x, z) -> powi(x, y + z)
-      if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) &&
-          match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) &&
-          Y->getType() == Z->getType()) {
-        auto *YZ = Builder.CreateAdd(Y, Z);
-        auto *NewPow = Builder.CreateIntrinsic(
-            Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I);
-        return replaceInstUsesWith(I, NewPow);
-      }
-
-      // exp(X) * exp(Y) -> exp(X + Y)
-      if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) &&
-          match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) {
-        Value *XY = Builder.CreateFAddFMF(X, Y, &I);
-        Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I);
-        return replaceInstUsesWith(I, Exp);
-      }
-
-      // exp2(X) * exp2(Y) -> exp2(X + Y)
-      if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) &&
-          match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y)))) {
-        Value *XY = Builder.CreateFAddFMF(X, Y, &I);
-        Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I);
-        return replaceInstUsesWith(I, Exp2);
-      }
-    }
-
-    // (X*Y) * X => (X*X) * Y where Y != X
-    //  The purpose is two-fold:
-    //   1) to form a power expression (of X).
-    //   2) potentially shorten the critical path: After transformation, the
-    //  latency of the instruction Y is amortized by the expression of X*X,
-    //  and therefore Y is in a "less critical" position compared to what it
-    //  was before the transformation.
-    if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) &&
-        Op1 != Y) {
-      Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I);
-      return BinaryOperator::CreateFMulFMF(XX, Y, &I);
-    }
-    if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) &&
-        Op0 != Y) {
-      Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I);
-      return BinaryOperator::CreateFMulFMF(XX, Y, &I);
-    }
-  }
+  if (I.hasAllowReassoc())
+    if (Instruction *FoldedMul = foldFMulReassoc(I))
+      return FoldedMul;
 
   // log2(X * 0.5) * Y = log2(X) * Y - Y
   if (I.isFast()) {

@vfdff vfdff merged commit a0cd626 into llvm:main Nov 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants