Skip to content

Commit

Permalink
Z3 all the things (#282)
Browse files Browse the repository at this point in the history
* WIP: Z3 all the things

* Fix bugs

* Fix bug

* Combine whiles

* Improve simplification

* Better condition generation. Still slow.

* Revert to simple substitution

* Refactor ZCS

* Refactor NCP

* Fix NCP bugs

* Fix condition sharing bug

* Fix CBR

* Temporarily disable RBR

* Improve `else-if` recognition

* Improve switch conditions generation

* Factor out `Sort`

* Sort during simplification, use faster simplify

* Improve NCP

* Remove RBR

* Fix bug

* Try improve performance

* Reintroduce RBR

* Fix NCP

* Remove `IsConstant`

Too slow

* Reactivate RBR

* Remove ignored tests

* Fix CBR & RBR

* Factor out RBR lambda

* CBR & RBR review

* Cleanup

* Remove `nc` from tools

* Maybe this time `HeavySimplify` will work?

* Rename "Provenance" to "DecompilationContext"

* Remove temporary variables for `load`s

* Rename `POISON_IDX`

* Put all aliases inside `DecompilationContext`

* Rename `Sort` to `OrderById`

* Revert useless change

* Simplify exit

* Short-circuit slow condition

* Add explanation

* Add comments for `KnownExprs`

* Map expression directly by id

* Add small comment

* Factor out CBR

* Rename parameter for clarity

* Factor out `ToExpr`

* Early returns

* More cleanup

* More cleanup

* Factor out stuff in NCP

* Create wrapper method for Z3 expressions

* Add comments for maps

* Cleanup and comment `ConvertExpr`

* More comments

* Remove old comment

* More comments
  • Loading branch information
frabert authored Sep 6, 2022
1 parent cd60369 commit fb4a663
Show file tree
Hide file tree
Showing 47 changed files with 1,010 additions and 2,570 deletions.
4 changes: 0 additions & 4 deletions ci/angha_1k_test_settings.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
{
"tests.ignore": [
"amd64/linux/tools/perf/bench/extr_numa.c___bench_numa.bc",
"amd64/SoftEtherVPN/src/See/extr_memory_t.h_SW_LONG_AT.bc",
"arm64/linux/tools/perf/bench/extr_numa.c___bench_numa.bc",
"arm64/SoftEtherVPN/src/See/extr_memory_t.h_SW_LONG_AT.bc",
"armv7/linux/tools/perf/bench/extr_numa.c___bench_numa.bc",
"armv7/SoftEtherVPN/src/See/extr_memory_t.h_SW_LONG_AT.bc",
"x86/linux/tools/perf/bench/extr_numa.c___bench_numa.bc",
"x86/SoftEtherVPN/src/See/extr_memory_t.h_SW_LONG_AT.bc"
]
}
10 changes: 5 additions & 5 deletions include/rellic/AST/ASTPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ASTPass {
std::atomic_bool stop{false};

protected:
Provenance& provenance;
DecompilationContext& dec_ctx;
clang::ASTUnit& ast_unit;
clang::ASTContext& ast_ctx;
ASTBuilder ast;
Expand All @@ -32,8 +32,8 @@ class ASTPass {
virtual void StopImpl() {}

public:
ASTPass(Provenance& provenance, clang::ASTUnit& ast_unit)
: provenance(provenance),
ASTPass(DecompilationContext& dec_ctx, clang::ASTUnit& ast_unit)
: dec_ctx(dec_ctx),
ast_unit(ast_unit),
ast_ctx(ast_unit.getASTContext()),
ast(ast_unit) {}
Expand Down Expand Up @@ -89,8 +89,8 @@ class CompositeASTPass : public ASTPass {
}

public:
CompositeASTPass(Provenance& provenance, clang::ASTUnit& ast_unit)
: ASTPass(provenance, ast_unit) {}
CompositeASTPass(DecompilationContext& dec_ctx, clang::ASTUnit& ast_unit)
: ASTPass(dec_ctx, ast_unit) {}
std::vector<std::unique_ptr<ASTPass>>& GetPasses() { return passes; }
};
} // namespace rellic
10 changes: 1 addition & 9 deletions include/rellic/AST/CondBasedRefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "rellic/AST/ASTPass.h"
#include "rellic/AST/IRToASTVisitor.h"
#include "rellic/AST/TransformVisitor.h"
#include "rellic/AST/Z3ConvVisitor.h"

namespace rellic {

Expand All @@ -35,18 +34,11 @@ namespace rellic {
*/
class CondBasedRefine : public TransformVisitor<CondBasedRefine> {
private:
std::unique_ptr<z3::context> z3_ctx;
std::unique_ptr<rellic::Z3ConvVisitor> z3_gen;

z3::tactic z3_solver;

z3::expr GetZ3Cond(clang::IfStmt *ifstmt);

protected:
void RunImpl() override;

public:
CondBasedRefine(Provenance &provenance, clang::ASTUnit &unit);
CondBasedRefine(DecompilationContext &dec_ctx, clang::ASTUnit &unit);

bool VisitCompoundStmt(clang::CompoundStmt *compound);
};
Expand Down
2 changes: 1 addition & 1 deletion include/rellic/AST/DeadStmtElim.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DeadStmtElim : public TransformVisitor<DeadStmtElim> {
void RunImpl() override;

public:
DeadStmtElim(Provenance &provenance, clang::ASTUnit &unit);
DeadStmtElim(DecompilationContext &dec_ctx, clang::ASTUnit &unit);

bool VisitIfStmt(clang::IfStmt *ifstmt);
bool VisitCompoundStmt(clang::CompoundStmt *compound);
Expand Down
2 changes: 1 addition & 1 deletion include/rellic/AST/ExprCombine.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ExprCombine : public TransformVisitor<ExprCombine> {
void RunImpl() override;

public:
ExprCombine(Provenance &provenance, clang::ASTUnit &unit);
ExprCombine(DecompilationContext &dec_ctx, clang::ASTUnit &unit);

bool VisitCStyleCastExpr(clang::CStyleCastExpr *cast);
bool VisitUnaryOperator(clang::UnaryOperator *op);
Expand Down
48 changes: 22 additions & 26 deletions include/rellic/AST/GenerateAST.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <llvm/IR/PassManager.h>
#include <z3++.h>

#include <limits>
#include <map>
#include <unordered_set>

Expand All @@ -25,28 +26,14 @@ class GenerateAST : public llvm::AnalysisInfoMixin<GenerateAST> {
friend llvm::AnalysisInfoMixin<GenerateAST>;
static llvm::AnalysisKey Key;

// Need to use `map` with these instead of `unordered_map`, because
// `std::pair` doesn't have a default hash implementation
using BBEdge = std::pair<llvm::BasicBlock *, llvm::BasicBlock *>;
using BrEdge = std::pair<llvm::BranchInst *, bool>;
using SwEdge = std::pair<llvm::SwitchInst *, llvm::ConstantInt *>;
constexpr static unsigned poison_idx = std::numeric_limits<unsigned>::max();
z3::expr ToExpr(unsigned idx);

clang::ASTUnit &unit;
clang::ASTContext *ast_ctx;
rellic::IRToASTVisitor ast_gen;
rellic::ASTBuilder ast;
z3::context *z_ctx;

Provenance &provenance;

z3::expr_vector z_exprs;
std::unordered_map<unsigned, BrEdge> z_br_edges_inv;
std::map<BrEdge, unsigned> z_br_edges;

std::unordered_map<unsigned, SwEdge> z_sw_edges_inv;
std::map<SwEdge, unsigned> z_sw_edges;

std::map<BBEdge, unsigned> z_edges;
std::unordered_map<llvm::BasicBlock *, unsigned> reaching_conds;
DecompilationContext &dec_ctx;
bool reaching_conds_changed{true};
std::unordered_map<llvm::BasicBlock *, clang::IfStmt *> block_stmts;
std::unordered_map<llvm::Region *, clang::CompoundStmt *> region_stmts;
Expand All @@ -57,16 +44,25 @@ class GenerateAST : public llvm::AnalysisInfoMixin<GenerateAST> {

std::vector<llvm::BasicBlock *> rpo_walk;

z3::expr GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond);
z3::expr GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst,
// GetOrCreateEdgeForBranch(branch, true) will return the index of an
// expression that is true when branch is taken.
// Viceversa, GetOrCreateEdgeForBranch(branch, false) is an expression that
// will be true when branch is not taken
unsigned GetOrCreateEdgeForBranch(llvm::BranchInst *inst, bool cond);

// Returns the index of an expression containing a numerical variable that
// represents the condition of a switch.
unsigned GetOrCreateVarForSwitch(llvm::SwitchInst *inst);
// Returns the index of an expression that is true when a particular case of a
// switch is taken. If c is nullptr, the expression for the default case will
// be returned.
unsigned GetOrCreateEdgeForSwitch(llvm::SwitchInst *inst,
llvm::ConstantInt *c);

z3::expr GetOrCreateEdgeCond(llvm::BasicBlock *from, llvm::BasicBlock *to);
z3::expr GetReachingCond(llvm::BasicBlock *block);
unsigned GetOrCreateEdgeCond(llvm::BasicBlock *from, llvm::BasicBlock *to);
unsigned GetReachingCond(llvm::BasicBlock *block);
void CreateReachingCond(llvm::BasicBlock *block);

clang::Expr *ConvertExpr(z3::expr expr);

std::vector<clang::Stmt *> CreateBasicBlockStmts(llvm::BasicBlock *block);
std::vector<clang::Stmt *> CreateRegionStmts(llvm::Region *region);

Expand All @@ -82,12 +78,12 @@ class GenerateAST : public llvm::AnalysisInfoMixin<GenerateAST> {

public:
using Result = llvm::PreservedAnalyses;
GenerateAST(Provenance &provenance, clang::ASTUnit &unit);
GenerateAST(DecompilationContext &dec_ctx, clang::ASTUnit &unit);

Result run(llvm::Function &F, llvm::FunctionAnalysisManager &FAM);
Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM);

static void run(llvm::Module &M, Provenance &provenance,
static void run(llvm::Module &M, DecompilationContext &dec_ctx,
clang::ASTUnit &unit);
};

Expand Down
5 changes: 3 additions & 2 deletions include/rellic/AST/IRToASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ class IRToASTVisitor {

ASTBuilder ast;

Provenance &provenance;
DecompilationContext &dec_ctx;

void VisitArgument(llvm::Argument &arg);

public:
IRToASTVisitor(clang::ASTUnit &unit, Provenance &provenance);
IRToASTVisitor(clang::ASTUnit &unit, DecompilationContext &dec_ctx);

clang::Expr *CreateOperandExpr(llvm::Use &val);
clang::Expr *CreateConstantExpr(llvm::Constant *constant);
clang::Expr *ConvertExpr(z3::expr expr);

void VisitGlobalVar(llvm::GlobalVariable &var);
void VisitFunctionDecl(llvm::Function &func);
Expand Down
4 changes: 2 additions & 2 deletions include/rellic/AST/InferenceRule.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ class InferenceRule : public clang::ast_matchers::MatchFinder::MatchCallback {
return cond;
}

virtual clang::Stmt *GetOrCreateSubstitution(Provenance &provenance,
virtual clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx,
clang::ASTUnit &unit,
clang::Stmt *stmt) = 0;
};

clang::Stmt *ApplyFirstMatchingRule(
Provenance &provenance, clang::ASTUnit &unit, clang::Stmt *stmt,
DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt,
std::vector<std::unique_ptr<InferenceRule>> &rules);

} // namespace rellic
2 changes: 1 addition & 1 deletion include/rellic/AST/LocalDeclRenamer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class LocalDeclRenamer : public TransformVisitor<LocalDeclRenamer> {
void RunImpl() override;

public:
LocalDeclRenamer(Provenance &provenance, clang::ASTUnit &unit,
LocalDeclRenamer(DecompilationContext &dec_ctx, clang::ASTUnit &unit,
IRToNameMap &names);

bool shouldTraversePostOrder() override;
Expand Down
2 changes: 1 addition & 1 deletion include/rellic/AST/LoopRefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class LoopRefine : public TransformVisitor<LoopRefine> {
void RunImpl() override;

public:
LoopRefine(Provenance &provenance, clang::ASTUnit &unit);
LoopRefine(DecompilationContext &dec_ctx, clang::ASTUnit &unit);

bool VisitWhileStmt(clang::WhileStmt *loop);
};
Expand Down
37 changes: 37 additions & 0 deletions include/rellic/AST/MaterializeConds.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2021-present, Trail of Bits, Inc.
* All rights reserved.
*
* This source code is licensed in accordance with the terms specified in
* the LICENSE file found in the root directory of this source tree.
*/

#pragma once

#include <clang/AST/ASTContext.h>

#include "rellic/AST/IRToASTVisitor.h"
#include "rellic/AST/TransformVisitor.h"

namespace rellic {

/*
* This pass substitutes the marker expression in loops and `if` statements for
* their translation from Z3 formulas
*/
class MaterializeConds : public TransformVisitor<MaterializeConds> {
private:
IRToASTVisitor ast_gen;

protected:
void RunImpl() override;

public:
MaterializeConds(DecompilationContext &dec_ctx, clang::ASTUnit &unit);

bool VisitIfStmt(clang::IfStmt *stmt);
bool VisitWhileStmt(clang::WhileStmt *loop);
bool VisitDoStmt(clang::DoStmt *loop);
};

} // namespace rellic
2 changes: 1 addition & 1 deletion include/rellic/AST/NestedCondProp.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class NestedCondProp : public ASTPass {
void RunImpl() override;

public:
NestedCondProp(Provenance &provenance, clang::ASTUnit &unit);
NestedCondProp(DecompilationContext& dec_ctx, clang::ASTUnit& unit);
};

} // namespace rellic
3 changes: 2 additions & 1 deletion include/rellic/AST/NestedScopeCombine.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ class NestedScopeCombine : public TransformVisitor<NestedScopeCombine> {
void RunImpl() override;

public:
NestedScopeCombine(Provenance &provenance, clang::ASTUnit &unit);
NestedScopeCombine(DecompilationContext &dec_ctx, clang::ASTUnit &unit);

bool VisitIfStmt(clang::IfStmt *ifstmt);
bool VisitWhileStmt(clang::WhileStmt *stmt);
bool VisitCompoundStmt(clang::CompoundStmt *compound);
};

Expand Down
33 changes: 0 additions & 33 deletions include/rellic/AST/NormalizeCond.h

This file was deleted.

8 changes: 1 addition & 7 deletions include/rellic/AST/ReachBasedRefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#pragma once

#include "rellic/AST/TransformVisitor.h"
#include "rellic/AST/Z3ConvVisitor.h"

namespace rellic {

Expand Down Expand Up @@ -38,16 +37,11 @@ namespace rellic {
*/
class ReachBasedRefine : public TransformVisitor<ReachBasedRefine> {
private:
std::unique_ptr<z3::context> z3_ctx;
std::unique_ptr<rellic::Z3ConvVisitor> z3_gen;

z3::expr GetZ3Cond(clang::IfStmt *ifstmt);

protected:
void RunImpl() override;

public:
ReachBasedRefine(Provenance &provenance, clang::ASTUnit &unit);
ReachBasedRefine(DecompilationContext &dec_ctx, clang::ASTUnit &unit);

bool VisitCompoundStmt(clang::CompoundStmt *compound);
};
Expand Down
2 changes: 1 addition & 1 deletion include/rellic/AST/StructFieldRenamer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class StructFieldRenamer
void RunImpl() override;

public:
StructFieldRenamer(Provenance &provenance, clang::ASTUnit &unit,
StructFieldRenamer(DecompilationContext &dec_ctx, clang::ASTUnit &unit,
IRTypeToDITypeMap &types);

bool VisitRecordDecl(clang::RecordDecl *decl);
Expand Down
8 changes: 4 additions & 4 deletions include/rellic/AST/TransformVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ class TransformVisitor : public ASTPass,
StmtSubMap substitutions;

void CopyProvenance(clang::Stmt *from, clang::Stmt *to) {
::rellic::CopyProvenance(from, to, provenance.stmt_provenance);
::rellic::CopyProvenance(from, to, dec_ctx.stmt_provenance);
}

void CopyProvenance(clang::Expr *from, clang::Expr *to) {
::rellic::CopyProvenance(from, to, provenance.use_provenance);
::rellic::CopyProvenance(from, to, dec_ctx.use_provenance);
}

bool ReplaceChildren(clang::Stmt *stmt, StmtSubMap &repl_map) {
Expand All @@ -55,8 +55,8 @@ class TransformVisitor : public ASTPass,
void RunImpl() override { substitutions.clear(); }

public:
TransformVisitor(Provenance &provenance, clang::ASTUnit &unit)
: ASTPass(provenance, unit) {}
TransformVisitor(DecompilationContext &dec_ctx, clang::ASTUnit &unit)
: ASTPass(dec_ctx, unit) {}

virtual bool shouldTraversePostOrder() { return true; }

Expand Down
Loading

0 comments on commit fb4a663

Please sign in to comment.