Skip to content

Refactor RewriteBitwiseOps #677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/smack/RewriteBitwiseOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,33 @@

#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
#include <boost/optional.hpp>

namespace smack {

class RewriteBitwiseOps : public llvm::ModulePass {
public:
enum class SpecialCase {
Shr, // Right shift
Shl, // Left shift
AndPoTMO, // And power of 2 minus 1
AndNPoT // And negative power of 2
};
static char ID; // Pass identification, replacement for typeid
RewriteBitwiseOps() : llvm::ModulePass(ID) {}
virtual llvm::StringRef getPassName() const;
virtual bool runOnModule(llvm::Module &m);

private:
boost::optional<SpecialCase> getSpecialCase(llvm::Instruction *i);
static llvm::Instruction *rewriteShift(llvm::Instruction::BinaryOps,
llvm::Instruction *);
static llvm::Instruction *rewriteShr(llvm::Instruction *);
static llvm::Instruction *rewriteShl(llvm::Instruction *);
static llvm::Instruction *rewriteAndPoTMO(llvm::Instruction *);
static llvm::Instruction *rewriteAndNPot(llvm::Instruction *);
llvm::Instruction *rewriteSpecialCase(SpecialCase, llvm::Instruction *);
boost::optional<llvm::Instruction *> rewriteGeneralCase(llvm::Instruction *);
};
} // namespace smack

Expand Down
238 changes: 125 additions & 113 deletions lib/smack/RewriteBitwiseOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,143 +18,155 @@
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <string>
#include <unordered_map>
#include <utility>

namespace smack {

using namespace llvm;
using SC = RewriteBitwiseOps::SpecialCase;

unsigned getOpBitWidth(const Value *V) {
const Type *T = V->getType();
const IntegerType *iType = cast<IntegerType>(T);
const IntegerType *iType = cast<const IntegerType>(T);
return iType->getBitWidth();
}

Optional<APInt> getConstantIntValue(const Value *V) {
if (const ConstantInt *ci = dyn_cast<ConstantInt>(V)) {
const APInt &value = ci->getValue();
return Optional<APInt>(value);
} else {
return Optional<APInt>();
int getConstOpId(const Instruction *i) {
return llvm::isa<ConstantInt>(i->getOperand(1))
? 1
: (llvm::isa<ConstantInt>(i->getOperand(0)) ? 0 : -1);
}

void copyDbgMetadata(const Instruction *src, Instruction *dest) {
dest->setMetadata("dbg", src->getMetadata("dbg"));
}

Instruction *RewriteBitwiseOps::rewriteShift(Instruction::BinaryOps newOp,
Instruction *i) {
auto ci = llvm::cast<ConstantInt>(i->getOperand(1));
auto lhs = i->getOperand(0);
unsigned bitWidth = getOpBitWidth(lhs);
APInt rhsVal = APInt(bitWidth, "1", 10);
const APInt &value = ci->getValue();
rhsVal <<= value;
Value *rhs = ConstantInt::get(ci->getType(), rhsVal);
return BinaryOperator::Create(newOp, lhs, rhs, "", i);
}

Instruction *RewriteBitwiseOps::rewriteShr(Instruction *i) {
// Shifting right by a constant amount is equivalent to dividing by
// 2^amount
return rewriteShift(Instruction::SDiv, i);
}

Instruction *RewriteBitwiseOps::rewriteShl(Instruction *i) {
// Shifting left by a constant amount is equivalent to dividing by
// 2^amount
return rewriteShift(Instruction::Mul, i);
}

Instruction *RewriteBitwiseOps::rewriteAndPoTMO(Instruction *i) {
// If the operation is a bit-wise `and' and the mask variable is
// constant, it may be possible to replace this operation with a
// remainder operation. If one argument has only ones, and they're only
// in the least significant bit, then the mask is 2^(number of ones)
// - 1. This is equivalent to the remainder when dividing by 2^(number
// of ones).
// Zvonimir: Right-side constants do not work well on SVCOMP
// } else if (isConstantInt(args[1])) {
// maskOperand = 1;
// }
auto constOpId = getConstOpId(i);
auto constVal =
(llvm::cast<ConstantInt>(i->getOperand(constOpId)))->getValue();
auto lhs = i->getOperand(1 - constOpId);
Value *rhs = ConstantInt::get(i->getContext(), constVal + 1);
return BinaryOperator::Create(Instruction::URem, lhs, rhs, "", i);
}

Instruction *RewriteBitwiseOps::rewriteAndNPot(Instruction *i) {
// E.g., rewrite ``x && -32'' into ``x - x % 32''
auto constOpId = getConstOpId(i);
auto constVal =
(llvm::cast<ConstantInt>(i->getOperand(constOpId)))->getValue();
auto lhs = i->getOperand(1 - constOpId);
auto rem = BinaryOperator::Create(
Instruction::URem, lhs, ConstantInt::get(i->getContext(), 0 - constVal),
"", i);
copyDbgMetadata(i, rem);
return BinaryOperator::Create(Instruction::Sub, lhs, rem, "", i);
}

boost::optional<SC> RewriteBitwiseOps::getSpecialCase(Instruction *i) {
if (auto bi = dyn_cast<BinaryOperator>(i)) {
auto opcode = bi->getOpcode();
if (bi->isShift() && llvm::isa<ConstantInt>(bi->getOperand(1))) {
if (opcode == Instruction::AShr || opcode == Instruction::LShr)
return SC::Shr;
else if (opcode == Instruction::Shl)
return SC::Shl;
} else if (opcode == Instruction::And) {
auto constOpId = getConstOpId(i);
if (constOpId >= 0) {
auto val =
(llvm::cast<ConstantInt>(i->getOperand(constOpId)))->getValue();
if ((val + 1).isPowerOf2())
return SC::AndPoTMO;
else if ((0 - val).isPowerOf2())
return SC::AndNPoT;
}
}
}
return boost::none;
}

Instruction *RewriteBitwiseOps::rewriteSpecialCase(SC sc, Instruction *i) {
static std::unordered_map<SC, Instruction *(*)(Instruction *)> table{
{SC::Shr, &rewriteShr},
{SC::Shl, &rewriteShl},
{SC::AndPoTMO, &rewriteAndPoTMO},
{SC::AndNPoT, &rewriteAndNPot}};
return table[sc](i);
}

bool isConstantInt(const Value *V) {
Optional<APInt> constantValue = getConstantIntValue(V);
return constantValue.hasValue() && (*constantValue + 1).isPowerOf2();
boost::optional<Instruction *>
RewriteBitwiseOps::rewriteGeneralCase(Instruction *i) {
auto opCode = i->getOpcode();
if (opCode != Instruction::And && opCode != Instruction::Or)
return boost::none;
unsigned bitWidth = getOpBitWidth(i);
if (bitWidth > 64 || bitWidth < 8 || (bitWidth & (bitWidth - 1)))
return boost::none;
std::string func =
std::string("__SMACK_") + i->getOpcodeName() + std::to_string(bitWidth);
Function *co = i->getModule()->getFunction(func);
assert(co != NULL && "Function __SMACK_[and|or] should be present.");
std::vector<Value *> args;
args.push_back(i->getOperand(0));
args.push_back(i->getOperand(1));
return CallInst::Create(co, args, "", i);
}

bool RewriteBitwiseOps::runOnModule(Module &m) {
std::vector<Instruction *> instsFrom;
std::vector<Instruction *> instsTo;
std::vector<std::pair<Instruction *, Instruction *>> insts;

for (auto &F : m) {
if (Naming::isSmackName(F.getName()))
continue;
for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
if (I->isShift()) {
BinaryOperator *bi = cast<BinaryOperator>(&*I);
Value *amount = bi->getOperand(1);

if (ConstantInt *ci = dyn_cast<ConstantInt>(amount)) {
unsigned opcode = bi->getOpcode();
Instruction::BinaryOps op;
if (opcode == Instruction::AShr || opcode == Instruction::LShr) {
// Shifting right by a constant amount is equivalent to dividing by
// 2^amount
op = Instruction::SDiv;
} else if (opcode == Instruction::Shl) {
// Shifting left by a constant amount is equivalent to dividing by
// 2^amount
op = Instruction::Mul;
}

auto lhs = bi->getOperand(0);
unsigned bitWidth = getOpBitWidth(lhs);
APInt rhsVal = APInt(bitWidth, "1", 10);
const APInt &value = ci->getValue();
rhsVal <<= value;
Value *rhs = ConstantInt::get(ci->getType(), rhsVal);
Instruction *replacement = BinaryOperator::Create(op, lhs, rhs, "");
instsFrom.push_back(&*I);
instsTo.push_back(replacement);
}
} else if (I->isBitwiseLogicOp()) {
// If the operation is a bit-wise `and' and the mask variable is
// constant, it may be possible to replace this operation with a
// remainder operation. If one argument has only ones, and they're only
// in the least significant bit, then the mask is 2^(number of ones)
// - 1. This is equivalent to the remainder when dividing by 2^(number
// of ones).
BinaryOperator *bi = cast<BinaryOperator>(&*I);
unsigned opcode = bi->getOpcode();
if (opcode == Instruction::And) {
Value *args[] = {bi->getOperand(0), bi->getOperand(1)};
int maskOperand = -1;

if (isConstantInt(args[0])) {
maskOperand = 0;
}
// Zvonimir: Right-side constants do not work well on SVCOMP
// } else if (isConstantInt(args[1])) {
// maskOperand = 1;
// }

if (maskOperand == -1) {
unsigned bitWidth = getOpBitWidth(&*I);
Function *co;
if (bitWidth == 64) {
co = m.getFunction("__SMACK_and64");
} else if (bitWidth == 32) {
co = m.getFunction("__SMACK_and32");
} else if (bitWidth == 16) {
co = m.getFunction("__SMACK_and16");
} else if (bitWidth == 8) {
co = m.getFunction("__SMACK_and8");
} else {
continue;
}
assert(co != NULL && "Function __SMACK_and should be present.");
std::vector<Value *> args;
args.push_back(bi->getOperand(0));
args.push_back(bi->getOperand(1));
instsFrom.push_back(&*I);
instsTo.push_back(CallInst::Create(co, args, ""));
} else {
auto lhs = args[1 - maskOperand];
Value *rhs = ConstantInt::get(
m.getContext(), *getConstantIntValue(args[maskOperand]) + 1);
Instruction *replacement =
BinaryOperator::Create(Instruction::URem, lhs, rhs, "");
instsFrom.push_back(&*I);
instsTo.push_back(replacement);
}
} else if (opcode == Instruction::Or) {
unsigned bitWidth = getOpBitWidth(&*I);
Function *co;
if (bitWidth == 64) {
co = m.getFunction("__SMACK_or64");
} else if (bitWidth == 32) {
co = m.getFunction("__SMACK_or32");
} else if (bitWidth == 16) {
co = m.getFunction("__SMACK_or16");
} else if (bitWidth == 8) {
co = m.getFunction("__SMACK_or8");
} else {
continue;
}
assert(co != NULL && "Function __SMACK_or should be present.");
std::vector<Value *> args;
args.push_back(bi->getOperand(0));
args.push_back(bi->getOperand(1));
instsFrom.push_back(&*I);
instsTo.push_back(CallInst::Create(co, args, ""));
}
}
Instruction *i = &*I;
if (auto sc = getSpecialCase(i))
insts.push_back(std::make_pair(i, rewriteSpecialCase(*sc, i)));
else if (auto oi = rewriteGeneralCase(i))
insts.push_back(std::make_pair(i, *oi));
}
}

for (size_t i = 0; i < instsFrom.size(); ++i) {
ReplaceInstWithInst(instsFrom[i], instsTo[i]);
for (auto &p : insts) {
copyDbgMetadata(p.first, p.second);
p.first->replaceAllUsesWith(p.second);
p.first->removeFromParent();
}

return true;
Expand Down
64 changes: 64 additions & 0 deletions test/c/special/bitwise_constant_and.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "smack.h"

// @expect verified
// @flag --rewrite-bitwise-ops

int main(void) {
unsigned char x = __VERIFIER_nondet_unsigned_char();
// 2**0
assert((x & 0) == 0);

// 2**1
assert((x & 1) == x % 2);
assert((1 & x) == x % 2);

assert((x & -2) == x - x % 2);
assert((-2 & x) == x - x % 2);

// 2**2
assert((x & 3) == x % 4);
assert((3 & x) == x % 4);

assert((x & -4) == x - x % 4);
assert((-4 & x) == x - x % 4);

// 2**3
assert((x & 7) == x % 8);
assert((7 & x) == x % 8);

assert((x & -8) == x - x % 8);
assert((-8 & x) == x - x % 8);

// 2**4
assert((x & 15) == x % 16);
assert((15 & x) == x % 16);

assert((x & -16) == x - x % 16);
assert((-16 & x) == x - x % 16);

// 2**5
assert((x & 31) == x % 32);
assert((31 & x) == x % 32);

assert((x & -32) == x - x % 32);
assert((-32 & x) == x - x % 32);

// 2**6
assert((x & 63) == x % 64);
assert((63 & x) == x % 64);

assert((x & -64) == x - x % 64);
assert((-64 & x) == x - x % 64);

// 2**7
assert((x & 127) == x % 128);
assert((127 & x) == x % 128);

assert((x & -128) == x - x % 128);
assert((-128 & x) == x - x % 128);

// 2**8
assert((x & 255) == x);

return 0;
}
Loading