From e83f9c1bb6c24e028e542c9148a560da02b862a2 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Fri, 8 Jul 2022 22:17:08 +0200 Subject: [PATCH 01/57] WIP: Z3 all the things --- include/rellic/AST/CondBasedRefine.h | 7 +- include/rellic/AST/GenerateAST.h | 27 +--- include/rellic/AST/IRToASTVisitor.h | 1 + include/rellic/AST/MaterializeConds.h | 38 +++++ include/rellic/AST/ReachBasedRefine.h | 5 +- include/rellic/AST/Util.h | 21 +++ include/rellic/AST/Z3CondSimplify.h | 38 ----- lib/AST/CondBasedRefine.cpp | 98 +++++++----- lib/AST/DeadStmtElim.cpp | 12 +- lib/AST/GenerateAST.cpp | 222 ++++++++++++-------------- lib/AST/IRToASTVisitor.cpp | 60 +++++++ lib/AST/LoopRefine.cpp | 87 ++++++---- lib/AST/MaterializeConds.cpp | 53 ++++++ lib/AST/NestedCondProp.cpp | 121 +++++++++----- lib/AST/NestedScopeCombine.cpp | 17 +- lib/AST/ReachBasedRefine.cpp | 186 ++++++++++----------- lib/AST/Util.cpp | 12 ++ lib/AST/Z3CondSimplify.cpp | 113 ++----------- lib/CMakeLists.txt | 2 + lib/Decompiler.cpp | 6 + tools/repl/Repl.cpp | 34 ++-- tools/xref/Xref.cpp | 60 ++++--- tools/xref/www/main.js | 20 ++- 23 files changed, 660 insertions(+), 580 deletions(-) create mode 100644 include/rellic/AST/MaterializeConds.h create mode 100644 lib/AST/MaterializeConds.cpp diff --git a/include/rellic/AST/CondBasedRefine.h b/include/rellic/AST/CondBasedRefine.h index e7d0f686..8f910048 100644 --- a/include/rellic/AST/CondBasedRefine.h +++ b/include/rellic/AST/CondBasedRefine.h @@ -35,12 +35,9 @@ namespace rellic { */ class CondBasedRefine : public TransformVisitor { private: - std::unique_ptr z3_ctx; - std::unique_ptr z3_gen; + using IfStmtVec = std::vector; - z3::tactic z3_solver; - - z3::expr GetZ3Cond(clang::IfStmt *ifstmt); + void CreateIfThenElseStmts(IfStmtVec stmts); protected: void RunImpl() override; diff --git a/include/rellic/AST/GenerateAST.h b/include/rellic/AST/GenerateAST.h index 7f31b997..3af1df12 100644 --- a/include/rellic/AST/GenerateAST.h +++ b/include/rellic/AST/GenerateAST.h @@ -25,28 +25,15 @@ class GenerateAST : public llvm::AnalysisInfoMixin { friend llvm::AnalysisInfoMixin; static llvm::AnalysisKey Key; + constexpr static unsigned POISON_IDX = static_cast(-1); + // Need to use `map` with these instead of `unordered_map`, because // `std::pair` doesn't have a default hash implementation - using BBEdge = std::pair; - using BrEdge = std::pair; - using SwEdge = std::pair; clang::ASTUnit &unit; clang::ASTContext *ast_ctx; rellic::IRToASTVisitor ast_gen; rellic::ASTBuilder ast; - z3::context *z_ctx; - Provenance &provenance; - - z3::expr_vector z_exprs; - std::unordered_map z_br_edges_inv; - std::map z_br_edges; - - std::unordered_map z_sw_edges_inv; - std::map z_sw_edges; - - std::map z_edges; - std::unordered_map reaching_conds; bool reaching_conds_changed{true}; std::unordered_map block_stmts; std::unordered_map region_stmts; @@ -57,16 +44,14 @@ class GenerateAST : public llvm::AnalysisInfoMixin { std::vector rpo_walk; - z3::expr GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond); - z3::expr GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, + unsigned GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond); + unsigned GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, llvm::ConstantInt *c); - z3::expr GetOrCreateEdgeCond(llvm::BasicBlock *from, llvm::BasicBlock *to); - z3::expr GetReachingCond(llvm::BasicBlock *block); + unsigned GetOrCreateEdgeCond(llvm::BasicBlock *from, llvm::BasicBlock *to); + unsigned GetReachingCond(llvm::BasicBlock *block); void CreateReachingCond(llvm::BasicBlock *block); - clang::Expr *ConvertExpr(z3::expr expr); - std::vector CreateBasicBlockStmts(llvm::BasicBlock *block); std::vector CreateRegionStmts(llvm::Region *region); diff --git a/include/rellic/AST/IRToASTVisitor.h b/include/rellic/AST/IRToASTVisitor.h index 720b6d2c..3e8436af 100644 --- a/include/rellic/AST/IRToASTVisitor.h +++ b/include/rellic/AST/IRToASTVisitor.h @@ -41,6 +41,7 @@ class IRToASTVisitor { clang::Expr *CreateOperandExpr(llvm::Use &val); clang::Expr *CreateConstantExpr(llvm::Constant *constant); + clang::Expr *ConvertExpr(z3::expr expr); void VisitGlobalVar(llvm::GlobalVariable &var); void VisitFunctionDecl(llvm::Function &func); diff --git a/include/rellic/AST/MaterializeConds.h b/include/rellic/AST/MaterializeConds.h new file mode 100644 index 00000000..88f944f9 --- /dev/null +++ b/include/rellic/AST/MaterializeConds.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021-present, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include + +#include "rellic/AST/IRToASTVisitor.h" +#include "rellic/AST/TransformVisitor.h" +#include "rellic/AST/Util.h" + +namespace rellic { + +/* + * This pass simplifies conditions using Z3 by trying to remove terms that are + * trivially true or false + */ +class MaterializeConds : public TransformVisitor { + private: + IRToASTVisitor ast_gen; + + protected: + void RunImpl() override; + + public: + MaterializeConds(Provenance &provenance, clang::ASTUnit &unit); + + bool VisitIfStmt(clang::IfStmt *stmt); + bool VisitWhileStmt(clang::WhileStmt *loop); + bool VisitDoStmt(clang::DoStmt *loop); +}; + +} // namespace rellic diff --git a/include/rellic/AST/ReachBasedRefine.h b/include/rellic/AST/ReachBasedRefine.h index 0bb97fb2..7ef9dbf0 100644 --- a/include/rellic/AST/ReachBasedRefine.h +++ b/include/rellic/AST/ReachBasedRefine.h @@ -38,10 +38,9 @@ namespace rellic { */ class ReachBasedRefine : public TransformVisitor { private: - std::unique_ptr z3_ctx; - std::unique_ptr z3_gen; + using IfStmtVec = std::vector; - z3::expr GetZ3Cond(clang::IfStmt *ifstmt); + void CreateIfElseStmts(IfStmtVec stmts); protected: void RunImpl() override; diff --git a/include/rellic/AST/Util.h b/include/rellic/AST/Util.h index f79ae3dc..4f60e096 100644 --- a/include/rellic/AST/Util.h +++ b/include/rellic/AST/Util.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -58,6 +59,11 @@ using IRToStmtMap = std::unordered_map; using ArgToTempMap = std::unordered_map; using BlockToUsesMap = std::unordered_map>; +using Z3CondMap = std::unordered_map; + +using BBEdge = std::pair; +using BrEdge = std::pair; +using SwEdge = std::pair; struct Provenance { StmtToIRMap stmt_provenance; ExprToUseMap use_provenance; @@ -65,6 +71,20 @@ struct Provenance { IRToValDeclMap value_decls; ArgToTempMap temp_decls; BlockToUsesMap outgoing_uses; + z3::context z3_ctx; + z3::expr_vector z3_exprs{z3_ctx}; + Z3CondMap conds; + + clang::Expr *marker_expr; + + std::unordered_map z3_br_edges_inv; + std::map z3_br_edges; + + std::unordered_map z3_sw_edges_inv; + std::map z3_sw_edges; + + std::map z3_edges; + std::unordered_map reaching_conds; size_t num_literal_structs = 0; size_t num_declared_structs = 0; @@ -88,4 +108,5 @@ z3::goal ApplyTactic(z3::context &ctx, const z3::tactic &tactic, z3::expr expr); bool Prove(z3::context &ctx, z3::expr expr); +z3::expr HeavySimplify(z3::context &ctx, z3::expr expr); } // namespace rellic \ No newline at end of file diff --git a/include/rellic/AST/Z3CondSimplify.h b/include/rellic/AST/Z3CondSimplify.h index e851c035..2ebf39a6 100644 --- a/include/rellic/AST/Z3CondSimplify.h +++ b/include/rellic/AST/Z3CondSimplify.h @@ -12,7 +12,6 @@ #include "rellic/AST/TransformVisitor.h" #include "rellic/AST/Util.h" -#include "rellic/AST/Z3ConvVisitor.h" namespace rellic { @@ -22,49 +21,12 @@ namespace rellic { */ class Z3CondSimplify : public TransformVisitor { private: - std::unique_ptr z_ctx; - std::unique_ptr z_gen; - - z3::expr ToZ3(clang::Expr *e); - - std::unordered_map hashes; - - struct Hash { - clang::ASTContext &ctx; - std::unordered_map &hashes; - std::size_t operator()(clang::Expr *e) const noexcept { - auto &hash{hashes[e]}; - if (!hash) { - hash = GetHash(ctx, e); - } - return hash; - } - }; - - struct KeyEqual { - bool operator()(clang::Expr *a, clang::Expr *b) const noexcept { - return IsEquivalent(a, b); - } - }; - Hash hash_adaptor; - KeyEqual ke_adaptor; - - std::unordered_map proven_true; - std::unordered_map proven_false; - - bool IsProvenTrue(clang::Expr *e); - bool IsProvenFalse(clang::Expr *e); - - clang::Expr *Simplify(clang::Expr *e); - protected: void RunImpl() override; public: Z3CondSimplify(Provenance &provenance, clang::ASTUnit &unit); - z3::context &GetZ3Context() { return *z_ctx; } - bool VisitIfStmt(clang::IfStmt *stmt); bool VisitWhileStmt(clang::WhileStmt *loop); bool VisitDoStmt(clang::DoStmt *loop); diff --git a/lib/AST/CondBasedRefine.cpp b/lib/AST/CondBasedRefine.cpp index c597e527..b89ecca5 100644 --- a/lib/AST/CondBasedRefine.cpp +++ b/lib/AST/CondBasedRefine.cpp @@ -18,53 +18,67 @@ namespace rellic { CondBasedRefine::CondBasedRefine(Provenance &provenance, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit), - z3_ctx(new z3::context()), - z3_gen(new rellic::Z3ConvVisitor(unit, z3_ctx.get())), - z3_solver(*z3_ctx, "sat") {} - -z3::expr CondBasedRefine::GetZ3Cond(clang::IfStmt *ifstmt) { - auto cond = ifstmt->getCond(); - auto expr = z3_gen->Z3BoolCast(z3_gen->GetOrCreateZ3Expr(cond)); - return expr.simplify(); -} + : TransformVisitor(provenance, unit) {} -bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { - std::vector body{compound->body_begin(), compound->body_end()}; - bool did_something{false}; - for (size_t i{0}; i + 1 < body.size() && !Stopped(); ++i) { - auto if_a{clang::dyn_cast(body[i])}; - auto if_b{clang::dyn_cast(body[i + 1])}; +void CondBasedRefine::CreateIfThenElseStmts(IfStmtVec worklist) { + auto RemoveFromWorkList = [&worklist](clang::Stmt *stmt) { + auto it = std::find(worklist.begin(), worklist.end(), stmt); + if (it != worklist.end()) { + worklist.erase(it); + } + }; + + auto ThenTest = [this](z3::expr lhs, z3::expr rhs) { + return Prove(provenance.z3_ctx, lhs == rhs); + }; + + auto ElseTest = [this](z3::expr lhs, z3::expr rhs) { + return Prove(provenance.z3_ctx, lhs == !rhs); + }; + + auto CombineTest = [this](z3::expr lhs, z3::expr rhs) { + return Prove(provenance.z3_ctx, z3::implies(rhs, lhs)); + }; + + while (!worklist.empty()) { + auto lhs = *worklist.begin(); + RemoveFromWorkList(lhs); + // Prepare conditions according to which we're going to + // cluster statements according to the whole `lhs` + // condition. + auto lcond = provenance.z3_exprs[provenance.conds[lhs]]; + // Get branch candidates wrt `clause` + std::vector thens({lhs}), elses; + for (auto rhs : worklist) { + auto rcond = provenance.z3_exprs[provenance.conds[rhs]]; + if (ThenTest(lcond, rcond) || CombineTest(lcond, rcond)) { + thens.push_back(rhs); + } else if (ElseTest(lcond, rcond) || CombineTest(!lcond, rcond)) { + elses.push_back(rhs); + } + } - if (!if_a || !if_b) { + // Check if we have enough statements to work with + if (thens.size() + elses.size() < 2) { continue; } - auto then_a{if_a->getThen()}; - auto then_b{if_b->getThen()}; - - auto else_a{if_a->getElse()}; - auto else_b{if_b->getElse()}; - - auto cond_a{GetZ3Cond(if_a)}; - auto cond_b{GetZ3Cond(if_b)}; - - clang::IfStmt *new_if{}; - std::vector new_then_body{then_a}; - if (Prove(*z3_ctx, cond_a == cond_b)) { - new_then_body.push_back(then_b); - auto new_then{ast.CreateCompoundStmt(new_then_body)}; - new_if = ast.CreateIf(if_a->getCond(), new_then); - - if (else_a || else_b) { - std::vector new_else_body; - if (else_a) { - new_else_body.push_back(else_a); - } - if (else_b) { - new_else_body.push_back(else_b); - } - new_if->setElse(ast.CreateCompoundStmt(new_else_body)); + // Erase then statements from the AST and `worklist` + for (auto stmt : thens) { + RemoveFromWorkList(stmt); + substitutions[stmt] = nullptr; + } + // Create our new if-then + auto sub = + ast.CreateIf(provenance.marker_expr, ast.CreateCompoundStmt(thens)); + provenance.conds[sub] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(lcond); + // Create an else branch if possible + if (!elses.empty()) { + // Erase else statements from the AST and `worklist` + for (auto stmt : elses) { + RemoveFromWorkList(stmt); + substitutions[stmt] = nullptr; } } else if (Prove(*z3_ctx, cond_a == !cond_b)) { if (else_b) { diff --git a/lib/AST/DeadStmtElim.cpp b/lib/AST/DeadStmtElim.cpp index 2309b1ca..68c8577a 100644 --- a/lib/AST/DeadStmtElim.cpp +++ b/lib/AST/DeadStmtElim.cpp @@ -17,17 +17,15 @@ DeadStmtElim::DeadStmtElim(Provenance &provenance, clang::ASTUnit &unit) bool DeadStmtElim::VisitIfStmt(clang::IfStmt *ifstmt) { // DLOG(INFO) << "VisitIfStmt"; - bool expr_bool_value = false; - auto if_const_expr = ifstmt->getCond()->getIntegerConstantExpr(ast_ctx); - - bool is_const = if_const_expr.hasValue(); - if (is_const) { - expr_bool_value = if_const_expr->getBoolValue(); + bool can_delete = false; + if (ifstmt->getCond() == provenance.marker_expr) { + can_delete = Prove(provenance.z3_ctx, + !provenance.z3_exprs[provenance.conds[ifstmt]]); } auto compound = clang::dyn_cast(ifstmt->getThen()); bool is_empty = compound ? compound->body_empty() : false; - if ((is_const && !expr_bool_value) || is_empty) { + if (can_delete || is_empty) { substitutions[ifstmt] = nullptr; } return true; diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index e05cab26..2b22b6e1 100644 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -136,76 +136,97 @@ static std::string GetName(llvm::Value *v) { return s; } -z3::expr GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, +unsigned GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond) { - if (z_br_edges.find({inst, cond}) == z_br_edges.end()) { + auto ToExpr = [&](unsigned idx) { + if (idx == POISON_IDX) { + return provenance.z3_ctx.bool_val(false); + } + return provenance.z3_exprs[idx]; + }; + if (provenance.z3_br_edges.find({inst, cond}) == + provenance.z3_br_edges.end()) { if (cond) { auto name{GetName(inst)}; - auto edge{z_ctx->bool_const(name.c_str())}; - z_br_edges[{inst, cond}] = z_exprs.size(); - z_exprs.push_back(edge); - z_br_edges_inv[edge.id()] = {inst, true}; + auto edge{provenance.z3_ctx.bool_const(name.c_str())}; + provenance.z3_br_edges[{inst, cond}] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(edge); + provenance.z3_br_edges_inv[edge.id()] = {inst, true}; } else { - auto edge{!(GetOrCreateEdgeForBranch(inst, true))}; - z_br_edges[{inst, cond}] = z_exprs.size(); - z_exprs.push_back(edge); + auto edge{!(ToExpr(GetOrCreateEdgeForBranch(inst, true)))}; + provenance.z3_br_edges[{inst, cond}] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(edge); } } - return z_exprs[z_br_edges[{inst, cond}]]; + return provenance.z3_br_edges[{inst, cond}]; } -z3::expr GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, +unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, llvm::ConstantInt *c) { - if (z_sw_edges.find({inst, c}) == z_sw_edges.end()) { + auto ToExpr = [&](unsigned idx) { + if (idx == POISON_IDX) { + return provenance.z3_ctx.bool_val(false); + } + return provenance.z3_exprs[idx]; + }; + if (provenance.z3_sw_edges.find({inst, c}) == provenance.z3_sw_edges.end()) { if (c) { auto name{GetName(inst) + GetName(inst->findCaseValue(c)->getCaseSuccessor())}; - auto edge{z_ctx->bool_const(name.c_str())}; - z_sw_edges_inv[edge.id()] = {inst, c}; + auto edge{provenance.z3_ctx.bool_const(name.c_str())}; + provenance.z3_sw_edges_inv[edge.id()] = {inst, c}; - z_sw_edges[{inst, c}] = z_exprs.size(); - z_exprs.push_back(edge); + provenance.z3_sw_edges[{inst, c}] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(edge); } else { // Default case - auto edge{z_ctx->bool_val(true)}; + auto edge{provenance.z3_ctx.bool_val(true)}; for (auto sw_case : inst->cases()) { - edge = edge && !GetOrCreateEdgeForSwitch(inst, sw_case.getCaseValue()); + edge = edge && + !ToExpr(GetOrCreateEdgeForSwitch(inst, sw_case.getCaseValue())); } edge = edge.simplify(); - z_sw_edges[{inst, c}] = z_exprs.size(); - z_exprs.push_back(edge); + provenance.z3_sw_edges[{inst, c}] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(edge); } } - return z_exprs[z_sw_edges[{inst, c}]]; + return provenance.z3_sw_edges[{inst, c}]; } -z3::expr GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, +unsigned GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, llvm::BasicBlock *to) { - if (z_edges.find({from, to}) == z_edges.end()) { + auto ToExpr = [&](unsigned idx) { + if (idx == POISON_IDX) { + return provenance.z3_ctx.bool_val(false); + } + return provenance.z3_exprs[idx]; + }; + if (provenance.z3_edges.find({from, to}) == provenance.z3_edges.end()) { // Construct the edge condition for CFG edge `(from, to)` - auto result{z_ctx->bool_val(true)}; + auto result{provenance.z3_ctx.bool_val(true)}; auto term = from->getTerminator(); switch (term->getOpcode()) { // Conditional branches case llvm::Instruction::Br: { auto br = llvm::cast(term); if (br->isConditional()) { - result = GetOrCreateEdgeForBranch(br, to == br->getSuccessor(0)); + result = + ToExpr(GetOrCreateEdgeForBranch(br, to == br->getSuccessor(0))); } } break; // Switches case llvm::Instruction::Switch: { auto sw{llvm::cast(term)}; if (to == sw->getDefaultDest()) { - result = GetOrCreateEdgeForSwitch(sw, nullptr); + result = ToExpr(GetOrCreateEdgeForSwitch(sw, nullptr)); } else { - result = z_ctx->bool_val(false); + result = provenance.z3_ctx.bool_val(false); for (auto sw_case : sw->cases()) { if (sw_case.getCaseSuccessor() == to) { - result = result || - GetOrCreateEdgeForSwitch(sw, sw_case.getCaseValue()); + result = result || ToExpr(GetOrCreateEdgeForSwitch( + sw, sw_case.getCaseValue())); } } } @@ -229,60 +250,57 @@ z3::expr GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, break; } - z_edges[{from, to}] = z_exprs.size(); - z_exprs.push_back(result.simplify()); + provenance.z3_edges[{from, to}] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(result.simplify()); } - return z_exprs[z_edges[{from, to}]]; + return provenance.z3_edges[{from, to}]; } -z3::expr GenerateAST::GetReachingCond(llvm::BasicBlock *block) { - if (reaching_conds.find(block) == reaching_conds.end()) { - return z_ctx->bool_val(false); +unsigned GenerateAST::GetReachingCond(llvm::BasicBlock *block) { + if (provenance.reaching_conds.find(block) == + provenance.reaching_conds.end()) { + return POISON_IDX; } - return z_exprs[reaching_conds[block]]; + return provenance.reaching_conds[block]; } void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { - auto Simplify{[this](z3::expr expr) { - if (Prove(*z_ctx, expr)) { - return z_ctx->bool_val(true); + auto ToExpr = [&](unsigned idx) { + if (idx == POISON_IDX) { + return provenance.z3_ctx.bool_val(false); } + return provenance.z3_exprs[idx]; + }; - z3::tactic aig(*z_ctx, "aig"); - z3::tactic simplify(*z_ctx, "simplify"); - z3::tactic ctx_solver_simplify(*z_ctx, "ctx-solver-simplify"); - auto tactic{simplify & aig & ctx_solver_simplify}; - return ApplyTactic(*z_ctx, tactic, expr).as_expr(); - }}; - - auto old_cond{GetReachingCond(block)}; + auto old_cond{ToExpr(GetReachingCond(block))}; if (block->hasNPredecessorsOrMore(1)) { // Gather reaching conditions from predecessors of the block - auto cond{z_ctx->bool_val(false)}; + auto cond{provenance.z3_ctx.bool_val(false)}; for (auto pred : llvm::predecessors(block)) { - auto pred_cond{GetReachingCond(pred)}; - auto edge_cond{GetOrCreateEdgeCond(pred, block)}; + auto pred_cond{ToExpr(GetReachingCond(pred))}; + auto edge_cond{ToExpr(GetOrCreateEdgeCond(pred, block))}; // Construct reaching condition from `pred` to `block` as // `reach_cond[pred] && edge_cond(pred, block)` or one of // the two if the other one is missing. - auto conj_cond{Simplify(pred_cond && edge_cond)}; + auto conj_cond{HeavySimplify(provenance.z3_ctx, pred_cond && edge_cond)}; // Append `conj_cond` to reaching conditions of other // predecessors via an `||`. Use `conj_cond` if there // is no `cond` yet. - cond = Simplify(cond || conj_cond); + cond = HeavySimplify(provenance.z3_ctx, cond || conj_cond); } - if (!Prove(*z_ctx, old_cond == cond)) { - reaching_conds[block] = z_exprs.size(); - z_exprs.push_back(cond); + if (!Prove(provenance.z3_ctx, old_cond == cond)) { + provenance.reaching_conds[block] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(cond); reaching_conds_changed = true; } } else { - if (reaching_conds.find(block) == reaching_conds.end()) { - reaching_conds[block] = z_exprs.size(); - z_exprs.push_back(z_ctx->bool_val(true)); + if (provenance.reaching_conds.find(block) == + provenance.reaching_conds.end()) { + provenance.reaching_conds[block] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(provenance.z3_ctx.bool_val(true)); reaching_conds_changed = true; } } @@ -294,65 +312,13 @@ StmtVec GenerateAST::CreateBasicBlockStmts(llvm::BasicBlock *block) { return result; } -clang::Expr *GenerateAST::ConvertExpr(z3::expr expr) { - auto hash{expr.id()}; - if (z_br_edges_inv.find(hash) != z_br_edges_inv.end()) { - auto edge{z_br_edges_inv[hash]}; - CHECK(edge.second) << "Inverse map should only be populated for branches " - "taken when condition is true"; - return ast_gen.CreateOperandExpr(*(edge.first->op_end() - 3)); - } - - if (z_sw_edges_inv.find(hash) != z_sw_edges_inv.end()) { - auto edge{z_sw_edges_inv[hash]}; - CHECK(edge.second) - << "Inverse map should only be populated for not-default switch cases"; - - auto opnd{ast_gen.CreateOperandExpr(edge.first->getOperandUse(0))}; - return ast.CreateEQ(opnd, ast_gen.CreateConstantExpr(edge.second)); - } - - std::vector args; - for (auto i{0U}; i < expr.num_args(); ++i) { - args.push_back(ConvertExpr(expr.arg(i))); - } - - switch (expr.decl().decl_kind()) { - case Z3_OP_TRUE: - CHECK_EQ(args.size(), 0) << "True cannot have arguments"; - return ast.CreateTrue(); - case Z3_OP_FALSE: - CHECK_EQ(args.size(), 0) << "False cannot have arguments"; - return ast.CreateFalse(); - case Z3_OP_AND: { - CHECK_GE(args.size(), 2) << "And must have at least 2 arguments"; - clang::Expr *res{args[0]}; - for (auto i{1U}; i < args.size(); ++i) { - res = ast.CreateLAnd(res, args[i]); - } - return res; - } - case Z3_OP_OR: { - CHECK_GE(args.size(), 2) << "Or must have at least 2 arguments"; - clang::Expr *res{args[0]}; - for (auto i{1U}; i < args.size(); ++i) { - res = ast.CreateLOr(res, args[i]); - } - return res; - } - case Z3_OP_NOT: { - CHECK_EQ(args.size(), 1) << "Not must have one argument"; - auto neg{ast.CreateLNot(args[0])}; - CopyProvenance(args[0], neg, provenance.use_provenance); - return neg; - } - default: - LOG(FATAL) << "Invalid z3 op"; - } - return nullptr; -} - StmtVec GenerateAST::CreateRegionStmts(llvm::Region *region) { + auto ToExpr = [&](unsigned idx) { + if (idx == POISON_IDX) { + return provenance.z3_ctx.bool_val(false); + } + return provenance.z3_exprs[idx]; + }; StmtVec result; for (auto block : rpo_walk) { // Check if the block is a subregion entry @@ -375,7 +341,8 @@ StmtVec GenerateAST::CreateRegionStmts(llvm::Region *region) { } // Gate the compound behind a reaching condition auto z_expr{GetReachingCond(block)}; - block_stmts[block] = ast.CreateIf(ConvertExpr(z_expr), compound); + block_stmts[block] = ast.CreateIf(provenance.marker_expr, compound); + provenance.conds[block_stmts[block]] = z_expr; // Store the compound result.push_back(block_stmts[block]); } @@ -437,6 +404,12 @@ clang::CompoundStmt *GenerateAST::StructureAcyclicRegion(llvm::Region *region) { } clang::CompoundStmt *GenerateAST::StructureCyclicRegion(llvm::Region *region) { + auto ToExpr = [&](unsigned idx) { + if (idx == POISON_IDX) { + return provenance.z3_ctx.bool_val(false); + } + return provenance.z3_exprs[idx]; + }; DLOG(INFO) << "Region " << GetRegionNameStr(region) << " is cyclic"; auto region_body = CreateRegionStmts(region); // Get the loop for which the entry block of the region is a header @@ -476,15 +449,18 @@ clang::CompoundStmt *GenerateAST::StructureCyclicRegion(llvm::Region *region) { for (auto edge : exits) { auto from = edge.first; auto to = edge.second; - // Create edge condition - auto z_expr{GetReachingCond(from) && GetOrCreateEdgeCond(from, to)}; - auto cond{ConvertExpr(z_expr.simplify())}; // Find the statement corresponding to the exiting block auto it = std::find(loop_body.begin(), loop_body.end(), block_stmts[from]); CHECK(it != loop_body.end()); // Create a loop exiting `break` statement StmtVec break_stmt({ast.CreateBreak()}); - auto exit_stmt = ast.CreateIf(cond, ast.CreateCompoundStmt(break_stmt)); + auto exit_stmt = ast.CreateIf(provenance.marker_expr, + ast.CreateCompoundStmt(break_stmt)); + provenance.conds[exit_stmt] = provenance.z3_exprs.size(); + // Create edge condition + provenance.z3_exprs.push_back( + (ToExpr(GetReachingCond(from)) && ToExpr(GetOrCreateEdgeCond(from, to))) + .simplify()); // Insert it after the exiting block statement loop_body.insert(std::next(it), exit_stmt); } @@ -572,9 +548,7 @@ GenerateAST::GenerateAST(Provenance &provenance, clang::ASTUnit &unit) unit(unit), provenance(provenance), ast_gen(unit, provenance), - ast(unit), - z_ctx(new z3::context()), - z_exprs(*z_ctx) {} + ast(unit) {} GenerateAST::Result GenerateAST::run(llvm::Module &module, llvm::ModuleAnalysisManager &MAM) { diff --git a/lib/AST/IRToASTVisitor.cpp b/lib/AST/IRToASTVisitor.cpp index da7eb92e..d89cb93a 100644 --- a/lib/AST/IRToASTVisitor.cpp +++ b/lib/AST/IRToASTVisitor.cpp @@ -61,6 +61,66 @@ class ExprGen : public llvm::InstVisitor { clang::Expr *visitUnaryOperator(llvm::UnaryOperator &inst); }; +clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { + auto hash{expr.id()}; + if (provenance.z3_br_edges_inv.find(hash) != + provenance.z3_br_edges_inv.end()) { + auto edge{provenance.z3_br_edges_inv[hash]}; + CHECK(edge.second) << "Inverse map should only be populated for branches " + "taken when condition is true"; + return CreateOperandExpr(*(edge.first->op_end() - 3)); + } + + if (provenance.z3_sw_edges_inv.find(hash) != + provenance.z3_sw_edges_inv.end()) { + auto edge{provenance.z3_sw_edges_inv[hash]}; + CHECK(edge.second) + << "Inverse map should only be populated for not-default switch cases"; + + auto opnd{CreateOperandExpr(edge.first->getOperandUse(0))}; + return ast.CreateEQ(opnd, CreateConstantExpr(edge.second)); + } + + std::vector args; + for (auto i{0U}; i < expr.num_args(); ++i) { + args.push_back(ConvertExpr(expr.arg(i))); + } + + switch (expr.decl().decl_kind()) { + case Z3_OP_TRUE: + CHECK_EQ(args.size(), 0) << "True cannot have arguments"; + return ast.CreateTrue(); + case Z3_OP_FALSE: + CHECK_EQ(args.size(), 0) << "False cannot have arguments"; + return ast.CreateFalse(); + case Z3_OP_AND: { + CHECK_GE(args.size(), 2) << "And must have at least 2 arguments"; + clang::Expr *res{args[0]}; + for (auto i{1U}; i < args.size(); ++i) { + res = ast.CreateLAnd(res, args[i]); + } + return res; + } + case Z3_OP_OR: { + CHECK_GE(args.size(), 2) << "Or must have at least 2 arguments"; + clang::Expr *res{args[0]}; + for (auto i{1U}; i < args.size(); ++i) { + res = ast.CreateLOr(res, args[i]); + } + return res; + } + case Z3_OP_NOT: { + CHECK_EQ(args.size(), 1) << "Not must have one argument"; + auto neg{ast.CreateLNot(args[0])}; + CopyProvenance(args[0], neg, provenance.use_provenance); + return neg; + } + default: + LOG(FATAL) << "Invalid z3 op"; + } + return nullptr; +} + void ExprGen::VisitGlobalVar(llvm::GlobalVariable &gvar) { DLOG(INFO) << "VisitGlobalVar: " << LLVMThingToString(&gvar); auto &var{provenance.value_decls[&gvar]}; diff --git a/lib/AST/LoopRefine.cpp b/lib/AST/LoopRefine.cpp index 76118a37..e39a596d 100644 --- a/lib/AST/LoopRefine.cpp +++ b/lib/AST/LoopRefine.cpp @@ -57,7 +57,6 @@ class WhileRule : public InferenceRule { auto comp{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(comp->body_front())}; - auto cond{ifstmt->getCond()}; std::vector new_body; if (auto else_stmt = ifstmt->getElse()) { new_body.push_back(else_stmt); @@ -65,9 +64,12 @@ class WhileRule : public InferenceRule { std::copy(comp->body_begin() + 1, comp->body_end(), std::back_inserter(new_body)); ASTBuilder ast(unit); - auto new_cond{ast.CreateLNot(cond)}; - CopyProvenance(cond, new_cond, provenance.use_provenance); - return ast.CreateWhile(new_cond, ast.CreateCompoundStmt(new_body)); + auto new_while{ast.CreateWhile(provenance.marker_expr, + ast.CreateCompoundStmt(new_body))}; + provenance.conds[new_while] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back( + !provenance.z3_exprs[provenance.conds[ifstmt]]); + return new_while; } }; @@ -98,13 +100,12 @@ class ElseWhileRule : public InferenceRule { auto comp{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(comp->body_front())}; - auto cond{ifstmt->getCond()}; std::vector new_body; - new_body.push_back(ifstmt->getThen()); - std::copy(comp->body_begin() + 1, comp->body_end(), - std::back_inserter(new_body)); ASTBuilder ast(unit); - return ast.CreateWhile(cond, ast.CreateCompoundStmt(new_body)); + auto new_while{ast.CreateWhile(provenance.marker_expr, + ast.CreateCompoundStmt(new_body))}; + provenance.conds[new_while] = provenance.conds[ifstmt]; + return new_while; } }; @@ -135,18 +136,20 @@ class DoWhileRule : public InferenceRule { auto comp{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(comp->body_back())}; - auto cond{ifstmt->getCond()}; + auto cond{provenance.z3_exprs[provenance.conds[ifstmt]]}; std::vector new_body(comp->body_begin(), comp->body_end() - 1); ASTBuilder ast(unit); if (auto else_stmt = ifstmt->getElse()) { - auto cond_inv{ast.CreateLNot(cond)}; - CopyProvenance(cond, cond_inv, provenance.use_provenance); - new_body.push_back(ast.CreateIf(cond_inv, else_stmt)); + auto new_if{ast.CreateIf(provenance.marker_expr, else_stmt)}; + provenance.conds[new_if] = provenance.z3_exprs.size(); + new_body.push_back(new_if); } - auto cond_inv{ast.CreateLNot(cond)}; - CopyProvenance(cond, cond_inv, provenance.use_provenance); - return ast.CreateDo(cond_inv, ast.CreateCompoundStmt(new_body)); + auto new_do{ + ast.CreateDo(provenance.marker_expr, ast.CreateCompoundStmt(new_body))}; + provenance.conds[new_do] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(!cond); + return new_do; } }; @@ -177,7 +180,6 @@ class ElseDoWhileRule : public InferenceRule { auto comp{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(comp->body_back())}; - auto cond{ifstmt->getCond()}; std::vector new_body(comp->body_begin(), comp->body_end() - 1); ASTBuilder ast(unit); @@ -185,8 +187,10 @@ class ElseDoWhileRule : public InferenceRule { ifstmt->setElse(nullptr); new_body.push_back(ifstmt); - auto cond_clone{Clone(unit, cond, provenance.use_provenance)}; - return ast.CreateDo(cond_clone, ast.CreateCompoundStmt(new_body)); + auto new_do{ + ast.CreateDo(provenance.marker_expr, ast.CreateCompoundStmt(new_body))}; + provenance.conds[new_do] = provenance.conds[ifstmt]; + return new_do; } }; @@ -223,23 +227,28 @@ class NestedDoWhileRule : public InferenceRule { CHECK(loop && loop == match) << "Substituted WhileStmt is not the matched WhileStmt!"; auto comp{clang::cast(loop->getBody())}; - auto cond{clang::cast(comp->body_back())}; + auto if_stmt{clang::cast(comp->body_back())}; + auto cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; std::vector do_body(comp->body_begin(), comp->body_end() - 1); ASTBuilder ast(unit); - if (auto else_stmt = cond->getElse()) { - auto cond_inv{ast.CreateLNot(cond->getCond())}; - CopyProvenance(cond->getCond(), cond_inv, provenance.use_provenance); - do_body.push_back(ast.CreateIf(cond_inv, else_stmt)); + if (auto else_stmt = if_stmt->getElse()) { + auto new_if{ast.CreateIf(provenance.marker_expr, else_stmt)}; + provenance.conds[new_if] = provenance.z3_exprs.size(); + do_body.push_back(new_if); } - auto do_cond{ast.CreateLNot(cond->getCond())}; - CopyProvenance(cond->getCond(), do_cond, provenance.use_provenance); - auto do_stmt{ast.CreateDo(do_cond, ast.CreateCompoundStmt(do_body))}; + auto do_stmt{ + ast.CreateDo(provenance.marker_expr, ast.CreateCompoundStmt(do_body))}; + provenance.conds[do_stmt] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(!cond); - std::vector while_body({do_stmt, cond->getThen()}); - return ast.CreateWhile(loop->getCond(), ast.CreateCompoundStmt(while_body)); + std::vector while_body({do_stmt, if_stmt->getThen()}); + auto new_while{ast.CreateWhile(provenance.marker_expr, + ast.CreateCompoundStmt(while_body))}; + provenance.conds[new_while] = provenance.conds[loop]; + return new_while; } }; @@ -330,14 +339,18 @@ class CondToSeqRule : public InferenceRule { ASTBuilder ast(unit); auto body{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(body->body_front())}; - auto inner_loop{ast.CreateWhile(ifstmt->getCond(), ifstmt->getThen())}; + auto inner_loop{ast.CreateWhile(provenance.marker_expr, ifstmt->getThen())}; + provenance.conds[inner_loop] = provenance.conds[ifstmt]; std::vector new_body({inner_loop}); if (auto comp = clang::dyn_cast(ifstmt->getElse())) { new_body.insert(new_body.end(), comp->body_begin(), comp->body_end()); } else { new_body.push_back(ifstmt->getElse()); } - return ast.CreateWhile(loop->getCond(), ast.CreateCompoundStmt(new_body)); + auto new_while{ast.CreateWhile(provenance.marker_expr, + ast.CreateCompoundStmt(new_body))}; + provenance.conds[new_while] = provenance.conds[loop]; + return new_while; } }; @@ -365,9 +378,10 @@ class CondToSeqNegRule : public InferenceRule { ASTBuilder ast(unit); auto body{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(body->body_front())}; - auto cond{ast.CreateLNot(ifstmt->getCond())}; - CopyProvenance(ifstmt->getCond(), cond, provenance.use_provenance); - auto inner_loop{ast.CreateWhile(cond, ifstmt->getElse())}; + auto cond{provenance.z3_exprs[provenance.conds[ifstmt]]}; + auto inner_loop{ast.CreateWhile(provenance.marker_expr, ifstmt->getElse())}; + provenance.conds[inner_loop] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(!cond); std::vector new_body({inner_loop}); if (auto comp = clang::dyn_cast(ifstmt->getThen())) { new_body.insert(new_body.end(), comp->body_begin(), comp->body_end()); @@ -375,7 +389,10 @@ class CondToSeqNegRule : public InferenceRule { new_body.push_back(ifstmt->getThen()); } - return ast.CreateWhile(loop->getCond(), ast.CreateCompoundStmt(new_body)); + auto new_while{ast.CreateWhile(provenance.marker_expr, + ast.CreateCompoundStmt(new_body))}; + provenance.conds[new_while] = provenance.conds[loop]; + return new_while; } }; diff --git a/lib/AST/MaterializeConds.cpp b/lib/AST/MaterializeConds.cpp new file mode 100644 index 00000000..fa080364 --- /dev/null +++ b/lib/AST/MaterializeConds.cpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021-present, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#include "rellic/AST/MaterializeConds.h" + +#include +#include +#include +#include +#include + +namespace rellic { + +MaterializeConds::MaterializeConds(Provenance &provenance, clang::ASTUnit &unit) + : TransformVisitor(provenance, unit), + ast_gen(unit, provenance) {} + +bool MaterializeConds::VisitIfStmt(clang::IfStmt *stmt) { + auto cond{provenance.z3_exprs[provenance.conds[stmt]]}; + if (stmt->getCond() == provenance.marker_expr) { + stmt->setCond(ast_gen.ConvertExpr(cond)); + } + return true; +} + +bool MaterializeConds::VisitWhileStmt(clang::WhileStmt *stmt) { + auto cond{provenance.z3_exprs[provenance.conds[stmt]]}; + if (stmt->getCond() == provenance.marker_expr) { + stmt->setCond(ast_gen.ConvertExpr(cond)); + } + return true; +} + +bool MaterializeConds::VisitDoStmt(clang::DoStmt *stmt) { + auto cond{provenance.z3_exprs[provenance.conds[stmt]]}; + if (stmt->getCond() == provenance.marker_expr) { + stmt->setCond(ast_gen.ConvertExpr(cond)); + } + return true; +} + +void MaterializeConds::RunImpl() { + LOG(INFO) << "Materializing conditions"; + TransformVisitor::RunImpl(); + TraverseDecl(ast_ctx.getTranslationUnitDecl()); +} + +} // namespace rellic diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index 766aac61..15208d04 100644 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include @@ -18,23 +19,35 @@ #include "rellic/AST/Util.h" namespace rellic { -using ExprVec = std::vector; - -static bool isConstant(const clang::ASTContext& ctx, clang::Expr* expr) { - return expr->getIntegerConstantExpr(ctx).hasValue(); -} - class CompoundVisitor - : public clang::StmtVisitor { + : public clang::StmtVisitor { private: + Provenance& provenance; ASTBuilder& ast; clang::ASTContext& ctx; + bool IsConstant(z3::expr expr) { + if (Prove(provenance.z3_ctx, expr)) { + return false; + } + + if (Prove(provenance.z3_ctx, !expr)) { + return false; + } + + return true; + } + + z3::expr Simplify(z3::expr expr) { + return HeavySimplify(provenance.z3_ctx, expr); + } + public: - CompoundVisitor(ASTBuilder& ast, clang::ASTContext& ctx) - : ast(ast), ctx(ctx) {} + CompoundVisitor(Provenance& provenance, ASTBuilder& ast, + clang::ASTContext& ctx) + : provenance(provenance), ast(ast), ctx(ctx) {} - bool VisitCompoundStmt(clang::CompoundStmt* compound, ExprVec& true_exprs) { + bool VisitCompoundStmt(clang::CompoundStmt* compound, z3::expr& true_exprs) { bool changed{false}; for (auto stmt : compound->body()) { changed |= Visit(stmt, true_exprs); @@ -43,56 +56,88 @@ class CompoundVisitor return changed; } - bool VisitWhileStmt(clang::WhileStmt* while_stmt, ExprVec& true_exprs) { + bool VisitWhileStmt(clang::WhileStmt* while_stmt, z3::expr& true_exprs) { bool changed{false}; - auto cond{while_stmt->getCond()}; - for (auto true_expr : true_exprs) { - changed |= Replace(true_expr, ast.CreateTrue(), &cond); + if (while_stmt->getCond() == provenance.marker_expr) { + auto old_cond{provenance.z3_exprs[provenance.conds[while_stmt]]}; + auto new_cond{Simplify(old_cond && true_exprs)}; + LOG(INFO) << "known: " << true_exprs.to_string() + << " old: " << old_cond.to_string() + << " new: " << new_cond.to_string(); + if (!z3::eq(old_cond, new_cond)) { + provenance.conds[while_stmt] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(new_cond); + changed = true; + } } - while_stmt->setCond(cond); - ExprVec inner{true_exprs}; - if (!isConstant(ctx, while_stmt->getCond())) { - inner.push_back(while_stmt->getCond()); + auto cond{provenance.z3_exprs[provenance.conds[while_stmt]]}; + z3::expr inner{true_exprs}; + bool isConstant{IsConstant(cond)}; + if (!isConstant) { + inner = Simplify(inner && cond); } changed |= Visit(while_stmt->getBody(), inner); - true_exprs.push_back(Negate(ast, while_stmt->getCond())); + if (!isConstant) { + true_exprs = Simplify(true_exprs && !cond); + } return changed; } - bool VisitDoStmt(clang::DoStmt* do_stmt, ExprVec& true_exprs) { + bool VisitDoStmt(clang::DoStmt* do_stmt, z3::expr& true_exprs) { bool changed{false}; - auto cond{do_stmt->getCond()}; - for (auto true_expr : true_exprs) { - changed |= Replace(true_expr, ast.CreateTrue(), &cond); + if (do_stmt->getCond() == provenance.marker_expr) { + auto old_cond{provenance.z3_exprs[provenance.conds[do_stmt]]}; + auto new_cond{Simplify(old_cond && true_exprs)}; + LOG(INFO) << "known: " << true_exprs.to_string() + << " old: " << old_cond.to_string() + << " new: " << new_cond.to_string(); + if (!z3::eq(old_cond, new_cond)) { + provenance.conds[do_stmt] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(new_cond); + changed = true; + } } - do_stmt->setCond(cond); - ExprVec inner{true_exprs}; + auto cond{provenance.z3_exprs[provenance.conds[do_stmt]]}; + auto inner{true_exprs}; Visit(do_stmt->getBody(), inner); - true_exprs.push_back(Negate(ast, do_stmt->getCond())); + if (!IsConstant(cond)) { + true_exprs = Simplify(true_exprs && !cond); + } return changed; } - bool VisitIfStmt(clang::IfStmt* if_stmt, ExprVec& true_exprs) { + bool VisitIfStmt(clang::IfStmt* if_stmt, z3::expr& true_exprs) { bool changed{false}; - auto cond{if_stmt->getCond()}; - for (auto true_expr : true_exprs) { - changed |= Replace(true_expr, ast.CreateTrue(), &cond); + if (if_stmt->getCond() == provenance.marker_expr) { + auto old_cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; + auto new_cond{Simplify(old_cond && true_exprs)}; + LOG(INFO) << "known: " << true_exprs.to_string() + << " old: " << old_cond.to_string() + << " new: " << new_cond.to_string(); + if (!z3::eq(old_cond, new_cond)) { + provenance.conds[if_stmt] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(new_cond); + changed = true; + } } - if_stmt->setCond(cond); - ExprVec inner_then{true_exprs}; - if (!isConstant(ctx, if_stmt->getCond())) { - inner_then.push_back(if_stmt->getCond()); + auto cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; + auto inner_then{true_exprs}; + bool isConstant{IsConstant(cond)}; + if (isConstant) { + inner_then = Simplify(inner_then && cond); } Visit(if_stmt->getThen(), inner_then); if (if_stmt->getElse()) { - ExprVec inner_else{true_exprs}; - inner_else.push_back(Negate(ast, if_stmt->getCond())); + auto inner_else{true_exprs}; + if (!isConstant) { + inner_else = Simplify(inner_else && !cond); + } Visit(if_stmt->getElse(), inner_else); } return changed; @@ -105,12 +150,12 @@ NestedCondProp::NestedCondProp(Provenance& provenance, clang::ASTUnit& unit) void NestedCondProp::RunImpl() { changed = false; ASTBuilder ast{ast_unit}; - CompoundVisitor visitor{ast, ast_ctx}; + CompoundVisitor visitor{provenance, ast, ast_ctx}; for (auto decl : ast_ctx.getTranslationUnitDecl()->decls()) { if (auto fdecl = clang::dyn_cast(decl)) { if (fdecl->hasBody()) { - ExprVec true_exprs; + auto true_exprs{provenance.z3_ctx.bool_val(true)}; changed |= visitor.Visit(fdecl->getBody(), true_exprs); } } diff --git a/lib/AST/NestedScopeCombine.cpp b/lib/AST/NestedScopeCombine.cpp index 222f97d2..51981b31 100644 --- a/lib/AST/NestedScopeCombine.cpp +++ b/lib/AST/NestedScopeCombine.cpp @@ -11,6 +11,9 @@ #include #include +#include "rellic/AST/Compat/Stmt.h" +#include "rellic/AST/Util.h" + namespace rellic { NestedScopeCombine::NestedScopeCombine(Provenance &provenance, @@ -21,17 +24,11 @@ bool NestedScopeCombine::VisitIfStmt(clang::IfStmt *ifstmt) { // DLOG(INFO) << "VisitIfStmt"; // Determine whether `cond` is a constant expression that is always true and // `ifstmt` should be replaced by `then` in it's parent nodes. - auto if_const_expr = ifstmt->getCond()->getIntegerConstantExpr(ast_ctx); - bool is_const = if_const_expr.hasValue(); - if (is_const && if_const_expr->getBoolValue()) { + auto cond{provenance.z3_exprs[provenance.conds[ifstmt]]}; + if (Prove(provenance.z3_ctx, cond)) { substitutions[ifstmt] = ifstmt->getThen(); - } else if (is_const && !if_const_expr->getBoolValue()) { - if (auto else_stmt = ifstmt->getElse()) { - substitutions[ifstmt] = else_stmt; - } else { - std::vector body; - substitutions[ifstmt] = ast.CreateCompoundStmt(body); - } + } else if (Prove(provenance.z3_ctx, !cond) && ifstmt->getElse()) { + substitutions[ifstmt] = ifstmt->getElse(); } return !Stopped(); } diff --git a/lib/AST/ReachBasedRefine.cpp b/lib/AST/ReachBasedRefine.cpp index 794d721b..0f5b94d8 100644 --- a/lib/AST/ReachBasedRefine.cpp +++ b/lib/AST/ReachBasedRefine.cpp @@ -16,24 +16,23 @@ namespace rellic { ReachBasedRefine::ReachBasedRefine(Provenance &provenance, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit), - z3_ctx(new z3::context()), - z3_gen(new rellic::Z3ConvVisitor(unit, z3_ctx.get())) {} - -z3::expr ReachBasedRefine::GetZ3Cond(clang::IfStmt *ifstmt) { - auto cond = ifstmt->getCond(); - auto expr = z3_gen->Z3BoolCast(z3_gen->GetOrCreateZ3Expr(cond)); - return expr.simplify(); -} - -bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { - std::vector body{compound->body_begin(), compound->body_end()}; - std::vector ifs; - z3::expr_vector conds{*z3_ctx}; - - auto ResetChain = [&]() { - ifs.clear(); - conds.resize(0); + : TransformVisitor(provenance, unit) {} + +void ReachBasedRefine::CreateIfElseStmts(IfStmtVec stmts) { + // Else-if candidate IfStmts and their Z3 form + // reaching conditions. + IfStmtVec elifs; + z3::expr_vector conds(provenance.z3_ctx); + // Test that determines if a new IfStmts is not + // reachable from the already gathered IfStmts. + auto IsUnrechable = [this, &conds](z3::expr cond) { + return Prove(provenance.z3_ctx, !(cond && z3::mk_or(conds))); + }; + // Test to determine if we have enough candidate + // IfStmts to form an else-if cascade. + auto IsTautology = [this, &conds] { + return Prove(provenance.z3_ctx, + z3::mk_or(conds) == provenance.z3_ctx.bool_val(true)); }; bool done_something{false}; @@ -43,100 +42,75 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { ResetChain(); continue; } - - ifs.push_back(if_stmt); - auto cond{GetZ3Cond(if_stmt)}; - - if (if_stmt->getElse()) { - // We cannot link `if` statements that contain `else` branches - ResetChain(); - continue; - } - - // Is the current `if` statement unreachable from all the others? - bool is_unreachable{Prove(*z3_ctx, !(cond && z3::mk_or(conds)))}; - - if (!is_unreachable) { - ResetChain(); - continue; + // Clear else-if IfStmts if we find a path among them. + auto cond = provenance.z3_exprs[provenance.conds[stmt]]; + if (stmt->getElse() || !IsUnrechable(cond)) { + conds = z3::expr_vector(provenance.z3_ctx); + elifs.clear(); } + // Add the current if-statement to the else-if candidates. + conds.push_back(cond); + elifs.push_back(stmt); + } - conds.push_back(GetZ3Cond(if_stmt)); - - // Do the collected statements cover all possibilities? - auto is_complete{Prove(*z3_ctx, z3::mk_or(conds))}; - - if (ifs.size() <= 2 || !is_complete) { - // We need to collect more statements - continue; - } + // Check if we have enough statements to work with + if (elifs.size() < 2) { + return; + } - /* - `body` will look like this at this point: - - ... - i - n : ... - i - n + 1: if(cond_1) { } - i - n + 2: if(cond_2) { } - ... - i - 1 : if(cond_n-1) { } - i : if(cond_n) { } - ... - - and we want to chain all of the statements together: - ... - i - n : ... - i - n + 1: if(cond_1) { } else if(cond_2) { } ... else if(cond_n) { } - i - n + 2: if(cond_2) { } else if(cond_3) { } ... else if(cond_n) { } - ... - i - 1 : if(cond_n-1) { } else if(cond_n) { } - i : if(cond_n) { } - ... - */ - auto last_if{ifs[0]}; - for (auto stmt : ifs) { - if (stmt == ifs.front()) { - continue; - } - if (stmt == ifs.back()) { - last_if->setElse(stmt->getThen()); - } else { - last_if->setElse(stmt); - last_if = stmt; - } + // Create the else-if cascade + clang::IfStmt *sub = nullptr; + for (auto stmt : llvm::make_range(elifs.rbegin(), elifs.rend())) { + auto then = stmt->getThen(); + if (stmt == elifs.back()) { + sub = ast.CreateIf(provenance.marker_expr, then); + provenance.conds[sub] = provenance.conds[stmt]; + substitutions[stmt] = sub; + } else if (stmt == elifs.front()) { + std::vector thens({then}); + sub->setElse(ast.CreateCompoundStmt(thens)); + substitutions[stmt] = nullptr; + } else { + auto elif = ast.CreateIf(provenance.marker_expr, then); + provenance.conds[elif] = provenance.conds[stmt]; + sub->setElse(elif); + sub = elif; + substitutions[stmt] = nullptr; } - - /* - `body` will look like this at this point: - - ... - i - n : ... - i - n + 1: if(cond_1) { } else if(cond_2) { } ... else if(cond_n) { } - i - n + 2: if(cond_2) { } else if(cond_3) { } ... else if(cond_n) { } - ... - i - 1 : if(cond_n-1) { } else { } - i : if(cond_n) { } - ... - - but since we chained all of the statements into the first, we want to remove - the others from the body: - - ... - i - n : ... - i - n + 1: if(cond_1) { } else if(cond_2) { } else if ... - ... - */ - size_t start_delete{i - (ifs.size() - 2)}; - size_t end_delete{i}; - body.erase(body.erase(std::next(body.begin(), start_delete), - std::next(body.begin(), end_delete))); - done_something = true; } +} - if (done_something) { - substitutions[compound] = ast.CreateCompoundStmt(body); - } - return !Stopped(); +/* +`body` will look like this at this point: + + ... + i - n : ... + i - n + 1: if(cond_1) { } else if(cond_2) { } ... else if(cond_n) { } + i - n + 2: if(cond_2) { } else if(cond_3) { } ... else if(cond_n) { } + ... + i - 1 : if(cond_n-1) { } else { } + i : if(cond_n) { } + ... + +but since we chained all of the statements into the first, we want to remove +the others from the body: + + ... + i - n : ... + i - n + 1: if(cond_1) { } else if(cond_2) { } else if ... + ... +*/ +size_t start_delete{i - (ifs.size() - 2)}; +size_t end_delete{i}; +body.erase(body.erase(std::next(body.begin(), start_delete), + std::next(body.begin(), end_delete))); +done_something = true; +} + +if (done_something) { + substitutions[compound] = ast.CreateCompoundStmt(body); +} +return !Stopped(); } void ReachBasedRefine::RunImpl() { diff --git a/lib/AST/Util.cpp b/lib/AST/Util.cpp index 9bb3ef11..b56cbebf 100644 --- a/lib/AST/Util.cpp +++ b/lib/AST/Util.cpp @@ -349,4 +349,16 @@ bool Prove(z3::context &ctx, z3::expr expr) { return ApplyTactic(ctx, z3::tactic(ctx, "sat"), !(expr.simplify())) .is_decided_unsat(); } + +z3::expr HeavySimplify(z3::context &ctx, z3::expr expr) { + if (Prove(ctx, expr)) { + return ctx.bool_val(true); + } + + z3::tactic aig(ctx, "aig"); + z3::tactic simplify(ctx, "simplify"); + z3::tactic ctx_solver_simplify(ctx, "ctx-solver-simplify"); + auto tactic{simplify & aig & ctx_solver_simplify}; + return ApplyTactic(ctx, tactic, expr).as_expr(); +} } // namespace rellic diff --git a/lib/AST/Z3CondSimplify.cpp b/lib/AST/Z3CondSimplify.cpp index 45977c81..6e4d3a71 100644 --- a/lib/AST/Z3CondSimplify.cpp +++ b/lib/AST/Z3CondSimplify.cpp @@ -17,119 +17,26 @@ namespace rellic { Z3CondSimplify::Z3CondSimplify(Provenance &provenance, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit), - z_ctx(new z3::context()), - z_gen(new Z3ConvVisitor(unit, z_ctx.get())), - hash_adaptor{ast_ctx, hashes}, - proven_true(10, hash_adaptor, ke_adaptor), - proven_false(10, hash_adaptor, ke_adaptor) {} - -bool Z3CondSimplify::IsProvenTrue(clang::Expr *e) { - auto it{proven_true.find(e)}; - if (it == proven_true.end()) { - proven_true[e] = Prove(*z_ctx, ToZ3(e)); - } - return proven_true[e]; -} - -bool Z3CondSimplify::IsProvenFalse(clang::Expr *e) { - auto it{proven_false.find(e)}; - if (it == proven_false.end()) { - proven_false[e] = Prove(*z_ctx, !ToZ3(e)); - } - return proven_false[e]; -} - -z3::expr Z3CondSimplify::ToZ3(clang::Expr *e) { - return z_gen->Z3BoolCast(z_gen->GetOrCreateZ3Expr(e)); -} - -clang::Expr *Z3CondSimplify::Simplify(clang::Expr *c_expr) { - if (auto binop = clang::dyn_cast(c_expr)) { - auto lhs{Simplify(binop->getLHS())}; - auto rhs{Simplify(binop->getRHS())}; - - auto opcode{binop->getOpcode()}; - if (opcode == clang::BO_LAnd || opcode == clang::BO_LOr) { - auto lhs_proven{IsProvenTrue(lhs)}; - auto rhs_proven{IsProvenTrue(rhs)}; - auto not_lhs_proven{IsProvenFalse(lhs)}; - auto not_rhs_proven{IsProvenFalse(rhs)}; - if (opcode == clang::BO_LAnd) { - if (lhs_proven && rhs_proven) { - changed = true; - return ast.CreateTrue(); - } else if (not_lhs_proven || not_rhs_proven) { - changed = true; - return ast.CreateFalse(); - } else if (lhs_proven) { - changed = true; - return rhs; - } else if (rhs_proven) { - changed = true; - return lhs; - } - } else { - if (not_lhs_proven && not_rhs_proven) { - changed = true; - return ast.CreateFalse(); - } else if (lhs_proven || rhs_proven) { - changed = true; - return ast.CreateTrue(); - } else if (not_lhs_proven) { - changed = true; - return rhs; - } else if (not_rhs_proven) { - changed = true; - return lhs; - } - } - - binop->setLHS(lhs); - binop->setRHS(rhs); - hashes[binop] = 0; - if (IsProvenTrue(binop)) { - changed = true; - return ast.CreateTrue(); - } else if (IsProvenFalse(binop)) { - changed = true; - return ast.CreateFalse(); - } - } - } else if (auto unop = clang::dyn_cast(c_expr)) { - if (unop->getOpcode() == clang::UO_LNot) { - auto sub{Simplify(unop->getSubExpr())}; - if (IsProvenTrue(sub)) { - changed = true; - return ast.CreateFalse(); - } else if (IsProvenFalse(sub)) { - changed = true; - return ast.CreateTrue(); - } - unop->setSubExpr(sub); - hashes[unop] = 0; - } - } else if (auto paren = clang::dyn_cast(c_expr)) { - auto sub{Simplify(paren->getSubExpr())}; - paren->setSubExpr(sub); - hashes[paren] = 0; - } - - return c_expr; -} + : TransformVisitor(provenance, unit) {} bool Z3CondSimplify::VisitIfStmt(clang::IfStmt *stmt) { - stmt->setCond(Simplify(stmt->getCond())); + auto cond{provenance.z3_exprs[provenance.conds[stmt]]}; + provenance.conds[stmt] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(HeavySimplify(provenance.z3_ctx, cond)); return true; } bool Z3CondSimplify::VisitWhileStmt(clang::WhileStmt *loop) { - loop->setCond(Simplify(loop->getCond())); + auto cond{provenance.z3_exprs[provenance.conds[loop]]}; + provenance.conds[loop] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(HeavySimplify(provenance.z3_ctx, cond)); return true; } bool Z3CondSimplify::VisitDoStmt(clang::DoStmt *loop) { - loop->setCond(Simplify(loop->getCond())); + auto cond{provenance.z3_exprs[provenance.conds[loop]]}; + provenance.conds[loop] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(HeavySimplify(provenance.z3_ctx, cond)); return true; } diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 85ac463d..933278ff 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -32,6 +32,7 @@ set(AST_HEADERS "${include_dir}/AST/InferenceRule.h" "${include_dir}/AST/LocalDeclRenamer.h" "${include_dir}/AST/LoopRefine.h" + "${include_dir}/AST/MaterializeConds.h" "${include_dir}/AST/NestedCondProp.h" "${include_dir}/AST/NestedScopeCombine.h" "${include_dir}/AST/NormalizeCond.h" @@ -72,6 +73,7 @@ set(AST_SOURCES AST/IRToASTVisitor.cpp AST/LocalDeclRenamer.cpp AST/LoopRefine.cpp + AST/MaterializeConds.cpp AST/NestedCondProp.cpp AST/NestedScopeCombine.cpp AST/NormalizeCond.cpp diff --git a/lib/Decompiler.cpp b/lib/Decompiler.cpp index 38a28b3b..b287138c 100644 --- a/lib/Decompiler.cpp +++ b/lib/Decompiler.cpp @@ -29,6 +29,7 @@ #include "rellic/AST/IRToASTVisitor.h" #include "rellic/AST/LocalDeclRenamer.h" #include "rellic/AST/LoopRefine.h" +#include "rellic/AST/MaterializeConds.h" #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" #include "rellic/AST/NormalizeCond.h" @@ -86,6 +87,9 @@ Result Decompile( rellic::Provenance provenance; rellic::GenerateAST::run(*module, provenance, *ast_unit); + ASTBuilder ast{*ast_unit}; + provenance.marker_expr = + ast.CreateAdd(ast.CreateFalse(), ast.CreateFalse()); // TODO(surovic): Add llvm::Value* -> clang::Decl* map // Especially for llvm::Argument* and llvm::Function*. @@ -190,6 +194,8 @@ Result Decompile( rellic::CompositeASTPass pass_ec{provenance, *ast_unit}; auto& ec_passes{pass_ec.GetPasses()}; + ec_passes.push_back( + std::make_unique(provenance, *ast_unit)); if (options.expression_combine) { ec_passes.push_back( std::make_unique(provenance, *ast_unit)); diff --git a/tools/repl/Repl.cpp b/tools/repl/Repl.cpp index 24367398..741e0994 100644 --- a/tools/repl/Repl.cpp +++ b/tools/repl/Repl.cpp @@ -41,6 +41,7 @@ #include "rellic/AST/IRToASTVisitor.h" #include "rellic/AST/LocalDeclRenamer.h" #include "rellic/AST/LoopRefine.h" +#include "rellic/AST/MaterializeConds.h" #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" #include "rellic/AST/NormalizeCond.h" @@ -61,7 +62,7 @@ DECLARE_bool(version); llvm::LLVMContext llvm_ctx; std::unique_ptr module{nullptr}; std::unique_ptr ast_unit{nullptr}; -rellic::Provenance provenance; +std::unique_ptr provenance; std::unique_ptr global_pass{nullptr}; static void SetVersion(void) { @@ -147,23 +148,25 @@ class Diff { static std::unique_ptr CreatePass(const std::string& name) { if (name == "cbr") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*provenance, *ast_unit); } else if (name == "dse") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*provenance, *ast_unit); } else if (name == "ec") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*provenance, *ast_unit); } else if (name == "lr") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*provenance, *ast_unit); + } else if (name == "mc") { + return std::make_unique(*provenance, *ast_unit); } else if (name == "ncp") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*provenance, *ast_unit); } else if (name == "nsc") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*provenance, *ast_unit); } else if (name == "nc") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*provenance, *ast_unit); } else if (name == "rbr") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*provenance, *ast_unit); } else if (name == "zcs") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*provenance, *ast_unit); } else { return nullptr; } @@ -197,6 +200,7 @@ static void do_help() { << " dse Dead statement elimination\n" << " ec Expression combination\n" << " lr Loop refinement\n" + << " mc Condition materialization\n" << " nc Condition normalization\n" << " ncp Nested condition propagation\n" << " nsc Nested scope combination\n" @@ -301,9 +305,9 @@ static void do_decompile() { rellic::DebugInfoCollector dic; dic.visit(*module); provenance = {}; - rellic::GenerateAST::run(*module, provenance, *ast_unit); - rellic::LocalDeclRenamer ldr{provenance, *ast_unit, dic.GetIRToNameMap()}; - rellic::StructFieldRenamer sfr{provenance, *ast_unit, + rellic::GenerateAST::run(*module, *provenance, *ast_unit); + rellic::LocalDeclRenamer ldr{*provenance, *ast_unit, dic.GetIRToNameMap()}; + rellic::StructFieldRenamer sfr{*provenance, *ast_unit, dic.GetIRTypeToDITypeMap()}; ldr.Run(); sfr.Run(); @@ -320,7 +324,7 @@ static void do_run(std::istream& is) { } global_pass = - std::make_unique(provenance, *ast_unit); + std::make_unique(*provenance, *ast_unit); std::string name; while (is >> name) { auto pass{CreatePass(name)}; @@ -358,7 +362,7 @@ static void do_fixpoint(std::istream& is) { } global_pass = - std::make_unique(provenance, *ast_unit); + std::make_unique(*provenance, *ast_unit); std::string name; while (is >> name) { auto pass{CreatePass(name)}; diff --git a/tools/xref/Xref.cpp b/tools/xref/Xref.cpp index fa4ebff9..518bcea4 100644 --- a/tools/xref/Xref.cpp +++ b/tools/xref/Xref.cpp @@ -49,11 +49,13 @@ #include "rellic/AST/IRToASTVisitor.h" #include "rellic/AST/LocalDeclRenamer.h" #include "rellic/AST/LoopRefine.h" +#include "rellic/AST/MaterializeConds.h" #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" #include "rellic/AST/NormalizeCond.h" #include "rellic/AST/ReachBasedRefine.h" #include "rellic/AST/StructFieldRenamer.h" +#include "rellic/AST/Util.h" #include "rellic/AST/Z3CondSimplify.h" #include "rellic/BC/Util.h" #include "rellic/Decompiler.h" @@ -110,7 +112,7 @@ struct Session { std::unique_ptr Module; std::unique_ptr Unit; std::unique_ptr Pass; - rellic::Provenance Provenance; + std::unique_ptr Provenance; // Must always be acquired in this order and released all at once std::shared_mutex LoadMutex, MutationMutex; }; @@ -266,18 +268,21 @@ static void Decompile(const httplib::Request& req, httplib::Response& res) { } try { - session.Provenance = {}; + session.Provenance = std::make_unique(); std::vector args{"-Wno-pointer-to-int-cast", "-Wno-pointer-sign", "-target", session.Module->getTargetTriple()}; session.Unit = clang::tooling::buildASTFromCodeWithArgs("", args, "out.c"); + rellic::ASTBuilder ast{*session.Unit}; + session.Provenance->marker_expr = + ast.CreateAdd(ast.CreateFalse(), ast.CreateFalse()); rellic::DebugInfoCollector dic; dic.visit(*session.Module); - rellic::GenerateAST::run(*session.Module, session.Provenance, + rellic::GenerateAST::run(*session.Module, *session.Provenance, *session.Unit); - rellic::LocalDeclRenamer ldr{session.Provenance, *session.Unit, + rellic::LocalDeclRenamer ldr{*session.Provenance, *session.Unit, dic.GetIRToNameMap()}; - rellic::StructFieldRenamer sfr{session.Provenance, *session.Unit, + rellic::StructFieldRenamer sfr{*session.Provenance, *session.Unit, dic.GetIRTypeToDITypeMap()}; ldr.Run(); sfr.Run(); @@ -396,38 +401,42 @@ static std::unique_ptr CreatePass( auto str{name->str()}; if (str == "cbr") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.Provenance, *session.Unit); } else if (str == "dse") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.Provenance, *session.Unit); } else if (str == "ec") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.Provenance, *session.Unit); } else if (str == "lr") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.Provenance, *session.Unit); + } else if (str == "mc") { + return std::make_unique(*session.Provenance, + *session.Unit); } else if (str == "ncp") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.Provenance, *session.Unit); } else if (str == "nsc") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.Provenance, *session.Unit); } else if (str == "nc") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.Provenance, *session.Unit); } else if (str == "rbr") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.Provenance, *session.Unit); } else if (str == "zcs") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.Provenance, *session.Unit); } else { LOG(ERROR) << "Request contains invalid pass id"; return nullptr; } } else if (auto arr = val.getAsArray()) { - auto fix{std::make_unique(session.Provenance, *session.Unit)}; + auto fix{ + std::make_unique(*session.Provenance, *session.Unit)}; for (auto& pass : *arr) { auto p{CreatePass(session, pass)}; if (!p) { @@ -509,7 +518,7 @@ static void Run(const httplib::Request& req, httplib::Response& res) { return; } - auto composite{std::make_unique(session.Provenance, + auto composite{std::make_unique(*session.Provenance, *session.Unit)}; for (auto& obj : *json->getAsArray()) { auto pass{CreatePass(session, obj)}; @@ -579,7 +588,7 @@ static void Fixpoint(const httplib::Request& req, httplib::Response& res) { return; } - auto composite{std::make_unique(session.Provenance, + auto composite{std::make_unique(*session.Provenance, *session.Unit)}; for (auto& obj : *json->getAsArray()) { auto pass{CreatePass(session, obj)}; @@ -712,11 +721,12 @@ static void PrintAST(const httplib::Request& req, httplib::Response& res) { rellic::DecompilationResult::IRToTypeDeclMap type_to_decl_map; rellic::DecompilationResult::TypeDeclToIRMap type_provenance_map; - CopyMap(session.Provenance.stmt_provenance, stmt_provenance_map, + CopyMap(session.Provenance->stmt_provenance, stmt_provenance_map, value_to_stmt_map); - CopyMap(session.Provenance.value_decls, value_to_decl_map, + CopyMap(session.Provenance->value_decls, value_to_decl_map, decl_provenance_map); - CopyMap(session.Provenance.type_decls, type_to_decl_map, type_provenance_map); + CopyMap(session.Provenance->type_decls, type_to_decl_map, + type_provenance_map); std::string s; llvm::raw_string_ostream os(s); @@ -797,31 +807,31 @@ static void PrintProvenance(const httplib::Request& req, } llvm::json::Array stmt_provenance; - for (auto elem : session.Provenance.stmt_provenance) { + for (auto elem : session.Provenance->stmt_provenance) { stmt_provenance.push_back(llvm::json::Array( {(unsigned long long)elem.first, (unsigned long long)elem.second})); } llvm::json::Array type_decls; - for (auto elem : session.Provenance.type_decls) { + for (auto elem : session.Provenance->type_decls) { type_decls.push_back(llvm::json::Array( {(unsigned long long)elem.first, (unsigned long long)elem.second})); } llvm::json::Array value_decls; - for (auto elem : session.Provenance.value_decls) { + for (auto elem : session.Provenance->value_decls) { value_decls.push_back(llvm::json::Array( {(unsigned long long)elem.first, (unsigned long long)elem.second})); } llvm::json::Array temp_decls; - for (auto elem : session.Provenance.temp_decls) { + for (auto elem : session.Provenance->temp_decls) { temp_decls.push_back(llvm::json::Array( {(unsigned long long)elem.first, (unsigned long long)elem.second})); } llvm::json::Array use_provenance; - for (auto elem : session.Provenance.use_provenance) { + for (auto elem : session.Provenance->use_provenance) { if (!elem.second) { continue; } diff --git a/tools/xref/www/main.js b/tools/xref/www/main.js index 24f5eea9..848a7c34 100644 --- a/tools/xref/www/main.js +++ b/tools/xref/www/main.js @@ -36,6 +36,10 @@ const nc = { id: "nc", label: "Condition normalization" } +const mc = { + id: "mc", + label: "Materialize conditions" +} Vue.component('list-comp', { props: ["items", "availableCommands", "showDelete", "title"], @@ -106,7 +110,8 @@ const app = new Vue({ rbr, lr, ec, - nc + nc, + mc ], actions: [ { @@ -208,20 +213,20 @@ const app = new Vue({ throw (await res.json()).message } const prov = await res.json() - for(let map in prov) { - for(let [from, to] of prov[map]) { - if(!from || !to) { + for (let map in prov) { + for (let [from, to] of prov[map]) { + if (!from || !to) { continue } const from_hex = from.toString(16) const to_hex = to.toString(16) - if(!this.provenance[from_hex]) { + if (!this.provenance[from_hex]) { this.provenance[from_hex] = [] } this.provenance[from_hex].push(to_hex) - if(!this.provenance[to_hex]) { + if (!this.provenance[to_hex]) { this.provenance[to_hex] = [] } this.provenance[to_hex].push(from_hex) @@ -316,8 +321,7 @@ const app = new Vue({ [zcs, ncp, nsc, cbr, rbr], [lr, nsc], [zcs, ncp, nsc], - ec, - [zcs, ncp, nsc] + mc, ec ] }, openAngha() { From 58d2e34f157c07c44ae168c0078fd09b3561e59d Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Thu, 21 Jul 2022 13:13:51 +0200 Subject: [PATCH 02/57] Fix bugs --- include/rellic/AST/Util.h | 1 + include/rellic/Decompiler.h | 29 ------ lib/AST/GenerateAST.cpp | 25 +++++- lib/AST/IRToASTVisitor.cpp | 2 - lib/AST/NestedCondProp.cpp | 158 ++++++++++++++++++++++++--------- lib/AST/NestedScopeCombine.cpp | 1 - lib/AST/Util.cpp | 10 +++ lib/Decompiler.cpp | 128 +++++++------------------- tools/decomp/Decomp.cpp | 1 - 9 files changed, 182 insertions(+), 173 deletions(-) mode change 100644 => 100755 lib/AST/GenerateAST.cpp mode change 100644 => 100755 lib/AST/IRToASTVisitor.cpp mode change 100644 => 100755 lib/AST/NestedCondProp.cpp diff --git a/include/rellic/AST/Util.h b/include/rellic/AST/Util.h index 4f60e096..833c878a 100644 --- a/include/rellic/AST/Util.h +++ b/include/rellic/AST/Util.h @@ -109,4 +109,5 @@ z3::goal ApplyTactic(z3::context &ctx, const z3::tactic &tactic, z3::expr expr); bool Prove(z3::context &ctx, z3::expr expr); z3::expr HeavySimplify(z3::context &ctx, z3::expr expr); +z3::expr_vector Clone(z3::expr_vector &vec); } // namespace rellic \ No newline at end of file diff --git a/include/rellic/Decompiler.h b/include/rellic/Decompiler.h index a229d5b6..a62fc659 100644 --- a/include/rellic/Decompiler.h +++ b/include/rellic/Decompiler.h @@ -20,35 +20,6 @@ namespace rellic { struct DecompilationOptions { bool lower_switches = false; bool remove_phi_nodes = false; - bool disable_z3 = false; - bool dead_stmt_elimination = true; - struct { - bool z3_cond_simplify = true; - bool nested_cond_propagate = true; - bool nested_scope_combine = true; - bool cond_base_refine = true; - bool reach_based_refine = true; - bool expression_normalize = false; - } condition_based_refinement; - struct { - bool loop_refine = true; - bool nested_cond_propagate = true; - bool nested_scope_combine = true; - bool expression_normalize = false; - } loop_refinement; - struct { - bool z3_cond_simplify = true; - bool nested_cond_propagate = true; - bool nested_scope_combine = true; - bool expression_normalize = false; - } scope_refinement; - bool expression_normalize = false; - bool expression_combine = true; - struct { - bool z3_cond_simplify = true; - bool nested_cond_propagate = true; - bool nested_scope_combine = true; - } final_refinement; }; struct DecompilationResult { diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp old mode 100644 new mode 100755 index 2b22b6e1..13f64e32 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -26,6 +26,7 @@ #include #include "rellic/AST/ASTBuilder.h" +#include "rellic/AST/Util.h" #include "rellic/BC/Util.h" #include "rellic/Exception.h" @@ -288,7 +289,29 @@ void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { // Append `conj_cond` to reaching conditions of other // predecessors via an `||`. Use `conj_cond` if there // is no `cond` yet. - cond = HeavySimplify(provenance.z3_ctx, cond || conj_cond); + z3::expr_vector left_exprs{provenance.z3_ctx}; + z3::expr_vector right_exprs{provenance.z3_ctx}; + + if (cond.is_app() && cond.decl().decl_kind() == Z3_OP_AND) { + left_exprs = cond.args(); + } else { + left_exprs.push_back(cond); + } + + if (conj_cond.is_app() && conj_cond.decl().decl_kind() == Z3_OP_AND) { + right_exprs = conj_cond.args(); + } else { + right_exprs.push_back(conj_cond); + } + + z3::expr_vector res_cond{provenance.z3_ctx}; + for (auto expr_a : left_exprs) { + for (auto expr_b : right_exprs) { + res_cond.push_back((expr_a || expr_b).simplify()); + } + } + + cond = z3::mk_and(res_cond); } if (!Prove(provenance.z3_ctx, old_cond == cond)) { diff --git a/lib/AST/IRToASTVisitor.cpp b/lib/AST/IRToASTVisitor.cpp old mode 100644 new mode 100755 index d89cb93a..ebb97c7c --- a/lib/AST/IRToASTVisitor.cpp +++ b/lib/AST/IRToASTVisitor.cpp @@ -94,7 +94,6 @@ clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { CHECK_EQ(args.size(), 0) << "False cannot have arguments"; return ast.CreateFalse(); case Z3_OP_AND: { - CHECK_GE(args.size(), 2) << "And must have at least 2 arguments"; clang::Expr *res{args[0]}; for (auto i{1U}; i < args.size(); ++i) { res = ast.CreateLAnd(res, args[i]); @@ -102,7 +101,6 @@ clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { return res; } case Z3_OP_OR: { - CHECK_GE(args.size(), 2) << "Or must have at least 2 arguments"; clang::Expr *res{args[0]}; for (auto i{1U}; i < args.size(); ++i) { res = ast.CreateLOr(res, args[i]); diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp old mode 100644 new mode 100755 index 15208d04..2ffb7088 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -20,7 +20,7 @@ namespace rellic { class CompoundVisitor - : public clang::StmtVisitor { + : public clang::StmtVisitor { private: Provenance& provenance; ASTBuilder& ast; @@ -28,26 +28,115 @@ class CompoundVisitor bool IsConstant(z3::expr expr) { if (Prove(provenance.z3_ctx, expr)) { - return false; + return true; } if (Prove(provenance.z3_ctx, !expr)) { - return false; + return true; } - return true; + return false; + } + + void AddExpr(z3::expr expr, z3::expr_vector& vec) { + if (!IsConstant(expr)) { + vec.push_back(expr); + } } z3::expr Simplify(z3::expr expr) { return HeavySimplify(provenance.z3_ctx, expr); } + z3::expr SimplifyWithAssumptions(z3::expr expr, z3::expr_vector& true_exprs) { + auto true_expr{z3::mk_and(true_exprs)}; + auto decl_kind{expr.decl().decl_kind()}; + + if (!expr.is_app() || decl_kind == Z3_OP_UNINTERPRETED || + decl_kind == Z3_OP_NOT) { + if (Prove(provenance.z3_ctx, z3::implies(true_expr, expr))) { + return provenance.z3_ctx.bool_val(true); + } + + if (Prove(provenance.z3_ctx, z3::implies(true_expr, !expr))) { + return provenance.z3_ctx.bool_val(false); + } + + return expr; + } + + if (decl_kind == Z3_OP_TRUE || decl_kind == Z3_OP_FALSE) { + return expr; + } + + if (decl_kind == Z3_OP_OR) { + z3::expr_vector new_or{provenance.z3_ctx}; + for (auto sub : expr.args()) { + if (Prove(provenance.z3_ctx, z3::implies(true_exprs, sub))) { + return provenance.z3_ctx.bool_val(true); + } + + if (!Prove(provenance.z3_ctx, z3::implies(true_exprs, !expr))) { + new_or.push_back(sub); + } + } + + return z3::mk_or(new_or).simplify(); + } + + CHECK_EQ(decl_kind, Z3_OP_AND) + << "Unknown expression kind: " << expr.to_string(); + + z3::expr_vector new_conj{provenance.z3_ctx}; + for (auto sub : expr.args()) { + auto sub_decl_kind{sub.decl().decl_kind()}; + + if (sub_decl_kind == Z3_OP_TRUE) { + continue; + } + + if (sub_decl_kind == Z3_OP_FALSE) { + return sub; + } + + if (sub_decl_kind == Z3_OP_OR) { + z3::expr_vector new_disj{provenance.z3_ctx}; + for (auto sub_disj : sub.args()) { + if (Prove(provenance.z3_ctx, z3::implies(true_expr, !sub_disj))) { + continue; + } + + if (Prove(provenance.z3_ctx, z3::implies(true_expr, sub_disj))) { + new_disj.push_back(provenance.z3_ctx.bool_val(true)); + } + + new_disj.push_back(sub_disj); + } + + sub = z3::mk_or(new_disj).simplify(); + } + + if (Prove(provenance.z3_ctx, z3::implies(true_expr, !sub))) { + return provenance.z3_ctx.bool_val(false); + } + + if (Prove(provenance.z3_ctx, z3::implies(true_expr, sub))) { + continue; + } + + new_conj.push_back(sub); + } + + return HeavySimplify(provenance.z3_ctx, z3::mk_and(new_conj)); + } + public: CompoundVisitor(Provenance& provenance, ASTBuilder& ast, clang::ASTContext& ctx) : provenance(provenance), ast(ast), ctx(ctx) {} - bool VisitCompoundStmt(clang::CompoundStmt* compound, z3::expr& true_exprs) { + bool VisitCompoundStmt(clang::CompoundStmt* compound, + z3::expr_vector& true_exprs) { bool changed{false}; for (auto stmt : compound->body()) { changed |= Visit(stmt, true_exprs); @@ -56,14 +145,12 @@ class CompoundVisitor return changed; } - bool VisitWhileStmt(clang::WhileStmt* while_stmt, z3::expr& true_exprs) { + bool VisitWhileStmt(clang::WhileStmt* while_stmt, + z3::expr_vector& true_exprs) { bool changed{false}; if (while_stmt->getCond() == provenance.marker_expr) { auto old_cond{provenance.z3_exprs[provenance.conds[while_stmt]]}; - auto new_cond{Simplify(old_cond && true_exprs)}; - LOG(INFO) << "known: " << true_exprs.to_string() - << " old: " << old_cond.to_string() - << " new: " << new_cond.to_string(); + auto new_cond{SimplifyWithAssumptions(old_cond, true_exprs)}; if (!z3::eq(old_cond, new_cond)) { provenance.conds[while_stmt] = provenance.z3_exprs.size(); provenance.z3_exprs.push_back(new_cond); @@ -72,27 +159,21 @@ class CompoundVisitor } auto cond{provenance.z3_exprs[provenance.conds[while_stmt]]}; - z3::expr inner{true_exprs}; - bool isConstant{IsConstant(cond)}; - if (!isConstant) { - inner = Simplify(inner && cond); - } + + auto inner{Clone(true_exprs)}; + auto simplified_cond{SimplifyWithAssumptions(cond, true_exprs)}; + AddExpr(simplified_cond, inner); changed |= Visit(while_stmt->getBody(), inner); - if (!isConstant) { - true_exprs = Simplify(true_exprs && !cond); - } + AddExpr(!cond, true_exprs); return changed; } - bool VisitDoStmt(clang::DoStmt* do_stmt, z3::expr& true_exprs) { + bool VisitDoStmt(clang::DoStmt* do_stmt, z3::expr_vector& true_exprs) { bool changed{false}; if (do_stmt->getCond() == provenance.marker_expr) { auto old_cond{provenance.z3_exprs[provenance.conds[do_stmt]]}; - auto new_cond{Simplify(old_cond && true_exprs)}; - LOG(INFO) << "known: " << true_exprs.to_string() - << " old: " << old_cond.to_string() - << " new: " << new_cond.to_string(); + auto new_cond{SimplifyWithAssumptions(old_cond, true_exprs)}; if (!z3::eq(old_cond, new_cond)) { provenance.conds[do_stmt] = provenance.z3_exprs.size(); provenance.z3_exprs.push_back(new_cond); @@ -101,23 +182,20 @@ class CompoundVisitor } auto cond{provenance.z3_exprs[provenance.conds[do_stmt]]}; - auto inner{true_exprs}; + auto inner{Clone(true_exprs)}; + Visit(do_stmt->getBody(), inner); - if (!IsConstant(cond)) { - true_exprs = Simplify(true_exprs && !cond); - } + auto simplified_cond{SimplifyWithAssumptions(cond, true_exprs)}; + AddExpr(simplified_cond, true_exprs); return changed; } - bool VisitIfStmt(clang::IfStmt* if_stmt, z3::expr& true_exprs) { + bool VisitIfStmt(clang::IfStmt* if_stmt, z3::expr_vector& true_exprs) { bool changed{false}; if (if_stmt->getCond() == provenance.marker_expr) { auto old_cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; - auto new_cond{Simplify(old_cond && true_exprs)}; - LOG(INFO) << "known: " << true_exprs.to_string() - << " old: " << old_cond.to_string() - << " new: " << new_cond.to_string(); + auto new_cond{SimplifyWithAssumptions(old_cond, true_exprs)}; if (!z3::eq(old_cond, new_cond)) { provenance.conds[if_stmt] = provenance.z3_exprs.size(); provenance.z3_exprs.push_back(new_cond); @@ -126,18 +204,14 @@ class CompoundVisitor } auto cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; - auto inner_then{true_exprs}; - bool isConstant{IsConstant(cond)}; - if (isConstant) { - inner_then = Simplify(inner_then && cond); - } + auto inner_then{Clone(true_exprs)}; + auto simplified_cond{SimplifyWithAssumptions(cond, true_exprs)}; + AddExpr(simplified_cond, inner_then); Visit(if_stmt->getThen(), inner_then); if (if_stmt->getElse()) { - auto inner_else{true_exprs}; - if (!isConstant) { - inner_else = Simplify(inner_else && !cond); - } + auto inner_else{Clone(true_exprs)}; + AddExpr(!simplified_cond, inner_else); Visit(if_stmt->getElse(), inner_else); } return changed; @@ -155,7 +229,7 @@ void NestedCondProp::RunImpl() { for (auto decl : ast_ctx.getTranslationUnitDecl()->decls()) { if (auto fdecl = clang::dyn_cast(decl)) { if (fdecl->hasBody()) { - auto true_exprs{provenance.z3_ctx.bool_val(true)}; + z3::expr_vector true_exprs{provenance.z3_ctx}; changed |= visitor.Visit(fdecl->getBody(), true_exprs); } } diff --git a/lib/AST/NestedScopeCombine.cpp b/lib/AST/NestedScopeCombine.cpp index 51981b31..45fe0bab 100644 --- a/lib/AST/NestedScopeCombine.cpp +++ b/lib/AST/NestedScopeCombine.cpp @@ -11,7 +11,6 @@ #include #include -#include "rellic/AST/Compat/Stmt.h" #include "rellic/AST/Util.h" namespace rellic { diff --git a/lib/AST/Util.cpp b/lib/AST/Util.cpp index b56cbebf..3bec3c9f 100644 --- a/lib/AST/Util.cpp +++ b/lib/AST/Util.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include "rellic/AST/ASTBuilder.h" #include "rellic/Exception.h" @@ -361,4 +362,13 @@ z3::expr HeavySimplify(z3::context &ctx, z3::expr expr) { auto tactic{simplify & aig & ctx_solver_simplify}; return ApplyTactic(ctx, tactic, expr).as_expr(); } + +z3::expr_vector Clone(z3::expr_vector &vec) { + z3::expr_vector clone{vec.ctx()}; + for (auto expr : vec) { + clone.push_back(expr); + } + + return clone; +} } // namespace rellic diff --git a/lib/Decompiler.cpp b/lib/Decompiler.cpp index b287138c..1297d3cf 100644 --- a/lib/Decompiler.cpp +++ b/lib/Decompiler.cpp @@ -32,7 +32,6 @@ #include "rellic/AST/MaterializeConds.h" #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" -#include "rellic/AST/NormalizeCond.h" #include "rellic/AST/ReachBasedRefine.h" #include "rellic/AST/StructFieldRenamer.h" #include "rellic/AST/Z3CondSimplify.h" @@ -85,21 +84,19 @@ Result Decompile( module->getTargetTriple()}; auto ast_unit{clang::tooling::buildASTFromCodeWithArgs("", args, "out.c")}; - rellic::Provenance provenance; - rellic::GenerateAST::run(*module, provenance, *ast_unit); ASTBuilder ast{*ast_unit}; + rellic::Provenance provenance; provenance.marker_expr = ast.CreateAdd(ast.CreateFalse(), ast.CreateFalse()); + rellic::GenerateAST::run(*module, provenance, *ast_unit); // TODO(surovic): Add llvm::Value* -> clang::Decl* map // Especially for llvm::Argument* and llvm::Function*. rellic::CompositeASTPass pass_ast(provenance, *ast_unit); auto& ast_passes{pass_ast.GetPasses()}; - if (options.dead_stmt_elimination) { - ast_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } + ast_passes.push_back( + std::make_unique(provenance, *ast_unit)); ast_passes.push_back(std::make_unique( provenance, *ast_unit, dic.GetIRToNameMap())); ast_passes.push_back(std::make_unique( @@ -109,36 +106,18 @@ Result Decompile( rellic::CompositeASTPass pass_cbr(provenance, *ast_unit); auto& cbr_passes{pass_cbr.GetPasses()}; - if (options.condition_based_refinement.expression_normalize) { - cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (!options.disable_z3) { - auto zcs{std::make_unique(provenance, *ast_unit)}; - if (options.condition_based_refinement.z3_cond_simplify) { - cbr_passes.push_back(std::move(zcs)); - } - if (options.condition_based_refinement.nested_cond_propagate) { - cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - } + cbr_passes.push_back( + std::make_unique(provenance, *ast_unit)); + cbr_passes.push_back( + std::make_unique(provenance, *ast_unit)); - if (options.condition_based_refinement.nested_scope_combine) { - cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } + cbr_passes.push_back( + std::make_unique(provenance, *ast_unit)); - if (!options.disable_z3) { - if (options.condition_based_refinement.cond_base_refine) { - cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.condition_based_refinement.reach_based_refine) { - cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - } + cbr_passes.push_back( + std::make_unique(provenance, *ast_unit)); + cbr_passes.push_back( + std::make_unique(provenance, *ast_unit)); while (pass_cbr.Run()) { ; @@ -147,47 +126,27 @@ Result Decompile( rellic::CompositeASTPass pass_loop{provenance, *ast_unit}; auto& loop_passes{pass_loop.GetPasses()}; - if (options.loop_refinement.loop_refine) { - loop_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.loop_refinement.nested_cond_propagate) { - loop_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.loop_refinement.nested_scope_combine) { - loop_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.loop_refinement.expression_normalize) { - loop_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } + loop_passes.push_back( + std::make_unique(provenance, *ast_unit)); + loop_passes.push_back( + std::make_unique(provenance, *ast_unit)); + loop_passes.push_back( + std::make_unique(provenance, *ast_unit)); + while (pass_loop.Run()) { ; } rellic::CompositeASTPass pass_scope{provenance, *ast_unit}; auto& scope_passes{pass_scope.GetPasses()}; - if (!options.disable_z3) { - auto zcs{std::make_unique(provenance, *ast_unit)}; - if (options.condition_based_refinement.z3_cond_simplify) { - scope_passes.push_back(std::move(zcs)); - } - if (options.condition_based_refinement.nested_cond_propagate) { - scope_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - } + scope_passes.push_back( + std::make_unique(provenance, *ast_unit)); + scope_passes.push_back( + std::make_unique(provenance, *ast_unit)); + + scope_passes.push_back( + std::make_unique(provenance, *ast_unit)); - if (options.scope_refinement.nested_scope_combine) { - scope_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.scope_refinement.expression_normalize) { - scope_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } while (pass_scope.Run()) { ; } @@ -196,35 +155,10 @@ Result Decompile( auto& ec_passes{pass_ec.GetPasses()}; ec_passes.push_back( std::make_unique(provenance, *ast_unit)); - if (options.expression_combine) { - ec_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.expression_normalize) { - ec_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - while (pass_ec.Run()) { - ; - } + ec_passes.push_back( + std::make_unique(provenance, *ast_unit)); - rellic::CompositeASTPass pass_final{provenance, *ast_unit}; - auto& final_passes{pass_final.GetPasses()}; - if (options.final_refinement.z3_cond_simplify) { - final_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.final_refinement.nested_cond_propagate) { - final_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.final_refinement.nested_scope_combine) { - final_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - while (pass_final.Run()) { - ; - } + pass_ec.Run(); DecompilationResult result{}; result.ast = std::move(ast_unit); diff --git a/tools/decomp/Decomp.cpp b/tools/decomp/Decomp.cpp index 628507b8..f457ad40 100644 --- a/tools/decomp/Decomp.cpp +++ b/tools/decomp/Decomp.cpp @@ -121,7 +121,6 @@ int main(int argc, char* argv[]) { CHECK(!ec) << "Failed to create output file: " << ec.message(); rellic::DecompilationOptions opts{}; - opts.disable_z3 = FLAGS_disable_z3; opts.lower_switches = FLAGS_lower_switch; opts.remove_phi_nodes = FLAGS_remove_phi_nodes; From 5d33d0776293f41e1b5d0f4dfa51a671d19df53c Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Thu, 21 Jul 2022 14:42:12 +0200 Subject: [PATCH 03/57] Fix bug --- lib/AST/NestedCondProp.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index 2ffb7088..9dd01118 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -187,7 +187,7 @@ class CompoundVisitor Visit(do_stmt->getBody(), inner); auto simplified_cond{SimplifyWithAssumptions(cond, true_exprs)}; - AddExpr(simplified_cond, true_exprs); + AddExpr(!simplified_cond, true_exprs); return changed; } From b9e3264901e16a952b75010ae1378d13a8d97397 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Thu, 21 Jul 2022 14:42:54 +0200 Subject: [PATCH 04/57] Combine whiles --- include/rellic/AST/NestedScopeCombine.h | 1 + lib/AST/NestedScopeCombine.cpp | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/include/rellic/AST/NestedScopeCombine.h b/include/rellic/AST/NestedScopeCombine.h index 42db6bb8..9281920e 100644 --- a/include/rellic/AST/NestedScopeCombine.h +++ b/include/rellic/AST/NestedScopeCombine.h @@ -38,6 +38,7 @@ class NestedScopeCombine : public TransformVisitor { NestedScopeCombine(Provenance &provenance, clang::ASTUnit &unit); bool VisitIfStmt(clang::IfStmt *ifstmt); + bool VisitWhileStmt(clang::WhileStmt *stmt); bool VisitCompoundStmt(clang::CompoundStmt *compound); }; diff --git a/lib/AST/NestedScopeCombine.cpp b/lib/AST/NestedScopeCombine.cpp index 45fe0bab..e1db0415 100644 --- a/lib/AST/NestedScopeCombine.cpp +++ b/lib/AST/NestedScopeCombine.cpp @@ -32,6 +32,19 @@ bool NestedScopeCombine::VisitIfStmt(clang::IfStmt *ifstmt) { return !Stopped(); } +bool NestedScopeCombine::VisitWhileStmt(clang::WhileStmt *stmt) { + auto cond{provenance.z3_exprs[provenance.conds[stmt]]}; + if (Prove(provenance.z3_ctx, cond)) { + auto body{clang::cast(stmt->getBody())}; + if (clang::isa(body->body_back())) { + std::vector new_body{body->body_begin(), + body->body_end() - 1}; + substitutions[stmt] = ast.CreateCompoundStmt(new_body); + } + } + return !Stopped(); +} + bool NestedScopeCombine::VisitCompoundStmt(clang::CompoundStmt *compound) { // DLOG(INFO) << "VisitCompoundStmt"; bool has_compound = false; From 4ed004069544b0cd496baea11d41feb48243745b Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Thu, 21 Jul 2022 16:21:27 +0200 Subject: [PATCH 05/57] Improve simplification --- lib/AST/NestedCondProp.cpp | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index 9dd01118..0b5557ae 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -52,20 +52,17 @@ class CompoundVisitor auto true_expr{z3::mk_and(true_exprs)}; auto decl_kind{expr.decl().decl_kind()}; - if (!expr.is_app() || decl_kind == Z3_OP_UNINTERPRETED || - decl_kind == Z3_OP_NOT) { - if (Prove(provenance.z3_ctx, z3::implies(true_expr, expr))) { - return provenance.z3_ctx.bool_val(true); - } - - if (Prove(provenance.z3_ctx, z3::implies(true_expr, !expr))) { - return provenance.z3_ctx.bool_val(false); - } + if (Prove(provenance.z3_ctx, z3::implies(true_expr, expr))) { + return provenance.z3_ctx.bool_val(true); + } - return expr; + if (Prove(provenance.z3_ctx, z3::implies(true_expr, !expr))) { + return provenance.z3_ctx.bool_val(false); } - if (decl_kind == Z3_OP_TRUE || decl_kind == Z3_OP_FALSE) { + if (!expr.is_app() || decl_kind == Z3_OP_UNINTERPRETED || + decl_kind == Z3_OP_NOT || decl_kind == Z3_OP_TRUE || + decl_kind == Z3_OP_FALSE) { return expr; } From 7739b459839102d0fed7e8d44a539647c8cd47a0 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Fri, 22 Jul 2022 08:52:23 +0200 Subject: [PATCH 06/57] Better condition generation. Still slow. --- lib/AST/GenerateAST.cpp | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index 13f64e32..5fc494a0 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -147,7 +147,13 @@ unsigned GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, }; if (provenance.z3_br_edges.find({inst, cond}) == provenance.z3_br_edges.end()) { - if (cond) { + if (auto constant = + llvm::dyn_cast(inst->getCondition())) { + provenance.z3_br_edges[{inst, cond}] = provenance.z3_exprs.size(); + auto edge{provenance.z3_ctx.bool_val(constant->isOne() == cond)}; + provenance.z3_exprs.push_back(edge); + provenance.z3_br_edges_inv[edge.id()] = {inst, true}; + } else if (cond) { auto name{GetName(inst)}; auto edge{provenance.z3_ctx.bool_const(name.c_str())}; provenance.z3_br_edges[{inst, cond}] = provenance.z3_exprs.size(); @@ -223,13 +229,14 @@ unsigned GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, if (to == sw->getDefaultDest()) { result = ToExpr(GetOrCreateEdgeForSwitch(sw, nullptr)); } else { - result = provenance.z3_ctx.bool_val(false); + z3::expr_vector or_vec{provenance.z3_ctx}; for (auto sw_case : sw->cases()) { if (sw_case.getCaseSuccessor() == to) { - result = result || ToExpr(GetOrCreateEdgeForSwitch( - sw, sw_case.getCaseValue())); + or_vec.push_back( + ToExpr(GetOrCreateEdgeForSwitch(sw, sw_case.getCaseValue()))); } } + result = HeavySimplify(provenance.z3_ctx, z3::mk_or(or_vec)); } } break; // Returns @@ -292,13 +299,13 @@ void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { z3::expr_vector left_exprs{provenance.z3_ctx}; z3::expr_vector right_exprs{provenance.z3_ctx}; - if (cond.is_app() && cond.decl().decl_kind() == Z3_OP_AND) { + if (cond.decl().decl_kind() == Z3_OP_AND) { left_exprs = cond.args(); } else { left_exprs.push_back(cond); } - if (conj_cond.is_app() && conj_cond.decl().decl_kind() == Z3_OP_AND) { + if (conj_cond.decl().decl_kind() == Z3_OP_AND) { right_exprs = conj_cond.args(); } else { right_exprs.push_back(conj_cond); @@ -307,7 +314,8 @@ void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { z3::expr_vector res_cond{provenance.z3_ctx}; for (auto expr_a : left_exprs) { for (auto expr_b : right_exprs) { - res_cond.push_back((expr_a || expr_b).simplify()); + res_cond.push_back( + HeavySimplify(provenance.z3_ctx, expr_a || expr_b)); } } From 29f02ed1e3aac68b486ebc57226d0edfb01119c2 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Mon, 1 Aug 2022 16:31:20 +0200 Subject: [PATCH 07/57] Revert to simple substitution --- lib/AST/GenerateAST.cpp | 28 +------- lib/AST/NestedCondProp.cpp | 129 ++++++++++++++----------------------- 2 files changed, 50 insertions(+), 107 deletions(-) diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index 5fc494a0..ba296205 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -285,7 +285,7 @@ void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { auto old_cond{ToExpr(GetReachingCond(block))}; if (block->hasNPredecessorsOrMore(1)) { // Gather reaching conditions from predecessors of the block - auto cond{provenance.z3_ctx.bool_val(false)}; + z3::expr_vector conds{provenance.z3_ctx}; for (auto pred : llvm::predecessors(block)) { auto pred_cond{ToExpr(GetReachingCond(pred))}; auto edge_cond{ToExpr(GetOrCreateEdgeCond(pred, block))}; @@ -296,32 +296,10 @@ void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { // Append `conj_cond` to reaching conditions of other // predecessors via an `||`. Use `conj_cond` if there // is no `cond` yet. - z3::expr_vector left_exprs{provenance.z3_ctx}; - z3::expr_vector right_exprs{provenance.z3_ctx}; - - if (cond.decl().decl_kind() == Z3_OP_AND) { - left_exprs = cond.args(); - } else { - left_exprs.push_back(cond); - } - - if (conj_cond.decl().decl_kind() == Z3_OP_AND) { - right_exprs = conj_cond.args(); - } else { - right_exprs.push_back(conj_cond); - } - - z3::expr_vector res_cond{provenance.z3_ctx}; - for (auto expr_a : left_exprs) { - for (auto expr_b : right_exprs) { - res_cond.push_back( - HeavySimplify(provenance.z3_ctx, expr_a || expr_b)); - } - } - - cond = z3::mk_and(res_cond); + conds.push_back(conj_cond); } + auto cond{HeavySimplify(provenance.z3_ctx, z3::mk_or(conds))}; if (!Prove(provenance.z3_ctx, old_cond == cond)) { provenance.reaching_conds[block] = provenance.z3_exprs.size(); provenance.z3_exprs.push_back(cond); diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index 0b5557ae..e96e055a 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -13,6 +13,8 @@ #include #include +#include +#include #include #include "rellic/AST/ASTBuilder.h" @@ -39,92 +41,58 @@ class CompoundVisitor } void AddExpr(z3::expr expr, z3::expr_vector& vec) { - if (!IsConstant(expr)) { - vec.push_back(expr); + if (IsConstant(expr)) { + return; } + + if (expr.is_and()) { + for (auto e : expr.args()) { + AddExpr(e, vec); + } + return; + } + + vec.push_back(expr); } z3::expr Simplify(z3::expr expr) { return HeavySimplify(provenance.z3_ctx, expr); } - z3::expr SimplifyWithAssumptions(z3::expr expr, z3::expr_vector& true_exprs) { - auto true_expr{z3::mk_and(true_exprs)}; - auto decl_kind{expr.decl().decl_kind()}; - - if (Prove(provenance.z3_ctx, z3::implies(true_expr, expr))) { - return provenance.z3_ctx.bool_val(true); - } - - if (Prove(provenance.z3_ctx, z3::implies(true_expr, !expr))) { - return provenance.z3_ctx.bool_val(false); - } - - if (!expr.is_app() || decl_kind == Z3_OP_UNINTERPRETED || - decl_kind == Z3_OP_NOT || decl_kind == Z3_OP_TRUE || - decl_kind == Z3_OP_FALSE) { - return expr; - } - - if (decl_kind == Z3_OP_OR) { - z3::expr_vector new_or{provenance.z3_ctx}; - for (auto sub : expr.args()) { - if (Prove(provenance.z3_ctx, z3::implies(true_exprs, sub))) { - return provenance.z3_ctx.bool_val(true); - } - - if (!Prove(provenance.z3_ctx, z3::implies(true_exprs, !expr))) { - new_or.push_back(sub); - } - } - - return z3::mk_or(new_or).simplify(); + z3::expr ApplyAssumptions(z3::expr expr, z3::expr_vector& true_exprs) { + z3::expr_vector dest{provenance.z3_ctx}; + for (auto e : true_exprs) { + dest.push_back(provenance.z3_ctx.bool_val(true)); } - CHECK_EQ(decl_kind, Z3_OP_AND) - << "Unknown expression kind: " << expr.to_string(); - - z3::expr_vector new_conj{provenance.z3_ctx}; - for (auto sub : expr.args()) { - auto sub_decl_kind{sub.decl().decl_kind()}; - - if (sub_decl_kind == Z3_OP_TRUE) { - continue; - } - - if (sub_decl_kind == Z3_OP_FALSE) { - return sub; - } - - if (sub_decl_kind == Z3_OP_OR) { - z3::expr_vector new_disj{provenance.z3_ctx}; - for (auto sub_disj : sub.args()) { - if (Prove(provenance.z3_ctx, z3::implies(true_expr, !sub_disj))) { - continue; - } - - if (Prove(provenance.z3_ctx, z3::implies(true_expr, sub_disj))) { - new_disj.push_back(provenance.z3_ctx.bool_val(true)); - } - - new_disj.push_back(sub_disj); - } + return expr.substitute(true_exprs, dest); + } - sub = z3::mk_or(new_disj).simplify(); + z3::expr Sort(z3::expr expr) { + if (expr.is_and() || expr.is_or()) { + auto args{expr.args()}; + std::vector args_indices{args.size()}; + std::iota(args_indices.begin(), args_indices.end(), 0); + std::sort(args_indices.begin(), args_indices.end(), + [&args](unsigned a, unsigned b) { + return args[a].id() < args[b].id(); + }); + z3::expr_vector new_args{provenance.z3_ctx}; + for (auto idx : args_indices) { + new_args.push_back(args[idx]); } - - if (Prove(provenance.z3_ctx, z3::implies(true_expr, !sub))) { - return provenance.z3_ctx.bool_val(false); - } - - if (Prove(provenance.z3_ctx, z3::implies(true_expr, sub))) { - continue; + if (expr.is_and()) { + return z3::mk_and(new_args); + } else { + return z3::mk_or(new_args); } + } - new_conj.push_back(sub); + if (expr.is_not()) { + return !Sort(expr.arg(0)); } - return HeavySimplify(provenance.z3_ctx, z3::mk_and(new_conj)); + return expr; } public: @@ -146,8 +114,8 @@ class CompoundVisitor z3::expr_vector& true_exprs) { bool changed{false}; if (while_stmt->getCond() == provenance.marker_expr) { - auto old_cond{provenance.z3_exprs[provenance.conds[while_stmt]]}; - auto new_cond{SimplifyWithAssumptions(old_cond, true_exprs)}; + auto old_cond{Sort(provenance.z3_exprs[provenance.conds[while_stmt]])}; + auto new_cond{Sort(Simplify(ApplyAssumptions(old_cond, true_exprs)))}; if (!z3::eq(old_cond, new_cond)) { provenance.conds[while_stmt] = provenance.z3_exprs.size(); provenance.z3_exprs.push_back(new_cond); @@ -158,8 +126,7 @@ class CompoundVisitor auto cond{provenance.z3_exprs[provenance.conds[while_stmt]]}; auto inner{Clone(true_exprs)}; - auto simplified_cond{SimplifyWithAssumptions(cond, true_exprs)}; - AddExpr(simplified_cond, inner); + AddExpr(cond, inner); changed |= Visit(while_stmt->getBody(), inner); AddExpr(!cond, true_exprs); @@ -170,7 +137,7 @@ class CompoundVisitor bool changed{false}; if (do_stmt->getCond() == provenance.marker_expr) { auto old_cond{provenance.z3_exprs[provenance.conds[do_stmt]]}; - auto new_cond{SimplifyWithAssumptions(old_cond, true_exprs)}; + auto new_cond{Sort(Simplify(ApplyAssumptions(old_cond, true_exprs)))}; if (!z3::eq(old_cond, new_cond)) { provenance.conds[do_stmt] = provenance.z3_exprs.size(); provenance.z3_exprs.push_back(new_cond); @@ -183,8 +150,7 @@ class CompoundVisitor Visit(do_stmt->getBody(), inner); - auto simplified_cond{SimplifyWithAssumptions(cond, true_exprs)}; - AddExpr(!simplified_cond, true_exprs); + AddExpr(!cond, true_exprs); return changed; } @@ -192,7 +158,7 @@ class CompoundVisitor bool changed{false}; if (if_stmt->getCond() == provenance.marker_expr) { auto old_cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; - auto new_cond{SimplifyWithAssumptions(old_cond, true_exprs)}; + auto new_cond{Sort(Simplify(ApplyAssumptions(old_cond, true_exprs)))}; if (!z3::eq(old_cond, new_cond)) { provenance.conds[if_stmt] = provenance.z3_exprs.size(); provenance.z3_exprs.push_back(new_cond); @@ -202,13 +168,12 @@ class CompoundVisitor auto cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; auto inner_then{Clone(true_exprs)}; - auto simplified_cond{SimplifyWithAssumptions(cond, true_exprs)}; - AddExpr(simplified_cond, inner_then); + AddExpr(cond, inner_then); Visit(if_stmt->getThen(), inner_then); if (if_stmt->getElse()) { auto inner_else{Clone(true_exprs)}; - AddExpr(!simplified_cond, inner_else); + AddExpr(!cond, inner_else); Visit(if_stmt->getElse(), inner_else); } return changed; From 1b2d7204d5cb52d1a3b68878ee9e4fead99d2a1d Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Tue, 2 Aug 2022 16:30:47 +0200 Subject: [PATCH 08/57] Refactor ZCS --- include/rellic/AST/Z3CondSimplify.h | 8 ++------ lib/AST/Z3CondSimplify.cpp | 29 +++++------------------------ 2 files changed, 7 insertions(+), 30 deletions(-) diff --git a/include/rellic/AST/Z3CondSimplify.h b/include/rellic/AST/Z3CondSimplify.h index 2ebf39a6..2aa1bfc0 100644 --- a/include/rellic/AST/Z3CondSimplify.h +++ b/include/rellic/AST/Z3CondSimplify.h @@ -10,7 +10,7 @@ #include -#include "rellic/AST/TransformVisitor.h" +#include "rellic/AST/ASTPass.h" #include "rellic/AST/Util.h" namespace rellic { @@ -19,17 +19,13 @@ namespace rellic { * This pass simplifies conditions using Z3 by trying to remove terms that are * trivially true or false */ -class Z3CondSimplify : public TransformVisitor { +class Z3CondSimplify : public ASTPass { private: protected: void RunImpl() override; public: Z3CondSimplify(Provenance &provenance, clang::ASTUnit &unit); - - bool VisitIfStmt(clang::IfStmt *stmt); - bool VisitWhileStmt(clang::WhileStmt *loop); - bool VisitDoStmt(clang::DoStmt *loop); }; } // namespace rellic diff --git a/lib/AST/Z3CondSimplify.cpp b/lib/AST/Z3CondSimplify.cpp index 6e4d3a71..2bf66e34 100644 --- a/lib/AST/Z3CondSimplify.cpp +++ b/lib/AST/Z3CondSimplify.cpp @@ -17,33 +17,14 @@ namespace rellic { Z3CondSimplify::Z3CondSimplify(Provenance &provenance, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit) {} - -bool Z3CondSimplify::VisitIfStmt(clang::IfStmt *stmt) { - auto cond{provenance.z3_exprs[provenance.conds[stmt]]}; - provenance.conds[stmt] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(HeavySimplify(provenance.z3_ctx, cond)); - return true; -} - -bool Z3CondSimplify::VisitWhileStmt(clang::WhileStmt *loop) { - auto cond{provenance.z3_exprs[provenance.conds[loop]]}; - provenance.conds[loop] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(HeavySimplify(provenance.z3_ctx, cond)); - return true; -} - -bool Z3CondSimplify::VisitDoStmt(clang::DoStmt *loop) { - auto cond{provenance.z3_exprs[provenance.conds[loop]]}; - provenance.conds[loop] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(HeavySimplify(provenance.z3_ctx, cond)); - return true; -} + : ASTPass(provenance, unit) {} void Z3CondSimplify::RunImpl() { LOG(INFO) << "Simplifying conditions using Z3"; - TransformVisitor::RunImpl(); - TraverseDecl(ast_ctx.getTranslationUnitDecl()); + for (size_t i{0}; i < provenance.z3_exprs.size() && !Stopped(); ++i) { + provenance.z3_exprs[i] = + HeavySimplify(provenance.z3_ctx, provenance.z3_exprs[i]); + } } } // namespace rellic From 4739b4a5c1062ff21bb4b7f96c5e3ae1c3b46167 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Tue, 2 Aug 2022 16:32:50 +0200 Subject: [PATCH 09/57] Refactor NCP --- lib/AST/NestedCondProp.cpp | 110 ++++++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 39 deletions(-) diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index e96e055a..2b172ba2 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -21,8 +21,15 @@ #include "rellic/AST/Util.h" namespace rellic { +struct KnownExprs { + z3::expr_vector src; + z3::expr_vector dst; + + KnownExprs Clone() { return {::rellic::Clone(src), ::rellic::Clone(dst)}; } +}; + class CompoundVisitor - : public clang::StmtVisitor { + : public clang::StmtVisitor { private: Provenance& provenance; ASTBuilder& ast; @@ -40,37 +47,62 @@ class CompoundVisitor return false; } - void AddExpr(z3::expr expr, z3::expr_vector& vec) { + void AddExpr(z3::expr expr, bool value, KnownExprs& vec) { if (IsConstant(expr)) { return; } - if (expr.is_and()) { - for (auto e : expr.args()) { - AddExpr(e, vec); - } + if (expr.is_not()) { + AddExpr(expr.arg(0), !value, vec); return; } - vec.push_back(expr); + if (value) { + if (expr.is_and()) { + for (auto e : expr.args()) { + AddExpr(e, true, vec); + } + return; + } + + if (expr.is_or() && expr.num_args() == 1) { + AddExpr(expr.arg(0), true, vec); + return; + } + } else { + if (expr.is_or()) { + for (auto e : expr.args()) { + AddExpr(e, false, vec); + } + return; + } + + if (expr.is_and() && expr.num_args() == 1) { + AddExpr(expr.arg(0), false, vec); + return; + } + } + + vec.src.push_back(expr); + vec.dst.push_back(provenance.z3_ctx.bool_val(value)); } z3::expr Simplify(z3::expr expr) { return HeavySimplify(provenance.z3_ctx, expr); } - z3::expr ApplyAssumptions(z3::expr expr, z3::expr_vector& true_exprs) { - z3::expr_vector dest{provenance.z3_ctx}; - for (auto e : true_exprs) { - dest.push_back(provenance.z3_ctx.bool_val(true)); - } - - return expr.substitute(true_exprs, dest); + z3::expr ApplyAssumptions(z3::expr expr, KnownExprs& known_exprs) { + auto res{expr.substitute(known_exprs.src, known_exprs.dst)}; + return res; } z3::expr Sort(z3::expr expr) { if (expr.is_and() || expr.is_or()) { auto args{expr.args()}; + for (size_t i{0}; i < args.size(); ++i) { + args[i] = Sort(args[i]); + } + std::vector args_indices{args.size()}; std::iota(args_indices.begin(), args_indices.end(), 0); std::sort(args_indices.begin(), args_indices.end(), @@ -101,21 +133,20 @@ class CompoundVisitor : provenance(provenance), ast(ast), ctx(ctx) {} bool VisitCompoundStmt(clang::CompoundStmt* compound, - z3::expr_vector& true_exprs) { + KnownExprs& known_exprs) { bool changed{false}; for (auto stmt : compound->body()) { - changed |= Visit(stmt, true_exprs); + changed |= Visit(stmt, known_exprs); } return changed; } - bool VisitWhileStmt(clang::WhileStmt* while_stmt, - z3::expr_vector& true_exprs) { + bool VisitWhileStmt(clang::WhileStmt* while_stmt, KnownExprs& known_exprs) { bool changed{false}; if (while_stmt->getCond() == provenance.marker_expr) { auto old_cond{Sort(provenance.z3_exprs[provenance.conds[while_stmt]])}; - auto new_cond{Sort(Simplify(ApplyAssumptions(old_cond, true_exprs)))}; + auto new_cond{Sort(Simplify(ApplyAssumptions(old_cond, known_exprs)))}; if (!z3::eq(old_cond, new_cond)) { provenance.conds[while_stmt] = provenance.z3_exprs.size(); provenance.z3_exprs.push_back(new_cond); @@ -125,19 +156,19 @@ class CompoundVisitor auto cond{provenance.z3_exprs[provenance.conds[while_stmt]]}; - auto inner{Clone(true_exprs)}; - AddExpr(cond, inner); + auto inner{known_exprs.Clone()}; + AddExpr(cond, true, inner); changed |= Visit(while_stmt->getBody(), inner); - AddExpr(!cond, true_exprs); + AddExpr(cond, false, known_exprs); return changed; } - bool VisitDoStmt(clang::DoStmt* do_stmt, z3::expr_vector& true_exprs) { + bool VisitDoStmt(clang::DoStmt* do_stmt, KnownExprs& known_exprs) { bool changed{false}; if (do_stmt->getCond() == provenance.marker_expr) { - auto old_cond{provenance.z3_exprs[provenance.conds[do_stmt]]}; - auto new_cond{Sort(Simplify(ApplyAssumptions(old_cond, true_exprs)))}; + auto old_cond{Sort(provenance.z3_exprs[provenance.conds[do_stmt]])}; + auto new_cond{Sort(Simplify(ApplyAssumptions(old_cond, known_exprs)))}; if (!z3::eq(old_cond, new_cond)) { provenance.conds[do_stmt] = provenance.z3_exprs.size(); provenance.z3_exprs.push_back(new_cond); @@ -146,19 +177,19 @@ class CompoundVisitor } auto cond{provenance.z3_exprs[provenance.conds[do_stmt]]}; - auto inner{Clone(true_exprs)}; + auto inner{known_exprs.Clone()}; - Visit(do_stmt->getBody(), inner); + changed |= Visit(do_stmt->getBody(), inner); - AddExpr(!cond, true_exprs); + AddExpr(cond, false, known_exprs); return changed; } - bool VisitIfStmt(clang::IfStmt* if_stmt, z3::expr_vector& true_exprs) { + bool VisitIfStmt(clang::IfStmt* if_stmt, KnownExprs& known_exprs) { bool changed{false}; if (if_stmt->getCond() == provenance.marker_expr) { - auto old_cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; - auto new_cond{Sort(Simplify(ApplyAssumptions(old_cond, true_exprs)))}; + auto old_cond{Sort(provenance.z3_exprs[provenance.conds[if_stmt]])}; + auto new_cond{Sort(Simplify(ApplyAssumptions(old_cond, known_exprs)))}; if (!z3::eq(old_cond, new_cond)) { provenance.conds[if_stmt] = provenance.z3_exprs.size(); provenance.z3_exprs.push_back(new_cond); @@ -167,14 +198,14 @@ class CompoundVisitor } auto cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; - auto inner_then{Clone(true_exprs)}; - AddExpr(cond, inner_then); - Visit(if_stmt->getThen(), inner_then); + auto inner_then{known_exprs.Clone()}; + AddExpr(cond, true, inner_then); + changed |= Visit(if_stmt->getThen(), inner_then); if (if_stmt->getElse()) { - auto inner_else{Clone(true_exprs)}; - AddExpr(!cond, inner_else); - Visit(if_stmt->getElse(), inner_else); + auto inner_else{known_exprs.Clone()}; + AddExpr(cond, false, inner_else); + changed |= Visit(if_stmt->getElse(), inner_else); } return changed; } @@ -191,8 +222,9 @@ void NestedCondProp::RunImpl() { for (auto decl : ast_ctx.getTranslationUnitDecl()->decls()) { if (auto fdecl = clang::dyn_cast(decl)) { if (fdecl->hasBody()) { - z3::expr_vector true_exprs{provenance.z3_ctx}; - changed |= visitor.Visit(fdecl->getBody(), true_exprs); + KnownExprs known_exprs{z3::expr_vector{provenance.z3_ctx}, + z3::expr_vector{provenance.z3_ctx}}; + changed |= visitor.Visit(fdecl->getBody(), known_exprs); } } } From 9223a8fe0abe107f755bcd6ba894f7eec244f06d Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Thu, 4 Aug 2022 15:18:24 +0200 Subject: [PATCH 10/57] Fix NCP bugs --- lib/AST/NestedCondProp.cpp | 157 +++++++++++++++++-------------------- lib/AST/Z3CondSimplify.cpp | 4 +- 2 files changed, 72 insertions(+), 89 deletions(-) diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index 2b172ba2..39332bbf 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -25,93 +25,85 @@ struct KnownExprs { z3::expr_vector src; z3::expr_vector dst; + KnownExprs(z3::context& ctx) : src(ctx), dst(ctx) {} + KnownExprs(z3::expr_vector&& src, z3::expr_vector&& dst) + : src(std::move(src)), dst(std::move(dst)) {} KnownExprs Clone() { return {::rellic::Clone(src), ::rellic::Clone(dst)}; } -}; - -class CompoundVisitor - : public clang::StmtVisitor { - private: - Provenance& provenance; - ASTBuilder& ast; - clang::ASTContext& ctx; bool IsConstant(z3::expr expr) { - if (Prove(provenance.z3_ctx, expr)) { + if (Prove(src.ctx(), expr)) { return true; } - if (Prove(provenance.z3_ctx, !expr)) { + if (Prove(src.ctx(), !expr)) { return true; } return false; } - void AddExpr(z3::expr expr, bool value, KnownExprs& vec) { + void AddExpr(z3::expr expr, bool value) { if (IsConstant(expr)) { return; } if (expr.is_not()) { - AddExpr(expr.arg(0), !value, vec); + AddExpr(expr.arg(0), !value); return; } - if (value) { - if (expr.is_and()) { - for (auto e : expr.args()) { - AddExpr(e, true, vec); - } - return; - } + if (expr.num_args() == 1) { + AddExpr(expr.arg(0), value); + return; + } - if (expr.is_or() && expr.num_args() == 1) { - AddExpr(expr.arg(0), true, vec); - return; - } - } else { - if (expr.is_or()) { - for (auto e : expr.args()) { - AddExpr(e, false, vec); - } - return; + if (value && expr.is_and()) { + for (auto e : expr.args()) { + AddExpr(e, true); } + return; + } - if (expr.is_and() && expr.num_args() == 1) { - AddExpr(expr.arg(0), false, vec); - return; + if (!value && expr.is_or()) { + for (auto e : expr.args()) { + AddExpr(e, false); } + return; } - vec.src.push_back(expr); - vec.dst.push_back(provenance.z3_ctx.bool_val(value)); + src.push_back(expr); + dst.push_back(dst.ctx().bool_val(value)); + CHECK_EQ(src.size(), dst.size()); } - z3::expr Simplify(z3::expr expr) { - return HeavySimplify(provenance.z3_ctx, expr); + z3::expr ApplyAssumptions(z3::expr expr) { + auto res{expr.substitute(src, dst)}; + return res; } +}; - z3::expr ApplyAssumptions(z3::expr expr, KnownExprs& known_exprs) { - auto res{expr.substitute(known_exprs.src, known_exprs.dst)}; - return res; +class CompoundVisitor + : public clang::StmtVisitor { + private: + Provenance& provenance; + ASTBuilder& ast; + clang::ASTContext& ctx; + + z3::expr Simplify(z3::expr expr) { + return HeavySimplify(provenance.z3_ctx, expr); } z3::expr Sort(z3::expr expr) { if (expr.is_and() || expr.is_or()) { - auto args{expr.args()}; - for (size_t i{0}; i < args.size(); ++i) { - args[i] = Sort(args[i]); - } - - std::vector args_indices{args.size()}; + std::vector args_indices(expr.num_args(), 0); std::iota(args_indices.begin(), args_indices.end(), 0); std::sort(args_indices.begin(), args_indices.end(), - [&args](unsigned a, unsigned b) { - return args[a].id() < args[b].id(); + [&expr](unsigned a, unsigned b) { + return expr.arg(a).id() < expr.arg(b).id(); }); z3::expr_vector new_args{provenance.z3_ctx}; for (auto idx : args_indices) { - new_args.push_back(args[idx]); + new_args.push_back(Sort(expr.arg(idx))); } if (expr.is_and()) { return z3::mk_and(new_args); @@ -143,68 +135,60 @@ class CompoundVisitor } bool VisitWhileStmt(clang::WhileStmt* while_stmt, KnownExprs& known_exprs) { - bool changed{false}; - if (while_stmt->getCond() == provenance.marker_expr) { - auto old_cond{Sort(provenance.z3_exprs[provenance.conds[while_stmt]])}; - auto new_cond{Sort(Simplify(ApplyAssumptions(old_cond, known_exprs)))}; - if (!z3::eq(old_cond, new_cond)) { - provenance.conds[while_stmt] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(new_cond); - changed = true; - } + auto cond_idx{provenance.conds[while_stmt]}; + auto old_cond{Sort(provenance.z3_exprs[cond_idx])}; + auto new_cond{Sort(known_exprs.ApplyAssumptions(old_cond))}; + if (while_stmt->getCond() != provenance.marker_expr && + !z3::eq(old_cond, new_cond)) { + provenance.z3_exprs.set(cond_idx, new_cond); + return true; } - auto cond{provenance.z3_exprs[provenance.conds[while_stmt]]}; - + bool changed{false}; auto inner{known_exprs.Clone()}; - AddExpr(cond, true, inner); + inner.AddExpr(new_cond, true); changed |= Visit(while_stmt->getBody(), inner); - AddExpr(cond, false, known_exprs); + known_exprs.AddExpr(new_cond, false); return changed; } bool VisitDoStmt(clang::DoStmt* do_stmt, KnownExprs& known_exprs) { - bool changed{false}; - if (do_stmt->getCond() == provenance.marker_expr) { - auto old_cond{Sort(provenance.z3_exprs[provenance.conds[do_stmt]])}; - auto new_cond{Sort(Simplify(ApplyAssumptions(old_cond, known_exprs)))}; - if (!z3::eq(old_cond, new_cond)) { - provenance.conds[do_stmt] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(new_cond); - changed = true; - } + auto cond_idx{provenance.conds[do_stmt]}; + auto old_cond{Sort(provenance.z3_exprs[cond_idx])}; + auto new_cond{Sort(known_exprs.ApplyAssumptions(old_cond))}; + if (do_stmt->getCond() == provenance.marker_expr && + !z3::eq(old_cond, new_cond)) { + provenance.z3_exprs.set(cond_idx, new_cond); + return true; } - auto cond{provenance.z3_exprs[provenance.conds[do_stmt]]}; + bool changed{false}; auto inner{known_exprs.Clone()}; - changed |= Visit(do_stmt->getBody(), inner); - AddExpr(cond, false, known_exprs); + known_exprs.AddExpr(new_cond, false); return changed; } bool VisitIfStmt(clang::IfStmt* if_stmt, KnownExprs& known_exprs) { - bool changed{false}; - if (if_stmt->getCond() == provenance.marker_expr) { - auto old_cond{Sort(provenance.z3_exprs[provenance.conds[if_stmt]])}; - auto new_cond{Sort(Simplify(ApplyAssumptions(old_cond, known_exprs)))}; - if (!z3::eq(old_cond, new_cond)) { - provenance.conds[if_stmt] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(new_cond); - changed = true; - } + auto cond_idx{provenance.conds[if_stmt]}; + auto old_cond{Sort(provenance.z3_exprs[cond_idx])}; + auto new_cond{Sort(known_exprs.ApplyAssumptions(old_cond))}; + if (if_stmt->getCond() == provenance.marker_expr && + !z3::eq(old_cond, new_cond)) { + provenance.z3_exprs.set(cond_idx, new_cond); + return true; } - auto cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; + bool changed{false}; auto inner_then{known_exprs.Clone()}; - AddExpr(cond, true, inner_then); + inner_then.AddExpr(new_cond, true); changed |= Visit(if_stmt->getThen(), inner_then); if (if_stmt->getElse()) { auto inner_else{known_exprs.Clone()}; - AddExpr(cond, false, inner_else); + inner_else.AddExpr(new_cond, false); changed |= Visit(if_stmt->getElse(), inner_else); } return changed; @@ -222,8 +206,7 @@ void NestedCondProp::RunImpl() { for (auto decl : ast_ctx.getTranslationUnitDecl()->decls()) { if (auto fdecl = clang::dyn_cast(decl)) { if (fdecl->hasBody()) { - KnownExprs known_exprs{z3::expr_vector{provenance.z3_ctx}, - z3::expr_vector{provenance.z3_ctx}}; + KnownExprs known_exprs{provenance.z3_ctx}; changed |= visitor.Visit(fdecl->getBody(), known_exprs); } } diff --git a/lib/AST/Z3CondSimplify.cpp b/lib/AST/Z3CondSimplify.cpp index 2bf66e34..529f66c5 100644 --- a/lib/AST/Z3CondSimplify.cpp +++ b/lib/AST/Z3CondSimplify.cpp @@ -22,8 +22,8 @@ Z3CondSimplify::Z3CondSimplify(Provenance &provenance, clang::ASTUnit &unit) void Z3CondSimplify::RunImpl() { LOG(INFO) << "Simplifying conditions using Z3"; for (size_t i{0}; i < provenance.z3_exprs.size() && !Stopped(); ++i) { - provenance.z3_exprs[i] = - HeavySimplify(provenance.z3_ctx, provenance.z3_exprs[i]); + auto simpl{HeavySimplify(provenance.z3_ctx, provenance.z3_exprs[i])}; + provenance.z3_exprs.set(i, simpl); } } From 61b5a2783a56eba66f231a87a77f570587ebea81 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Mon, 8 Aug 2022 16:12:52 +0200 Subject: [PATCH 11/57] Fix condition sharing bug --- lib/AST/GenerateAST.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index ba296205..83d3be5c 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -351,7 +351,8 @@ StmtVec GenerateAST::CreateRegionStmts(llvm::Region *region) { // Gate the compound behind a reaching condition auto z_expr{GetReachingCond(block)}; block_stmts[block] = ast.CreateIf(provenance.marker_expr, compound); - provenance.conds[block_stmts[block]] = z_expr; + provenance.conds[block_stmts[block]] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(provenance.z3_exprs[z_expr]); // Store the compound result.push_back(block_stmts[block]); } From 30ba9454edb14d88af1a1f6d364668e41b9dc961 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Mon, 8 Aug 2022 19:06:41 +0200 Subject: [PATCH 12/57] Fix CBR --- include/rellic/AST/CondBasedRefine.h | 4 - lib/AST/CondBasedRefine.cpp | 129 +++++++++++++-------------- 2 files changed, 61 insertions(+), 72 deletions(-) diff --git a/include/rellic/AST/CondBasedRefine.h b/include/rellic/AST/CondBasedRefine.h index 8f910048..aa767d46 100644 --- a/include/rellic/AST/CondBasedRefine.h +++ b/include/rellic/AST/CondBasedRefine.h @@ -35,10 +35,6 @@ namespace rellic { */ class CondBasedRefine : public TransformVisitor { private: - using IfStmtVec = std::vector; - - void CreateIfThenElseStmts(IfStmtVec stmts); - protected: void RunImpl() override; diff --git a/lib/AST/CondBasedRefine.cpp b/lib/AST/CondBasedRefine.cpp index b89ecca5..fce91456 100644 --- a/lib/AST/CondBasedRefine.cpp +++ b/lib/AST/CondBasedRefine.cpp @@ -11,99 +11,92 @@ #include #include -#include - -#include "rellic/AST/Util.h" +#include namespace rellic { CondBasedRefine::CondBasedRefine(Provenance &provenance, clang::ASTUnit &unit) : TransformVisitor(provenance, unit) {} -void CondBasedRefine::CreateIfThenElseStmts(IfStmtVec worklist) { - auto RemoveFromWorkList = [&worklist](clang::Stmt *stmt) { - auto it = std::find(worklist.begin(), worklist.end(), stmt); - if (it != worklist.end()) { - worklist.erase(it); - } - }; - - auto ThenTest = [this](z3::expr lhs, z3::expr rhs) { - return Prove(provenance.z3_ctx, lhs == rhs); - }; - - auto ElseTest = [this](z3::expr lhs, z3::expr rhs) { - return Prove(provenance.z3_ctx, lhs == !rhs); - }; - - auto CombineTest = [this](z3::expr lhs, z3::expr rhs) { - return Prove(provenance.z3_ctx, z3::implies(rhs, lhs)); - }; - - while (!worklist.empty()) { - auto lhs = *worklist.begin(); - RemoveFromWorkList(lhs); - // Prepare conditions according to which we're going to - // cluster statements according to the whole `lhs` - // condition. - auto lcond = provenance.z3_exprs[provenance.conds[lhs]]; - // Get branch candidates wrt `clause` - std::vector thens({lhs}), elses; - for (auto rhs : worklist) { - auto rcond = provenance.z3_exprs[provenance.conds[rhs]]; - if (ThenTest(lcond, rcond) || CombineTest(lcond, rcond)) { - thens.push_back(rhs); - } else if (ElseTest(lcond, rcond) || CombineTest(!lcond, rcond)) { - elses.push_back(rhs); - } - } +bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { + std::vector body{compound->body_begin(), compound->body_end()}; + std::vector new_body{body}; + bool did_something{false}; + + for (size_t i{0}; i + 1 < body.size(); ++i) { + auto if_a{clang::dyn_cast(body[i])}; + auto if_b{clang::dyn_cast(body[i + 1])}; - // Check if we have enough statements to work with - if (thens.size() + elses.size() < 2) { + if (!if_a || !if_b) { continue; } - // Erase then statements from the AST and `worklist` - for (auto stmt : thens) { - RemoveFromWorkList(stmt); - substitutions[stmt] = nullptr; - } - // Create our new if-then - auto sub = - ast.CreateIf(provenance.marker_expr, ast.CreateCompoundStmt(thens)); - provenance.conds[sub] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(lcond); - // Create an else branch if possible - if (!elses.empty()) { - // Erase else statements from the AST and `worklist` - for (auto stmt : elses) { - RemoveFromWorkList(stmt); - substitutions[stmt] = nullptr; + auto cond_a{provenance.z3_exprs[provenance.conds[if_a]]}; + auto cond_b{provenance.z3_exprs[provenance.conds[if_b]]}; + + auto then_a{if_a->getThen()}; + auto then_b{if_b->getThen()}; + + auto else_a{if_a->getElse()}; + auto else_b{if_b->getElse()}; + + if (Prove(provenance.z3_ctx, cond_a == cond_b)) { + std::vector new_then_body{then_a, then_b}; + auto new_then{ast.CreateCompoundStmt(new_then_body)}; + + auto new_if{ast.CreateIf(provenance.marker_expr, new_then)}; + + if (else_a || else_b) { + std::vector new_else_body{}; + + if (else_a) { + new_else_body.push_back(else_a); + } + + if (else_b) { + new_else_body.push_back(else_b); + } + + auto new_else{ast.CreateCompoundStmt(new_else_body)}; + new_if->setElse(new_else); } - } else if (Prove(*z3_ctx, cond_a == !cond_b)) { + + provenance.conds[new_if] = provenance.conds[if_a]; + new_body[i] = new_if; + new_body.erase(std::next(new_body.begin(), i + 1)); + did_something = true; + break; + } + + if (Prove(provenance.z3_ctx, cond_a == !cond_b)) { + std::vector new_then_body{then_a}; if (else_b) { new_then_body.push_back(else_b); } + auto new_then{ast.CreateCompoundStmt(new_then_body)}; - new_if = ast.CreateIf(if_a->getCond(), new_then); - std::vector new_else_body; + std::vector new_else_body{}; if (else_a) { new_else_body.push_back(else_a); } new_else_body.push_back(then_b); - new_if->setElse(ast.CreateCompoundStmt(new_else_body)); - } - if (new_if) { - body[i] = new_if; - body.erase(std::next(body.begin(), i + 1)); + auto new_if{ast.CreateIf(provenance.marker_expr, new_then)}; + + auto new_else{ast.CreateCompoundStmt(new_else_body)}; + new_if->setElse(new_else); + + provenance.conds[new_if] = provenance.conds[if_a]; + new_body[i] = new_if; + new_body.erase(std::next(new_body.begin(), i + 1)); did_something = true; + break; } } - if (did_something) { - substitutions[compound] = ast.CreateCompoundStmt(body); + auto new_compound{ast.CreateCompoundStmt(new_body)}; + substitutions[compound] = new_compound; } return !Stopped(); } From 06de4fb9058d31ef697c474f3d81bccb80f60ff3 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Mon, 8 Aug 2022 19:13:13 +0200 Subject: [PATCH 13/57] Temporarily disable RBR --- lib/Decompiler.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Decompiler.cpp b/lib/Decompiler.cpp index 1297d3cf..04037b62 100644 --- a/lib/Decompiler.cpp +++ b/lib/Decompiler.cpp @@ -116,8 +116,8 @@ Result Decompile( cbr_passes.push_back( std::make_unique(provenance, *ast_unit)); - cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); + // cbr_passes.push_back( + // std::make_unique(provenance, *ast_unit)); while (pass_cbr.Run()) { ; From 6547998d0ba5b7d76b5ca55ef93e11da48a974ca Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Mon, 8 Aug 2022 19:52:39 +0200 Subject: [PATCH 14/57] Improve `else-if` recognition --- lib/AST/CondBasedRefine.cpp | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/lib/AST/CondBasedRefine.cpp b/lib/AST/CondBasedRefine.cpp index fce91456..dd5335af 100644 --- a/lib/AST/CondBasedRefine.cpp +++ b/lib/AST/CondBasedRefine.cpp @@ -10,6 +10,7 @@ #include #include +#include #include @@ -93,6 +94,37 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { did_something = true; break; } + + if (Prove(provenance.z3_ctx, z3::implies(cond_b, cond_a))) { + std::vector new_then_body{then_a, if_b}; + auto new_then{ast.CreateCompoundStmt(new_then_body)}; + + auto new_if{ast.CreateIf(provenance.marker_expr, new_then)}; + new_if->setElse(else_a); + + provenance.conds[new_if] = provenance.conds[if_a]; + new_body[i] = new_if; + new_body.erase(std::next(new_body.begin(), i + 1)); + did_something = true; + break; + } + + if (Prove(provenance.z3_ctx, !(cond_a && cond_b)) && + !Prove(provenance.z3_ctx, cond_a || cond_b)) { + auto new_if{ast.CreateIf(provenance.marker_expr, then_a)}; + if (else_a) { + std::vector new_else_body{else_a, if_b}; + new_if->setElse(ast.CreateCompoundStmt(new_else_body)); + } else { + new_if->setElse(if_b); + } + + provenance.conds[new_if] = provenance.conds[if_a]; + new_body[i] = new_if; + new_body.erase(std::next(new_body.begin(), i + 1)); + did_something = true; + break; + } } if (did_something) { auto new_compound{ast.CreateCompoundStmt(new_body)}; From f352e0f80b7003176595138078f93f8264b2b150 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Tue, 9 Aug 2022 12:19:44 +0200 Subject: [PATCH 15/57] Improve switch conditions generation --- include/rellic/AST/GenerateAST.h | 2 ++ lib/AST/GenerateAST.cpp | 39 ++++++++++++++++++++++++-------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/include/rellic/AST/GenerateAST.h b/include/rellic/AST/GenerateAST.h index 3af1df12..05982551 100644 --- a/include/rellic/AST/GenerateAST.h +++ b/include/rellic/AST/GenerateAST.h @@ -45,6 +45,8 @@ class GenerateAST : public llvm::AnalysisInfoMixin { std::vector rpo_walk; unsigned GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond); + unsigned GetOrCreateVarForSwitch(llvm::SwitchInst *inst, + llvm::ConstantInt *c); unsigned GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, llvm::ConstantInt *c); diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index 83d3be5c..ef08017c 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -169,6 +169,21 @@ unsigned GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, return provenance.z3_br_edges[{inst, cond}]; } +unsigned GenerateAST::GetOrCreateVarForSwitch(llvm::SwitchInst *inst, + llvm::ConstantInt *c) { + if (provenance.z3_sw_vars.find({inst, c}) == provenance.z3_sw_vars.end()) { + auto name{GetName(inst) + + GetName(inst->findCaseValue(c)->getCaseSuccessor())}; + auto var{provenance.z3_ctx.bool_const(name.c_str())}; + provenance.z3_sw_vars[{inst, c}] = provenance.z3_exprs.size(); + provenance.z3_exprs.push_back(var); + provenance.z3_sw_edges_inv[var.id()] = {inst, c}; + return provenance.z3_sw_vars[{inst, c}]; + } else { + return provenance.z3_sw_vars[{inst, c}]; + } +} + unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, llvm::ConstantInt *c) { auto ToExpr = [&](unsigned idx) { @@ -179,23 +194,27 @@ unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, }; if (provenance.z3_sw_edges.find({inst, c}) == provenance.z3_sw_edges.end()) { if (c) { - auto name{GetName(inst) + - GetName(inst->findCaseValue(c)->getCaseSuccessor())}; - auto edge{provenance.z3_ctx.bool_const(name.c_str())}; - provenance.z3_sw_edges_inv[edge.id()] = {inst, c}; + auto edge{ToExpr(GetOrCreateVarForSwitch(inst, c))}; + z3::expr_vector vec{provenance.z3_ctx}; + vec.push_back(edge); + for (auto sw_case : inst->cases()) { + if (c != sw_case.getCaseValue()) { + vec.push_back( + !ToExpr(GetOrCreateVarForSwitch(inst, sw_case.getCaseValue()))); + } + } provenance.z3_sw_edges[{inst, c}] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(edge); + provenance.z3_exprs.push_back(z3::mk_and(vec)); } else { // Default case - auto edge{provenance.z3_ctx.bool_val(true)}; + z3::expr_vector vec{provenance.z3_ctx}; for (auto sw_case : inst->cases()) { - edge = edge && - !ToExpr(GetOrCreateEdgeForSwitch(inst, sw_case.getCaseValue())); + vec.push_back( + !ToExpr(GetOrCreateEdgeForSwitch(inst, sw_case.getCaseValue()))); } - edge = edge.simplify(); provenance.z3_sw_edges[{inst, c}] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(edge); + provenance.z3_exprs.push_back(z3::mk_and(vec)); } } From e244c008bcd5c8ec749e8c45a5ad50902bfd08d6 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Tue, 9 Aug 2022 12:20:55 +0200 Subject: [PATCH 16/57] Factor out `Sort` --- include/rellic/AST/Util.h | 5 +++++ lib/AST/Util.cpp | 29 ++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/include/rellic/AST/Util.h b/include/rellic/AST/Util.h index 833c878a..fbbba97b 100644 --- a/include/rellic/AST/Util.h +++ b/include/rellic/AST/Util.h @@ -82,6 +82,7 @@ struct Provenance { std::unordered_map z3_sw_edges_inv; std::map z3_sw_edges; + std::map z3_sw_vars; std::map z3_edges; std::unordered_map reaching_conds; @@ -110,4 +111,8 @@ bool Prove(z3::context &ctx, z3::expr expr); z3::expr HeavySimplify(z3::context &ctx, z3::expr expr); z3::expr_vector Clone(z3::expr_vector &vec); + +// Tries to keep each subformula sorted by its id so that they don't get +// shuffled around by simplification +z3::expr Sort(z3::context &ctx, z3::expr expr); } // namespace rellic \ No newline at end of file diff --git a/lib/AST/Util.cpp b/lib/AST/Util.cpp index 3bec3c9f..04741b5e 100644 --- a/lib/AST/Util.cpp +++ b/lib/AST/Util.cpp @@ -14,7 +14,8 @@ #include #include #include -#include + +#include #include "rellic/AST/ASTBuilder.h" #include "rellic/Exception.h" @@ -371,4 +372,30 @@ z3::expr_vector Clone(z3::expr_vector &vec) { return clone; } + +z3::expr Sort(z3::context &ctx, z3::expr expr) { + if (expr.is_and() || expr.is_or()) { + std::vector args_indices(expr.num_args(), 0); + std::iota(args_indices.begin(), args_indices.end(), 0); + std::sort(args_indices.begin(), args_indices.end(), + [&expr](unsigned a, unsigned b) { + return expr.arg(a).id() < expr.arg(b).id(); + }); + z3::expr_vector new_args{ctx}; + for (auto idx : args_indices) { + new_args.push_back(Sort(ctx, expr.arg(idx))); + } + if (expr.is_and()) { + return z3::mk_and(new_args); + } else { + return z3::mk_or(new_args); + } + } + + if (expr.is_not()) { + return !Sort(ctx, expr.arg(0)); + } + + return expr; +} } // namespace rellic From 8309f2a50fa800de0b1a1ccd24fa4d9f19c9ea74 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Tue, 9 Aug 2022 12:21:26 +0200 Subject: [PATCH 17/57] Sort during simplification, use faster simplify --- lib/AST/Z3CondSimplify.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/AST/Z3CondSimplify.cpp b/lib/AST/Z3CondSimplify.cpp index 529f66c5..c0f82d08 100644 --- a/lib/AST/Z3CondSimplify.cpp +++ b/lib/AST/Z3CondSimplify.cpp @@ -22,7 +22,7 @@ Z3CondSimplify::Z3CondSimplify(Provenance &provenance, clang::ASTUnit &unit) void Z3CondSimplify::RunImpl() { LOG(INFO) << "Simplifying conditions using Z3"; for (size_t i{0}; i < provenance.z3_exprs.size() && !Stopped(); ++i) { - auto simpl{HeavySimplify(provenance.z3_ctx, provenance.z3_exprs[i])}; + auto simpl{Sort(provenance.z3_ctx, provenance.z3_exprs[i].simplify())}; provenance.z3_exprs.set(i, simpl); } } From e8af6c23904b2f0132183268db393242619a6e70 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Tue, 9 Aug 2022 12:21:36 +0200 Subject: [PATCH 18/57] Improve NCP --- lib/AST/NestedCondProp.cpp | 125 ++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 59 deletions(-) diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index 39332bbf..c16e0926 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -11,10 +11,7 @@ #include #include #include -#include -#include -#include #include #include "rellic/AST/ASTBuilder.h" @@ -76,48 +73,44 @@ struct KnownExprs { CHECK_EQ(src.size(), dst.size()); } - z3::expr ApplyAssumptions(z3::expr expr) { - auto res{expr.substitute(src, dst)}; - return res; - } -}; - -class CompoundVisitor - : public clang::StmtVisitor { - private: - Provenance& provenance; - ASTBuilder& ast; - clang::ASTContext& ctx; + z3::expr ApplyAssumptions(z3::expr expr, bool& found) { + if (IsConstant(expr)) { + return expr; + } - z3::expr Simplify(z3::expr expr) { - return HeavySimplify(provenance.z3_ctx, expr); - } + for (unsigned i{0}; i < dst.size(); ++i) { + if (z3::eq(expr, src[i])) { + found = true; + return dst[i]; + } + } - z3::expr Sort(z3::expr expr) { if (expr.is_and() || expr.is_or()) { - std::vector args_indices(expr.num_args(), 0); - std::iota(args_indices.begin(), args_indices.end(), 0); - std::sort(args_indices.begin(), args_indices.end(), - [&expr](unsigned a, unsigned b) { - return expr.arg(a).id() < expr.arg(b).id(); - }); - z3::expr_vector new_args{provenance.z3_ctx}; - for (auto idx : args_indices) { - new_args.push_back(Sort(expr.arg(idx))); + z3::expr_vector args{src.ctx()}; + for (auto arg : expr.args()) { + args.push_back(ApplyAssumptions(arg, found)); } if (expr.is_and()) { - return z3::mk_and(new_args); + return z3::mk_and(args); } else { - return z3::mk_or(new_args); + return z3::mk_or(args); } } if (expr.is_not()) { - return !Sort(expr.arg(0)); + return !ApplyAssumptions(expr.arg(0), found); } return expr; } +}; + +class CompoundVisitor + : public clang::StmtVisitor { + private: + Provenance& provenance; + ASTBuilder& ast; + clang::ASTContext& ctx; public: CompoundVisitor(Provenance& provenance, ASTBuilder& ast, @@ -126,72 +119,79 @@ class CompoundVisitor bool VisitCompoundStmt(clang::CompoundStmt* compound, KnownExprs& known_exprs) { - bool changed{false}; for (auto stmt : compound->body()) { - changed |= Visit(stmt, known_exprs); + if (Visit(stmt, known_exprs)) { + return true; + } } - return changed; + return false; } bool VisitWhileStmt(clang::WhileStmt* while_stmt, KnownExprs& known_exprs) { auto cond_idx{provenance.conds[while_stmt]}; - auto old_cond{Sort(provenance.z3_exprs[cond_idx])}; - auto new_cond{Sort(known_exprs.ApplyAssumptions(old_cond))}; - if (while_stmt->getCond() != provenance.marker_expr && - !z3::eq(old_cond, new_cond)) { + bool changed{false}; + auto old_cond{provenance.z3_exprs[cond_idx]}; + auto new_cond{known_exprs.ApplyAssumptions(old_cond, changed)}; + if (while_stmt->getCond() != provenance.marker_expr && changed) { provenance.z3_exprs.set(cond_idx, new_cond); return true; } - bool changed{false}; auto inner{known_exprs.Clone()}; inner.AddExpr(new_cond, true); - changed |= Visit(while_stmt->getBody(), inner); - known_exprs.AddExpr(new_cond, false); - return changed; + + if (Visit(while_stmt->getBody(), inner)) { + return true; + } + return false; } bool VisitDoStmt(clang::DoStmt* do_stmt, KnownExprs& known_exprs) { auto cond_idx{provenance.conds[do_stmt]}; - auto old_cond{Sort(provenance.z3_exprs[cond_idx])}; - auto new_cond{Sort(known_exprs.ApplyAssumptions(old_cond))}; - if (do_stmt->getCond() == provenance.marker_expr && - !z3::eq(old_cond, new_cond)) { + bool changed{false}; + auto old_cond{provenance.z3_exprs[cond_idx]}; + auto new_cond{known_exprs.ApplyAssumptions(old_cond, changed)}; + if (do_stmt->getCond() == provenance.marker_expr && changed) { provenance.z3_exprs.set(cond_idx, new_cond); return true; } - bool changed{false}; auto inner{known_exprs.Clone()}; - changed |= Visit(do_stmt->getBody(), inner); - known_exprs.AddExpr(new_cond, false); - return changed; + + if (Visit(do_stmt->getBody(), inner)) { + return true; + } + + return false; } bool VisitIfStmt(clang::IfStmt* if_stmt, KnownExprs& known_exprs) { auto cond_idx{provenance.conds[if_stmt]}; - auto old_cond{Sort(provenance.z3_exprs[cond_idx])}; - auto new_cond{Sort(known_exprs.ApplyAssumptions(old_cond))}; - if (if_stmt->getCond() == provenance.marker_expr && - !z3::eq(old_cond, new_cond)) { + bool changed{false}; + auto old_cond{provenance.z3_exprs[cond_idx]}; + auto new_cond{known_exprs.ApplyAssumptions(old_cond, changed)}; + if (if_stmt->getCond() == provenance.marker_expr && changed) { provenance.z3_exprs.set(cond_idx, new_cond); return true; } - bool changed{false}; auto inner_then{known_exprs.Clone()}; inner_then.AddExpr(new_cond, true); - changed |= Visit(if_stmt->getThen(), inner_then); + if (Visit(if_stmt->getThen(), inner_then)) { + return true; + } if (if_stmt->getElse()) { auto inner_else{known_exprs.Clone()}; inner_else.AddExpr(new_cond, false); - changed |= Visit(if_stmt->getElse(), inner_else); + if (Visit(if_stmt->getElse(), inner_else)) { + return true; + } } - return changed; + return false; } }; @@ -205,9 +205,16 @@ void NestedCondProp::RunImpl() { for (auto decl : ast_ctx.getTranslationUnitDecl()->decls()) { if (auto fdecl = clang::dyn_cast(decl)) { + if (Stopped()) { + return; + } + if (fdecl->hasBody()) { KnownExprs known_exprs{provenance.z3_ctx}; - changed |= visitor.Visit(fdecl->getBody(), known_exprs); + if (visitor.Visit(fdecl->getBody(), known_exprs)) { + changed = true; + return; + } } } } From 42586867341acf7d2a5d0ed5a44e672248aa3ebc Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Tue, 9 Aug 2022 12:26:40 +0200 Subject: [PATCH 19/57] Remove RBR --- include/rellic/AST/ReachBasedRefine.h | 54 ------------ lib/AST/ReachBasedRefine.cpp | 122 -------------------------- lib/CMakeLists.txt | 2 - lib/Decompiler.cpp | 3 - tools/repl/Repl.cpp | 8 +- tools/xref/Xref.cpp | 4 - tools/xref/www/main.js | 7 +- 7 files changed, 3 insertions(+), 197 deletions(-) delete mode 100644 include/rellic/AST/ReachBasedRefine.h delete mode 100644 lib/AST/ReachBasedRefine.cpp diff --git a/include/rellic/AST/ReachBasedRefine.h b/include/rellic/AST/ReachBasedRefine.h deleted file mode 100644 index 7ef9dbf0..00000000 --- a/include/rellic/AST/ReachBasedRefine.h +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2021-present, Trail of Bits, Inc. - * All rights reserved. - * - * This source code is licensed in accordance with the terms specified in - * the LICENSE file found in the root directory of this source tree. - */ - -#pragma once - -#include "rellic/AST/TransformVisitor.h" -#include "rellic/AST/Z3ConvVisitor.h" - -namespace rellic { - -/* - * This pass restructures a sequence of if statements that have a shape like - * - * if(cond1 && !cond2 && !cond3) { - * body1; - * } - * if(cond2 && !cond1 && !cond3) { - * body2; - * } - * if(cond3 && !cond1 && !cond2) { - * body3; - * } - * - * into - * - * if(cond1) { - * body1; - * } else if(cond2) { - * body2; - * } else if(cond3) { - * body3; - * } - */ -class ReachBasedRefine : public TransformVisitor { - private: - using IfStmtVec = std::vector; - - void CreateIfElseStmts(IfStmtVec stmts); - - protected: - void RunImpl() override; - - public: - ReachBasedRefine(Provenance &provenance, clang::ASTUnit &unit); - - bool VisitCompoundStmt(clang::CompoundStmt *compound); -}; - -} // namespace rellic \ No newline at end of file diff --git a/lib/AST/ReachBasedRefine.cpp b/lib/AST/ReachBasedRefine.cpp deleted file mode 100644 index 0f5b94d8..00000000 --- a/lib/AST/ReachBasedRefine.cpp +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (c) 2021-present, Trail of Bits, Inc. - * All rights reserved. - * - * This source code is licensed in accordance with the terms specified in - * the LICENSE file found in the root directory of this source tree. - */ - -#include "rellic/AST/ReachBasedRefine.h" - -#include -#include - -#include "rellic/AST/Util.h" - -namespace rellic { - -ReachBasedRefine::ReachBasedRefine(Provenance &provenance, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit) {} - -void ReachBasedRefine::CreateIfElseStmts(IfStmtVec stmts) { - // Else-if candidate IfStmts and their Z3 form - // reaching conditions. - IfStmtVec elifs; - z3::expr_vector conds(provenance.z3_ctx); - // Test that determines if a new IfStmts is not - // reachable from the already gathered IfStmts. - auto IsUnrechable = [this, &conds](z3::expr cond) { - return Prove(provenance.z3_ctx, !(cond && z3::mk_or(conds))); - }; - // Test to determine if we have enough candidate - // IfStmts to form an else-if cascade. - auto IsTautology = [this, &conds] { - return Prove(provenance.z3_ctx, - z3::mk_or(conds) == provenance.z3_ctx.bool_val(true)); - }; - - bool done_something{false}; - for (size_t i{0}; i < body.size() && !done_something; ++i) { - auto if_stmt{clang::dyn_cast(body[i])}; - if (!if_stmt) { - ResetChain(); - continue; - } - // Clear else-if IfStmts if we find a path among them. - auto cond = provenance.z3_exprs[provenance.conds[stmt]]; - if (stmt->getElse() || !IsUnrechable(cond)) { - conds = z3::expr_vector(provenance.z3_ctx); - elifs.clear(); - } - // Add the current if-statement to the else-if candidates. - conds.push_back(cond); - elifs.push_back(stmt); - } - - // Check if we have enough statements to work with - if (elifs.size() < 2) { - return; - } - - // Create the else-if cascade - clang::IfStmt *sub = nullptr; - for (auto stmt : llvm::make_range(elifs.rbegin(), elifs.rend())) { - auto then = stmt->getThen(); - if (stmt == elifs.back()) { - sub = ast.CreateIf(provenance.marker_expr, then); - provenance.conds[sub] = provenance.conds[stmt]; - substitutions[stmt] = sub; - } else if (stmt == elifs.front()) { - std::vector thens({then}); - sub->setElse(ast.CreateCompoundStmt(thens)); - substitutions[stmt] = nullptr; - } else { - auto elif = ast.CreateIf(provenance.marker_expr, then); - provenance.conds[elif] = provenance.conds[stmt]; - sub->setElse(elif); - sub = elif; - substitutions[stmt] = nullptr; - } - } -} - -/* -`body` will look like this at this point: - - ... - i - n : ... - i - n + 1: if(cond_1) { } else if(cond_2) { } ... else if(cond_n) { } - i - n + 2: if(cond_2) { } else if(cond_3) { } ... else if(cond_n) { } - ... - i - 1 : if(cond_n-1) { } else { } - i : if(cond_n) { } - ... - -but since we chained all of the statements into the first, we want to remove -the others from the body: - - ... - i - n : ... - i - n + 1: if(cond_1) { } else if(cond_2) { } else if ... - ... -*/ -size_t start_delete{i - (ifs.size() - 2)}; -size_t end_delete{i}; -body.erase(body.erase(std::next(body.begin(), start_delete), - std::next(body.begin(), end_delete))); -done_something = true; -} - -if (done_something) { - substitutions[compound] = ast.CreateCompoundStmt(body); -} -return !Stopped(); -} - -void ReachBasedRefine::RunImpl() { - LOG(INFO) << "Reachability-based refinement"; - TransformVisitor::RunImpl(); - TraverseDecl(ast_ctx.getTranslationUnitDecl()); -} - -} // namespace rellic \ No newline at end of file diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 933278ff..8cb7e92e 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -36,7 +36,6 @@ set(AST_HEADERS "${include_dir}/AST/NestedCondProp.h" "${include_dir}/AST/NestedScopeCombine.h" "${include_dir}/AST/NormalizeCond.h" - "${include_dir}/AST/ReachBasedRefine.h" "${include_dir}/AST/StructFieldRenamer.h" "${include_dir}/AST/StructGenerator.h" "${include_dir}/AST/SubprogramGenerator.h" @@ -80,7 +79,6 @@ set(AST_SOURCES AST/Util.cpp AST/Z3CondSimplify.cpp AST/Z3ConvVisitor.cpp - AST/ReachBasedRefine.cpp AST/StructFieldRenamer.cpp AST/StructGenerator.cpp AST/SubprogramGenerator.cpp diff --git a/lib/Decompiler.cpp b/lib/Decompiler.cpp index 04037b62..0e3316b9 100644 --- a/lib/Decompiler.cpp +++ b/lib/Decompiler.cpp @@ -32,7 +32,6 @@ #include "rellic/AST/MaterializeConds.h" #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" -#include "rellic/AST/ReachBasedRefine.h" #include "rellic/AST/StructFieldRenamer.h" #include "rellic/AST/Z3CondSimplify.h" #include "rellic/BC/Util.h" @@ -116,8 +115,6 @@ Result Decompile( cbr_passes.push_back( std::make_unique(provenance, *ast_unit)); - // cbr_passes.push_back( - // std::make_unique(provenance, *ast_unit)); while (pass_cbr.Run()) { ; diff --git a/tools/repl/Repl.cpp b/tools/repl/Repl.cpp index 741e0994..1ed86a79 100644 --- a/tools/repl/Repl.cpp +++ b/tools/repl/Repl.cpp @@ -45,7 +45,6 @@ #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" #include "rellic/AST/NormalizeCond.h" -#include "rellic/AST/ReachBasedRefine.h" #include "rellic/AST/StructFieldRenamer.h" #include "rellic/AST/Z3CondSimplify.h" #include "rellic/BC/Util.h" @@ -94,8 +93,8 @@ static void SetVersion(void) { google::SetVersionString(version.str()); } -static const char* available_passes[] = {"cbr", "dse", "ec", "lr", "ncp", - "nsc", "nc", "rbr", "zcs"}; +static const char* available_passes[] = {"cbr", "dse", "ec", "lr", + "ncp", "nsc", "nc", "zcs"}; static bool diff = false; @@ -163,8 +162,6 @@ static std::unique_ptr CreatePass(const std::string& name) { return std::make_unique(*provenance, *ast_unit); } else if (name == "nc") { return std::make_unique(*provenance, *ast_unit); - } else if (name == "rbr") { - return std::make_unique(*provenance, *ast_unit); } else if (name == "zcs") { return std::make_unique(*provenance, *ast_unit); } else { @@ -204,7 +201,6 @@ static void do_help() { << " nc Condition normalization\n" << " ncp Nested condition propagation\n" << " nsc Nested scope combination\n" - << " rbr Reach-based refinement\n" << " zcs Z3-based condition simplification\n" << std::endl; } diff --git a/tools/xref/Xref.cpp b/tools/xref/Xref.cpp index 518bcea4..aa379603 100644 --- a/tools/xref/Xref.cpp +++ b/tools/xref/Xref.cpp @@ -53,7 +53,6 @@ #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" #include "rellic/AST/NormalizeCond.h" -#include "rellic/AST/ReachBasedRefine.h" #include "rellic/AST/StructFieldRenamer.h" #include "rellic/AST/Util.h" #include "rellic/AST/Z3CondSimplify.h" @@ -424,9 +423,6 @@ static std::unique_ptr CreatePass( } else if (str == "nc") { return std::make_unique(*session.Provenance, *session.Unit); - } else if (str == "rbr") { - return std::make_unique(*session.Provenance, - *session.Unit); } else if (str == "zcs") { return std::make_unique(*session.Provenance, *session.Unit); diff --git a/tools/xref/www/main.js b/tools/xref/www/main.js index 848a7c34..2921f632 100644 --- a/tools/xref/www/main.js +++ b/tools/xref/www/main.js @@ -20,10 +20,6 @@ const cbr = { id: "cbr", label: "Condition-based refinement" } -const rbr = { - id: "rbr", - label: "Reach-based refinement" -} const lr = { id: "lr", label: "Loop refinement" @@ -107,7 +103,6 @@ const app = new Vue({ ncp, nsc, cbr, - rbr, lr, ec, nc, @@ -318,7 +313,7 @@ const app = new Vue({ useDefaultChain() { this.commands = [ dse, - [zcs, ncp, nsc, cbr, rbr], + [zcs, ncp, nsc, cbr], [lr, nsc], [zcs, ncp, nsc], mc, ec From 791de8ef363e51df99c8e6302f7075218e09995e Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Tue, 9 Aug 2022 15:51:27 +0200 Subject: [PATCH 20/57] Fix bug --- lib/AST/GenerateAST.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index ef08017c..e400e44b 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -301,7 +301,8 @@ void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { return provenance.z3_exprs[idx]; }; - auto old_cond{ToExpr(GetReachingCond(block))}; + auto old_cond_idx{GetReachingCond(block)}; + auto old_cond{ToExpr(old_cond_idx)}; if (block->hasNPredecessorsOrMore(1)) { // Gather reaching conditions from predecessors of the block z3::expr_vector conds{provenance.z3_ctx}; @@ -319,7 +320,8 @@ void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { } auto cond{HeavySimplify(provenance.z3_ctx, z3::mk_or(conds))}; - if (!Prove(provenance.z3_ctx, old_cond == cond)) { + if (old_cond_idx == POISON_IDX || + !Prove(provenance.z3_ctx, old_cond == cond)) { provenance.reaching_conds[block] = provenance.z3_exprs.size(); provenance.z3_exprs.push_back(cond); reaching_conds_changed = true; From 47ccf7ecf487603729fd2c48460220fabe090987 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 10 Aug 2022 12:28:41 +0200 Subject: [PATCH 21/57] Try improve performance --- include/rellic/AST/GenerateAST.h | 3 +- include/rellic/AST/Util.h | 12 ++++---- lib/AST/CondBasedRefine.cpp | 9 +++--- lib/AST/DeadStmtElim.cpp | 3 +- lib/AST/GenerateAST.cpp | 41 ++++++++++-------------- lib/AST/IRToASTVisitor.cpp | 36 +++++++++++++++++----- lib/AST/NestedCondProp.cpp | 53 +++++++++++++++++--------------- lib/AST/NestedScopeCombine.cpp | 6 ++-- lib/AST/Util.cpp | 32 +++++++++---------- lib/AST/Z3CondSimplify.cpp | 2 +- 10 files changed, 104 insertions(+), 93 deletions(-) diff --git a/include/rellic/AST/GenerateAST.h b/include/rellic/AST/GenerateAST.h index 05982551..de1fa626 100644 --- a/include/rellic/AST/GenerateAST.h +++ b/include/rellic/AST/GenerateAST.h @@ -45,8 +45,7 @@ class GenerateAST : public llvm::AnalysisInfoMixin { std::vector rpo_walk; unsigned GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond); - unsigned GetOrCreateVarForSwitch(llvm::SwitchInst *inst, - llvm::ConstantInt *c); + unsigned GetOrCreateVarForSwitch(llvm::SwitchInst *inst); unsigned GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, llvm::ConstantInt *c); diff --git a/include/rellic/AST/Util.h b/include/rellic/AST/Util.h index fbbba97b..308cf563 100644 --- a/include/rellic/AST/Util.h +++ b/include/rellic/AST/Util.h @@ -80,9 +80,9 @@ struct Provenance { std::unordered_map z3_br_edges_inv; std::map z3_br_edges; - std::unordered_map z3_sw_edges_inv; + std::unordered_map z3_sw_vars; + std::unordered_map z3_sw_vars_inv; std::map z3_sw_edges; - std::map z3_sw_vars; std::map z3_edges; std::unordered_map reaching_conds; @@ -105,14 +105,14 @@ clang::Expr *Negate(ASTBuilder &ast, clang::Expr *expr); std::string ClangThingToString(const clang::Stmt *stmt); -z3::goal ApplyTactic(z3::context &ctx, const z3::tactic &tactic, z3::expr expr); +z3::goal ApplyTactic(const z3::tactic &tactic, z3::expr expr); -bool Prove(z3::context &ctx, z3::expr expr); +bool Prove(z3::expr expr); -z3::expr HeavySimplify(z3::context &ctx, z3::expr expr); +z3::expr HeavySimplify(z3::expr expr); z3::expr_vector Clone(z3::expr_vector &vec); // Tries to keep each subformula sorted by its id so that they don't get // shuffled around by simplification -z3::expr Sort(z3::context &ctx, z3::expr expr); +z3::expr Sort(z3::expr expr); } // namespace rellic \ No newline at end of file diff --git a/lib/AST/CondBasedRefine.cpp b/lib/AST/CondBasedRefine.cpp index dd5335af..0c37e7ca 100644 --- a/lib/AST/CondBasedRefine.cpp +++ b/lib/AST/CondBasedRefine.cpp @@ -41,7 +41,7 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { auto else_a{if_a->getElse()}; auto else_b{if_b->getElse()}; - if (Prove(provenance.z3_ctx, cond_a == cond_b)) { + if (Prove(cond_a == cond_b)) { std::vector new_then_body{then_a, then_b}; auto new_then{ast.CreateCompoundStmt(new_then_body)}; @@ -69,7 +69,7 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { break; } - if (Prove(provenance.z3_ctx, cond_a == !cond_b)) { + if (Prove(cond_a == !cond_b)) { std::vector new_then_body{then_a}; if (else_b) { new_then_body.push_back(else_b); @@ -95,7 +95,7 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { break; } - if (Prove(provenance.z3_ctx, z3::implies(cond_b, cond_a))) { + if (Prove(z3::implies(cond_b, cond_a))) { std::vector new_then_body{then_a, if_b}; auto new_then{ast.CreateCompoundStmt(new_then_body)}; @@ -109,8 +109,7 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { break; } - if (Prove(provenance.z3_ctx, !(cond_a && cond_b)) && - !Prove(provenance.z3_ctx, cond_a || cond_b)) { + if (Prove(!(cond_a && cond_b)) && !Prove(cond_a || cond_b)) { auto new_if{ast.CreateIf(provenance.marker_expr, then_a)}; if (else_a) { std::vector new_else_body{else_a, if_b}; diff --git a/lib/AST/DeadStmtElim.cpp b/lib/AST/DeadStmtElim.cpp index 68c8577a..b3b28664 100644 --- a/lib/AST/DeadStmtElim.cpp +++ b/lib/AST/DeadStmtElim.cpp @@ -19,8 +19,7 @@ bool DeadStmtElim::VisitIfStmt(clang::IfStmt *ifstmt) { // DLOG(INFO) << "VisitIfStmt"; bool can_delete = false; if (ifstmt->getCond() == provenance.marker_expr) { - can_delete = Prove(provenance.z3_ctx, - !provenance.z3_exprs[provenance.conds[ifstmt]]); + can_delete = Prove(!provenance.z3_exprs[provenance.conds[ifstmt]]); } auto compound = clang::dyn_cast(ifstmt->getThen()); diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index e400e44b..b2e5ec6d 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -169,18 +169,16 @@ unsigned GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, return provenance.z3_br_edges[{inst, cond}]; } -unsigned GenerateAST::GetOrCreateVarForSwitch(llvm::SwitchInst *inst, - llvm::ConstantInt *c) { - if (provenance.z3_sw_vars.find({inst, c}) == provenance.z3_sw_vars.end()) { - auto name{GetName(inst) + - GetName(inst->findCaseValue(c)->getCaseSuccessor())}; - auto var{provenance.z3_ctx.bool_const(name.c_str())}; - provenance.z3_sw_vars[{inst, c}] = provenance.z3_exprs.size(); +unsigned GenerateAST::GetOrCreateVarForSwitch(llvm::SwitchInst *inst) { + if (provenance.z3_sw_vars.find(inst) == provenance.z3_sw_vars.end()) { + auto name{GetName(inst)}; + auto var{provenance.z3_ctx.int_const(name.c_str())}; + provenance.z3_sw_vars[inst] = provenance.z3_exprs.size(); provenance.z3_exprs.push_back(var); - provenance.z3_sw_edges_inv[var.id()] = {inst, c}; - return provenance.z3_sw_vars[{inst, c}]; + provenance.z3_sw_vars_inv[var.id()] = inst; + return provenance.z3_sw_vars[inst]; } else { - return provenance.z3_sw_vars[{inst, c}]; + return provenance.z3_sw_vars[inst]; } } @@ -194,18 +192,12 @@ unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, }; if (provenance.z3_sw_edges.find({inst, c}) == provenance.z3_sw_edges.end()) { if (c) { - auto edge{ToExpr(GetOrCreateVarForSwitch(inst, c))}; - z3::expr_vector vec{provenance.z3_ctx}; - vec.push_back(edge); - for (auto sw_case : inst->cases()) { - if (c != sw_case.getCaseValue()) { - vec.push_back( - !ToExpr(GetOrCreateVarForSwitch(inst, sw_case.getCaseValue()))); - } - } + auto sw_case{inst->findCaseValue(c)}; + auto var{ToExpr(GetOrCreateVarForSwitch(inst))}; + auto expr{var == provenance.z3_ctx.int_val(sw_case->getCaseIndex())}; provenance.z3_sw_edges[{inst, c}] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(z3::mk_and(vec)); + provenance.z3_exprs.push_back(expr); } else { // Default case z3::expr_vector vec{provenance.z3_ctx}; @@ -255,7 +247,7 @@ unsigned GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, ToExpr(GetOrCreateEdgeForSwitch(sw, sw_case.getCaseValue()))); } } - result = HeavySimplify(provenance.z3_ctx, z3::mk_or(or_vec)); + result = HeavySimplify(z3::mk_or(or_vec)); } } break; // Returns @@ -312,16 +304,15 @@ void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { // Construct reaching condition from `pred` to `block` as // `reach_cond[pred] && edge_cond(pred, block)` or one of // the two if the other one is missing. - auto conj_cond{HeavySimplify(provenance.z3_ctx, pred_cond && edge_cond)}; + auto conj_cond{HeavySimplify(pred_cond && edge_cond)}; // Append `conj_cond` to reaching conditions of other // predecessors via an `||`. Use `conj_cond` if there // is no `cond` yet. conds.push_back(conj_cond); } - auto cond{HeavySimplify(provenance.z3_ctx, z3::mk_or(conds))}; - if (old_cond_idx == POISON_IDX || - !Prove(provenance.z3_ctx, old_cond == cond)) { + auto cond{HeavySimplify(z3::mk_or(conds))}; + if (old_cond_idx == POISON_IDX || !Prove(old_cond == cond)) { provenance.reaching_conds[block] = provenance.z3_exprs.size(); provenance.z3_exprs.push_back(cond); reaching_conds_changed = true; diff --git a/lib/AST/IRToASTVisitor.cpp b/lib/AST/IRToASTVisitor.cpp index ebb97c7c..33349b4f 100755 --- a/lib/AST/IRToASTVisitor.cpp +++ b/lib/AST/IRToASTVisitor.cpp @@ -62,6 +62,31 @@ class ExprGen : public llvm::InstVisitor { }; clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { + if (expr.decl().decl_kind() == Z3_OP_EQ) { + CHECK_EQ(expr.num_args(), 2) << "Equalities must have 2 arguments"; + auto a{expr.arg(0)}; + auto b{expr.arg(1)}; + + llvm::SwitchInst *inst{provenance.z3_sw_vars_inv[a.id()]}; + unsigned case_idx{}; + + if (!inst) { + inst = provenance.z3_sw_vars_inv[b.id()]; + case_idx = a.get_numeral_uint(); + } else { + case_idx = b.get_numeral_uint(); + } + + for (auto sw_case : inst->cases()) { + if (sw_case.getCaseIndex() == case_idx) { + return ast.CreateEQ(CreateOperandExpr(inst->getOperandUse(0)), + CreateConstantExpr(sw_case.getCaseValue())); + } + } + + LOG(FATAL) << "Couldn't find switch case"; + } + auto hash{expr.id()}; if (provenance.z3_br_edges_inv.find(hash) != provenance.z3_br_edges_inv.end()) { @@ -71,14 +96,9 @@ clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { return CreateOperandExpr(*(edge.first->op_end() - 3)); } - if (provenance.z3_sw_edges_inv.find(hash) != - provenance.z3_sw_edges_inv.end()) { - auto edge{provenance.z3_sw_edges_inv[hash]}; - CHECK(edge.second) - << "Inverse map should only be populated for not-default switch cases"; - - auto opnd{CreateOperandExpr(edge.first->getOperandUse(0))}; - return ast.CreateEQ(opnd, CreateConstantExpr(edge.second)); + if (provenance.z3_sw_vars_inv.find(hash) != provenance.z3_sw_vars_inv.end()) { + auto inst{provenance.z3_sw_vars_inv[hash]}; + return CreateOperandExpr(inst->getOperandUse(0)); } std::vector args; diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index c16e0926..02056dfc 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -12,27 +12,36 @@ #include #include +#include #include #include "rellic/AST/ASTBuilder.h" #include "rellic/AST/Util.h" +namespace std { +template <> +struct hash { + size_t operator()(const z3::expr& e) const { return e.hash(); } +}; + +template <> +struct equal_to { + bool operator()(const z3::expr& a, const z3::expr& b) const { + return a.id() == b.id(); + } +}; +} // namespace std + namespace rellic { struct KnownExprs { - z3::expr_vector src; - z3::expr_vector dst; - - KnownExprs(z3::context& ctx) : src(ctx), dst(ctx) {} - KnownExprs(z3::expr_vector&& src, z3::expr_vector&& dst) - : src(std::move(src)), dst(std::move(dst)) {} - KnownExprs Clone() { return {::rellic::Clone(src), ::rellic::Clone(dst)}; } + std::unordered_map values; bool IsConstant(z3::expr expr) { - if (Prove(src.ctx(), expr)) { + if (Prove(expr)) { return true; } - if (Prove(src.ctx(), !expr)) { + if (Prove(!expr)) { return true; } @@ -68,25 +77,20 @@ struct KnownExprs { return; } - src.push_back(expr); - dst.push_back(dst.ctx().bool_val(value)); - CHECK_EQ(src.size(), dst.size()); + values[expr] = value; } z3::expr ApplyAssumptions(z3::expr expr, bool& found) { - if (IsConstant(expr)) { + if (IsConstant(expr) || values.empty()) { return expr; } - for (unsigned i{0}; i < dst.size(); ++i) { - if (z3::eq(expr, src[i])) { - found = true; - return dst[i]; - } + if (values.find(expr) != values.end()) { + return expr.ctx().bool_val(values[expr]); } if (expr.is_and() || expr.is_or()) { - z3::expr_vector args{src.ctx()}; + z3::expr_vector args{expr.ctx()}; for (auto arg : expr.args()) { args.push_back(ApplyAssumptions(arg, found)); } @@ -138,7 +142,7 @@ class CompoundVisitor return true; } - auto inner{known_exprs.Clone()}; + auto inner{known_exprs}; inner.AddExpr(new_cond, true); known_exprs.AddExpr(new_cond, false); @@ -158,7 +162,7 @@ class CompoundVisitor return true; } - auto inner{known_exprs.Clone()}; + auto inner{known_exprs}; known_exprs.AddExpr(new_cond, false); if (Visit(do_stmt->getBody(), inner)) { @@ -178,14 +182,14 @@ class CompoundVisitor return true; } - auto inner_then{known_exprs.Clone()}; + auto inner_then{known_exprs}; inner_then.AddExpr(new_cond, true); if (Visit(if_stmt->getThen(), inner_then)) { return true; } if (if_stmt->getElse()) { - auto inner_else{known_exprs.Clone()}; + auto inner_else{known_exprs}; inner_else.AddExpr(new_cond, false); if (Visit(if_stmt->getElse(), inner_else)) { return true; @@ -199,6 +203,7 @@ NestedCondProp::NestedCondProp(Provenance& provenance, clang::ASTUnit& unit) : ASTPass(provenance, unit) {} void NestedCondProp::RunImpl() { + LOG(INFO) << "Propagating conditions"; changed = false; ASTBuilder ast{ast_unit}; CompoundVisitor visitor{provenance, ast, ast_ctx}; @@ -210,7 +215,7 @@ void NestedCondProp::RunImpl() { } if (fdecl->hasBody()) { - KnownExprs known_exprs{provenance.z3_ctx}; + KnownExprs known_exprs{}; if (visitor.Visit(fdecl->getBody(), known_exprs)) { changed = true; return; diff --git a/lib/AST/NestedScopeCombine.cpp b/lib/AST/NestedScopeCombine.cpp index e1db0415..ca18ae6e 100644 --- a/lib/AST/NestedScopeCombine.cpp +++ b/lib/AST/NestedScopeCombine.cpp @@ -24,9 +24,9 @@ bool NestedScopeCombine::VisitIfStmt(clang::IfStmt *ifstmt) { // Determine whether `cond` is a constant expression that is always true and // `ifstmt` should be replaced by `then` in it's parent nodes. auto cond{provenance.z3_exprs[provenance.conds[ifstmt]]}; - if (Prove(provenance.z3_ctx, cond)) { + if (Prove(cond)) { substitutions[ifstmt] = ifstmt->getThen(); - } else if (Prove(provenance.z3_ctx, !cond) && ifstmt->getElse()) { + } else if (Prove(!cond) && ifstmt->getElse()) { substitutions[ifstmt] = ifstmt->getElse(); } return !Stopped(); @@ -34,7 +34,7 @@ bool NestedScopeCombine::VisitIfStmt(clang::IfStmt *ifstmt) { bool NestedScopeCombine::VisitWhileStmt(clang::WhileStmt *stmt) { auto cond{provenance.z3_exprs[provenance.conds[stmt]]}; - if (Prove(provenance.z3_ctx, cond)) { + if (Prove(cond)) { auto body{clang::cast(stmt->getBody())}; if (clang::isa(body->body_back())) { std::vector new_body{body->body_begin(), diff --git a/lib/AST/Util.cpp b/lib/AST/Util.cpp index 04741b5e..41f0d0c6 100644 --- a/lib/AST/Util.cpp +++ b/lib/AST/Util.cpp @@ -338,30 +338,28 @@ std::string ClangThingToString(const clang::Stmt *stmt) { return s; } -z3::goal ApplyTactic(z3::context &ctx, const z3::tactic &tactic, - z3::expr expr) { - z3::goal goal(ctx); +z3::goal ApplyTactic(const z3::tactic &tactic, z3::expr expr) { + z3::goal goal(tactic.ctx()); goal.add(expr.simplify()); auto app{tactic(goal)}; CHECK(app.size() == 1) << "Unexpected multiple goals in application!"; return app[0]; } -bool Prove(z3::context &ctx, z3::expr expr) { - return ApplyTactic(ctx, z3::tactic(ctx, "sat"), !(expr.simplify())) - .is_decided_unsat(); +bool Prove(z3::expr expr) { + return ApplyTactic(z3::tactic(expr.ctx(), "sat"), !expr).is_decided_unsat(); } -z3::expr HeavySimplify(z3::context &ctx, z3::expr expr) { - if (Prove(ctx, expr)) { - return ctx.bool_val(true); +z3::expr HeavySimplify(z3::expr expr) { + if (Prove(expr)) { + return expr.ctx().bool_val(true); } - z3::tactic aig(ctx, "aig"); - z3::tactic simplify(ctx, "simplify"); - z3::tactic ctx_solver_simplify(ctx, "ctx-solver-simplify"); + z3::tactic aig(expr.ctx(), "aig"); + z3::tactic simplify(expr.ctx(), "simplify"); + z3::tactic ctx_solver_simplify(expr.ctx(), "ctx-solver-simplify"); auto tactic{simplify & aig & ctx_solver_simplify}; - return ApplyTactic(ctx, tactic, expr).as_expr(); + return ApplyTactic(tactic, expr).as_expr(); } z3::expr_vector Clone(z3::expr_vector &vec) { @@ -373,7 +371,7 @@ z3::expr_vector Clone(z3::expr_vector &vec) { return clone; } -z3::expr Sort(z3::context &ctx, z3::expr expr) { +z3::expr Sort(z3::expr expr) { if (expr.is_and() || expr.is_or()) { std::vector args_indices(expr.num_args(), 0); std::iota(args_indices.begin(), args_indices.end(), 0); @@ -381,9 +379,9 @@ z3::expr Sort(z3::context &ctx, z3::expr expr) { [&expr](unsigned a, unsigned b) { return expr.arg(a).id() < expr.arg(b).id(); }); - z3::expr_vector new_args{ctx}; + z3::expr_vector new_args{expr.ctx()}; for (auto idx : args_indices) { - new_args.push_back(Sort(ctx, expr.arg(idx))); + new_args.push_back(Sort(expr.arg(idx))); } if (expr.is_and()) { return z3::mk_and(new_args); @@ -393,7 +391,7 @@ z3::expr Sort(z3::context &ctx, z3::expr expr) { } if (expr.is_not()) { - return !Sort(ctx, expr.arg(0)); + return !Sort(expr.arg(0)); } return expr; diff --git a/lib/AST/Z3CondSimplify.cpp b/lib/AST/Z3CondSimplify.cpp index c0f82d08..1a20534d 100644 --- a/lib/AST/Z3CondSimplify.cpp +++ b/lib/AST/Z3CondSimplify.cpp @@ -22,7 +22,7 @@ Z3CondSimplify::Z3CondSimplify(Provenance &provenance, clang::ASTUnit &unit) void Z3CondSimplify::RunImpl() { LOG(INFO) << "Simplifying conditions using Z3"; for (size_t i{0}; i < provenance.z3_exprs.size() && !Stopped(); ++i) { - auto simpl{Sort(provenance.z3_ctx, provenance.z3_exprs[i].simplify())}; + auto simpl{Sort(provenance.z3_exprs[i].simplify())}; provenance.z3_exprs.set(i, simpl); } } From 5d1c7ee3f80778fc4b192890311a4830ce06fec3 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Thu, 11 Aug 2022 11:17:22 +0200 Subject: [PATCH 22/57] Reintroduce RBR --- include/rellic/AST/ReachBasedRefine.h | 50 +++++++++++++++++++ lib/AST/ReachBasedRefine.cpp | 71 +++++++++++++++++++++++++++ lib/CMakeLists.txt | 2 + lib/Decompiler.cpp | 3 ++ tools/repl/Repl.cpp | 8 ++- tools/xref/Xref.cpp | 4 ++ tools/xref/www/main.js | 7 ++- 7 files changed, 142 insertions(+), 3 deletions(-) create mode 100644 include/rellic/AST/ReachBasedRefine.h create mode 100644 lib/AST/ReachBasedRefine.cpp diff --git a/include/rellic/AST/ReachBasedRefine.h b/include/rellic/AST/ReachBasedRefine.h new file mode 100644 index 00000000..eff7fdb2 --- /dev/null +++ b/include/rellic/AST/ReachBasedRefine.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-present, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include "rellic/AST/TransformVisitor.h" +#include "rellic/AST/Z3ConvVisitor.h" + +namespace rellic { + +/* + * This pass restructures a sequence of if statements that have a shape like + * + * if(cond1 && !cond2 && !cond3) { + * body1; + * } + * if(cond2 && !cond1 && !cond3) { + * body2; + * } + * if(cond3 && !cond1 && !cond2) { + * body3; + * } + * + * into + * + * if(cond1) { + * body1; + * } else if(cond2) { + * body2; + * } else if(cond3) { + * body3; + * } + */ +class ReachBasedRefine : public TransformVisitor { + private: + protected: + void RunImpl() override; + + public: + ReachBasedRefine(Provenance &provenance, clang::ASTUnit &unit); + + bool VisitCompoundStmt(clang::CompoundStmt *compound); +}; + +} // namespace rellic \ No newline at end of file diff --git a/lib/AST/ReachBasedRefine.cpp b/lib/AST/ReachBasedRefine.cpp new file mode 100644 index 00000000..746a3014 --- /dev/null +++ b/lib/AST/ReachBasedRefine.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2021-present, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#include "rellic/AST/ReachBasedRefine.h" + +#include +#include + +namespace rellic { + +ReachBasedRefine::ReachBasedRefine(Provenance &provenance, clang::ASTUnit &unit) + : TransformVisitor(provenance, unit) {} + +bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { + std::vector body{compound->body_begin(), compound->body_end()}; + std::vector ifs; + z3::expr_vector conds{provenance.z3_ctx}; + bool done_something{false}; + for (size_t i{0}; i < body.size(); ++i) { + if (auto if_stmt = clang::dyn_cast(body[i])) { + ifs.push_back(if_stmt); + auto cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; + if (!if_stmt->getElse() && Prove(!(cond && z3::mk_or(conds)))) { + conds.push_back(cond); + + if (Prove(z3::mk_or(conds)) && ifs.size() > 2) { + auto last_if{ifs[0]}; + for (auto stmt : ifs) { + if (stmt == ifs.front()) { + continue; + } + if (stmt == ifs.back()) { + last_if->setElse(stmt->getThen()); + } else { + last_if->setElse(stmt); + last_if = stmt; + } + } + + size_t start_delete{i - (ifs.size() - 2)}; + size_t end_delete{i}; + body.erase(body.erase(std::next(body.begin(), start_delete), + std::next(body.begin(), end_delete))); + done_something = true; + break; + } + } + } + ifs.clear(); + conds.resize(0); + } + + if (done_something) { + substitutions[compound] = ast.CreateCompoundStmt(body); + } + + return !Stopped(); +} + +void ReachBasedRefine::RunImpl() { + LOG(INFO) << "Reachability-based refinement"; + TransformVisitor::RunImpl(); + TraverseDecl(ast_ctx.getTranslationUnitDecl()); +} + +} // namespace rellic \ No newline at end of file diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 8cb7e92e..933278ff 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -36,6 +36,7 @@ set(AST_HEADERS "${include_dir}/AST/NestedCondProp.h" "${include_dir}/AST/NestedScopeCombine.h" "${include_dir}/AST/NormalizeCond.h" + "${include_dir}/AST/ReachBasedRefine.h" "${include_dir}/AST/StructFieldRenamer.h" "${include_dir}/AST/StructGenerator.h" "${include_dir}/AST/SubprogramGenerator.h" @@ -79,6 +80,7 @@ set(AST_SOURCES AST/Util.cpp AST/Z3CondSimplify.cpp AST/Z3ConvVisitor.cpp + AST/ReachBasedRefine.cpp AST/StructFieldRenamer.cpp AST/StructGenerator.cpp AST/SubprogramGenerator.cpp diff --git a/lib/Decompiler.cpp b/lib/Decompiler.cpp index 0e3316b9..04037b62 100644 --- a/lib/Decompiler.cpp +++ b/lib/Decompiler.cpp @@ -32,6 +32,7 @@ #include "rellic/AST/MaterializeConds.h" #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" +#include "rellic/AST/ReachBasedRefine.h" #include "rellic/AST/StructFieldRenamer.h" #include "rellic/AST/Z3CondSimplify.h" #include "rellic/BC/Util.h" @@ -115,6 +116,8 @@ Result Decompile( cbr_passes.push_back( std::make_unique(provenance, *ast_unit)); + // cbr_passes.push_back( + // std::make_unique(provenance, *ast_unit)); while (pass_cbr.Run()) { ; diff --git a/tools/repl/Repl.cpp b/tools/repl/Repl.cpp index 1ed86a79..741e0994 100644 --- a/tools/repl/Repl.cpp +++ b/tools/repl/Repl.cpp @@ -45,6 +45,7 @@ #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" #include "rellic/AST/NormalizeCond.h" +#include "rellic/AST/ReachBasedRefine.h" #include "rellic/AST/StructFieldRenamer.h" #include "rellic/AST/Z3CondSimplify.h" #include "rellic/BC/Util.h" @@ -93,8 +94,8 @@ static void SetVersion(void) { google::SetVersionString(version.str()); } -static const char* available_passes[] = {"cbr", "dse", "ec", "lr", - "ncp", "nsc", "nc", "zcs"}; +static const char* available_passes[] = {"cbr", "dse", "ec", "lr", "ncp", + "nsc", "nc", "rbr", "zcs"}; static bool diff = false; @@ -162,6 +163,8 @@ static std::unique_ptr CreatePass(const std::string& name) { return std::make_unique(*provenance, *ast_unit); } else if (name == "nc") { return std::make_unique(*provenance, *ast_unit); + } else if (name == "rbr") { + return std::make_unique(*provenance, *ast_unit); } else if (name == "zcs") { return std::make_unique(*provenance, *ast_unit); } else { @@ -201,6 +204,7 @@ static void do_help() { << " nc Condition normalization\n" << " ncp Nested condition propagation\n" << " nsc Nested scope combination\n" + << " rbr Reach-based refinement\n" << " zcs Z3-based condition simplification\n" << std::endl; } diff --git a/tools/xref/Xref.cpp b/tools/xref/Xref.cpp index aa379603..518bcea4 100644 --- a/tools/xref/Xref.cpp +++ b/tools/xref/Xref.cpp @@ -53,6 +53,7 @@ #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" #include "rellic/AST/NormalizeCond.h" +#include "rellic/AST/ReachBasedRefine.h" #include "rellic/AST/StructFieldRenamer.h" #include "rellic/AST/Util.h" #include "rellic/AST/Z3CondSimplify.h" @@ -423,6 +424,9 @@ static std::unique_ptr CreatePass( } else if (str == "nc") { return std::make_unique(*session.Provenance, *session.Unit); + } else if (str == "rbr") { + return std::make_unique(*session.Provenance, + *session.Unit); } else if (str == "zcs") { return std::make_unique(*session.Provenance, *session.Unit); diff --git a/tools/xref/www/main.js b/tools/xref/www/main.js index 2921f632..848a7c34 100644 --- a/tools/xref/www/main.js +++ b/tools/xref/www/main.js @@ -20,6 +20,10 @@ const cbr = { id: "cbr", label: "Condition-based refinement" } +const rbr = { + id: "rbr", + label: "Reach-based refinement" +} const lr = { id: "lr", label: "Loop refinement" @@ -103,6 +107,7 @@ const app = new Vue({ ncp, nsc, cbr, + rbr, lr, ec, nc, @@ -313,7 +318,7 @@ const app = new Vue({ useDefaultChain() { this.commands = [ dse, - [zcs, ncp, nsc, cbr], + [zcs, ncp, nsc, cbr, rbr], [lr, nsc], [zcs, ncp, nsc], mc, ec From a85e490bc704adfef9c1bd4ae86a9ea59c67e2e0 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Thu, 11 Aug 2022 11:33:41 +0200 Subject: [PATCH 23/57] Fix NCP --- lib/AST/NestedCondProp.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index 02056dfc..f543cb31 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -21,7 +21,7 @@ namespace std { template <> struct hash { - size_t operator()(const z3::expr& e) const { return e.hash(); } + size_t operator()(const z3::expr& e) const { return e.id(); } }; template <> @@ -49,8 +49,12 @@ struct KnownExprs { } void AddExpr(z3::expr expr, bool value) { - if (IsConstant(expr)) { - return; + switch (expr.decl().decl_kind()) { + case Z3_OP_TRUE: + case Z3_OP_FALSE: + return; + default: + break; } if (expr.is_not()) { @@ -81,11 +85,12 @@ struct KnownExprs { } z3::expr ApplyAssumptions(z3::expr expr, bool& found) { - if (IsConstant(expr) || values.empty()) { + if (values.empty()) { return expr; } if (values.find(expr) != values.end()) { + found = true; return expr.ctx().bool_val(values[expr]); } From d5d68bf16cdff5aaea899590cbea86916f898ee9 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Thu, 11 Aug 2022 11:35:09 +0200 Subject: [PATCH 24/57] Remove `IsConstant` Too slow --- lib/AST/NestedCondProp.cpp | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index f543cb31..37e0b0da 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -36,18 +36,6 @@ namespace rellic { struct KnownExprs { std::unordered_map values; - bool IsConstant(z3::expr expr) { - if (Prove(expr)) { - return true; - } - - if (Prove(!expr)) { - return true; - } - - return false; - } - void AddExpr(z3::expr expr, bool value) { switch (expr.decl().decl_kind()) { case Z3_OP_TRUE: From 367a1c68593c8ae7a57e80bac32746a50466e72c Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Thu, 11 Aug 2022 11:58:45 +0200 Subject: [PATCH 25/57] Reactivate RBR --- lib/Decompiler.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Decompiler.cpp b/lib/Decompiler.cpp index 04037b62..1297d3cf 100644 --- a/lib/Decompiler.cpp +++ b/lib/Decompiler.cpp @@ -116,8 +116,8 @@ Result Decompile( cbr_passes.push_back( std::make_unique(provenance, *ast_unit)); - // cbr_passes.push_back( - // std::make_unique(provenance, *ast_unit)); + cbr_passes.push_back( + std::make_unique(provenance, *ast_unit)); while (pass_cbr.Run()) { ; From a1de89daf9afe9509acfe9526f6c4c3e7daa90c5 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Thu, 11 Aug 2022 12:15:34 +0200 Subject: [PATCH 26/57] Remove ignored tests --- ci/angha_1k_test_settings.json | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ci/angha_1k_test_settings.json b/ci/angha_1k_test_settings.json index 06f7aafe..34b4d485 100644 --- a/ci/angha_1k_test_settings.json +++ b/ci/angha_1k_test_settings.json @@ -1,12 +1,8 @@ { "tests.ignore": [ - "amd64/linux/tools/perf/bench/extr_numa.c___bench_numa.bc", "amd64/SoftEtherVPN/src/See/extr_memory_t.h_SW_LONG_AT.bc", - "arm64/linux/tools/perf/bench/extr_numa.c___bench_numa.bc", "arm64/SoftEtherVPN/src/See/extr_memory_t.h_SW_LONG_AT.bc", - "armv7/linux/tools/perf/bench/extr_numa.c___bench_numa.bc", "armv7/SoftEtherVPN/src/See/extr_memory_t.h_SW_LONG_AT.bc", - "x86/linux/tools/perf/bench/extr_numa.c___bench_numa.bc", "x86/SoftEtherVPN/src/See/extr_memory_t.h_SW_LONG_AT.bc" ] } \ No newline at end of file From d5f2f7d85c4f9816e11c964e83b58cb7269a33e8 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Fri, 12 Aug 2022 14:09:34 +0200 Subject: [PATCH 27/57] Fix CBR & RBR --- lib/AST/CondBasedRefine.cpp | 30 --------- lib/AST/ReachBasedRefine.cpp | 126 ++++++++++++++++++++++++++--------- lib/AST/Util.cpp | 3 +- 3 files changed, 98 insertions(+), 61 deletions(-) diff --git a/lib/AST/CondBasedRefine.cpp b/lib/AST/CondBasedRefine.cpp index 0c37e7ca..2a85d2bd 100644 --- a/lib/AST/CondBasedRefine.cpp +++ b/lib/AST/CondBasedRefine.cpp @@ -94,36 +94,6 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { did_something = true; break; } - - if (Prove(z3::implies(cond_b, cond_a))) { - std::vector new_then_body{then_a, if_b}; - auto new_then{ast.CreateCompoundStmt(new_then_body)}; - - auto new_if{ast.CreateIf(provenance.marker_expr, new_then)}; - new_if->setElse(else_a); - - provenance.conds[new_if] = provenance.conds[if_a]; - new_body[i] = new_if; - new_body.erase(std::next(new_body.begin(), i + 1)); - did_something = true; - break; - } - - if (Prove(!(cond_a && cond_b)) && !Prove(cond_a || cond_b)) { - auto new_if{ast.CreateIf(provenance.marker_expr, then_a)}; - if (else_a) { - std::vector new_else_body{else_a, if_b}; - new_if->setElse(ast.CreateCompoundStmt(new_else_body)); - } else { - new_if->setElse(if_b); - } - - provenance.conds[new_if] = provenance.conds[if_a]; - new_body[i] = new_if; - new_body.erase(std::next(new_body.begin(), i + 1)); - did_something = true; - break; - } } if (did_something) { auto new_compound{ast.CreateCompoundStmt(new_body)}; diff --git a/lib/AST/ReachBasedRefine.cpp b/lib/AST/ReachBasedRefine.cpp index 746a3014..c02f516f 100644 --- a/lib/AST/ReachBasedRefine.cpp +++ b/lib/AST/ReachBasedRefine.cpp @@ -20,45 +20,111 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { std::vector body{compound->body_begin(), compound->body_end()}; std::vector ifs; z3::expr_vector conds{provenance.z3_ctx}; + bool done_something{false}; for (size_t i{0}; i < body.size(); ++i) { - if (auto if_stmt = clang::dyn_cast(body[i])) { - ifs.push_back(if_stmt); - auto cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; - if (!if_stmt->getElse() && Prove(!(cond && z3::mk_or(conds)))) { - conds.push_back(cond); - - if (Prove(z3::mk_or(conds)) && ifs.size() > 2) { - auto last_if{ifs[0]}; - for (auto stmt : ifs) { - if (stmt == ifs.front()) { - continue; - } - if (stmt == ifs.back()) { - last_if->setElse(stmt->getThen()); - } else { - last_if->setElse(stmt); - last_if = stmt; - } - } - - size_t start_delete{i - (ifs.size() - 2)}; - size_t end_delete{i}; - body.erase(body.erase(std::next(body.begin(), start_delete), - std::next(body.begin(), end_delete))); - done_something = true; - break; - } + auto if_stmt{clang::dyn_cast(body[i])}; + if (!if_stmt) { + ifs.clear(); + conds.resize(0); + continue; + } + + ifs.push_back(if_stmt); + auto cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; + + if (if_stmt->getElse()) { + // We cannot link `if` statements that contain `else` branches + ifs.clear(); + conds.resize(0); + continue; + } + + // Is the current `if` statement unreachable from all the others? + bool is_unreachable{Prove(!(cond && z3::mk_or(conds)))}; + + if (!is_unreachable) { + ifs.clear(); + conds.resize(0); + continue; + } + + conds.push_back(cond); + + // Do the collected statements cover all possibilities? + auto is_complete{Prove(z3::mk_or(conds))}; + + if (ifs.size() <= 2 || !is_complete) { + // We need to collect more statements + continue; + } + + /* + `body` will look like this at this point: + + ... + i - n : ... + i - n + 1: if(cond_1) { } + i - n + 2: if(cond_2) { } + ... + i - 1 : if(cond_n-1) { } + i : if(cond_n) { } + ... + + and we want to chain all of the statements together: + ... + i - n : ... + i - n + 1: if(cond_1) { } else if(cond_2) { } ... else if(cond_n) { } + i - n + 2: if(cond_2) { } else if(cond_3) { } ... else if(cond_n) { } + ... + i - 1 : if(cond_n-1) { } else if(cond_n) { } + i : if(cond_n) { } + ... + */ + auto last_if{ifs[0]}; + for (auto stmt : ifs) { + if (stmt == ifs.front()) { + continue; + } + if (stmt == ifs.back()) { + last_if->setElse(stmt->getThen()); + } else { + last_if->setElse(stmt); + last_if = stmt; } } - ifs.clear(); - conds.resize(0); + + /* + `body` will look like this at this point: + + ... + i - n : ... + i - n + 1: if(cond_1) { } else if(cond_2) { } ... else if(cond_n) { } + i - n + 2: if(cond_2) { } else if(cond_3) { } ... else if(cond_n) { } + ... + i - 1 : if(cond_n-1) { } else if(cond_n) { } + i : if(cond_n) { } + ... + + but since we chained all of the statements into the first, we want to remove + the others from the body: + + ... + i - n : ... + i - n + 1: if(cond_1) { } else if(cond_2) { } else if ... + ... + */ + size_t start_delete{i - (ifs.size() - 2)}; + size_t end_delete{i}; + body.erase(body.erase(std::next(body.begin(), start_delete), + std::next(body.begin(), end_delete))); + done_something = true; + break; } if (done_something) { substitutions[compound] = ast.CreateCompoundStmt(body); } - return !Stopped(); } diff --git a/lib/AST/Util.cpp b/lib/AST/Util.cpp index 41f0d0c6..aaeaa814 100644 --- a/lib/AST/Util.cpp +++ b/lib/AST/Util.cpp @@ -347,7 +347,8 @@ z3::goal ApplyTactic(const z3::tactic &tactic, z3::expr expr) { } bool Prove(z3::expr expr) { - return ApplyTactic(z3::tactic(expr.ctx(), "sat"), !expr).is_decided_unsat(); + return ApplyTactic(z3::tactic(expr.ctx(), "sat"), !(expr).simplify()) + .is_decided_unsat(); } z3::expr HeavySimplify(z3::expr expr) { From 700505960ee8b7c1400b50e7db1b137533bf9273 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Fri, 12 Aug 2022 16:16:24 +0200 Subject: [PATCH 28/57] Factor out RBR lambda --- lib/AST/ReachBasedRefine.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/lib/AST/ReachBasedRefine.cpp b/lib/AST/ReachBasedRefine.cpp index c02f516f..c3d0d587 100644 --- a/lib/AST/ReachBasedRefine.cpp +++ b/lib/AST/ReachBasedRefine.cpp @@ -21,12 +21,16 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { std::vector ifs; z3::expr_vector conds{provenance.z3_ctx}; + auto ResetChain = [&]() { + ifs.clear(); + conds.resize(0); + }; + bool done_something{false}; for (size_t i{0}; i < body.size(); ++i) { auto if_stmt{clang::dyn_cast(body[i])}; if (!if_stmt) { - ifs.clear(); - conds.resize(0); + ResetChain(); continue; } @@ -35,8 +39,7 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { if (if_stmt->getElse()) { // We cannot link `if` statements that contain `else` branches - ifs.clear(); - conds.resize(0); + ResetChain(); continue; } @@ -44,8 +47,7 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { bool is_unreachable{Prove(!(cond && z3::mk_or(conds)))}; if (!is_unreachable) { - ifs.clear(); - conds.resize(0); + ResetChain(); continue; } @@ -77,7 +79,7 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { i - n + 1: if(cond_1) { } else if(cond_2) { } ... else if(cond_n) { } i - n + 2: if(cond_2) { } else if(cond_3) { } ... else if(cond_n) { } ... - i - 1 : if(cond_n-1) { } else if(cond_n) { } + i - 1 : if(cond_n-1) { } else { } i : if(cond_n) { } ... */ From 3747557fea85489e55c34ac823e840bc59d2df16 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Fri, 12 Aug 2022 16:20:29 +0200 Subject: [PATCH 29/57] CBR & RBR review --- lib/AST/CondBasedRefine.cpp | 18 ++++++++---------- lib/AST/ReachBasedRefine.cpp | 3 +-- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/lib/AST/CondBasedRefine.cpp b/lib/AST/CondBasedRefine.cpp index 2a85d2bd..8cfd57a7 100644 --- a/lib/AST/CondBasedRefine.cpp +++ b/lib/AST/CondBasedRefine.cpp @@ -21,10 +21,9 @@ CondBasedRefine::CondBasedRefine(Provenance &provenance, clang::ASTUnit &unit) bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { std::vector body{compound->body_begin(), compound->body_end()}; - std::vector new_body{body}; bool did_something{false}; - for (size_t i{0}; i + 1 < body.size(); ++i) { + for (size_t i{0}; i + 1 < body.size() && !did_something; ++i) { auto if_a{clang::dyn_cast(body[i])}; auto if_b{clang::dyn_cast(body[i + 1])}; @@ -41,8 +40,9 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { auto else_a{if_a->getElse()}; auto else_b{if_b->getElse()}; + std::vector new_then_body{then_a}; if (Prove(cond_a == cond_b)) { - std::vector new_then_body{then_a, then_b}; + new_then_body.push_back(then_b); auto new_then{ast.CreateCompoundStmt(new_then_body)}; auto new_if{ast.CreateIf(provenance.marker_expr, new_then)}; @@ -63,14 +63,13 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { } provenance.conds[new_if] = provenance.conds[if_a]; - new_body[i] = new_if; - new_body.erase(std::next(new_body.begin(), i + 1)); + body[i] = new_if; + body.erase(std::next(body.begin(), i + 1)); did_something = true; break; } if (Prove(cond_a == !cond_b)) { - std::vector new_then_body{then_a}; if (else_b) { new_then_body.push_back(else_b); } @@ -89,14 +88,13 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { new_if->setElse(new_else); provenance.conds[new_if] = provenance.conds[if_a]; - new_body[i] = new_if; - new_body.erase(std::next(new_body.begin(), i + 1)); + body[i] = new_if; + body.erase(std::next(body.begin(), i + 1)); did_something = true; - break; } } if (did_something) { - auto new_compound{ast.CreateCompoundStmt(new_body)}; + auto new_compound{ast.CreateCompoundStmt(body)}; substitutions[compound] = new_compound; } return !Stopped(); diff --git a/lib/AST/ReachBasedRefine.cpp b/lib/AST/ReachBasedRefine.cpp index c3d0d587..95ab1786 100644 --- a/lib/AST/ReachBasedRefine.cpp +++ b/lib/AST/ReachBasedRefine.cpp @@ -27,7 +27,7 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { }; bool done_something{false}; - for (size_t i{0}; i < body.size(); ++i) { + for (size_t i{0}; i < body.size() && !done_something; ++i) { auto if_stmt{clang::dyn_cast(body[i])}; if (!if_stmt) { ResetChain(); @@ -121,7 +121,6 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { body.erase(body.erase(std::next(body.begin(), start_delete), std::next(body.begin(), end_delete))); done_something = true; - break; } if (done_something) { From 1b8d3d5588920fed1fe50073d02190931203b811 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Fri, 19 Aug 2022 09:39:43 +0200 Subject: [PATCH 30/57] Cleanup --- include/rellic/AST/CondBasedRefine.h | 1 - include/rellic/AST/MaterializeConds.h | 5 +- include/rellic/AST/NormalizeCond.h | 33 - include/rellic/AST/ReachBasedRefine.h | 1 - include/rellic/AST/Util.h | 3 - include/rellic/AST/Z3CondSimplify.h | 1 - include/rellic/AST/Z3ConvVisitor.h | 106 --- lib/AST/CondBasedRefine.cpp | 1 - lib/AST/MaterializeConds.cpp | 5 +- lib/AST/NormalizeCond.cpp | 207 ----- lib/AST/Util.cpp | 13 - lib/AST/Z3CondSimplify.cpp | 5 +- lib/AST/Z3ConvVisitor.cpp | 1223 ------------------------- lib/CMakeLists.txt | 4 - tools/repl/Repl.cpp | 3 - tools/xref/Xref.cpp | 4 - unittests/CMakeLists.txt | 9 +- 17 files changed, 10 insertions(+), 1614 deletions(-) delete mode 100644 include/rellic/AST/NormalizeCond.h delete mode 100644 include/rellic/AST/Z3ConvVisitor.h delete mode 100644 lib/AST/NormalizeCond.cpp delete mode 100644 lib/AST/Z3ConvVisitor.cpp diff --git a/include/rellic/AST/CondBasedRefine.h b/include/rellic/AST/CondBasedRefine.h index aa767d46..c6a519f7 100644 --- a/include/rellic/AST/CondBasedRefine.h +++ b/include/rellic/AST/CondBasedRefine.h @@ -11,7 +11,6 @@ #include "rellic/AST/ASTPass.h" #include "rellic/AST/IRToASTVisitor.h" #include "rellic/AST/TransformVisitor.h" -#include "rellic/AST/Z3ConvVisitor.h" namespace rellic { diff --git a/include/rellic/AST/MaterializeConds.h b/include/rellic/AST/MaterializeConds.h index 88f944f9..07513c1c 100644 --- a/include/rellic/AST/MaterializeConds.h +++ b/include/rellic/AST/MaterializeConds.h @@ -12,13 +12,12 @@ #include "rellic/AST/IRToASTVisitor.h" #include "rellic/AST/TransformVisitor.h" -#include "rellic/AST/Util.h" namespace rellic { /* - * This pass simplifies conditions using Z3 by trying to remove terms that are - * trivially true or false + * This pass substitutes the marker expression in loops and `if` statements for + * their translation from Z3 formulas */ class MaterializeConds : public TransformVisitor { private: diff --git a/include/rellic/AST/NormalizeCond.h b/include/rellic/AST/NormalizeCond.h deleted file mode 100644 index 45bca02b..00000000 --- a/include/rellic/AST/NormalizeCond.h +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2022-present, Trail of Bits, Inc. - * All rights reserved. - * - * This source code is licensed in accordance with the terms specified in - * the LICENSE file found in the root directory of this source tree. - */ - -#pragma once - -#include "rellic/AST/TransformVisitor.h" - -namespace rellic { - -/* - * This pass turns conditions into conjunctive normal form (CNF). Warning: this - * has the potential of creating an exponential number of terms, so it's best to - * perform this pass after simplification. - */ -class NormalizeCond : public TransformVisitor { - protected: - void RunImpl() override; - - public: - static char ID; - - NormalizeCond(Provenance &provenance, clang::ASTUnit &unit); - - bool VisitUnaryOperator(clang::UnaryOperator *op); - bool VisitBinaryOperator(clang::BinaryOperator *op); -}; - -} // namespace rellic diff --git a/include/rellic/AST/ReachBasedRefine.h b/include/rellic/AST/ReachBasedRefine.h index eff7fdb2..a8fbf3cc 100644 --- a/include/rellic/AST/ReachBasedRefine.h +++ b/include/rellic/AST/ReachBasedRefine.h @@ -9,7 +9,6 @@ #pragma once #include "rellic/AST/TransformVisitor.h" -#include "rellic/AST/Z3ConvVisitor.h" namespace rellic { diff --git a/include/rellic/AST/Util.h b/include/rellic/AST/Util.h index 308cf563..3c37cea0 100644 --- a/include/rellic/AST/Util.h +++ b/include/rellic/AST/Util.h @@ -100,9 +100,6 @@ void CopyProvenance(TKey1 *from, TKey2 *to, clang::Expr *Clone(clang::ASTUnit &unit, clang::Expr *stmt, ExprToUseMap &provenance); -// Negates an expression while stripping parentheses and double negations -clang::Expr *Negate(ASTBuilder &ast, clang::Expr *expr); - std::string ClangThingToString(const clang::Stmt *stmt); z3::goal ApplyTactic(const z3::tactic &tactic, z3::expr expr); diff --git a/include/rellic/AST/Z3CondSimplify.h b/include/rellic/AST/Z3CondSimplify.h index 2aa1bfc0..835e6ca0 100644 --- a/include/rellic/AST/Z3CondSimplify.h +++ b/include/rellic/AST/Z3CondSimplify.h @@ -11,7 +11,6 @@ #include #include "rellic/AST/ASTPass.h" -#include "rellic/AST/Util.h" namespace rellic { diff --git a/include/rellic/AST/Z3ConvVisitor.h b/include/rellic/AST/Z3ConvVisitor.h deleted file mode 100644 index 889976eb..00000000 --- a/include/rellic/AST/Z3ConvVisitor.h +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright (c) 2021-present, Trail of Bits, Inc. - * All rights reserved. - * - * This source code is licensed in accordance with the terms specified in - * the LICENSE file found in the root directory of this source tree. - */ - -#pragma once - -#include -#include - -#include - -#include "rellic/AST/ASTBuilder.h" - -namespace rellic { - -class Z3ConvVisitor : public clang::RecursiveASTVisitor { - private: - clang::ASTContext *c_ctx; - ASTBuilder ast; - - z3::context *z_ctx; - - // Expression maps - z3::expr_vector z_expr_vec; - std::unordered_map z_expr_map; - std::unordered_map c_expr_map; - // Declaration maps - z3::func_decl_vector z_decl_vec; - std::unordered_map z_decl_map; - std::unordered_map c_decl_map; - - void InsertZ3Expr(clang::Expr *c_expr, z3::expr z_expr); - z3::expr GetZ3Expr(clang::Expr *c_expr); - - void InsertCExpr(z3::expr z_expr, clang::Expr *c_expr); - clang::Expr *GetCExpr(z3::expr z_expr); - - void InsertZ3Decl(clang::ValueDecl *c_decl, z3::func_decl z_decl); - z3::func_decl GetZ3Decl(clang::ValueDecl *c_decl); - - void InsertCValDecl(z3::func_decl z_decl, clang::ValueDecl *c_decl); - clang::ValueDecl *GetCValDecl(z3::func_decl z_decl); - - z3::sort GetZ3Sort(clang::QualType type); - - clang::Expr *CreateLiteralExpr(z3::expr z_expr); - - z3::expr CreateZ3BitwiseCast(z3::expr expr, size_t src, size_t dst, - bool sign); - - void VisitZ3Expr(z3::expr z_expr); - void VisitZ3Decl(z3::func_decl z_decl); - - template - bool HandleCastExpr(T *c_cast); - clang::Expr *HandleZ3Concat(z3::expr z_op); - clang::Expr *HandleZ3Uninterpreted(z3::expr z_op); - - public: - z3::func_decl GetOrCreateZ3Decl(clang::ValueDecl *c_decl); - z3::expr GetOrCreateZ3Expr(clang::Expr *c_expr); - - clang::Expr *GetOrCreateCExpr(z3::expr z_expr); - clang::ValueDecl *GetOrCreateCValDecl(z3::func_decl z_decl); - - Z3ConvVisitor(clang::ASTUnit &unit, z3::context *z_ctx); - bool shouldTraversePostOrder() { return true; } - - z3::expr Z3BoolCast(z3::expr expr); - z3::expr Z3BoolToBVCast(z3::expr expr); - - bool VisitArraySubscriptExpr(clang::ArraySubscriptExpr *sub); - bool VisitImplicitCastExpr(clang::ImplicitCastExpr *cast); - bool VisitCStyleCastExpr(clang::CStyleCastExpr *cast); - bool VisitMemberExpr(clang::MemberExpr *expr); - bool VisitCallExpr(clang::CallExpr *call); - bool VisitParenExpr(clang::ParenExpr *parens); - bool VisitUnaryOperator(clang::UnaryOperator *c_op); - bool VisitBinaryOperator(clang::BinaryOperator *c_op); - bool VisitConditionalOperator(clang::ConditionalOperator *c_op); - bool VisitDeclRefExpr(clang::DeclRefExpr *c_ref); - bool VisitCharacterLiteral(clang::CharacterLiteral *lit); - bool VisitIntegerLiteral(clang::IntegerLiteral *lit); - bool VisitFloatingLiteral(clang::FloatingLiteral *lit); - - bool VisitVarDecl(clang::VarDecl *var); - bool VisitFieldDecl(clang::FieldDecl *field); - bool VisitFunctionDecl(clang::FunctionDecl *func); - - // Do not traverse function bodies - bool TraverseFunctionDecl(clang::FunctionDecl *func) { - WalkUpFromFunctionDecl(func); - return true; - } - - void VisitConstant(z3::expr z_const); - void VisitUnaryApp(z3::expr z_op); - void VisitBinaryApp(z3::expr z_op); - void VisitTernaryApp(z3::expr z_op); -}; - -} // namespace rellic diff --git a/lib/AST/CondBasedRefine.cpp b/lib/AST/CondBasedRefine.cpp index 8cfd57a7..b5d3744b 100644 --- a/lib/AST/CondBasedRefine.cpp +++ b/lib/AST/CondBasedRefine.cpp @@ -10,7 +10,6 @@ #include #include -#include #include diff --git a/lib/AST/MaterializeConds.cpp b/lib/AST/MaterializeConds.cpp index fa080364..d70cf745 100644 --- a/lib/AST/MaterializeConds.cpp +++ b/lib/AST/MaterializeConds.cpp @@ -8,12 +8,11 @@ #include "rellic/AST/MaterializeConds.h" -#include -#include -#include #include #include +#include "rellic/AST/Util.h" + namespace rellic { MaterializeConds::MaterializeConds(Provenance &provenance, clang::ASTUnit &unit) diff --git a/lib/AST/NormalizeCond.cpp b/lib/AST/NormalizeCond.cpp deleted file mode 100644 index 94d217cc..00000000 --- a/lib/AST/NormalizeCond.cpp +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Copyright (c) 2022-present, Trail of Bits, Inc. - * All rights reserved. - * - * This source code is licensed in accordance with the terms specified in - * the LICENSE file found in the root directory of this source tree. - */ - -#include "rellic/AST/NormalizeCond.h" - -#include -#include -#include -#include -#include -#include - -#include "rellic/AST/ASTBuilder.h" -#include "rellic/AST/InferenceRule.h" - -namespace rellic { - -namespace { - -using namespace clang::ast_matchers; - -static const auto zero_int_lit = integerLiteral(equals(0)); - -static inline std::string GetOperatorName(clang::BinaryOperator::Opcode op) { - return op == clang::BO_LAnd ? "&&" : "||"; -} - -class DeMorganRule : public InferenceRule { - clang::BinaryOperator::Opcode to; - - public: - DeMorganRule(clang::BinaryOperator::Opcode from, - clang::BinaryOperator::Opcode to) - : InferenceRule( - unaryOperator(hasOperatorName("!"), - has(ignoringParenImpCasts(binaryOperator( - hasOperatorName(GetOperatorName(from)))))) - .bind("not")), - to(to) {} - - void run(const MatchFinder::MatchResult &result) override { - match = result.Nodes.getNodeAs("not"); - } - - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, - clang::ASTUnit &unit, - clang::Stmt *stmt) override { - auto &ctx{unit.getASTContext()}; - ASTBuilder ast{unit}; - - auto unop{clang::cast(stmt)}; - CHECK(unop == match) - << "Substituted UnaryOperator is not the matched UnaryOperator"; - auto binop{clang::cast( - unop->getSubExpr()->IgnoreParenImpCasts())}; - - auto new_lhs{ast.CreateLNot(binop->getLHS())}; - auto new_rhs{ast.CreateLNot(binop->getRHS())}; - CopyProvenance(binop->getLHS(), new_lhs, provenance.use_provenance); - CopyProvenance(binop->getRHS(), new_rhs, provenance.use_provenance); - return ast.CreateBinaryOp(to, new_lhs, new_rhs); - } -}; - -class AssociativeRule : public InferenceRule { - clang::BinaryOperator::Opcode op; - - public: - AssociativeRule(clang::BinaryOperator::Opcode op) - : InferenceRule( - binaryOperator(hasOperatorName(GetOperatorName(op)), - hasRHS(ignoringParenImpCasts(binaryOperator( - hasOperatorName(GetOperatorName(op)))))) - .bind("binop")), - op(op) {} - - void run(const MatchFinder::MatchResult &result) override { - match = result.Nodes.getNodeAs("binop"); - } - - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, - clang::ASTUnit &unit, - clang::Stmt *stmt) override { - auto &ctx{unit.getASTContext()}; - ASTBuilder ast{unit}; - - auto outer{clang::cast(stmt)}; - CHECK(outer == match) - << "Substituted BinaryOperator is not the matched BinaryOperator"; - auto inner{clang::cast( - outer->getRHS()->IgnoreParenImpCasts())}; - - auto new_lhs{ast.CreateBinaryOp(op, outer->getLHS(), inner->getLHS())}; - auto new_rhs{inner->getRHS()}; - return ast.CreateBinaryOp(op, new_lhs, new_rhs); - } -}; - -class LDistributiveRule : public InferenceRule { - public: - LDistributiveRule() - : InferenceRule( - binaryOperator(hasOperatorName("||"), - hasLHS(ignoringParenImpCasts( - binaryOperator(hasOperatorName("&&"))))) - .bind("binop")) {} - - void run(const MatchFinder::MatchResult &result) override { - match = result.Nodes.getNodeAs("binop"); - } - - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, - clang::ASTUnit &unit, - clang::Stmt *stmt) override { - auto &ctx{unit.getASTContext()}; - ASTBuilder ast{unit}; - - auto outer{clang::cast(stmt)}; - CHECK(outer == match) - << "Substituted BinaryOperator is not the matched BinaryOperator"; - auto inner{clang::cast( - outer->getLHS()->IgnoreParenImpCasts())}; - - auto new_lhs{ast.CreateLOr(inner->getLHS(), outer->getRHS())}; - auto new_rhs{ast.CreateLOr(inner->getRHS(), outer->getRHS())}; - return ast.CreateLAnd(new_lhs, new_rhs); - } -}; - -class RDistributiveRule : public InferenceRule { - public: - RDistributiveRule() - : InferenceRule( - binaryOperator(hasOperatorName("||"), - hasRHS(ignoringParenImpCasts( - binaryOperator(hasOperatorName("&&"))))) - .bind("binop")) {} - - void run(const MatchFinder::MatchResult &result) override { - match = result.Nodes.getNodeAs("binop"); - } - - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, - clang::ASTUnit &unit, - clang::Stmt *stmt) override { - auto &ctx{unit.getASTContext()}; - ASTBuilder ast{unit}; - - auto outer{clang::cast(stmt)}; - CHECK(outer == match) - << "Substituted BinaryOperator is not the matched BinaryOperator"; - auto inner{clang::cast( - outer->getRHS()->IgnoreParenImpCasts())}; - - auto new_lhs{ast.CreateLOr(outer->getLHS(), inner->getLHS())}; - auto new_rhs{ast.CreateLOr(outer->getLHS(), inner->getRHS())}; - return ast.CreateLAnd(new_lhs, new_rhs); - } -}; - -} // namespace - -NormalizeCond::NormalizeCond(Provenance &provenance, clang::ASTUnit &u) - : TransformVisitor(provenance, u) {} - -bool NormalizeCond::VisitUnaryOperator(clang::UnaryOperator *op) { - std::vector> rules; - - rules.emplace_back(new DeMorganRule(clang::BO_LAnd, clang::BO_LOr)); - rules.emplace_back(new DeMorganRule(clang::BO_LOr, clang::BO_LAnd)); - - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, op, rules)}; - if (sub != op) { - substitutions[op] = sub; - } - - return !Stopped(); -} - -bool NormalizeCond::VisitBinaryOperator(clang::BinaryOperator *op) { - std::vector> rules; - - rules.emplace_back(new AssociativeRule(clang::BO_LAnd)); - rules.emplace_back(new AssociativeRule(clang::BO_LOr)); - rules.emplace_back(new LDistributiveRule); - rules.emplace_back(new RDistributiveRule); - - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, op, rules)}; - if (sub != op) { - substitutions[op] = sub; - } - - return !Stopped(); -} - -void NormalizeCond::RunImpl() { - LOG(INFO) << "Conversion into conjunctive normal form"; - TransformVisitor::RunImpl(); - TraverseDecl(ast_ctx.getTranslationUnitDecl()); -} - -} // namespace rellic \ No newline at end of file diff --git a/lib/AST/Util.cpp b/lib/AST/Util.cpp index aaeaa814..ad7fb6c3 100644 --- a/lib/AST/Util.cpp +++ b/lib/AST/Util.cpp @@ -318,19 +318,6 @@ clang::Expr *Clone(clang::ASTUnit &unit, clang::Expr *expr, return CHECK_NOTNULL(cloner.Visit(CHECK_NOTNULL(expr))); } -static clang::Expr *ApplyLNot(rellic::ASTBuilder &ast, clang::Expr *expr) { - if (auto unop = clang::dyn_cast(expr)) { - if (unop->getOpcode() == clang::UO_LNot) { - return unop->getSubExpr(); - } - } - return ast.CreateLNot(expr); -} - -clang::Expr *Negate(rellic::ASTBuilder &ast, clang::Expr *expr) { - return ApplyLNot(ast, expr->IgnoreParens())->IgnoreParens(); -} - std::string ClangThingToString(const clang::Stmt *stmt) { std::string s; llvm::raw_string_ostream os(s); diff --git a/lib/AST/Z3CondSimplify.cpp b/lib/AST/Z3CondSimplify.cpp index 1a20534d..ff7fe311 100644 --- a/lib/AST/Z3CondSimplify.cpp +++ b/lib/AST/Z3CondSimplify.cpp @@ -8,12 +8,11 @@ #include "rellic/AST/Z3CondSimplify.h" -#include -#include -#include #include #include +#include "rellic/AST/Util.h" + namespace rellic { Z3CondSimplify::Z3CondSimplify(Provenance &provenance, clang::ASTUnit &unit) diff --git a/lib/AST/Z3ConvVisitor.cpp b/lib/AST/Z3ConvVisitor.cpp deleted file mode 100644 index ef248197..00000000 --- a/lib/AST/Z3ConvVisitor.cpp +++ /dev/null @@ -1,1223 +0,0 @@ -/* - * Copyright (c) 2021-present, Trail of Bits, Inc. - * All rights reserved. - * - * This source code is licensed in accordance with the terms specified in - * the LICENSE file found in the root directory of this source tree. - */ - -#define GOOGLE_STRIP_LOG 1 - -#include "rellic/AST/Z3ConvVisitor.h" - -#include -#include - -#include "rellic/AST/Util.h" -#include "rellic/Exception.h" - -namespace rellic { - -namespace { - -static unsigned GetZ3SortSize(z3::sort sort) { - switch (sort.sort_kind()) { - case Z3_BOOL_SORT: - return 1; - break; - - case Z3_BV_SORT: - return sort.bv_size(); - break; - - case Z3_FLOATING_POINT_SORT: { - auto &ctx{sort.ctx()}; - return Z3_fpa_get_sbits(ctx, sort) + Z3_fpa_get_ebits(ctx, sort); - } break; - - case Z3_UNINTERPRETED_SORT: - return 0; - break; - - default: - LOG(FATAL) << "Unknown Z3 sort: " << sort; - break; - } - // This code is unreachable, but sometimes we need to - // fix 'error: control reaches end of non-void function' on some compilers. - return (unsigned)(-1); -} - -static unsigned GetZ3SortSize(z3::expr expr) { - return GetZ3SortSize(expr.get_sort()); -} - -// Determine if `op` is a `z3::concat(l, r)` that's -// equivalent to a sign extension. This is done by -// checking if `l` is an "all-one" or "all-zero" bit -// value. -static bool IsSignExt(z3::expr op) { - if (op.decl().decl_kind() != Z3_OP_CONCAT) { - return false; - } - - auto lhs{op.arg(0)}; - - if (lhs.is_numeral()) { - auto size{GetZ3SortSize(lhs)}; - llvm::APInt val(size, Z3_get_numeral_string(op.ctx(), lhs), 10); - return val.isAllOnesValue() || val.isNullValue(); - } - return false; -} - -static std::string CreateZ3DeclName(clang::NamedDecl *decl) { - std::stringstream ss; - ss << std::hex << decl << std::dec; - ss << '_' << decl->getNameAsString(); - return ss.str(); -} - -} // namespace - -Z3ConvVisitor::Z3ConvVisitor(clang::ASTUnit &unit, z3::context *z_ctx) - : c_ctx(&unit.getASTContext()), - ast(unit), - z_ctx(z_ctx), - z_expr_vec(*z_ctx), - z_decl_vec(*z_ctx) {} - -// Inserts a `clang::Expr` <=> `z3::expr` mapping into -void Z3ConvVisitor::InsertZ3Expr(clang::Expr *c_expr, z3::expr z_expr) { - CHECK(c_expr) << "Inserting null clang::Expr key."; - CHECK(bool(z_expr)) << "Inserting null z3::expr value."; - CHECK(!z_expr_map.count(c_expr)) << "clang::Expr key already exists."; - z_expr_map[c_expr] = z_expr_vec.size(); - z_expr_vec.push_back(z_expr); -} - -// Retrieves a `z3::expr` corresponding to `c_expr`. -// The `z3::expr` needs to be created and inserted by -// `Z3ConvVisistor::InsertZ3Expr` first. -z3::expr Z3ConvVisitor::GetZ3Expr(clang::Expr *c_expr) { - auto iter{z_expr_map.find(c_expr)}; - CHECK(iter != z_expr_map.end()); - return z_expr_vec[iter->second]; -} - -// Inserts a `clang::ValueDecl` <=> `z3::func_decl` mapping into -void Z3ConvVisitor::InsertZ3Decl(clang::ValueDecl *c_decl, - z3::func_decl z_decl) { - CHECK(c_decl) << "Inserting null clang::ValueDecl key."; - CHECK(bool(z_decl)) << "Inserting null z3::func_decl value."; - CHECK(!z_decl_map.count(c_decl)) << "clang::ValueDecl key already exists."; - z_decl_map[c_decl] = z_decl_vec.size(); - z_decl_vec.push_back(z_decl); -} - -// Retrieves a `z3::func_decl` corresponding to `c_decl`. -// The `z3::func_decl` needs to be created and inserted by -// `Z3ConvVisistor::InsertZ3Decl` first. -z3::func_decl Z3ConvVisitor::GetZ3Decl(clang::ValueDecl *c_decl) { - auto iter{z_decl_map.find(c_decl)}; - CHECK(iter != z_decl_map.end()); - return z_decl_vec[iter->second]; -} - -// If `expr` is not boolean, returns a `z3::expr` that corresponds -// to the non-boolean to boolean expression cast in C. Otherwise -// returns `expr`. -z3::expr Z3ConvVisitor::Z3BoolCast(z3::expr expr) { - if (expr.is_bool()) { - return expr; - } else { - return expr != z_ctx->num_val(0, expr.get_sort()); - } -} - -z3::expr Z3ConvVisitor::Z3BoolToBVCast(z3::expr expr) { - if (expr.is_bv()) { - return expr; - } - - CHECK(expr.is_bool()); - - auto size{c_ctx->getTypeSize(c_ctx->IntTy)}; - return z3::ite(expr, z_ctx->bv_val(1U, size), z_ctx->bv_val(0U, size)); -} - -void Z3ConvVisitor::InsertCExpr(z3::expr z_expr, clang::Expr *c_expr) { - CHECK(bool(z_expr)) << "Inserting null z3::expr key."; - CHECK(c_expr) << "Inserting null clang::Expr value."; - auto hash{z_expr.hash()}; - CHECK(!c_expr_map.count(hash)) << "z3::expr key already exists."; - c_expr_map[hash] = c_expr; -} - -clang::Expr *Z3ConvVisitor::GetCExpr(z3::expr z_expr) { - auto hash{z_expr.hash()}; - CHECK(c_expr_map.count(hash)) << "No Z3 equivalent for C expression!"; - return c_expr_map[hash]; -} - -void Z3ConvVisitor::InsertCValDecl(z3::func_decl z_decl, - clang::ValueDecl *c_decl) { - CHECK(c_decl) << "Inserting null z3::func_decl key."; - CHECK(bool(z_decl)) << "Inserting null clang::ValueDecl value."; - CHECK(!c_decl_map.count(z_decl.id())) << "z3::func_decl key already exists."; - c_decl_map[z_decl.id()] = c_decl; -} - -clang::ValueDecl *Z3ConvVisitor::GetCValDecl(z3::func_decl z_decl) { - CHECK(c_decl_map.count(z_decl.id())) << "No C equivalent for Z3 declaration!"; - return c_decl_map[z_decl.id()]; -} - -z3::sort Z3ConvVisitor::GetZ3Sort(clang::QualType type) { - // Void - if (type->isVoidType()) { - return z_ctx->uninterpreted_sort("void"); - } - // Booleans - if (type->isBooleanType()) { - return z_ctx->bool_sort(); - } - // Structures - if (type->isStructureType()) { - auto decl{clang::cast(type)->getDecl()}; - return z_ctx->uninterpreted_sort(decl->getNameAsString().c_str()); - } - auto bitwidth{c_ctx->getTypeSize(type)}; - // Floating points - if (type->isRealFloatingType()) { - switch (bitwidth) { - case 16: - return z3::to_sort(*z_ctx, Z3_mk_fpa_sort_16(*z_ctx)); - break; - - case 32: - return z3::to_sort(*z_ctx, Z3_mk_fpa_sort_32(*z_ctx)); - break; - - case 64: - return z3::to_sort(*z_ctx, Z3_mk_fpa_sort_64(*z_ctx)); - break; - - case 128: - return z3::to_sort(*z_ctx, Z3_mk_fpa_sort_128(*z_ctx)); - break; - - default: - LOG(FATAL) << "Unsupported floating-point bitwidth!"; - break; - } - } - // Default to bitvectors - return z3::to_sort(*z_ctx, z_ctx->bv_sort(bitwidth)); -} - -clang::Expr *Z3ConvVisitor::CreateLiteralExpr(z3::expr z_expr) { - DLOG(INFO) << "Creating literal clang::Expr for " << z_expr; - CHECK(z_expr.is_const()) << "Can only create C literal from Z3 constant!"; - auto z_sort{z_expr.get_sort()}; - clang::Expr *result{nullptr}; - switch (z_sort.sort_kind()) { - case Z3_BOOL_SORT: { - auto val{z_expr.bool_value() == Z3_L_TRUE ? 1U : 0U}; - result = ast.CreateIntLit(llvm::APInt(/*BitWidth=*/1U, val)); - } break; - - case Z3_BV_SORT: { - llvm::APInt val(GetZ3SortSize(z_expr), - Z3_get_numeral_string(z_expr.ctx(), z_expr), 10); - // Handle `char` and `short` types separately, because clang - // adds non-standard `Ui8` and `Ui16` suffixes respectively. - result = ast.CreateAdjustedIntLit(val); - } break; - - case Z3_FLOATING_POINT_SORT: { - auto size{GetZ3SortSize(z_sort)}; - const llvm::fltSemantics *semantics{nullptr}; - switch (size) { - case 16: - semantics = &llvm::APFloat::IEEEhalf(); - break; - case 32: - semantics = &llvm::APFloat::IEEEsingle(); - break; - case 64: - semantics = &llvm::APFloat::IEEEdouble(); - break; - case 128: - semantics = &llvm::APFloat::IEEEquad(); - break; - default: - THROW() << "Unknown Z3 floating-point sort!"; - break; - } - z3::expr bv(*z_ctx, Z3_mk_fpa_to_ieee_bv(*z_ctx, z_expr)); - auto bits{Z3_get_numeral_string(*z_ctx, bv.simplify())}; - CHECK(std::strlen(bits) > 0) - << "Failed to convert IEEE bitvector to string!"; - llvm::APInt api(size, bits, /*radix=*/10U); - result = ast.CreateFPLit(llvm::APFloat(*semantics, api)); - } break; - - default: - LOG(FATAL) << "Unknown Z3 sort: " << z_sort; - break; - } - - return result; -} - -// Retrieves or creates`z3::expr`s from `clang::Expr`. -z3::expr Z3ConvVisitor::GetOrCreateZ3Expr(clang::Expr *c_expr) { - if (!z_expr_map.count(c_expr)) { - TraverseStmt(c_expr); - } - - return GetZ3Expr(c_expr); -} - -z3::func_decl Z3ConvVisitor::GetOrCreateZ3Decl(clang::ValueDecl *c_decl) { - if (!z_decl_map.count(c_decl)) { - TraverseDecl(c_decl); - } - - auto z_decl{GetZ3Decl(c_decl)}; - if (!c_decl_map.count(z_decl.id())) { - InsertCValDecl(z_decl, c_decl); - } - - return z_decl; -} - -// Retrieves or creates `clang::Expr` from `z3::expr`. -clang::Expr *Z3ConvVisitor::GetOrCreateCExpr(z3::expr z_expr) { - if (!c_expr_map.count(z_expr.hash())) { - VisitZ3Expr(z_expr); - } - return GetCExpr(z_expr); -} - -// Retrieves or creates `clang::ValueDecl` from `z3::func_decl`. -clang::ValueDecl *Z3ConvVisitor::GetOrCreateCValDecl(z3::func_decl z_decl) { - if (!c_decl_map.count(z_decl.id())) { - VisitZ3Decl(z_decl); - } - return GetCValDecl(z_decl); -} - -bool Z3ConvVisitor::VisitVarDecl(clang::VarDecl *c_var) { - auto c_name{c_var->getNameAsString()}; - DLOG(INFO) << "VisitVarDecl: " << c_name; - if (z_decl_map.count(c_var)) { - DLOG(INFO) << "Re-declaration of " << c_name << "; Returning."; - return true; - } - - auto z_name{CreateZ3DeclName(c_var)}; - auto z_sort{GetZ3Sort(c_var->getType())}; - auto z_const{z_ctx->constant(z_name.c_str(), z_sort)}; - - InsertZ3Decl(c_var, z_const.decl()); - - return true; -} - -bool Z3ConvVisitor::VisitFieldDecl(clang::FieldDecl *c_field) { - auto c_name{c_field->getNameAsString()}; - DLOG(INFO) << "VisitFieldDecl: " << c_name; - if (z_decl_map.count(c_field)) { - DLOG(INFO) << "Re-declaration of " << c_name << "; Returning."; - return true; - } - - auto z_name{CreateZ3DeclName(c_field->getParent()) + "_" + c_name}; - auto z_sort{GetZ3Sort(c_field->getType())}; - auto z_const{z_ctx->constant(z_name.c_str(), z_sort)}; - - InsertZ3Decl(c_field, z_const.decl()); - - return true; -} - -bool Z3ConvVisitor::VisitFunctionDecl(clang::FunctionDecl *c_func) { - auto c_name{c_func->getNameAsString()}; - DLOG(INFO) << "VisitFunctionDecl: " << c_name; - if (z_decl_map.count(c_func)) { - DLOG(INFO) << "Re-declaration of " << c_name << "; Returning."; - return true; - } - - z3::sort_vector z_domains(*z_ctx); - for (auto c_param : c_func->parameters()) { - z_domains.push_back(GetZ3Sort(c_param->getType())); - } - - auto z_range{GetZ3Sort(c_func->getReturnType())}; - auto z_func{z_ctx->function(c_name.c_str(), z_domains, z_range)}; - - InsertZ3Decl(c_func, z_func); - - return true; -} - -z3::expr Z3ConvVisitor::CreateZ3BitwiseCast(z3::expr expr, size_t src, - size_t dst, bool sign) { - if (expr.is_bool()) { - expr = Z3BoolToBVCast(expr); - } - - // CHECK(expr.is_bv()) << "z3::expr is not a bitvector!"; - CHECK_EQ(GetZ3SortSize(expr), src); - - int64_t diff = dst - src; - // extend - if (diff > 0) { - return sign ? z3::sext(expr, diff) : z3::zext(expr, diff); - } - // truncate - if (diff < 0) { - return expr.extract(dst - 1, 0); - } - // nothing - return expr; -} - -template -bool Z3ConvVisitor::HandleCastExpr(T *c_cast) { - CHECK(clang::isa(c_cast) || - clang::isa(c_cast)); - // C exprs - auto c_sub{c_cast->getSubExpr()}; - // C types - auto c_src_ty{c_sub->getType()}; - auto c_dst_ty{c_cast->getType()}; - // C type sizes - auto src_ty_size{c_ctx->getTypeSize(c_src_ty)}; - auto dst_ty_size{c_ctx->getTypeSize(c_dst_ty)}; - // Z3 exprs - auto z_sub{GetZ3Expr(c_sub)}; - auto z_cast{z_sub}; - // Z3 sorts - auto z_src_sort{z_sub.get_sort()}; - auto z_dst_sort{z_ctx->bv_sort(dst_ty_size)}; - switch (c_cast->getCastKind()) { - case clang::CastKind::CK_IntegralCast: - z_cast = CreateZ3BitwiseCast(z_sub, src_ty_size, dst_ty_size, - c_src_ty->isSignedIntegerType()); - break; - - case clang::CastKind::CK_PointerToIntegral: - z_cast = z_ctx->function("PtrToInt", z_src_sort, z_dst_sort)(z_sub); - break; - - case clang::CastKind::CK_NullToPointer: - case clang::CastKind::CK_IntegralToPointer: { - auto c_dst_ty_ptr{reinterpret_cast(c_dst_ty.getAsOpaquePtr())}; - auto z_dst_ty_ptr{z_ctx->bv_val(c_dst_ty_ptr, 8 * sizeof(void *))}; - z_cast = z_ctx->function("IntToPtr", z_dst_ty_ptr.get_sort(), z_src_sort, - z_dst_sort)(z_dst_ty_ptr, z_sub); - } break; - - case clang::CastKind::CK_BitCast: { - auto c_dst_ty_ptr{reinterpret_cast(c_dst_ty.getAsOpaquePtr())}; - auto z_dst_ty_ptr{z_ctx->bv_val(c_dst_ty_ptr, 8 * sizeof(void *))}; - auto z_dst_sort{z_ctx->bv_sort(dst_ty_size)}; - z_cast = z_ctx->function("BitCast", z_dst_ty_ptr.get_sort(), z_src_sort, - z_dst_sort)(z_dst_ty_ptr, z_sub); - } break; - - case clang::CastKind::CK_ArrayToPointerDecay: { - CHECK(z_sub.is_bv()) << "Pointer cast operand is not a bit-vector"; - auto z_ptr_sort{GetZ3Sort(c_cast->getType())}; - auto z_arr_sort{z_sub.get_sort()}; - z_cast = z_ctx->function("PtrDecay", z_arr_sort, z_ptr_sort)(z_sub); - } break; - - case clang::CastKind::CK_NoOp: - case clang::CastKind::CK_LValueToRValue: - case clang::CastKind::CK_FunctionToPointerDecay: - // case clang::CastKind::CK_ArrayToPointerDecay: - break; - - default: - THROW() << "Unsupported cast type: " << c_cast->getCastKindName(); - break; - } - // Save - InsertZ3Expr(c_cast, z_cast); - - return true; -} - -bool Z3ConvVisitor::VisitCStyleCastExpr(clang::CStyleCastExpr *c_cast) { - DLOG(INFO) << "VisitCStyleCastExpr"; - if (z_expr_map.count(c_cast)) { - return true; - } - return HandleCastExpr(c_cast); -} - -bool Z3ConvVisitor::VisitImplicitCastExpr(clang::ImplicitCastExpr *c_cast) { - DLOG(INFO) << "VisitImplicitCastExpr"; - if (z_expr_map.count(c_cast)) { - return true; - } - return HandleCastExpr(c_cast); -} - -bool Z3ConvVisitor::VisitArraySubscriptExpr(clang::ArraySubscriptExpr *sub) { - DLOG(INFO) << "VisitArraySubscriptExpr"; - if (z_expr_map.count(sub)) { - return true; - } - // Get base - auto z_base{GetZ3Expr(sub->getBase())}; - auto base_sort{z_base.get_sort()}; - CHECK(base_sort.is_bv()) << "Invalid Z3 sort for base expression"; - // Get index - auto z_idx{GetZ3Expr(sub->getIdx())}; - auto idx_sort{z_idx.get_sort()}; - CHECK(idx_sort.is_bv()) << "Invalid Z3 sort for index expression"; - // Get result - auto elm_sort{GetZ3Sort(sub->getType())}; - // Create a z_function - auto z_arr_sub{z_ctx->function("ArraySub", base_sort, idx_sort, elm_sort)}; - // Create a z3 expression - InsertZ3Expr(sub, z_arr_sub(z_base, z_idx)); - // Done - return true; -} - -bool Z3ConvVisitor::VisitMemberExpr(clang::MemberExpr *expr) { - DLOG(INFO) << "VisitMemberExpr"; - if (z_expr_map.count(expr)) { - return true; - } - - auto z_mem{GetOrCreateZ3Decl(expr->getMemberDecl())()}; - auto z_base{GetZ3Expr(expr->getBase())}; - auto z_mem_expr{z_ctx->function("Member", z_base.get_sort(), z_mem.get_sort(), - z_mem.get_sort())}; - - InsertZ3Expr(expr, z_mem_expr(z_base, z_mem)); - - return true; -} - -bool Z3ConvVisitor::VisitCallExpr(clang::CallExpr *c_call) { - DLOG(INFO) << "VisitCallExpr"; - if (z_expr_map.count(c_call)) { - return true; - } - z3::expr_vector z_args(*z_ctx); - // Get call id - z_args.push_back(z_ctx->bv_val(GetHash(*c_ctx, c_call), /*sz=*/64U)); - // Get callee - auto z_callee{GetZ3Expr(c_call->getCallee())}; - z_args.push_back(z_callee); - // Get call arguments - for (auto c_arg : c_call->arguments()) { - z_args.push_back(GetZ3Expr(c_arg)); - } - // Build the z3 call - z3::sort_vector z_domain(*z_ctx); - for (auto z_arg : z_args) { - z_domain.push_back(z_arg.get_sort()); - } - auto z_range{z_callee.is_array() ? z_callee.get_sort().array_range() - : z_callee.get_sort()}; - auto z_call{z_ctx->function("Call", z_domain, z_range)}; - // Insert the call - InsertZ3Expr(c_call, z_call(z_args)); - - return true; -} - -// Translates clang unary operators expressions to Z3 equivalents. -bool Z3ConvVisitor::VisitParenExpr(clang::ParenExpr *parens) { - DLOG(INFO) << "VisitParenExpr"; - if (z_expr_map.count(parens)) { - return true; - } - - InsertZ3Expr(parens, GetZ3Expr(parens->getSubExpr())); - - return true; -} - -// Translates clang unary operators expressions to Z3 equivalents. -bool Z3ConvVisitor::VisitUnaryOperator(clang::UnaryOperator *c_op) { - DLOG(INFO) << "VisitUnaryOperator: " - << c_op->getOpcodeStr(c_op->getOpcode()).str(); - if (z_expr_map.count(c_op)) { - return true; - } - // Get operand - auto operand{GetZ3Expr(c_op->getSubExpr())}; - // Conditionally cast operands to a bitvector - auto CondBoolToBVCast{[this, &operand]() { - if (operand.is_bool()) { - operand = Z3BoolToBVCast(operand); - } - }}; - // Create z3 unary op - switch (c_op->getOpcode()) { - case clang::UO_LNot: - InsertZ3Expr(c_op, !Z3BoolCast(operand)); - break; - - case clang::UO_Minus: - InsertZ3Expr(c_op, -operand); - break; - - case clang::UO_Not: - CondBoolToBVCast(); - InsertZ3Expr(c_op, ~operand); - break; - - case clang::UO_AddrOf: { - auto ptr_sort{GetZ3Sort(c_op->getType())}; - auto z_addrof{z_ctx->function("AddrOf", operand.get_sort(), ptr_sort)}; - InsertZ3Expr(c_op, z_addrof(operand)); - } break; - - case clang::UO_Deref: { - auto elm_sort{GetZ3Sort(c_op->getType())}; - auto z_deref{z_ctx->function("Deref", operand.get_sort(), elm_sort)}; - InsertZ3Expr(c_op, z_deref(operand)); - } break; - - default: - THROW() << "Unknown clang::UnaryOperator operation: " - << c_op->getOpcodeStr(c_op->getOpcode()).str(); - break; - } - return true; -} - -// Translates clang binary operators expressions to Z3 equivalents. -bool Z3ConvVisitor::VisitBinaryOperator(clang::BinaryOperator *c_op) { - DLOG(INFO) << "VisitBinaryOperator: " << c_op->getOpcodeStr().str(); - if (z_expr_map.count(c_op)) { - return true; - } - // Get operands - auto lhs{GetZ3Expr(c_op->getLHS())}; - auto rhs{GetZ3Expr(c_op->getRHS())}; - // Conditionally cast operands match size to the wider one - auto CondSizeCast{[this, &lhs, &rhs] { - auto lhs_size{GetZ3SortSize(lhs)}; - auto rhs_size{GetZ3SortSize(rhs)}; - if (lhs_size < rhs_size) { - lhs = CreateZ3BitwiseCast(lhs, lhs_size, rhs_size, /*sign=*/true); - } else if (lhs_size > rhs_size) { - rhs = CreateZ3BitwiseCast(rhs, rhs_size, lhs_size, /*sign=*/true); - } - }}; - // Conditionally cast operands to bool - auto CondBoolCast{[this, &lhs, &rhs]() { - if (lhs.is_bool() || rhs.is_bool()) { - lhs = Z3BoolCast(lhs); - rhs = Z3BoolCast(rhs); - } - }}; - // Conditionally cast operands to a bitvector - auto CondBoolToBVCast{[this, &lhs, &rhs]() { - if (lhs.is_bool()) { - CHECK(rhs.is_bv() || rhs.is_bool()); - lhs = Z3BoolToBVCast(lhs); - } - if (rhs.is_bool()) { - CHECK(lhs.is_bv()); - rhs = Z3BoolToBVCast(rhs); - } - }}; - // Create z3 binary op - switch (c_op->getOpcode()) { - case clang::BO_LAnd: - CondBoolCast(); - InsertZ3Expr(c_op, lhs.is_bv() && rhs.is_bv() ? lhs & rhs : lhs && rhs); - break; - - case clang::BO_LOr: - CondBoolCast(); - InsertZ3Expr(c_op, lhs.is_bv() && rhs.is_bv() ? lhs | rhs : lhs || rhs); - break; - - case clang::BO_EQ: { - CondBoolCast(); - CondSizeCast(); - InsertZ3Expr(c_op, lhs == rhs); - } break; - - case clang::BO_NE: - CondBoolCast(); - CondSizeCast(); - InsertZ3Expr(c_op, lhs != rhs); - break; - - case clang::BO_GE: - CondSizeCast(); - InsertZ3Expr(c_op, lhs >= rhs); - break; - - case clang::BO_GT: - CondSizeCast(); - InsertZ3Expr(c_op, lhs > rhs); - break; - - case clang::BO_LE: - CondSizeCast(); - InsertZ3Expr(c_op, lhs <= rhs); - break; - - case clang::BO_LT: - CondSizeCast(); - InsertZ3Expr(c_op, lhs < rhs); - break; - - case clang::BO_Rem: - CondSizeCast(); - InsertZ3Expr(c_op, z3::srem(lhs, rhs)); - break; - - case clang::BO_Add: - CondSizeCast(); - InsertZ3Expr(c_op, lhs + rhs); - break; - - case clang::BO_Sub: - CondSizeCast(); - InsertZ3Expr(c_op, lhs - rhs); - break; - - case clang::BO_Mul: - CondSizeCast(); - InsertZ3Expr(c_op, lhs * rhs); - break; - - case clang::BO_Div: - CondSizeCast(); - InsertZ3Expr(c_op, lhs / rhs); - break; - - case clang::BO_And: - CondBoolToBVCast(); - CondSizeCast(); - InsertZ3Expr(c_op, lhs & rhs); - break; - - case clang::BO_Or: - CondBoolToBVCast(); - CondSizeCast(); - InsertZ3Expr(c_op, lhs | rhs); - break; - - case clang::BO_Xor: - CondSizeCast(); - InsertZ3Expr(c_op, lhs ^ rhs); - break; - - case clang::BO_Shr: - CondBoolToBVCast(); - CondSizeCast(); - InsertZ3Expr(c_op, c_op->getLHS()->getType()->isSignedIntegerType() - ? z3::ashr(lhs, rhs) - : z3::lshr(lhs, rhs)); - break; - - case clang::BO_Shl: - CondSizeCast(); - InsertZ3Expr(c_op, z3::shl(lhs, rhs)); - break; - - default: - THROW() << "Unknown clang::BinaryOperator operation: " - << c_op->getOpcodeStr().str(); - break; - } - return true; -} - -bool Z3ConvVisitor::VisitConditionalOperator(clang::ConditionalOperator *c_op) { - DLOG(INFO) << "VisitConditionalOperator"; - if (z_expr_map.count(c_op)) { - return true; - } - - auto z_cond{Z3BoolCast(GetZ3Expr(c_op->getCond()))}; - auto z_then{GetZ3Expr(c_op->getTrueExpr())}; - auto z_else{GetZ3Expr(c_op->getFalseExpr())}; - - auto z_then_size{GetZ3SortSize(z_then)}; - auto z_else_size{GetZ3SortSize(z_else)}; - - if (z_then_size > z_else_size) { - z_else = - CreateZ3BitwiseCast(z_else, z_else_size, z_then_size, /*sign=*/true); - } else if (z_then_size < z_else_size) { - z_then = - CreateZ3BitwiseCast(z_then, z_then_size, z_else_size, /*sign=*/true); - } - - InsertZ3Expr(c_op, z3::ite(z_cond, z_then, z_else)); - - return true; -} - -// Translates clang variable references to Z3 constants. -bool Z3ConvVisitor::VisitDeclRefExpr(clang::DeclRefExpr *c_ref) { - auto c_ref_decl{c_ref->getDecl()}; - auto c_ref_name{c_ref_decl->getNameAsString()}; - DLOG(INFO) << "VisitDeclRefExpr: " << c_ref_name; - if (z_expr_map.count(c_ref)) { - return true; - } - - auto z_decl{GetOrCreateZ3Decl(c_ref_decl)}; - auto z_ref{z_decl.is_const() ? z_decl() : z3::as_array(z_decl)}; - - InsertZ3Expr(c_ref, z_ref); - - return true; -} - -// Translates clang character literals references to Z3 numeral values. -bool Z3ConvVisitor::VisitCharacterLiteral(clang::CharacterLiteral *c_lit) { - auto c_val{c_lit->getValue()}; - DLOG(INFO) << "VisitCharacterLiteral: " << c_val; - if (z_expr_map.count(c_lit)) { - return true; - } - - auto z_sort{GetZ3Sort(c_lit->getType())}; - auto z_val{z_ctx->num_val(c_val, z_sort)}; - InsertZ3Expr(c_lit, z_val); - - return true; -} - -// Translates clang integer literal references to Z3 numeral values. -bool Z3ConvVisitor::VisitIntegerLiteral(clang::IntegerLiteral *c_lit) { - auto c_val{c_lit->getValue().getLimitedValue()}; - DLOG(INFO) << "VisitIntegerLiteral: " << c_val; - if (z_expr_map.count(c_lit)) { - return true; - } - - auto z_sort{GetZ3Sort(c_lit->getType())}; - auto z_val{z_sort.is_bool() ? z_ctx->bool_val(c_val != 0) - : z_ctx->num_val(c_val, z_sort)}; - InsertZ3Expr(c_lit, z_val); - - return true; -} - -// Translates clang floating point literal references to Z3 numeral values. -bool Z3ConvVisitor::VisitFloatingLiteral(clang::FloatingLiteral *lit) { - auto api{lit->getValue().bitcastToAPInt()}; - DLOG(INFO) << "VisitFloatingLiteral: " << api.bitsToDouble(); - if (z_expr_map.count(lit)) { - return true; - } - - auto size{api.getBitWidth()}; - llvm::SmallString<64U> bits; - api.toString(bits, /*Radix=*/10U, /*Signed=*/false); - auto sort{GetZ3Sort(lit->getType())}; - auto bv{z_ctx->bv_val(bits.c_str(), size)}; - auto fpa{z3::to_expr(*z_ctx, Z3_mk_fpa_to_fp_bv(*z_ctx, bv, sort))}; - - InsertZ3Expr(lit, fpa); - - return true; -} - -void Z3ConvVisitor::VisitZ3Expr(z3::expr z_expr) { - auto z_decl{z_expr.decl()}; - CHECK(z_expr.is_app()) << "Unexpected Z3 operation: " << z_decl.name(); - // Handle arguments first - for (auto i{0U}; i < z_expr.num_args(); ++i) { - GetOrCreateCExpr(z_expr.arg(i)); - } - // TODO(msurovic): Rework this into a visitor based on - // z_expr.decl().decl_kind() - // Handle bitvector concats - if (z_decl.decl_kind() == Z3_OP_CONCAT) { - InsertCExpr(z_expr, HandleZ3Concat(z_expr)); - return; - } - // Handle uninterpreted functions - if (z_decl.decl_kind() == Z3_OP_UNINTERPRETED) { - InsertCExpr(z_expr, HandleZ3Uninterpreted(z_expr)); - return; - } - // Handle the rest - switch (z_decl.arity()) { - case 0: - VisitConstant(z_expr); - break; - - case 1: - VisitUnaryApp(z_expr); - break; - - case 2: - VisitBinaryApp(z_expr); - break; - - case 3: - VisitTernaryApp(z_expr); - break; - - default: - LOG(FATAL) << "Unexpected Z3 operation: " << z_decl.name(); - break; - } -} - -void Z3ConvVisitor::VisitConstant(z3::expr z_const) { - DLOG(INFO) << "VisitConstant: " << z_const; - CHECK(z_const.is_const()) << "Z3 expression is not a constant!"; - // Create C literals and variable references - clang::Expr *c_expr{nullptr}; - switch (z_const.decl().decl_kind()) { - // Boolean literals - case Z3_OP_TRUE: - case Z3_OP_FALSE: - // Arithmetic numerals - case Z3_OP_ANUM: - // Bitvector numerals - case Z3_OP_BNUM: - // Floating-point numerals - case Z3_OP_FPA_NUM: - case Z3_OP_FPA_PLUS_INF: - case Z3_OP_FPA_MINUS_INF: - case Z3_OP_FPA_NAN: - case Z3_OP_FPA_PLUS_ZERO: - case Z3_OP_FPA_MINUS_ZERO: - c_expr = CreateLiteralExpr(z_const); - break; - // Functions-as-array expressions - case Z3_OP_AS_ARRAY: { - z3::func_decl z_decl(*z_ctx, Z3_get_as_array_func_decl(*z_ctx, z_const)); - c_expr = ast.CreateDeclRef(GetOrCreateCValDecl(z_decl)); - } break; - // Internal constants handled by parent Z3 exprs - case Z3_OP_INTERNAL: - break; - // Unknowns - default: - THROW() << "Unknown Z3 constant: " << z_const; - break; - } - InsertCExpr(z_const, c_expr); -} - -clang::Expr *Z3ConvVisitor::HandleZ3Concat(z3::expr z_op) { - auto lhs{GetCExpr(z_op.arg(0U))}; - for (auto i{1U}; i < z_op.num_args(); ++i) { - // Given a `(concat l r)` we generate `((t)l << w) | r` where - // * `w` is the bitwidth of `r` - // - // * `t` is the smallest integer type that can fit the result - // of `(concat l r)` - auto rhs{GetCExpr(z_op.arg(i))}; - auto res_ty{ast.GetLeastIntTypeForBitWidth(GetZ3SortSize(z_op), - /*sign=*/0U)}; - if (!IsSignExt(z_op)) { - auto cast{ast.CreateCStyleCast(res_ty, lhs)}; - auto shl_val{ - ast.CreateIntLit(llvm::APInt(32U, GetZ3SortSize(z_op.arg(1U))))}; - auto shl{ast.CreateShl(cast, shl_val)}; - lhs = ast.CreateOr(shl, rhs); - } else { - lhs = ast.CreateCStyleCast(res_ty, rhs); - } - } - return lhs; -} - -clang::Expr *Z3ConvVisitor::HandleZ3Uninterpreted(z3::expr z_op) { - clang::Expr *c_op{nullptr}; - // Constants - if (z_op.is_const()) { - auto c_decl{GetOrCreateCValDecl(z_op.decl())}; - if (clang::isa(c_decl)) { - c_op = ast.CreateNull(); - } else { - c_op = ast.CreateDeclRef(c_decl); - } - return c_op; - } - // Functions - auto z_decl{z_op.decl()}; - auto z_func_name{z_decl.name().str()}; - auto lhs{[this, &z_op] { return GetCExpr(z_op.arg(0U)); }}; - auto rhs{[this, &z_op] { return GetCExpr(z_op.arg(1U)); }}; - // Get result type for casts - auto GetTypeFromOpaquePtrLiteral{[&lhs] { - auto c_lit{clang::cast(lhs())}; - auto t_dst_opaque_ptr_val{c_lit->getValue().getLimitedValue()}; - auto t_dst_opaque_ptr{reinterpret_cast(t_dst_opaque_ptr_val)}; - return clang::QualType::getFromOpaquePtr(t_dst_opaque_ptr); - }}; - if (z_func_name == "AddrOf") { - c_op = ast.CreateAddrOf(lhs()); - } else if (z_func_name == "Deref") { - c_op = ast.CreateDeref(lhs()); - } else if (z_func_name == "PtrToInt") { - auto s_size{GetZ3SortSize(z_op)}; - auto t_op{ast.GetLeastIntTypeForBitWidth(s_size, /*sign=*/0U)}; - c_op = ast.CreateCStyleCast(t_op, lhs()); - } else if (z_func_name == "PtrDecay") { - c_op = lhs(); - } else if (z_func_name == "ArraySub") { - c_op = ast.CreateArraySub(lhs(), rhs()); - } else if (z_func_name == "Member") { - auto mem{GetOrCreateCValDecl(z_op.arg(1U).decl())}; - auto field{clang::dyn_cast(mem)}; - CHECK(field != nullptr) << "Operand is not a clang::FieldDecl"; - c_op = ast.CreateDot(lhs(), field); - } else if (z_func_name == "IntToPtr" || z_func_name == "BitCast") { - c_op = ast.CreateCStyleCast(GetTypeFromOpaquePtrLiteral(), rhs()); - } else if (z_func_name == "Call") { - auto c_callee{rhs()}; - std::vector c_args; - for (auto i{2U}; i < z_op.num_args(); ++i) { - c_args.push_back(GetCExpr(z_op.arg(i))); - } - c_op = ast.CreateCall(c_callee, c_args); - } else { - LOG(FATAL) << "Unknown Z3 uninterpreted function: " << z_func_name; - } - return c_op; -} - -void Z3ConvVisitor::VisitUnaryApp(z3::expr z_op) { - DLOG(INFO) << "VisitUnaryApp: " << z_op; - CHECK(z_op.is_app() && z_op.decl().arity() == 1) - << "Z3 expression is not a unary operator!"; - // Get operand - auto c_sub{GetCExpr(z_op.arg(0U))}; - // Get z3 function declaration - auto z_decl{z_op.decl()}; - // Create C unary operator - clang::Expr *c_op{nullptr}; - switch (z_decl.decl_kind()) { - case Z3_OP_NOT: - c_op = ast.CreateLNot(c_sub); - break; - - case Z3_OP_BNOT: - c_op = ast.CreateNot(c_sub); - break; - // Given a `(extract hi lo o)` we generate `(o >> lo & m)` where: - // - // * `o` is the operand from which we extract a bit sequence - // - // * `hi` is the upper bound of the extracted sequence - // - // * `lo` is the lower bound of the extracted sequence - // - // * `m` is a bitmask integer literal - case Z3_OP_EXTRACT: { - if (z_op.lo() != 0U) { - auto shr_val{ast.CreateIntLit(llvm::APInt(32U, z_op.lo()))}; - auto shr{ast.CreateShr(c_sub, shr_val)}; - auto mask_val{ast.CreateIntLit( - llvm::APInt::getAllOnesValue(GetZ3SortSize(z_op)))}; - c_op = ast.CreateAnd(shr, mask_val); - } else { - c_op = ast.CreateCStyleCast( - ast.GetLeastIntTypeForBitWidth(GetZ3SortSize(z_op), - /*sign=*/0), - c_sub); - } - - } break; - - case Z3_OP_ZERO_EXT: - c_op = ast.CreateCStyleCast( - ast.GetLeastIntTypeForBitWidth(GetZ3SortSize(z_op), - /*sign=*/0), - c_sub); - break; - - case Z3_OP_SIGN_EXT: - c_op = ast.CreateCStyleCast( - ast.GetLeastIntTypeForBitWidth(GetZ3SortSize(z_op), - /*sign=*/1), - c_sub); - break; - - default: - LOG(FATAL) << "Unknown Z3 unary operator: " << z_decl.name(); - break; - } - // Save - InsertCExpr(z_op, c_op); -} - -void Z3ConvVisitor::VisitBinaryApp(z3::expr z_op) { - DLOG(INFO) << "VisitBinaryApp: " << z_op; - CHECK(z_op.is_app() && z_op.decl().arity() == 2) - << "Z3 expression is not a binary operator!"; - // Get operands - auto lhs{GetCExpr(z_op.arg(0U))}; - auto rhs{GetCExpr(z_op.arg(1U))}; - // Get z3 function declaration - auto z_decl{z_op.decl()}; - // Create C binary operator - clang::Expr *c_op{nullptr}; - switch (z_decl.decl_kind()) { - // `&&` in z3 can be n-ary, so we create a tree of C binary `&&`. - case Z3_OP_AND: { - c_op = lhs; - for (auto i{1U}; i < z_op.num_args(); ++i) { - rhs = GetCExpr(z_op.arg(i)); - c_op = ast.CreateLAnd(c_op, rhs); - } - } break; - // `||` in z3 can be n-ary, so we create a tree of C binary `||`. - case Z3_OP_OR: { - c_op = lhs; - for (auto i{1U}; i < z_op.num_args(); ++i) { - rhs = GetCExpr(z_op.arg(i)); - c_op = ast.CreateLOr(c_op, rhs); - } - } break; - - case Z3_OP_EQ: - c_op = ast.CreateEQ(lhs, rhs); - break; - - case Z3_OP_SLEQ: - case Z3_OP_ULEQ: - case Z3_OP_FPA_LE: - c_op = ast.CreateLE(lhs, rhs); - break; - - case Z3_OP_SLT: - case Z3_OP_ULT: - case Z3_OP_FPA_LT: - c_op = ast.CreateLT(lhs, rhs); - break; - - case Z3_OP_SGT: - case Z3_OP_UGT: - case Z3_OP_FPA_GT: - c_op = ast.CreateGT(lhs, rhs); - break; - - case Z3_OP_BADD: - c_op = ast.CreateAdd(lhs, rhs); - break; - - case Z3_OP_BLSHR: - c_op = ast.CreateShr(lhs, rhs); - break; - - case Z3_OP_BASHR: { - auto size{c_ctx->getTypeSize(lhs->getType())}; - auto type{ast.GetLeastIntTypeForBitWidth(size, /*sign=*/1U)}; - auto cast{ast.CreateCStyleCast(type, lhs)}; - c_op = ast.CreateShr(cast, rhs); - } break; - - case Z3_OP_BSHL: - c_op = ast.CreateShl(lhs, rhs); - break; - - case Z3_OP_BAND: - c_op = ast.CreateAnd(lhs, rhs); - break; - - case Z3_OP_BOR: - c_op = ast.CreateOr(lhs, rhs); - break; - - case Z3_OP_BXOR: - c_op = ast.CreateXor(lhs, rhs); - break; - - case Z3_OP_BMUL: - c_op = ast.CreateMul(lhs, rhs); - break; - - case Z3_OP_BSDIV: - case Z3_OP_BSDIV_I: - c_op = ast.CreateDiv(lhs, rhs); - break; - - case Z3_OP_BSREM: - case Z3_OP_BSREM_I: - c_op = ast.CreateRem(lhs, rhs); - break; - - case Z3_OP_DISTINCT: - c_op = ast.CreateNE(lhs, rhs); - break; - - // Unknowns - default: - LOG(FATAL) << "Unknown Z3 binary operator: " << z_decl.name(); - break; - } - // Save - InsertCExpr(z_op, c_op); -} - -void Z3ConvVisitor::VisitTernaryApp(z3::expr z_op) { - DLOG(INFO) << "VisitTernaryApp: " << z_op; - CHECK(z_op.is_app() && z_op.decl().arity() == 3) - << "Z3 expression is not a ternary operator!"; - // Create C binary operator - clang::Expr *c_op{nullptr}; - auto z_decl{z_op.decl()}; - switch (z_decl.decl_kind()) { - case Z3_OP_ITE: { - auto z_cond{z_op.arg(0U)}; - auto z_then{z_op.arg(1U)}; - auto z_else{z_op.arg(2U)}; - uint64_t then_val{0U}; - uint64_t else_val{1U}; - if (z_then.is_numeral_u64(then_val) && z_else.is_numeral_u64(else_val) && - then_val == 1U && else_val == 0U) { - c_op = GetCExpr(z_cond); - } else { - c_op = ast.CreateConditional(GetCExpr(z_cond), GetCExpr(z_then), - GetCExpr(z_else)); - } - } break; - // Unknowns - default: - LOG(FATAL) << "Unknown Z3 ternary operator: " << z_decl.name(); - break; - } - // Save - InsertCExpr(z_op, c_op); -} - -void Z3ConvVisitor::VisitZ3Decl(z3::func_decl z_decl) { - THROW() << "Unimplemented Z3 declaration visitor!"; - // switch (z_decl.decl_kind()) { - // case Z3_OP_UNINTERPRETED: - // if (!z_decl.is_const()) { - - // } else { - // LOG(FATAL) << "Unimplemented Z3 declaration!"; - // } - // break; - - // default: - // LOG(FATAL) << "Unknown Z3 function declaration: " << z_decl.name(); - // break; - // } -} - -} // namespace rellic \ No newline at end of file diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 933278ff..8d6f590b 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -35,7 +35,6 @@ set(AST_HEADERS "${include_dir}/AST/MaterializeConds.h" "${include_dir}/AST/NestedCondProp.h" "${include_dir}/AST/NestedScopeCombine.h" - "${include_dir}/AST/NormalizeCond.h" "${include_dir}/AST/ReachBasedRefine.h" "${include_dir}/AST/StructFieldRenamer.h" "${include_dir}/AST/StructGenerator.h" @@ -43,7 +42,6 @@ set(AST_HEADERS "${include_dir}/AST/TransformVisitor.h" "${include_dir}/AST/Util.h" "${include_dir}/AST/Z3CondSimplify.h" - "${include_dir}/AST/Z3ConvVisitor.h" ) set(BC_HEADERS @@ -76,10 +74,8 @@ set(AST_SOURCES AST/MaterializeConds.cpp AST/NestedCondProp.cpp AST/NestedScopeCombine.cpp - AST/NormalizeCond.cpp AST/Util.cpp AST/Z3CondSimplify.cpp - AST/Z3ConvVisitor.cpp AST/ReachBasedRefine.cpp AST/StructFieldRenamer.cpp AST/StructGenerator.cpp diff --git a/tools/repl/Repl.cpp b/tools/repl/Repl.cpp index 741e0994..e6668d5a 100644 --- a/tools/repl/Repl.cpp +++ b/tools/repl/Repl.cpp @@ -44,7 +44,6 @@ #include "rellic/AST/MaterializeConds.h" #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" -#include "rellic/AST/NormalizeCond.h" #include "rellic/AST/ReachBasedRefine.h" #include "rellic/AST/StructFieldRenamer.h" #include "rellic/AST/Z3CondSimplify.h" @@ -161,8 +160,6 @@ static std::unique_ptr CreatePass(const std::string& name) { return std::make_unique(*provenance, *ast_unit); } else if (name == "nsc") { return std::make_unique(*provenance, *ast_unit); - } else if (name == "nc") { - return std::make_unique(*provenance, *ast_unit); } else if (name == "rbr") { return std::make_unique(*provenance, *ast_unit); } else if (name == "zcs") { diff --git a/tools/xref/Xref.cpp b/tools/xref/Xref.cpp index 518bcea4..c6411193 100644 --- a/tools/xref/Xref.cpp +++ b/tools/xref/Xref.cpp @@ -52,7 +52,6 @@ #include "rellic/AST/MaterializeConds.h" #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" -#include "rellic/AST/NormalizeCond.h" #include "rellic/AST/ReachBasedRefine.h" #include "rellic/AST/StructFieldRenamer.h" #include "rellic/AST/Util.h" @@ -421,9 +420,6 @@ static std::unique_ptr CreatePass( } else if (str == "nsc") { return std::make_unique(*session.Provenance, *session.Unit); - } else if (str == "nc") { - return std::make_unique(*session.Provenance, - *session.Unit); } else if (str == "rbr") { return std::make_unique(*session.Provenance, *session.Unit); diff --git a/unittests/CMakeLists.txt b/unittests/CMakeLists.txt index 164abd67..684d0fb8 100644 --- a/unittests/CMakeLists.txt +++ b/unittests/CMakeLists.txt @@ -8,15 +8,14 @@ add_executable(${RELLIC_UNITTEST} AST/ASTBuilder.cpp AST/StructGenerator.cpp AST/Util.cpp - AST/Z3ConvVisitor.cpp UnitTest.cpp ) target_link_libraries(${RELLIC_UNITTEST} PRIVATE - "${PROJECT_NAME}_cxx_settings" - "${PROJECT_NAME}" - z3::libz3 - doctest::doctest + "${PROJECT_NAME}_cxx_settings" + "${PROJECT_NAME}" + z3::libz3 + doctest::doctest ) target_compile_options(${RELLIC_UNITTEST} PRIVATE -fexceptions) From 3661cc0f849438876b860f20040415cdbe071a01 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Fri, 19 Aug 2022 12:46:19 +0200 Subject: [PATCH 31/57] Remove `nc` from tools --- tools/repl/Repl.cpp | 5 ++--- tools/xref/www/main.js | 5 ----- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/tools/repl/Repl.cpp b/tools/repl/Repl.cpp index e6668d5a..facf8ac8 100644 --- a/tools/repl/Repl.cpp +++ b/tools/repl/Repl.cpp @@ -93,8 +93,8 @@ static void SetVersion(void) { google::SetVersionString(version.str()); } -static const char* available_passes[] = {"cbr", "dse", "ec", "lr", "ncp", - "nsc", "nc", "rbr", "zcs"}; +static const char* available_passes[] = {"cbr", "dse", "ec", "lr", + "ncp", "nsc", "rbr", "zcs"}; static bool diff = false; @@ -198,7 +198,6 @@ static void do_help() { << " ec Expression combination\n" << " lr Loop refinement\n" << " mc Condition materialization\n" - << " nc Condition normalization\n" << " ncp Nested condition propagation\n" << " nsc Nested scope combination\n" << " rbr Reach-based refinement\n" diff --git a/tools/xref/www/main.js b/tools/xref/www/main.js index 848a7c34..074e87b5 100644 --- a/tools/xref/www/main.js +++ b/tools/xref/www/main.js @@ -32,10 +32,6 @@ const ec = { id: "ec", label: "Expression combination" } -const nc = { - id: "nc", - label: "Condition normalization" -} const mc = { id: "mc", label: "Materialize conditions" @@ -110,7 +106,6 @@ const app = new Vue({ rbr, lr, ec, - nc, mc ], actions: [ From 0526a4a397f2958fc2e91a52d449faaf4801ea61 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Fri, 19 Aug 2022 16:10:49 +0200 Subject: [PATCH 32/57] Maybe this time `HeavySimplify` will work? --- lib/AST/ReachBasedRefine.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/AST/ReachBasedRefine.cpp b/lib/AST/ReachBasedRefine.cpp index 95ab1786..533be292 100644 --- a/lib/AST/ReachBasedRefine.cpp +++ b/lib/AST/ReachBasedRefine.cpp @@ -11,6 +11,8 @@ #include #include +#include "rellic/AST/Util.h" + namespace rellic { ReachBasedRefine::ReachBasedRefine(Provenance &provenance, clang::ASTUnit &unit) @@ -44,7 +46,7 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { } // Is the current `if` statement unreachable from all the others? - bool is_unreachable{Prove(!(cond && z3::mk_or(conds)))}; + bool is_unreachable{Prove(HeavySimplify(!(cond && z3::mk_or(conds))))}; if (!is_unreachable) { ResetChain(); @@ -54,7 +56,7 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { conds.push_back(cond); // Do the collected statements cover all possibilities? - auto is_complete{Prove(z3::mk_or(conds))}; + auto is_complete{Prove(HeavySimplify(z3::mk_or(conds)))}; if (ifs.size() <= 2 || !is_complete) { // We need to collect more statements From 7319504613fc3261727369a36ed999250cd920e5 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Mon, 22 Aug 2022 10:45:47 +0200 Subject: [PATCH 33/57] Rename "Provenance" to "DecompilationContext" --- include/rellic/AST/ASTPass.h | 10 +- include/rellic/AST/CondBasedRefine.h | 2 +- include/rellic/AST/DeadStmtElim.h | 2 +- include/rellic/AST/ExprCombine.h | 2 +- include/rellic/AST/GenerateAST.h | 6 +- include/rellic/AST/IRToASTVisitor.h | 4 +- include/rellic/AST/InferenceRule.h | 4 +- include/rellic/AST/LocalDeclRenamer.h | 2 +- include/rellic/AST/LoopRefine.h | 2 +- include/rellic/AST/MaterializeConds.h | 2 +- include/rellic/AST/NestedCondProp.h | 2 +- include/rellic/AST/NestedScopeCombine.h | 2 +- include/rellic/AST/ReachBasedRefine.h | 2 +- include/rellic/AST/StructFieldRenamer.h | 2 +- include/rellic/AST/TransformVisitor.h | 8 +- include/rellic/AST/Util.h | 2 +- include/rellic/AST/Z3CondSimplify.h | 2 +- lib/AST/CondBasedRefine.cpp | 17 +-- lib/AST/DeadStmtElim.cpp | 8 +- lib/AST/ExprCombine.cpp | 64 +++++------ lib/AST/GenerateAST.cpp | 143 ++++++++++++------------ lib/AST/IRToASTVisitor.cpp | 116 ++++++++++--------- lib/AST/InferenceRule.cpp | 4 +- lib/AST/LocalDeclRenamer.cpp | 8 +- lib/AST/LoopRefine.cpp | 93 ++++++++------- lib/AST/MaterializeConds.cpp | 19 ++-- lib/AST/NestedCondProp.cpp | 37 +++--- lib/AST/NestedScopeCombine.cpp | 8 +- lib/AST/ReachBasedRefine.cpp | 9 +- lib/AST/StructFieldRenamer.cpp | 6 +- lib/AST/Z3CondSimplify.cpp | 11 +- lib/Decompiler.cpp | 58 +++++----- tools/repl/Repl.cpp | 36 +++--- tools/xref/Xref.cpp | 62 +++++----- 34 files changed, 375 insertions(+), 380 deletions(-) diff --git a/include/rellic/AST/ASTPass.h b/include/rellic/AST/ASTPass.h index 54c48128..a20263b1 100644 --- a/include/rellic/AST/ASTPass.h +++ b/include/rellic/AST/ASTPass.h @@ -21,7 +21,7 @@ class ASTPass { std::atomic_bool stop{false}; protected: - Provenance& provenance; + DecompilationContext& dec_ctx; clang::ASTUnit& ast_unit; clang::ASTContext& ast_ctx; ASTBuilder ast; @@ -32,8 +32,8 @@ class ASTPass { virtual void StopImpl() {} public: - ASTPass(Provenance& provenance, clang::ASTUnit& ast_unit) - : provenance(provenance), + ASTPass(DecompilationContext& dec_ctx, clang::ASTUnit& ast_unit) + : dec_ctx(dec_ctx), ast_unit(ast_unit), ast_ctx(ast_unit.getASTContext()), ast(ast_unit) {} @@ -89,8 +89,8 @@ class CompositeASTPass : public ASTPass { } public: - CompositeASTPass(Provenance& provenance, clang::ASTUnit& ast_unit) - : ASTPass(provenance, ast_unit) {} + CompositeASTPass(DecompilationContext& dec_ctx, clang::ASTUnit& ast_unit) + : ASTPass(dec_ctx, ast_unit) {} std::vector>& GetPasses() { return passes; } }; } // namespace rellic \ No newline at end of file diff --git a/include/rellic/AST/CondBasedRefine.h b/include/rellic/AST/CondBasedRefine.h index c6a519f7..28905b7f 100644 --- a/include/rellic/AST/CondBasedRefine.h +++ b/include/rellic/AST/CondBasedRefine.h @@ -38,7 +38,7 @@ class CondBasedRefine : public TransformVisitor { void RunImpl() override; public: - CondBasedRefine(Provenance &provenance, clang::ASTUnit &unit); + CondBasedRefine(DecompilationContext &dec_ctx, clang::ASTUnit &unit); bool VisitCompoundStmt(clang::CompoundStmt *compound); }; diff --git a/include/rellic/AST/DeadStmtElim.h b/include/rellic/AST/DeadStmtElim.h index 9d011c30..0eec0eda 100644 --- a/include/rellic/AST/DeadStmtElim.h +++ b/include/rellic/AST/DeadStmtElim.h @@ -22,7 +22,7 @@ class DeadStmtElim : public TransformVisitor { void RunImpl() override; public: - DeadStmtElim(Provenance &provenance, clang::ASTUnit &unit); + DeadStmtElim(DecompilationContext &dec_ctx, clang::ASTUnit &unit); bool VisitIfStmt(clang::IfStmt *ifstmt); bool VisitCompoundStmt(clang::CompoundStmt *compound); diff --git a/include/rellic/AST/ExprCombine.h b/include/rellic/AST/ExprCombine.h index cd4be916..1fe86dd2 100644 --- a/include/rellic/AST/ExprCombine.h +++ b/include/rellic/AST/ExprCombine.h @@ -21,7 +21,7 @@ class ExprCombine : public TransformVisitor { void RunImpl() override; public: - ExprCombine(Provenance &provenance, clang::ASTUnit &unit); + ExprCombine(DecompilationContext &dec_ctx, clang::ASTUnit &unit); bool VisitCStyleCastExpr(clang::CStyleCastExpr *cast); bool VisitUnaryOperator(clang::UnaryOperator *op); diff --git a/include/rellic/AST/GenerateAST.h b/include/rellic/AST/GenerateAST.h index de1fa626..e976292c 100644 --- a/include/rellic/AST/GenerateAST.h +++ b/include/rellic/AST/GenerateAST.h @@ -33,7 +33,7 @@ class GenerateAST : public llvm::AnalysisInfoMixin { clang::ASTContext *ast_ctx; rellic::IRToASTVisitor ast_gen; rellic::ASTBuilder ast; - Provenance &provenance; + DecompilationContext &dec_ctx; bool reaching_conds_changed{true}; std::unordered_map block_stmts; std::unordered_map region_stmts; @@ -68,12 +68,12 @@ class GenerateAST : public llvm::AnalysisInfoMixin { public: using Result = llvm::PreservedAnalyses; - GenerateAST(Provenance &provenance, clang::ASTUnit &unit); + GenerateAST(DecompilationContext &dec_ctx, clang::ASTUnit &unit); Result run(llvm::Function &F, llvm::FunctionAnalysisManager &FAM); Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM); - static void run(llvm::Module &M, Provenance &provenance, + static void run(llvm::Module &M, DecompilationContext &dec_ctx, clang::ASTUnit &unit); }; diff --git a/include/rellic/AST/IRToASTVisitor.h b/include/rellic/AST/IRToASTVisitor.h index 3e8436af..5722a0a5 100644 --- a/include/rellic/AST/IRToASTVisitor.h +++ b/include/rellic/AST/IRToASTVisitor.h @@ -32,12 +32,12 @@ class IRToASTVisitor { ASTBuilder ast; - Provenance &provenance; + DecompilationContext &dec_ctx; void VisitArgument(llvm::Argument &arg); public: - IRToASTVisitor(clang::ASTUnit &unit, Provenance &provenance); + IRToASTVisitor(clang::ASTUnit &unit, DecompilationContext &dec_ctx); clang::Expr *CreateOperandExpr(llvm::Use &val); clang::Expr *CreateConstantExpr(llvm::Constant *constant); diff --git a/include/rellic/AST/InferenceRule.h b/include/rellic/AST/InferenceRule.h index 73124fd2..a1f89c30 100644 --- a/include/rellic/AST/InferenceRule.h +++ b/include/rellic/AST/InferenceRule.h @@ -34,13 +34,13 @@ class InferenceRule : public clang::ast_matchers::MatchFinder::MatchCallback { return cond; } - virtual clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + virtual clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) = 0; }; clang::Stmt *ApplyFirstMatchingRule( - Provenance &provenance, clang::ASTUnit &unit, clang::Stmt *stmt, + DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt, std::vector> &rules); } // namespace rellic \ No newline at end of file diff --git a/include/rellic/AST/LocalDeclRenamer.h b/include/rellic/AST/LocalDeclRenamer.h index fdd83869..8b57be8e 100644 --- a/include/rellic/AST/LocalDeclRenamer.h +++ b/include/rellic/AST/LocalDeclRenamer.h @@ -35,7 +35,7 @@ class LocalDeclRenamer : public TransformVisitor { void RunImpl() override; public: - LocalDeclRenamer(Provenance &provenance, clang::ASTUnit &unit, + LocalDeclRenamer(DecompilationContext &dec_ctx, clang::ASTUnit &unit, IRToNameMap &names); bool shouldTraversePostOrder() override; diff --git a/include/rellic/AST/LoopRefine.h b/include/rellic/AST/LoopRefine.h index 022690cb..363e7b35 100644 --- a/include/rellic/AST/LoopRefine.h +++ b/include/rellic/AST/LoopRefine.h @@ -34,7 +34,7 @@ class LoopRefine : public TransformVisitor { void RunImpl() override; public: - LoopRefine(Provenance &provenance, clang::ASTUnit &unit); + LoopRefine(DecompilationContext &dec_ctx, clang::ASTUnit &unit); bool VisitWhileStmt(clang::WhileStmt *loop); }; diff --git a/include/rellic/AST/MaterializeConds.h b/include/rellic/AST/MaterializeConds.h index 07513c1c..fcd3a587 100644 --- a/include/rellic/AST/MaterializeConds.h +++ b/include/rellic/AST/MaterializeConds.h @@ -27,7 +27,7 @@ class MaterializeConds : public TransformVisitor { void RunImpl() override; public: - MaterializeConds(Provenance &provenance, clang::ASTUnit &unit); + MaterializeConds(DecompilationContext &dec_ctx, clang::ASTUnit &unit); bool VisitIfStmt(clang::IfStmt *stmt); bool VisitWhileStmt(clang::WhileStmt *loop); diff --git a/include/rellic/AST/NestedCondProp.h b/include/rellic/AST/NestedCondProp.h index f5afe0cb..1534f332 100644 --- a/include/rellic/AST/NestedCondProp.h +++ b/include/rellic/AST/NestedCondProp.h @@ -54,7 +54,7 @@ class NestedCondProp : public ASTPass { void RunImpl() override; public: - NestedCondProp(Provenance &provenance, clang::ASTUnit &unit); + NestedCondProp(DecompilationContext& dec_ctx, clang::ASTUnit& unit); }; } // namespace rellic diff --git a/include/rellic/AST/NestedScopeCombine.h b/include/rellic/AST/NestedScopeCombine.h index 9281920e..b9158cd6 100644 --- a/include/rellic/AST/NestedScopeCombine.h +++ b/include/rellic/AST/NestedScopeCombine.h @@ -35,7 +35,7 @@ class NestedScopeCombine : public TransformVisitor { void RunImpl() override; public: - NestedScopeCombine(Provenance &provenance, clang::ASTUnit &unit); + NestedScopeCombine(DecompilationContext &dec_ctx, clang::ASTUnit &unit); bool VisitIfStmt(clang::IfStmt *ifstmt); bool VisitWhileStmt(clang::WhileStmt *stmt); diff --git a/include/rellic/AST/ReachBasedRefine.h b/include/rellic/AST/ReachBasedRefine.h index a8fbf3cc..1781cf7b 100644 --- a/include/rellic/AST/ReachBasedRefine.h +++ b/include/rellic/AST/ReachBasedRefine.h @@ -41,7 +41,7 @@ class ReachBasedRefine : public TransformVisitor { void RunImpl() override; public: - ReachBasedRefine(Provenance &provenance, clang::ASTUnit &unit); + ReachBasedRefine(DecompilationContext &dec_ctx, clang::ASTUnit &unit); bool VisitCompoundStmt(clang::CompoundStmt *compound); }; diff --git a/include/rellic/AST/StructFieldRenamer.h b/include/rellic/AST/StructFieldRenamer.h index 556290b3..9af6040a 100644 --- a/include/rellic/AST/StructFieldRenamer.h +++ b/include/rellic/AST/StructFieldRenamer.h @@ -29,7 +29,7 @@ class StructFieldRenamer void RunImpl() override; public: - StructFieldRenamer(Provenance &provenance, clang::ASTUnit &unit, + StructFieldRenamer(DecompilationContext &dec_ctx, clang::ASTUnit &unit, IRTypeToDITypeMap &types); bool VisitRecordDecl(clang::RecordDecl *decl); diff --git a/include/rellic/AST/TransformVisitor.h b/include/rellic/AST/TransformVisitor.h index 13b444dd..50e35965 100644 --- a/include/rellic/AST/TransformVisitor.h +++ b/include/rellic/AST/TransformVisitor.h @@ -26,11 +26,11 @@ class TransformVisitor : public ASTPass, StmtSubMap substitutions; void CopyProvenance(clang::Stmt *from, clang::Stmt *to) { - ::rellic::CopyProvenance(from, to, provenance.stmt_provenance); + ::rellic::CopyProvenance(from, to, dec_ctx.stmt_provenance); } void CopyProvenance(clang::Expr *from, clang::Expr *to) { - ::rellic::CopyProvenance(from, to, provenance.use_provenance); + ::rellic::CopyProvenance(from, to, dec_ctx.use_provenance); } bool ReplaceChildren(clang::Stmt *stmt, StmtSubMap &repl_map) { @@ -55,8 +55,8 @@ class TransformVisitor : public ASTPass, void RunImpl() override { substitutions.clear(); } public: - TransformVisitor(Provenance &provenance, clang::ASTUnit &unit) - : ASTPass(provenance, unit) {} + TransformVisitor(DecompilationContext &dec_ctx, clang::ASTUnit &unit) + : ASTPass(dec_ctx, unit) {} virtual bool shouldTraversePostOrder() { return true; } diff --git a/include/rellic/AST/Util.h b/include/rellic/AST/Util.h index 3c37cea0..c5006dd0 100644 --- a/include/rellic/AST/Util.h +++ b/include/rellic/AST/Util.h @@ -64,7 +64,7 @@ using Z3CondMap = std::unordered_map; using BBEdge = std::pair; using BrEdge = std::pair; using SwEdge = std::pair; -struct Provenance { +struct DecompilationContext { StmtToIRMap stmt_provenance; ExprToUseMap use_provenance; IRToTypeDeclMap type_decls; diff --git a/include/rellic/AST/Z3CondSimplify.h b/include/rellic/AST/Z3CondSimplify.h index 835e6ca0..edd27522 100644 --- a/include/rellic/AST/Z3CondSimplify.h +++ b/include/rellic/AST/Z3CondSimplify.h @@ -24,7 +24,7 @@ class Z3CondSimplify : public ASTPass { void RunImpl() override; public: - Z3CondSimplify(Provenance &provenance, clang::ASTUnit &unit); + Z3CondSimplify(DecompilationContext& dec_ctx, clang::ASTUnit& unit); }; } // namespace rellic diff --git a/lib/AST/CondBasedRefine.cpp b/lib/AST/CondBasedRefine.cpp index b5d3744b..977ec3c7 100644 --- a/lib/AST/CondBasedRefine.cpp +++ b/lib/AST/CondBasedRefine.cpp @@ -15,8 +15,9 @@ namespace rellic { -CondBasedRefine::CondBasedRefine(Provenance &provenance, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit) {} +CondBasedRefine::CondBasedRefine(DecompilationContext &dec_ctx, + clang::ASTUnit &unit) + : TransformVisitor(dec_ctx, unit) {} bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { std::vector body{compound->body_begin(), compound->body_end()}; @@ -30,8 +31,8 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { continue; } - auto cond_a{provenance.z3_exprs[provenance.conds[if_a]]}; - auto cond_b{provenance.z3_exprs[provenance.conds[if_b]]}; + auto cond_a{dec_ctx.z3_exprs[dec_ctx.conds[if_a]]}; + auto cond_b{dec_ctx.z3_exprs[dec_ctx.conds[if_b]]}; auto then_a{if_a->getThen()}; auto then_b{if_b->getThen()}; @@ -44,7 +45,7 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { new_then_body.push_back(then_b); auto new_then{ast.CreateCompoundStmt(new_then_body)}; - auto new_if{ast.CreateIf(provenance.marker_expr, new_then)}; + auto new_if{ast.CreateIf(dec_ctx.marker_expr, new_then)}; if (else_a || else_b) { std::vector new_else_body{}; @@ -61,7 +62,7 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { new_if->setElse(new_else); } - provenance.conds[new_if] = provenance.conds[if_a]; + dec_ctx.conds[new_if] = dec_ctx.conds[if_a]; body[i] = new_if; body.erase(std::next(body.begin(), i + 1)); did_something = true; @@ -81,12 +82,12 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { } new_else_body.push_back(then_b); - auto new_if{ast.CreateIf(provenance.marker_expr, new_then)}; + auto new_if{ast.CreateIf(dec_ctx.marker_expr, new_then)}; auto new_else{ast.CreateCompoundStmt(new_else_body)}; new_if->setElse(new_else); - provenance.conds[new_if] = provenance.conds[if_a]; + dec_ctx.conds[new_if] = dec_ctx.conds[if_a]; body[i] = new_if; body.erase(std::next(body.begin(), i + 1)); did_something = true; diff --git a/lib/AST/DeadStmtElim.cpp b/lib/AST/DeadStmtElim.cpp index b3b28664..038a5b83 100644 --- a/lib/AST/DeadStmtElim.cpp +++ b/lib/AST/DeadStmtElim.cpp @@ -12,14 +12,14 @@ namespace rellic { -DeadStmtElim::DeadStmtElim(Provenance &provenance, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit) {} +DeadStmtElim::DeadStmtElim(DecompilationContext &dec_ctx, clang::ASTUnit &unit) + : TransformVisitor(dec_ctx, unit) {} bool DeadStmtElim::VisitIfStmt(clang::IfStmt *ifstmt) { // DLOG(INFO) << "VisitIfStmt"; bool can_delete = false; - if (ifstmt->getCond() == provenance.marker_expr) { - can_delete = Prove(!provenance.z3_exprs[provenance.conds[ifstmt]]); + if (ifstmt->getCond() == dec_ctx.marker_expr) { + can_delete = Prove(!dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]); } auto compound = clang::dyn_cast(ifstmt->getThen()); diff --git a/lib/AST/ExprCombine.cpp b/lib/AST/ExprCombine.cpp index 441848f5..9303d479 100644 --- a/lib/AST/ExprCombine.cpp +++ b/lib/AST/ExprCombine.cpp @@ -40,7 +40,7 @@ class ArraySubscriptAddrOfRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto sub = clang::cast(stmt); @@ -48,7 +48,7 @@ class ArraySubscriptAddrOfRule : public InferenceRule { "ArraySubscriptExpr!"; auto paren = clang::cast(sub->getBase()); auto addr_of = clang::cast(paren->getSubExpr()); - CopyProvenance(addr_of, addr_of->getSubExpr(), provenance.use_provenance); + CopyProvenance(addr_of, addr_of->getSubExpr(), dec_ctx.use_provenance); return addr_of->getSubExpr(); } }; @@ -66,7 +66,7 @@ class AddrOfArraySubscriptRule : public InferenceRule { match = result.Nodes.getNodeAs("addr_of"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto addr_of = clang::cast(stmt); @@ -74,7 +74,7 @@ class AddrOfArraySubscriptRule : public InferenceRule { << "Substituted UnaryOperator is not the matched UnaryOperator!"; auto subexpr = addr_of->getSubExpr()->IgnoreParenImpCasts(); auto sub = clang::cast(subexpr); - CopyProvenance(sub, sub->getBase(), provenance.use_provenance); + CopyProvenance(sub, sub->getBase(), dec_ctx.use_provenance); return sub->getBase(); } }; @@ -91,7 +91,7 @@ class DerefAddrOfRule : public InferenceRule { match = result.Nodes.getNodeAs("deref"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto deref = clang::cast(stmt); @@ -99,7 +99,7 @@ class DerefAddrOfRule : public InferenceRule { << "Substituted UnaryOperator is not the matched UnaryOperator!"; auto subexpr = deref->getSubExpr()->IgnoreParenImpCasts(); auto addr_of = clang::cast(subexpr); - CopyProvenance(addr_of, addr_of->getSubExpr(), provenance.use_provenance); + CopyProvenance(addr_of, addr_of->getSubExpr(), dec_ctx.use_provenance); return addr_of->getSubExpr(); } }; @@ -120,7 +120,7 @@ class DerefAddrOfConditionalRule : public InferenceRule { match = result.Nodes.getNodeAs("deref"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto deref = clang::cast(stmt); @@ -137,8 +137,8 @@ class DerefAddrOfConditionalRule : public InferenceRule { clang::cast(conditional->getTrueExpr()); auto addr_of2 = clang::cast(conditional->getFalseExpr()); - CopyProvenance(addr_of1, addr_of1->getSubExpr(), provenance.use_provenance); - CopyProvenance(addr_of2, addr_of2->getSubExpr(), provenance.use_provenance); + CopyProvenance(addr_of1, addr_of1->getSubExpr(), dec_ctx.use_provenance); + CopyProvenance(addr_of2, addr_of2->getSubExpr(), dec_ctx.use_provenance); return ASTBuilder(unit).CreateConditional( conditional->getCond(), addr_of1->getSubExpr(), addr_of2->getSubExpr()); @@ -161,7 +161,7 @@ class NegComparisonRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto op = clang::cast(stmt); @@ -172,7 +172,7 @@ class NegComparisonRule : public InferenceRule { auto opc = clang::BinaryOperator::negateComparisonOp(binop->getOpcode()); auto res = ASTBuilder(unit).CreateBinaryOp(opc, binop->getLHS(), binop->getRHS()); - CopyProvenance(op, res, provenance.stmt_provenance); + CopyProvenance(op, res, dec_ctx.stmt_provenance); return res; } }; @@ -189,7 +189,7 @@ class ParenDeclRefExprStripRule : public InferenceRule { match = result.Nodes.getNodeAs("paren"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto paren = clang::cast(stmt); @@ -209,7 +209,7 @@ class DoubleParenStripRule : public InferenceRule { match = result.Nodes.getNodeAs("paren"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto paren = clang::cast(stmt); @@ -235,7 +235,7 @@ class MemberExprAddrOfRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto arrow{clang::cast(stmt)}; @@ -246,7 +246,7 @@ class MemberExprAddrOfRule : public InferenceRule { auto field{clang::dyn_cast(arrow->getMemberDecl())}; CHECK(field != nullptr) << "Substituted MemberExpr is not a structure field access!"; - CopyProvenance(addr_of, addr_of->getSubExpr(), provenance.use_provenance); + CopyProvenance(addr_of, addr_of->getSubExpr(), dec_ctx.use_provenance); return ASTBuilder(unit).CreateDot(addr_of->getSubExpr(), field); } }; @@ -268,7 +268,7 @@ class MemberExprArraySubRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto dot{clang::cast(stmt)}; @@ -279,7 +279,7 @@ class MemberExprArraySubRule : public InferenceRule { auto field{clang::dyn_cast(dot->getMemberDecl())}; CHECK(field != nullptr) << "Substituted MemberExpr is not a structure field access!"; - CopyProvenance(sub, sub->getBase(), provenance.use_provenance); + CopyProvenance(sub, sub->getBase(), dec_ctx.use_provenance); return ASTBuilder(unit).CreateArrow(sub->getBase(), field); } }; @@ -300,7 +300,7 @@ class AssignCastedExprRule : public InferenceRule { }; } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto assign{clang::cast(stmt)}; @@ -336,7 +336,7 @@ class UnsignedToSignedCStyleCastRule : public InferenceRule { match = result.Nodes.getNodeAs("cast"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto cast{clang::cast(stmt)}; @@ -370,7 +370,7 @@ class TripleCStyleCastElimRule : public InferenceRule { match = result.Nodes.getNodeAs("cast"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto cast{clang::cast(stmt)}; @@ -409,7 +409,7 @@ class VoidToTypePtrCastElimRule : public InferenceRule { match = result.Nodes.getNodeAs("cast"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto cast{clang::cast(stmt)}; @@ -444,7 +444,7 @@ class CStyleZeroToPtrCastElimRule : public InferenceRule { match = result.Nodes.getNodeAs("cast"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto cast{clang::cast(stmt)}; @@ -466,7 +466,7 @@ class CStyleConstElimRule : public InferenceRule { match = result.Nodes.getNodeAs("cast"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto cast{clang::cast(stmt)}; @@ -485,8 +485,8 @@ class CStyleConstElimRule : public InferenceRule { } // namespace -ExprCombine::ExprCombine(Provenance &provenance, clang::ASTUnit &u) - : TransformVisitor(provenance, u) {} +ExprCombine::ExprCombine(DecompilationContext &dec_ctx, clang::ASTUnit &u) + : TransformVisitor(dec_ctx, u) {} bool ExprCombine::VisitCStyleCastExpr(clang::CStyleCastExpr *cast) { // TODO(frabert): Re-enable nullptr casts simplification @@ -500,7 +500,7 @@ bool ExprCombine::VisitCStyleCastExpr(clang::CStyleCastExpr *cast) { pre_rules.emplace_back(new VoidToTypePtrCastElimRule); - auto pre_sub{ApplyFirstMatchingRule(provenance, ast_unit, cast, pre_rules)}; + auto pre_sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, cast, pre_rules)}; if (pre_sub != cast) { substitutions[cast] = pre_sub; return true; @@ -533,7 +533,7 @@ bool ExprCombine::VisitCStyleCastExpr(clang::CStyleCastExpr *cast) { rules.emplace_back(new TripleCStyleCastElimRule); rules.emplace_back(new CStyleConstElimRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, cast, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, cast, rules)}; if (sub != cast) { substitutions[cast] = sub; } @@ -551,7 +551,7 @@ bool ExprCombine::VisitUnaryOperator(clang::UnaryOperator *op) { rules.emplace_back(new DerefAddrOfConditionalRule); rules.emplace_back(new AddrOfArraySubscriptRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, op, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, op, rules)}; if (sub != op) { substitutions[op] = sub; } @@ -565,7 +565,7 @@ bool ExprCombine::VisitBinaryOperator(clang::BinaryOperator *op) { rules.emplace_back(new AssignCastedExprRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, op, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, op, rules)}; if (sub != op) { substitutions[op] = sub; } @@ -579,7 +579,7 @@ bool ExprCombine::VisitArraySubscriptExpr(clang::ArraySubscriptExpr *expr) { rules.emplace_back(new ArraySubscriptAddrOfRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, expr, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, expr, rules)}; if (sub != expr) { substitutions[expr] = sub; } @@ -594,7 +594,7 @@ bool ExprCombine::VisitMemberExpr(clang::MemberExpr *expr) { rules.emplace_back(new MemberExprAddrOfRule); rules.emplace_back(new MemberExprArraySubRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, expr, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, expr, rules)}; if (sub != expr) { substitutions[expr] = sub; } @@ -609,7 +609,7 @@ bool ExprCombine::VisitParenExpr(clang::ParenExpr *expr) { rules.emplace_back(new ParenDeclRefExprStripRule); rules.emplace_back(new DoubleParenStripRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, expr, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, expr, rules)}; if (sub != expr) { substitutions[expr] = sub; } diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index b2e5ec6d..afe2c720 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -141,44 +141,43 @@ unsigned GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond) { auto ToExpr = [&](unsigned idx) { if (idx == POISON_IDX) { - return provenance.z3_ctx.bool_val(false); + return dec_ctx.z3_ctx.bool_val(false); } - return provenance.z3_exprs[idx]; + return dec_ctx.z3_exprs[idx]; }; - if (provenance.z3_br_edges.find({inst, cond}) == - provenance.z3_br_edges.end()) { + if (dec_ctx.z3_br_edges.find({inst, cond}) == dec_ctx.z3_br_edges.end()) { if (auto constant = llvm::dyn_cast(inst->getCondition())) { - provenance.z3_br_edges[{inst, cond}] = provenance.z3_exprs.size(); - auto edge{provenance.z3_ctx.bool_val(constant->isOne() == cond)}; - provenance.z3_exprs.push_back(edge); - provenance.z3_br_edges_inv[edge.id()] = {inst, true}; + dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.z3_exprs.size(); + auto edge{dec_ctx.z3_ctx.bool_val(constant->isOne() == cond)}; + dec_ctx.z3_exprs.push_back(edge); + dec_ctx.z3_br_edges_inv[edge.id()] = {inst, true}; } else if (cond) { auto name{GetName(inst)}; - auto edge{provenance.z3_ctx.bool_const(name.c_str())}; - provenance.z3_br_edges[{inst, cond}] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(edge); - provenance.z3_br_edges_inv[edge.id()] = {inst, true}; + auto edge{dec_ctx.z3_ctx.bool_const(name.c_str())}; + dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(edge); + dec_ctx.z3_br_edges_inv[edge.id()] = {inst, true}; } else { auto edge{!(ToExpr(GetOrCreateEdgeForBranch(inst, true)))}; - provenance.z3_br_edges[{inst, cond}] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(edge); + dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(edge); } } - return provenance.z3_br_edges[{inst, cond}]; + return dec_ctx.z3_br_edges[{inst, cond}]; } unsigned GenerateAST::GetOrCreateVarForSwitch(llvm::SwitchInst *inst) { - if (provenance.z3_sw_vars.find(inst) == provenance.z3_sw_vars.end()) { + if (dec_ctx.z3_sw_vars.find(inst) == dec_ctx.z3_sw_vars.end()) { auto name{GetName(inst)}; - auto var{provenance.z3_ctx.int_const(name.c_str())}; - provenance.z3_sw_vars[inst] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(var); - provenance.z3_sw_vars_inv[var.id()] = inst; - return provenance.z3_sw_vars[inst]; + auto var{dec_ctx.z3_ctx.int_const(name.c_str())}; + dec_ctx.z3_sw_vars[inst] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(var); + dec_ctx.z3_sw_vars_inv[var.id()] = inst; + return dec_ctx.z3_sw_vars[inst]; } else { - return provenance.z3_sw_vars[inst]; + return dec_ctx.z3_sw_vars[inst]; } } @@ -186,44 +185,44 @@ unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, llvm::ConstantInt *c) { auto ToExpr = [&](unsigned idx) { if (idx == POISON_IDX) { - return provenance.z3_ctx.bool_val(false); + return dec_ctx.z3_ctx.bool_val(false); } - return provenance.z3_exprs[idx]; + return dec_ctx.z3_exprs[idx]; }; - if (provenance.z3_sw_edges.find({inst, c}) == provenance.z3_sw_edges.end()) { + if (dec_ctx.z3_sw_edges.find({inst, c}) == dec_ctx.z3_sw_edges.end()) { if (c) { auto sw_case{inst->findCaseValue(c)}; auto var{ToExpr(GetOrCreateVarForSwitch(inst))}; - auto expr{var == provenance.z3_ctx.int_val(sw_case->getCaseIndex())}; + auto expr{var == dec_ctx.z3_ctx.int_val(sw_case->getCaseIndex())}; - provenance.z3_sw_edges[{inst, c}] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(expr); + dec_ctx.z3_sw_edges[{inst, c}] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(expr); } else { // Default case - z3::expr_vector vec{provenance.z3_ctx}; + z3::expr_vector vec{dec_ctx.z3_ctx}; for (auto sw_case : inst->cases()) { vec.push_back( !ToExpr(GetOrCreateEdgeForSwitch(inst, sw_case.getCaseValue()))); } - provenance.z3_sw_edges[{inst, c}] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(z3::mk_and(vec)); + dec_ctx.z3_sw_edges[{inst, c}] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(z3::mk_and(vec)); } } - return provenance.z3_sw_edges[{inst, c}]; + return dec_ctx.z3_sw_edges[{inst, c}]; } unsigned GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, llvm::BasicBlock *to) { auto ToExpr = [&](unsigned idx) { if (idx == POISON_IDX) { - return provenance.z3_ctx.bool_val(false); + return dec_ctx.z3_ctx.bool_val(false); } - return provenance.z3_exprs[idx]; + return dec_ctx.z3_exprs[idx]; }; - if (provenance.z3_edges.find({from, to}) == provenance.z3_edges.end()) { + if (dec_ctx.z3_edges.find({from, to}) == dec_ctx.z3_edges.end()) { // Construct the edge condition for CFG edge `(from, to)` - auto result{provenance.z3_ctx.bool_val(true)}; + auto result{dec_ctx.z3_ctx.bool_val(true)}; auto term = from->getTerminator(); switch (term->getOpcode()) { // Conditional branches @@ -240,7 +239,7 @@ unsigned GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, if (to == sw->getDefaultDest()) { result = ToExpr(GetOrCreateEdgeForSwitch(sw, nullptr)); } else { - z3::expr_vector or_vec{provenance.z3_ctx}; + z3::expr_vector or_vec{dec_ctx.z3_ctx}; for (auto sw_case : sw->cases()) { if (sw_case.getCaseSuccessor() == to) { or_vec.push_back( @@ -269,35 +268,34 @@ unsigned GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, break; } - provenance.z3_edges[{from, to}] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(result.simplify()); + dec_ctx.z3_edges[{from, to}] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(result.simplify()); } - return provenance.z3_edges[{from, to}]; + return dec_ctx.z3_edges[{from, to}]; } unsigned GenerateAST::GetReachingCond(llvm::BasicBlock *block) { - if (provenance.reaching_conds.find(block) == - provenance.reaching_conds.end()) { + if (dec_ctx.reaching_conds.find(block) == dec_ctx.reaching_conds.end()) { return POISON_IDX; } - return provenance.reaching_conds[block]; + return dec_ctx.reaching_conds[block]; } void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { auto ToExpr = [&](unsigned idx) { if (idx == POISON_IDX) { - return provenance.z3_ctx.bool_val(false); + return dec_ctx.z3_ctx.bool_val(false); } - return provenance.z3_exprs[idx]; + return dec_ctx.z3_exprs[idx]; }; auto old_cond_idx{GetReachingCond(block)}; auto old_cond{ToExpr(old_cond_idx)}; if (block->hasNPredecessorsOrMore(1)) { // Gather reaching conditions from predecessors of the block - z3::expr_vector conds{provenance.z3_ctx}; + z3::expr_vector conds{dec_ctx.z3_ctx}; for (auto pred : llvm::predecessors(block)) { auto pred_cond{ToExpr(GetReachingCond(pred))}; auto edge_cond{ToExpr(GetOrCreateEdgeCond(pred, block))}; @@ -313,15 +311,14 @@ void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { auto cond{HeavySimplify(z3::mk_or(conds))}; if (old_cond_idx == POISON_IDX || !Prove(old_cond == cond)) { - provenance.reaching_conds[block] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(cond); + dec_ctx.reaching_conds[block] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(cond); reaching_conds_changed = true; } } else { - if (provenance.reaching_conds.find(block) == - provenance.reaching_conds.end()) { - provenance.reaching_conds[block] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(provenance.z3_ctx.bool_val(true)); + if (dec_ctx.reaching_conds.find(block) == dec_ctx.reaching_conds.end()) { + dec_ctx.reaching_conds[block] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(dec_ctx.z3_ctx.bool_val(true)); reaching_conds_changed = true; } } @@ -336,9 +333,9 @@ StmtVec GenerateAST::CreateBasicBlockStmts(llvm::BasicBlock *block) { StmtVec GenerateAST::CreateRegionStmts(llvm::Region *region) { auto ToExpr = [&](unsigned idx) { if (idx == POISON_IDX) { - return provenance.z3_ctx.bool_val(false); + return dec_ctx.z3_ctx.bool_val(false); } - return provenance.z3_exprs[idx]; + return dec_ctx.z3_exprs[idx]; }; StmtVec result; for (auto block : rpo_walk) { @@ -362,9 +359,9 @@ StmtVec GenerateAST::CreateRegionStmts(llvm::Region *region) { } // Gate the compound behind a reaching condition auto z_expr{GetReachingCond(block)}; - block_stmts[block] = ast.CreateIf(provenance.marker_expr, compound); - provenance.conds[block_stmts[block]] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(provenance.z3_exprs[z_expr]); + block_stmts[block] = ast.CreateIf(dec_ctx.marker_expr, compound); + dec_ctx.conds[block_stmts[block]] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(dec_ctx.z3_exprs[z_expr]); // Store the compound result.push_back(block_stmts[block]); } @@ -428,9 +425,9 @@ clang::CompoundStmt *GenerateAST::StructureAcyclicRegion(llvm::Region *region) { clang::CompoundStmt *GenerateAST::StructureCyclicRegion(llvm::Region *region) { auto ToExpr = [&](unsigned idx) { if (idx == POISON_IDX) { - return provenance.z3_ctx.bool_val(false); + return dec_ctx.z3_ctx.bool_val(false); } - return provenance.z3_exprs[idx]; + return dec_ctx.z3_exprs[idx]; }; DLOG(INFO) << "Region " << GetRegionNameStr(region) << " is cyclic"; auto region_body = CreateRegionStmts(region); @@ -476,11 +473,11 @@ clang::CompoundStmt *GenerateAST::StructureCyclicRegion(llvm::Region *region) { CHECK(it != loop_body.end()); // Create a loop exiting `break` statement StmtVec break_stmt({ast.CreateBreak()}); - auto exit_stmt = ast.CreateIf(provenance.marker_expr, - ast.CreateCompoundStmt(break_stmt)); - provenance.conds[exit_stmt] = provenance.z3_exprs.size(); + auto exit_stmt = + ast.CreateIf(dec_ctx.marker_expr, ast.CreateCompoundStmt(break_stmt)); + dec_ctx.conds[exit_stmt] = dec_ctx.z3_exprs.size(); // Create edge condition - provenance.z3_exprs.push_back( + dec_ctx.z3_exprs.push_back( (ToExpr(GetReachingCond(from)) && ToExpr(GetOrCreateEdgeCond(from, to))) .simplify()); // Insert it after the exiting block statement @@ -565,11 +562,11 @@ clang::CompoundStmt *GenerateAST::StructureRegion(llvm::Region *region) { llvm::AnalysisKey GenerateAST::Key; -GenerateAST::GenerateAST(Provenance &provenance, clang::ASTUnit &unit) +GenerateAST::GenerateAST(DecompilationContext &dec_ctx, clang::ASTUnit &unit) : ast_ctx(&unit.getASTContext()), unit(unit), - provenance(provenance), - ast_gen(unit, provenance), + dec_ctx(dec_ctx), + ast_gen(unit, dec_ctx), ast(unit) {} GenerateAST::Result GenerateAST::run(llvm::Module &module, @@ -648,13 +645,13 @@ GenerateAST::Result GenerateAST::run(llvm::Function &func, // Call the above declared bad boy POWalkSubRegions(regions->getTopLevelRegion()); // Get the function declaration AST node for `func` - auto fdecl = clang::cast(provenance.value_decls[&func]); + auto fdecl = clang::cast(dec_ctx.value_decls[&func]); // Create a redeclaration of `fdecl` that will serve as a definition auto tudecl = ast_ctx->getTranslationUnitDecl(); auto fdefn = ast.CreateFunctionDecl(tudecl, fdecl->getType(), fdecl->getIdentifier()); fdefn->setPreviousDecl(fdecl); - provenance.value_decls[&func] = fdefn; + dec_ctx.value_decls[&func] = fdefn; tudecl->addDecl(fdefn); // Set parameters to the same as the previous declaration fdefn->setParams(fdecl->parameters()); @@ -676,20 +673,20 @@ GenerateAST::Result GenerateAST::run(llvm::Function &func, return llvm::PreservedAnalyses::all(); } -void GenerateAST::run(llvm::Module &module, Provenance &provenance, +void GenerateAST::run(llvm::Module &module, DecompilationContext &dec_ctx, clang::ASTUnit &unit) { llvm::ModulePassManager mpm; llvm::ModuleAnalysisManager mam; llvm::PassBuilder pb; - mam.registerPass([&] { return rellic::GenerateAST(provenance, unit); }); - mpm.addPass(rellic::GenerateAST(provenance, unit)); + mam.registerPass([&] { return rellic::GenerateAST(dec_ctx, unit); }); + mpm.addPass(rellic::GenerateAST(dec_ctx, unit)); pb.registerModuleAnalyses(mam); mpm.run(module, mam); llvm::FunctionPassManager fpm; llvm::FunctionAnalysisManager fam; - fam.registerPass([&] { return rellic::GenerateAST(provenance, unit); }); - fpm.addPass(rellic::GenerateAST(provenance, unit)); + fam.registerPass([&] { return rellic::GenerateAST(dec_ctx, unit); }); + fpm.addPass(rellic::GenerateAST(dec_ctx, unit)); pb.registerFunctionAnalyses(fam); for (auto &func : module.functions()) { fpm.run(func, fam); diff --git a/lib/AST/IRToASTVisitor.cpp b/lib/AST/IRToASTVisitor.cpp index 33349b4f..1bc58662 100755 --- a/lib/AST/IRToASTVisitor.cpp +++ b/lib/AST/IRToASTVisitor.cpp @@ -30,11 +30,13 @@ class ExprGen : public llvm::InstVisitor { ASTBuilder ast; - Provenance &provenance; + DecompilationContext &dec_ctx; + size_t num_literal_structs = 0; + size_t num_declared_structs = 0; public: - ExprGen(clang::ASTUnit &unit, Provenance &provenance) - : ast_ctx(unit.getASTContext()), ast(unit), provenance(provenance) {} + ExprGen(clang::ASTUnit &unit, DecompilationContext &dec_ctx) + : ast_ctx(unit.getASTContext()), ast(unit), dec_ctx(dec_ctx) {} void VisitGlobalVar(llvm::GlobalVariable &gvar); clang::QualType GetQualType(llvm::Type *type); @@ -67,11 +69,11 @@ clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { auto a{expr.arg(0)}; auto b{expr.arg(1)}; - llvm::SwitchInst *inst{provenance.z3_sw_vars_inv[a.id()]}; + llvm::SwitchInst *inst{dec_ctx.z3_sw_vars_inv[a.id()]}; unsigned case_idx{}; if (!inst) { - inst = provenance.z3_sw_vars_inv[b.id()]; + inst = dec_ctx.z3_sw_vars_inv[b.id()]; case_idx = a.get_numeral_uint(); } else { case_idx = b.get_numeral_uint(); @@ -88,16 +90,15 @@ clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { } auto hash{expr.id()}; - if (provenance.z3_br_edges_inv.find(hash) != - provenance.z3_br_edges_inv.end()) { - auto edge{provenance.z3_br_edges_inv[hash]}; + if (dec_ctx.z3_br_edges_inv.find(hash) != dec_ctx.z3_br_edges_inv.end()) { + auto edge{dec_ctx.z3_br_edges_inv[hash]}; CHECK(edge.second) << "Inverse map should only be populated for branches " "taken when condition is true"; return CreateOperandExpr(*(edge.first->op_end() - 3)); } - if (provenance.z3_sw_vars_inv.find(hash) != provenance.z3_sw_vars_inv.end()) { - auto inst{provenance.z3_sw_vars_inv[hash]}; + if (dec_ctx.z3_sw_vars_inv.find(hash) != dec_ctx.z3_sw_vars_inv.end()) { + auto inst{dec_ctx.z3_sw_vars_inv[hash]}; return CreateOperandExpr(inst->getOperandUse(0)); } @@ -130,7 +131,7 @@ clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { case Z3_OP_NOT: { CHECK_EQ(args.size(), 1) << "Not must have one argument"; auto neg{ast.CreateLNot(args[0])}; - CopyProvenance(args[0], neg, provenance.use_provenance); + CopyProvenance(args[0], neg, dec_ctx.use_provenance); return neg; } default: @@ -141,7 +142,7 @@ clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { void ExprGen::VisitGlobalVar(llvm::GlobalVariable &gvar) { DLOG(INFO) << "VisitGlobalVar: " << LLVMThingToString(&gvar); - auto &var{provenance.value_decls[&gvar]}; + auto &var{dec_ctx.value_decls[&gvar]}; if (var) { return; } @@ -242,16 +243,16 @@ clang::QualType ExprGen::GetQualType(llvm::Type *type) { case llvm::Type::StructTyID: { clang::RecordDecl *sdecl{nullptr}; - auto &decl{provenance.type_decls[type]}; + auto &decl{dec_ctx.type_decls[type]}; if (!decl) { auto tudecl{ast_ctx.getTranslationUnitDecl()}; auto strct{llvm::cast(type)}; auto sname{strct->isLiteral() ? ("literal_struct_" + - std::to_string(provenance.num_literal_structs++)) + std::to_string(dec_ctx.num_literal_structs++)) : strct->getName().str()}; if (sname.empty()) { - sname = "struct" + std::to_string(provenance.num_declared_structs++); + sname = "struct" + std::to_string(dec_ctx.num_declared_structs++); } // Create a C struct declaration @@ -305,13 +306,13 @@ clang::Expr *ExprGen::CreateConstantExpr(llvm::Constant *constant) { if (auto cexpr = llvm::dyn_cast(constant)) { auto inst{cexpr->getAsInstruction()}; auto expr{visit(inst)}; - provenance.use_provenance.erase(expr); + dec_ctx.use_provenance.erase(expr); inst->deleteValue(); return expr; } else if (auto alias = llvm::dyn_cast(constant)) { return CreateConstantExpr(alias->getAliasee()); } else if (auto global = llvm::dyn_cast(constant)) { - auto decl{provenance.value_decls[global]}; + auto decl{dec_ctx.value_decls[global]}; auto ref{ast.CreateDeclRef(decl)}; return ast.CreateAddrOf(ref); } @@ -421,9 +422,9 @@ clang::Expr *ExprGen::CreateLiteralExpr(llvm::Constant *constant) { clang::Expr *ExprGen::CreateOperandExpr(llvm::Use &val) { DLOG(INFO) << "Getting Expr for " << LLVMThingToString(val); auto CreateRef{[this, &val] { - auto decl{provenance.value_decls[val]}; + auto decl{dec_ctx.value_decls[val]}; auto ref{ast.CreateDeclRef(decl)}; - provenance.use_provenance[ref] = &val; + dec_ctx.use_provenance[ref] = &val; return ref; }}; @@ -445,13 +446,12 @@ clang::Expr *ExprGen::CreateOperandExpr(llvm::Use &val) { // the actual argument and use it instead of the actual argument. This // is because `byval` arguments are pointers, so each reference to those // arguments assume they are dealing with pointers. - auto &temp{provenance.temp_decls[arg]}; + auto &temp{dec_ctx.temp_decls[arg]}; if (!temp) { auto addr_of_arg{ast.CreateAddrOf(ref)}; auto func{arg->getParent()}; - auto fdecl{provenance.value_decls[func]->getAsFunction()}; - auto argdecl{ - clang::cast(provenance.value_decls[arg])}; + auto fdecl{dec_ctx.value_decls[func]->getAsFunction()}; + auto argdecl{clang::cast(dec_ctx.value_decls[arg])}; temp = ast.CreateVarDecl(fdecl, GetQualType(arg->getType()), argdecl->getName().str() + "_ptr"); temp->setInit(addr_of_arg); @@ -464,7 +464,7 @@ clang::Expr *ExprGen::CreateOperandExpr(llvm::Use &val) { } } else if (auto inst = llvm::dyn_cast(val)) { // Operand is a result of an expression - if (auto decl = provenance.value_decls[inst]) { + if (auto decl = dec_ctx.value_decls[inst]) { res = ast.CreateDeclRef(decl); } else { res = visit(inst); @@ -485,7 +485,7 @@ clang::Expr *ExprGen::CreateOperandExpr(llvm::Use &val) { << "Bitcode: [" << LLVMThingToString(val) << "]\n" << "Type: [" << LLVMThingToString(val->getType()) << "]\n"; } - provenance.use_provenance[res] = &val; + dec_ctx.use_provenance[res] = &val; return res; } @@ -574,7 +574,7 @@ clang::Expr *ExprGen::visitCallInst(llvm::CallInst &inst) { auto ptr_type{ ast_ctx.getPointerType(GetQualType(inst.getParamByValType(i)))}; opnd = ast.CreateDeref(ast.CreateCStyleCast(ptr_type, opnd)); - provenance.use_provenance[opnd] = &arg; + dec_ctx.use_provenance[opnd] = &arg; } args.push_back(opnd); } @@ -582,7 +582,7 @@ clang::Expr *ExprGen::visitCallInst(llvm::CallInst &inst) { clang::Expr *callexpr{nullptr}; auto &callee{*(inst.op_end() - 1)}; if (auto func = llvm::dyn_cast(callee)) { - auto fdecl{provenance.value_decls[func]->getAsFunction()}; + auto fdecl{dec_ctx.value_decls[func]->getAsFunction()}; if (func->getFunctionType() == inst.getFunctionType()) { callexpr = ast.CreateCall(fdecl, args); } else { @@ -593,7 +593,7 @@ clang::Expr *ExprGen::visitCallInst(llvm::CallInst &inst) { callexpr = ast.CreateCall(cast, args); } } else if (auto iasm = llvm::dyn_cast(callee)) { - auto fdecl{provenance.value_decls[iasm]->getAsFunction()}; + auto fdecl{dec_ctx.value_decls[iasm]->getAsFunction()}; callexpr = ast.CreateCall(fdecl, args); } else if (llvm::isa(callee->getType())) { auto funcPtr{ast_ctx.getPointerType(GetQualType(inst.getFunctionType()))}; @@ -697,7 +697,7 @@ clang::Expr *ExprGen::visitGetElementPtrInst(llvm::GetElementPtrInst &inst) { case llvm::Type::StructTyID: { auto mem_idx = llvm::dyn_cast(idx); CHECK(mem_idx) << "Non-constant GEP index while indexing a structure"; - auto tdecl{provenance.type_decls[indexed_type]}; + auto tdecl{dec_ctx.type_decls[indexed_type]}; CHECK(tdecl) << "Structure declaration doesn't exist"; auto record{clang::cast(tdecl)}; auto field_it{record->field_begin()}; @@ -755,7 +755,7 @@ clang::Expr *ExprGen::visitExtractValueInst(llvm::ExtractValueInst &inst) { } break; // Structures case llvm::Type::StructTyID: { - auto tdecl{provenance.type_decls[indexed_type]}; + auto tdecl{dec_ctx.type_decls[indexed_type]}; CHECK(tdecl) << "Structure declaration doesn't exist"; auto record{clang::cast(tdecl)}; auto field_it{record->field_begin()}; @@ -879,7 +879,7 @@ clang::Expr *ExprGen::visitCmpInst(llvm::CmpInst &inst) { return op; } else { auto cast{ast.CreateCStyleCast(rt, op)}; - CopyProvenance(op, cast, provenance.use_provenance); + CopyProvenance(op, cast, dec_ctx.use_provenance); return (clang::Expr *)cast; } }}; @@ -1067,15 +1067,12 @@ class StmtGen : public llvm::InstVisitor { clang::ASTContext &ast_ctx; ASTBuilder * ExprGen &expr_gen; - Provenance &provenance; + DecompilationContext &dec_ctx; public: StmtGen(clang::ASTContext &ast_ctx, ASTBuilder &ast, ExprGen &expr_gen, - Provenance &provenance) - : ast_ctx(ast_ctx), - ast(ast), - expr_gen(expr_gen), - provenance(provenance) {} + DecompilationContext &dec_ctx) + : ast_ctx(ast_ctx), ast(ast), expr_gen(expr_gen), dec_ctx(dec_ctx) {} clang::Stmt *visitStoreInst(llvm::StoreInst &inst); clang::Stmt *visitCallInst(llvm::CallInst &inst); @@ -1121,13 +1118,13 @@ clang::Stmt *StmtGen::visitStoreInst(llvm::StoreInst &inst) { } else { // Create the assignemnt itself auto deref{ast.CreateDeref(lhs)}; - CopyProvenance(lhs, deref, provenance.use_provenance); + CopyProvenance(lhs, deref, dec_ctx.use_provenance); return ast.CreateAssign(deref, rhs); } } clang::Stmt *StmtGen::visitCallInst(llvm::CallInst &inst) { - auto &var{provenance.value_decls[&inst]}; + auto &var{dec_ctx.value_decls[&inst]}; auto expr{expr_gen.visit(inst)}; if (var) { return ast.CreateAssign(ast.CreateDeclRef(var), expr); @@ -1168,7 +1165,7 @@ clang::Stmt *StmtGen::visitUnreachableInst(llvm::UnreachableInst &inst) { clang::Stmt *StmtGen::visitPHINode(llvm::PHINode &inst) { return nullptr; } clang::Stmt *StmtGen::visitInstruction(llvm::Instruction &inst) { - auto &var{provenance.value_decls[&inst]}; + auto &var{dec_ctx.value_decls[&inst]}; if (var) { auto expr{expr_gen.visit(inst)}; return ast.CreateAssign(ast.CreateDeclRef(var), expr); @@ -1176,20 +1173,21 @@ clang::Stmt *StmtGen::visitInstruction(llvm::Instruction &inst) { return nullptr; } -IRToASTVisitor::IRToASTVisitor(clang::ASTUnit &unit, Provenance &provenance) +IRToASTVisitor::IRToASTVisitor(clang::ASTUnit &unit, + DecompilationContext &dec_ctx) : ast_unit(unit), ast_ctx(unit.getASTContext()), ast(unit), - provenance(provenance) {} + dec_ctx(dec_ctx) {} void IRToASTVisitor::VisitGlobalVar(llvm::GlobalVariable &gvar) { - ExprGen expr_gen{ast_unit, provenance}; + ExprGen expr_gen{ast_unit, dec_ctx}; expr_gen.VisitGlobalVar(gvar); } void IRToASTVisitor::VisitArgument(llvm::Argument &arg) { DLOG(INFO) << "VisitArgument: " << LLVMThingToString(&arg); - auto &parm{provenance.value_decls[&arg]}; + auto &parm{dec_ctx.value_decls[&arg]}; if (parm) { return; } @@ -1198,13 +1196,13 @@ void IRToASTVisitor::VisitArgument(llvm::Argument &arg) { : "arg" + std::to_string(arg.getArgNo())}; // Get parent function declaration auto func{arg.getParent()}; - auto fdecl{clang::cast(provenance.value_decls[func])}; + auto fdecl{clang::cast(dec_ctx.value_decls[func])}; auto argtype{arg.getType()}; if (arg.hasByValAttr()) { auto byval{arg.getAttribute(llvm::Attribute::ByVal)}; argtype = byval.getValueAsType(); } - ExprGen expr_gen{ast_unit, provenance}; + ExprGen expr_gen{ast_unit, dec_ctx}; // Create a declaration parm = ast.CreateParamDecl(fdecl, expr_gen.GetQualType(argtype), name); } @@ -1237,20 +1235,20 @@ static llvm::FunctionType *GetFixedFunctionType(llvm::Function &func) { void IRToASTVisitor::VisitBasicBlock(llvm::BasicBlock &block, std::vector &stmts) { - ExprGen expr_gen{ast_unit, provenance}; - StmtGen stmt_gen{ast_ctx, ast, expr_gen, provenance}; + ExprGen expr_gen{ast_unit, dec_ctx}; + StmtGen stmt_gen{ast_ctx, ast, expr_gen, dec_ctx}; for (auto &inst : block) { auto stmt{stmt_gen.visit(inst)}; if (stmt) { stmts.push_back(stmt); - provenance.stmt_provenance[stmt] = &inst; + dec_ctx.stmt_provenance[stmt] = &inst; } } - auto &uses{provenance.outgoing_uses[&block]}; + auto &uses{dec_ctx.outgoing_uses[&block]}; for (auto it{uses.rbegin()}; it != uses.rend(); ++it) { auto use{*it}; - auto var{provenance.value_decls[use->getUser()]}; + auto var{dec_ctx.value_decls[use->getUser()]}; auto expr{expr_gen.CreateOperandExpr(*use)}; stmts.push_back(ast.CreateAssign(ast.CreateDeclRef(var), expr)); } @@ -1265,14 +1263,14 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { return; } - auto &decl{provenance.value_decls[&func]}; + auto &decl{dec_ctx.value_decls[&func]}; if (decl) { return; } DLOG(INFO) << "Creating FunctionDecl for " << name; auto tudecl{ast_ctx.getTranslationUnitDecl()}; - ExprGen expr_gen{ast_unit, provenance}; + ExprGen expr_gen{ast_unit, dec_ctx}; auto type{expr_gen.GetQualType(GetFixedFunctionType(func))}; decl = ast.CreateFunctionDecl(tudecl, type, name); @@ -1282,14 +1280,14 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { for (auto &arg : func.args()) { VisitArgument(arg); params.push_back( - clang::cast(provenance.value_decls[&arg])); + clang::cast(dec_ctx.value_decls[&arg])); } auto fdecl{decl->getAsFunction()}; fdecl->setParams(params); for (auto &inst : llvm::instructions(func)) { - auto &var{provenance.value_decls[&inst]}; + auto &var{dec_ctx.value_decls[&inst]}; if (auto alloca = llvm::dyn_cast(&inst)) { auto name{"var" + std::to_string(GetNumDecls(fdecl))}; // TLDR: Here we discard the variable name as present in the bitcode @@ -1346,7 +1344,7 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { auto bb{phi->getIncomingBlock(i)}; auto &use{phi->getOperandUse(i)}; - provenance.outgoing_uses[bb].push_back(&use); + dec_ctx.outgoing_uses[bb].push_back(&use); } } } @@ -1356,7 +1354,7 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { if (auto iasm = llvm::dyn_cast(&opnd)) { // TODO(frabert): We still need to find a way to embed the inline asm // into the function - auto &decl{provenance.value_decls[iasm]}; + auto &decl{dec_ctx.value_decls[iasm]}; if (decl) { return; } @@ -1386,12 +1384,12 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { } clang::Expr *IRToASTVisitor::CreateOperandExpr(llvm::Use &val) { - ExprGen expr_gen{ast_unit, provenance}; + ExprGen expr_gen{ast_unit, dec_ctx}; return expr_gen.CreateOperandExpr(val); } clang::Expr *IRToASTVisitor::CreateConstantExpr(llvm::Constant *constant) { - ExprGen expr_gen{ast_unit, provenance}; + ExprGen expr_gen{ast_unit, dec_ctx}; return expr_gen.CreateConstantExpr(constant); } diff --git a/lib/AST/InferenceRule.cpp b/lib/AST/InferenceRule.cpp index efcabcc8..7a50f5b0 100644 --- a/lib/AST/InferenceRule.cpp +++ b/lib/AST/InferenceRule.cpp @@ -14,7 +14,7 @@ namespace rellic { clang::Stmt *ApplyFirstMatchingRule( - Provenance &provenance, clang::ASTUnit &unit, clang::Stmt *stmt, + DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt, std::vector> &rules) { clang::ast_matchers::MatchFinder::MatchFinderOptions opts; clang::ast_matchers::MatchFinder finder(opts); @@ -27,7 +27,7 @@ clang::Stmt *ApplyFirstMatchingRule( for (auto &rule : rules) { if (*rule) { - return rule->GetOrCreateSubstitution(provenance, unit, stmt); + return rule->GetOrCreateSubstitution(dec_ctx, unit, stmt); } } diff --git a/lib/AST/LocalDeclRenamer.cpp b/lib/AST/LocalDeclRenamer.cpp index ae23fe92..7885007a 100644 --- a/lib/AST/LocalDeclRenamer.cpp +++ b/lib/AST/LocalDeclRenamer.cpp @@ -16,9 +16,9 @@ namespace rellic { -LocalDeclRenamer::LocalDeclRenamer(Provenance &provenance, clang::ASTUnit &unit, - IRToNameMap &names) - : TransformVisitor(provenance, unit), +LocalDeclRenamer::LocalDeclRenamer(DecompilationContext &dec_ctx, + clang::ASTUnit &unit, IRToNameMap &names) + : TransformVisitor(dec_ctx, unit), seen_names(1), names(names) {} @@ -78,7 +78,7 @@ bool LocalDeclRenamer::TraverseFunctionDecl(clang::FunctionDecl *decl) { bool LocalDeclRenamer::shouldTraversePostOrder() { return false; } void LocalDeclRenamer::RunImpl() { - for (auto &pair : provenance.value_decls) { + for (auto &pair : dec_ctx.value_decls) { decls[pair.second] = pair.first; } TraverseDecl(ast_ctx.getTranslationUnitDecl()); diff --git a/lib/AST/LoopRefine.cpp b/lib/AST/LoopRefine.cpp index e39a596d..c3720fcd 100644 --- a/lib/AST/LoopRefine.cpp +++ b/lib/AST/LoopRefine.cpp @@ -47,7 +47,7 @@ class WhileRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -64,11 +64,10 @@ class WhileRule : public InferenceRule { std::copy(comp->body_begin() + 1, comp->body_end(), std::back_inserter(new_body)); ASTBuilder ast(unit); - auto new_while{ast.CreateWhile(provenance.marker_expr, - ast.CreateCompoundStmt(new_body))}; - provenance.conds[new_while] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back( - !provenance.z3_exprs[provenance.conds[ifstmt]]); + auto new_while{ + ast.CreateWhile(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; + dec_ctx.conds[new_while] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(!dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]); return new_while; } }; @@ -90,7 +89,7 @@ class ElseWhileRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -102,9 +101,9 @@ class ElseWhileRule : public InferenceRule { auto ifstmt{clang::cast(comp->body_front())}; std::vector new_body; ASTBuilder ast(unit); - auto new_while{ast.CreateWhile(provenance.marker_expr, - ast.CreateCompoundStmt(new_body))}; - provenance.conds[new_while] = provenance.conds[ifstmt]; + auto new_while{ + ast.CreateWhile(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; + dec_ctx.conds[new_while] = dec_ctx.conds[ifstmt]; return new_while; } }; @@ -126,7 +125,7 @@ class DoWhileRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -136,19 +135,19 @@ class DoWhileRule : public InferenceRule { auto comp{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(comp->body_back())}; - auto cond{provenance.z3_exprs[provenance.conds[ifstmt]]}; + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]}; std::vector new_body(comp->body_begin(), comp->body_end() - 1); ASTBuilder ast(unit); if (auto else_stmt = ifstmt->getElse()) { - auto new_if{ast.CreateIf(provenance.marker_expr, else_stmt)}; - provenance.conds[new_if] = provenance.z3_exprs.size(); + auto new_if{ast.CreateIf(dec_ctx.marker_expr, else_stmt)}; + dec_ctx.conds[new_if] = dec_ctx.z3_exprs.size(); new_body.push_back(new_if); } auto new_do{ - ast.CreateDo(provenance.marker_expr, ast.CreateCompoundStmt(new_body))}; - provenance.conds[new_do] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(!cond); + ast.CreateDo(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; + dec_ctx.conds[new_do] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(!cond); return new_do; } }; @@ -170,7 +169,7 @@ class ElseDoWhileRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -188,8 +187,8 @@ class ElseDoWhileRule : public InferenceRule { new_body.push_back(ifstmt); auto new_do{ - ast.CreateDo(provenance.marker_expr, ast.CreateCompoundStmt(new_body))}; - provenance.conds[new_do] = provenance.conds[ifstmt]; + ast.CreateDo(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; + dec_ctx.conds[new_do] = dec_ctx.conds[ifstmt]; return new_do; } }; @@ -219,7 +218,7 @@ class NestedDoWhileRule : public InferenceRule { matched = true; } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -228,26 +227,26 @@ class NestedDoWhileRule : public InferenceRule { << "Substituted WhileStmt is not the matched WhileStmt!"; auto comp{clang::cast(loop->getBody())}; auto if_stmt{clang::cast(comp->body_back())}; - auto cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[if_stmt]]}; std::vector do_body(comp->body_begin(), comp->body_end() - 1); ASTBuilder ast(unit); if (auto else_stmt = if_stmt->getElse()) { - auto new_if{ast.CreateIf(provenance.marker_expr, else_stmt)}; - provenance.conds[new_if] = provenance.z3_exprs.size(); + auto new_if{ast.CreateIf(dec_ctx.marker_expr, else_stmt)}; + dec_ctx.conds[new_if] = dec_ctx.z3_exprs.size(); do_body.push_back(new_if); } auto do_stmt{ - ast.CreateDo(provenance.marker_expr, ast.CreateCompoundStmt(do_body))}; - provenance.conds[do_stmt] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(!cond); + ast.CreateDo(dec_ctx.marker_expr, ast.CreateCompoundStmt(do_body))}; + dec_ctx.conds[do_stmt] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(!cond); std::vector while_body({do_stmt, if_stmt->getThen()}); - auto new_while{ast.CreateWhile(provenance.marker_expr, + auto new_while{ast.CreateWhile(dec_ctx.marker_expr, ast.CreateCompoundStmt(while_body))}; - provenance.conds[new_while] = provenance.conds[loop]; + dec_ctx.conds[new_while] = dec_ctx.conds[loop]; return new_while; } }; @@ -274,7 +273,7 @@ class LoopToSeq : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop = clang::dyn_cast(stmt); @@ -328,7 +327,7 @@ class CondToSeqRule : public InferenceRule { match = result.Nodes.getNodeAs("while"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -339,17 +338,17 @@ class CondToSeqRule : public InferenceRule { ASTBuilder ast(unit); auto body{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(body->body_front())}; - auto inner_loop{ast.CreateWhile(provenance.marker_expr, ifstmt->getThen())}; - provenance.conds[inner_loop] = provenance.conds[ifstmt]; + auto inner_loop{ast.CreateWhile(dec_ctx.marker_expr, ifstmt->getThen())}; + dec_ctx.conds[inner_loop] = dec_ctx.conds[ifstmt]; std::vector new_body({inner_loop}); if (auto comp = clang::dyn_cast(ifstmt->getElse())) { new_body.insert(new_body.end(), comp->body_begin(), comp->body_end()); } else { new_body.push_back(ifstmt->getElse()); } - auto new_while{ast.CreateWhile(provenance.marker_expr, - ast.CreateCompoundStmt(new_body))}; - provenance.conds[new_while] = provenance.conds[loop]; + auto new_while{ + ast.CreateWhile(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; + dec_ctx.conds[new_while] = dec_ctx.conds[loop]; return new_while; } }; @@ -367,7 +366,7 @@ class CondToSeqNegRule : public InferenceRule { match = result.Nodes.getNodeAs("while"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -378,10 +377,10 @@ class CondToSeqNegRule : public InferenceRule { ASTBuilder ast(unit); auto body{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(body->body_front())}; - auto cond{provenance.z3_exprs[provenance.conds[ifstmt]]}; - auto inner_loop{ast.CreateWhile(provenance.marker_expr, ifstmt->getElse())}; - provenance.conds[inner_loop] = provenance.z3_exprs.size(); - provenance.z3_exprs.push_back(!cond); + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]}; + auto inner_loop{ast.CreateWhile(dec_ctx.marker_expr, ifstmt->getElse())}; + dec_ctx.conds[inner_loop] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(!cond); std::vector new_body({inner_loop}); if (auto comp = clang::dyn_cast(ifstmt->getThen())) { new_body.insert(new_body.end(), comp->body_begin(), comp->body_end()); @@ -389,17 +388,17 @@ class CondToSeqNegRule : public InferenceRule { new_body.push_back(ifstmt->getThen()); } - auto new_while{ast.CreateWhile(provenance.marker_expr, - ast.CreateCompoundStmt(new_body))}; - provenance.conds[new_while] = provenance.conds[loop]; + auto new_while{ + ast.CreateWhile(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; + dec_ctx.conds[new_while] = dec_ctx.conds[loop]; return new_while; } }; } // namespace -LoopRefine::LoopRefine(Provenance &provenance, clang::ASTUnit &u) - : TransformVisitor(provenance, u) {} +LoopRefine::LoopRefine(DecompilationContext &dec_ctx, clang::ASTUnit &u) + : TransformVisitor(dec_ctx, u) {} bool LoopRefine::VisitWhileStmt(clang::WhileStmt *loop) { // DLOG(INFO) << "VisitWhileStmt"; @@ -429,7 +428,7 @@ bool LoopRefine::VisitWhileStmt(clang::WhileStmt *loop) { rules.emplace_back(new ElseWhileRule); rules.emplace_back(new ElseDoWhileRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, loop, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, loop, rules)}; if (sub != loop) { substitutions[loop] = sub; } diff --git a/lib/AST/MaterializeConds.cpp b/lib/AST/MaterializeConds.cpp index d70cf745..c41ff94c 100644 --- a/lib/AST/MaterializeConds.cpp +++ b/lib/AST/MaterializeConds.cpp @@ -15,29 +15,30 @@ namespace rellic { -MaterializeConds::MaterializeConds(Provenance &provenance, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit), - ast_gen(unit, provenance) {} +MaterializeConds::MaterializeConds(DecompilationContext &dec_ctx, + clang::ASTUnit &unit) + : TransformVisitor(dec_ctx, unit), + ast_gen(unit, dec_ctx) {} bool MaterializeConds::VisitIfStmt(clang::IfStmt *stmt) { - auto cond{provenance.z3_exprs[provenance.conds[stmt]]}; - if (stmt->getCond() == provenance.marker_expr) { + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[stmt]]}; + if (stmt->getCond() == dec_ctx.marker_expr) { stmt->setCond(ast_gen.ConvertExpr(cond)); } return true; } bool MaterializeConds::VisitWhileStmt(clang::WhileStmt *stmt) { - auto cond{provenance.z3_exprs[provenance.conds[stmt]]}; - if (stmt->getCond() == provenance.marker_expr) { + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[stmt]]}; + if (stmt->getCond() == dec_ctx.marker_expr) { stmt->setCond(ast_gen.ConvertExpr(cond)); } return true; } bool MaterializeConds::VisitDoStmt(clang::DoStmt *stmt) { - auto cond{provenance.z3_exprs[provenance.conds[stmt]]}; - if (stmt->getCond() == provenance.marker_expr) { + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[stmt]]}; + if (stmt->getCond() == dec_ctx.marker_expr) { stmt->setCond(ast_gen.ConvertExpr(cond)); } return true; diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index 37e0b0da..3b7d423d 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -105,14 +105,14 @@ struct KnownExprs { class CompoundVisitor : public clang::StmtVisitor { private: - Provenance& provenance; + DecompilationContext& dec_ctx; ASTBuilder& ast; clang::ASTContext& ctx; public: - CompoundVisitor(Provenance& provenance, ASTBuilder& ast, + CompoundVisitor(DecompilationContext& dec_ctx, ASTBuilder& ast, clang::ASTContext& ctx) - : provenance(provenance), ast(ast), ctx(ctx) {} + : dec_ctx(dec_ctx), ast(ast), ctx(ctx) {} bool VisitCompoundStmt(clang::CompoundStmt* compound, KnownExprs& known_exprs) { @@ -126,12 +126,12 @@ class CompoundVisitor } bool VisitWhileStmt(clang::WhileStmt* while_stmt, KnownExprs& known_exprs) { - auto cond_idx{provenance.conds[while_stmt]}; + auto cond_idx{dec_ctx.conds[while_stmt]}; bool changed{false}; - auto old_cond{provenance.z3_exprs[cond_idx]}; + auto old_cond{dec_ctx.z3_exprs[cond_idx]}; auto new_cond{known_exprs.ApplyAssumptions(old_cond, changed)}; - if (while_stmt->getCond() != provenance.marker_expr && changed) { - provenance.z3_exprs.set(cond_idx, new_cond); + if (while_stmt->getCond() != dec_ctx.marker_expr && changed) { + dec_ctx.z3_exprs.set(cond_idx, new_cond); return true; } @@ -146,12 +146,12 @@ class CompoundVisitor } bool VisitDoStmt(clang::DoStmt* do_stmt, KnownExprs& known_exprs) { - auto cond_idx{provenance.conds[do_stmt]}; + auto cond_idx{dec_ctx.conds[do_stmt]}; bool changed{false}; - auto old_cond{provenance.z3_exprs[cond_idx]}; + auto old_cond{dec_ctx.z3_exprs[cond_idx]}; auto new_cond{known_exprs.ApplyAssumptions(old_cond, changed)}; - if (do_stmt->getCond() == provenance.marker_expr && changed) { - provenance.z3_exprs.set(cond_idx, new_cond); + if (do_stmt->getCond() == dec_ctx.marker_expr && changed) { + dec_ctx.z3_exprs.set(cond_idx, new_cond); return true; } @@ -166,12 +166,12 @@ class CompoundVisitor } bool VisitIfStmt(clang::IfStmt* if_stmt, KnownExprs& known_exprs) { - auto cond_idx{provenance.conds[if_stmt]}; + auto cond_idx{dec_ctx.conds[if_stmt]}; bool changed{false}; - auto old_cond{provenance.z3_exprs[cond_idx]}; + auto old_cond{dec_ctx.z3_exprs[cond_idx]}; auto new_cond{known_exprs.ApplyAssumptions(old_cond, changed)}; - if (if_stmt->getCond() == provenance.marker_expr && changed) { - provenance.z3_exprs.set(cond_idx, new_cond); + if (if_stmt->getCond() == dec_ctx.marker_expr && changed) { + dec_ctx.z3_exprs.set(cond_idx, new_cond); return true; } @@ -192,14 +192,15 @@ class CompoundVisitor } }; -NestedCondProp::NestedCondProp(Provenance& provenance, clang::ASTUnit& unit) - : ASTPass(provenance, unit) {} +NestedCondProp::NestedCondProp(DecompilationContext& dec_ctx, + clang::ASTUnit& unit) + : ASTPass(dec_ctx, unit) {} void NestedCondProp::RunImpl() { LOG(INFO) << "Propagating conditions"; changed = false; ASTBuilder ast{ast_unit}; - CompoundVisitor visitor{provenance, ast, ast_ctx}; + CompoundVisitor visitor{dec_ctx, ast, ast_ctx}; for (auto decl : ast_ctx.getTranslationUnitDecl()->decls()) { if (auto fdecl = clang::dyn_cast(decl)) { diff --git a/lib/AST/NestedScopeCombine.cpp b/lib/AST/NestedScopeCombine.cpp index ca18ae6e..90fccd21 100644 --- a/lib/AST/NestedScopeCombine.cpp +++ b/lib/AST/NestedScopeCombine.cpp @@ -15,15 +15,15 @@ namespace rellic { -NestedScopeCombine::NestedScopeCombine(Provenance &provenance, +NestedScopeCombine::NestedScopeCombine(DecompilationContext &dec_ctx, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit) {} + : TransformVisitor(dec_ctx, unit) {} bool NestedScopeCombine::VisitIfStmt(clang::IfStmt *ifstmt) { // DLOG(INFO) << "VisitIfStmt"; // Determine whether `cond` is a constant expression that is always true and // `ifstmt` should be replaced by `then` in it's parent nodes. - auto cond{provenance.z3_exprs[provenance.conds[ifstmt]]}; + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]}; if (Prove(cond)) { substitutions[ifstmt] = ifstmt->getThen(); } else if (Prove(!cond) && ifstmt->getElse()) { @@ -33,7 +33,7 @@ bool NestedScopeCombine::VisitIfStmt(clang::IfStmt *ifstmt) { } bool NestedScopeCombine::VisitWhileStmt(clang::WhileStmt *stmt) { - auto cond{provenance.z3_exprs[provenance.conds[stmt]]}; + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[stmt]]}; if (Prove(cond)) { auto body{clang::cast(stmt->getBody())}; if (clang::isa(body->body_back())) { diff --git a/lib/AST/ReachBasedRefine.cpp b/lib/AST/ReachBasedRefine.cpp index 533be292..824f4a97 100644 --- a/lib/AST/ReachBasedRefine.cpp +++ b/lib/AST/ReachBasedRefine.cpp @@ -15,13 +15,14 @@ namespace rellic { -ReachBasedRefine::ReachBasedRefine(Provenance &provenance, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit) {} +ReachBasedRefine::ReachBasedRefine(DecompilationContext &dec_ctx, + clang::ASTUnit &unit) + : TransformVisitor(dec_ctx, unit) {} bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { std::vector body{compound->body_begin(), compound->body_end()}; std::vector ifs; - z3::expr_vector conds{provenance.z3_ctx}; + z3::expr_vector conds{dec_ctx.z3_ctx}; auto ResetChain = [&]() { ifs.clear(); @@ -37,7 +38,7 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { } ifs.push_back(if_stmt); - auto cond{provenance.z3_exprs[provenance.conds[if_stmt]]}; + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[if_stmt]]}; if (if_stmt->getElse()) { // We cannot link `if` statements that contain `else` branches diff --git a/lib/AST/StructFieldRenamer.cpp b/lib/AST/StructFieldRenamer.cpp index 6fad637c..3839322f 100644 --- a/lib/AST/StructFieldRenamer.cpp +++ b/lib/AST/StructFieldRenamer.cpp @@ -17,10 +17,10 @@ namespace rellic { -StructFieldRenamer::StructFieldRenamer(Provenance &provenance, +StructFieldRenamer::StructFieldRenamer(DecompilationContext &dec_ctx, clang::ASTUnit &unit, IRTypeToDITypeMap &types) - : ASTPass(provenance, unit), types(types) {} + : ASTPass(dec_ctx, unit), types(types) {} bool StructFieldRenamer::VisitRecordDecl(clang::RecordDecl *decl) { auto type{decls[decl]}; @@ -68,7 +68,7 @@ bool StructFieldRenamer::VisitRecordDecl(clang::RecordDecl *decl) { void StructFieldRenamer::RunImpl() { LOG(INFO) << "Renaming struct fields"; - for (auto &pair : provenance.type_decls) { + for (auto &pair : dec_ctx.type_decls) { decls[pair.second] = pair.first; } TraverseDecl(ast_ctx.getTranslationUnitDecl()); diff --git a/lib/AST/Z3CondSimplify.cpp b/lib/AST/Z3CondSimplify.cpp index ff7fe311..af0aef72 100644 --- a/lib/AST/Z3CondSimplify.cpp +++ b/lib/AST/Z3CondSimplify.cpp @@ -15,14 +15,15 @@ namespace rellic { -Z3CondSimplify::Z3CondSimplify(Provenance &provenance, clang::ASTUnit &unit) - : ASTPass(provenance, unit) {} +Z3CondSimplify::Z3CondSimplify(DecompilationContext& dec_ctx, + clang::ASTUnit& unit) + : ASTPass(dec_ctx, unit) {} void Z3CondSimplify::RunImpl() { LOG(INFO) << "Simplifying conditions using Z3"; - for (size_t i{0}; i < provenance.z3_exprs.size() && !Stopped(); ++i) { - auto simpl{Sort(provenance.z3_exprs[i].simplify())}; - provenance.z3_exprs.set(i, simpl); + for (size_t i{0}; i < dec_ctx.z3_exprs.size() && !Stopped(); ++i) { + auto simpl{Sort(dec_ctx.z3_exprs[i].simplify())}; + dec_ctx.z3_exprs.set(i, simpl); } } diff --git a/lib/Decompiler.cpp b/lib/Decompiler.cpp index 1297d3cf..5bae24a0 100644 --- a/lib/Decompiler.cpp +++ b/lib/Decompiler.cpp @@ -85,92 +85,90 @@ Result Decompile( auto ast_unit{clang::tooling::buildASTFromCodeWithArgs("", args, "out.c")}; ASTBuilder ast{*ast_unit}; - rellic::Provenance provenance; - provenance.marker_expr = - ast.CreateAdd(ast.CreateFalse(), ast.CreateFalse()); - rellic::GenerateAST::run(*module, provenance, *ast_unit); + rellic::DecompilationContext dec_ctx; + dec_ctx.marker_expr = ast.CreateAdd(ast.CreateFalse(), ast.CreateFalse()); + rellic::GenerateAST::run(*module, dec_ctx, *ast_unit); // TODO(surovic): Add llvm::Value* -> clang::Decl* map // Especially for llvm::Argument* and llvm::Function*. - rellic::CompositeASTPass pass_ast(provenance, *ast_unit); + rellic::CompositeASTPass pass_ast(dec_ctx, *ast_unit); auto& ast_passes{pass_ast.GetPasses()}; ast_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); ast_passes.push_back(std::make_unique( - provenance, *ast_unit, dic.GetIRToNameMap())); + dec_ctx, *ast_unit, dic.GetIRToNameMap())); ast_passes.push_back(std::make_unique( - provenance, *ast_unit, dic.GetIRTypeToDITypeMap())); + dec_ctx, *ast_unit, dic.GetIRTypeToDITypeMap())); pass_ast.Run(); - rellic::CompositeASTPass pass_cbr(provenance, *ast_unit); + rellic::CompositeASTPass pass_cbr(dec_ctx, *ast_unit); auto& cbr_passes{pass_cbr.GetPasses()}; cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); while (pass_cbr.Run()) { ; } - rellic::CompositeASTPass pass_loop{provenance, *ast_unit}; + rellic::CompositeASTPass pass_loop{dec_ctx, *ast_unit}; auto& loop_passes{pass_loop.GetPasses()}; loop_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); loop_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); loop_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); while (pass_loop.Run()) { ; } - rellic::CompositeASTPass pass_scope{provenance, *ast_unit}; + rellic::CompositeASTPass pass_scope{dec_ctx, *ast_unit}; auto& scope_passes{pass_scope.GetPasses()}; scope_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); scope_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); scope_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); while (pass_scope.Run()) { ; } - rellic::CompositeASTPass pass_ec{provenance, *ast_unit}; + rellic::CompositeASTPass pass_ec{dec_ctx, *ast_unit}; auto& ec_passes{pass_ec.GetPasses()}; ec_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); ec_passes.push_back( - std::make_unique(provenance, *ast_unit)); + std::make_unique(dec_ctx, *ast_unit)); pass_ec.Run(); DecompilationResult result{}; result.ast = std::move(ast_unit); result.module = std::move(module); - CopyMap(provenance.stmt_provenance, result.stmt_provenance_map, + CopyMap(dec_ctx.stmt_provenance, result.stmt_provenance_map, result.value_to_stmt_map); - CopyMap(provenance.value_decls, result.value_to_decl_map, + CopyMap(dec_ctx.value_decls, result.value_to_decl_map, result.decl_provenance_map); - CopyMap(provenance.type_decls, result.type_to_decl_map, + CopyMap(dec_ctx.type_decls, result.type_to_decl_map, result.type_provenance_map); - CopyMap(provenance.use_provenance, result.expr_use_map, - result.use_expr_map); + CopyMap(dec_ctx.use_provenance, result.expr_use_map, result.use_expr_map); return Result(std::move(result)); } catch (Exception& ex) { diff --git a/tools/repl/Repl.cpp b/tools/repl/Repl.cpp index facf8ac8..788de1f8 100644 --- a/tools/repl/Repl.cpp +++ b/tools/repl/Repl.cpp @@ -61,7 +61,7 @@ DECLARE_bool(version); llvm::LLVMContext llvm_ctx; std::unique_ptr module{nullptr}; std::unique_ptr ast_unit{nullptr}; -std::unique_ptr provenance; +std::unique_ptr dec_ctx; std::unique_ptr global_pass{nullptr}; static void SetVersion(void) { @@ -147,23 +147,23 @@ class Diff { static std::unique_ptr CreatePass(const std::string& name) { if (name == "cbr") { - return std::make_unique(*provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "dse") { - return std::make_unique(*provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "ec") { - return std::make_unique(*provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "lr") { - return std::make_unique(*provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "mc") { - return std::make_unique(*provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "ncp") { - return std::make_unique(*provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "nsc") { - return std::make_unique(*provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "rbr") { - return std::make_unique(*provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "zcs") { - return std::make_unique(*provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else { return nullptr; } @@ -218,7 +218,7 @@ static void do_load(std::istream& is) { std::vector args{"-Wno-pointer-to-int-cast", "-Wno-pointer-sign", "-target", module->getTargetTriple()}; ast_unit = clang::tooling::buildASTFromCodeWithArgs("", args, "out.c"); - provenance = {}; + dec_ctx = {}; std::cout << "ok." << std::endl; } @@ -300,10 +300,10 @@ static void do_decompile() { try { rellic::DebugInfoCollector dic; dic.visit(*module); - provenance = {}; - rellic::GenerateAST::run(*module, *provenance, *ast_unit); - rellic::LocalDeclRenamer ldr{*provenance, *ast_unit, dic.GetIRToNameMap()}; - rellic::StructFieldRenamer sfr{*provenance, *ast_unit, + dec_ctx = {}; + rellic::GenerateAST::run(*module, *dec_ctx, *ast_unit); + rellic::LocalDeclRenamer ldr{*dec_ctx, *ast_unit, dic.GetIRToNameMap()}; + rellic::StructFieldRenamer sfr{*dec_ctx, *ast_unit, dic.GetIRTypeToDITypeMap()}; ldr.Run(); sfr.Run(); @@ -319,8 +319,7 @@ static void do_run(std::istream& is) { return; } - global_pass = - std::make_unique(*provenance, *ast_unit); + global_pass = std::make_unique(*dec_ctx, *ast_unit); std::string name; while (is >> name) { auto pass{CreatePass(name)}; @@ -357,8 +356,7 @@ static void do_fixpoint(std::istream& is) { return; } - global_pass = - std::make_unique(*provenance, *ast_unit); + global_pass = std::make_unique(*dec_ctx, *ast_unit); std::string name; while (is >> name) { auto pass{CreatePass(name)}; diff --git a/tools/xref/Xref.cpp b/tools/xref/Xref.cpp index c6411193..315de3ef 100644 --- a/tools/xref/Xref.cpp +++ b/tools/xref/Xref.cpp @@ -111,7 +111,7 @@ struct Session { std::unique_ptr Module; std::unique_ptr Unit; std::unique_ptr Pass; - std::unique_ptr Provenance; + std::unique_ptr DecompContext; // Must always be acquired in this order and released all at once std::shared_mutex LoadMutex, MutationMutex; }; @@ -267,21 +267,21 @@ static void Decompile(const httplib::Request& req, httplib::Response& res) { } try { - session.Provenance = std::make_unique(); + session.DecompContext = std::make_unique(); std::vector args{"-Wno-pointer-to-int-cast", "-Wno-pointer-sign", "-target", session.Module->getTargetTriple()}; session.Unit = clang::tooling::buildASTFromCodeWithArgs("", args, "out.c"); rellic::ASTBuilder ast{*session.Unit}; - session.Provenance->marker_expr = + session.DecompContext->marker_expr = ast.CreateAdd(ast.CreateFalse(), ast.CreateFalse()); rellic::DebugInfoCollector dic; dic.visit(*session.Module); - rellic::GenerateAST::run(*session.Module, *session.Provenance, + rellic::GenerateAST::run(*session.Module, *session.DecompContext, *session.Unit); - rellic::LocalDeclRenamer ldr{*session.Provenance, *session.Unit, + rellic::LocalDeclRenamer ldr{*session.DecompContext, *session.Unit, dic.GetIRToNameMap()}; - rellic::StructFieldRenamer sfr{*session.Provenance, *session.Unit, + rellic::StructFieldRenamer sfr{*session.DecompContext, *session.Unit, dic.GetIRTypeToDITypeMap()}; ldr.Run(); sfr.Run(); @@ -382,8 +382,8 @@ class FixpointPass : public rellic::ASTPass { void RunImpl() override { comp.Fixpoint(); } public: - FixpointPass(rellic::Provenance& provenance, clang::ASTUnit& ast_unit) - : ASTPass(provenance, ast_unit), comp(provenance, ast_unit) {} + FixpointPass(rellic::DecompilationContext& dec_ctx, clang::ASTUnit& ast_unit) + : ASTPass(dec_ctx, ast_unit), comp(dec_ctx, ast_unit) {} std::vector>& GetPasses() { return comp.GetPasses(); } @@ -400,31 +400,31 @@ static std::unique_ptr CreatePass( auto str{name->str()}; if (str == "cbr") { - return std::make_unique(*session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else if (str == "dse") { - return std::make_unique(*session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else if (str == "ec") { - return std::make_unique(*session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else if (str == "lr") { - return std::make_unique(*session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else if (str == "mc") { - return std::make_unique(*session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else if (str == "ncp") { - return std::make_unique(*session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else if (str == "nsc") { - return std::make_unique(*session.Provenance, - *session.Unit); + return std::make_unique( + *session.DecompContext, *session.Unit); } else if (str == "rbr") { - return std::make_unique(*session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else if (str == "zcs") { - return std::make_unique(*session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else { LOG(ERROR) << "Request contains invalid pass id"; @@ -432,7 +432,7 @@ static std::unique_ptr CreatePass( } } else if (auto arr = val.getAsArray()) { auto fix{ - std::make_unique(*session.Provenance, *session.Unit)}; + std::make_unique(*session.DecompContext, *session.Unit)}; for (auto& pass : *arr) { auto p{CreatePass(session, pass)}; if (!p) { @@ -514,8 +514,8 @@ static void Run(const httplib::Request& req, httplib::Response& res) { return; } - auto composite{std::make_unique(*session.Provenance, - *session.Unit)}; + auto composite{std::make_unique( + *session.DecompContext, *session.Unit)}; for (auto& obj : *json->getAsArray()) { auto pass{CreatePass(session, obj)}; if (!pass) { @@ -584,8 +584,8 @@ static void Fixpoint(const httplib::Request& req, httplib::Response& res) { return; } - auto composite{std::make_unique(*session.Provenance, - *session.Unit)}; + auto composite{std::make_unique( + *session.DecompContext, *session.Unit)}; for (auto& obj : *json->getAsArray()) { auto pass{CreatePass(session, obj)}; if (!pass) { @@ -717,11 +717,11 @@ static void PrintAST(const httplib::Request& req, httplib::Response& res) { rellic::DecompilationResult::IRToTypeDeclMap type_to_decl_map; rellic::DecompilationResult::TypeDeclToIRMap type_provenance_map; - CopyMap(session.Provenance->stmt_provenance, stmt_provenance_map, + CopyMap(session.DecompContext->stmt_provenance, stmt_provenance_map, value_to_stmt_map); - CopyMap(session.Provenance->value_decls, value_to_decl_map, + CopyMap(session.DecompContext->value_decls, value_to_decl_map, decl_provenance_map); - CopyMap(session.Provenance->type_decls, type_to_decl_map, + CopyMap(session.DecompContext->type_decls, type_to_decl_map, type_provenance_map); std::string s; @@ -803,31 +803,31 @@ static void PrintProvenance(const httplib::Request& req, } llvm::json::Array stmt_provenance; - for (auto elem : session.Provenance->stmt_provenance) { + for (auto elem : session.DecompContext->stmt_provenance) { stmt_provenance.push_back(llvm::json::Array( {(unsigned long long)elem.first, (unsigned long long)elem.second})); } llvm::json::Array type_decls; - for (auto elem : session.Provenance->type_decls) { + for (auto elem : session.DecompContext->type_decls) { type_decls.push_back(llvm::json::Array( {(unsigned long long)elem.first, (unsigned long long)elem.second})); } llvm::json::Array value_decls; - for (auto elem : session.Provenance->value_decls) { + for (auto elem : session.DecompContext->value_decls) { value_decls.push_back(llvm::json::Array( {(unsigned long long)elem.first, (unsigned long long)elem.second})); } llvm::json::Array temp_decls; - for (auto elem : session.Provenance->temp_decls) { + for (auto elem : session.DecompContext->temp_decls) { temp_decls.push_back(llvm::json::Array( {(unsigned long long)elem.first, (unsigned long long)elem.second})); } llvm::json::Array use_provenance; - for (auto elem : session.Provenance->use_provenance) { + for (auto elem : session.DecompContext->use_provenance) { if (!elem.second) { continue; } From 39edaa3bfc5260c85544468b91df8088baba3847 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Fri, 26 Aug 2022 11:27:02 +0200 Subject: [PATCH 34/57] Remove temporary variables for `load`s --- include/rellic/BC/Util.h | 2 -- lib/AST/IRToASTVisitor.cpp | 4 --- lib/BC/Util.cpp | 57 -------------------------------------- lib/Decompiler.cpp | 1 - tools/xref/Xref.cpp | 1 - 5 files changed, 65 deletions(-) diff --git a/include/rellic/BC/Util.h b/include/rellic/BC/Util.h index 2435395a..a08e5aa2 100644 --- a/include/rellic/BC/Util.h +++ b/include/rellic/BC/Util.h @@ -67,6 +67,4 @@ void RemoveInsertValues(llvm::Module &module); // Converts by value array arguments and wraps them into a struct, so that // semantics are preserved in C void ConvertArrayArguments(llvm::Module &module); - -void FindRedundantLoads(llvm::Module &module); } // namespace rellic \ No newline at end of file diff --git a/lib/AST/IRToASTVisitor.cpp b/lib/AST/IRToASTVisitor.cpp index 1bc58662..19c872f7 100755 --- a/lib/AST/IRToASTVisitor.cpp +++ b/lib/AST/IRToASTVisitor.cpp @@ -1313,11 +1313,7 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { fdecl->addDecl(var); } else if (inst.hasNUsesOrMore(2) || (inst.hasNUsesOrMore(1) && llvm::isa(inst)) || - llvm::isa(inst) || llvm::isa(inst)) { - if (inst.getMetadata("rellic.notemp")) { - continue; - } if (!inst.getType()->isVoidTy()) { auto GetPrefix{[&](llvm::Instruction *inst) { if (llvm::isa(inst)) { diff --git a/lib/BC/Util.cpp b/lib/BC/Util.cpp index 76052e4c..63f7dd0f 100644 --- a/lib/BC/Util.cpp +++ b/lib/BC/Util.cpp @@ -437,61 +437,4 @@ void ConvertArrayArguments(llvm::Module &m) { CHECK(VerifyModule(&m)) << "Transformation broke module correctness"; } -static void FindRedundantLoads(llvm::Function &func) { - auto HasStoreBeforeUse = [&](llvm::Value *ptr, llvm::User *user, - llvm::LoadInst *load) { - std::unordered_set visited_blocks; - std::vector work_list; - auto inst{llvm::cast(user)}; - work_list.push_back(inst); - while (!work_list.empty()) { - inst = work_list.back(); - work_list.pop_back(); - auto bb{inst->getParent()}; - visited_blocks.insert(bb); - - while (inst) { - if (auto store = llvm::dyn_cast(inst)) { - if (store->getPointerOperand() == ptr) { - return true; - } - } - if (inst == load) { - return false; - } - inst = inst->getPrevNode(); - } - for (auto pred : llvm::predecessors(bb)) { - if (!visited_blocks.count(pred)) { - work_list.push_back(pred->getTerminator()); - } - } - } - - return false; - }; - - for (auto &inst : llvm::instructions(func)) { - if (auto load = llvm::dyn_cast(&inst)) { - for (auto &use : load->uses()) { - auto ptr{load->getPointerOperand()}; - if (!llvm::isa(ptr)) { - continue; - } - if (HasStoreBeforeUse(ptr, use.getUser(), load)) { - continue; - } - load->setMetadata("rellic.notemp", - llvm::MDNode::get(func.getContext(), {})); - } - } - } -} - -void FindRedundantLoads(llvm::Module &module) { - for (auto &func : module.functions()) { - FindRedundantLoads(func); - } -} - } // namespace rellic \ No newline at end of file diff --git a/lib/Decompiler.cpp b/lib/Decompiler.cpp index 5bae24a0..001d8748 100644 --- a/lib/Decompiler.cpp +++ b/lib/Decompiler.cpp @@ -73,7 +73,6 @@ Result Decompile( ConvertArrayArguments(*module); RemoveInsertValues(*module); - FindRedundantLoads(*module); InitOptPasses(); rellic::DebugInfoCollector dic; diff --git a/tools/xref/Xref.cpp b/tools/xref/Xref.cpp index 315de3ef..c5200006 100644 --- a/tools/xref/Xref.cpp +++ b/tools/xref/Xref.cpp @@ -248,7 +248,6 @@ static void LoadModule(const httplib::Request& req, httplib::Response& res) { return; } session.Module = std::unique_ptr(mod); - rellic::FindRedundantLoads(*session.Module); llvm::json::Object msg{{"message", "Ok."}}; SendJSON(res, msg); res.status = 200; From fed7db1864c69cda3b8b33f156a668f3666e66a3 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 09:49:00 +0200 Subject: [PATCH 35/57] Rename `POISON_IDX` --- include/rellic/AST/GenerateAST.h | 3 ++- lib/AST/GenerateAST.cpp | 16 ++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/include/rellic/AST/GenerateAST.h b/include/rellic/AST/GenerateAST.h index e976292c..4584a03a 100644 --- a/include/rellic/AST/GenerateAST.h +++ b/include/rellic/AST/GenerateAST.h @@ -13,6 +13,7 @@ #include #include +#include #include #include @@ -25,7 +26,7 @@ class GenerateAST : public llvm::AnalysisInfoMixin { friend llvm::AnalysisInfoMixin; static llvm::AnalysisKey Key; - constexpr static unsigned POISON_IDX = static_cast(-1); + constexpr static unsigned poison_idx = std::numeric_limits::max(); // Need to use `map` with these instead of `unordered_map`, because // `std::pair` doesn't have a default hash implementation diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index afe2c720..229800d0 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -140,7 +140,7 @@ static std::string GetName(llvm::Value *v) { unsigned GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond) { auto ToExpr = [&](unsigned idx) { - if (idx == POISON_IDX) { + if (idx == poison_idx) { return dec_ctx.z3_ctx.bool_val(false); } return dec_ctx.z3_exprs[idx]; @@ -184,7 +184,7 @@ unsigned GenerateAST::GetOrCreateVarForSwitch(llvm::SwitchInst *inst) { unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, llvm::ConstantInt *c) { auto ToExpr = [&](unsigned idx) { - if (idx == POISON_IDX) { + if (idx == poison_idx) { return dec_ctx.z3_ctx.bool_val(false); } return dec_ctx.z3_exprs[idx]; @@ -215,7 +215,7 @@ unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, unsigned GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, llvm::BasicBlock *to) { auto ToExpr = [&](unsigned idx) { - if (idx == POISON_IDX) { + if (idx == poison_idx) { return dec_ctx.z3_ctx.bool_val(false); } return dec_ctx.z3_exprs[idx]; @@ -277,7 +277,7 @@ unsigned GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, unsigned GenerateAST::GetReachingCond(llvm::BasicBlock *block) { if (dec_ctx.reaching_conds.find(block) == dec_ctx.reaching_conds.end()) { - return POISON_IDX; + return poison_idx; } return dec_ctx.reaching_conds[block]; @@ -285,7 +285,7 @@ unsigned GenerateAST::GetReachingCond(llvm::BasicBlock *block) { void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { auto ToExpr = [&](unsigned idx) { - if (idx == POISON_IDX) { + if (idx == poison_idx) { return dec_ctx.z3_ctx.bool_val(false); } return dec_ctx.z3_exprs[idx]; @@ -310,7 +310,7 @@ void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { } auto cond{HeavySimplify(z3::mk_or(conds))}; - if (old_cond_idx == POISON_IDX || !Prove(old_cond == cond)) { + if (old_cond_idx == poison_idx || !Prove(old_cond == cond)) { dec_ctx.reaching_conds[block] = dec_ctx.z3_exprs.size(); dec_ctx.z3_exprs.push_back(cond); reaching_conds_changed = true; @@ -332,7 +332,7 @@ StmtVec GenerateAST::CreateBasicBlockStmts(llvm::BasicBlock *block) { StmtVec GenerateAST::CreateRegionStmts(llvm::Region *region) { auto ToExpr = [&](unsigned idx) { - if (idx == POISON_IDX) { + if (idx == poison_idx) { return dec_ctx.z3_ctx.bool_val(false); } return dec_ctx.z3_exprs[idx]; @@ -424,7 +424,7 @@ clang::CompoundStmt *GenerateAST::StructureAcyclicRegion(llvm::Region *region) { clang::CompoundStmt *GenerateAST::StructureCyclicRegion(llvm::Region *region) { auto ToExpr = [&](unsigned idx) { - if (idx == POISON_IDX) { + if (idx == poison_idx) { return dec_ctx.z3_ctx.bool_val(false); } return dec_ctx.z3_exprs[idx]; From 94b4dd534648f96b5e7335201e6da9d70b6b1891 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 09:53:28 +0200 Subject: [PATCH 36/57] Put all aliases inside `DecompilationContext` --- include/rellic/AST/Util.h | 29 +++++++++++++++-------------- lib/AST/Util.cpp | 7 ++++--- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/include/rellic/AST/Util.h b/include/rellic/AST/Util.h index c5006dd0..4a5d3d46 100644 --- a/include/rellic/AST/Util.h +++ b/include/rellic/AST/Util.h @@ -51,20 +51,21 @@ size_t GetNumDecls(clang::DeclContext *decl_ctx) { return result; } -using StmtToIRMap = std::unordered_map; -using ExprToUseMap = std::unordered_map; -using IRToTypeDeclMap = std::unordered_map; -using IRToValDeclMap = std::unordered_map; -using IRToStmtMap = std::unordered_map; -using ArgToTempMap = std::unordered_map; -using BlockToUsesMap = - std::unordered_map>; -using Z3CondMap = std::unordered_map; - -using BBEdge = std::pair; -using BrEdge = std::pair; -using SwEdge = std::pair; struct DecompilationContext { + using StmtToIRMap = std::unordered_map; + using ExprToUseMap = std::unordered_map; + using IRToTypeDeclMap = std::unordered_map; + using IRToValDeclMap = std::unordered_map; + using IRToStmtMap = std::unordered_map; + using ArgToTempMap = std::unordered_map; + using BlockToUsesMap = + std::unordered_map>; + using Z3CondMap = std::unordered_map; + + using BBEdge = std::pair; + using BrEdge = std::pair; + using SwEdge = std::pair; + StmtToIRMap stmt_provenance; ExprToUseMap use_provenance; IRToTypeDeclMap type_decls; @@ -98,7 +99,7 @@ void CopyProvenance(TKey1 *from, TKey2 *to, } clang::Expr *Clone(clang::ASTUnit &unit, clang::Expr *stmt, - ExprToUseMap &provenance); + DecompilationContext::ExprToUseMap &provenance); std::string ClangThingToString(const clang::Stmt *stmt); diff --git a/lib/AST/Util.cpp b/lib/AST/Util.cpp index ad7fb6c3..0e8ae9f1 100644 --- a/lib/AST/Util.cpp +++ b/lib/AST/Util.cpp @@ -198,10 +198,11 @@ bool IsEquivalent(clang::Expr *a, clang::Expr *b) { class ExprCloner : public clang::StmtVisitor { ASTBuilder ast; clang::ASTContext &ctx; - ExprToUseMap &provenance; + DecompilationContext::ExprToUseMap &provenance; public: - ExprCloner(clang::ASTUnit &unit, ExprToUseMap &provenance) + ExprCloner(clang::ASTUnit &unit, + DecompilationContext::ExprToUseMap &provenance) : ast(unit), ctx(unit.getASTContext()), provenance(provenance) {} clang::Expr *VisitIntegerLiteral(clang::IntegerLiteral *expr) { @@ -313,7 +314,7 @@ class ExprCloner : public clang::StmtVisitor { }; clang::Expr *Clone(clang::ASTUnit &unit, clang::Expr *expr, - ExprToUseMap &provenance) { + DecompilationContext::ExprToUseMap &provenance) { ExprCloner cloner{unit, provenance}; return CHECK_NOTNULL(cloner.Visit(CHECK_NOTNULL(expr))); } From 923fb57ab18823b6ab8f0e72ff035b5ce117a3e8 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 09:56:35 +0200 Subject: [PATCH 37/57] Rename `Sort` to `OrderById` --- include/rellic/AST/Util.h | 2 +- lib/AST/Util.cpp | 6 +++--- lib/AST/Z3CondSimplify.cpp | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/rellic/AST/Util.h b/include/rellic/AST/Util.h index 4a5d3d46..8f7280af 100644 --- a/include/rellic/AST/Util.h +++ b/include/rellic/AST/Util.h @@ -112,5 +112,5 @@ z3::expr_vector Clone(z3::expr_vector &vec); // Tries to keep each subformula sorted by its id so that they don't get // shuffled around by simplification -z3::expr Sort(z3::expr expr); +z3::expr OrderById(z3::expr expr); } // namespace rellic \ No newline at end of file diff --git a/lib/AST/Util.cpp b/lib/AST/Util.cpp index 0e8ae9f1..b53d0c7c 100644 --- a/lib/AST/Util.cpp +++ b/lib/AST/Util.cpp @@ -360,7 +360,7 @@ z3::expr_vector Clone(z3::expr_vector &vec) { return clone; } -z3::expr Sort(z3::expr expr) { +z3::expr OrderById(z3::expr expr) { if (expr.is_and() || expr.is_or()) { std::vector args_indices(expr.num_args(), 0); std::iota(args_indices.begin(), args_indices.end(), 0); @@ -370,7 +370,7 @@ z3::expr Sort(z3::expr expr) { }); z3::expr_vector new_args{expr.ctx()}; for (auto idx : args_indices) { - new_args.push_back(Sort(expr.arg(idx))); + new_args.push_back(OrderById(expr.arg(idx))); } if (expr.is_and()) { return z3::mk_and(new_args); @@ -380,7 +380,7 @@ z3::expr Sort(z3::expr expr) { } if (expr.is_not()) { - return !Sort(expr.arg(0)); + return !OrderById(expr.arg(0)); } return expr; diff --git a/lib/AST/Z3CondSimplify.cpp b/lib/AST/Z3CondSimplify.cpp index af0aef72..828df587 100644 --- a/lib/AST/Z3CondSimplify.cpp +++ b/lib/AST/Z3CondSimplify.cpp @@ -22,7 +22,7 @@ Z3CondSimplify::Z3CondSimplify(DecompilationContext& dec_ctx, void Z3CondSimplify::RunImpl() { LOG(INFO) << "Simplifying conditions using Z3"; for (size_t i{0}; i < dec_ctx.z3_exprs.size() && !Stopped(); ++i) { - auto simpl{Sort(dec_ctx.z3_exprs[i].simplify())}; + auto simpl{OrderById(dec_ctx.z3_exprs[i].simplify())}; dec_ctx.z3_exprs.set(i, simpl); } } From e8ece2933fd1328c29c3b301c9993e5d04de2052 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 09:57:21 +0200 Subject: [PATCH 38/57] Revert useless change --- lib/AST/CondBasedRefine.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/AST/CondBasedRefine.cpp b/lib/AST/CondBasedRefine.cpp index 977ec3c7..0d6c7062 100644 --- a/lib/AST/CondBasedRefine.cpp +++ b/lib/AST/CondBasedRefine.cpp @@ -94,8 +94,7 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { } } if (did_something) { - auto new_compound{ast.CreateCompoundStmt(body)}; - substitutions[compound] = new_compound; + substitutions[compound] = ast.CreateCompoundStmt(body); } return !Stopped(); } From 23e9e75dc9e17de4c96868767bb7c204370f917e Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 10:10:43 +0200 Subject: [PATCH 39/57] Simplify exit --- lib/AST/CondBasedRefine.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lib/AST/CondBasedRefine.cpp b/lib/AST/CondBasedRefine.cpp index 0d6c7062..e6ac58e1 100644 --- a/lib/AST/CondBasedRefine.cpp +++ b/lib/AST/CondBasedRefine.cpp @@ -66,10 +66,7 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { body[i] = new_if; body.erase(std::next(body.begin(), i + 1)); did_something = true; - break; - } - - if (Prove(cond_a == !cond_b)) { + } else if (Prove(cond_a == !cond_b)) { if (else_b) { new_then_body.push_back(else_b); } From 886db0be23cafa15547fce1376144f1bce71da01 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 10:11:21 +0200 Subject: [PATCH 40/57] Short-circuit slow condition --- lib/AST/NestedScopeCombine.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/AST/NestedScopeCombine.cpp b/lib/AST/NestedScopeCombine.cpp index 90fccd21..338d023f 100644 --- a/lib/AST/NestedScopeCombine.cpp +++ b/lib/AST/NestedScopeCombine.cpp @@ -26,7 +26,7 @@ bool NestedScopeCombine::VisitIfStmt(clang::IfStmt *ifstmt) { auto cond{dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]}; if (Prove(cond)) { substitutions[ifstmt] = ifstmt->getThen(); - } else if (Prove(!cond) && ifstmt->getElse()) { + } else if (ifstmt->getElse() && Prove(!cond)) { substitutions[ifstmt] = ifstmt->getElse(); } return !Stopped(); From e79a3bf26d37f43e772662176d5c48865f887485 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 10:11:36 +0200 Subject: [PATCH 41/57] Add explanation --- lib/AST/NestedScopeCombine.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/AST/NestedScopeCombine.cpp b/lib/AST/NestedScopeCombine.cpp index 338d023f..b941e077 100644 --- a/lib/AST/NestedScopeCombine.cpp +++ b/lib/AST/NestedScopeCombine.cpp @@ -33,6 +33,8 @@ bool NestedScopeCombine::VisitIfStmt(clang::IfStmt *ifstmt) { } bool NestedScopeCombine::VisitWhileStmt(clang::WhileStmt *stmt) { + // Substitute while statements in the form `while(1) { sth; break; }` with + // just `{ sth; }` auto cond{dec_ctx.z3_exprs[dec_ctx.conds[stmt]]}; if (Prove(cond)) { auto body{clang::cast(stmt->getBody())}; From 0752a670c2089963600ed902cbc9d0e4d37afcc0 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 10:21:05 +0200 Subject: [PATCH 42/57] Add comments for `KnownExprs` --- lib/AST/NestedCondProp.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index 3b7d423d..f2bb2152 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -33,10 +33,18 @@ struct equal_to { } // namespace std namespace rellic { +// Stores a set of expression that have a known value, so that they can be +// recognized as part of larger expressions and simplified. struct KnownExprs { std::unordered_map values; void AddExpr(z3::expr expr, bool value) { + // When adding expressions to the set of known values, it's important that + // they are added in their smallest possible form. E.g., if it's known that + // `A && B` is true, then both of its subexpressions are true, and we should + // add those instead. + // This is so that `A && B && C` can be used to simplify smaller + // expressions, like `A && B`, which would otherwise not be recognized. switch (expr.decl().decl_kind()) { case Z3_OP_TRUE: case Z3_OP_FALSE: @@ -45,16 +53,19 @@ struct KnownExprs { break; } + // If !A has value V, then A has value !V, so add that instead. if (expr.is_not()) { AddExpr(expr.arg(0), !value); return; } + // A unary && or ||, just add the single subexpression if (expr.num_args() == 1) { AddExpr(expr.arg(0), value); return; } + // A true && expression means all of its subexpressions are true if (value && expr.is_and()) { for (auto e : expr.args()) { AddExpr(e, true); @@ -62,6 +73,7 @@ struct KnownExprs { return; } + // A false || expression means all of its subexpressions are false if (!value && expr.is_or()) { for (auto e : expr.args()) { AddExpr(e, false); @@ -72,6 +84,8 @@ struct KnownExprs { values[expr] = value; } + // Simplify an expression `expr` using all the known values stored. Sets + // `found` to true is any simplification has been applied. z3::expr ApplyAssumptions(z3::expr expr, bool& found) { if (values.empty()) { return expr; From 51e5397b8f5453d2a99fcbee6c4558e2bb480acf Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 10:37:33 +0200 Subject: [PATCH 43/57] Map expression directly by id --- lib/AST/NestedCondProp.cpp | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index f2bb2152..c6decf7c 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -18,25 +18,11 @@ #include "rellic/AST/ASTBuilder.h" #include "rellic/AST/Util.h" -namespace std { -template <> -struct hash { - size_t operator()(const z3::expr& e) const { return e.id(); } -}; - -template <> -struct equal_to { - bool operator()(const z3::expr& a, const z3::expr& b) const { - return a.id() == b.id(); - } -}; -} // namespace std - namespace rellic { // Stores a set of expression that have a known value, so that they can be // recognized as part of larger expressions and simplified. struct KnownExprs { - std::unordered_map values; + std::unordered_map values; void AddExpr(z3::expr expr, bool value) { // When adding expressions to the set of known values, it's important that @@ -81,7 +67,7 @@ struct KnownExprs { return; } - values[expr] = value; + values[expr.id()] = value; } // Simplify an expression `expr` using all the known values stored. Sets @@ -91,9 +77,9 @@ struct KnownExprs { return expr; } - if (values.find(expr) != values.end()) { + if (values.find(expr.id()) != values.end()) { found = true; - return expr.ctx().bool_val(values[expr]); + return expr.ctx().bool_val(values[expr.id()]); } if (expr.is_and() || expr.is_or()) { From ee48c73076422b9612a387438bd00c9f012c4024 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 10:37:50 +0200 Subject: [PATCH 44/57] Add small comment --- lib/AST/GenerateAST.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index 229800d0..87fd7aa6 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -169,6 +169,9 @@ unsigned GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, } unsigned GenerateAST::GetOrCreateVarForSwitch(llvm::SwitchInst *inst) { + // To aide simplification, switch instructions actually produce numerical + // variables instead of boolean ones, but are always compared against a + // constant value. if (dec_ctx.z3_sw_vars.find(inst) == dec_ctx.z3_sw_vars.end()) { auto name{GetName(inst)}; auto var{dec_ctx.z3_ctx.int_const(name.c_str())}; From 1a438e4f7e70daf3b9ed7f6f237cacd6f3b0e784 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 11:35:17 +0200 Subject: [PATCH 45/57] Factor out CBR --- lib/AST/CondBasedRefine.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/lib/AST/CondBasedRefine.cpp b/lib/AST/CondBasedRefine.cpp index e6ac58e1..37c9ad11 100644 --- a/lib/AST/CondBasedRefine.cpp +++ b/lib/AST/CondBasedRefine.cpp @@ -41,11 +41,12 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { auto else_b{if_b->getElse()}; std::vector new_then_body{then_a}; + clang::IfStmt *new_if{nullptr}; if (Prove(cond_a == cond_b)) { new_then_body.push_back(then_b); auto new_then{ast.CreateCompoundStmt(new_then_body)}; - auto new_if{ast.CreateIf(dec_ctx.marker_expr, new_then)}; + new_if = ast.CreateIf(dec_ctx.marker_expr, new_then); if (else_a || else_b) { std::vector new_else_body{}; @@ -62,9 +63,6 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { new_if->setElse(new_else); } - dec_ctx.conds[new_if] = dec_ctx.conds[if_a]; - body[i] = new_if; - body.erase(std::next(body.begin(), i + 1)); did_something = true; } else if (Prove(cond_a == !cond_b)) { if (else_b) { @@ -79,15 +77,18 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { } new_else_body.push_back(then_b); - auto new_if{ast.CreateIf(dec_ctx.marker_expr, new_then)}; + new_if = ast.CreateIf(dec_ctx.marker_expr, new_then); auto new_else{ast.CreateCompoundStmt(new_else_body)}; new_if->setElse(new_else); + did_something = true; + } + + if (did_something) { dec_ctx.conds[new_if] = dec_ctx.conds[if_a]; body[i] = new_if; body.erase(std::next(body.begin(), i + 1)); - did_something = true; } } if (did_something) { From 4560a7a6c3b8a90870d4d57dc796a13335a1bb86 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 11:35:26 +0200 Subject: [PATCH 46/57] Rename parameter for clarity --- lib/AST/NestedCondProp.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index c6decf7c..7e99c6de 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -71,21 +71,21 @@ struct KnownExprs { } // Simplify an expression `expr` using all the known values stored. Sets - // `found` to true is any simplification has been applied. - z3::expr ApplyAssumptions(z3::expr expr, bool& found) { + // `changed` to true is any simplification has been applied. + z3::expr ApplyAssumptions(z3::expr expr, bool& changed) { if (values.empty()) { return expr; } if (values.find(expr.id()) != values.end()) { - found = true; + changed = true; return expr.ctx().bool_val(values[expr.id()]); } if (expr.is_and() || expr.is_or()) { z3::expr_vector args{expr.ctx()}; for (auto arg : expr.args()) { - args.push_back(ApplyAssumptions(arg, found)); + args.push_back(ApplyAssumptions(arg, changed)); } if (expr.is_and()) { return z3::mk_and(args); @@ -95,7 +95,7 @@ struct KnownExprs { } if (expr.is_not()) { - return !ApplyAssumptions(expr.arg(0), found); + return !ApplyAssumptions(expr.arg(0), changed); } return expr; From 70be30288bbbe7c34d4fbe7027c0d6ad038a80fc Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 12:44:09 +0200 Subject: [PATCH 47/57] Factor out `ToExpr` --- include/rellic/AST/GenerateAST.h | 1 + lib/AST/GenerateAST.cpp | 37 ++++++-------------------------- 2 files changed, 8 insertions(+), 30 deletions(-) diff --git a/include/rellic/AST/GenerateAST.h b/include/rellic/AST/GenerateAST.h index 4584a03a..488b8a4c 100644 --- a/include/rellic/AST/GenerateAST.h +++ b/include/rellic/AST/GenerateAST.h @@ -27,6 +27,7 @@ class GenerateAST : public llvm::AnalysisInfoMixin { static llvm::AnalysisKey Key; constexpr static unsigned poison_idx = std::numeric_limits::max(); + z3::expr ToExpr(unsigned idx); // Need to use `map` with these instead of `unordered_map`, because // `std::pair` doesn't have a default hash implementation diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index 87fd7aa6..ddc8cf37 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -137,14 +137,15 @@ static std::string GetName(llvm::Value *v) { return s; } +z3::expr GenerateAST::ToExpr(unsigned idx) { + if (idx == poison_idx) { + return dec_ctx.z3_ctx.bool_val(false); + } + return dec_ctx.z3_exprs[idx]; +} + unsigned GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond) { - auto ToExpr = [&](unsigned idx) { - if (idx == poison_idx) { - return dec_ctx.z3_ctx.bool_val(false); - } - return dec_ctx.z3_exprs[idx]; - }; if (dec_ctx.z3_br_edges.find({inst, cond}) == dec_ctx.z3_br_edges.end()) { if (auto constant = llvm::dyn_cast(inst->getCondition())) { @@ -186,12 +187,6 @@ unsigned GenerateAST::GetOrCreateVarForSwitch(llvm::SwitchInst *inst) { unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, llvm::ConstantInt *c) { - auto ToExpr = [&](unsigned idx) { - if (idx == poison_idx) { - return dec_ctx.z3_ctx.bool_val(false); - } - return dec_ctx.z3_exprs[idx]; - }; if (dec_ctx.z3_sw_edges.find({inst, c}) == dec_ctx.z3_sw_edges.end()) { if (c) { auto sw_case{inst->findCaseValue(c)}; @@ -217,12 +212,6 @@ unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, unsigned GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, llvm::BasicBlock *to) { - auto ToExpr = [&](unsigned idx) { - if (idx == poison_idx) { - return dec_ctx.z3_ctx.bool_val(false); - } - return dec_ctx.z3_exprs[idx]; - }; if (dec_ctx.z3_edges.find({from, to}) == dec_ctx.z3_edges.end()) { // Construct the edge condition for CFG edge `(from, to)` auto result{dec_ctx.z3_ctx.bool_val(true)}; @@ -334,12 +323,6 @@ StmtVec GenerateAST::CreateBasicBlockStmts(llvm::BasicBlock *block) { } StmtVec GenerateAST::CreateRegionStmts(llvm::Region *region) { - auto ToExpr = [&](unsigned idx) { - if (idx == poison_idx) { - return dec_ctx.z3_ctx.bool_val(false); - } - return dec_ctx.z3_exprs[idx]; - }; StmtVec result; for (auto block : rpo_walk) { // Check if the block is a subregion entry @@ -426,12 +409,6 @@ clang::CompoundStmt *GenerateAST::StructureAcyclicRegion(llvm::Region *region) { } clang::CompoundStmt *GenerateAST::StructureCyclicRegion(llvm::Region *region) { - auto ToExpr = [&](unsigned idx) { - if (idx == poison_idx) { - return dec_ctx.z3_ctx.bool_val(false); - } - return dec_ctx.z3_exprs[idx]; - }; DLOG(INFO) << "Region " << GetRegionNameStr(region) << " is cyclic"; auto region_body = CreateRegionStmts(region); // Get the loop for which the entry block of the region is a header From a6e4e56f804c3956e4ca1cead5430ae6aed21a4f Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 12:47:11 +0200 Subject: [PATCH 48/57] Early returns --- lib/AST/GenerateAST.cpp | 57 ++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index ddc8cf37..b57cb103 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -173,41 +173,44 @@ unsigned GenerateAST::GetOrCreateVarForSwitch(llvm::SwitchInst *inst) { // To aide simplification, switch instructions actually produce numerical // variables instead of boolean ones, but are always compared against a // constant value. - if (dec_ctx.z3_sw_vars.find(inst) == dec_ctx.z3_sw_vars.end()) { - auto name{GetName(inst)}; - auto var{dec_ctx.z3_ctx.int_const(name.c_str())}; - dec_ctx.z3_sw_vars[inst] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(var); - dec_ctx.z3_sw_vars_inv[var.id()] = inst; - return dec_ctx.z3_sw_vars[inst]; - } else { + if (dec_ctx.z3_sw_vars.find(inst) != dec_ctx.z3_sw_vars.end()) { return dec_ctx.z3_sw_vars[inst]; } + + auto name{GetName(inst)}; + auto var{dec_ctx.z3_ctx.int_const(name.c_str())}; + dec_ctx.z3_sw_vars[inst] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(var); + dec_ctx.z3_sw_vars_inv[var.id()] = inst; + return dec_ctx.z3_sw_vars[inst]; } unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, llvm::ConstantInt *c) { - if (dec_ctx.z3_sw_edges.find({inst, c}) == dec_ctx.z3_sw_edges.end()) { - if (c) { - auto sw_case{inst->findCaseValue(c)}; - auto var{ToExpr(GetOrCreateVarForSwitch(inst))}; - auto expr{var == dec_ctx.z3_ctx.int_val(sw_case->getCaseIndex())}; - - dec_ctx.z3_sw_edges[{inst, c}] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(expr); - } else { - // Default case - z3::expr_vector vec{dec_ctx.z3_ctx}; - for (auto sw_case : inst->cases()) { - vec.push_back( - !ToExpr(GetOrCreateEdgeForSwitch(inst, sw_case.getCaseValue()))); - } - dec_ctx.z3_sw_edges[{inst, c}] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(z3::mk_and(vec)); - } + if (dec_ctx.z3_sw_edges.find({inst, c}) != dec_ctx.z3_sw_edges.end()) { + return dec_ctx.z3_sw_edges[{inst, c}]; } - return dec_ctx.z3_sw_edges[{inst, c}]; + unsigned idx; + if (c) { + auto sw_case{inst->findCaseValue(c)}; + auto var{ToExpr(GetOrCreateVarForSwitch(inst))}; + auto expr{var == dec_ctx.z3_ctx.int_val(sw_case->getCaseIndex())}; + + idx = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(expr); + } else { + // Default case + z3::expr_vector vec{dec_ctx.z3_ctx}; + for (auto sw_case : inst->cases()) { + vec.push_back( + !ToExpr(GetOrCreateEdgeForSwitch(inst, sw_case.getCaseValue()))); + } + idx = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(z3::mk_and(vec)); + } + dec_ctx.z3_sw_edges[{inst, c}] = idx; + return idx; } unsigned GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, From 599fa4653c553c7959a9cfc3b64309b9a9a588e8 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 14:14:37 +0200 Subject: [PATCH 49/57] More cleanup --- lib/AST/GenerateAST.cpp | 104 +++++++++++++++++++--------------------- 1 file changed, 49 insertions(+), 55 deletions(-) diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index b57cb103..2efa0136 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -215,58 +215,59 @@ unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, unsigned GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, llvm::BasicBlock *to) { - if (dec_ctx.z3_edges.find({from, to}) == dec_ctx.z3_edges.end()) { - // Construct the edge condition for CFG edge `(from, to)` - auto result{dec_ctx.z3_ctx.bool_val(true)}; - auto term = from->getTerminator(); - switch (term->getOpcode()) { - // Conditional branches - case llvm::Instruction::Br: { - auto br = llvm::cast(term); - if (br->isConditional()) { - result = - ToExpr(GetOrCreateEdgeForBranch(br, to == br->getSuccessor(0))); - } - } break; - // Switches - case llvm::Instruction::Switch: { - auto sw{llvm::cast(term)}; - if (to == sw->getDefaultDest()) { - result = ToExpr(GetOrCreateEdgeForSwitch(sw, nullptr)); - } else { - z3::expr_vector or_vec{dec_ctx.z3_ctx}; - for (auto sw_case : sw->cases()) { - if (sw_case.getCaseSuccessor() == to) { - or_vec.push_back( - ToExpr(GetOrCreateEdgeForSwitch(sw, sw_case.getCaseValue()))); - } + if (dec_ctx.z3_edges.find({from, to}) != dec_ctx.z3_edges.end()) { + return dec_ctx.z3_edges[{from, to}]; + } + + // Construct the edge condition for CFG edge `(from, to)` + auto result{dec_ctx.z3_ctx.bool_val(true)}; + auto term = from->getTerminator(); + switch (term->getOpcode()) { + // Conditional branches + case llvm::Instruction::Br: { + auto br = llvm::cast(term); + if (br->isConditional()) { + result = + ToExpr(GetOrCreateEdgeForBranch(br, to == br->getSuccessor(0))); + } + } break; + // Switches + case llvm::Instruction::Switch: { + auto sw{llvm::cast(term)}; + if (to == sw->getDefaultDest()) { + result = ToExpr(GetOrCreateEdgeForSwitch(sw, nullptr)); + } else { + z3::expr_vector or_vec{dec_ctx.z3_ctx}; + for (auto sw_case : sw->cases()) { + if (sw_case.getCaseSuccessor() == to) { + or_vec.push_back( + ToExpr(GetOrCreateEdgeForSwitch(sw, sw_case.getCaseValue()))); } - result = HeavySimplify(z3::mk_or(or_vec)); } - } break; - // Returns - case llvm::Instruction::Ret: - break; - // Exceptions - case llvm::Instruction::Invoke: - case llvm::Instruction::Resume: - case llvm::Instruction::CatchSwitch: - case llvm::Instruction::CatchRet: - case llvm::Instruction::CleanupRet: - THROW() << "Exception terminator '" << term->getOpcodeName() - << "' is not supported yet"; - break; - // Unknown - default: - THROW() << "Unsupported terminator instruction: " - << term->getOpcodeName(); - break; - } - - dec_ctx.z3_edges[{from, to}] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(result.simplify()); + result = HeavySimplify(z3::mk_or(or_vec)); + } + } break; + // Returns + case llvm::Instruction::Ret: + break; + // Exceptions + case llvm::Instruction::Invoke: + case llvm::Instruction::Resume: + case llvm::Instruction::CatchSwitch: + case llvm::Instruction::CatchRet: + case llvm::Instruction::CleanupRet: + THROW() << "Exception terminator '" << term->getOpcodeName() + << "' is not supported yet"; + break; + // Unknown + default: + THROW() << "Unsupported terminator instruction: " + << term->getOpcodeName(); + break; } + dec_ctx.z3_edges[{from, to}] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(result.simplify()); return dec_ctx.z3_edges[{from, to}]; } @@ -279,13 +280,6 @@ unsigned GenerateAST::GetReachingCond(llvm::BasicBlock *block) { } void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { - auto ToExpr = [&](unsigned idx) { - if (idx == poison_idx) { - return dec_ctx.z3_ctx.bool_val(false); - } - return dec_ctx.z3_exprs[idx]; - }; - auto old_cond_idx{GetReachingCond(block)}; auto old_cond{ToExpr(old_cond_idx)}; if (block->hasNPredecessorsOrMore(1)) { From 88fa360da0e7ff2a228a8822657aa8b2a06f1329 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 15:08:03 +0200 Subject: [PATCH 50/57] More cleanup --- lib/AST/GenerateAST.cpp | 48 ++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index 2efa0136..fc9b0630 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -146,24 +146,25 @@ z3::expr GenerateAST::ToExpr(unsigned idx) { unsigned GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond) { - if (dec_ctx.z3_br_edges.find({inst, cond}) == dec_ctx.z3_br_edges.end()) { - if (auto constant = - llvm::dyn_cast(inst->getCondition())) { - dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.z3_exprs.size(); - auto edge{dec_ctx.z3_ctx.bool_val(constant->isOne() == cond)}; - dec_ctx.z3_exprs.push_back(edge); - dec_ctx.z3_br_edges_inv[edge.id()] = {inst, true}; - } else if (cond) { - auto name{GetName(inst)}; - auto edge{dec_ctx.z3_ctx.bool_const(name.c_str())}; - dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(edge); - dec_ctx.z3_br_edges_inv[edge.id()] = {inst, true}; - } else { - auto edge{!(ToExpr(GetOrCreateEdgeForBranch(inst, true)))}; - dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(edge); - } + if (dec_ctx.z3_br_edges.find({inst, cond}) != dec_ctx.z3_br_edges.end()) { + return dec_ctx.z3_br_edges[{inst, cond}]; + } + + if (auto constant = llvm::dyn_cast(inst->getCondition())) { + dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.z3_exprs.size(); + auto edge{dec_ctx.z3_ctx.bool_val(constant->isOne() == cond)}; + dec_ctx.z3_exprs.push_back(edge); + dec_ctx.z3_br_edges_inv[edge.id()] = {inst, true}; + } else if (cond) { + auto name{GetName(inst)}; + auto edge{dec_ctx.z3_ctx.bool_const(name.c_str())}; + dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(edge); + dec_ctx.z3_br_edges_inv[edge.id()] = {inst, true}; + } else { + auto edge{!(ToExpr(GetOrCreateEdgeForBranch(inst, true)))}; + dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(edge); } return dec_ctx.z3_br_edges[{inst, cond}]; @@ -304,12 +305,11 @@ void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { dec_ctx.z3_exprs.push_back(cond); reaching_conds_changed = true; } - } else { - if (dec_ctx.reaching_conds.find(block) == dec_ctx.reaching_conds.end()) { - dec_ctx.reaching_conds[block] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(dec_ctx.z3_ctx.bool_val(true)); - reaching_conds_changed = true; - } + } else if (dec_ctx.reaching_conds.find(block) == + dec_ctx.reaching_conds.end()) { + dec_ctx.reaching_conds[block] = dec_ctx.z3_exprs.size(); + dec_ctx.z3_exprs.push_back(dec_ctx.z3_ctx.bool_val(true)); + reaching_conds_changed = true; } } From 0bd968cdb32e61d0f02a4904a9f7290acaf346f9 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 15:16:31 +0200 Subject: [PATCH 51/57] Factor out stuff in NCP --- lib/AST/NestedCondProp.cpp | 59 ++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp index 7e99c6de..7df48f96 100755 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -109,6 +109,29 @@ class CompoundVisitor ASTBuilder& ast; clang::ASTContext& ctx; + template + bool VisitLoop(T* loop, KnownExprs& known_exprs) { + auto cond_idx{dec_ctx.conds[loop]}; + bool changed{false}; + auto old_cond{dec_ctx.z3_exprs[cond_idx]}; + auto new_cond{known_exprs.ApplyAssumptions(old_cond, changed)}; + if (loop->getCond() != dec_ctx.marker_expr && changed) { + dec_ctx.z3_exprs.set(cond_idx, new_cond); + return true; + } + + auto inner{known_exprs}; + if constexpr (cond_is_true_in_body) { + inner.AddExpr(new_cond, true); + } + known_exprs.AddExpr(new_cond, false); + + if (Visit(loop->getBody(), inner)) { + return true; + } + return false; + } + public: CompoundVisitor(DecompilationContext& dec_ctx, ASTBuilder& ast, clang::ASTContext& ctx) @@ -126,43 +149,11 @@ class CompoundVisitor } bool VisitWhileStmt(clang::WhileStmt* while_stmt, KnownExprs& known_exprs) { - auto cond_idx{dec_ctx.conds[while_stmt]}; - bool changed{false}; - auto old_cond{dec_ctx.z3_exprs[cond_idx]}; - auto new_cond{known_exprs.ApplyAssumptions(old_cond, changed)}; - if (while_stmt->getCond() != dec_ctx.marker_expr && changed) { - dec_ctx.z3_exprs.set(cond_idx, new_cond); - return true; - } - - auto inner{known_exprs}; - inner.AddExpr(new_cond, true); - known_exprs.AddExpr(new_cond, false); - - if (Visit(while_stmt->getBody(), inner)) { - return true; - } - return false; + return VisitLoop(while_stmt, known_exprs); } bool VisitDoStmt(clang::DoStmt* do_stmt, KnownExprs& known_exprs) { - auto cond_idx{dec_ctx.conds[do_stmt]}; - bool changed{false}; - auto old_cond{dec_ctx.z3_exprs[cond_idx]}; - auto new_cond{known_exprs.ApplyAssumptions(old_cond, changed)}; - if (do_stmt->getCond() == dec_ctx.marker_expr && changed) { - dec_ctx.z3_exprs.set(cond_idx, new_cond); - return true; - } - - auto inner{known_exprs}; - known_exprs.AddExpr(new_cond, false); - - if (Visit(do_stmt->getBody(), inner)) { - return true; - } - - return false; + return VisitLoop(do_stmt, known_exprs); } bool VisitIfStmt(clang::IfStmt* if_stmt, KnownExprs& known_exprs) { From af24ad2553177cd988e9ef3dd81a28cf81230602 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 31 Aug 2022 15:44:56 +0200 Subject: [PATCH 52/57] Create wrapper method for Z3 expressions --- include/rellic/AST/Util.h | 3 +++ lib/AST/GenerateAST.cpp | 35 +++++++++++++---------------------- lib/AST/LoopRefine.cpp | 13 +++++-------- lib/AST/Util.cpp | 6 ++++++ 4 files changed, 27 insertions(+), 30 deletions(-) diff --git a/include/rellic/AST/Util.h b/include/rellic/AST/Util.h index 8f7280af..004699c4 100644 --- a/include/rellic/AST/Util.h +++ b/include/rellic/AST/Util.h @@ -90,6 +90,9 @@ struct DecompilationContext { size_t num_literal_structs = 0; size_t num_declared_structs = 0; + + // Inserts and expression into z3_exprs and returns its index + unsigned InsertZExpr(const z3::expr &e); }; template diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index fc9b0630..93fe344c 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -151,20 +151,17 @@ unsigned GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, } if (auto constant = llvm::dyn_cast(inst->getCondition())) { - dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.z3_exprs.size(); auto edge{dec_ctx.z3_ctx.bool_val(constant->isOne() == cond)}; - dec_ctx.z3_exprs.push_back(edge); + dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.InsertZExpr(edge); dec_ctx.z3_br_edges_inv[edge.id()] = {inst, true}; } else if (cond) { auto name{GetName(inst)}; auto edge{dec_ctx.z3_ctx.bool_const(name.c_str())}; - dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(edge); + dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.InsertZExpr(edge); dec_ctx.z3_br_edges_inv[edge.id()] = {inst, true}; } else { auto edge{!(ToExpr(GetOrCreateEdgeForBranch(inst, true)))}; - dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(edge); + dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.InsertZExpr(edge); } return dec_ctx.z3_br_edges[{inst, cond}]; @@ -180,8 +177,7 @@ unsigned GenerateAST::GetOrCreateVarForSwitch(llvm::SwitchInst *inst) { auto name{GetName(inst)}; auto var{dec_ctx.z3_ctx.int_const(name.c_str())}; - dec_ctx.z3_sw_vars[inst] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(var); + dec_ctx.z3_sw_vars[inst] = dec_ctx.InsertZExpr(var); dec_ctx.z3_sw_vars_inv[var.id()] = inst; return dec_ctx.z3_sw_vars[inst]; } @@ -198,8 +194,7 @@ unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, auto var{ToExpr(GetOrCreateVarForSwitch(inst))}; auto expr{var == dec_ctx.z3_ctx.int_val(sw_case->getCaseIndex())}; - idx = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(expr); + idx = dec_ctx.InsertZExpr(expr); } else { // Default case z3::expr_vector vec{dec_ctx.z3_ctx}; @@ -207,8 +202,7 @@ unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, vec.push_back( !ToExpr(GetOrCreateEdgeForSwitch(inst, sw_case.getCaseValue()))); } - idx = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(z3::mk_and(vec)); + idx = dec_ctx.InsertZExpr(z3::mk_and(vec)); } dec_ctx.z3_sw_edges[{inst, c}] = idx; return idx; @@ -267,8 +261,7 @@ unsigned GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, break; } - dec_ctx.z3_edges[{from, to}] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(result.simplify()); + dec_ctx.z3_edges[{from, to}] = dec_ctx.InsertZExpr(result.simplify()); return dec_ctx.z3_edges[{from, to}]; } @@ -301,14 +294,13 @@ void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { auto cond{HeavySimplify(z3::mk_or(conds))}; if (old_cond_idx == poison_idx || !Prove(old_cond == cond)) { - dec_ctx.reaching_conds[block] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(cond); + dec_ctx.reaching_conds[block] = dec_ctx.InsertZExpr(cond); reaching_conds_changed = true; } } else if (dec_ctx.reaching_conds.find(block) == dec_ctx.reaching_conds.end()) { - dec_ctx.reaching_conds[block] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(dec_ctx.z3_ctx.bool_val(true)); + dec_ctx.reaching_conds[block] = + dec_ctx.InsertZExpr(dec_ctx.z3_ctx.bool_val(true)); reaching_conds_changed = true; } } @@ -343,8 +335,8 @@ StmtVec GenerateAST::CreateRegionStmts(llvm::Region *region) { // Gate the compound behind a reaching condition auto z_expr{GetReachingCond(block)}; block_stmts[block] = ast.CreateIf(dec_ctx.marker_expr, compound); - dec_ctx.conds[block_stmts[block]] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(dec_ctx.z3_exprs[z_expr]); + dec_ctx.conds[block_stmts[block]] = + dec_ctx.InsertZExpr(dec_ctx.z3_exprs[z_expr]); // Store the compound result.push_back(block_stmts[block]); } @@ -452,9 +444,8 @@ clang::CompoundStmt *GenerateAST::StructureCyclicRegion(llvm::Region *region) { StmtVec break_stmt({ast.CreateBreak()}); auto exit_stmt = ast.CreateIf(dec_ctx.marker_expr, ast.CreateCompoundStmt(break_stmt)); - dec_ctx.conds[exit_stmt] = dec_ctx.z3_exprs.size(); // Create edge condition - dec_ctx.z3_exprs.push_back( + dec_ctx.conds[exit_stmt] = dec_ctx.InsertZExpr( (ToExpr(GetReachingCond(from)) && ToExpr(GetOrCreateEdgeCond(from, to))) .simplify()); // Insert it after the exiting block statement diff --git a/lib/AST/LoopRefine.cpp b/lib/AST/LoopRefine.cpp index c3720fcd..9710ce8d 100644 --- a/lib/AST/LoopRefine.cpp +++ b/lib/AST/LoopRefine.cpp @@ -66,8 +66,8 @@ class WhileRule : public InferenceRule { ASTBuilder ast(unit); auto new_while{ ast.CreateWhile(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; - dec_ctx.conds[new_while] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(!dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]); + dec_ctx.conds[new_while] = + dec_ctx.InsertZExpr(!dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]); return new_while; } }; @@ -146,8 +146,7 @@ class DoWhileRule : public InferenceRule { } auto new_do{ ast.CreateDo(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; - dec_ctx.conds[new_do] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(!cond); + dec_ctx.conds[new_do] = dec_ctx.InsertZExpr(!cond); return new_do; } }; @@ -240,8 +239,7 @@ class NestedDoWhileRule : public InferenceRule { auto do_stmt{ ast.CreateDo(dec_ctx.marker_expr, ast.CreateCompoundStmt(do_body))}; - dec_ctx.conds[do_stmt] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(!cond); + dec_ctx.conds[do_stmt] = dec_ctx.InsertZExpr(!cond); std::vector while_body({do_stmt, if_stmt->getThen()}); auto new_while{ast.CreateWhile(dec_ctx.marker_expr, @@ -379,8 +377,7 @@ class CondToSeqNegRule : public InferenceRule { auto ifstmt{clang::cast(body->body_front())}; auto cond{dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]}; auto inner_loop{ast.CreateWhile(dec_ctx.marker_expr, ifstmt->getElse())}; - dec_ctx.conds[inner_loop] = dec_ctx.z3_exprs.size(); - dec_ctx.z3_exprs.push_back(!cond); + dec_ctx.conds[inner_loop] = dec_ctx.InsertZExpr(!cond); std::vector new_body({inner_loop}); if (auto comp = clang::dyn_cast(ifstmt->getThen())) { new_body.insert(new_body.end(), comp->body_begin(), comp->body_end()); diff --git a/lib/AST/Util.cpp b/lib/AST/Util.cpp index b53d0c7c..e7ae9682 100644 --- a/lib/AST/Util.cpp +++ b/lib/AST/Util.cpp @@ -385,4 +385,10 @@ z3::expr OrderById(z3::expr expr) { return expr; } + +unsigned DecompilationContext::InsertZExpr(const z3::expr &e) { + auto idx{z3_exprs.size()}; + z3_exprs.push_back(e); + return idx; +} } // namespace rellic From c65fe31a2954a848287f25e958c7ad32e24ea315 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Thu, 1 Sep 2022 09:38:19 +0200 Subject: [PATCH 53/57] Add comments for maps --- include/rellic/AST/Util.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/include/rellic/AST/Util.h b/include/rellic/AST/Util.h index 004699c4..c2afd11e 100644 --- a/include/rellic/AST/Util.h +++ b/include/rellic/AST/Util.h @@ -79,6 +79,10 @@ struct DecompilationContext { clang::Expr *marker_expr; std::unordered_map z3_br_edges_inv; + + // Pairs do not have a std::hash specialization so we can't use unordered maps + // here. If this turns out to be a performance issue, investigate adding hash + // specializations for these specifically std::map z3_br_edges; std::unordered_map z3_sw_vars; @@ -91,7 +95,7 @@ struct DecompilationContext { size_t num_literal_structs = 0; size_t num_declared_structs = 0; - // Inserts and expression into z3_exprs and returns its index + // Inserts an expression into z3_exprs and returns its index unsigned InsertZExpr(const z3::expr &e); }; From 27924aabc97db97c33996c72e6aadc1e779890e8 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Thu, 1 Sep 2022 14:52:22 +0200 Subject: [PATCH 54/57] Cleanup and comment `ConvertExpr` --- lib/AST/IRToASTVisitor.cpp | 46 +++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/lib/AST/IRToASTVisitor.cpp b/lib/AST/IRToASTVisitor.cpp index 19c872f7..3f7bce3b 100755 --- a/lib/AST/IRToASTVisitor.cpp +++ b/lib/AST/IRToASTVisitor.cpp @@ -65,6 +65,10 @@ class ExprGen : public llvm::InstVisitor { clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { if (expr.decl().decl_kind() == Z3_OP_EQ) { + // Equalities generated for the reaching conditions of switch instructions + // Always in the for (VAR == CONST) or (CONST == VAR) + // VAR will uniquely identify a SwitchInst, CONST will represent the index + // of the case taken CHECK_EQ(expr.num_args(), 2) << "Equalities must have 2 arguments"; auto a{expr.arg(0)}; auto b{expr.arg(1)}; @@ -72,6 +76,9 @@ clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { llvm::SwitchInst *inst{dec_ctx.z3_sw_vars_inv[a.id()]}; unsigned case_idx{}; + // GenerateAST always generates equalities in the form (VAR == CONST), but + // there is a chance that some Z3 simplification inverts the order, so + // handle that here. if (!inst) { inst = dec_ctx.z3_sw_vars_inv[b.id()]; case_idx = a.get_numeral_uint(); @@ -94,44 +101,41 @@ clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { auto edge{dec_ctx.z3_br_edges_inv[hash]}; CHECK(edge.second) << "Inverse map should only be populated for branches " "taken when condition is true"; - return CreateOperandExpr(*(edge.first->op_end() - 3)); - } - - if (dec_ctx.z3_sw_vars_inv.find(hash) != dec_ctx.z3_sw_vars_inv.end()) { - auto inst{dec_ctx.z3_sw_vars_inv[hash]}; - return CreateOperandExpr(inst->getOperandUse(0)); - } + // expr is a variable that represents the condition of a branch instruction. - std::vector args; - for (auto i{0U}; i < expr.num_args(); ++i) { - args.push_back(ConvertExpr(expr.arg(i))); + // FIXME(frabert): Unfortunately there is no public API in BranchInst that + // gives the operand of the condition. From reverse engineering LLVM code, + // this is the way they obtain uses internally, but it's probably not + // stable. + return CreateOperandExpr(*(edge.first->op_end() - 3)); } switch (expr.decl().decl_kind()) { case Z3_OP_TRUE: - CHECK_EQ(args.size(), 0) << "True cannot have arguments"; + CHECK_EQ(expr.num_args(), 0) << "True cannot have arguments"; return ast.CreateTrue(); case Z3_OP_FALSE: - CHECK_EQ(args.size(), 0) << "False cannot have arguments"; + CHECK_EQ(expr.num_args(), 0) << "False cannot have arguments"; return ast.CreateFalse(); case Z3_OP_AND: { - clang::Expr *res{args[0]}; - for (auto i{1U}; i < args.size(); ++i) { - res = ast.CreateLAnd(res, args[i]); + clang::Expr *res{ConvertExpr(expr.arg(0))}; + for (auto i{1U}; i < expr.num_args(); ++i) { + res = ast.CreateLAnd(res, ConvertExpr(expr.arg(i))); } return res; } case Z3_OP_OR: { - clang::Expr *res{args[0]}; - for (auto i{1U}; i < args.size(); ++i) { - res = ast.CreateLOr(res, args[i]); + clang::Expr *res{ConvertExpr(expr.arg(0))}; + for (auto i{1U}; i < expr.num_args(); ++i) { + res = ast.CreateLOr(res, ConvertExpr(expr.arg(i))); } return res; } case Z3_OP_NOT: { - CHECK_EQ(args.size(), 1) << "Not must have one argument"; - auto neg{ast.CreateLNot(args[0])}; - CopyProvenance(args[0], neg, dec_ctx.use_provenance); + CHECK_EQ(expr.num_args(), 1) << "Not must have one argument"; + auto sub{ConvertExpr(expr.arg(0))}; + auto neg{ast.CreateLNot(sub)}; + CopyProvenance(sub, neg, dec_ctx.use_provenance); return neg; } default: From 7510837394f3ee56c0a5d003fb96a2d7279270da Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Thu, 1 Sep 2022 15:05:52 +0200 Subject: [PATCH 55/57] More comments --- include/rellic/AST/GenerateAST.h | 10 ++++++++++ lib/AST/GenerateAST.cpp | 6 ++++++ lib/AST/IRToASTVisitor.cpp | 2 +- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/include/rellic/AST/GenerateAST.h b/include/rellic/AST/GenerateAST.h index 488b8a4c..b53eae01 100644 --- a/include/rellic/AST/GenerateAST.h +++ b/include/rellic/AST/GenerateAST.h @@ -46,8 +46,18 @@ class GenerateAST : public llvm::AnalysisInfoMixin { std::vector rpo_walk; + // GetOrCreateEdgeForBranch(branch, true) will return the index of an + // expression that is true when branch is taken. + // Viceversa, GetOrCreateEdgeForBranch(branch, false) is an expression that + // will be true when branch is not taken unsigned GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond); + + // Returns the index of an expression containing a numerical variable that + // represents the condition of a switch. unsigned GetOrCreateVarForSwitch(llvm::SwitchInst *inst); + // Returns the index of an expression that is true when a particular case of a + // switch is taken. If c is nullptr, the expression for the default case will + // be returned. unsigned GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, llvm::ConstantInt *c); diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp index 93fe344c..238d4c9d 100755 --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -151,15 +151,21 @@ unsigned GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, } if (auto constant = llvm::dyn_cast(inst->getCondition())) { + // This is a conditional branch with a constant condition, so just emit + // whether the condition matches the wanted value auto edge{dec_ctx.z3_ctx.bool_val(constant->isOne() == cond)}; dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.InsertZExpr(edge); dec_ctx.z3_br_edges_inv[edge.id()] = {inst, true}; } else if (cond) { + // This is a conditional branch, so the expression that is true when the + // branch is going to be taken is just a new variable. auto name{GetName(inst)}; auto edge{dec_ctx.z3_ctx.bool_const(name.c_str())}; dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.InsertZExpr(edge); dec_ctx.z3_br_edges_inv[edge.id()] = {inst, true}; } else { + // Like the previous case, but in this case we want to know the expression + // that will be true when the branch is not going to be taken auto edge{!(ToExpr(GetOrCreateEdgeForBranch(inst, true)))}; dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.InsertZExpr(edge); } diff --git a/lib/AST/IRToASTVisitor.cpp b/lib/AST/IRToASTVisitor.cpp index 3f7bce3b..cba86151 100755 --- a/lib/AST/IRToASTVisitor.cpp +++ b/lib/AST/IRToASTVisitor.cpp @@ -65,7 +65,7 @@ class ExprGen : public llvm::InstVisitor { clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { if (expr.decl().decl_kind() == Z3_OP_EQ) { - // Equalities generated for the reaching conditions of switch instructions + // Equalities generated form the reaching conditions of switch instructions // Always in the for (VAR == CONST) or (CONST == VAR) // VAR will uniquely identify a SwitchInst, CONST will represent the index // of the case taken From e052b3664eca63acfd8cf7ffb1c2a1347f690eac Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Fri, 2 Sep 2022 10:04:17 +0200 Subject: [PATCH 56/57] Remove old comment --- include/rellic/AST/GenerateAST.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/include/rellic/AST/GenerateAST.h b/include/rellic/AST/GenerateAST.h index b53eae01..b80d10ae 100644 --- a/include/rellic/AST/GenerateAST.h +++ b/include/rellic/AST/GenerateAST.h @@ -29,8 +29,6 @@ class GenerateAST : public llvm::AnalysisInfoMixin { constexpr static unsigned poison_idx = std::numeric_limits::max(); z3::expr ToExpr(unsigned idx); - // Need to use `map` with these instead of `unordered_map`, because - // `std::pair` doesn't have a default hash implementation clang::ASTUnit &unit; clang::ASTContext *ast_ctx; rellic::IRToASTVisitor ast_gen; From 389250eb4cdb8ff31ad0d4de659252e6e64fbf7d Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Fri, 2 Sep 2022 10:24:38 +0200 Subject: [PATCH 57/57] More comments --- lib/AST/CondBasedRefine.cpp | 17 +++++++++++++++++ lib/AST/IRToASTVisitor.cpp | 3 +++ 2 files changed, 20 insertions(+) diff --git a/lib/AST/CondBasedRefine.cpp b/lib/AST/CondBasedRefine.cpp index 37c9ad11..51216404 100644 --- a/lib/AST/CondBasedRefine.cpp +++ b/lib/AST/CondBasedRefine.cpp @@ -27,6 +27,7 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { auto if_a{clang::dyn_cast(body[i])}; auto if_b{clang::dyn_cast(body[i + 1])}; + // We need two `if` statements to combine if (!if_a || !if_b) { continue; } @@ -43,12 +44,20 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { std::vector new_then_body{then_a}; clang::IfStmt *new_if{nullptr}; if (Prove(cond_a == cond_b)) { + // We found two consecutive `if` statements with identical conditions, so + // we can merge their `then` and `else` branches + // + // if(a) { X1; } else { Y1; } + // if(a) { X2; } else { Y2; } + // becomes + // if(a) { X1; X2; } else { Y1; Y2; } new_then_body.push_back(then_b); auto new_then{ast.CreateCompoundStmt(new_then_body)}; new_if = ast.CreateIf(dec_ctx.marker_expr, new_then); if (else_a || else_b) { + // At least one of the two `if` statements has an `else` branch std::vector new_else_body{}; if (else_a) { @@ -65,6 +74,14 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { did_something = true; } else if (Prove(cond_a == !cond_b)) { + // We found two consecutive `if` statements with opposite conditions, so + // we can append the else branch of the second to the then branch of the + // first, and viceversa + // + // if(a) { X1; } else { Y1; } + // if(!a) { X2; } else { Y2; } + // becomes + // if(a) { X1; Y2; } else { Y1; X2; } if (else_b) { new_then_body.push_back(else_b); } diff --git a/lib/AST/IRToASTVisitor.cpp b/lib/AST/IRToASTVisitor.cpp index cba86151..b44b8da6 100755 --- a/lib/AST/IRToASTVisitor.cpp +++ b/lib/AST/IRToASTVisitor.cpp @@ -118,6 +118,9 @@ clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { CHECK_EQ(expr.num_args(), 0) << "False cannot have arguments"; return ast.CreateFalse(); case Z3_OP_AND: { + // Since AND and OR expressions are n-ary we need to convert them to + // binary. If they have only one subexpression, we can forego the AND/OR + // altogether. clang::Expr *res{ConvertExpr(expr.arg(0))}; for (auto i{1U}; i < expr.num_args(); ++i) { res = ast.CreateLAnd(res, ConvertExpr(expr.arg(i)));