diff --git a/ci/angha_1k_test_settings.json b/ci/angha_1k_test_settings.json index 06f7aafe..34b4d485 100644 --- a/ci/angha_1k_test_settings.json +++ b/ci/angha_1k_test_settings.json @@ -1,12 +1,8 @@ { "tests.ignore": [ - "amd64/linux/tools/perf/bench/extr_numa.c___bench_numa.bc", "amd64/SoftEtherVPN/src/See/extr_memory_t.h_SW_LONG_AT.bc", - "arm64/linux/tools/perf/bench/extr_numa.c___bench_numa.bc", "arm64/SoftEtherVPN/src/See/extr_memory_t.h_SW_LONG_AT.bc", - "armv7/linux/tools/perf/bench/extr_numa.c___bench_numa.bc", "armv7/SoftEtherVPN/src/See/extr_memory_t.h_SW_LONG_AT.bc", - "x86/linux/tools/perf/bench/extr_numa.c___bench_numa.bc", "x86/SoftEtherVPN/src/See/extr_memory_t.h_SW_LONG_AT.bc" ] } \ No newline at end of file diff --git a/include/rellic/AST/ASTPass.h b/include/rellic/AST/ASTPass.h index 54c48128..a20263b1 100644 --- a/include/rellic/AST/ASTPass.h +++ b/include/rellic/AST/ASTPass.h @@ -21,7 +21,7 @@ class ASTPass { std::atomic_bool stop{false}; protected: - Provenance& provenance; + DecompilationContext& dec_ctx; clang::ASTUnit& ast_unit; clang::ASTContext& ast_ctx; ASTBuilder ast; @@ -32,8 +32,8 @@ class ASTPass { virtual void StopImpl() {} public: - ASTPass(Provenance& provenance, clang::ASTUnit& ast_unit) - : provenance(provenance), + ASTPass(DecompilationContext& dec_ctx, clang::ASTUnit& ast_unit) + : dec_ctx(dec_ctx), ast_unit(ast_unit), ast_ctx(ast_unit.getASTContext()), ast(ast_unit) {} @@ -89,8 +89,8 @@ class CompositeASTPass : public ASTPass { } public: - CompositeASTPass(Provenance& provenance, clang::ASTUnit& ast_unit) - : ASTPass(provenance, ast_unit) {} + CompositeASTPass(DecompilationContext& dec_ctx, clang::ASTUnit& ast_unit) + : ASTPass(dec_ctx, ast_unit) {} std::vector>& GetPasses() { return passes; } }; } // namespace rellic \ No newline at end of file diff --git a/include/rellic/AST/CondBasedRefine.h b/include/rellic/AST/CondBasedRefine.h index e7d0f686..28905b7f 100644 --- a/include/rellic/AST/CondBasedRefine.h +++ b/include/rellic/AST/CondBasedRefine.h @@ -11,7 +11,6 @@ #include "rellic/AST/ASTPass.h" #include "rellic/AST/IRToASTVisitor.h" #include "rellic/AST/TransformVisitor.h" -#include "rellic/AST/Z3ConvVisitor.h" namespace rellic { @@ -35,18 +34,11 @@ namespace rellic { */ class CondBasedRefine : public TransformVisitor { private: - std::unique_ptr z3_ctx; - std::unique_ptr z3_gen; - - z3::tactic z3_solver; - - z3::expr GetZ3Cond(clang::IfStmt *ifstmt); - protected: void RunImpl() override; public: - CondBasedRefine(Provenance &provenance, clang::ASTUnit &unit); + CondBasedRefine(DecompilationContext &dec_ctx, clang::ASTUnit &unit); bool VisitCompoundStmt(clang::CompoundStmt *compound); }; diff --git a/include/rellic/AST/DeadStmtElim.h b/include/rellic/AST/DeadStmtElim.h index 9d011c30..0eec0eda 100644 --- a/include/rellic/AST/DeadStmtElim.h +++ b/include/rellic/AST/DeadStmtElim.h @@ -22,7 +22,7 @@ class DeadStmtElim : public TransformVisitor { void RunImpl() override; public: - DeadStmtElim(Provenance &provenance, clang::ASTUnit &unit); + DeadStmtElim(DecompilationContext &dec_ctx, clang::ASTUnit &unit); bool VisitIfStmt(clang::IfStmt *ifstmt); bool VisitCompoundStmt(clang::CompoundStmt *compound); diff --git a/include/rellic/AST/ExprCombine.h b/include/rellic/AST/ExprCombine.h index cd4be916..1fe86dd2 100644 --- a/include/rellic/AST/ExprCombine.h +++ b/include/rellic/AST/ExprCombine.h @@ -21,7 +21,7 @@ class ExprCombine : public TransformVisitor { void RunImpl() override; public: - ExprCombine(Provenance &provenance, clang::ASTUnit &unit); + ExprCombine(DecompilationContext &dec_ctx, clang::ASTUnit &unit); bool VisitCStyleCastExpr(clang::CStyleCastExpr *cast); bool VisitUnaryOperator(clang::UnaryOperator *op); diff --git a/include/rellic/AST/GenerateAST.h b/include/rellic/AST/GenerateAST.h index 7f31b997..b80d10ae 100644 --- a/include/rellic/AST/GenerateAST.h +++ b/include/rellic/AST/GenerateAST.h @@ -13,6 +13,7 @@ #include #include +#include #include #include @@ -25,28 +26,14 @@ class GenerateAST : public llvm::AnalysisInfoMixin { friend llvm::AnalysisInfoMixin; static llvm::AnalysisKey Key; - // Need to use `map` with these instead of `unordered_map`, because - // `std::pair` doesn't have a default hash implementation - using BBEdge = std::pair; - using BrEdge = std::pair; - using SwEdge = std::pair; + constexpr static unsigned poison_idx = std::numeric_limits::max(); + z3::expr ToExpr(unsigned idx); + clang::ASTUnit &unit; clang::ASTContext *ast_ctx; rellic::IRToASTVisitor ast_gen; rellic::ASTBuilder ast; - z3::context *z_ctx; - - Provenance &provenance; - - z3::expr_vector z_exprs; - std::unordered_map z_br_edges_inv; - std::map z_br_edges; - - std::unordered_map z_sw_edges_inv; - std::map z_sw_edges; - - std::map z_edges; - std::unordered_map reaching_conds; + DecompilationContext &dec_ctx; bool reaching_conds_changed{true}; std::unordered_map block_stmts; std::unordered_map region_stmts; @@ -57,16 +44,25 @@ class GenerateAST : public llvm::AnalysisInfoMixin { std::vector rpo_walk; - z3::expr GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond); - z3::expr GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, + // GetOrCreateEdgeForBranch(branch, true) will return the index of an + // expression that is true when branch is taken. + // Viceversa, GetOrCreateEdgeForBranch(branch, false) is an expression that + // will be true when branch is not taken + unsigned GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond); + + // Returns the index of an expression containing a numerical variable that + // represents the condition of a switch. + unsigned GetOrCreateVarForSwitch(llvm::SwitchInst *inst); + // Returns the index of an expression that is true when a particular case of a + // switch is taken. If c is nullptr, the expression for the default case will + // be returned. + unsigned GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, llvm::ConstantInt *c); - z3::expr GetOrCreateEdgeCond(llvm::BasicBlock *from, llvm::BasicBlock *to); - z3::expr GetReachingCond(llvm::BasicBlock *block); + unsigned GetOrCreateEdgeCond(llvm::BasicBlock *from, llvm::BasicBlock *to); + unsigned GetReachingCond(llvm::BasicBlock *block); void CreateReachingCond(llvm::BasicBlock *block); - clang::Expr *ConvertExpr(z3::expr expr); - std::vector CreateBasicBlockStmts(llvm::BasicBlock *block); std::vector CreateRegionStmts(llvm::Region *region); @@ -82,12 +78,12 @@ class GenerateAST : public llvm::AnalysisInfoMixin { public: using Result = llvm::PreservedAnalyses; - GenerateAST(Provenance &provenance, clang::ASTUnit &unit); + GenerateAST(DecompilationContext &dec_ctx, clang::ASTUnit &unit); Result run(llvm::Function &F, llvm::FunctionAnalysisManager &FAM); Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM); - static void run(llvm::Module &M, Provenance &provenance, + static void run(llvm::Module &M, DecompilationContext &dec_ctx, clang::ASTUnit &unit); }; diff --git a/include/rellic/AST/IRToASTVisitor.h b/include/rellic/AST/IRToASTVisitor.h index 720b6d2c..5722a0a5 100644 --- a/include/rellic/AST/IRToASTVisitor.h +++ b/include/rellic/AST/IRToASTVisitor.h @@ -32,15 +32,16 @@ class IRToASTVisitor { ASTBuilder ast; - Provenance &provenance; + DecompilationContext &dec_ctx; void VisitArgument(llvm::Argument &arg); public: - IRToASTVisitor(clang::ASTUnit &unit, Provenance &provenance); + IRToASTVisitor(clang::ASTUnit &unit, DecompilationContext &dec_ctx); clang::Expr *CreateOperandExpr(llvm::Use &val); clang::Expr *CreateConstantExpr(llvm::Constant *constant); + clang::Expr *ConvertExpr(z3::expr expr); void VisitGlobalVar(llvm::GlobalVariable &var); void VisitFunctionDecl(llvm::Function &func); diff --git a/include/rellic/AST/InferenceRule.h b/include/rellic/AST/InferenceRule.h index 73124fd2..a1f89c30 100644 --- a/include/rellic/AST/InferenceRule.h +++ b/include/rellic/AST/InferenceRule.h @@ -34,13 +34,13 @@ class InferenceRule : public clang::ast_matchers::MatchFinder::MatchCallback { return cond; } - virtual clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + virtual clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) = 0; }; clang::Stmt *ApplyFirstMatchingRule( - Provenance &provenance, clang::ASTUnit &unit, clang::Stmt *stmt, + DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt, std::vector> &rules); } // namespace rellic \ No newline at end of file diff --git a/include/rellic/AST/LocalDeclRenamer.h b/include/rellic/AST/LocalDeclRenamer.h index fdd83869..8b57be8e 100644 --- a/include/rellic/AST/LocalDeclRenamer.h +++ b/include/rellic/AST/LocalDeclRenamer.h @@ -35,7 +35,7 @@ class LocalDeclRenamer : public TransformVisitor { void RunImpl() override; public: - LocalDeclRenamer(Provenance &provenance, clang::ASTUnit &unit, + LocalDeclRenamer(DecompilationContext &dec_ctx, clang::ASTUnit &unit, IRToNameMap &names); bool shouldTraversePostOrder() override; diff --git a/include/rellic/AST/LoopRefine.h b/include/rellic/AST/LoopRefine.h index 022690cb..363e7b35 100644 --- a/include/rellic/AST/LoopRefine.h +++ b/include/rellic/AST/LoopRefine.h @@ -34,7 +34,7 @@ class LoopRefine : public TransformVisitor { void RunImpl() override; public: - LoopRefine(Provenance &provenance, clang::ASTUnit &unit); + LoopRefine(DecompilationContext &dec_ctx, clang::ASTUnit &unit); bool VisitWhileStmt(clang::WhileStmt *loop); }; diff --git a/include/rellic/AST/MaterializeConds.h b/include/rellic/AST/MaterializeConds.h new file mode 100644 index 00000000..fcd3a587 --- /dev/null +++ b/include/rellic/AST/MaterializeConds.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021-present, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include + +#include "rellic/AST/IRToASTVisitor.h" +#include "rellic/AST/TransformVisitor.h" + +namespace rellic { + +/* + * This pass substitutes the marker expression in loops and `if` statements for + * their translation from Z3 formulas + */ +class MaterializeConds : public TransformVisitor { + private: + IRToASTVisitor ast_gen; + + protected: + void RunImpl() override; + + public: + MaterializeConds(DecompilationContext &dec_ctx, clang::ASTUnit &unit); + + bool VisitIfStmt(clang::IfStmt *stmt); + bool VisitWhileStmt(clang::WhileStmt *loop); + bool VisitDoStmt(clang::DoStmt *loop); +}; + +} // namespace rellic diff --git a/include/rellic/AST/NestedCondProp.h b/include/rellic/AST/NestedCondProp.h index f5afe0cb..1534f332 100644 --- a/include/rellic/AST/NestedCondProp.h +++ b/include/rellic/AST/NestedCondProp.h @@ -54,7 +54,7 @@ class NestedCondProp : public ASTPass { void RunImpl() override; public: - NestedCondProp(Provenance &provenance, clang::ASTUnit &unit); + NestedCondProp(DecompilationContext& dec_ctx, clang::ASTUnit& unit); }; } // namespace rellic diff --git a/include/rellic/AST/NestedScopeCombine.h b/include/rellic/AST/NestedScopeCombine.h index 42db6bb8..b9158cd6 100644 --- a/include/rellic/AST/NestedScopeCombine.h +++ b/include/rellic/AST/NestedScopeCombine.h @@ -35,9 +35,10 @@ class NestedScopeCombine : public TransformVisitor { void RunImpl() override; public: - NestedScopeCombine(Provenance &provenance, clang::ASTUnit &unit); + NestedScopeCombine(DecompilationContext &dec_ctx, clang::ASTUnit &unit); bool VisitIfStmt(clang::IfStmt *ifstmt); + bool VisitWhileStmt(clang::WhileStmt *stmt); bool VisitCompoundStmt(clang::CompoundStmt *compound); }; diff --git a/include/rellic/AST/NormalizeCond.h b/include/rellic/AST/NormalizeCond.h deleted file mode 100644 index 45bca02b..00000000 --- a/include/rellic/AST/NormalizeCond.h +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2022-present, Trail of Bits, Inc. - * All rights reserved. - * - * This source code is licensed in accordance with the terms specified in - * the LICENSE file found in the root directory of this source tree. - */ - -#pragma once - -#include "rellic/AST/TransformVisitor.h" - -namespace rellic { - -/* - * This pass turns conditions into conjunctive normal form (CNF). Warning: this - * has the potential of creating an exponential number of terms, so it's best to - * perform this pass after simplification. - */ -class NormalizeCond : public TransformVisitor { - protected: - void RunImpl() override; - - public: - static char ID; - - NormalizeCond(Provenance &provenance, clang::ASTUnit &unit); - - bool VisitUnaryOperator(clang::UnaryOperator *op); - bool VisitBinaryOperator(clang::BinaryOperator *op); -}; - -} // namespace rellic diff --git a/include/rellic/AST/ReachBasedRefine.h b/include/rellic/AST/ReachBasedRefine.h index 0bb97fb2..1781cf7b 100644 --- a/include/rellic/AST/ReachBasedRefine.h +++ b/include/rellic/AST/ReachBasedRefine.h @@ -9,7 +9,6 @@ #pragma once #include "rellic/AST/TransformVisitor.h" -#include "rellic/AST/Z3ConvVisitor.h" namespace rellic { @@ -38,16 +37,11 @@ namespace rellic { */ class ReachBasedRefine : public TransformVisitor { private: - std::unique_ptr z3_ctx; - std::unique_ptr z3_gen; - - z3::expr GetZ3Cond(clang::IfStmt *ifstmt); - protected: void RunImpl() override; public: - ReachBasedRefine(Provenance &provenance, clang::ASTUnit &unit); + ReachBasedRefine(DecompilationContext &dec_ctx, clang::ASTUnit &unit); bool VisitCompoundStmt(clang::CompoundStmt *compound); }; diff --git a/include/rellic/AST/StructFieldRenamer.h b/include/rellic/AST/StructFieldRenamer.h index 556290b3..9af6040a 100644 --- a/include/rellic/AST/StructFieldRenamer.h +++ b/include/rellic/AST/StructFieldRenamer.h @@ -29,7 +29,7 @@ class StructFieldRenamer void RunImpl() override; public: - StructFieldRenamer(Provenance &provenance, clang::ASTUnit &unit, + StructFieldRenamer(DecompilationContext &dec_ctx, clang::ASTUnit &unit, IRTypeToDITypeMap &types); bool VisitRecordDecl(clang::RecordDecl *decl); diff --git a/include/rellic/AST/TransformVisitor.h b/include/rellic/AST/TransformVisitor.h index 13b444dd..50e35965 100644 --- a/include/rellic/AST/TransformVisitor.h +++ b/include/rellic/AST/TransformVisitor.h @@ -26,11 +26,11 @@ class TransformVisitor : public ASTPass, StmtSubMap substitutions; void CopyProvenance(clang::Stmt *from, clang::Stmt *to) { - ::rellic::CopyProvenance(from, to, provenance.stmt_provenance); + ::rellic::CopyProvenance(from, to, dec_ctx.stmt_provenance); } void CopyProvenance(clang::Expr *from, clang::Expr *to) { - ::rellic::CopyProvenance(from, to, provenance.use_provenance); + ::rellic::CopyProvenance(from, to, dec_ctx.use_provenance); } bool ReplaceChildren(clang::Stmt *stmt, StmtSubMap &repl_map) { @@ -55,8 +55,8 @@ class TransformVisitor : public ASTPass, void RunImpl() override { substitutions.clear(); } public: - TransformVisitor(Provenance &provenance, clang::ASTUnit &unit) - : ASTPass(provenance, unit) {} + TransformVisitor(DecompilationContext &dec_ctx, clang::ASTUnit &unit) + : ASTPass(dec_ctx, unit) {} virtual bool shouldTraversePostOrder() { return true; } diff --git a/include/rellic/AST/Util.h b/include/rellic/AST/Util.h index f79ae3dc..c2afd11e 100644 --- a/include/rellic/AST/Util.h +++ b/include/rellic/AST/Util.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -50,24 +51,52 @@ size_t GetNumDecls(clang::DeclContext *decl_ctx) { return result; } -using StmtToIRMap = std::unordered_map; -using ExprToUseMap = std::unordered_map; -using IRToTypeDeclMap = std::unordered_map; -using IRToValDeclMap = std::unordered_map; -using IRToStmtMap = std::unordered_map; -using ArgToTempMap = std::unordered_map; -using BlockToUsesMap = - std::unordered_map>; -struct Provenance { +struct DecompilationContext { + using StmtToIRMap = std::unordered_map; + using ExprToUseMap = std::unordered_map; + using IRToTypeDeclMap = std::unordered_map; + using IRToValDeclMap = std::unordered_map; + using IRToStmtMap = std::unordered_map; + using ArgToTempMap = std::unordered_map; + using BlockToUsesMap = + std::unordered_map>; + using Z3CondMap = std::unordered_map; + + using BBEdge = std::pair; + using BrEdge = std::pair; + using SwEdge = std::pair; + StmtToIRMap stmt_provenance; ExprToUseMap use_provenance; IRToTypeDeclMap type_decls; IRToValDeclMap value_decls; ArgToTempMap temp_decls; BlockToUsesMap outgoing_uses; + z3::context z3_ctx; + z3::expr_vector z3_exprs{z3_ctx}; + Z3CondMap conds; + + clang::Expr *marker_expr; + + std::unordered_map z3_br_edges_inv; + + // Pairs do not have a std::hash specialization so we can't use unordered maps + // here. If this turns out to be a performance issue, investigate adding hash + // specializations for these specifically + std::map z3_br_edges; + + std::unordered_map z3_sw_vars; + std::unordered_map z3_sw_vars_inv; + std::map z3_sw_edges; + + std::map z3_edges; + std::unordered_map reaching_conds; size_t num_literal_structs = 0; size_t num_declared_structs = 0; + + // Inserts an expression into z3_exprs and returns its index + unsigned InsertZExpr(const z3::expr &e); }; template @@ -77,15 +106,18 @@ void CopyProvenance(TKey1 *from, TKey2 *to, } clang::Expr *Clone(clang::ASTUnit &unit, clang::Expr *stmt, - ExprToUseMap &provenance); - -// Negates an expression while stripping parentheses and double negations -clang::Expr *Negate(ASTBuilder &ast, clang::Expr *expr); + DecompilationContext::ExprToUseMap &provenance); std::string ClangThingToString(const clang::Stmt *stmt); -z3::goal ApplyTactic(z3::context &ctx, const z3::tactic &tactic, z3::expr expr); +z3::goal ApplyTactic(const z3::tactic &tactic, z3::expr expr); + +bool Prove(z3::expr expr); -bool Prove(z3::context &ctx, z3::expr expr); +z3::expr HeavySimplify(z3::expr expr); +z3::expr_vector Clone(z3::expr_vector &vec); +// Tries to keep each subformula sorted by its id so that they don't get +// shuffled around by simplification +z3::expr OrderById(z3::expr expr); } // namespace rellic \ No newline at end of file diff --git a/include/rellic/AST/Z3CondSimplify.h b/include/rellic/AST/Z3CondSimplify.h index e851c035..edd27522 100644 --- a/include/rellic/AST/Z3CondSimplify.h +++ b/include/rellic/AST/Z3CondSimplify.h @@ -10,9 +10,7 @@ #include -#include "rellic/AST/TransformVisitor.h" -#include "rellic/AST/Util.h" -#include "rellic/AST/Z3ConvVisitor.h" +#include "rellic/AST/ASTPass.h" namespace rellic { @@ -20,54 +18,13 @@ namespace rellic { * This pass simplifies conditions using Z3 by trying to remove terms that are * trivially true or false */ -class Z3CondSimplify : public TransformVisitor { +class Z3CondSimplify : public ASTPass { private: - std::unique_ptr z_ctx; - std::unique_ptr z_gen; - - z3::expr ToZ3(clang::Expr *e); - - std::unordered_map hashes; - - struct Hash { - clang::ASTContext &ctx; - std::unordered_map &hashes; - std::size_t operator()(clang::Expr *e) const noexcept { - auto &hash{hashes[e]}; - if (!hash) { - hash = GetHash(ctx, e); - } - return hash; - } - }; - - struct KeyEqual { - bool operator()(clang::Expr *a, clang::Expr *b) const noexcept { - return IsEquivalent(a, b); - } - }; - Hash hash_adaptor; - KeyEqual ke_adaptor; - - std::unordered_map proven_true; - std::unordered_map proven_false; - - bool IsProvenTrue(clang::Expr *e); - bool IsProvenFalse(clang::Expr *e); - - clang::Expr *Simplify(clang::Expr *e); - protected: void RunImpl() override; public: - Z3CondSimplify(Provenance &provenance, clang::ASTUnit &unit); - - z3::context &GetZ3Context() { return *z_ctx; } - - bool VisitIfStmt(clang::IfStmt *stmt); - bool VisitWhileStmt(clang::WhileStmt *loop); - bool VisitDoStmt(clang::DoStmt *loop); + Z3CondSimplify(DecompilationContext& dec_ctx, clang::ASTUnit& unit); }; } // namespace rellic diff --git a/include/rellic/AST/Z3ConvVisitor.h b/include/rellic/AST/Z3ConvVisitor.h deleted file mode 100644 index 889976eb..00000000 --- a/include/rellic/AST/Z3ConvVisitor.h +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright (c) 2021-present, Trail of Bits, Inc. - * All rights reserved. - * - * This source code is licensed in accordance with the terms specified in - * the LICENSE file found in the root directory of this source tree. - */ - -#pragma once - -#include -#include - -#include - -#include "rellic/AST/ASTBuilder.h" - -namespace rellic { - -class Z3ConvVisitor : public clang::RecursiveASTVisitor { - private: - clang::ASTContext *c_ctx; - ASTBuilder ast; - - z3::context *z_ctx; - - // Expression maps - z3::expr_vector z_expr_vec; - std::unordered_map z_expr_map; - std::unordered_map c_expr_map; - // Declaration maps - z3::func_decl_vector z_decl_vec; - std::unordered_map z_decl_map; - std::unordered_map c_decl_map; - - void InsertZ3Expr(clang::Expr *c_expr, z3::expr z_expr); - z3::expr GetZ3Expr(clang::Expr *c_expr); - - void InsertCExpr(z3::expr z_expr, clang::Expr *c_expr); - clang::Expr *GetCExpr(z3::expr z_expr); - - void InsertZ3Decl(clang::ValueDecl *c_decl, z3::func_decl z_decl); - z3::func_decl GetZ3Decl(clang::ValueDecl *c_decl); - - void InsertCValDecl(z3::func_decl z_decl, clang::ValueDecl *c_decl); - clang::ValueDecl *GetCValDecl(z3::func_decl z_decl); - - z3::sort GetZ3Sort(clang::QualType type); - - clang::Expr *CreateLiteralExpr(z3::expr z_expr); - - z3::expr CreateZ3BitwiseCast(z3::expr expr, size_t src, size_t dst, - bool sign); - - void VisitZ3Expr(z3::expr z_expr); - void VisitZ3Decl(z3::func_decl z_decl); - - template - bool HandleCastExpr(T *c_cast); - clang::Expr *HandleZ3Concat(z3::expr z_op); - clang::Expr *HandleZ3Uninterpreted(z3::expr z_op); - - public: - z3::func_decl GetOrCreateZ3Decl(clang::ValueDecl *c_decl); - z3::expr GetOrCreateZ3Expr(clang::Expr *c_expr); - - clang::Expr *GetOrCreateCExpr(z3::expr z_expr); - clang::ValueDecl *GetOrCreateCValDecl(z3::func_decl z_decl); - - Z3ConvVisitor(clang::ASTUnit &unit, z3::context *z_ctx); - bool shouldTraversePostOrder() { return true; } - - z3::expr Z3BoolCast(z3::expr expr); - z3::expr Z3BoolToBVCast(z3::expr expr); - - bool VisitArraySubscriptExpr(clang::ArraySubscriptExpr *sub); - bool VisitImplicitCastExpr(clang::ImplicitCastExpr *cast); - bool VisitCStyleCastExpr(clang::CStyleCastExpr *cast); - bool VisitMemberExpr(clang::MemberExpr *expr); - bool VisitCallExpr(clang::CallExpr *call); - bool VisitParenExpr(clang::ParenExpr *parens); - bool VisitUnaryOperator(clang::UnaryOperator *c_op); - bool VisitBinaryOperator(clang::BinaryOperator *c_op); - bool VisitConditionalOperator(clang::ConditionalOperator *c_op); - bool VisitDeclRefExpr(clang::DeclRefExpr *c_ref); - bool VisitCharacterLiteral(clang::CharacterLiteral *lit); - bool VisitIntegerLiteral(clang::IntegerLiteral *lit); - bool VisitFloatingLiteral(clang::FloatingLiteral *lit); - - bool VisitVarDecl(clang::VarDecl *var); - bool VisitFieldDecl(clang::FieldDecl *field); - bool VisitFunctionDecl(clang::FunctionDecl *func); - - // Do not traverse function bodies - bool TraverseFunctionDecl(clang::FunctionDecl *func) { - WalkUpFromFunctionDecl(func); - return true; - } - - void VisitConstant(z3::expr z_const); - void VisitUnaryApp(z3::expr z_op); - void VisitBinaryApp(z3::expr z_op); - void VisitTernaryApp(z3::expr z_op); -}; - -} // namespace rellic diff --git a/include/rellic/BC/Util.h b/include/rellic/BC/Util.h index 2435395a..a08e5aa2 100644 --- a/include/rellic/BC/Util.h +++ b/include/rellic/BC/Util.h @@ -67,6 +67,4 @@ void RemoveInsertValues(llvm::Module &module); // Converts by value array arguments and wraps them into a struct, so that // semantics are preserved in C void ConvertArrayArguments(llvm::Module &module); - -void FindRedundantLoads(llvm::Module &module); } // namespace rellic \ No newline at end of file diff --git a/include/rellic/Decompiler.h b/include/rellic/Decompiler.h index a229d5b6..a62fc659 100644 --- a/include/rellic/Decompiler.h +++ b/include/rellic/Decompiler.h @@ -20,35 +20,6 @@ namespace rellic { struct DecompilationOptions { bool lower_switches = false; bool remove_phi_nodes = false; - bool disable_z3 = false; - bool dead_stmt_elimination = true; - struct { - bool z3_cond_simplify = true; - bool nested_cond_propagate = true; - bool nested_scope_combine = true; - bool cond_base_refine = true; - bool reach_based_refine = true; - bool expression_normalize = false; - } condition_based_refinement; - struct { - bool loop_refine = true; - bool nested_cond_propagate = true; - bool nested_scope_combine = true; - bool expression_normalize = false; - } loop_refinement; - struct { - bool z3_cond_simplify = true; - bool nested_cond_propagate = true; - bool nested_scope_combine = true; - bool expression_normalize = false; - } scope_refinement; - bool expression_normalize = false; - bool expression_combine = true; - struct { - bool z3_cond_simplify = true; - bool nested_cond_propagate = true; - bool nested_scope_combine = true; - } final_refinement; }; struct DecompilationResult { diff --git a/lib/AST/CondBasedRefine.cpp b/lib/AST/CondBasedRefine.cpp index c597e527..51216404 100644 --- a/lib/AST/CondBasedRefine.cpp +++ b/lib/AST/CondBasedRefine.cpp @@ -11,83 +11,103 @@ #include #include -#include - -#include "rellic/AST/Util.h" +#include namespace rellic { -CondBasedRefine::CondBasedRefine(Provenance &provenance, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit), - z3_ctx(new z3::context()), - z3_gen(new rellic::Z3ConvVisitor(unit, z3_ctx.get())), - z3_solver(*z3_ctx, "sat") {} - -z3::expr CondBasedRefine::GetZ3Cond(clang::IfStmt *ifstmt) { - auto cond = ifstmt->getCond(); - auto expr = z3_gen->Z3BoolCast(z3_gen->GetOrCreateZ3Expr(cond)); - return expr.simplify(); -} +CondBasedRefine::CondBasedRefine(DecompilationContext &dec_ctx, + clang::ASTUnit &unit) + : TransformVisitor(dec_ctx, unit) {} bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { std::vector body{compound->body_begin(), compound->body_end()}; bool did_something{false}; - for (size_t i{0}; i + 1 < body.size() && !Stopped(); ++i) { + + for (size_t i{0}; i + 1 < body.size() && !did_something; ++i) { auto if_a{clang::dyn_cast(body[i])}; auto if_b{clang::dyn_cast(body[i + 1])}; + // We need two `if` statements to combine if (!if_a || !if_b) { continue; } + auto cond_a{dec_ctx.z3_exprs[dec_ctx.conds[if_a]]}; + auto cond_b{dec_ctx.z3_exprs[dec_ctx.conds[if_b]]}; + auto then_a{if_a->getThen()}; auto then_b{if_b->getThen()}; auto else_a{if_a->getElse()}; auto else_b{if_b->getElse()}; - auto cond_a{GetZ3Cond(if_a)}; - auto cond_b{GetZ3Cond(if_b)}; - - clang::IfStmt *new_if{}; std::vector new_then_body{then_a}; - if (Prove(*z3_ctx, cond_a == cond_b)) { + clang::IfStmt *new_if{nullptr}; + if (Prove(cond_a == cond_b)) { + // We found two consecutive `if` statements with identical conditions, so + // we can merge their `then` and `else` branches + // + // if(a) { X1; } else { Y1; } + // if(a) { X2; } else { Y2; } + // becomes + // if(a) { X1; X2; } else { Y1; Y2; } new_then_body.push_back(then_b); auto new_then{ast.CreateCompoundStmt(new_then_body)}; - new_if = ast.CreateIf(if_a->getCond(), new_then); + + new_if = ast.CreateIf(dec_ctx.marker_expr, new_then); if (else_a || else_b) { - std::vector new_else_body; + // At least one of the two `if` statements has an `else` branch + std::vector new_else_body{}; + if (else_a) { new_else_body.push_back(else_a); } + if (else_b) { new_else_body.push_back(else_b); } - new_if->setElse(ast.CreateCompoundStmt(new_else_body)); + + auto new_else{ast.CreateCompoundStmt(new_else_body)}; + new_if->setElse(new_else); } - } else if (Prove(*z3_ctx, cond_a == !cond_b)) { + + did_something = true; + } else if (Prove(cond_a == !cond_b)) { + // We found two consecutive `if` statements with opposite conditions, so + // we can append the else branch of the second to the then branch of the + // first, and viceversa + // + // if(a) { X1; } else { Y1; } + // if(!a) { X2; } else { Y2; } + // becomes + // if(a) { X1; Y2; } else { Y1; X2; } if (else_b) { new_then_body.push_back(else_b); } + auto new_then{ast.CreateCompoundStmt(new_then_body)}; - new_if = ast.CreateIf(if_a->getCond(), new_then); - std::vector new_else_body; + std::vector new_else_body{}; if (else_a) { new_else_body.push_back(else_a); } new_else_body.push_back(then_b); - new_if->setElse(ast.CreateCompoundStmt(new_else_body)); + + new_if = ast.CreateIf(dec_ctx.marker_expr, new_then); + + auto new_else{ast.CreateCompoundStmt(new_else_body)}; + new_if->setElse(new_else); + + did_something = true; } - if (new_if) { + if (did_something) { + dec_ctx.conds[new_if] = dec_ctx.conds[if_a]; body[i] = new_if; body.erase(std::next(body.begin(), i + 1)); - did_something = true; } } - if (did_something) { substitutions[compound] = ast.CreateCompoundStmt(body); } diff --git a/lib/AST/DeadStmtElim.cpp b/lib/AST/DeadStmtElim.cpp index 2309b1ca..038a5b83 100644 --- a/lib/AST/DeadStmtElim.cpp +++ b/lib/AST/DeadStmtElim.cpp @@ -12,22 +12,19 @@ namespace rellic { -DeadStmtElim::DeadStmtElim(Provenance &provenance, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit) {} +DeadStmtElim::DeadStmtElim(DecompilationContext &dec_ctx, clang::ASTUnit &unit) + : TransformVisitor(dec_ctx, unit) {} bool DeadStmtElim::VisitIfStmt(clang::IfStmt *ifstmt) { // DLOG(INFO) << "VisitIfStmt"; - bool expr_bool_value = false; - auto if_const_expr = ifstmt->getCond()->getIntegerConstantExpr(ast_ctx); - - bool is_const = if_const_expr.hasValue(); - if (is_const) { - expr_bool_value = if_const_expr->getBoolValue(); + bool can_delete = false; + if (ifstmt->getCond() == dec_ctx.marker_expr) { + can_delete = Prove(!dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]); } auto compound = clang::dyn_cast(ifstmt->getThen()); bool is_empty = compound ? compound->body_empty() : false; - if ((is_const && !expr_bool_value) || is_empty) { + if (can_delete || is_empty) { substitutions[ifstmt] = nullptr; } return true; diff --git a/lib/AST/ExprCombine.cpp b/lib/AST/ExprCombine.cpp index 441848f5..9303d479 100644 --- a/lib/AST/ExprCombine.cpp +++ b/lib/AST/ExprCombine.cpp @@ -40,7 +40,7 @@ class ArraySubscriptAddrOfRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto sub = clang::cast(stmt); @@ -48,7 +48,7 @@ class ArraySubscriptAddrOfRule : public InferenceRule { "ArraySubscriptExpr!"; auto paren = clang::cast(sub->getBase()); auto addr_of = clang::cast(paren->getSubExpr()); - CopyProvenance(addr_of, addr_of->getSubExpr(), provenance.use_provenance); + CopyProvenance(addr_of, addr_of->getSubExpr(), dec_ctx.use_provenance); return addr_of->getSubExpr(); } }; @@ -66,7 +66,7 @@ class AddrOfArraySubscriptRule : public InferenceRule { match = result.Nodes.getNodeAs("addr_of"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto addr_of = clang::cast(stmt); @@ -74,7 +74,7 @@ class AddrOfArraySubscriptRule : public InferenceRule { << "Substituted UnaryOperator is not the matched UnaryOperator!"; auto subexpr = addr_of->getSubExpr()->IgnoreParenImpCasts(); auto sub = clang::cast(subexpr); - CopyProvenance(sub, sub->getBase(), provenance.use_provenance); + CopyProvenance(sub, sub->getBase(), dec_ctx.use_provenance); return sub->getBase(); } }; @@ -91,7 +91,7 @@ class DerefAddrOfRule : public InferenceRule { match = result.Nodes.getNodeAs("deref"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto deref = clang::cast(stmt); @@ -99,7 +99,7 @@ class DerefAddrOfRule : public InferenceRule { << "Substituted UnaryOperator is not the matched UnaryOperator!"; auto subexpr = deref->getSubExpr()->IgnoreParenImpCasts(); auto addr_of = clang::cast(subexpr); - CopyProvenance(addr_of, addr_of->getSubExpr(), provenance.use_provenance); + CopyProvenance(addr_of, addr_of->getSubExpr(), dec_ctx.use_provenance); return addr_of->getSubExpr(); } }; @@ -120,7 +120,7 @@ class DerefAddrOfConditionalRule : public InferenceRule { match = result.Nodes.getNodeAs("deref"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto deref = clang::cast(stmt); @@ -137,8 +137,8 @@ class DerefAddrOfConditionalRule : public InferenceRule { clang::cast(conditional->getTrueExpr()); auto addr_of2 = clang::cast(conditional->getFalseExpr()); - CopyProvenance(addr_of1, addr_of1->getSubExpr(), provenance.use_provenance); - CopyProvenance(addr_of2, addr_of2->getSubExpr(), provenance.use_provenance); + CopyProvenance(addr_of1, addr_of1->getSubExpr(), dec_ctx.use_provenance); + CopyProvenance(addr_of2, addr_of2->getSubExpr(), dec_ctx.use_provenance); return ASTBuilder(unit).CreateConditional( conditional->getCond(), addr_of1->getSubExpr(), addr_of2->getSubExpr()); @@ -161,7 +161,7 @@ class NegComparisonRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto op = clang::cast(stmt); @@ -172,7 +172,7 @@ class NegComparisonRule : public InferenceRule { auto opc = clang::BinaryOperator::negateComparisonOp(binop->getOpcode()); auto res = ASTBuilder(unit).CreateBinaryOp(opc, binop->getLHS(), binop->getRHS()); - CopyProvenance(op, res, provenance.stmt_provenance); + CopyProvenance(op, res, dec_ctx.stmt_provenance); return res; } }; @@ -189,7 +189,7 @@ class ParenDeclRefExprStripRule : public InferenceRule { match = result.Nodes.getNodeAs("paren"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto paren = clang::cast(stmt); @@ -209,7 +209,7 @@ class DoubleParenStripRule : public InferenceRule { match = result.Nodes.getNodeAs("paren"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto paren = clang::cast(stmt); @@ -235,7 +235,7 @@ class MemberExprAddrOfRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto arrow{clang::cast(stmt)}; @@ -246,7 +246,7 @@ class MemberExprAddrOfRule : public InferenceRule { auto field{clang::dyn_cast(arrow->getMemberDecl())}; CHECK(field != nullptr) << "Substituted MemberExpr is not a structure field access!"; - CopyProvenance(addr_of, addr_of->getSubExpr(), provenance.use_provenance); + CopyProvenance(addr_of, addr_of->getSubExpr(), dec_ctx.use_provenance); return ASTBuilder(unit).CreateDot(addr_of->getSubExpr(), field); } }; @@ -268,7 +268,7 @@ class MemberExprArraySubRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto dot{clang::cast(stmt)}; @@ -279,7 +279,7 @@ class MemberExprArraySubRule : public InferenceRule { auto field{clang::dyn_cast(dot->getMemberDecl())}; CHECK(field != nullptr) << "Substituted MemberExpr is not a structure field access!"; - CopyProvenance(sub, sub->getBase(), provenance.use_provenance); + CopyProvenance(sub, sub->getBase(), dec_ctx.use_provenance); return ASTBuilder(unit).CreateArrow(sub->getBase(), field); } }; @@ -300,7 +300,7 @@ class AssignCastedExprRule : public InferenceRule { }; } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto assign{clang::cast(stmt)}; @@ -336,7 +336,7 @@ class UnsignedToSignedCStyleCastRule : public InferenceRule { match = result.Nodes.getNodeAs("cast"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto cast{clang::cast(stmt)}; @@ -370,7 +370,7 @@ class TripleCStyleCastElimRule : public InferenceRule { match = result.Nodes.getNodeAs("cast"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto cast{clang::cast(stmt)}; @@ -409,7 +409,7 @@ class VoidToTypePtrCastElimRule : public InferenceRule { match = result.Nodes.getNodeAs("cast"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto cast{clang::cast(stmt)}; @@ -444,7 +444,7 @@ class CStyleZeroToPtrCastElimRule : public InferenceRule { match = result.Nodes.getNodeAs("cast"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto cast{clang::cast(stmt)}; @@ -466,7 +466,7 @@ class CStyleConstElimRule : public InferenceRule { match = result.Nodes.getNodeAs("cast"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto cast{clang::cast(stmt)}; @@ -485,8 +485,8 @@ class CStyleConstElimRule : public InferenceRule { } // namespace -ExprCombine::ExprCombine(Provenance &provenance, clang::ASTUnit &u) - : TransformVisitor(provenance, u) {} +ExprCombine::ExprCombine(DecompilationContext &dec_ctx, clang::ASTUnit &u) + : TransformVisitor(dec_ctx, u) {} bool ExprCombine::VisitCStyleCastExpr(clang::CStyleCastExpr *cast) { // TODO(frabert): Re-enable nullptr casts simplification @@ -500,7 +500,7 @@ bool ExprCombine::VisitCStyleCastExpr(clang::CStyleCastExpr *cast) { pre_rules.emplace_back(new VoidToTypePtrCastElimRule); - auto pre_sub{ApplyFirstMatchingRule(provenance, ast_unit, cast, pre_rules)}; + auto pre_sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, cast, pre_rules)}; if (pre_sub != cast) { substitutions[cast] = pre_sub; return true; @@ -533,7 +533,7 @@ bool ExprCombine::VisitCStyleCastExpr(clang::CStyleCastExpr *cast) { rules.emplace_back(new TripleCStyleCastElimRule); rules.emplace_back(new CStyleConstElimRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, cast, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, cast, rules)}; if (sub != cast) { substitutions[cast] = sub; } @@ -551,7 +551,7 @@ bool ExprCombine::VisitUnaryOperator(clang::UnaryOperator *op) { rules.emplace_back(new DerefAddrOfConditionalRule); rules.emplace_back(new AddrOfArraySubscriptRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, op, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, op, rules)}; if (sub != op) { substitutions[op] = sub; } @@ -565,7 +565,7 @@ bool ExprCombine::VisitBinaryOperator(clang::BinaryOperator *op) { rules.emplace_back(new AssignCastedExprRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, op, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, op, rules)}; if (sub != op) { substitutions[op] = sub; } @@ -579,7 +579,7 @@ bool ExprCombine::VisitArraySubscriptExpr(clang::ArraySubscriptExpr *expr) { rules.emplace_back(new ArraySubscriptAddrOfRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, expr, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, expr, rules)}; if (sub != expr) { substitutions[expr] = sub; } @@ -594,7 +594,7 @@ bool ExprCombine::VisitMemberExpr(clang::MemberExpr *expr) { rules.emplace_back(new MemberExprAddrOfRule); rules.emplace_back(new MemberExprArraySubRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, expr, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, expr, rules)}; if (sub != expr) { substitutions[expr] = sub; } @@ -609,7 +609,7 @@ bool ExprCombine::VisitParenExpr(clang::ParenExpr *expr) { rules.emplace_back(new ParenDeclRefExprStripRule); rules.emplace_back(new DoubleParenStripRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, expr, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, expr, rules)}; if (sub != expr) { substitutions[expr] = sub; } diff --git a/lib/AST/GenerateAST.cpp b/lib/AST/GenerateAST.cpp old mode 100644 new mode 100755 index e05cab26..238d4c9d --- a/lib/AST/GenerateAST.cpp +++ b/lib/AST/GenerateAST.cpp @@ -26,6 +26,7 @@ #include #include "rellic/AST/ASTBuilder.h" +#include "rellic/AST/Util.h" #include "rellic/BC/Util.h" #include "rellic/Exception.h" @@ -136,155 +137,177 @@ static std::string GetName(llvm::Value *v) { return s; } -z3::expr GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, +z3::expr GenerateAST::ToExpr(unsigned idx) { + if (idx == poison_idx) { + return dec_ctx.z3_ctx.bool_val(false); + } + return dec_ctx.z3_exprs[idx]; +} + +unsigned GenerateAST::GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond) { - if (z_br_edges.find({inst, cond}) == z_br_edges.end()) { - if (cond) { - auto name{GetName(inst)}; - auto edge{z_ctx->bool_const(name.c_str())}; - z_br_edges[{inst, cond}] = z_exprs.size(); - z_exprs.push_back(edge); - z_br_edges_inv[edge.id()] = {inst, true}; - } else { - auto edge{!(GetOrCreateEdgeForBranch(inst, true))}; - z_br_edges[{inst, cond}] = z_exprs.size(); - z_exprs.push_back(edge); - } + if (dec_ctx.z3_br_edges.find({inst, cond}) != dec_ctx.z3_br_edges.end()) { + return dec_ctx.z3_br_edges[{inst, cond}]; } - return z_exprs[z_br_edges[{inst, cond}]]; + if (auto constant = llvm::dyn_cast(inst->getCondition())) { + // This is a conditional branch with a constant condition, so just emit + // whether the condition matches the wanted value + auto edge{dec_ctx.z3_ctx.bool_val(constant->isOne() == cond)}; + dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.InsertZExpr(edge); + dec_ctx.z3_br_edges_inv[edge.id()] = {inst, true}; + } else if (cond) { + // This is a conditional branch, so the expression that is true when the + // branch is going to be taken is just a new variable. + auto name{GetName(inst)}; + auto edge{dec_ctx.z3_ctx.bool_const(name.c_str())}; + dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.InsertZExpr(edge); + dec_ctx.z3_br_edges_inv[edge.id()] = {inst, true}; + } else { + // Like the previous case, but in this case we want to know the expression + // that will be true when the branch is not going to be taken + auto edge{!(ToExpr(GetOrCreateEdgeForBranch(inst, true)))}; + dec_ctx.z3_br_edges[{inst, cond}] = dec_ctx.InsertZExpr(edge); + } + + return dec_ctx.z3_br_edges[{inst, cond}]; +} + +unsigned GenerateAST::GetOrCreateVarForSwitch(llvm::SwitchInst *inst) { + // To aide simplification, switch instructions actually produce numerical + // variables instead of boolean ones, but are always compared against a + // constant value. + if (dec_ctx.z3_sw_vars.find(inst) != dec_ctx.z3_sw_vars.end()) { + return dec_ctx.z3_sw_vars[inst]; + } + + auto name{GetName(inst)}; + auto var{dec_ctx.z3_ctx.int_const(name.c_str())}; + dec_ctx.z3_sw_vars[inst] = dec_ctx.InsertZExpr(var); + dec_ctx.z3_sw_vars_inv[var.id()] = inst; + return dec_ctx.z3_sw_vars[inst]; } -z3::expr GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, +unsigned GenerateAST::GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst, llvm::ConstantInt *c) { - if (z_sw_edges.find({inst, c}) == z_sw_edges.end()) { - if (c) { - auto name{GetName(inst) + - GetName(inst->findCaseValue(c)->getCaseSuccessor())}; - auto edge{z_ctx->bool_const(name.c_str())}; - z_sw_edges_inv[edge.id()] = {inst, c}; - - z_sw_edges[{inst, c}] = z_exprs.size(); - z_exprs.push_back(edge); - } else { - // Default case - auto edge{z_ctx->bool_val(true)}; - for (auto sw_case : inst->cases()) { - edge = edge && !GetOrCreateEdgeForSwitch(inst, sw_case.getCaseValue()); - } - edge = edge.simplify(); - z_sw_edges[{inst, c}] = z_exprs.size(); - z_exprs.push_back(edge); - } + if (dec_ctx.z3_sw_edges.find({inst, c}) != dec_ctx.z3_sw_edges.end()) { + return dec_ctx.z3_sw_edges[{inst, c}]; } - return z_exprs[z_sw_edges[{inst, c}]]; + unsigned idx; + if (c) { + auto sw_case{inst->findCaseValue(c)}; + auto var{ToExpr(GetOrCreateVarForSwitch(inst))}; + auto expr{var == dec_ctx.z3_ctx.int_val(sw_case->getCaseIndex())}; + + idx = dec_ctx.InsertZExpr(expr); + } else { + // Default case + z3::expr_vector vec{dec_ctx.z3_ctx}; + for (auto sw_case : inst->cases()) { + vec.push_back( + !ToExpr(GetOrCreateEdgeForSwitch(inst, sw_case.getCaseValue()))); + } + idx = dec_ctx.InsertZExpr(z3::mk_and(vec)); + } + dec_ctx.z3_sw_edges[{inst, c}] = idx; + return idx; } -z3::expr GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, +unsigned GenerateAST::GetOrCreateEdgeCond(llvm::BasicBlock *from, llvm::BasicBlock *to) { - if (z_edges.find({from, to}) == z_edges.end()) { - // Construct the edge condition for CFG edge `(from, to)` - auto result{z_ctx->bool_val(true)}; - auto term = from->getTerminator(); - switch (term->getOpcode()) { - // Conditional branches - case llvm::Instruction::Br: { - auto br = llvm::cast(term); - if (br->isConditional()) { - result = GetOrCreateEdgeForBranch(br, to == br->getSuccessor(0)); - } - } break; - // Switches - case llvm::Instruction::Switch: { - auto sw{llvm::cast(term)}; - if (to == sw->getDefaultDest()) { - result = GetOrCreateEdgeForSwitch(sw, nullptr); - } else { - result = z_ctx->bool_val(false); - for (auto sw_case : sw->cases()) { - if (sw_case.getCaseSuccessor() == to) { - result = result || - GetOrCreateEdgeForSwitch(sw, sw_case.getCaseValue()); - } + if (dec_ctx.z3_edges.find({from, to}) != dec_ctx.z3_edges.end()) { + return dec_ctx.z3_edges[{from, to}]; + } + + // Construct the edge condition for CFG edge `(from, to)` + auto result{dec_ctx.z3_ctx.bool_val(true)}; + auto term = from->getTerminator(); + switch (term->getOpcode()) { + // Conditional branches + case llvm::Instruction::Br: { + auto br = llvm::cast(term); + if (br->isConditional()) { + result = + ToExpr(GetOrCreateEdgeForBranch(br, to == br->getSuccessor(0))); + } + } break; + // Switches + case llvm::Instruction::Switch: { + auto sw{llvm::cast(term)}; + if (to == sw->getDefaultDest()) { + result = ToExpr(GetOrCreateEdgeForSwitch(sw, nullptr)); + } else { + z3::expr_vector or_vec{dec_ctx.z3_ctx}; + for (auto sw_case : sw->cases()) { + if (sw_case.getCaseSuccessor() == to) { + or_vec.push_back( + ToExpr(GetOrCreateEdgeForSwitch(sw, sw_case.getCaseValue()))); } } - } break; - // Returns - case llvm::Instruction::Ret: - break; - // Exceptions - case llvm::Instruction::Invoke: - case llvm::Instruction::Resume: - case llvm::Instruction::CatchSwitch: - case llvm::Instruction::CatchRet: - case llvm::Instruction::CleanupRet: - THROW() << "Exception terminator '" << term->getOpcodeName() - << "' is not supported yet"; - break; - // Unknown - default: - THROW() << "Unsupported terminator instruction: " - << term->getOpcodeName(); - break; - } - - z_edges[{from, to}] = z_exprs.size(); - z_exprs.push_back(result.simplify()); + result = HeavySimplify(z3::mk_or(or_vec)); + } + } break; + // Returns + case llvm::Instruction::Ret: + break; + // Exceptions + case llvm::Instruction::Invoke: + case llvm::Instruction::Resume: + case llvm::Instruction::CatchSwitch: + case llvm::Instruction::CatchRet: + case llvm::Instruction::CleanupRet: + THROW() << "Exception terminator '" << term->getOpcodeName() + << "' is not supported yet"; + break; + // Unknown + default: + THROW() << "Unsupported terminator instruction: " + << term->getOpcodeName(); + break; } - return z_exprs[z_edges[{from, to}]]; + dec_ctx.z3_edges[{from, to}] = dec_ctx.InsertZExpr(result.simplify()); + return dec_ctx.z3_edges[{from, to}]; } -z3::expr GenerateAST::GetReachingCond(llvm::BasicBlock *block) { - if (reaching_conds.find(block) == reaching_conds.end()) { - return z_ctx->bool_val(false); +unsigned GenerateAST::GetReachingCond(llvm::BasicBlock *block) { + if (dec_ctx.reaching_conds.find(block) == dec_ctx.reaching_conds.end()) { + return poison_idx; } - return z_exprs[reaching_conds[block]]; + return dec_ctx.reaching_conds[block]; } void GenerateAST::CreateReachingCond(llvm::BasicBlock *block) { - auto Simplify{[this](z3::expr expr) { - if (Prove(*z_ctx, expr)) { - return z_ctx->bool_val(true); - } - - z3::tactic aig(*z_ctx, "aig"); - z3::tactic simplify(*z_ctx, "simplify"); - z3::tactic ctx_solver_simplify(*z_ctx, "ctx-solver-simplify"); - auto tactic{simplify & aig & ctx_solver_simplify}; - return ApplyTactic(*z_ctx, tactic, expr).as_expr(); - }}; - - auto old_cond{GetReachingCond(block)}; + auto old_cond_idx{GetReachingCond(block)}; + auto old_cond{ToExpr(old_cond_idx)}; if (block->hasNPredecessorsOrMore(1)) { // Gather reaching conditions from predecessors of the block - auto cond{z_ctx->bool_val(false)}; + z3::expr_vector conds{dec_ctx.z3_ctx}; for (auto pred : llvm::predecessors(block)) { - auto pred_cond{GetReachingCond(pred)}; - auto edge_cond{GetOrCreateEdgeCond(pred, block)}; + auto pred_cond{ToExpr(GetReachingCond(pred))}; + auto edge_cond{ToExpr(GetOrCreateEdgeCond(pred, block))}; // Construct reaching condition from `pred` to `block` as // `reach_cond[pred] && edge_cond(pred, block)` or one of // the two if the other one is missing. - auto conj_cond{Simplify(pred_cond && edge_cond)}; + auto conj_cond{HeavySimplify(pred_cond && edge_cond)}; // Append `conj_cond` to reaching conditions of other // predecessors via an `||`. Use `conj_cond` if there // is no `cond` yet. - cond = Simplify(cond || conj_cond); + conds.push_back(conj_cond); } - if (!Prove(*z_ctx, old_cond == cond)) { - reaching_conds[block] = z_exprs.size(); - z_exprs.push_back(cond); - reaching_conds_changed = true; - } - } else { - if (reaching_conds.find(block) == reaching_conds.end()) { - reaching_conds[block] = z_exprs.size(); - z_exprs.push_back(z_ctx->bool_val(true)); + auto cond{HeavySimplify(z3::mk_or(conds))}; + if (old_cond_idx == poison_idx || !Prove(old_cond == cond)) { + dec_ctx.reaching_conds[block] = dec_ctx.InsertZExpr(cond); reaching_conds_changed = true; } + } else if (dec_ctx.reaching_conds.find(block) == + dec_ctx.reaching_conds.end()) { + dec_ctx.reaching_conds[block] = + dec_ctx.InsertZExpr(dec_ctx.z3_ctx.bool_val(true)); + reaching_conds_changed = true; } } @@ -294,64 +317,6 @@ StmtVec GenerateAST::CreateBasicBlockStmts(llvm::BasicBlock *block) { return result; } -clang::Expr *GenerateAST::ConvertExpr(z3::expr expr) { - auto hash{expr.id()}; - if (z_br_edges_inv.find(hash) != z_br_edges_inv.end()) { - auto edge{z_br_edges_inv[hash]}; - CHECK(edge.second) << "Inverse map should only be populated for branches " - "taken when condition is true"; - return ast_gen.CreateOperandExpr(*(edge.first->op_end() - 3)); - } - - if (z_sw_edges_inv.find(hash) != z_sw_edges_inv.end()) { - auto edge{z_sw_edges_inv[hash]}; - CHECK(edge.second) - << "Inverse map should only be populated for not-default switch cases"; - - auto opnd{ast_gen.CreateOperandExpr(edge.first->getOperandUse(0))}; - return ast.CreateEQ(opnd, ast_gen.CreateConstantExpr(edge.second)); - } - - std::vector args; - for (auto i{0U}; i < expr.num_args(); ++i) { - args.push_back(ConvertExpr(expr.arg(i))); - } - - switch (expr.decl().decl_kind()) { - case Z3_OP_TRUE: - CHECK_EQ(args.size(), 0) << "True cannot have arguments"; - return ast.CreateTrue(); - case Z3_OP_FALSE: - CHECK_EQ(args.size(), 0) << "False cannot have arguments"; - return ast.CreateFalse(); - case Z3_OP_AND: { - CHECK_GE(args.size(), 2) << "And must have at least 2 arguments"; - clang::Expr *res{args[0]}; - for (auto i{1U}; i < args.size(); ++i) { - res = ast.CreateLAnd(res, args[i]); - } - return res; - } - case Z3_OP_OR: { - CHECK_GE(args.size(), 2) << "Or must have at least 2 arguments"; - clang::Expr *res{args[0]}; - for (auto i{1U}; i < args.size(); ++i) { - res = ast.CreateLOr(res, args[i]); - } - return res; - } - case Z3_OP_NOT: { - CHECK_EQ(args.size(), 1) << "Not must have one argument"; - auto neg{ast.CreateLNot(args[0])}; - CopyProvenance(args[0], neg, provenance.use_provenance); - return neg; - } - default: - LOG(FATAL) << "Invalid z3 op"; - } - return nullptr; -} - StmtVec GenerateAST::CreateRegionStmts(llvm::Region *region) { StmtVec result; for (auto block : rpo_walk) { @@ -375,7 +340,9 @@ StmtVec GenerateAST::CreateRegionStmts(llvm::Region *region) { } // Gate the compound behind a reaching condition auto z_expr{GetReachingCond(block)}; - block_stmts[block] = ast.CreateIf(ConvertExpr(z_expr), compound); + block_stmts[block] = ast.CreateIf(dec_ctx.marker_expr, compound); + dec_ctx.conds[block_stmts[block]] = + dec_ctx.InsertZExpr(dec_ctx.z3_exprs[z_expr]); // Store the compound result.push_back(block_stmts[block]); } @@ -476,15 +443,17 @@ clang::CompoundStmt *GenerateAST::StructureCyclicRegion(llvm::Region *region) { for (auto edge : exits) { auto from = edge.first; auto to = edge.second; - // Create edge condition - auto z_expr{GetReachingCond(from) && GetOrCreateEdgeCond(from, to)}; - auto cond{ConvertExpr(z_expr.simplify())}; // Find the statement corresponding to the exiting block auto it = std::find(loop_body.begin(), loop_body.end(), block_stmts[from]); CHECK(it != loop_body.end()); // Create a loop exiting `break` statement StmtVec break_stmt({ast.CreateBreak()}); - auto exit_stmt = ast.CreateIf(cond, ast.CreateCompoundStmt(break_stmt)); + auto exit_stmt = + ast.CreateIf(dec_ctx.marker_expr, ast.CreateCompoundStmt(break_stmt)); + // Create edge condition + dec_ctx.conds[exit_stmt] = dec_ctx.InsertZExpr( + (ToExpr(GetReachingCond(from)) && ToExpr(GetOrCreateEdgeCond(from, to))) + .simplify()); // Insert it after the exiting block statement loop_body.insert(std::next(it), exit_stmt); } @@ -567,14 +536,12 @@ clang::CompoundStmt *GenerateAST::StructureRegion(llvm::Region *region) { llvm::AnalysisKey GenerateAST::Key; -GenerateAST::GenerateAST(Provenance &provenance, clang::ASTUnit &unit) +GenerateAST::GenerateAST(DecompilationContext &dec_ctx, clang::ASTUnit &unit) : ast_ctx(&unit.getASTContext()), unit(unit), - provenance(provenance), - ast_gen(unit, provenance), - ast(unit), - z_ctx(new z3::context()), - z_exprs(*z_ctx) {} + dec_ctx(dec_ctx), + ast_gen(unit, dec_ctx), + ast(unit) {} GenerateAST::Result GenerateAST::run(llvm::Module &module, llvm::ModuleAnalysisManager &MAM) { @@ -652,13 +619,13 @@ GenerateAST::Result GenerateAST::run(llvm::Function &func, // Call the above declared bad boy POWalkSubRegions(regions->getTopLevelRegion()); // Get the function declaration AST node for `func` - auto fdecl = clang::cast(provenance.value_decls[&func]); + auto fdecl = clang::cast(dec_ctx.value_decls[&func]); // Create a redeclaration of `fdecl` that will serve as a definition auto tudecl = ast_ctx->getTranslationUnitDecl(); auto fdefn = ast.CreateFunctionDecl(tudecl, fdecl->getType(), fdecl->getIdentifier()); fdefn->setPreviousDecl(fdecl); - provenance.value_decls[&func] = fdefn; + dec_ctx.value_decls[&func] = fdefn; tudecl->addDecl(fdefn); // Set parameters to the same as the previous declaration fdefn->setParams(fdecl->parameters()); @@ -680,20 +647,20 @@ GenerateAST::Result GenerateAST::run(llvm::Function &func, return llvm::PreservedAnalyses::all(); } -void GenerateAST::run(llvm::Module &module, Provenance &provenance, +void GenerateAST::run(llvm::Module &module, DecompilationContext &dec_ctx, clang::ASTUnit &unit) { llvm::ModulePassManager mpm; llvm::ModuleAnalysisManager mam; llvm::PassBuilder pb; - mam.registerPass([&] { return rellic::GenerateAST(provenance, unit); }); - mpm.addPass(rellic::GenerateAST(provenance, unit)); + mam.registerPass([&] { return rellic::GenerateAST(dec_ctx, unit); }); + mpm.addPass(rellic::GenerateAST(dec_ctx, unit)); pb.registerModuleAnalyses(mam); mpm.run(module, mam); llvm::FunctionPassManager fpm; llvm::FunctionAnalysisManager fam; - fam.registerPass([&] { return rellic::GenerateAST(provenance, unit); }); - fpm.addPass(rellic::GenerateAST(provenance, unit)); + fam.registerPass([&] { return rellic::GenerateAST(dec_ctx, unit); }); + fpm.addPass(rellic::GenerateAST(dec_ctx, unit)); pb.registerFunctionAnalyses(fam); for (auto &func : module.functions()) { fpm.run(func, fam); diff --git a/lib/AST/IRToASTVisitor.cpp b/lib/AST/IRToASTVisitor.cpp old mode 100644 new mode 100755 index da7eb92e..b44b8da6 --- a/lib/AST/IRToASTVisitor.cpp +++ b/lib/AST/IRToASTVisitor.cpp @@ -30,11 +30,13 @@ class ExprGen : public llvm::InstVisitor { ASTBuilder ast; - Provenance &provenance; + DecompilationContext &dec_ctx; + size_t num_literal_structs = 0; + size_t num_declared_structs = 0; public: - ExprGen(clang::ASTUnit &unit, Provenance &provenance) - : ast_ctx(unit.getASTContext()), ast(unit), provenance(provenance) {} + ExprGen(clang::ASTUnit &unit, DecompilationContext &dec_ctx) + : ast_ctx(unit.getASTContext()), ast(unit), dec_ctx(dec_ctx) {} void VisitGlobalVar(llvm::GlobalVariable &gvar); clang::QualType GetQualType(llvm::Type *type); @@ -61,9 +63,93 @@ class ExprGen : public llvm::InstVisitor { clang::Expr *visitUnaryOperator(llvm::UnaryOperator &inst); }; +clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { + if (expr.decl().decl_kind() == Z3_OP_EQ) { + // Equalities generated form the reaching conditions of switch instructions + // Always in the for (VAR == CONST) or (CONST == VAR) + // VAR will uniquely identify a SwitchInst, CONST will represent the index + // of the case taken + CHECK_EQ(expr.num_args(), 2) << "Equalities must have 2 arguments"; + auto a{expr.arg(0)}; + auto b{expr.arg(1)}; + + llvm::SwitchInst *inst{dec_ctx.z3_sw_vars_inv[a.id()]}; + unsigned case_idx{}; + + // GenerateAST always generates equalities in the form (VAR == CONST), but + // there is a chance that some Z3 simplification inverts the order, so + // handle that here. + if (!inst) { + inst = dec_ctx.z3_sw_vars_inv[b.id()]; + case_idx = a.get_numeral_uint(); + } else { + case_idx = b.get_numeral_uint(); + } + + for (auto sw_case : inst->cases()) { + if (sw_case.getCaseIndex() == case_idx) { + return ast.CreateEQ(CreateOperandExpr(inst->getOperandUse(0)), + CreateConstantExpr(sw_case.getCaseValue())); + } + } + + LOG(FATAL) << "Couldn't find switch case"; + } + + auto hash{expr.id()}; + if (dec_ctx.z3_br_edges_inv.find(hash) != dec_ctx.z3_br_edges_inv.end()) { + auto edge{dec_ctx.z3_br_edges_inv[hash]}; + CHECK(edge.second) << "Inverse map should only be populated for branches " + "taken when condition is true"; + // expr is a variable that represents the condition of a branch instruction. + + // FIXME(frabert): Unfortunately there is no public API in BranchInst that + // gives the operand of the condition. From reverse engineering LLVM code, + // this is the way they obtain uses internally, but it's probably not + // stable. + return CreateOperandExpr(*(edge.first->op_end() - 3)); + } + + switch (expr.decl().decl_kind()) { + case Z3_OP_TRUE: + CHECK_EQ(expr.num_args(), 0) << "True cannot have arguments"; + return ast.CreateTrue(); + case Z3_OP_FALSE: + CHECK_EQ(expr.num_args(), 0) << "False cannot have arguments"; + return ast.CreateFalse(); + case Z3_OP_AND: { + // Since AND and OR expressions are n-ary we need to convert them to + // binary. If they have only one subexpression, we can forego the AND/OR + // altogether. + clang::Expr *res{ConvertExpr(expr.arg(0))}; + for (auto i{1U}; i < expr.num_args(); ++i) { + res = ast.CreateLAnd(res, ConvertExpr(expr.arg(i))); + } + return res; + } + case Z3_OP_OR: { + clang::Expr *res{ConvertExpr(expr.arg(0))}; + for (auto i{1U}; i < expr.num_args(); ++i) { + res = ast.CreateLOr(res, ConvertExpr(expr.arg(i))); + } + return res; + } + case Z3_OP_NOT: { + CHECK_EQ(expr.num_args(), 1) << "Not must have one argument"; + auto sub{ConvertExpr(expr.arg(0))}; + auto neg{ast.CreateLNot(sub)}; + CopyProvenance(sub, neg, dec_ctx.use_provenance); + return neg; + } + default: + LOG(FATAL) << "Invalid z3 op"; + } + return nullptr; +} + void ExprGen::VisitGlobalVar(llvm::GlobalVariable &gvar) { DLOG(INFO) << "VisitGlobalVar: " << LLVMThingToString(&gvar); - auto &var{provenance.value_decls[&gvar]}; + auto &var{dec_ctx.value_decls[&gvar]}; if (var) { return; } @@ -164,16 +250,16 @@ clang::QualType ExprGen::GetQualType(llvm::Type *type) { case llvm::Type::StructTyID: { clang::RecordDecl *sdecl{nullptr}; - auto &decl{provenance.type_decls[type]}; + auto &decl{dec_ctx.type_decls[type]}; if (!decl) { auto tudecl{ast_ctx.getTranslationUnitDecl()}; auto strct{llvm::cast(type)}; auto sname{strct->isLiteral() ? ("literal_struct_" + - std::to_string(provenance.num_literal_structs++)) + std::to_string(dec_ctx.num_literal_structs++)) : strct->getName().str()}; if (sname.empty()) { - sname = "struct" + std::to_string(provenance.num_declared_structs++); + sname = "struct" + std::to_string(dec_ctx.num_declared_structs++); } // Create a C struct declaration @@ -227,13 +313,13 @@ clang::Expr *ExprGen::CreateConstantExpr(llvm::Constant *constant) { if (auto cexpr = llvm::dyn_cast(constant)) { auto inst{cexpr->getAsInstruction()}; auto expr{visit(inst)}; - provenance.use_provenance.erase(expr); + dec_ctx.use_provenance.erase(expr); inst->deleteValue(); return expr; } else if (auto alias = llvm::dyn_cast(constant)) { return CreateConstantExpr(alias->getAliasee()); } else if (auto global = llvm::dyn_cast(constant)) { - auto decl{provenance.value_decls[global]}; + auto decl{dec_ctx.value_decls[global]}; auto ref{ast.CreateDeclRef(decl)}; return ast.CreateAddrOf(ref); } @@ -343,9 +429,9 @@ clang::Expr *ExprGen::CreateLiteralExpr(llvm::Constant *constant) { clang::Expr *ExprGen::CreateOperandExpr(llvm::Use &val) { DLOG(INFO) << "Getting Expr for " << LLVMThingToString(val); auto CreateRef{[this, &val] { - auto decl{provenance.value_decls[val]}; + auto decl{dec_ctx.value_decls[val]}; auto ref{ast.CreateDeclRef(decl)}; - provenance.use_provenance[ref] = &val; + dec_ctx.use_provenance[ref] = &val; return ref; }}; @@ -367,13 +453,12 @@ clang::Expr *ExprGen::CreateOperandExpr(llvm::Use &val) { // the actual argument and use it instead of the actual argument. This // is because `byval` arguments are pointers, so each reference to those // arguments assume they are dealing with pointers. - auto &temp{provenance.temp_decls[arg]}; + auto &temp{dec_ctx.temp_decls[arg]}; if (!temp) { auto addr_of_arg{ast.CreateAddrOf(ref)}; auto func{arg->getParent()}; - auto fdecl{provenance.value_decls[func]->getAsFunction()}; - auto argdecl{ - clang::cast(provenance.value_decls[arg])}; + auto fdecl{dec_ctx.value_decls[func]->getAsFunction()}; + auto argdecl{clang::cast(dec_ctx.value_decls[arg])}; temp = ast.CreateVarDecl(fdecl, GetQualType(arg->getType()), argdecl->getName().str() + "_ptr"); temp->setInit(addr_of_arg); @@ -386,7 +471,7 @@ clang::Expr *ExprGen::CreateOperandExpr(llvm::Use &val) { } } else if (auto inst = llvm::dyn_cast(val)) { // Operand is a result of an expression - if (auto decl = provenance.value_decls[inst]) { + if (auto decl = dec_ctx.value_decls[inst]) { res = ast.CreateDeclRef(decl); } else { res = visit(inst); @@ -407,7 +492,7 @@ clang::Expr *ExprGen::CreateOperandExpr(llvm::Use &val) { << "Bitcode: [" << LLVMThingToString(val) << "]\n" << "Type: [" << LLVMThingToString(val->getType()) << "]\n"; } - provenance.use_provenance[res] = &val; + dec_ctx.use_provenance[res] = &val; return res; } @@ -496,7 +581,7 @@ clang::Expr *ExprGen::visitCallInst(llvm::CallInst &inst) { auto ptr_type{ ast_ctx.getPointerType(GetQualType(inst.getParamByValType(i)))}; opnd = ast.CreateDeref(ast.CreateCStyleCast(ptr_type, opnd)); - provenance.use_provenance[opnd] = &arg; + dec_ctx.use_provenance[opnd] = &arg; } args.push_back(opnd); } @@ -504,7 +589,7 @@ clang::Expr *ExprGen::visitCallInst(llvm::CallInst &inst) { clang::Expr *callexpr{nullptr}; auto &callee{*(inst.op_end() - 1)}; if (auto func = llvm::dyn_cast(callee)) { - auto fdecl{provenance.value_decls[func]->getAsFunction()}; + auto fdecl{dec_ctx.value_decls[func]->getAsFunction()}; if (func->getFunctionType() == inst.getFunctionType()) { callexpr = ast.CreateCall(fdecl, args); } else { @@ -515,7 +600,7 @@ clang::Expr *ExprGen::visitCallInst(llvm::CallInst &inst) { callexpr = ast.CreateCall(cast, args); } } else if (auto iasm = llvm::dyn_cast(callee)) { - auto fdecl{provenance.value_decls[iasm]->getAsFunction()}; + auto fdecl{dec_ctx.value_decls[iasm]->getAsFunction()}; callexpr = ast.CreateCall(fdecl, args); } else if (llvm::isa(callee->getType())) { auto funcPtr{ast_ctx.getPointerType(GetQualType(inst.getFunctionType()))}; @@ -619,7 +704,7 @@ clang::Expr *ExprGen::visitGetElementPtrInst(llvm::GetElementPtrInst &inst) { case llvm::Type::StructTyID: { auto mem_idx = llvm::dyn_cast(idx); CHECK(mem_idx) << "Non-constant GEP index while indexing a structure"; - auto tdecl{provenance.type_decls[indexed_type]}; + auto tdecl{dec_ctx.type_decls[indexed_type]}; CHECK(tdecl) << "Structure declaration doesn't exist"; auto record{clang::cast(tdecl)}; auto field_it{record->field_begin()}; @@ -677,7 +762,7 @@ clang::Expr *ExprGen::visitExtractValueInst(llvm::ExtractValueInst &inst) { } break; // Structures case llvm::Type::StructTyID: { - auto tdecl{provenance.type_decls[indexed_type]}; + auto tdecl{dec_ctx.type_decls[indexed_type]}; CHECK(tdecl) << "Structure declaration doesn't exist"; auto record{clang::cast(tdecl)}; auto field_it{record->field_begin()}; @@ -801,7 +886,7 @@ clang::Expr *ExprGen::visitCmpInst(llvm::CmpInst &inst) { return op; } else { auto cast{ast.CreateCStyleCast(rt, op)}; - CopyProvenance(op, cast, provenance.use_provenance); + CopyProvenance(op, cast, dec_ctx.use_provenance); return (clang::Expr *)cast; } }}; @@ -989,15 +1074,12 @@ class StmtGen : public llvm::InstVisitor { clang::ASTContext &ast_ctx; ASTBuilder * ExprGen &expr_gen; - Provenance &provenance; + DecompilationContext &dec_ctx; public: StmtGen(clang::ASTContext &ast_ctx, ASTBuilder &ast, ExprGen &expr_gen, - Provenance &provenance) - : ast_ctx(ast_ctx), - ast(ast), - expr_gen(expr_gen), - provenance(provenance) {} + DecompilationContext &dec_ctx) + : ast_ctx(ast_ctx), ast(ast), expr_gen(expr_gen), dec_ctx(dec_ctx) {} clang::Stmt *visitStoreInst(llvm::StoreInst &inst); clang::Stmt *visitCallInst(llvm::CallInst &inst); @@ -1043,13 +1125,13 @@ clang::Stmt *StmtGen::visitStoreInst(llvm::StoreInst &inst) { } else { // Create the assignemnt itself auto deref{ast.CreateDeref(lhs)}; - CopyProvenance(lhs, deref, provenance.use_provenance); + CopyProvenance(lhs, deref, dec_ctx.use_provenance); return ast.CreateAssign(deref, rhs); } } clang::Stmt *StmtGen::visitCallInst(llvm::CallInst &inst) { - auto &var{provenance.value_decls[&inst]}; + auto &var{dec_ctx.value_decls[&inst]}; auto expr{expr_gen.visit(inst)}; if (var) { return ast.CreateAssign(ast.CreateDeclRef(var), expr); @@ -1090,7 +1172,7 @@ clang::Stmt *StmtGen::visitUnreachableInst(llvm::UnreachableInst &inst) { clang::Stmt *StmtGen::visitPHINode(llvm::PHINode &inst) { return nullptr; } clang::Stmt *StmtGen::visitInstruction(llvm::Instruction &inst) { - auto &var{provenance.value_decls[&inst]}; + auto &var{dec_ctx.value_decls[&inst]}; if (var) { auto expr{expr_gen.visit(inst)}; return ast.CreateAssign(ast.CreateDeclRef(var), expr); @@ -1098,20 +1180,21 @@ clang::Stmt *StmtGen::visitInstruction(llvm::Instruction &inst) { return nullptr; } -IRToASTVisitor::IRToASTVisitor(clang::ASTUnit &unit, Provenance &provenance) +IRToASTVisitor::IRToASTVisitor(clang::ASTUnit &unit, + DecompilationContext &dec_ctx) : ast_unit(unit), ast_ctx(unit.getASTContext()), ast(unit), - provenance(provenance) {} + dec_ctx(dec_ctx) {} void IRToASTVisitor::VisitGlobalVar(llvm::GlobalVariable &gvar) { - ExprGen expr_gen{ast_unit, provenance}; + ExprGen expr_gen{ast_unit, dec_ctx}; expr_gen.VisitGlobalVar(gvar); } void IRToASTVisitor::VisitArgument(llvm::Argument &arg) { DLOG(INFO) << "VisitArgument: " << LLVMThingToString(&arg); - auto &parm{provenance.value_decls[&arg]}; + auto &parm{dec_ctx.value_decls[&arg]}; if (parm) { return; } @@ -1120,13 +1203,13 @@ void IRToASTVisitor::VisitArgument(llvm::Argument &arg) { : "arg" + std::to_string(arg.getArgNo())}; // Get parent function declaration auto func{arg.getParent()}; - auto fdecl{clang::cast(provenance.value_decls[func])}; + auto fdecl{clang::cast(dec_ctx.value_decls[func])}; auto argtype{arg.getType()}; if (arg.hasByValAttr()) { auto byval{arg.getAttribute(llvm::Attribute::ByVal)}; argtype = byval.getValueAsType(); } - ExprGen expr_gen{ast_unit, provenance}; + ExprGen expr_gen{ast_unit, dec_ctx}; // Create a declaration parm = ast.CreateParamDecl(fdecl, expr_gen.GetQualType(argtype), name); } @@ -1159,20 +1242,20 @@ static llvm::FunctionType *GetFixedFunctionType(llvm::Function &func) { void IRToASTVisitor::VisitBasicBlock(llvm::BasicBlock &block, std::vector &stmts) { - ExprGen expr_gen{ast_unit, provenance}; - StmtGen stmt_gen{ast_ctx, ast, expr_gen, provenance}; + ExprGen expr_gen{ast_unit, dec_ctx}; + StmtGen stmt_gen{ast_ctx, ast, expr_gen, dec_ctx}; for (auto &inst : block) { auto stmt{stmt_gen.visit(inst)}; if (stmt) { stmts.push_back(stmt); - provenance.stmt_provenance[stmt] = &inst; + dec_ctx.stmt_provenance[stmt] = &inst; } } - auto &uses{provenance.outgoing_uses[&block]}; + auto &uses{dec_ctx.outgoing_uses[&block]}; for (auto it{uses.rbegin()}; it != uses.rend(); ++it) { auto use{*it}; - auto var{provenance.value_decls[use->getUser()]}; + auto var{dec_ctx.value_decls[use->getUser()]}; auto expr{expr_gen.CreateOperandExpr(*use)}; stmts.push_back(ast.CreateAssign(ast.CreateDeclRef(var), expr)); } @@ -1187,14 +1270,14 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { return; } - auto &decl{provenance.value_decls[&func]}; + auto &decl{dec_ctx.value_decls[&func]}; if (decl) { return; } DLOG(INFO) << "Creating FunctionDecl for " << name; auto tudecl{ast_ctx.getTranslationUnitDecl()}; - ExprGen expr_gen{ast_unit, provenance}; + ExprGen expr_gen{ast_unit, dec_ctx}; auto type{expr_gen.GetQualType(GetFixedFunctionType(func))}; decl = ast.CreateFunctionDecl(tudecl, type, name); @@ -1204,14 +1287,14 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { for (auto &arg : func.args()) { VisitArgument(arg); params.push_back( - clang::cast(provenance.value_decls[&arg])); + clang::cast(dec_ctx.value_decls[&arg])); } auto fdecl{decl->getAsFunction()}; fdecl->setParams(params); for (auto &inst : llvm::instructions(func)) { - auto &var{provenance.value_decls[&inst]}; + auto &var{dec_ctx.value_decls[&inst]}; if (auto alloca = llvm::dyn_cast(&inst)) { auto name{"var" + std::to_string(GetNumDecls(fdecl))}; // TLDR: Here we discard the variable name as present in the bitcode @@ -1237,11 +1320,7 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { fdecl->addDecl(var); } else if (inst.hasNUsesOrMore(2) || (inst.hasNUsesOrMore(1) && llvm::isa(inst)) || - llvm::isa(inst) || llvm::isa(inst)) { - if (inst.getMetadata("rellic.notemp")) { - continue; - } if (!inst.getType()->isVoidTy()) { auto GetPrefix{[&](llvm::Instruction *inst) { if (llvm::isa(inst)) { @@ -1268,7 +1347,7 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { auto bb{phi->getIncomingBlock(i)}; auto &use{phi->getOperandUse(i)}; - provenance.outgoing_uses[bb].push_back(&use); + dec_ctx.outgoing_uses[bb].push_back(&use); } } } @@ -1278,7 +1357,7 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { if (auto iasm = llvm::dyn_cast(&opnd)) { // TODO(frabert): We still need to find a way to embed the inline asm // into the function - auto &decl{provenance.value_decls[iasm]}; + auto &decl{dec_ctx.value_decls[iasm]}; if (decl) { return; } @@ -1308,12 +1387,12 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { } clang::Expr *IRToASTVisitor::CreateOperandExpr(llvm::Use &val) { - ExprGen expr_gen{ast_unit, provenance}; + ExprGen expr_gen{ast_unit, dec_ctx}; return expr_gen.CreateOperandExpr(val); } clang::Expr *IRToASTVisitor::CreateConstantExpr(llvm::Constant *constant) { - ExprGen expr_gen{ast_unit, provenance}; + ExprGen expr_gen{ast_unit, dec_ctx}; return expr_gen.CreateConstantExpr(constant); } diff --git a/lib/AST/InferenceRule.cpp b/lib/AST/InferenceRule.cpp index efcabcc8..7a50f5b0 100644 --- a/lib/AST/InferenceRule.cpp +++ b/lib/AST/InferenceRule.cpp @@ -14,7 +14,7 @@ namespace rellic { clang::Stmt *ApplyFirstMatchingRule( - Provenance &provenance, clang::ASTUnit &unit, clang::Stmt *stmt, + DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt, std::vector> &rules) { clang::ast_matchers::MatchFinder::MatchFinderOptions opts; clang::ast_matchers::MatchFinder finder(opts); @@ -27,7 +27,7 @@ clang::Stmt *ApplyFirstMatchingRule( for (auto &rule : rules) { if (*rule) { - return rule->GetOrCreateSubstitution(provenance, unit, stmt); + return rule->GetOrCreateSubstitution(dec_ctx, unit, stmt); } } diff --git a/lib/AST/LocalDeclRenamer.cpp b/lib/AST/LocalDeclRenamer.cpp index ae23fe92..7885007a 100644 --- a/lib/AST/LocalDeclRenamer.cpp +++ b/lib/AST/LocalDeclRenamer.cpp @@ -16,9 +16,9 @@ namespace rellic { -LocalDeclRenamer::LocalDeclRenamer(Provenance &provenance, clang::ASTUnit &unit, - IRToNameMap &names) - : TransformVisitor(provenance, unit), +LocalDeclRenamer::LocalDeclRenamer(DecompilationContext &dec_ctx, + clang::ASTUnit &unit, IRToNameMap &names) + : TransformVisitor(dec_ctx, unit), seen_names(1), names(names) {} @@ -78,7 +78,7 @@ bool LocalDeclRenamer::TraverseFunctionDecl(clang::FunctionDecl *decl) { bool LocalDeclRenamer::shouldTraversePostOrder() { return false; } void LocalDeclRenamer::RunImpl() { - for (auto &pair : provenance.value_decls) { + for (auto &pair : dec_ctx.value_decls) { decls[pair.second] = pair.first; } TraverseDecl(ast_ctx.getTranslationUnitDecl()); diff --git a/lib/AST/LoopRefine.cpp b/lib/AST/LoopRefine.cpp index 76118a37..9710ce8d 100644 --- a/lib/AST/LoopRefine.cpp +++ b/lib/AST/LoopRefine.cpp @@ -47,7 +47,7 @@ class WhileRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -57,7 +57,6 @@ class WhileRule : public InferenceRule { auto comp{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(comp->body_front())}; - auto cond{ifstmt->getCond()}; std::vector new_body; if (auto else_stmt = ifstmt->getElse()) { new_body.push_back(else_stmt); @@ -65,9 +64,11 @@ class WhileRule : public InferenceRule { std::copy(comp->body_begin() + 1, comp->body_end(), std::back_inserter(new_body)); ASTBuilder ast(unit); - auto new_cond{ast.CreateLNot(cond)}; - CopyProvenance(cond, new_cond, provenance.use_provenance); - return ast.CreateWhile(new_cond, ast.CreateCompoundStmt(new_body)); + auto new_while{ + ast.CreateWhile(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; + dec_ctx.conds[new_while] = + dec_ctx.InsertZExpr(!dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]); + return new_while; } }; @@ -88,7 +89,7 @@ class ElseWhileRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -98,13 +99,12 @@ class ElseWhileRule : public InferenceRule { auto comp{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(comp->body_front())}; - auto cond{ifstmt->getCond()}; std::vector new_body; - new_body.push_back(ifstmt->getThen()); - std::copy(comp->body_begin() + 1, comp->body_end(), - std::back_inserter(new_body)); ASTBuilder ast(unit); - return ast.CreateWhile(cond, ast.CreateCompoundStmt(new_body)); + auto new_while{ + ast.CreateWhile(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; + dec_ctx.conds[new_while] = dec_ctx.conds[ifstmt]; + return new_while; } }; @@ -125,7 +125,7 @@ class DoWhileRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -135,18 +135,19 @@ class DoWhileRule : public InferenceRule { auto comp{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(comp->body_back())}; - auto cond{ifstmt->getCond()}; + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]}; std::vector new_body(comp->body_begin(), comp->body_end() - 1); ASTBuilder ast(unit); if (auto else_stmt = ifstmt->getElse()) { - auto cond_inv{ast.CreateLNot(cond)}; - CopyProvenance(cond, cond_inv, provenance.use_provenance); - new_body.push_back(ast.CreateIf(cond_inv, else_stmt)); + auto new_if{ast.CreateIf(dec_ctx.marker_expr, else_stmt)}; + dec_ctx.conds[new_if] = dec_ctx.z3_exprs.size(); + new_body.push_back(new_if); } - auto cond_inv{ast.CreateLNot(cond)}; - CopyProvenance(cond, cond_inv, provenance.use_provenance); - return ast.CreateDo(cond_inv, ast.CreateCompoundStmt(new_body)); + auto new_do{ + ast.CreateDo(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; + dec_ctx.conds[new_do] = dec_ctx.InsertZExpr(!cond); + return new_do; } }; @@ -167,7 +168,7 @@ class ElseDoWhileRule : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -177,7 +178,6 @@ class ElseDoWhileRule : public InferenceRule { auto comp{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(comp->body_back())}; - auto cond{ifstmt->getCond()}; std::vector new_body(comp->body_begin(), comp->body_end() - 1); ASTBuilder ast(unit); @@ -185,8 +185,10 @@ class ElseDoWhileRule : public InferenceRule { ifstmt->setElse(nullptr); new_body.push_back(ifstmt); - auto cond_clone{Clone(unit, cond, provenance.use_provenance)}; - return ast.CreateDo(cond_clone, ast.CreateCompoundStmt(new_body)); + auto new_do{ + ast.CreateDo(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; + dec_ctx.conds[new_do] = dec_ctx.conds[ifstmt]; + return new_do; } }; @@ -215,7 +217,7 @@ class NestedDoWhileRule : public InferenceRule { matched = true; } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -223,23 +225,27 @@ class NestedDoWhileRule : public InferenceRule { CHECK(loop && loop == match) << "Substituted WhileStmt is not the matched WhileStmt!"; auto comp{clang::cast(loop->getBody())}; - auto cond{clang::cast(comp->body_back())}; + auto if_stmt{clang::cast(comp->body_back())}; + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[if_stmt]]}; std::vector do_body(comp->body_begin(), comp->body_end() - 1); ASTBuilder ast(unit); - if (auto else_stmt = cond->getElse()) { - auto cond_inv{ast.CreateLNot(cond->getCond())}; - CopyProvenance(cond->getCond(), cond_inv, provenance.use_provenance); - do_body.push_back(ast.CreateIf(cond_inv, else_stmt)); + if (auto else_stmt = if_stmt->getElse()) { + auto new_if{ast.CreateIf(dec_ctx.marker_expr, else_stmt)}; + dec_ctx.conds[new_if] = dec_ctx.z3_exprs.size(); + do_body.push_back(new_if); } - auto do_cond{ast.CreateLNot(cond->getCond())}; - CopyProvenance(cond->getCond(), do_cond, provenance.use_provenance); - auto do_stmt{ast.CreateDo(do_cond, ast.CreateCompoundStmt(do_body))}; + auto do_stmt{ + ast.CreateDo(dec_ctx.marker_expr, ast.CreateCompoundStmt(do_body))}; + dec_ctx.conds[do_stmt] = dec_ctx.InsertZExpr(!cond); - std::vector while_body({do_stmt, cond->getThen()}); - return ast.CreateWhile(loop->getCond(), ast.CreateCompoundStmt(while_body)); + std::vector while_body({do_stmt, if_stmt->getThen()}); + auto new_while{ast.CreateWhile(dec_ctx.marker_expr, + ast.CreateCompoundStmt(while_body))}; + dec_ctx.conds[new_while] = dec_ctx.conds[loop]; + return new_while; } }; @@ -265,7 +271,7 @@ class LoopToSeq : public InferenceRule { } } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop = clang::dyn_cast(stmt); @@ -319,7 +325,7 @@ class CondToSeqRule : public InferenceRule { match = result.Nodes.getNodeAs("while"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -330,14 +336,18 @@ class CondToSeqRule : public InferenceRule { ASTBuilder ast(unit); auto body{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(body->body_front())}; - auto inner_loop{ast.CreateWhile(ifstmt->getCond(), ifstmt->getThen())}; + auto inner_loop{ast.CreateWhile(dec_ctx.marker_expr, ifstmt->getThen())}; + dec_ctx.conds[inner_loop] = dec_ctx.conds[ifstmt]; std::vector new_body({inner_loop}); if (auto comp = clang::dyn_cast(ifstmt->getElse())) { new_body.insert(new_body.end(), comp->body_begin(), comp->body_end()); } else { new_body.push_back(ifstmt->getElse()); } - return ast.CreateWhile(loop->getCond(), ast.CreateCompoundStmt(new_body)); + auto new_while{ + ast.CreateWhile(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; + dec_ctx.conds[new_while] = dec_ctx.conds[loop]; + return new_while; } }; @@ -354,7 +364,7 @@ class CondToSeqNegRule : public InferenceRule { match = result.Nodes.getNodeAs("while"); } - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, + clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt) override { auto loop{clang::dyn_cast(stmt)}; @@ -365,9 +375,9 @@ class CondToSeqNegRule : public InferenceRule { ASTBuilder ast(unit); auto body{clang::cast(loop->getBody())}; auto ifstmt{clang::cast(body->body_front())}; - auto cond{ast.CreateLNot(ifstmt->getCond())}; - CopyProvenance(ifstmt->getCond(), cond, provenance.use_provenance); - auto inner_loop{ast.CreateWhile(cond, ifstmt->getElse())}; + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]}; + auto inner_loop{ast.CreateWhile(dec_ctx.marker_expr, ifstmt->getElse())}; + dec_ctx.conds[inner_loop] = dec_ctx.InsertZExpr(!cond); std::vector new_body({inner_loop}); if (auto comp = clang::dyn_cast(ifstmt->getThen())) { new_body.insert(new_body.end(), comp->body_begin(), comp->body_end()); @@ -375,14 +385,17 @@ class CondToSeqNegRule : public InferenceRule { new_body.push_back(ifstmt->getThen()); } - return ast.CreateWhile(loop->getCond(), ast.CreateCompoundStmt(new_body)); + auto new_while{ + ast.CreateWhile(dec_ctx.marker_expr, ast.CreateCompoundStmt(new_body))}; + dec_ctx.conds[new_while] = dec_ctx.conds[loop]; + return new_while; } }; } // namespace -LoopRefine::LoopRefine(Provenance &provenance, clang::ASTUnit &u) - : TransformVisitor(provenance, u) {} +LoopRefine::LoopRefine(DecompilationContext &dec_ctx, clang::ASTUnit &u) + : TransformVisitor(dec_ctx, u) {} bool LoopRefine::VisitWhileStmt(clang::WhileStmt *loop) { // DLOG(INFO) << "VisitWhileStmt"; @@ -412,7 +425,7 @@ bool LoopRefine::VisitWhileStmt(clang::WhileStmt *loop) { rules.emplace_back(new ElseWhileRule); rules.emplace_back(new ElseDoWhileRule); - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, loop, rules)}; + auto sub{ApplyFirstMatchingRule(dec_ctx, ast_unit, loop, rules)}; if (sub != loop) { substitutions[loop] = sub; } diff --git a/lib/AST/MaterializeConds.cpp b/lib/AST/MaterializeConds.cpp new file mode 100644 index 00000000..c41ff94c --- /dev/null +++ b/lib/AST/MaterializeConds.cpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021-present, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#include "rellic/AST/MaterializeConds.h" + +#include +#include + +#include "rellic/AST/Util.h" + +namespace rellic { + +MaterializeConds::MaterializeConds(DecompilationContext &dec_ctx, + clang::ASTUnit &unit) + : TransformVisitor(dec_ctx, unit), + ast_gen(unit, dec_ctx) {} + +bool MaterializeConds::VisitIfStmt(clang::IfStmt *stmt) { + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[stmt]]}; + if (stmt->getCond() == dec_ctx.marker_expr) { + stmt->setCond(ast_gen.ConvertExpr(cond)); + } + return true; +} + +bool MaterializeConds::VisitWhileStmt(clang::WhileStmt *stmt) { + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[stmt]]}; + if (stmt->getCond() == dec_ctx.marker_expr) { + stmt->setCond(ast_gen.ConvertExpr(cond)); + } + return true; +} + +bool MaterializeConds::VisitDoStmt(clang::DoStmt *stmt) { + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[stmt]]}; + if (stmt->getCond() == dec_ctx.marker_expr) { + stmt->setCond(ast_gen.ConvertExpr(cond)); + } + return true; +} + +void MaterializeConds::RunImpl() { + LOG(INFO) << "Materializing conditions"; + TransformVisitor::RunImpl(); + TraverseDecl(ast_ctx.getTranslationUnitDecl()); +} + +} // namespace rellic diff --git a/lib/AST/NestedCondProp.cpp b/lib/AST/NestedCondProp.cpp old mode 100644 new mode 100755 index 766aac61..7df48f96 --- a/lib/AST/NestedCondProp.cpp +++ b/lib/AST/NestedCondProp.cpp @@ -12,106 +12,199 @@ #include #include +#include #include #include "rellic/AST/ASTBuilder.h" #include "rellic/AST/Util.h" namespace rellic { -using ExprVec = std::vector; +// Stores a set of expression that have a known value, so that they can be +// recognized as part of larger expressions and simplified. +struct KnownExprs { + std::unordered_map values; + + void AddExpr(z3::expr expr, bool value) { + // When adding expressions to the set of known values, it's important that + // they are added in their smallest possible form. E.g., if it's known that + // `A && B` is true, then both of its subexpressions are true, and we should + // add those instead. + // This is so that `A && B && C` can be used to simplify smaller + // expressions, like `A && B`, which would otherwise not be recognized. + switch (expr.decl().decl_kind()) { + case Z3_OP_TRUE: + case Z3_OP_FALSE: + return; + default: + break; + } -static bool isConstant(const clang::ASTContext& ctx, clang::Expr* expr) { - return expr->getIntegerConstantExpr(ctx).hasValue(); -} + // If !A has value V, then A has value !V, so add that instead. + if (expr.is_not()) { + AddExpr(expr.arg(0), !value); + return; + } -class CompoundVisitor - : public clang::StmtVisitor { - private: - ASTBuilder& ast; - clang::ASTContext& ctx; + // A unary && or ||, just add the single subexpression + if (expr.num_args() == 1) { + AddExpr(expr.arg(0), value); + return; + } - public: - CompoundVisitor(ASTBuilder& ast, clang::ASTContext& ctx) - : ast(ast), ctx(ctx) {} + // A true && expression means all of its subexpressions are true + if (value && expr.is_and()) { + for (auto e : expr.args()) { + AddExpr(e, true); + } + return; + } - bool VisitCompoundStmt(clang::CompoundStmt* compound, ExprVec& true_exprs) { - bool changed{false}; - for (auto stmt : compound->body()) { - changed |= Visit(stmt, true_exprs); + // A false || expression means all of its subexpressions are false + if (!value && expr.is_or()) { + for (auto e : expr.args()) { + AddExpr(e, false); + } + return; } - return changed; + values[expr.id()] = value; } - bool VisitWhileStmt(clang::WhileStmt* while_stmt, ExprVec& true_exprs) { - bool changed{false}; - auto cond{while_stmt->getCond()}; - for (auto true_expr : true_exprs) { - changed |= Replace(true_expr, ast.CreateTrue(), &cond); + // Simplify an expression `expr` using all the known values stored. Sets + // `changed` to true is any simplification has been applied. + z3::expr ApplyAssumptions(z3::expr expr, bool& changed) { + if (values.empty()) { + return expr; } - while_stmt->setCond(cond); - ExprVec inner{true_exprs}; - if (!isConstant(ctx, while_stmt->getCond())) { - inner.push_back(while_stmt->getCond()); + if (values.find(expr.id()) != values.end()) { + changed = true; + return expr.ctx().bool_val(values[expr.id()]); } - changed |= Visit(while_stmt->getBody(), inner); - true_exprs.push_back(Negate(ast, while_stmt->getCond())); - return changed; + if (expr.is_and() || expr.is_or()) { + z3::expr_vector args{expr.ctx()}; + for (auto arg : expr.args()) { + args.push_back(ApplyAssumptions(arg, changed)); + } + if (expr.is_and()) { + return z3::mk_and(args); + } else { + return z3::mk_or(args); + } + } + + if (expr.is_not()) { + return !ApplyAssumptions(expr.arg(0), changed); + } + + return expr; } +}; - bool VisitDoStmt(clang::DoStmt* do_stmt, ExprVec& true_exprs) { +class CompoundVisitor + : public clang::StmtVisitor { + private: + DecompilationContext& dec_ctx; + ASTBuilder& ast; + clang::ASTContext& ctx; + + template + bool VisitLoop(T* loop, KnownExprs& known_exprs) { + auto cond_idx{dec_ctx.conds[loop]}; bool changed{false}; - auto cond{do_stmt->getCond()}; - for (auto true_expr : true_exprs) { - changed |= Replace(true_expr, ast.CreateTrue(), &cond); + auto old_cond{dec_ctx.z3_exprs[cond_idx]}; + auto new_cond{known_exprs.ApplyAssumptions(old_cond, changed)}; + if (loop->getCond() != dec_ctx.marker_expr && changed) { + dec_ctx.z3_exprs.set(cond_idx, new_cond); + return true; + } + + auto inner{known_exprs}; + if constexpr (cond_is_true_in_body) { + inner.AddExpr(new_cond, true); + } + known_exprs.AddExpr(new_cond, false); + + if (Visit(loop->getBody(), inner)) { + return true; + } + return false; + } + + public: + CompoundVisitor(DecompilationContext& dec_ctx, ASTBuilder& ast, + clang::ASTContext& ctx) + : dec_ctx(dec_ctx), ast(ast), ctx(ctx) {} + + bool VisitCompoundStmt(clang::CompoundStmt* compound, + KnownExprs& known_exprs) { + for (auto stmt : compound->body()) { + if (Visit(stmt, known_exprs)) { + return true; + } } - do_stmt->setCond(cond); - ExprVec inner{true_exprs}; - Visit(do_stmt->getBody(), inner); + return false; + } - true_exprs.push_back(Negate(ast, do_stmt->getCond())); - return changed; + bool VisitWhileStmt(clang::WhileStmt* while_stmt, KnownExprs& known_exprs) { + return VisitLoop(while_stmt, known_exprs); } - bool VisitIfStmt(clang::IfStmt* if_stmt, ExprVec& true_exprs) { + bool VisitDoStmt(clang::DoStmt* do_stmt, KnownExprs& known_exprs) { + return VisitLoop(do_stmt, known_exprs); + } + + bool VisitIfStmt(clang::IfStmt* if_stmt, KnownExprs& known_exprs) { + auto cond_idx{dec_ctx.conds[if_stmt]}; bool changed{false}; - auto cond{if_stmt->getCond()}; - for (auto true_expr : true_exprs) { - changed |= Replace(true_expr, ast.CreateTrue(), &cond); + auto old_cond{dec_ctx.z3_exprs[cond_idx]}; + auto new_cond{known_exprs.ApplyAssumptions(old_cond, changed)}; + if (if_stmt->getCond() == dec_ctx.marker_expr && changed) { + dec_ctx.z3_exprs.set(cond_idx, new_cond); + return true; } - if_stmt->setCond(cond); - ExprVec inner_then{true_exprs}; - if (!isConstant(ctx, if_stmt->getCond())) { - inner_then.push_back(if_stmt->getCond()); + auto inner_then{known_exprs}; + inner_then.AddExpr(new_cond, true); + if (Visit(if_stmt->getThen(), inner_then)) { + return true; } - Visit(if_stmt->getThen(), inner_then); if (if_stmt->getElse()) { - ExprVec inner_else{true_exprs}; - inner_else.push_back(Negate(ast, if_stmt->getCond())); - Visit(if_stmt->getElse(), inner_else); + auto inner_else{known_exprs}; + inner_else.AddExpr(new_cond, false); + if (Visit(if_stmt->getElse(), inner_else)) { + return true; + } } - return changed; + return false; } }; -NestedCondProp::NestedCondProp(Provenance& provenance, clang::ASTUnit& unit) - : ASTPass(provenance, unit) {} +NestedCondProp::NestedCondProp(DecompilationContext& dec_ctx, + clang::ASTUnit& unit) + : ASTPass(dec_ctx, unit) {} void NestedCondProp::RunImpl() { + LOG(INFO) << "Propagating conditions"; changed = false; ASTBuilder ast{ast_unit}; - CompoundVisitor visitor{ast, ast_ctx}; + CompoundVisitor visitor{dec_ctx, ast, ast_ctx}; for (auto decl : ast_ctx.getTranslationUnitDecl()->decls()) { if (auto fdecl = clang::dyn_cast(decl)) { + if (Stopped()) { + return; + } + if (fdecl->hasBody()) { - ExprVec true_exprs; - changed |= visitor.Visit(fdecl->getBody(), true_exprs); + KnownExprs known_exprs{}; + if (visitor.Visit(fdecl->getBody(), known_exprs)) { + changed = true; + return; + } } } } diff --git a/lib/AST/NestedScopeCombine.cpp b/lib/AST/NestedScopeCombine.cpp index 222f97d2..b941e077 100644 --- a/lib/AST/NestedScopeCombine.cpp +++ b/lib/AST/NestedScopeCombine.cpp @@ -11,26 +11,37 @@ #include #include +#include "rellic/AST/Util.h" + namespace rellic { -NestedScopeCombine::NestedScopeCombine(Provenance &provenance, +NestedScopeCombine::NestedScopeCombine(DecompilationContext &dec_ctx, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit) {} + : TransformVisitor(dec_ctx, unit) {} bool NestedScopeCombine::VisitIfStmt(clang::IfStmt *ifstmt) { // DLOG(INFO) << "VisitIfStmt"; // Determine whether `cond` is a constant expression that is always true and // `ifstmt` should be replaced by `then` in it's parent nodes. - auto if_const_expr = ifstmt->getCond()->getIntegerConstantExpr(ast_ctx); - bool is_const = if_const_expr.hasValue(); - if (is_const && if_const_expr->getBoolValue()) { + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[ifstmt]]}; + if (Prove(cond)) { substitutions[ifstmt] = ifstmt->getThen(); - } else if (is_const && !if_const_expr->getBoolValue()) { - if (auto else_stmt = ifstmt->getElse()) { - substitutions[ifstmt] = else_stmt; - } else { - std::vector body; - substitutions[ifstmt] = ast.CreateCompoundStmt(body); + } else if (ifstmt->getElse() && Prove(!cond)) { + substitutions[ifstmt] = ifstmt->getElse(); + } + return !Stopped(); +} + +bool NestedScopeCombine::VisitWhileStmt(clang::WhileStmt *stmt) { + // Substitute while statements in the form `while(1) { sth; break; }` with + // just `{ sth; }` + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[stmt]]}; + if (Prove(cond)) { + auto body{clang::cast(stmt->getBody())}; + if (clang::isa(body->body_back())) { + std::vector new_body{body->body_begin(), + body->body_end() - 1}; + substitutions[stmt] = ast.CreateCompoundStmt(new_body); } } return !Stopped(); diff --git a/lib/AST/NormalizeCond.cpp b/lib/AST/NormalizeCond.cpp deleted file mode 100644 index 94d217cc..00000000 --- a/lib/AST/NormalizeCond.cpp +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Copyright (c) 2022-present, Trail of Bits, Inc. - * All rights reserved. - * - * This source code is licensed in accordance with the terms specified in - * the LICENSE file found in the root directory of this source tree. - */ - -#include "rellic/AST/NormalizeCond.h" - -#include -#include -#include -#include -#include -#include - -#include "rellic/AST/ASTBuilder.h" -#include "rellic/AST/InferenceRule.h" - -namespace rellic { - -namespace { - -using namespace clang::ast_matchers; - -static const auto zero_int_lit = integerLiteral(equals(0)); - -static inline std::string GetOperatorName(clang::BinaryOperator::Opcode op) { - return op == clang::BO_LAnd ? "&&" : "||"; -} - -class DeMorganRule : public InferenceRule { - clang::BinaryOperator::Opcode to; - - public: - DeMorganRule(clang::BinaryOperator::Opcode from, - clang::BinaryOperator::Opcode to) - : InferenceRule( - unaryOperator(hasOperatorName("!"), - has(ignoringParenImpCasts(binaryOperator( - hasOperatorName(GetOperatorName(from)))))) - .bind("not")), - to(to) {} - - void run(const MatchFinder::MatchResult &result) override { - match = result.Nodes.getNodeAs("not"); - } - - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, - clang::ASTUnit &unit, - clang::Stmt *stmt) override { - auto &ctx{unit.getASTContext()}; - ASTBuilder ast{unit}; - - auto unop{clang::cast(stmt)}; - CHECK(unop == match) - << "Substituted UnaryOperator is not the matched UnaryOperator"; - auto binop{clang::cast( - unop->getSubExpr()->IgnoreParenImpCasts())}; - - auto new_lhs{ast.CreateLNot(binop->getLHS())}; - auto new_rhs{ast.CreateLNot(binop->getRHS())}; - CopyProvenance(binop->getLHS(), new_lhs, provenance.use_provenance); - CopyProvenance(binop->getRHS(), new_rhs, provenance.use_provenance); - return ast.CreateBinaryOp(to, new_lhs, new_rhs); - } -}; - -class AssociativeRule : public InferenceRule { - clang::BinaryOperator::Opcode op; - - public: - AssociativeRule(clang::BinaryOperator::Opcode op) - : InferenceRule( - binaryOperator(hasOperatorName(GetOperatorName(op)), - hasRHS(ignoringParenImpCasts(binaryOperator( - hasOperatorName(GetOperatorName(op)))))) - .bind("binop")), - op(op) {} - - void run(const MatchFinder::MatchResult &result) override { - match = result.Nodes.getNodeAs("binop"); - } - - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, - clang::ASTUnit &unit, - clang::Stmt *stmt) override { - auto &ctx{unit.getASTContext()}; - ASTBuilder ast{unit}; - - auto outer{clang::cast(stmt)}; - CHECK(outer == match) - << "Substituted BinaryOperator is not the matched BinaryOperator"; - auto inner{clang::cast( - outer->getRHS()->IgnoreParenImpCasts())}; - - auto new_lhs{ast.CreateBinaryOp(op, outer->getLHS(), inner->getLHS())}; - auto new_rhs{inner->getRHS()}; - return ast.CreateBinaryOp(op, new_lhs, new_rhs); - } -}; - -class LDistributiveRule : public InferenceRule { - public: - LDistributiveRule() - : InferenceRule( - binaryOperator(hasOperatorName("||"), - hasLHS(ignoringParenImpCasts( - binaryOperator(hasOperatorName("&&"))))) - .bind("binop")) {} - - void run(const MatchFinder::MatchResult &result) override { - match = result.Nodes.getNodeAs("binop"); - } - - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, - clang::ASTUnit &unit, - clang::Stmt *stmt) override { - auto &ctx{unit.getASTContext()}; - ASTBuilder ast{unit}; - - auto outer{clang::cast(stmt)}; - CHECK(outer == match) - << "Substituted BinaryOperator is not the matched BinaryOperator"; - auto inner{clang::cast( - outer->getLHS()->IgnoreParenImpCasts())}; - - auto new_lhs{ast.CreateLOr(inner->getLHS(), outer->getRHS())}; - auto new_rhs{ast.CreateLOr(inner->getRHS(), outer->getRHS())}; - return ast.CreateLAnd(new_lhs, new_rhs); - } -}; - -class RDistributiveRule : public InferenceRule { - public: - RDistributiveRule() - : InferenceRule( - binaryOperator(hasOperatorName("||"), - hasRHS(ignoringParenImpCasts( - binaryOperator(hasOperatorName("&&"))))) - .bind("binop")) {} - - void run(const MatchFinder::MatchResult &result) override { - match = result.Nodes.getNodeAs("binop"); - } - - clang::Stmt *GetOrCreateSubstitution(Provenance &provenance, - clang::ASTUnit &unit, - clang::Stmt *stmt) override { - auto &ctx{unit.getASTContext()}; - ASTBuilder ast{unit}; - - auto outer{clang::cast(stmt)}; - CHECK(outer == match) - << "Substituted BinaryOperator is not the matched BinaryOperator"; - auto inner{clang::cast( - outer->getRHS()->IgnoreParenImpCasts())}; - - auto new_lhs{ast.CreateLOr(outer->getLHS(), inner->getLHS())}; - auto new_rhs{ast.CreateLOr(outer->getLHS(), inner->getRHS())}; - return ast.CreateLAnd(new_lhs, new_rhs); - } -}; - -} // namespace - -NormalizeCond::NormalizeCond(Provenance &provenance, clang::ASTUnit &u) - : TransformVisitor(provenance, u) {} - -bool NormalizeCond::VisitUnaryOperator(clang::UnaryOperator *op) { - std::vector> rules; - - rules.emplace_back(new DeMorganRule(clang::BO_LAnd, clang::BO_LOr)); - rules.emplace_back(new DeMorganRule(clang::BO_LOr, clang::BO_LAnd)); - - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, op, rules)}; - if (sub != op) { - substitutions[op] = sub; - } - - return !Stopped(); -} - -bool NormalizeCond::VisitBinaryOperator(clang::BinaryOperator *op) { - std::vector> rules; - - rules.emplace_back(new AssociativeRule(clang::BO_LAnd)); - rules.emplace_back(new AssociativeRule(clang::BO_LOr)); - rules.emplace_back(new LDistributiveRule); - rules.emplace_back(new RDistributiveRule); - - auto sub{ApplyFirstMatchingRule(provenance, ast_unit, op, rules)}; - if (sub != op) { - substitutions[op] = sub; - } - - return !Stopped(); -} - -void NormalizeCond::RunImpl() { - LOG(INFO) << "Conversion into conjunctive normal form"; - TransformVisitor::RunImpl(); - TraverseDecl(ast_ctx.getTranslationUnitDecl()); -} - -} // namespace rellic \ No newline at end of file diff --git a/lib/AST/ReachBasedRefine.cpp b/lib/AST/ReachBasedRefine.cpp index 794d721b..824f4a97 100644 --- a/lib/AST/ReachBasedRefine.cpp +++ b/lib/AST/ReachBasedRefine.cpp @@ -15,21 +15,14 @@ namespace rellic { -ReachBasedRefine::ReachBasedRefine(Provenance &provenance, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit), - z3_ctx(new z3::context()), - z3_gen(new rellic::Z3ConvVisitor(unit, z3_ctx.get())) {} - -z3::expr ReachBasedRefine::GetZ3Cond(clang::IfStmt *ifstmt) { - auto cond = ifstmt->getCond(); - auto expr = z3_gen->Z3BoolCast(z3_gen->GetOrCreateZ3Expr(cond)); - return expr.simplify(); -} +ReachBasedRefine::ReachBasedRefine(DecompilationContext &dec_ctx, + clang::ASTUnit &unit) + : TransformVisitor(dec_ctx, unit) {} bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { std::vector body{compound->body_begin(), compound->body_end()}; std::vector ifs; - z3::expr_vector conds{*z3_ctx}; + z3::expr_vector conds{dec_ctx.z3_ctx}; auto ResetChain = [&]() { ifs.clear(); @@ -45,7 +38,7 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { } ifs.push_back(if_stmt); - auto cond{GetZ3Cond(if_stmt)}; + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[if_stmt]]}; if (if_stmt->getElse()) { // We cannot link `if` statements that contain `else` branches @@ -54,17 +47,17 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { } // Is the current `if` statement unreachable from all the others? - bool is_unreachable{Prove(*z3_ctx, !(cond && z3::mk_or(conds)))}; + bool is_unreachable{Prove(HeavySimplify(!(cond && z3::mk_or(conds))))}; if (!is_unreachable) { ResetChain(); continue; } - conds.push_back(GetZ3Cond(if_stmt)); + conds.push_back(cond); // Do the collected statements cover all possibilities? - auto is_complete{Prove(*z3_ctx, z3::mk_or(conds))}; + auto is_complete{Prove(HeavySimplify(z3::mk_or(conds)))}; if (ifs.size() <= 2 || !is_complete) { // We need to collect more statements @@ -89,7 +82,7 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { i - n + 1: if(cond_1) { } else if(cond_2) { } ... else if(cond_n) { } i - n + 2: if(cond_2) { } else if(cond_3) { } ... else if(cond_n) { } ... - i - 1 : if(cond_n-1) { } else if(cond_n) { } + i - 1 : if(cond_n-1) { } else { } i : if(cond_n) { } ... */ @@ -114,7 +107,7 @@ bool ReachBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) { i - n + 1: if(cond_1) { } else if(cond_2) { } ... else if(cond_n) { } i - n + 2: if(cond_2) { } else if(cond_3) { } ... else if(cond_n) { } ... - i - 1 : if(cond_n-1) { } else { } + i - 1 : if(cond_n-1) { } else if(cond_n) { } i : if(cond_n) { } ... diff --git a/lib/AST/StructFieldRenamer.cpp b/lib/AST/StructFieldRenamer.cpp index 6fad637c..3839322f 100644 --- a/lib/AST/StructFieldRenamer.cpp +++ b/lib/AST/StructFieldRenamer.cpp @@ -17,10 +17,10 @@ namespace rellic { -StructFieldRenamer::StructFieldRenamer(Provenance &provenance, +StructFieldRenamer::StructFieldRenamer(DecompilationContext &dec_ctx, clang::ASTUnit &unit, IRTypeToDITypeMap &types) - : ASTPass(provenance, unit), types(types) {} + : ASTPass(dec_ctx, unit), types(types) {} bool StructFieldRenamer::VisitRecordDecl(clang::RecordDecl *decl) { auto type{decls[decl]}; @@ -68,7 +68,7 @@ bool StructFieldRenamer::VisitRecordDecl(clang::RecordDecl *decl) { void StructFieldRenamer::RunImpl() { LOG(INFO) << "Renaming struct fields"; - for (auto &pair : provenance.type_decls) { + for (auto &pair : dec_ctx.type_decls) { decls[pair.second] = pair.first; } TraverseDecl(ast_ctx.getTranslationUnitDecl()); diff --git a/lib/AST/Util.cpp b/lib/AST/Util.cpp index 9bb3ef11..e7ae9682 100644 --- a/lib/AST/Util.cpp +++ b/lib/AST/Util.cpp @@ -15,6 +15,8 @@ #include #include +#include + #include "rellic/AST/ASTBuilder.h" #include "rellic/Exception.h" @@ -196,10 +198,11 @@ bool IsEquivalent(clang::Expr *a, clang::Expr *b) { class ExprCloner : public clang::StmtVisitor { ASTBuilder ast; clang::ASTContext &ctx; - ExprToUseMap &provenance; + DecompilationContext::ExprToUseMap &provenance; public: - ExprCloner(clang::ASTUnit &unit, ExprToUseMap &provenance) + ExprCloner(clang::ASTUnit &unit, + DecompilationContext::ExprToUseMap &provenance) : ast(unit), ctx(unit.getASTContext()), provenance(provenance) {} clang::Expr *VisitIntegerLiteral(clang::IntegerLiteral *expr) { @@ -311,24 +314,11 @@ class ExprCloner : public clang::StmtVisitor { }; clang::Expr *Clone(clang::ASTUnit &unit, clang::Expr *expr, - ExprToUseMap &provenance) { + DecompilationContext::ExprToUseMap &provenance) { ExprCloner cloner{unit, provenance}; return CHECK_NOTNULL(cloner.Visit(CHECK_NOTNULL(expr))); } -static clang::Expr *ApplyLNot(rellic::ASTBuilder &ast, clang::Expr *expr) { - if (auto unop = clang::dyn_cast(expr)) { - if (unop->getOpcode() == clang::UO_LNot) { - return unop->getSubExpr(); - } - } - return ast.CreateLNot(expr); -} - -clang::Expr *Negate(rellic::ASTBuilder &ast, clang::Expr *expr) { - return ApplyLNot(ast, expr->IgnoreParens())->IgnoreParens(); -} - std::string ClangThingToString(const clang::Stmt *stmt) { std::string s; llvm::raw_string_ostream os(s); @@ -336,17 +326,69 @@ std::string ClangThingToString(const clang::Stmt *stmt) { return s; } -z3::goal ApplyTactic(z3::context &ctx, const z3::tactic &tactic, - z3::expr expr) { - z3::goal goal(ctx); +z3::goal ApplyTactic(const z3::tactic &tactic, z3::expr expr) { + z3::goal goal(tactic.ctx()); goal.add(expr.simplify()); auto app{tactic(goal)}; CHECK(app.size() == 1) << "Unexpected multiple goals in application!"; return app[0]; } -bool Prove(z3::context &ctx, z3::expr expr) { - return ApplyTactic(ctx, z3::tactic(ctx, "sat"), !(expr.simplify())) +bool Prove(z3::expr expr) { + return ApplyTactic(z3::tactic(expr.ctx(), "sat"), !(expr).simplify()) .is_decided_unsat(); } + +z3::expr HeavySimplify(z3::expr expr) { + if (Prove(expr)) { + return expr.ctx().bool_val(true); + } + + z3::tactic aig(expr.ctx(), "aig"); + z3::tactic simplify(expr.ctx(), "simplify"); + z3::tactic ctx_solver_simplify(expr.ctx(), "ctx-solver-simplify"); + auto tactic{simplify & aig & ctx_solver_simplify}; + return ApplyTactic(tactic, expr).as_expr(); +} + +z3::expr_vector Clone(z3::expr_vector &vec) { + z3::expr_vector clone{vec.ctx()}; + for (auto expr : vec) { + clone.push_back(expr); + } + + return clone; +} + +z3::expr OrderById(z3::expr expr) { + if (expr.is_and() || expr.is_or()) { + std::vector args_indices(expr.num_args(), 0); + std::iota(args_indices.begin(), args_indices.end(), 0); + std::sort(args_indices.begin(), args_indices.end(), + [&expr](unsigned a, unsigned b) { + return expr.arg(a).id() < expr.arg(b).id(); + }); + z3::expr_vector new_args{expr.ctx()}; + for (auto idx : args_indices) { + new_args.push_back(OrderById(expr.arg(idx))); + } + if (expr.is_and()) { + return z3::mk_and(new_args); + } else { + return z3::mk_or(new_args); + } + } + + if (expr.is_not()) { + return !OrderById(expr.arg(0)); + } + + return expr; +} + +unsigned DecompilationContext::InsertZExpr(const z3::expr &e) { + auto idx{z3_exprs.size()}; + z3_exprs.push_back(e); + return idx; +} } // namespace rellic diff --git a/lib/AST/Z3CondSimplify.cpp b/lib/AST/Z3CondSimplify.cpp index 45977c81..828df587 100644 --- a/lib/AST/Z3CondSimplify.cpp +++ b/lib/AST/Z3CondSimplify.cpp @@ -8,135 +8,23 @@ #include "rellic/AST/Z3CondSimplify.h" -#include -#include -#include #include #include -namespace rellic { - -Z3CondSimplify::Z3CondSimplify(Provenance &provenance, clang::ASTUnit &unit) - : TransformVisitor(provenance, unit), - z_ctx(new z3::context()), - z_gen(new Z3ConvVisitor(unit, z_ctx.get())), - hash_adaptor{ast_ctx, hashes}, - proven_true(10, hash_adaptor, ke_adaptor), - proven_false(10, hash_adaptor, ke_adaptor) {} - -bool Z3CondSimplify::IsProvenTrue(clang::Expr *e) { - auto it{proven_true.find(e)}; - if (it == proven_true.end()) { - proven_true[e] = Prove(*z_ctx, ToZ3(e)); - } - return proven_true[e]; -} - -bool Z3CondSimplify::IsProvenFalse(clang::Expr *e) { - auto it{proven_false.find(e)}; - if (it == proven_false.end()) { - proven_false[e] = Prove(*z_ctx, !ToZ3(e)); - } - return proven_false[e]; -} - -z3::expr Z3CondSimplify::ToZ3(clang::Expr *e) { - return z_gen->Z3BoolCast(z_gen->GetOrCreateZ3Expr(e)); -} - -clang::Expr *Z3CondSimplify::Simplify(clang::Expr *c_expr) { - if (auto binop = clang::dyn_cast(c_expr)) { - auto lhs{Simplify(binop->getLHS())}; - auto rhs{Simplify(binop->getRHS())}; +#include "rellic/AST/Util.h" - auto opcode{binop->getOpcode()}; - if (opcode == clang::BO_LAnd || opcode == clang::BO_LOr) { - auto lhs_proven{IsProvenTrue(lhs)}; - auto rhs_proven{IsProvenTrue(rhs)}; - auto not_lhs_proven{IsProvenFalse(lhs)}; - auto not_rhs_proven{IsProvenFalse(rhs)}; - if (opcode == clang::BO_LAnd) { - if (lhs_proven && rhs_proven) { - changed = true; - return ast.CreateTrue(); - } else if (not_lhs_proven || not_rhs_proven) { - changed = true; - return ast.CreateFalse(); - } else if (lhs_proven) { - changed = true; - return rhs; - } else if (rhs_proven) { - changed = true; - return lhs; - } - } else { - if (not_lhs_proven && not_rhs_proven) { - changed = true; - return ast.CreateFalse(); - } else if (lhs_proven || rhs_proven) { - changed = true; - return ast.CreateTrue(); - } else if (not_lhs_proven) { - changed = true; - return rhs; - } else if (not_rhs_proven) { - changed = true; - return lhs; - } - } - - binop->setLHS(lhs); - binop->setRHS(rhs); - hashes[binop] = 0; - if (IsProvenTrue(binop)) { - changed = true; - return ast.CreateTrue(); - } else if (IsProvenFalse(binop)) { - changed = true; - return ast.CreateFalse(); - } - } - } else if (auto unop = clang::dyn_cast(c_expr)) { - if (unop->getOpcode() == clang::UO_LNot) { - auto sub{Simplify(unop->getSubExpr())}; - if (IsProvenTrue(sub)) { - changed = true; - return ast.CreateFalse(); - } else if (IsProvenFalse(sub)) { - changed = true; - return ast.CreateTrue(); - } - unop->setSubExpr(sub); - hashes[unop] = 0; - } - } else if (auto paren = clang::dyn_cast(c_expr)) { - auto sub{Simplify(paren->getSubExpr())}; - paren->setSubExpr(sub); - hashes[paren] = 0; - } - - return c_expr; -} - -bool Z3CondSimplify::VisitIfStmt(clang::IfStmt *stmt) { - stmt->setCond(Simplify(stmt->getCond())); - return true; -} +namespace rellic { -bool Z3CondSimplify::VisitWhileStmt(clang::WhileStmt *loop) { - loop->setCond(Simplify(loop->getCond())); - return true; -} - -bool Z3CondSimplify::VisitDoStmt(clang::DoStmt *loop) { - loop->setCond(Simplify(loop->getCond())); - return true; -} +Z3CondSimplify::Z3CondSimplify(DecompilationContext& dec_ctx, + clang::ASTUnit& unit) + : ASTPass(dec_ctx, unit) {} void Z3CondSimplify::RunImpl() { LOG(INFO) << "Simplifying conditions using Z3"; - TransformVisitor::RunImpl(); - TraverseDecl(ast_ctx.getTranslationUnitDecl()); + for (size_t i{0}; i < dec_ctx.z3_exprs.size() && !Stopped(); ++i) { + auto simpl{OrderById(dec_ctx.z3_exprs[i].simplify())}; + dec_ctx.z3_exprs.set(i, simpl); + } } } // namespace rellic diff --git a/lib/AST/Z3ConvVisitor.cpp b/lib/AST/Z3ConvVisitor.cpp deleted file mode 100644 index ef248197..00000000 --- a/lib/AST/Z3ConvVisitor.cpp +++ /dev/null @@ -1,1223 +0,0 @@ -/* - * Copyright (c) 2021-present, Trail of Bits, Inc. - * All rights reserved. - * - * This source code is licensed in accordance with the terms specified in - * the LICENSE file found in the root directory of this source tree. - */ - -#define GOOGLE_STRIP_LOG 1 - -#include "rellic/AST/Z3ConvVisitor.h" - -#include -#include - -#include "rellic/AST/Util.h" -#include "rellic/Exception.h" - -namespace rellic { - -namespace { - -static unsigned GetZ3SortSize(z3::sort sort) { - switch (sort.sort_kind()) { - case Z3_BOOL_SORT: - return 1; - break; - - case Z3_BV_SORT: - return sort.bv_size(); - break; - - case Z3_FLOATING_POINT_SORT: { - auto &ctx{sort.ctx()}; - return Z3_fpa_get_sbits(ctx, sort) + Z3_fpa_get_ebits(ctx, sort); - } break; - - case Z3_UNINTERPRETED_SORT: - return 0; - break; - - default: - LOG(FATAL) << "Unknown Z3 sort: " << sort; - break; - } - // This code is unreachable, but sometimes we need to - // fix 'error: control reaches end of non-void function' on some compilers. - return (unsigned)(-1); -} - -static unsigned GetZ3SortSize(z3::expr expr) { - return GetZ3SortSize(expr.get_sort()); -} - -// Determine if `op` is a `z3::concat(l, r)` that's -// equivalent to a sign extension. This is done by -// checking if `l` is an "all-one" or "all-zero" bit -// value. -static bool IsSignExt(z3::expr op) { - if (op.decl().decl_kind() != Z3_OP_CONCAT) { - return false; - } - - auto lhs{op.arg(0)}; - - if (lhs.is_numeral()) { - auto size{GetZ3SortSize(lhs)}; - llvm::APInt val(size, Z3_get_numeral_string(op.ctx(), lhs), 10); - return val.isAllOnesValue() || val.isNullValue(); - } - return false; -} - -static std::string CreateZ3DeclName(clang::NamedDecl *decl) { - std::stringstream ss; - ss << std::hex << decl << std::dec; - ss << '_' << decl->getNameAsString(); - return ss.str(); -} - -} // namespace - -Z3ConvVisitor::Z3ConvVisitor(clang::ASTUnit &unit, z3::context *z_ctx) - : c_ctx(&unit.getASTContext()), - ast(unit), - z_ctx(z_ctx), - z_expr_vec(*z_ctx), - z_decl_vec(*z_ctx) {} - -// Inserts a `clang::Expr` <=> `z3::expr` mapping into -void Z3ConvVisitor::InsertZ3Expr(clang::Expr *c_expr, z3::expr z_expr) { - CHECK(c_expr) << "Inserting null clang::Expr key."; - CHECK(bool(z_expr)) << "Inserting null z3::expr value."; - CHECK(!z_expr_map.count(c_expr)) << "clang::Expr key already exists."; - z_expr_map[c_expr] = z_expr_vec.size(); - z_expr_vec.push_back(z_expr); -} - -// Retrieves a `z3::expr` corresponding to `c_expr`. -// The `z3::expr` needs to be created and inserted by -// `Z3ConvVisistor::InsertZ3Expr` first. -z3::expr Z3ConvVisitor::GetZ3Expr(clang::Expr *c_expr) { - auto iter{z_expr_map.find(c_expr)}; - CHECK(iter != z_expr_map.end()); - return z_expr_vec[iter->second]; -} - -// Inserts a `clang::ValueDecl` <=> `z3::func_decl` mapping into -void Z3ConvVisitor::InsertZ3Decl(clang::ValueDecl *c_decl, - z3::func_decl z_decl) { - CHECK(c_decl) << "Inserting null clang::ValueDecl key."; - CHECK(bool(z_decl)) << "Inserting null z3::func_decl value."; - CHECK(!z_decl_map.count(c_decl)) << "clang::ValueDecl key already exists."; - z_decl_map[c_decl] = z_decl_vec.size(); - z_decl_vec.push_back(z_decl); -} - -// Retrieves a `z3::func_decl` corresponding to `c_decl`. -// The `z3::func_decl` needs to be created and inserted by -// `Z3ConvVisistor::InsertZ3Decl` first. -z3::func_decl Z3ConvVisitor::GetZ3Decl(clang::ValueDecl *c_decl) { - auto iter{z_decl_map.find(c_decl)}; - CHECK(iter != z_decl_map.end()); - return z_decl_vec[iter->second]; -} - -// If `expr` is not boolean, returns a `z3::expr` that corresponds -// to the non-boolean to boolean expression cast in C. Otherwise -// returns `expr`. -z3::expr Z3ConvVisitor::Z3BoolCast(z3::expr expr) { - if (expr.is_bool()) { - return expr; - } else { - return expr != z_ctx->num_val(0, expr.get_sort()); - } -} - -z3::expr Z3ConvVisitor::Z3BoolToBVCast(z3::expr expr) { - if (expr.is_bv()) { - return expr; - } - - CHECK(expr.is_bool()); - - auto size{c_ctx->getTypeSize(c_ctx->IntTy)}; - return z3::ite(expr, z_ctx->bv_val(1U, size), z_ctx->bv_val(0U, size)); -} - -void Z3ConvVisitor::InsertCExpr(z3::expr z_expr, clang::Expr *c_expr) { - CHECK(bool(z_expr)) << "Inserting null z3::expr key."; - CHECK(c_expr) << "Inserting null clang::Expr value."; - auto hash{z_expr.hash()}; - CHECK(!c_expr_map.count(hash)) << "z3::expr key already exists."; - c_expr_map[hash] = c_expr; -} - -clang::Expr *Z3ConvVisitor::GetCExpr(z3::expr z_expr) { - auto hash{z_expr.hash()}; - CHECK(c_expr_map.count(hash)) << "No Z3 equivalent for C expression!"; - return c_expr_map[hash]; -} - -void Z3ConvVisitor::InsertCValDecl(z3::func_decl z_decl, - clang::ValueDecl *c_decl) { - CHECK(c_decl) << "Inserting null z3::func_decl key."; - CHECK(bool(z_decl)) << "Inserting null clang::ValueDecl value."; - CHECK(!c_decl_map.count(z_decl.id())) << "z3::func_decl key already exists."; - c_decl_map[z_decl.id()] = c_decl; -} - -clang::ValueDecl *Z3ConvVisitor::GetCValDecl(z3::func_decl z_decl) { - CHECK(c_decl_map.count(z_decl.id())) << "No C equivalent for Z3 declaration!"; - return c_decl_map[z_decl.id()]; -} - -z3::sort Z3ConvVisitor::GetZ3Sort(clang::QualType type) { - // Void - if (type->isVoidType()) { - return z_ctx->uninterpreted_sort("void"); - } - // Booleans - if (type->isBooleanType()) { - return z_ctx->bool_sort(); - } - // Structures - if (type->isStructureType()) { - auto decl{clang::cast(type)->getDecl()}; - return z_ctx->uninterpreted_sort(decl->getNameAsString().c_str()); - } - auto bitwidth{c_ctx->getTypeSize(type)}; - // Floating points - if (type->isRealFloatingType()) { - switch (bitwidth) { - case 16: - return z3::to_sort(*z_ctx, Z3_mk_fpa_sort_16(*z_ctx)); - break; - - case 32: - return z3::to_sort(*z_ctx, Z3_mk_fpa_sort_32(*z_ctx)); - break; - - case 64: - return z3::to_sort(*z_ctx, Z3_mk_fpa_sort_64(*z_ctx)); - break; - - case 128: - return z3::to_sort(*z_ctx, Z3_mk_fpa_sort_128(*z_ctx)); - break; - - default: - LOG(FATAL) << "Unsupported floating-point bitwidth!"; - break; - } - } - // Default to bitvectors - return z3::to_sort(*z_ctx, z_ctx->bv_sort(bitwidth)); -} - -clang::Expr *Z3ConvVisitor::CreateLiteralExpr(z3::expr z_expr) { - DLOG(INFO) << "Creating literal clang::Expr for " << z_expr; - CHECK(z_expr.is_const()) << "Can only create C literal from Z3 constant!"; - auto z_sort{z_expr.get_sort()}; - clang::Expr *result{nullptr}; - switch (z_sort.sort_kind()) { - case Z3_BOOL_SORT: { - auto val{z_expr.bool_value() == Z3_L_TRUE ? 1U : 0U}; - result = ast.CreateIntLit(llvm::APInt(/*BitWidth=*/1U, val)); - } break; - - case Z3_BV_SORT: { - llvm::APInt val(GetZ3SortSize(z_expr), - Z3_get_numeral_string(z_expr.ctx(), z_expr), 10); - // Handle `char` and `short` types separately, because clang - // adds non-standard `Ui8` and `Ui16` suffixes respectively. - result = ast.CreateAdjustedIntLit(val); - } break; - - case Z3_FLOATING_POINT_SORT: { - auto size{GetZ3SortSize(z_sort)}; - const llvm::fltSemantics *semantics{nullptr}; - switch (size) { - case 16: - semantics = &llvm::APFloat::IEEEhalf(); - break; - case 32: - semantics = &llvm::APFloat::IEEEsingle(); - break; - case 64: - semantics = &llvm::APFloat::IEEEdouble(); - break; - case 128: - semantics = &llvm::APFloat::IEEEquad(); - break; - default: - THROW() << "Unknown Z3 floating-point sort!"; - break; - } - z3::expr bv(*z_ctx, Z3_mk_fpa_to_ieee_bv(*z_ctx, z_expr)); - auto bits{Z3_get_numeral_string(*z_ctx, bv.simplify())}; - CHECK(std::strlen(bits) > 0) - << "Failed to convert IEEE bitvector to string!"; - llvm::APInt api(size, bits, /*radix=*/10U); - result = ast.CreateFPLit(llvm::APFloat(*semantics, api)); - } break; - - default: - LOG(FATAL) << "Unknown Z3 sort: " << z_sort; - break; - } - - return result; -} - -// Retrieves or creates`z3::expr`s from `clang::Expr`. -z3::expr Z3ConvVisitor::GetOrCreateZ3Expr(clang::Expr *c_expr) { - if (!z_expr_map.count(c_expr)) { - TraverseStmt(c_expr); - } - - return GetZ3Expr(c_expr); -} - -z3::func_decl Z3ConvVisitor::GetOrCreateZ3Decl(clang::ValueDecl *c_decl) { - if (!z_decl_map.count(c_decl)) { - TraverseDecl(c_decl); - } - - auto z_decl{GetZ3Decl(c_decl)}; - if (!c_decl_map.count(z_decl.id())) { - InsertCValDecl(z_decl, c_decl); - } - - return z_decl; -} - -// Retrieves or creates `clang::Expr` from `z3::expr`. -clang::Expr *Z3ConvVisitor::GetOrCreateCExpr(z3::expr z_expr) { - if (!c_expr_map.count(z_expr.hash())) { - VisitZ3Expr(z_expr); - } - return GetCExpr(z_expr); -} - -// Retrieves or creates `clang::ValueDecl` from `z3::func_decl`. -clang::ValueDecl *Z3ConvVisitor::GetOrCreateCValDecl(z3::func_decl z_decl) { - if (!c_decl_map.count(z_decl.id())) { - VisitZ3Decl(z_decl); - } - return GetCValDecl(z_decl); -} - -bool Z3ConvVisitor::VisitVarDecl(clang::VarDecl *c_var) { - auto c_name{c_var->getNameAsString()}; - DLOG(INFO) << "VisitVarDecl: " << c_name; - if (z_decl_map.count(c_var)) { - DLOG(INFO) << "Re-declaration of " << c_name << "; Returning."; - return true; - } - - auto z_name{CreateZ3DeclName(c_var)}; - auto z_sort{GetZ3Sort(c_var->getType())}; - auto z_const{z_ctx->constant(z_name.c_str(), z_sort)}; - - InsertZ3Decl(c_var, z_const.decl()); - - return true; -} - -bool Z3ConvVisitor::VisitFieldDecl(clang::FieldDecl *c_field) { - auto c_name{c_field->getNameAsString()}; - DLOG(INFO) << "VisitFieldDecl: " << c_name; - if (z_decl_map.count(c_field)) { - DLOG(INFO) << "Re-declaration of " << c_name << "; Returning."; - return true; - } - - auto z_name{CreateZ3DeclName(c_field->getParent()) + "_" + c_name}; - auto z_sort{GetZ3Sort(c_field->getType())}; - auto z_const{z_ctx->constant(z_name.c_str(), z_sort)}; - - InsertZ3Decl(c_field, z_const.decl()); - - return true; -} - -bool Z3ConvVisitor::VisitFunctionDecl(clang::FunctionDecl *c_func) { - auto c_name{c_func->getNameAsString()}; - DLOG(INFO) << "VisitFunctionDecl: " << c_name; - if (z_decl_map.count(c_func)) { - DLOG(INFO) << "Re-declaration of " << c_name << "; Returning."; - return true; - } - - z3::sort_vector z_domains(*z_ctx); - for (auto c_param : c_func->parameters()) { - z_domains.push_back(GetZ3Sort(c_param->getType())); - } - - auto z_range{GetZ3Sort(c_func->getReturnType())}; - auto z_func{z_ctx->function(c_name.c_str(), z_domains, z_range)}; - - InsertZ3Decl(c_func, z_func); - - return true; -} - -z3::expr Z3ConvVisitor::CreateZ3BitwiseCast(z3::expr expr, size_t src, - size_t dst, bool sign) { - if (expr.is_bool()) { - expr = Z3BoolToBVCast(expr); - } - - // CHECK(expr.is_bv()) << "z3::expr is not a bitvector!"; - CHECK_EQ(GetZ3SortSize(expr), src); - - int64_t diff = dst - src; - // extend - if (diff > 0) { - return sign ? z3::sext(expr, diff) : z3::zext(expr, diff); - } - // truncate - if (diff < 0) { - return expr.extract(dst - 1, 0); - } - // nothing - return expr; -} - -template -bool Z3ConvVisitor::HandleCastExpr(T *c_cast) { - CHECK(clang::isa(c_cast) || - clang::isa(c_cast)); - // C exprs - auto c_sub{c_cast->getSubExpr()}; - // C types - auto c_src_ty{c_sub->getType()}; - auto c_dst_ty{c_cast->getType()}; - // C type sizes - auto src_ty_size{c_ctx->getTypeSize(c_src_ty)}; - auto dst_ty_size{c_ctx->getTypeSize(c_dst_ty)}; - // Z3 exprs - auto z_sub{GetZ3Expr(c_sub)}; - auto z_cast{z_sub}; - // Z3 sorts - auto z_src_sort{z_sub.get_sort()}; - auto z_dst_sort{z_ctx->bv_sort(dst_ty_size)}; - switch (c_cast->getCastKind()) { - case clang::CastKind::CK_IntegralCast: - z_cast = CreateZ3BitwiseCast(z_sub, src_ty_size, dst_ty_size, - c_src_ty->isSignedIntegerType()); - break; - - case clang::CastKind::CK_PointerToIntegral: - z_cast = z_ctx->function("PtrToInt", z_src_sort, z_dst_sort)(z_sub); - break; - - case clang::CastKind::CK_NullToPointer: - case clang::CastKind::CK_IntegralToPointer: { - auto c_dst_ty_ptr{reinterpret_cast(c_dst_ty.getAsOpaquePtr())}; - auto z_dst_ty_ptr{z_ctx->bv_val(c_dst_ty_ptr, 8 * sizeof(void *))}; - z_cast = z_ctx->function("IntToPtr", z_dst_ty_ptr.get_sort(), z_src_sort, - z_dst_sort)(z_dst_ty_ptr, z_sub); - } break; - - case clang::CastKind::CK_BitCast: { - auto c_dst_ty_ptr{reinterpret_cast(c_dst_ty.getAsOpaquePtr())}; - auto z_dst_ty_ptr{z_ctx->bv_val(c_dst_ty_ptr, 8 * sizeof(void *))}; - auto z_dst_sort{z_ctx->bv_sort(dst_ty_size)}; - z_cast = z_ctx->function("BitCast", z_dst_ty_ptr.get_sort(), z_src_sort, - z_dst_sort)(z_dst_ty_ptr, z_sub); - } break; - - case clang::CastKind::CK_ArrayToPointerDecay: { - CHECK(z_sub.is_bv()) << "Pointer cast operand is not a bit-vector"; - auto z_ptr_sort{GetZ3Sort(c_cast->getType())}; - auto z_arr_sort{z_sub.get_sort()}; - z_cast = z_ctx->function("PtrDecay", z_arr_sort, z_ptr_sort)(z_sub); - } break; - - case clang::CastKind::CK_NoOp: - case clang::CastKind::CK_LValueToRValue: - case clang::CastKind::CK_FunctionToPointerDecay: - // case clang::CastKind::CK_ArrayToPointerDecay: - break; - - default: - THROW() << "Unsupported cast type: " << c_cast->getCastKindName(); - break; - } - // Save - InsertZ3Expr(c_cast, z_cast); - - return true; -} - -bool Z3ConvVisitor::VisitCStyleCastExpr(clang::CStyleCastExpr *c_cast) { - DLOG(INFO) << "VisitCStyleCastExpr"; - if (z_expr_map.count(c_cast)) { - return true; - } - return HandleCastExpr(c_cast); -} - -bool Z3ConvVisitor::VisitImplicitCastExpr(clang::ImplicitCastExpr *c_cast) { - DLOG(INFO) << "VisitImplicitCastExpr"; - if (z_expr_map.count(c_cast)) { - return true; - } - return HandleCastExpr(c_cast); -} - -bool Z3ConvVisitor::VisitArraySubscriptExpr(clang::ArraySubscriptExpr *sub) { - DLOG(INFO) << "VisitArraySubscriptExpr"; - if (z_expr_map.count(sub)) { - return true; - } - // Get base - auto z_base{GetZ3Expr(sub->getBase())}; - auto base_sort{z_base.get_sort()}; - CHECK(base_sort.is_bv()) << "Invalid Z3 sort for base expression"; - // Get index - auto z_idx{GetZ3Expr(sub->getIdx())}; - auto idx_sort{z_idx.get_sort()}; - CHECK(idx_sort.is_bv()) << "Invalid Z3 sort for index expression"; - // Get result - auto elm_sort{GetZ3Sort(sub->getType())}; - // Create a z_function - auto z_arr_sub{z_ctx->function("ArraySub", base_sort, idx_sort, elm_sort)}; - // Create a z3 expression - InsertZ3Expr(sub, z_arr_sub(z_base, z_idx)); - // Done - return true; -} - -bool Z3ConvVisitor::VisitMemberExpr(clang::MemberExpr *expr) { - DLOG(INFO) << "VisitMemberExpr"; - if (z_expr_map.count(expr)) { - return true; - } - - auto z_mem{GetOrCreateZ3Decl(expr->getMemberDecl())()}; - auto z_base{GetZ3Expr(expr->getBase())}; - auto z_mem_expr{z_ctx->function("Member", z_base.get_sort(), z_mem.get_sort(), - z_mem.get_sort())}; - - InsertZ3Expr(expr, z_mem_expr(z_base, z_mem)); - - return true; -} - -bool Z3ConvVisitor::VisitCallExpr(clang::CallExpr *c_call) { - DLOG(INFO) << "VisitCallExpr"; - if (z_expr_map.count(c_call)) { - return true; - } - z3::expr_vector z_args(*z_ctx); - // Get call id - z_args.push_back(z_ctx->bv_val(GetHash(*c_ctx, c_call), /*sz=*/64U)); - // Get callee - auto z_callee{GetZ3Expr(c_call->getCallee())}; - z_args.push_back(z_callee); - // Get call arguments - for (auto c_arg : c_call->arguments()) { - z_args.push_back(GetZ3Expr(c_arg)); - } - // Build the z3 call - z3::sort_vector z_domain(*z_ctx); - for (auto z_arg : z_args) { - z_domain.push_back(z_arg.get_sort()); - } - auto z_range{z_callee.is_array() ? z_callee.get_sort().array_range() - : z_callee.get_sort()}; - auto z_call{z_ctx->function("Call", z_domain, z_range)}; - // Insert the call - InsertZ3Expr(c_call, z_call(z_args)); - - return true; -} - -// Translates clang unary operators expressions to Z3 equivalents. -bool Z3ConvVisitor::VisitParenExpr(clang::ParenExpr *parens) { - DLOG(INFO) << "VisitParenExpr"; - if (z_expr_map.count(parens)) { - return true; - } - - InsertZ3Expr(parens, GetZ3Expr(parens->getSubExpr())); - - return true; -} - -// Translates clang unary operators expressions to Z3 equivalents. -bool Z3ConvVisitor::VisitUnaryOperator(clang::UnaryOperator *c_op) { - DLOG(INFO) << "VisitUnaryOperator: " - << c_op->getOpcodeStr(c_op->getOpcode()).str(); - if (z_expr_map.count(c_op)) { - return true; - } - // Get operand - auto operand{GetZ3Expr(c_op->getSubExpr())}; - // Conditionally cast operands to a bitvector - auto CondBoolToBVCast{[this, &operand]() { - if (operand.is_bool()) { - operand = Z3BoolToBVCast(operand); - } - }}; - // Create z3 unary op - switch (c_op->getOpcode()) { - case clang::UO_LNot: - InsertZ3Expr(c_op, !Z3BoolCast(operand)); - break; - - case clang::UO_Minus: - InsertZ3Expr(c_op, -operand); - break; - - case clang::UO_Not: - CondBoolToBVCast(); - InsertZ3Expr(c_op, ~operand); - break; - - case clang::UO_AddrOf: { - auto ptr_sort{GetZ3Sort(c_op->getType())}; - auto z_addrof{z_ctx->function("AddrOf", operand.get_sort(), ptr_sort)}; - InsertZ3Expr(c_op, z_addrof(operand)); - } break; - - case clang::UO_Deref: { - auto elm_sort{GetZ3Sort(c_op->getType())}; - auto z_deref{z_ctx->function("Deref", operand.get_sort(), elm_sort)}; - InsertZ3Expr(c_op, z_deref(operand)); - } break; - - default: - THROW() << "Unknown clang::UnaryOperator operation: " - << c_op->getOpcodeStr(c_op->getOpcode()).str(); - break; - } - return true; -} - -// Translates clang binary operators expressions to Z3 equivalents. -bool Z3ConvVisitor::VisitBinaryOperator(clang::BinaryOperator *c_op) { - DLOG(INFO) << "VisitBinaryOperator: " << c_op->getOpcodeStr().str(); - if (z_expr_map.count(c_op)) { - return true; - } - // Get operands - auto lhs{GetZ3Expr(c_op->getLHS())}; - auto rhs{GetZ3Expr(c_op->getRHS())}; - // Conditionally cast operands match size to the wider one - auto CondSizeCast{[this, &lhs, &rhs] { - auto lhs_size{GetZ3SortSize(lhs)}; - auto rhs_size{GetZ3SortSize(rhs)}; - if (lhs_size < rhs_size) { - lhs = CreateZ3BitwiseCast(lhs, lhs_size, rhs_size, /*sign=*/true); - } else if (lhs_size > rhs_size) { - rhs = CreateZ3BitwiseCast(rhs, rhs_size, lhs_size, /*sign=*/true); - } - }}; - // Conditionally cast operands to bool - auto CondBoolCast{[this, &lhs, &rhs]() { - if (lhs.is_bool() || rhs.is_bool()) { - lhs = Z3BoolCast(lhs); - rhs = Z3BoolCast(rhs); - } - }}; - // Conditionally cast operands to a bitvector - auto CondBoolToBVCast{[this, &lhs, &rhs]() { - if (lhs.is_bool()) { - CHECK(rhs.is_bv() || rhs.is_bool()); - lhs = Z3BoolToBVCast(lhs); - } - if (rhs.is_bool()) { - CHECK(lhs.is_bv()); - rhs = Z3BoolToBVCast(rhs); - } - }}; - // Create z3 binary op - switch (c_op->getOpcode()) { - case clang::BO_LAnd: - CondBoolCast(); - InsertZ3Expr(c_op, lhs.is_bv() && rhs.is_bv() ? lhs & rhs : lhs && rhs); - break; - - case clang::BO_LOr: - CondBoolCast(); - InsertZ3Expr(c_op, lhs.is_bv() && rhs.is_bv() ? lhs | rhs : lhs || rhs); - break; - - case clang::BO_EQ: { - CondBoolCast(); - CondSizeCast(); - InsertZ3Expr(c_op, lhs == rhs); - } break; - - case clang::BO_NE: - CondBoolCast(); - CondSizeCast(); - InsertZ3Expr(c_op, lhs != rhs); - break; - - case clang::BO_GE: - CondSizeCast(); - InsertZ3Expr(c_op, lhs >= rhs); - break; - - case clang::BO_GT: - CondSizeCast(); - InsertZ3Expr(c_op, lhs > rhs); - break; - - case clang::BO_LE: - CondSizeCast(); - InsertZ3Expr(c_op, lhs <= rhs); - break; - - case clang::BO_LT: - CondSizeCast(); - InsertZ3Expr(c_op, lhs < rhs); - break; - - case clang::BO_Rem: - CondSizeCast(); - InsertZ3Expr(c_op, z3::srem(lhs, rhs)); - break; - - case clang::BO_Add: - CondSizeCast(); - InsertZ3Expr(c_op, lhs + rhs); - break; - - case clang::BO_Sub: - CondSizeCast(); - InsertZ3Expr(c_op, lhs - rhs); - break; - - case clang::BO_Mul: - CondSizeCast(); - InsertZ3Expr(c_op, lhs * rhs); - break; - - case clang::BO_Div: - CondSizeCast(); - InsertZ3Expr(c_op, lhs / rhs); - break; - - case clang::BO_And: - CondBoolToBVCast(); - CondSizeCast(); - InsertZ3Expr(c_op, lhs & rhs); - break; - - case clang::BO_Or: - CondBoolToBVCast(); - CondSizeCast(); - InsertZ3Expr(c_op, lhs | rhs); - break; - - case clang::BO_Xor: - CondSizeCast(); - InsertZ3Expr(c_op, lhs ^ rhs); - break; - - case clang::BO_Shr: - CondBoolToBVCast(); - CondSizeCast(); - InsertZ3Expr(c_op, c_op->getLHS()->getType()->isSignedIntegerType() - ? z3::ashr(lhs, rhs) - : z3::lshr(lhs, rhs)); - break; - - case clang::BO_Shl: - CondSizeCast(); - InsertZ3Expr(c_op, z3::shl(lhs, rhs)); - break; - - default: - THROW() << "Unknown clang::BinaryOperator operation: " - << c_op->getOpcodeStr().str(); - break; - } - return true; -} - -bool Z3ConvVisitor::VisitConditionalOperator(clang::ConditionalOperator *c_op) { - DLOG(INFO) << "VisitConditionalOperator"; - if (z_expr_map.count(c_op)) { - return true; - } - - auto z_cond{Z3BoolCast(GetZ3Expr(c_op->getCond()))}; - auto z_then{GetZ3Expr(c_op->getTrueExpr())}; - auto z_else{GetZ3Expr(c_op->getFalseExpr())}; - - auto z_then_size{GetZ3SortSize(z_then)}; - auto z_else_size{GetZ3SortSize(z_else)}; - - if (z_then_size > z_else_size) { - z_else = - CreateZ3BitwiseCast(z_else, z_else_size, z_then_size, /*sign=*/true); - } else if (z_then_size < z_else_size) { - z_then = - CreateZ3BitwiseCast(z_then, z_then_size, z_else_size, /*sign=*/true); - } - - InsertZ3Expr(c_op, z3::ite(z_cond, z_then, z_else)); - - return true; -} - -// Translates clang variable references to Z3 constants. -bool Z3ConvVisitor::VisitDeclRefExpr(clang::DeclRefExpr *c_ref) { - auto c_ref_decl{c_ref->getDecl()}; - auto c_ref_name{c_ref_decl->getNameAsString()}; - DLOG(INFO) << "VisitDeclRefExpr: " << c_ref_name; - if (z_expr_map.count(c_ref)) { - return true; - } - - auto z_decl{GetOrCreateZ3Decl(c_ref_decl)}; - auto z_ref{z_decl.is_const() ? z_decl() : z3::as_array(z_decl)}; - - InsertZ3Expr(c_ref, z_ref); - - return true; -} - -// Translates clang character literals references to Z3 numeral values. -bool Z3ConvVisitor::VisitCharacterLiteral(clang::CharacterLiteral *c_lit) { - auto c_val{c_lit->getValue()}; - DLOG(INFO) << "VisitCharacterLiteral: " << c_val; - if (z_expr_map.count(c_lit)) { - return true; - } - - auto z_sort{GetZ3Sort(c_lit->getType())}; - auto z_val{z_ctx->num_val(c_val, z_sort)}; - InsertZ3Expr(c_lit, z_val); - - return true; -} - -// Translates clang integer literal references to Z3 numeral values. -bool Z3ConvVisitor::VisitIntegerLiteral(clang::IntegerLiteral *c_lit) { - auto c_val{c_lit->getValue().getLimitedValue()}; - DLOG(INFO) << "VisitIntegerLiteral: " << c_val; - if (z_expr_map.count(c_lit)) { - return true; - } - - auto z_sort{GetZ3Sort(c_lit->getType())}; - auto z_val{z_sort.is_bool() ? z_ctx->bool_val(c_val != 0) - : z_ctx->num_val(c_val, z_sort)}; - InsertZ3Expr(c_lit, z_val); - - return true; -} - -// Translates clang floating point literal references to Z3 numeral values. -bool Z3ConvVisitor::VisitFloatingLiteral(clang::FloatingLiteral *lit) { - auto api{lit->getValue().bitcastToAPInt()}; - DLOG(INFO) << "VisitFloatingLiteral: " << api.bitsToDouble(); - if (z_expr_map.count(lit)) { - return true; - } - - auto size{api.getBitWidth()}; - llvm::SmallString<64U> bits; - api.toString(bits, /*Radix=*/10U, /*Signed=*/false); - auto sort{GetZ3Sort(lit->getType())}; - auto bv{z_ctx->bv_val(bits.c_str(), size)}; - auto fpa{z3::to_expr(*z_ctx, Z3_mk_fpa_to_fp_bv(*z_ctx, bv, sort))}; - - InsertZ3Expr(lit, fpa); - - return true; -} - -void Z3ConvVisitor::VisitZ3Expr(z3::expr z_expr) { - auto z_decl{z_expr.decl()}; - CHECK(z_expr.is_app()) << "Unexpected Z3 operation: " << z_decl.name(); - // Handle arguments first - for (auto i{0U}; i < z_expr.num_args(); ++i) { - GetOrCreateCExpr(z_expr.arg(i)); - } - // TODO(msurovic): Rework this into a visitor based on - // z_expr.decl().decl_kind() - // Handle bitvector concats - if (z_decl.decl_kind() == Z3_OP_CONCAT) { - InsertCExpr(z_expr, HandleZ3Concat(z_expr)); - return; - } - // Handle uninterpreted functions - if (z_decl.decl_kind() == Z3_OP_UNINTERPRETED) { - InsertCExpr(z_expr, HandleZ3Uninterpreted(z_expr)); - return; - } - // Handle the rest - switch (z_decl.arity()) { - case 0: - VisitConstant(z_expr); - break; - - case 1: - VisitUnaryApp(z_expr); - break; - - case 2: - VisitBinaryApp(z_expr); - break; - - case 3: - VisitTernaryApp(z_expr); - break; - - default: - LOG(FATAL) << "Unexpected Z3 operation: " << z_decl.name(); - break; - } -} - -void Z3ConvVisitor::VisitConstant(z3::expr z_const) { - DLOG(INFO) << "VisitConstant: " << z_const; - CHECK(z_const.is_const()) << "Z3 expression is not a constant!"; - // Create C literals and variable references - clang::Expr *c_expr{nullptr}; - switch (z_const.decl().decl_kind()) { - // Boolean literals - case Z3_OP_TRUE: - case Z3_OP_FALSE: - // Arithmetic numerals - case Z3_OP_ANUM: - // Bitvector numerals - case Z3_OP_BNUM: - // Floating-point numerals - case Z3_OP_FPA_NUM: - case Z3_OP_FPA_PLUS_INF: - case Z3_OP_FPA_MINUS_INF: - case Z3_OP_FPA_NAN: - case Z3_OP_FPA_PLUS_ZERO: - case Z3_OP_FPA_MINUS_ZERO: - c_expr = CreateLiteralExpr(z_const); - break; - // Functions-as-array expressions - case Z3_OP_AS_ARRAY: { - z3::func_decl z_decl(*z_ctx, Z3_get_as_array_func_decl(*z_ctx, z_const)); - c_expr = ast.CreateDeclRef(GetOrCreateCValDecl(z_decl)); - } break; - // Internal constants handled by parent Z3 exprs - case Z3_OP_INTERNAL: - break; - // Unknowns - default: - THROW() << "Unknown Z3 constant: " << z_const; - break; - } - InsertCExpr(z_const, c_expr); -} - -clang::Expr *Z3ConvVisitor::HandleZ3Concat(z3::expr z_op) { - auto lhs{GetCExpr(z_op.arg(0U))}; - for (auto i{1U}; i < z_op.num_args(); ++i) { - // Given a `(concat l r)` we generate `((t)l << w) | r` where - // * `w` is the bitwidth of `r` - // - // * `t` is the smallest integer type that can fit the result - // of `(concat l r)` - auto rhs{GetCExpr(z_op.arg(i))}; - auto res_ty{ast.GetLeastIntTypeForBitWidth(GetZ3SortSize(z_op), - /*sign=*/0U)}; - if (!IsSignExt(z_op)) { - auto cast{ast.CreateCStyleCast(res_ty, lhs)}; - auto shl_val{ - ast.CreateIntLit(llvm::APInt(32U, GetZ3SortSize(z_op.arg(1U))))}; - auto shl{ast.CreateShl(cast, shl_val)}; - lhs = ast.CreateOr(shl, rhs); - } else { - lhs = ast.CreateCStyleCast(res_ty, rhs); - } - } - return lhs; -} - -clang::Expr *Z3ConvVisitor::HandleZ3Uninterpreted(z3::expr z_op) { - clang::Expr *c_op{nullptr}; - // Constants - if (z_op.is_const()) { - auto c_decl{GetOrCreateCValDecl(z_op.decl())}; - if (clang::isa(c_decl)) { - c_op = ast.CreateNull(); - } else { - c_op = ast.CreateDeclRef(c_decl); - } - return c_op; - } - // Functions - auto z_decl{z_op.decl()}; - auto z_func_name{z_decl.name().str()}; - auto lhs{[this, &z_op] { return GetCExpr(z_op.arg(0U)); }}; - auto rhs{[this, &z_op] { return GetCExpr(z_op.arg(1U)); }}; - // Get result type for casts - auto GetTypeFromOpaquePtrLiteral{[&lhs] { - auto c_lit{clang::cast(lhs())}; - auto t_dst_opaque_ptr_val{c_lit->getValue().getLimitedValue()}; - auto t_dst_opaque_ptr{reinterpret_cast(t_dst_opaque_ptr_val)}; - return clang::QualType::getFromOpaquePtr(t_dst_opaque_ptr); - }}; - if (z_func_name == "AddrOf") { - c_op = ast.CreateAddrOf(lhs()); - } else if (z_func_name == "Deref") { - c_op = ast.CreateDeref(lhs()); - } else if (z_func_name == "PtrToInt") { - auto s_size{GetZ3SortSize(z_op)}; - auto t_op{ast.GetLeastIntTypeForBitWidth(s_size, /*sign=*/0U)}; - c_op = ast.CreateCStyleCast(t_op, lhs()); - } else if (z_func_name == "PtrDecay") { - c_op = lhs(); - } else if (z_func_name == "ArraySub") { - c_op = ast.CreateArraySub(lhs(), rhs()); - } else if (z_func_name == "Member") { - auto mem{GetOrCreateCValDecl(z_op.arg(1U).decl())}; - auto field{clang::dyn_cast(mem)}; - CHECK(field != nullptr) << "Operand is not a clang::FieldDecl"; - c_op = ast.CreateDot(lhs(), field); - } else if (z_func_name == "IntToPtr" || z_func_name == "BitCast") { - c_op = ast.CreateCStyleCast(GetTypeFromOpaquePtrLiteral(), rhs()); - } else if (z_func_name == "Call") { - auto c_callee{rhs()}; - std::vector c_args; - for (auto i{2U}; i < z_op.num_args(); ++i) { - c_args.push_back(GetCExpr(z_op.arg(i))); - } - c_op = ast.CreateCall(c_callee, c_args); - } else { - LOG(FATAL) << "Unknown Z3 uninterpreted function: " << z_func_name; - } - return c_op; -} - -void Z3ConvVisitor::VisitUnaryApp(z3::expr z_op) { - DLOG(INFO) << "VisitUnaryApp: " << z_op; - CHECK(z_op.is_app() && z_op.decl().arity() == 1) - << "Z3 expression is not a unary operator!"; - // Get operand - auto c_sub{GetCExpr(z_op.arg(0U))}; - // Get z3 function declaration - auto z_decl{z_op.decl()}; - // Create C unary operator - clang::Expr *c_op{nullptr}; - switch (z_decl.decl_kind()) { - case Z3_OP_NOT: - c_op = ast.CreateLNot(c_sub); - break; - - case Z3_OP_BNOT: - c_op = ast.CreateNot(c_sub); - break; - // Given a `(extract hi lo o)` we generate `(o >> lo & m)` where: - // - // * `o` is the operand from which we extract a bit sequence - // - // * `hi` is the upper bound of the extracted sequence - // - // * `lo` is the lower bound of the extracted sequence - // - // * `m` is a bitmask integer literal - case Z3_OP_EXTRACT: { - if (z_op.lo() != 0U) { - auto shr_val{ast.CreateIntLit(llvm::APInt(32U, z_op.lo()))}; - auto shr{ast.CreateShr(c_sub, shr_val)}; - auto mask_val{ast.CreateIntLit( - llvm::APInt::getAllOnesValue(GetZ3SortSize(z_op)))}; - c_op = ast.CreateAnd(shr, mask_val); - } else { - c_op = ast.CreateCStyleCast( - ast.GetLeastIntTypeForBitWidth(GetZ3SortSize(z_op), - /*sign=*/0), - c_sub); - } - - } break; - - case Z3_OP_ZERO_EXT: - c_op = ast.CreateCStyleCast( - ast.GetLeastIntTypeForBitWidth(GetZ3SortSize(z_op), - /*sign=*/0), - c_sub); - break; - - case Z3_OP_SIGN_EXT: - c_op = ast.CreateCStyleCast( - ast.GetLeastIntTypeForBitWidth(GetZ3SortSize(z_op), - /*sign=*/1), - c_sub); - break; - - default: - LOG(FATAL) << "Unknown Z3 unary operator: " << z_decl.name(); - break; - } - // Save - InsertCExpr(z_op, c_op); -} - -void Z3ConvVisitor::VisitBinaryApp(z3::expr z_op) { - DLOG(INFO) << "VisitBinaryApp: " << z_op; - CHECK(z_op.is_app() && z_op.decl().arity() == 2) - << "Z3 expression is not a binary operator!"; - // Get operands - auto lhs{GetCExpr(z_op.arg(0U))}; - auto rhs{GetCExpr(z_op.arg(1U))}; - // Get z3 function declaration - auto z_decl{z_op.decl()}; - // Create C binary operator - clang::Expr *c_op{nullptr}; - switch (z_decl.decl_kind()) { - // `&&` in z3 can be n-ary, so we create a tree of C binary `&&`. - case Z3_OP_AND: { - c_op = lhs; - for (auto i{1U}; i < z_op.num_args(); ++i) { - rhs = GetCExpr(z_op.arg(i)); - c_op = ast.CreateLAnd(c_op, rhs); - } - } break; - // `||` in z3 can be n-ary, so we create a tree of C binary `||`. - case Z3_OP_OR: { - c_op = lhs; - for (auto i{1U}; i < z_op.num_args(); ++i) { - rhs = GetCExpr(z_op.arg(i)); - c_op = ast.CreateLOr(c_op, rhs); - } - } break; - - case Z3_OP_EQ: - c_op = ast.CreateEQ(lhs, rhs); - break; - - case Z3_OP_SLEQ: - case Z3_OP_ULEQ: - case Z3_OP_FPA_LE: - c_op = ast.CreateLE(lhs, rhs); - break; - - case Z3_OP_SLT: - case Z3_OP_ULT: - case Z3_OP_FPA_LT: - c_op = ast.CreateLT(lhs, rhs); - break; - - case Z3_OP_SGT: - case Z3_OP_UGT: - case Z3_OP_FPA_GT: - c_op = ast.CreateGT(lhs, rhs); - break; - - case Z3_OP_BADD: - c_op = ast.CreateAdd(lhs, rhs); - break; - - case Z3_OP_BLSHR: - c_op = ast.CreateShr(lhs, rhs); - break; - - case Z3_OP_BASHR: { - auto size{c_ctx->getTypeSize(lhs->getType())}; - auto type{ast.GetLeastIntTypeForBitWidth(size, /*sign=*/1U)}; - auto cast{ast.CreateCStyleCast(type, lhs)}; - c_op = ast.CreateShr(cast, rhs); - } break; - - case Z3_OP_BSHL: - c_op = ast.CreateShl(lhs, rhs); - break; - - case Z3_OP_BAND: - c_op = ast.CreateAnd(lhs, rhs); - break; - - case Z3_OP_BOR: - c_op = ast.CreateOr(lhs, rhs); - break; - - case Z3_OP_BXOR: - c_op = ast.CreateXor(lhs, rhs); - break; - - case Z3_OP_BMUL: - c_op = ast.CreateMul(lhs, rhs); - break; - - case Z3_OP_BSDIV: - case Z3_OP_BSDIV_I: - c_op = ast.CreateDiv(lhs, rhs); - break; - - case Z3_OP_BSREM: - case Z3_OP_BSREM_I: - c_op = ast.CreateRem(lhs, rhs); - break; - - case Z3_OP_DISTINCT: - c_op = ast.CreateNE(lhs, rhs); - break; - - // Unknowns - default: - LOG(FATAL) << "Unknown Z3 binary operator: " << z_decl.name(); - break; - } - // Save - InsertCExpr(z_op, c_op); -} - -void Z3ConvVisitor::VisitTernaryApp(z3::expr z_op) { - DLOG(INFO) << "VisitTernaryApp: " << z_op; - CHECK(z_op.is_app() && z_op.decl().arity() == 3) - << "Z3 expression is not a ternary operator!"; - // Create C binary operator - clang::Expr *c_op{nullptr}; - auto z_decl{z_op.decl()}; - switch (z_decl.decl_kind()) { - case Z3_OP_ITE: { - auto z_cond{z_op.arg(0U)}; - auto z_then{z_op.arg(1U)}; - auto z_else{z_op.arg(2U)}; - uint64_t then_val{0U}; - uint64_t else_val{1U}; - if (z_then.is_numeral_u64(then_val) && z_else.is_numeral_u64(else_val) && - then_val == 1U && else_val == 0U) { - c_op = GetCExpr(z_cond); - } else { - c_op = ast.CreateConditional(GetCExpr(z_cond), GetCExpr(z_then), - GetCExpr(z_else)); - } - } break; - // Unknowns - default: - LOG(FATAL) << "Unknown Z3 ternary operator: " << z_decl.name(); - break; - } - // Save - InsertCExpr(z_op, c_op); -} - -void Z3ConvVisitor::VisitZ3Decl(z3::func_decl z_decl) { - THROW() << "Unimplemented Z3 declaration visitor!"; - // switch (z_decl.decl_kind()) { - // case Z3_OP_UNINTERPRETED: - // if (!z_decl.is_const()) { - - // } else { - // LOG(FATAL) << "Unimplemented Z3 declaration!"; - // } - // break; - - // default: - // LOG(FATAL) << "Unknown Z3 function declaration: " << z_decl.name(); - // break; - // } -} - -} // namespace rellic \ No newline at end of file diff --git a/lib/BC/Util.cpp b/lib/BC/Util.cpp index 76052e4c..63f7dd0f 100644 --- a/lib/BC/Util.cpp +++ b/lib/BC/Util.cpp @@ -437,61 +437,4 @@ void ConvertArrayArguments(llvm::Module &m) { CHECK(VerifyModule(&m)) << "Transformation broke module correctness"; } -static void FindRedundantLoads(llvm::Function &func) { - auto HasStoreBeforeUse = [&](llvm::Value *ptr, llvm::User *user, - llvm::LoadInst *load) { - std::unordered_set visited_blocks; - std::vector work_list; - auto inst{llvm::cast(user)}; - work_list.push_back(inst); - while (!work_list.empty()) { - inst = work_list.back(); - work_list.pop_back(); - auto bb{inst->getParent()}; - visited_blocks.insert(bb); - - while (inst) { - if (auto store = llvm::dyn_cast(inst)) { - if (store->getPointerOperand() == ptr) { - return true; - } - } - if (inst == load) { - return false; - } - inst = inst->getPrevNode(); - } - for (auto pred : llvm::predecessors(bb)) { - if (!visited_blocks.count(pred)) { - work_list.push_back(pred->getTerminator()); - } - } - } - - return false; - }; - - for (auto &inst : llvm::instructions(func)) { - if (auto load = llvm::dyn_cast(&inst)) { - for (auto &use : load->uses()) { - auto ptr{load->getPointerOperand()}; - if (!llvm::isa(ptr)) { - continue; - } - if (HasStoreBeforeUse(ptr, use.getUser(), load)) { - continue; - } - load->setMetadata("rellic.notemp", - llvm::MDNode::get(func.getContext(), {})); - } - } - } -} - -void FindRedundantLoads(llvm::Module &module) { - for (auto &func : module.functions()) { - FindRedundantLoads(func); - } -} - } // namespace rellic \ No newline at end of file diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 85ac463d..8d6f590b 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -32,9 +32,9 @@ set(AST_HEADERS "${include_dir}/AST/InferenceRule.h" "${include_dir}/AST/LocalDeclRenamer.h" "${include_dir}/AST/LoopRefine.h" + "${include_dir}/AST/MaterializeConds.h" "${include_dir}/AST/NestedCondProp.h" "${include_dir}/AST/NestedScopeCombine.h" - "${include_dir}/AST/NormalizeCond.h" "${include_dir}/AST/ReachBasedRefine.h" "${include_dir}/AST/StructFieldRenamer.h" "${include_dir}/AST/StructGenerator.h" @@ -42,7 +42,6 @@ set(AST_HEADERS "${include_dir}/AST/TransformVisitor.h" "${include_dir}/AST/Util.h" "${include_dir}/AST/Z3CondSimplify.h" - "${include_dir}/AST/Z3ConvVisitor.h" ) set(BC_HEADERS @@ -72,12 +71,11 @@ set(AST_SOURCES AST/IRToASTVisitor.cpp AST/LocalDeclRenamer.cpp AST/LoopRefine.cpp + AST/MaterializeConds.cpp AST/NestedCondProp.cpp AST/NestedScopeCombine.cpp - AST/NormalizeCond.cpp AST/Util.cpp AST/Z3CondSimplify.cpp - AST/Z3ConvVisitor.cpp AST/ReachBasedRefine.cpp AST/StructFieldRenamer.cpp AST/StructGenerator.cpp diff --git a/lib/Decompiler.cpp b/lib/Decompiler.cpp index 38a28b3b..001d8748 100644 --- a/lib/Decompiler.cpp +++ b/lib/Decompiler.cpp @@ -29,9 +29,9 @@ #include "rellic/AST/IRToASTVisitor.h" #include "rellic/AST/LocalDeclRenamer.h" #include "rellic/AST/LoopRefine.h" +#include "rellic/AST/MaterializeConds.h" #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" -#include "rellic/AST/NormalizeCond.h" #include "rellic/AST/ReachBasedRefine.h" #include "rellic/AST/StructFieldRenamer.h" #include "rellic/AST/Z3CondSimplify.h" @@ -73,7 +73,6 @@ Result Decompile( ConvertArrayArguments(*module); RemoveInsertValues(*module); - FindRedundantLoads(*module); InitOptPasses(); rellic::DebugInfoCollector dic; @@ -84,153 +83,91 @@ Result Decompile( module->getTargetTriple()}; auto ast_unit{clang::tooling::buildASTFromCodeWithArgs("", args, "out.c")}; - rellic::Provenance provenance; - rellic::GenerateAST::run(*module, provenance, *ast_unit); + ASTBuilder ast{*ast_unit}; + rellic::DecompilationContext dec_ctx; + dec_ctx.marker_expr = ast.CreateAdd(ast.CreateFalse(), ast.CreateFalse()); + rellic::GenerateAST::run(*module, dec_ctx, *ast_unit); // TODO(surovic): Add llvm::Value* -> clang::Decl* map // Especially for llvm::Argument* and llvm::Function*. - rellic::CompositeASTPass pass_ast(provenance, *ast_unit); + rellic::CompositeASTPass pass_ast(dec_ctx, *ast_unit); auto& ast_passes{pass_ast.GetPasses()}; - if (options.dead_stmt_elimination) { - ast_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } + ast_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); ast_passes.push_back(std::make_unique( - provenance, *ast_unit, dic.GetIRToNameMap())); + dec_ctx, *ast_unit, dic.GetIRToNameMap())); ast_passes.push_back(std::make_unique( - provenance, *ast_unit, dic.GetIRTypeToDITypeMap())); + dec_ctx, *ast_unit, dic.GetIRTypeToDITypeMap())); pass_ast.Run(); - rellic::CompositeASTPass pass_cbr(provenance, *ast_unit); + rellic::CompositeASTPass pass_cbr(dec_ctx, *ast_unit); auto& cbr_passes{pass_cbr.GetPasses()}; - if (options.condition_based_refinement.expression_normalize) { - cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (!options.disable_z3) { - auto zcs{std::make_unique(provenance, *ast_unit)}; - if (options.condition_based_refinement.z3_cond_simplify) { - cbr_passes.push_back(std::move(zcs)); - } - if (options.condition_based_refinement.nested_cond_propagate) { - cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - } + cbr_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); + cbr_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); - if (options.condition_based_refinement.nested_scope_combine) { - cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } + cbr_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); - if (!options.disable_z3) { - if (options.condition_based_refinement.cond_base_refine) { - cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.condition_based_refinement.reach_based_refine) { - cbr_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - } + cbr_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); + cbr_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); while (pass_cbr.Run()) { ; } - rellic::CompositeASTPass pass_loop{provenance, *ast_unit}; + rellic::CompositeASTPass pass_loop{dec_ctx, *ast_unit}; auto& loop_passes{pass_loop.GetPasses()}; - if (options.loop_refinement.loop_refine) { - loop_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.loop_refinement.nested_cond_propagate) { - loop_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.loop_refinement.nested_scope_combine) { - loop_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.loop_refinement.expression_normalize) { - loop_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } + loop_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); + loop_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); + loop_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); + while (pass_loop.Run()) { ; } - rellic::CompositeASTPass pass_scope{provenance, *ast_unit}; + rellic::CompositeASTPass pass_scope{dec_ctx, *ast_unit}; auto& scope_passes{pass_scope.GetPasses()}; - if (!options.disable_z3) { - auto zcs{std::make_unique(provenance, *ast_unit)}; - if (options.condition_based_refinement.z3_cond_simplify) { - scope_passes.push_back(std::move(zcs)); - } - if (options.condition_based_refinement.nested_cond_propagate) { - scope_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - } + scope_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); + scope_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); + + scope_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); - if (options.scope_refinement.nested_scope_combine) { - scope_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.scope_refinement.expression_normalize) { - scope_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } while (pass_scope.Run()) { ; } - rellic::CompositeASTPass pass_ec{provenance, *ast_unit}; + rellic::CompositeASTPass pass_ec{dec_ctx, *ast_unit}; auto& ec_passes{pass_ec.GetPasses()}; - if (options.expression_combine) { - ec_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.expression_normalize) { - ec_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - while (pass_ec.Run()) { - ; - } + ec_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); + ec_passes.push_back( + std::make_unique(dec_ctx, *ast_unit)); - rellic::CompositeASTPass pass_final{provenance, *ast_unit}; - auto& final_passes{pass_final.GetPasses()}; - if (options.final_refinement.z3_cond_simplify) { - final_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.final_refinement.nested_cond_propagate) { - final_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - if (options.final_refinement.nested_scope_combine) { - final_passes.push_back( - std::make_unique(provenance, *ast_unit)); - } - while (pass_final.Run()) { - ; - } + pass_ec.Run(); DecompilationResult result{}; result.ast = std::move(ast_unit); result.module = std::move(module); - CopyMap(provenance.stmt_provenance, result.stmt_provenance_map, + CopyMap(dec_ctx.stmt_provenance, result.stmt_provenance_map, result.value_to_stmt_map); - CopyMap(provenance.value_decls, result.value_to_decl_map, + CopyMap(dec_ctx.value_decls, result.value_to_decl_map, result.decl_provenance_map); - CopyMap(provenance.type_decls, result.type_to_decl_map, + CopyMap(dec_ctx.type_decls, result.type_to_decl_map, result.type_provenance_map); - CopyMap(provenance.use_provenance, result.expr_use_map, - result.use_expr_map); + CopyMap(dec_ctx.use_provenance, result.expr_use_map, result.use_expr_map); return Result(std::move(result)); } catch (Exception& ex) { diff --git a/tools/decomp/Decomp.cpp b/tools/decomp/Decomp.cpp index 628507b8..f457ad40 100644 --- a/tools/decomp/Decomp.cpp +++ b/tools/decomp/Decomp.cpp @@ -121,7 +121,6 @@ int main(int argc, char* argv[]) { CHECK(!ec) << "Failed to create output file: " << ec.message(); rellic::DecompilationOptions opts{}; - opts.disable_z3 = FLAGS_disable_z3; opts.lower_switches = FLAGS_lower_switch; opts.remove_phi_nodes = FLAGS_remove_phi_nodes; diff --git a/tools/repl/Repl.cpp b/tools/repl/Repl.cpp index 24367398..788de1f8 100644 --- a/tools/repl/Repl.cpp +++ b/tools/repl/Repl.cpp @@ -41,9 +41,9 @@ #include "rellic/AST/IRToASTVisitor.h" #include "rellic/AST/LocalDeclRenamer.h" #include "rellic/AST/LoopRefine.h" +#include "rellic/AST/MaterializeConds.h" #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" -#include "rellic/AST/NormalizeCond.h" #include "rellic/AST/ReachBasedRefine.h" #include "rellic/AST/StructFieldRenamer.h" #include "rellic/AST/Z3CondSimplify.h" @@ -61,7 +61,7 @@ DECLARE_bool(version); llvm::LLVMContext llvm_ctx; std::unique_ptr module{nullptr}; std::unique_ptr ast_unit{nullptr}; -rellic::Provenance provenance; +std::unique_ptr dec_ctx; std::unique_ptr global_pass{nullptr}; static void SetVersion(void) { @@ -93,8 +93,8 @@ static void SetVersion(void) { google::SetVersionString(version.str()); } -static const char* available_passes[] = {"cbr", "dse", "ec", "lr", "ncp", - "nsc", "nc", "rbr", "zcs"}; +static const char* available_passes[] = {"cbr", "dse", "ec", "lr", + "ncp", "nsc", "rbr", "zcs"}; static bool diff = false; @@ -147,23 +147,23 @@ class Diff { static std::unique_ptr CreatePass(const std::string& name) { if (name == "cbr") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "dse") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "ec") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "lr") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); + } else if (name == "mc") { + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "ncp") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "nsc") { - return std::make_unique(provenance, *ast_unit); - } else if (name == "nc") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "rbr") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else if (name == "zcs") { - return std::make_unique(provenance, *ast_unit); + return std::make_unique(*dec_ctx, *ast_unit); } else { return nullptr; } @@ -197,7 +197,7 @@ static void do_help() { << " dse Dead statement elimination\n" << " ec Expression combination\n" << " lr Loop refinement\n" - << " nc Condition normalization\n" + << " mc Condition materialization\n" << " ncp Nested condition propagation\n" << " nsc Nested scope combination\n" << " rbr Reach-based refinement\n" @@ -218,7 +218,7 @@ static void do_load(std::istream& is) { std::vector args{"-Wno-pointer-to-int-cast", "-Wno-pointer-sign", "-target", module->getTargetTriple()}; ast_unit = clang::tooling::buildASTFromCodeWithArgs("", args, "out.c"); - provenance = {}; + dec_ctx = {}; std::cout << "ok." << std::endl; } @@ -300,10 +300,10 @@ static void do_decompile() { try { rellic::DebugInfoCollector dic; dic.visit(*module); - provenance = {}; - rellic::GenerateAST::run(*module, provenance, *ast_unit); - rellic::LocalDeclRenamer ldr{provenance, *ast_unit, dic.GetIRToNameMap()}; - rellic::StructFieldRenamer sfr{provenance, *ast_unit, + dec_ctx = {}; + rellic::GenerateAST::run(*module, *dec_ctx, *ast_unit); + rellic::LocalDeclRenamer ldr{*dec_ctx, *ast_unit, dic.GetIRToNameMap()}; + rellic::StructFieldRenamer sfr{*dec_ctx, *ast_unit, dic.GetIRTypeToDITypeMap()}; ldr.Run(); sfr.Run(); @@ -319,8 +319,7 @@ static void do_run(std::istream& is) { return; } - global_pass = - std::make_unique(provenance, *ast_unit); + global_pass = std::make_unique(*dec_ctx, *ast_unit); std::string name; while (is >> name) { auto pass{CreatePass(name)}; @@ -357,8 +356,7 @@ static void do_fixpoint(std::istream& is) { return; } - global_pass = - std::make_unique(provenance, *ast_unit); + global_pass = std::make_unique(*dec_ctx, *ast_unit); std::string name; while (is >> name) { auto pass{CreatePass(name)}; diff --git a/tools/xref/Xref.cpp b/tools/xref/Xref.cpp index fa4ebff9..c5200006 100644 --- a/tools/xref/Xref.cpp +++ b/tools/xref/Xref.cpp @@ -49,11 +49,12 @@ #include "rellic/AST/IRToASTVisitor.h" #include "rellic/AST/LocalDeclRenamer.h" #include "rellic/AST/LoopRefine.h" +#include "rellic/AST/MaterializeConds.h" #include "rellic/AST/NestedCondProp.h" #include "rellic/AST/NestedScopeCombine.h" -#include "rellic/AST/NormalizeCond.h" #include "rellic/AST/ReachBasedRefine.h" #include "rellic/AST/StructFieldRenamer.h" +#include "rellic/AST/Util.h" #include "rellic/AST/Z3CondSimplify.h" #include "rellic/BC/Util.h" #include "rellic/Decompiler.h" @@ -110,7 +111,7 @@ struct Session { std::unique_ptr Module; std::unique_ptr Unit; std::unique_ptr Pass; - rellic::Provenance Provenance; + std::unique_ptr DecompContext; // Must always be acquired in this order and released all at once std::shared_mutex LoadMutex, MutationMutex; }; @@ -247,7 +248,6 @@ static void LoadModule(const httplib::Request& req, httplib::Response& res) { return; } session.Module = std::unique_ptr(mod); - rellic::FindRedundantLoads(*session.Module); llvm::json::Object msg{{"message", "Ok."}}; SendJSON(res, msg); res.status = 200; @@ -266,18 +266,21 @@ static void Decompile(const httplib::Request& req, httplib::Response& res) { } try { - session.Provenance = {}; + session.DecompContext = std::make_unique(); std::vector args{"-Wno-pointer-to-int-cast", "-Wno-pointer-sign", "-target", session.Module->getTargetTriple()}; session.Unit = clang::tooling::buildASTFromCodeWithArgs("", args, "out.c"); + rellic::ASTBuilder ast{*session.Unit}; + session.DecompContext->marker_expr = + ast.CreateAdd(ast.CreateFalse(), ast.CreateFalse()); rellic::DebugInfoCollector dic; dic.visit(*session.Module); - rellic::GenerateAST::run(*session.Module, session.Provenance, + rellic::GenerateAST::run(*session.Module, *session.DecompContext, *session.Unit); - rellic::LocalDeclRenamer ldr{session.Provenance, *session.Unit, + rellic::LocalDeclRenamer ldr{*session.DecompContext, *session.Unit, dic.GetIRToNameMap()}; - rellic::StructFieldRenamer sfr{session.Provenance, *session.Unit, + rellic::StructFieldRenamer sfr{*session.DecompContext, *session.Unit, dic.GetIRTypeToDITypeMap()}; ldr.Run(); sfr.Run(); @@ -378,8 +381,8 @@ class FixpointPass : public rellic::ASTPass { void RunImpl() override { comp.Fixpoint(); } public: - FixpointPass(rellic::Provenance& provenance, clang::ASTUnit& ast_unit) - : ASTPass(provenance, ast_unit), comp(provenance, ast_unit) {} + FixpointPass(rellic::DecompilationContext& dec_ctx, clang::ASTUnit& ast_unit) + : ASTPass(dec_ctx, ast_unit), comp(dec_ctx, ast_unit) {} std::vector>& GetPasses() { return comp.GetPasses(); } @@ -396,38 +399,39 @@ static std::unique_ptr CreatePass( auto str{name->str()}; if (str == "cbr") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else if (str == "dse") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else if (str == "ec") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else if (str == "lr") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); + } else if (str == "mc") { + return std::make_unique(*session.DecompContext, + *session.Unit); } else if (str == "ncp") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else if (str == "nsc") { - return std::make_unique(session.Provenance, - *session.Unit); - } else if (str == "nc") { - return std::make_unique(session.Provenance, - *session.Unit); + return std::make_unique( + *session.DecompContext, *session.Unit); } else if (str == "rbr") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else if (str == "zcs") { - return std::make_unique(session.Provenance, + return std::make_unique(*session.DecompContext, *session.Unit); } else { LOG(ERROR) << "Request contains invalid pass id"; return nullptr; } } else if (auto arr = val.getAsArray()) { - auto fix{std::make_unique(session.Provenance, *session.Unit)}; + auto fix{ + std::make_unique(*session.DecompContext, *session.Unit)}; for (auto& pass : *arr) { auto p{CreatePass(session, pass)}; if (!p) { @@ -509,8 +513,8 @@ static void Run(const httplib::Request& req, httplib::Response& res) { return; } - auto composite{std::make_unique(session.Provenance, - *session.Unit)}; + auto composite{std::make_unique( + *session.DecompContext, *session.Unit)}; for (auto& obj : *json->getAsArray()) { auto pass{CreatePass(session, obj)}; if (!pass) { @@ -579,8 +583,8 @@ static void Fixpoint(const httplib::Request& req, httplib::Response& res) { return; } - auto composite{std::make_unique(session.Provenance, - *session.Unit)}; + auto composite{std::make_unique( + *session.DecompContext, *session.Unit)}; for (auto& obj : *json->getAsArray()) { auto pass{CreatePass(session, obj)}; if (!pass) { @@ -712,11 +716,12 @@ static void PrintAST(const httplib::Request& req, httplib::Response& res) { rellic::DecompilationResult::IRToTypeDeclMap type_to_decl_map; rellic::DecompilationResult::TypeDeclToIRMap type_provenance_map; - CopyMap(session.Provenance.stmt_provenance, stmt_provenance_map, + CopyMap(session.DecompContext->stmt_provenance, stmt_provenance_map, value_to_stmt_map); - CopyMap(session.Provenance.value_decls, value_to_decl_map, + CopyMap(session.DecompContext->value_decls, value_to_decl_map, decl_provenance_map); - CopyMap(session.Provenance.type_decls, type_to_decl_map, type_provenance_map); + CopyMap(session.DecompContext->type_decls, type_to_decl_map, + type_provenance_map); std::string s; llvm::raw_string_ostream os(s); @@ -797,31 +802,31 @@ static void PrintProvenance(const httplib::Request& req, } llvm::json::Array stmt_provenance; - for (auto elem : session.Provenance.stmt_provenance) { + for (auto elem : session.DecompContext->stmt_provenance) { stmt_provenance.push_back(llvm::json::Array( {(unsigned long long)elem.first, (unsigned long long)elem.second})); } llvm::json::Array type_decls; - for (auto elem : session.Provenance.type_decls) { + for (auto elem : session.DecompContext->type_decls) { type_decls.push_back(llvm::json::Array( {(unsigned long long)elem.first, (unsigned long long)elem.second})); } llvm::json::Array value_decls; - for (auto elem : session.Provenance.value_decls) { + for (auto elem : session.DecompContext->value_decls) { value_decls.push_back(llvm::json::Array( {(unsigned long long)elem.first, (unsigned long long)elem.second})); } llvm::json::Array temp_decls; - for (auto elem : session.Provenance.temp_decls) { + for (auto elem : session.DecompContext->temp_decls) { temp_decls.push_back(llvm::json::Array( {(unsigned long long)elem.first, (unsigned long long)elem.second})); } llvm::json::Array use_provenance; - for (auto elem : session.Provenance.use_provenance) { + for (auto elem : session.DecompContext->use_provenance) { if (!elem.second) { continue; } diff --git a/tools/xref/www/main.js b/tools/xref/www/main.js index 24f5eea9..074e87b5 100644 --- a/tools/xref/www/main.js +++ b/tools/xref/www/main.js @@ -32,9 +32,9 @@ const ec = { id: "ec", label: "Expression combination" } -const nc = { - id: "nc", - label: "Condition normalization" +const mc = { + id: "mc", + label: "Materialize conditions" } Vue.component('list-comp', { @@ -106,7 +106,7 @@ const app = new Vue({ rbr, lr, ec, - nc + mc ], actions: [ { @@ -208,20 +208,20 @@ const app = new Vue({ throw (await res.json()).message } const prov = await res.json() - for(let map in prov) { - for(let [from, to] of prov[map]) { - if(!from || !to) { + for (let map in prov) { + for (let [from, to] of prov[map]) { + if (!from || !to) { continue } const from_hex = from.toString(16) const to_hex = to.toString(16) - if(!this.provenance[from_hex]) { + if (!this.provenance[from_hex]) { this.provenance[from_hex] = [] } this.provenance[from_hex].push(to_hex) - if(!this.provenance[to_hex]) { + if (!this.provenance[to_hex]) { this.provenance[to_hex] = [] } this.provenance[to_hex].push(from_hex) @@ -316,8 +316,7 @@ const app = new Vue({ [zcs, ncp, nsc, cbr, rbr], [lr, nsc], [zcs, ncp, nsc], - ec, - [zcs, ncp, nsc] + mc, ec ] }, openAngha() { diff --git a/unittests/CMakeLists.txt b/unittests/CMakeLists.txt index 164abd67..684d0fb8 100644 --- a/unittests/CMakeLists.txt +++ b/unittests/CMakeLists.txt @@ -8,15 +8,14 @@ add_executable(${RELLIC_UNITTEST} AST/ASTBuilder.cpp AST/StructGenerator.cpp AST/Util.cpp - AST/Z3ConvVisitor.cpp UnitTest.cpp ) target_link_libraries(${RELLIC_UNITTEST} PRIVATE - "${PROJECT_NAME}_cxx_settings" - "${PROJECT_NAME}" - z3::libz3 - doctest::doctest + "${PROJECT_NAME}_cxx_settings" + "${PROJECT_NAME}" + z3::libz3 + doctest::doctest ) target_compile_options(${RELLIC_UNITTEST} PRIVATE -fexceptions)