[Analysis] Only fold trunc X to (X & Mask) if Mask == getLowBitsSet(bit width of original operands)#176589
[Analysis] Only fold trunc X to (X & Mask) if Mask == getLowBitsSet(bit width of original operands)#176589
trunc X to (X & Mask) if Mask == getLowBitsSet(bit width of original operands)#176589Conversation
|
@llvm/pr-subscribers-llvm-analysis Author: Tirthankar Mazumder (wermos) ChangesImplements the refinement mentioned in #171195 (review). We only perform the fold if the resultant mask is equal to Full diff: https://github.com/llvm/llvm-project/pull/176589.diff 1 Files Affected:
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 880006c0fcfac..5e434f3e711fe 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/CmpInstAnalysis.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/PatternMatch.h"
@@ -164,9 +165,16 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
// Try to convert (trunc X) eq/ne C into (X & Mask) eq/ne C
if (LookThroughTrunc && isa<TruncInst>(LHS)) {
- Result.Pred = Pred;
- Result.Mask = APInt::getAllOnes(C.getBitWidth());
- Result.C = C;
+ auto *TI = dyn_cast<TruncInst>(LHS);
+ unsigned SrcBW = TI->getSrcTy()->getScalarSizeInBits(),
+ DstBW = TI->getDestTy()->getScalarSizeInBits();
+ APInt DesiredMask = APInt::getLowBitsSet(SrcBW, DstBW);
+ APInt Mask = APInt::getAllOnes(C.getBitWidth()).zext(SrcBW);
+ if (Mask == DesiredMask) {
+ Result.Pred = Pred;
+ Result.Mask = Mask;
+ Result.C = C;
+ }
break;
}
|
trunc X to (X & Mask) if Mask == getLowBitsSet(bit width of original operands)trunc X to (X & Mask) if Mask == getLowBitsSet(bit width of original operands)
|
@dtcxzyw I'm not exactly sure how #171195 (comment) is supposed to be handled. Do you have any suggestions? Also, I think we can run |
|
@andjo403 In #171195 you suggested writing the llvm-project/llvm/lib/Analysis/CmpInstAnalysis.cpp Lines 165 to 171 in ae425ab but I'm not sure if this is correct. Imagine we have a snippet like %v1 = trunc i64 %x to i32
%v2 = icmp eq i32 %v1, 0The I think the way it was written before is more correct, because the mask needs to be |
|
Sorry I do not understand this change as it will compare two numbers that are always equal. llvm-project/llvm/lib/Analysis/CmpInstAnalysis.cpp Lines 184 to 190 in 0ece357 for your example the APInt::getAllOnes(C.getBitWidth()) will be a i32 with all bits set to one and that is then zext to a i64 with 32bits set at the code from the link above. |
|
Ah, I see. Your suggestion makes sense to me now.
What you're saying makes sense. Did I misinterpret the comment in #171195 (review)?
|
|
@dtcxzyw could you shed some light on this matter? |
|
I mean changing the code here: llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp Lines 6291 to 6298 in d7cbc7f This should allow patterns like However, I double checked and found that it didn't happen, since it was already guarded by the following check ( llvm-project/llvm/lib/Analysis/CmpInstAnalysis.cpp Lines 83 to 85 in 010f6c8 We need more time to investigate the root cause of regressions, as demonstrated in https://github.com/dtcxzyw/llvm-opt-benchmark/pull/3316/files (See bench/llvm/optimized/SemaType.ll, bench/libigl/optimized/bijective_composite_harmonic_mapping.ll, and bench/cpython/optimized/_ssl.ll). With #171195, new masking instructions are introduced in multiple IR files. So I guess there are some optimizations converting truncs into ands unconditionally. |
|
Thanks for the reply. I'm going to mark this PR as a draft for now, while I try to investigate the root cause of those regressions |
|
I'm closing this PR for now. I'll open a new one when I get around to investigating the regressions. |
Implements the refinement mentioned in #171195 (review).
We only perform the fold if the resultant mask is equal to
getLowBitsSet(bit width of original operands).