Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring ast_unit, ast_ctx, ast into dec_ctx #303

Merged
merged 4 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions include/rellic/AST/ASTPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,15 @@ class ASTPass {

protected:
DecompilationContext& dec_ctx;
clang::ASTUnit& ast_unit;
clang::ASTContext& ast_ctx;
ASTBuilder ast;

bool changed{false};

virtual void RunImpl() = 0;
virtual void StopImpl() {}

public:
ASTPass(DecompilationContext& dec_ctx, clang::ASTUnit& ast_unit)
: dec_ctx(dec_ctx),
ast_unit(ast_unit),
ast_ctx(ast_unit.getASTContext()),
ast(ast_unit) {}
ASTPass(DecompilationContext& dec_ctx)
: dec_ctx(dec_ctx) {}
virtual ~ASTPass() = default;
void Stop() {
stop = true;
Expand Down Expand Up @@ -89,8 +83,8 @@ class CompositeASTPass : public ASTPass {
}

public:
CompositeASTPass(DecompilationContext& dec_ctx, clang::ASTUnit& ast_unit)
: ASTPass(dec_ctx, ast_unit) {}
CompositeASTPass(DecompilationContext& dec_ctx)
: ASTPass(dec_ctx) {}
std::vector<std::unique_ptr<ASTPass>>& GetPasses() { return passes; }
};
} // namespace rellic
2 changes: 1 addition & 1 deletion include/rellic/AST/CondBasedRefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class CondBasedRefine : public TransformVisitor<CondBasedRefine> {
void RunImpl() override;

public:
CondBasedRefine(DecompilationContext &dec_ctx, clang::ASTUnit &unit);
CondBasedRefine(DecompilationContext &dec_ctx);

bool VisitCompoundStmt(clang::CompoundStmt *compound);
};
Expand Down
2 changes: 1 addition & 1 deletion include/rellic/AST/DeadStmtElim.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DeadStmtElim : public TransformVisitor<DeadStmtElim> {
void RunImpl() override;

public:
DeadStmtElim(DecompilationContext &dec_ctx, clang::ASTUnit &unit);
DeadStmtElim(DecompilationContext &dec_ctx);

bool VisitIfStmt(clang::IfStmt *ifstmt);
bool VisitCompoundStmt(clang::CompoundStmt *compound);
Expand Down
2 changes: 1 addition & 1 deletion include/rellic/AST/ExprCombine.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ExprCombine : public TransformVisitor<ExprCombine> {
void RunImpl() override;

public:
ExprCombine(DecompilationContext &dec_ctx, clang::ASTUnit &unit);
ExprCombine(DecompilationContext &dec_ctx);

bool VisitCStyleCastExpr(clang::CStyleCastExpr *cast);
bool VisitUnaryOperator(clang::UnaryOperator *op);
Expand Down
10 changes: 4 additions & 6 deletions include/rellic/AST/GenerateAST.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <map>
#include <unordered_set>

#include "rellic/AST/ASTBuilder.h"
#include "rellic/AST/IRToASTVisitor.h"

namespace rellic {
Expand All @@ -29,11 +30,9 @@ class GenerateAST : public llvm::AnalysisInfoMixin<GenerateAST> {
constexpr static unsigned poison_idx = std::numeric_limits<unsigned>::max();
z3::expr ToExpr(unsigned idx);

clang::ASTUnit &unit;
clang::ASTContext *ast_ctx;
rellic::IRToASTVisitor ast_gen;
rellic::ASTBuilder ast;
DecompilationContext &dec_ctx;
ASTBuilder &ast;
bool reaching_conds_changed{true};
std::unordered_map<llvm::BasicBlock *, clang::IfStmt *> block_stmts;
std::unordered_map<llvm::Region *, clang::CompoundStmt *> region_stmts;
Expand Down Expand Up @@ -78,13 +77,12 @@ class GenerateAST : public llvm::AnalysisInfoMixin<GenerateAST> {

public:
using Result = llvm::PreservedAnalyses;
GenerateAST(DecompilationContext &dec_ctx, clang::ASTUnit &unit);
GenerateAST(DecompilationContext &dec_ctx);

Result run(llvm::Function &F, llvm::FunctionAnalysisManager &FAM);
Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM);

static void run(llvm::Module &M, DecompilationContext &dec_ctx,
clang::ASTUnit &unit);
static void run(llvm::Module &M, DecompilationContext &dec_ctx);
};

} // namespace rellic
8 changes: 2 additions & 6 deletions include/rellic/AST/IRToASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,13 @@
namespace rellic {
class IRToASTVisitor {
private:
clang::ASTUnit &ast_unit;
clang::ASTContext &ast_ctx;

ASTBuilder ast;

DecompilationContext &dec_ctx;
ASTBuilder &ast;

void VisitArgument(llvm::Argument &arg);

public:
IRToASTVisitor(clang::ASTUnit &unit, DecompilationContext &dec_ctx);
IRToASTVisitor(DecompilationContext &dec_ctx);

clang::Expr *CreateOperandExpr(llvm::Use &val);
clang::Expr *CreateConstantExpr(llvm::Constant *constant);
Expand Down
3 changes: 1 addition & 2 deletions include/rellic/AST/InferenceRule.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@ class InferenceRule : public clang::ast_matchers::MatchFinder::MatchCallback {
}

virtual clang::Stmt *GetOrCreateSubstitution(DecompilationContext &dec_ctx,
clang::ASTUnit &unit,
clang::Stmt *stmt) = 0;
};

clang::Stmt *ApplyFirstMatchingRule(
DecompilationContext &dec_ctx, clang::ASTUnit &unit, clang::Stmt *stmt,
DecompilationContext &dec_ctx, clang::Stmt *stmt,
std::vector<std::unique_ptr<InferenceRule>> &rules);

} // namespace rellic
3 changes: 1 addition & 2 deletions include/rellic/AST/LocalDeclRenamer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class LocalDeclRenamer : public TransformVisitor<LocalDeclRenamer> {
void RunImpl() override;

public:
LocalDeclRenamer(DecompilationContext &dec_ctx, clang::ASTUnit &unit,
IRToNameMap &names);
LocalDeclRenamer(DecompilationContext &dec_ctx, IRToNameMap &names);

bool shouldTraversePostOrder() override;
bool VisitVarDecl(clang::VarDecl *decl);
Expand Down
2 changes: 1 addition & 1 deletion include/rellic/AST/LoopRefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class LoopRefine : public TransformVisitor<LoopRefine> {
void RunImpl() override;

public:
LoopRefine(DecompilationContext &dec_ctx, clang::ASTUnit &unit);
LoopRefine(DecompilationContext &dec_ctx);

bool VisitWhileStmt(clang::WhileStmt *loop);
};
Expand Down
2 changes: 1 addition & 1 deletion include/rellic/AST/MaterializeConds.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class MaterializeConds : public TransformVisitor<MaterializeConds> {
void RunImpl() override;

public:
MaterializeConds(DecompilationContext &dec_ctx, clang::ASTUnit &unit);
MaterializeConds(DecompilationContext &dec_ctx);

bool VisitIfStmt(clang::IfStmt *stmt);
bool VisitWhileStmt(clang::WhileStmt *loop);
Expand Down
2 changes: 1 addition & 1 deletion include/rellic/AST/NestedCondProp.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class NestedCondProp : public ASTPass {
void RunImpl() override;

public:
NestedCondProp(DecompilationContext& dec_ctx, clang::ASTUnit& unit);
NestedCondProp(DecompilationContext& dec_ctx);
};

} // namespace rellic
2 changes: 1 addition & 1 deletion include/rellic/AST/NestedScopeCombine.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class NestedScopeCombine : public TransformVisitor<NestedScopeCombine> {
void RunImpl() override;

public:
NestedScopeCombine(DecompilationContext &dec_ctx, clang::ASTUnit &unit);
NestedScopeCombine(DecompilationContext &dec_ctx);

bool VisitIfStmt(clang::IfStmt *ifstmt);
bool VisitWhileStmt(clang::WhileStmt *stmt);
Expand Down
2 changes: 1 addition & 1 deletion include/rellic/AST/ReachBasedRefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ReachBasedRefine : public TransformVisitor<ReachBasedRefine> {
void RunImpl() override;

public:
ReachBasedRefine(DecompilationContext &dec_ctx, clang::ASTUnit &unit);
ReachBasedRefine(DecompilationContext &dec_ctx);

bool VisitCompoundStmt(clang::CompoundStmt *compound);
};
Expand Down
3 changes: 1 addition & 2 deletions include/rellic/AST/StructFieldRenamer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class StructFieldRenamer
void RunImpl() override;

public:
StructFieldRenamer(DecompilationContext &dec_ctx, clang::ASTUnit &unit,
IRTypeToDITypeMap &types);
StructFieldRenamer(DecompilationContext &dec_ctx, IRTypeToDITypeMap &types);

bool VisitRecordDecl(clang::RecordDecl *decl);
};
Expand Down
3 changes: 1 addition & 2 deletions include/rellic/AST/TransformVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ class TransformVisitor : public ASTPass,
void RunImpl() override { substitutions.clear(); }

public:
TransformVisitor(DecompilationContext &dec_ctx, clang::ASTUnit &unit)
: ASTPass(dec_ctx, unit) {}
TransformVisitor(DecompilationContext &dec_ctx) : ASTPass(dec_ctx) {}

virtual bool shouldTraversePostOrder() { return true; }

Expand Down
7 changes: 7 additions & 0 deletions include/rellic/AST/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include <clang/AST/ASTContext.h>
#include <clang/AST/DeclBase.h>
#include <clang/AST/Stmt.h>
#include <clang/Frontend/ASTUnit.h>
Expand Down Expand Up @@ -66,6 +67,12 @@ struct DecompilationContext {
using BrEdge = std::pair<llvm::BranchInst *, bool>;
using SwEdge = std::pair<llvm::SwitchInst *, llvm::ConstantInt *>;

DecompilationContext(clang::ASTUnit &ast_unit);

clang::ASTUnit &ast_unit;
clang::ASTContext &ast_ctx;
ASTBuilder ast;

StmtToIRMap stmt_provenance;
ExprToUseMap use_provenance;
IRToTypeDeclMap type_decls;
Expand Down
2 changes: 1 addition & 1 deletion include/rellic/AST/Z3CondSimplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Z3CondSimplify : public ASTPass {
void RunImpl() override;

public:
Z3CondSimplify(DecompilationContext& dec_ctx, clang::ASTUnit& unit);
Z3CondSimplify(DecompilationContext& dec_ctx);
};

} // namespace rellic
21 changes: 10 additions & 11 deletions lib/AST/CondBasedRefine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

namespace rellic {

CondBasedRefine::CondBasedRefine(DecompilationContext &dec_ctx,
clang::ASTUnit &unit)
: TransformVisitor<CondBasedRefine>(dec_ctx, unit) {}
CondBasedRefine::CondBasedRefine(DecompilationContext &dec_ctx)
: TransformVisitor<CondBasedRefine>(dec_ctx) {}

bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) {
std::vector<clang::Stmt *> body{compound->body_begin(), compound->body_end()};
Expand Down Expand Up @@ -52,9 +51,9 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) {
// becomes
// if(a) { X1; X2; } else { Y1; Y2; }
new_then_body.push_back(then_b);
auto new_then{ast.CreateCompoundStmt(new_then_body)};
auto new_then{dec_ctx.ast.CreateCompoundStmt(new_then_body)};

new_if = ast.CreateIf(dec_ctx.marker_expr, new_then);
new_if = dec_ctx.ast.CreateIf(dec_ctx.marker_expr, new_then);

if (else_a || else_b) {
// At least one of the two `if` statements has an `else` branch
Expand All @@ -68,7 +67,7 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) {
new_else_body.push_back(else_b);
}

auto new_else{ast.CreateCompoundStmt(new_else_body)};
auto new_else{dec_ctx.ast.CreateCompoundStmt(new_else_body)};
new_if->setElse(new_else);
}

Expand All @@ -86,17 +85,17 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) {
new_then_body.push_back(else_b);
}

auto new_then{ast.CreateCompoundStmt(new_then_body)};
auto new_then{dec_ctx.ast.CreateCompoundStmt(new_then_body)};

std::vector<clang::Stmt *> new_else_body{};
if (else_a) {
new_else_body.push_back(else_a);
}
new_else_body.push_back(then_b);

new_if = ast.CreateIf(dec_ctx.marker_expr, new_then);
new_if = dec_ctx.ast.CreateIf(dec_ctx.marker_expr, new_then);

auto new_else{ast.CreateCompoundStmt(new_else_body)};
auto new_else{dec_ctx.ast.CreateCompoundStmt(new_else_body)};
new_if->setElse(new_else);

did_something = true;
Expand All @@ -109,15 +108,15 @@ bool CondBasedRefine::VisitCompoundStmt(clang::CompoundStmt *compound) {
}
}
if (did_something) {
substitutions[compound] = ast.CreateCompoundStmt(body);
substitutions[compound] = dec_ctx.ast.CreateCompoundStmt(body);
}
return !Stopped();
}

void CondBasedRefine::RunImpl() {
LOG(INFO) << "Condition-based refinement";
TransformVisitor<CondBasedRefine>::RunImpl();
TraverseDecl(ast_ctx.getTranslationUnitDecl());
TraverseDecl(dec_ctx.ast_ctx.getTranslationUnitDecl());
}

} // namespace rellic
10 changes: 5 additions & 5 deletions lib/AST/DeadStmtElim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

namespace rellic {

DeadStmtElim::DeadStmtElim(DecompilationContext &dec_ctx, clang::ASTUnit &unit)
: TransformVisitor<DeadStmtElim>(dec_ctx, unit) {}
DeadStmtElim::DeadStmtElim(DecompilationContext &dec_ctx)
: TransformVisitor<DeadStmtElim>(dec_ctx) {}

bool DeadStmtElim::VisitIfStmt(clang::IfStmt *ifstmt) {
// DLOG(INFO) << "VisitIfStmt";
Expand All @@ -40,7 +40,7 @@ bool DeadStmtElim::VisitCompoundStmt(clang::CompoundStmt *compound) {
}
// Add only necessary statements
if (auto expr = clang::dyn_cast<clang::Expr>(stmt)) {
if (expr->HasSideEffects(ast_ctx)) {
if (expr->HasSideEffects(dec_ctx.ast_ctx)) {
new_body.push_back(stmt);
}
} else if (!clang::dyn_cast<clang::NullStmt>(stmt)) {
Expand All @@ -49,15 +49,15 @@ bool DeadStmtElim::VisitCompoundStmt(clang::CompoundStmt *compound) {
}
// Create the a new compound
if (changed || new_body.size() < compound->size()) {
substitutions[compound] = ast.CreateCompoundStmt(new_body);
substitutions[compound] = dec_ctx.ast.CreateCompoundStmt(new_body);
}
return !Stopped();
}

void DeadStmtElim::RunImpl() {
LOG(INFO) << "Eliminating dead statements";
TransformVisitor<DeadStmtElim>::RunImpl();
TraverseDecl(ast_ctx.getTranslationUnitDecl());
TraverseDecl(dec_ctx.ast_ctx.getTranslationUnitDecl());
}

} // namespace rellic
Loading