diff --git a/include/rellic/AST/DecompilationContext.h b/include/rellic/AST/DecompilationContext.h new file mode 100644 index 00000000..b9ba9417 --- /dev/null +++ b/include/rellic/AST/DecompilationContext.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2021-present, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include "rellic/AST/ASTBuilder.h" +#include "rellic/AST/TypeProvider.h" + +namespace rellic { + +struct DecompilationContext { + using StmtToIRMap = std::unordered_map; + using ExprToUseMap = std::unordered_map; + using IRToTypeDeclMap = std::unordered_map; + using IRToValDeclMap = std::unordered_map; + using IRToStmtMap = std::unordered_map; + using ArgToTempMap = std::unordered_map; + using BlockToUsesMap = + std::unordered_map>; + using Z3CondMap = std::unordered_map; + + using BBEdge = std::pair; + using BrEdge = std::pair; + using SwEdge = std::pair; + + DecompilationContext(clang::ASTUnit &ast_unit); + + clang::ASTUnit &ast_unit; + clang::ASTContext &ast_ctx; + ASTBuilder ast; + + std::unique_ptr type_provider; + + StmtToIRMap stmt_provenance; + ExprToUseMap use_provenance; + IRToTypeDeclMap type_decls; + IRToValDeclMap value_decls; + ArgToTempMap temp_decls; + BlockToUsesMap outgoing_uses; + z3::context z3_ctx; + z3::expr_vector z3_exprs{z3_ctx}; + Z3CondMap conds; + + clang::Expr *marker_expr; + + std::unordered_map z3_br_edges_inv; + + // Pairs do not have a std::hash specialization so we can't use unordered maps + // here. If this turns out to be a performance issue, investigate adding hash + // specializations for these specifically + std::map z3_br_edges; + + std::unordered_map z3_sw_vars; + std::unordered_map z3_sw_vars_inv; + std::map z3_sw_edges; + + std::map z3_edges; + std::unordered_map reaching_conds; + + size_t num_literal_structs = 0; + size_t num_declared_structs = 0; + + // Inserts an expression into z3_exprs and returns its index + unsigned InsertZExpr(const z3::expr &e); + + clang::QualType GetQualType(llvm::Type *type); +}; + +} // namespace rellic \ No newline at end of file diff --git a/include/rellic/AST/TypeProvider.h b/include/rellic/AST/TypeProvider.h new file mode 100644 index 00000000..e3137d44 --- /dev/null +++ b/include/rellic/AST/TypeProvider.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2022-present, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include + +#include "rellic/AST/ASTBuilder.h" + +namespace rellic { +struct DecompilationContext; + +class TypeProvider { + protected: + DecompilationContext& dec_ctx; + + public: + TypeProvider(DecompilationContext& dec_ctx); + virtual ~TypeProvider(); + + // Returns the return type of a function if available. + // A null return value is assumed to mean that no info is available. + virtual clang::QualType GetFunctionReturnType(llvm::Function& func); + + // Returns the type of the argument if available. + // A null return value is assumed to mean that no info is available. + virtual clang::QualType GetArgumentType(llvm::Argument& arg); + + // Returns the type of a global variable if available. + // A null return value is assumed to mean that no info is available. + virtual clang::QualType GetGlobalVarType(llvm::GlobalVariable& gvar); +}; + +class TypeProviderCombiner : public TypeProvider { + private: + std::vector> providers; + + public: + TypeProviderCombiner(DecompilationContext& dec_ctx); + template + void AddProvider(TArgs&&... args) { + providers.push_back( + std::make_unique(dec_ctx, std::forward(args)...)); + } + + void AddProvider(std::unique_ptr provider); + + clang::QualType GetFunctionReturnType(llvm::Function& func) override; + clang::QualType GetArgumentType(llvm::Argument& arg) override; + clang::QualType GetGlobalVarType(llvm::GlobalVariable& gvar) override; +}; +} // namespace rellic \ No newline at end of file diff --git a/include/rellic/AST/Util.h b/include/rellic/AST/Util.h index 2971ad51..e42363e9 100644 --- a/include/rellic/AST/Util.h +++ b/include/rellic/AST/Util.h @@ -8,17 +8,7 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "rellic/AST/ASTBuilder.h" +#include "rellic/AST/DecompilationContext.h" namespace rellic { @@ -52,60 +42,6 @@ size_t GetNumDecls(clang::DeclContext *decl_ctx) { return result; } -struct DecompilationContext { - using StmtToIRMap = std::unordered_map; - using ExprToUseMap = std::unordered_map; - using IRToTypeDeclMap = std::unordered_map; - using IRToValDeclMap = std::unordered_map; - using IRToStmtMap = std::unordered_map; - using ArgToTempMap = std::unordered_map; - using BlockToUsesMap = - std::unordered_map>; - using Z3CondMap = std::unordered_map; - - using BBEdge = std::pair; - using BrEdge = std::pair; - using SwEdge = std::pair; - - DecompilationContext(clang::ASTUnit &ast_unit); - - clang::ASTUnit &ast_unit; - clang::ASTContext &ast_ctx; - ASTBuilder ast; - - StmtToIRMap stmt_provenance; - ExprToUseMap use_provenance; - IRToTypeDeclMap type_decls; - IRToValDeclMap value_decls; - ArgToTempMap temp_decls; - BlockToUsesMap outgoing_uses; - z3::context z3_ctx; - z3::expr_vector z3_exprs{z3_ctx}; - Z3CondMap conds; - - clang::Expr *marker_expr; - - std::unordered_map z3_br_edges_inv; - - // Pairs do not have a std::hash specialization so we can't use unordered maps - // here. If this turns out to be a performance issue, investigate adding hash - // specializations for these specifically - std::map z3_br_edges; - - std::unordered_map z3_sw_vars; - std::unordered_map z3_sw_vars_inv; - std::map z3_sw_edges; - - std::map z3_edges; - std::unordered_map reaching_conds; - - size_t num_literal_structs = 0; - size_t num_declared_structs = 0; - - // Inserts an expression into z3_exprs and returns its index - unsigned InsertZExpr(const z3::expr &e); -}; - template void CopyProvenance(TKey1 *from, TKey2 *to, std::unordered_map &map) { diff --git a/include/rellic/Decompiler.h b/include/rellic/Decompiler.h index a62fc659..bfee1e43 100644 --- a/include/rellic/Decompiler.h +++ b/include/rellic/Decompiler.h @@ -13,13 +13,39 @@ #include #include +#include #include "Result.h" +#include "rellic/AST/TypeProvider.h" namespace rellic { + +/* This additional level of indirection is needed to alleviate the users from + * the burden of having to instantiate custom TypeProviders before the actual + * DecompilationContext has been created */ +class TypeProviderFactory { + public: + virtual ~TypeProviderFactory() = default; + virtual std::unique_ptr create(DecompilationContext& ctx) = 0; +}; + +template +class SimpleTypeProviderFactory final : public TypeProviderFactory { + public: + std::unique_ptr create(DecompilationContext& ctx) override { + return std::make_unique(ctx); + } +}; + struct DecompilationOptions { + using TypeProviderFactoryPtr = std::unique_ptr; + bool lower_switches = false; bool remove_phi_nodes = false; + + // Additional type providers to be used during code generation. + // Providers added later will have higher priority. + std::vector additional_providers; }; struct DecompilationResult { diff --git a/lib/AST/IRToASTVisitor.cpp b/lib/AST/IRToASTVisitor.cpp index cd232f47..8e033d61 100755 --- a/lib/AST/IRToASTVisitor.cpp +++ b/lib/AST/IRToASTVisitor.cpp @@ -20,6 +20,7 @@ #include #include "rellic/AST/IRToASTVisitor.h" +#include "rellic/AST/TypeProvider.h" #include "rellic/BC/Util.h" #include "rellic/Exception.h" @@ -37,7 +38,6 @@ class ExprGen : public llvm::InstVisitor { : dec_ctx(dec_ctx), ast(dec_ctx.ast), ast_ctx(dec_ctx.ast_ctx) {} void VisitGlobalVar(llvm::GlobalVariable &gvar); - clang::QualType GetQualType(llvm::Type *type); clang::Expr *CreateConstantExpr(llvm::Constant *constant); clang::Expr *CreateLiteralExpr(llvm::Constant *constant); @@ -159,7 +159,7 @@ void ExprGen::VisitGlobalVar(llvm::GlobalVariable &gvar) { return; } - auto type{gvar.getValueType()}; + auto type{dec_ctx.type_provider->GetGlobalVarType(gvar)}; auto tudecl{ast_ctx.getTranslationUnitDecl()}; auto name{gvar.getName().str()}; if (name.empty()) { @@ -167,7 +167,7 @@ void ExprGen::VisitGlobalVar(llvm::GlobalVariable &gvar) { } // Create a variable declaration - var = ast.CreateVarDecl(tudecl, GetQualType(type), name); + var = ast.CreateVarDecl(tudecl, type, name); // Add to translation unit tudecl->addDecl(var); @@ -181,128 +181,6 @@ void ExprGen::VisitGlobalVar(llvm::GlobalVariable &gvar) { } } -clang::QualType ExprGen::GetQualType(llvm::Type *type) { - DLOG(INFO) << "GetQualType: " << LLVMThingToString(type); - - clang::QualType result; - switch (type->getTypeID()) { - case llvm::Type::VoidTyID: - result = ast_ctx.VoidTy; - break; - - case llvm::Type::HalfTyID: - result = ast_ctx.HalfTy; - break; - - case llvm::Type::FloatTyID: - result = ast_ctx.FloatTy; - break; - - case llvm::Type::DoubleTyID: - result = ast_ctx.DoubleTy; - break; - - case llvm::Type::X86_FP80TyID: - result = ast_ctx.LongDoubleTy; - break; - - case llvm::Type::FP128TyID: - result = ast_ctx.Float128Ty; - break; - - case llvm::Type::IntegerTyID: { - auto size{type->getIntegerBitWidth()}; - CHECK(size > 0) << "Integer bit width has to be greater than 0"; - result = ast.GetLeastIntTypeForBitWidth(size, /*sign=*/0); - } break; - - case llvm::Type::FunctionTyID: { - auto func{llvm::cast(type)}; - auto ret{GetQualType(func->getReturnType())}; - std::vector params; - for (auto param : func->params()) { - params.push_back(GetQualType(param)); - } - auto epi{clang::FunctionProtoType::ExtProtoInfo()}; - epi.Variadic = func->isVarArg(); - result = ast_ctx.getFunctionType(ret, params, epi); - } break; - - case llvm::Type::PointerTyID: { - auto ptr_type{llvm::cast(type)}; - if (ptr_type->isOpaque()) { - result = ast_ctx.VoidPtrTy; - } else { - result = ast_ctx.getPointerType( - GetQualType(ptr_type->getNonOpaquePointerElementType())); - } - } break; - - case llvm::Type::ArrayTyID: { - auto arr{llvm::cast(type)}; - auto elm{GetQualType(arr->getElementType())}; - result = ast_ctx.getConstantArrayType( - elm, llvm::APInt(64, arr->getNumElements()), nullptr, - clang::ArrayType::ArraySizeModifier::Normal, 0); - } break; - - case llvm::Type::StructTyID: { - clang::RecordDecl *sdecl{nullptr}; - auto &decl{dec_ctx.type_decls[type]}; - if (!decl) { - auto tudecl{ast_ctx.getTranslationUnitDecl()}; - auto strct{llvm::cast(type)}; - auto sname{strct->isLiteral() - ? ("literal_struct_" + - std::to_string(dec_ctx.num_literal_structs++)) - : strct->getName().str()}; - if (sname.empty()) { - sname = "struct" + std::to_string(dec_ctx.num_declared_structs++); - } - - // Create a C struct declaration - decl = sdecl = ast.CreateStructDecl(tudecl, sname); - - // Add fields to the C struct - for (auto ecnt{0U}; ecnt < strct->getNumElements(); ++ecnt) { - auto etype{GetQualType(strct->getElementType(ecnt))}; - auto fname{"field" + std::to_string(ecnt)}; - sdecl->addDecl(ast.CreateFieldDecl(sdecl, etype, fname)); - } - - // Complete the C struct definition - sdecl->completeDefinition(); - // Add C struct to translation unit - tudecl->addDecl(sdecl); - - } else { - sdecl = clang::cast(decl); - } - result = ast_ctx.getRecordType(sdecl); - } break; - - case llvm::Type::MetadataTyID: - result = ast_ctx.VoidPtrTy; - break; - - default: { - if (type->isVectorTy()) { - auto vtype{llvm::cast(type)}; - auto etype{GetQualType(vtype->getElementType())}; - auto ecnt{vtype->getNumElements()}; - auto vkind{clang::VectorType::GenericVector}; - result = ast_ctx.getVectorType(etype, ecnt, vkind); - } else { - THROW() << "Unknown LLVM Type: " << LLVMThingToString(type); - } - } break; - } - - CHECK_THROW(!result.isNull()) << "Unknown LLVM Type"; - - return result; -} - clang::Expr *ExprGen::CreateConstantExpr(llvm::Constant *constant) { if (auto gvar = llvm::dyn_cast(constant)) { VisitGlobalVar(*gvar); @@ -330,7 +208,7 @@ clang::Expr *ExprGen::CreateLiteralExpr(llvm::Constant *constant) { clang::Expr *result{nullptr}; auto l_type{constant->getType()}; - auto c_type{GetQualType(l_type)}; + auto c_type{dec_ctx.GetQualType(l_type)}; auto CreateInitListLiteral{[this, &constant] { std::vector init_exprs; @@ -457,7 +335,7 @@ clang::Expr *ExprGen::CreateOperandExpr(llvm::Use &val) { auto func{arg->getParent()}; auto fdecl{dec_ctx.value_decls[func]->getAsFunction()}; auto argdecl{clang::cast(dec_ctx.value_decls[arg])}; - temp = ast.CreateVarDecl(fdecl, GetQualType(arg->getType()), + temp = ast.CreateVarDecl(fdecl, dec_ctx.GetQualType(arg->getType()), argdecl->getName().str() + "_ptr"); temp->setInit(addr_of_arg); fdecl->addDecl(temp); @@ -576,8 +454,8 @@ clang::Expr *ExprGen::visitCallInst(llvm::CallInst &inst) { auto &arg{inst.getArgOperandUse(i)}; auto opnd{CreateOperandExpr(arg)}; if (inst.getParamAttr(i, llvm::Attribute::ByVal).isValid()) { - auto ptr_type{ - ast_ctx.getPointerType(GetQualType(inst.getParamByValType(i)))}; + auto ptr_type{ast_ctx.getPointerType( + dec_ctx.GetQualType(inst.getParamByValType(i)))}; opnd = ast.CreateDeref(ast.CreateCStyleCast(ptr_type, opnd)); dec_ctx.use_provenance[opnd] = &arg; } @@ -592,7 +470,8 @@ clang::Expr *ExprGen::visitCallInst(llvm::CallInst &inst) { callexpr = ast.CreateCall(fdecl, args); } else { // Cast function type to match the one used in the call instruction - auto funcPtr{ast_ctx.getPointerType(GetQualType(inst.getFunctionType()))}; + auto funcPtr{ + ast_ctx.getPointerType(dec_ctx.GetQualType(inst.getFunctionType()))}; auto callee{ast.CreateAddrOf(ast.CreateDeclRef(fdecl))}; auto cast{ast.CreateCStyleCast(funcPtr, callee)}; callexpr = ast.CreateCall(cast, args); @@ -601,7 +480,8 @@ clang::Expr *ExprGen::visitCallInst(llvm::CallInst &inst) { auto fdecl{dec_ctx.value_decls[iasm]->getAsFunction()}; callexpr = ast.CreateCall(fdecl, args); } else if (llvm::isa(callee->getType())) { - auto funcPtr{ast_ctx.getPointerType(GetQualType(inst.getFunctionType()))}; + auto funcPtr{ + ast_ctx.getPointerType(dec_ctx.GetQualType(inst.getFunctionType()))}; auto cast{ast.CreateCStyleCast(funcPtr, CreateOperandExpr(callee))}; callexpr = ast.CreateCall(cast, args); } else { @@ -677,11 +557,11 @@ clang::Expr *ExprGen::visitGetElementPtrInst(llvm::GetElementPtrInst &inst) { auto base{CreateOperandExpr(ptr_opnd)}; auto ptr_type{ - ast_ctx.getPointerType(GetQualType(inst.getSourceElementType()))}; + ast_ctx.getPointerType(dec_ctx.GetQualType(inst.getSourceElementType()))}; base = ast.CreateCStyleCast(ptr_type, base); for (auto &idx : llvm::make_range(inst.idx_begin(), inst.idx_end())) { - GetQualType(indexed_type); + dec_ctx.GetQualType(indexed_type); switch (indexed_type->getTypeID()) { // Initial pointer case llvm::Type::PointerTyID: { @@ -718,7 +598,7 @@ clang::Expr *ExprGen::visitGetElementPtrInst(llvm::GetElementPtrInst &inst) { if (indexed_type->isVectorTy()) { auto l_vec_ty{llvm::cast(indexed_type)}; auto l_elm_ty{l_vec_ty->getElementType()}; - auto c_elm_ty{GetQualType(l_elm_ty)}; + auto c_elm_ty{dec_ctx.GetQualType(l_elm_ty)}; base = ast.CreateCStyleCast(ast_ctx.getPointerType(c_elm_ty), ast.CreateAddrOf(base)); base = ast.CreateArraySub(base, CreateOperandExpr(idx)); @@ -745,11 +625,11 @@ clang::Expr *ExprGen::visitExtractValueInst(llvm::ExtractValueInst &inst) { auto base{CreateOperandExpr(inst.getOperandUse(0))}; auto indexed_type{inst.getAggregateOperand()->getType()}; if (clang::isa(base)) { - base = ast.CreateCompoundLit(GetQualType(indexed_type), base); + base = ast.CreateCompoundLit(dec_ctx.GetQualType(indexed_type), base); } for (auto idx : llvm::make_range(inst.idx_begin(), inst.idx_end())) { - GetQualType(indexed_type); + dec_ctx.GetQualType(indexed_type); switch (indexed_type->getTypeID()) { // Arrays case llvm::Type::ArrayTyID: { @@ -783,7 +663,7 @@ clang::Expr *ExprGen::visitExtractValueInst(llvm::ExtractValueInst &inst) { clang::Expr *ExprGen::visitLoadInst(llvm::LoadInst &inst) { DLOG(INFO) << "visitLoadInst: " << LLVMThingToString(&inst); - auto ptr_type{ast_ctx.getPointerType(GetQualType(inst.getType()))}; + auto ptr_type{ast_ctx.getPointerType(dec_ctx.GetQualType(inst.getType()))}; auto cast{ ast.CreateCStyleCast(ptr_type, CreateOperandExpr(inst.getOperandUse(0)))}; return ast.CreateDeref(cast); @@ -984,7 +864,7 @@ clang::Expr *ExprGen::visitCastInst(llvm::CastInst &inst) { // Get a C-language expression of the operand auto operand{CreateOperandExpr(inst.getOperandUse(0))}; // Get destination type - auto type{GetQualType(inst.getType())}; + auto type{dec_ctx.GetQualType(inst.getType())}; // Adjust type switch (inst.getOpcode()) { case llvm::CastInst::Trunc: { @@ -1099,7 +979,7 @@ clang::Stmt *StmtGen::visitStoreInst(llvm::StoreInst &inst) { // Stores in LLVM IR correspond to value assignments in C // Get the operand we're assigning to auto ptr_type{ - ast_ctx.getPointerType(expr_gen.GetQualType(value_opnd->getType()))}; + ast_ctx.getPointerType(dec_ctx.GetQualType(value_opnd->getType()))}; auto lhs{ast.CreateCStyleCast( ptr_type, expr_gen.CreateOperandExpr( inst.getOperandUse(inst.getPointerOperandIndex())))}; @@ -1200,40 +1080,9 @@ void IRToASTVisitor::VisitArgument(llvm::Argument &arg) { // Get parent function declaration auto func{arg.getParent()}; auto fdecl{clang::cast(dec_ctx.value_decls[func])}; - auto argtype{arg.getType()}; - if (arg.hasByValAttr()) { - auto byval{arg.getAttribute(llvm::Attribute::ByVal)}; - argtype = byval.getValueAsType(); - } - ExprGen expr_gen{dec_ctx}; + auto argtype{dec_ctx.type_provider->GetArgumentType(arg)}; // Create a declaration - parm = ast.CreateParamDecl(fdecl, expr_gen.GetQualType(argtype), name); -} - -// This function fixes function types for those functions that have arguments -// that are passed by value using the `byval` attribute. -// They need special treatment because those arguments, instead of actually -// being passed by value, are instead passed "by reference" from a bitcode point -// of view, with the caveat that the actual semantics are more like "create a -// copy of the reference before calling, and pass a pointer to that copy -// instead" (this is done implicitly). -// Thus, we need to convert a function type like -// i32 @do_foo(%struct.foo* byval(%struct.foo) align 4 %f) -// into -// i32 @do_foo(%struct.foo %f) -static llvm::FunctionType *GetFixedFunctionType(llvm::Function &func) { - std::vector new_arg_types{}; - - for (auto &arg : func.args()) { - if (arg.hasByValAttr()) { - new_arg_types.push_back(arg.getParamByValType()); - } else { - new_arg_types.push_back(arg.getType()); - } - } - - return llvm::FunctionType::get(func.getReturnType(), new_arg_types, - func.isVarArg()); + parm = ast.CreateParamDecl(fdecl, argtype, name); } void IRToASTVisitor::VisitBasicBlock(llvm::BasicBlock &block, @@ -1273,9 +1122,16 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { DLOG(INFO) << "Creating FunctionDecl for " << name; auto tudecl{dec_ctx.ast_ctx.getTranslationUnitDecl()}; - ExprGen expr_gen{dec_ctx}; - auto type{expr_gen.GetQualType(GetFixedFunctionType(func))}; - decl = ast.CreateFunctionDecl(tudecl, type, name); + + std::vector arg_types; + for (auto &arg : func.args()) { + arg_types.push_back(dec_ctx.type_provider->GetArgumentType(arg)); + } + auto ret_type{dec_ctx.type_provider->GetFunctionReturnType(func)}; + clang::FunctionProtoType::ExtProtoInfo epi; + epi.Variadic = func.isVarArg(); + auto ftype{dec_ctx.ast_ctx.getFunctionType(ret_type, arg_types, epi)}; + decl = ast.CreateFunctionDecl(tudecl, ftype, name); tudecl->addDecl(decl); @@ -1312,7 +1168,7 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { // storage for parameters e.g. a parameter named "foo" has a corresponding // local variable named "foo_addr"). var = ast.CreateVarDecl( - fdecl, expr_gen.GetQualType(alloca->getAllocatedType()), name); + fdecl, dec_ctx.GetQualType(alloca->getAllocatedType()), name); fdecl->addDecl(var); } else if (inst.hasNUsesOrMore(2) || (inst.hasNUsesOrMore(1) && llvm::isa(inst)) || @@ -1330,7 +1186,7 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { auto name{GetPrefix(&inst) + std::to_string(GetNumDecls(fdecl))}; - auto type{expr_gen.GetQualType(inst.getType())}; + auto type{dec_ctx.GetQualType(inst.getType())}; if (auto arrayType = clang::dyn_cast(type)) { type = dec_ctx.ast_ctx.getPointerType(arrayType->getElementType()); } @@ -1362,12 +1218,12 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { auto name{"asm_" + std::to_string(GetNumDecls(tudecl))}; auto ftype{iasm->getFunctionType()}; - auto type{expr_gen.GetQualType(ftype)}; + auto type{dec_ctx.GetQualType(ftype)}; decl = ast.CreateFunctionDecl(tudecl, type, name); std::vector iasm_params; for (auto arg : ftype->params()) { - auto arg_type{expr_gen.GetQualType(arg)}; + auto arg_type{dec_ctx.GetQualType(arg)}; auto name{"arg_" + std::to_string(iasm_params.size())}; iasm_params.push_back( ast.CreateParamDecl(decl->getDeclContext(), arg_type, name)); diff --git a/lib/AST/TypeProvider.cpp b/lib/AST/TypeProvider.cpp new file mode 100644 index 00000000..7b64b543 --- /dev/null +++ b/lib/AST/TypeProvider.cpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2022-present, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#include "rellic/AST/TypeProvider.h" + +#include "rellic/AST/Util.h" + +namespace rellic { +TypeProvider::TypeProvider(DecompilationContext& dec_ctx) : dec_ctx(dec_ctx) {} +TypeProvider::~TypeProvider() = default; + +clang::QualType TypeProvider::GetFunctionReturnType(llvm::Function&) { + return {}; +} + +clang::QualType TypeProvider::GetArgumentType(llvm::Argument&) { return {}; } + +clang::QualType TypeProvider::GetGlobalVarType(llvm::GlobalVariable&) { + return {}; +} + +// Defers to DecompilationContext::GetQualType +class FallbackTypeProvider : public TypeProvider { + public: + FallbackTypeProvider(DecompilationContext& dec_ctx); + clang::QualType GetFunctionReturnType(llvm::Function& func) override; + clang::QualType GetArgumentType(llvm::Argument& arg) override; + clang::QualType GetGlobalVarType(llvm::GlobalVariable& gvar) override; +}; + +FallbackTypeProvider::FallbackTypeProvider(DecompilationContext& dec_ctx) + : TypeProvider(dec_ctx) {} + +clang::QualType FallbackTypeProvider::GetFunctionReturnType( + llvm::Function& func) { + return dec_ctx.GetQualType(func.getReturnType()); +} + +clang::QualType FallbackTypeProvider::GetArgumentType(llvm::Argument& arg) { + return dec_ctx.GetQualType(arg.getType()); +} + +clang::QualType FallbackTypeProvider::GetGlobalVarType( + llvm::GlobalVariable& gvar) { + return dec_ctx.GetQualType(gvar.getValueType()); +} + +// Fixes function arguments that have a byval attribute +class ByValFixupTypeProvider : public TypeProvider { + public: + ByValFixupTypeProvider(DecompilationContext& dec_ctx); + + // This function fixes types for those arguments that are passed by value + // using the `byval` attribute. They need special treatment because those + // arguments, instead of actually being passed by value, are instead passed + // "by reference" from a bitcode point of view, with the caveat that the + // actual semantics are more like "create a copy of the reference before + // calling, and pass a pointer to that copy instead" (this is done + // implicitly). Thus, we need to convert a function type like + // + // `i32 @do_foo(%struct.foo* byval(%struct.foo) align 4 %f)` + // + // into + // + // `i32 @do_foo(%struct.foo %f)` + clang::QualType GetArgumentType(llvm::Argument& arg) override; +}; + +ByValFixupTypeProvider::ByValFixupTypeProvider(DecompilationContext& dec_ctx) + : TypeProvider(dec_ctx) {} + +clang::QualType ByValFixupTypeProvider::GetArgumentType(llvm::Argument& arg) { + if (!arg.hasByValAttr()) { + return {}; + } + + auto byval{arg.getAttribute(llvm::Attribute::ByVal)}; + return dec_ctx.GetQualType(byval.getValueAsType()); +} + +// Fixes the function signature for `main` +class MainFuncTypeProvider : public TypeProvider { + public: + MainFuncTypeProvider(DecompilationContext& dec_ctx); + clang::QualType GetFunctionReturnType(llvm::Function& func) override; + clang::QualType GetArgumentType(llvm::Argument& arg) override; +}; + +MainFuncTypeProvider::MainFuncTypeProvider(DecompilationContext& dec_ctx) + : TypeProvider(dec_ctx) {} +clang::QualType MainFuncTypeProvider::GetFunctionReturnType( + llvm::Function& func) { + if (func.getName() != "main") { + return {}; + } + return dec_ctx.ast_ctx.IntTy; +} + +clang::QualType MainFuncTypeProvider::GetArgumentType(llvm::Argument& arg) { + if (arg.getParent()->getName() != "main") { + return {}; + } + + auto arg_no{arg.getArgNo()}; + switch (arg.getArgNo()) { + case 0: // argc + return dec_ctx.ast_ctx.IntTy; + case 1: // argv and envp + case 2: { + auto char_ty{dec_ctx.ast_ctx.CharTy}; + auto char_ptr{dec_ctx.ast_ctx.getPointerType(char_ty)}; + auto char_ptr_ptr{dec_ctx.ast_ctx.getPointerType(char_ptr)}; + return char_ptr_ptr; + } + default: + return {}; + } +} + +TypeProviderCombiner::TypeProviderCombiner(DecompilationContext& dec_ctx) + : TypeProvider(dec_ctx) { + AddProvider(); + AddProvider(); + AddProvider(); +} + +void TypeProviderCombiner::AddProvider(std::unique_ptr provider) { + providers.push_back(std::move(provider)); +} + +clang::QualType TypeProviderCombiner::GetFunctionReturnType( + llvm::Function& func) { + for (auto it{providers.rbegin()}; it != providers.rend(); ++it) { + auto& provider{*it}; + auto res{provider->GetFunctionReturnType(func)}; + if (!res.isNull()) { + return res; + } + } + return {}; +} + +clang::QualType TypeProviderCombiner::GetArgumentType(llvm::Argument& arg) { + for (auto it{providers.rbegin()}; it != providers.rend(); ++it) { + auto& provider{*it}; + auto res{provider->GetArgumentType(arg)}; + if (!res.isNull()) { + return res; + } + } + return {}; +} + +clang::QualType TypeProviderCombiner::GetGlobalVarType( + llvm::GlobalVariable& gvar) { + for (auto it{providers.rbegin()}; it != providers.rend(); ++it) { + auto& provider{*it}; + auto res{provider->GetGlobalVarType(gvar)}; + if (!res.isNull()) { + return res; + } + } + return {}; +} +} // namespace rellic \ No newline at end of file diff --git a/lib/AST/Util.cpp b/lib/AST/Util.cpp index dfc3158c..e6c1da99 100644 --- a/lib/AST/Util.cpp +++ b/lib/AST/Util.cpp @@ -19,6 +19,8 @@ #include #include "rellic/AST/ASTBuilder.h" +#include "rellic/AST/TypeProvider.h" +#include "rellic/BC/Util.h" #include "rellic/Exception.h" namespace rellic { @@ -391,11 +393,137 @@ DecompilationContext::DecompilationContext(clang::ASTUnit &ast_unit) : ast_unit(ast_unit), ast_ctx(ast_unit.getASTContext()), ast(ast_unit), - marker_expr(ast.CreateAdd(ast.CreateFalse(), ast.CreateFalse())) {} + marker_expr(ast.CreateAdd(ast.CreateFalse(), ast.CreateFalse())), + type_provider(std::make_unique(*this)) {} unsigned DecompilationContext::InsertZExpr(const z3::expr &e) { auto idx{z3_exprs.size()}; z3_exprs.push_back(e); return idx; } + +clang::QualType DecompilationContext::GetQualType(llvm::Type *type) { + DLOG(INFO) << "GetQualType: " << LLVMThingToString(type); + + clang::QualType result; + switch (type->getTypeID()) { + case llvm::Type::VoidTyID: + result = ast_ctx.VoidTy; + break; + + case llvm::Type::HalfTyID: + result = ast_ctx.HalfTy; + break; + + case llvm::Type::FloatTyID: + result = ast_ctx.FloatTy; + break; + + case llvm::Type::DoubleTyID: + result = ast_ctx.DoubleTy; + break; + + case llvm::Type::X86_FP80TyID: + result = ast_ctx.LongDoubleTy; + break; + + case llvm::Type::FP128TyID: + result = ast_ctx.Float128Ty; + break; + + case llvm::Type::IntegerTyID: { + auto size{type->getIntegerBitWidth()}; + CHECK(size > 0) << "Integer bit width has to be greater than 0"; + if (size == 8) { + result = ast_ctx.CharTy; + } else { + result = ast.GetLeastIntTypeForBitWidth(size, /*sign=*/0); + } + } break; + + case llvm::Type::FunctionTyID: { + auto func{llvm::cast(type)}; + auto ret{GetQualType(func->getReturnType())}; + std::vector params; + for (auto param : func->params()) { + params.push_back(GetQualType(param)); + } + auto epi{clang::FunctionProtoType::ExtProtoInfo()}; + epi.Variadic = func->isVarArg(); + result = ast_ctx.getFunctionType(ret, params, epi); + } break; + + case llvm::Type::PointerTyID: { + auto ptr_type{llvm::cast(type)}; + if (ptr_type->isOpaque()) { + result = ast_ctx.VoidPtrTy; + } else { + result = ast_ctx.getPointerType( + GetQualType(ptr_type->getNonOpaquePointerElementType())); + } + } break; + + case llvm::Type::ArrayTyID: { + auto arr{llvm::cast(type)}; + auto elm{GetQualType(arr->getElementType())}; + result = ast_ctx.getConstantArrayType( + elm, llvm::APInt(64, arr->getNumElements()), nullptr, + clang::ArrayType::ArraySizeModifier::Normal, 0); + } break; + + case llvm::Type::StructTyID: { + clang::RecordDecl *sdecl{nullptr}; + auto &decl{type_decls[type]}; + if (!decl) { + auto tudecl{ast_ctx.getTranslationUnitDecl()}; + auto strct{llvm::cast(type)}; + auto sname{strct->isLiteral() ? ("literal_struct_" + + std::to_string(num_literal_structs++)) + : strct->getName().str()}; + if (sname.empty()) { + sname = "struct" + std::to_string(num_declared_structs++); + } + + // Create a C struct declaration + decl = sdecl = ast.CreateStructDecl(tudecl, sname); + + // Add fields to the C struct + for (auto ecnt{0U}; ecnt < strct->getNumElements(); ++ecnt) { + auto etype{GetQualType(strct->getElementType(ecnt))}; + auto fname{"field" + std::to_string(ecnt)}; + sdecl->addDecl(ast.CreateFieldDecl(sdecl, etype, fname)); + } + + // Complete the C struct definition + sdecl->completeDefinition(); + // Add C struct to translation unit + tudecl->addDecl(sdecl); + + } else { + sdecl = clang::cast(decl); + } + result = ast_ctx.getRecordType(sdecl); + } break; + + case llvm::Type::MetadataTyID: + result = ast_ctx.VoidPtrTy; + break; + + default: { + if (type->isVectorTy()) { + auto vtype{llvm::cast(type)}; + auto etype{GetQualType(vtype->getElementType())}; + auto ecnt{vtype->getNumElements()}; + auto vkind{clang::VectorType::GenericVector}; + result = ast_ctx.getVectorType(etype, ecnt, vkind); + } else { + THROW() << "Unknown LLVM Type: " << LLVMThingToString(type); + } + } break; + } + + CHECK_THROW(!result.isNull()) << "Unknown LLVM Type"; + + return result; +} } // namespace rellic diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 8d6f590b..d9646b55 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -40,6 +40,7 @@ set(AST_HEADERS "${include_dir}/AST/StructGenerator.h" "${include_dir}/AST/SubprogramGenerator.h" "${include_dir}/AST/TransformVisitor.h" + "${include_dir}/AST/TypeProvider.h" "${include_dir}/AST/Util.h" "${include_dir}/AST/Z3CondSimplify.h" ) @@ -80,6 +81,7 @@ set(AST_SOURCES AST/StructFieldRenamer.cpp AST/StructGenerator.cpp AST/SubprogramGenerator.cpp + AST/TypeProvider.cpp ) set(BC_SOURCES diff --git a/lib/Decompiler.cpp b/lib/Decompiler.cpp index 156dc197..ebd7b48f 100644 --- a/lib/Decompiler.cpp +++ b/lib/Decompiler.cpp @@ -83,6 +83,11 @@ Result Decompile( module->getTargetTriple()}; auto ast_unit{clang::tooling::buildASTFromCodeWithArgs("", args, "out.c")}; rellic::DecompilationContext dec_ctx(*ast_unit); + + for (auto& provider : options.additional_providers) { + dec_ctx.type_provider->AddProvider(provider->create(dec_ctx)); + } + rellic::GenerateAST::run(*module, dec_ctx); // TODO(surovic): Add llvm::Value* -> clang::Decl* map // Especially for llvm::Argument* and llvm::Function*. diff --git a/tools/decomp/Decomp.cpp b/tools/decomp/Decomp.cpp index f457ad40..0e4df1f1 100644 --- a/tools/decomp/Decomp.cpp +++ b/tools/decomp/Decomp.cpp @@ -124,7 +124,7 @@ int main(int argc, char* argv[]) { opts.lower_switches = FLAGS_lower_switch; opts.remove_phi_nodes = FLAGS_remove_phi_nodes; - auto result{rellic::Decompile(std::move(module), opts)}; + auto result{rellic::Decompile(std::move(module), std::move(opts))}; if (result.Succeeded()) { auto value{result.TakeValue()}; value.ast->getASTContext().getTranslationUnitDecl()->print(output);