diff --git a/propeller/BUILD b/propeller/BUILD index f34c22cc8a7..b0f695c473b 100644 --- a/propeller/BUILD +++ b/propeller/BUILD @@ -269,6 +269,7 @@ cc_library( deps = [ ":addr2cu", ":status_macros", + "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", @@ -280,6 +281,7 @@ cc_library( "@llvm-project//llvm:DebugInfo", "@llvm-project//llvm:Object", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", ], ) @@ -474,6 +476,7 @@ cc_library( ":binary_address_branch", ":binary_address_branch_path", ":binary_content", + ":branch_aggregation", ":propeller_options_cc_proto", ":propeller_statistics", ":status_macros", diff --git a/propeller/binary_address_mapper.cc b/propeller/binary_address_mapper.cc index 543474ebd2f..31edcf7a2fa 100644 --- a/propeller/binary_address_mapper.cc +++ b/propeller/binary_address_mapper.cc @@ -13,6 +13,7 @@ #include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/base/nullability.h" +#include "absl/container/btree_map.h" #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -32,6 +33,7 @@ #include "propeller/binary_address_branch.h" #include "propeller/binary_address_branch_path.h" #include "propeller/binary_content.h" +#include "propeller/branch_aggregation.h" #include "propeller/propeller_options.pb.h" #include "propeller/propeller_statistics.h" #include "propeller/status_macros.h" @@ -84,7 +86,9 @@ class BinaryAddressMapperBuilder { symtab, std::vector bb_addr_map, PropellerStats &stats, absl::Nonnull options - ABSL_ATTRIBUTE_LIFETIME_BOUND); + ABSL_ATTRIBUTE_LIFETIME_BOUND, + std::optional> + thunk_map); BinaryAddressMapperBuilder(const BinaryAddressMapperBuilder &) = delete; BinaryAddressMapperBuilder &operator=(const BinaryAddressMapper &) = delete; @@ -131,6 +135,9 @@ class BinaryAddressMapperBuilder { int FilterDuplicateNameFunctions( absl::btree_set &selected_functions) const; + // Create a sorted vector of thunks in the binary from `thunk_map_`. + std::optional> GetThunks(); + // BB address map of functions. std::vector bb_addr_map_; // Non-zero sized function symbols from elf symbol table, indexed by @@ -144,6 +151,10 @@ class BinaryAddressMapperBuilder { PropellerStats *stats_; const PropellerOptions *options_; + + // Map of thunks by address. + std::optional> + thunk_map_; }; // Helper class for extracting intra-function paths from binary-address paths. @@ -504,6 +515,42 @@ bool BinaryAddressMapper::CanFallThrough(int from, int to) const { return true; } +std::optional BinaryAddressMapper::GetThunkInfoUsingBinaryAddress( + uint64_t address) const { + std::optional index = FindThunkInfoIndexUsingBinaryAddress(address); + if (!index.has_value()) return std::nullopt; + return thunks_->at(*index); +} + +// Find thunk by binary address +std::optional BinaryAddressMapper::FindThunkInfoIndexUsingBinaryAddress( + uint64_t address) const { + if (!thunks_.has_value()) return std::nullopt; + auto it = absl::c_upper_bound(*thunks_, address, + [](uint64_t addr, const ThunkInfo &thunk) { + return addr < thunk.address; + }); + if (it == thunks_->begin()) return std::nullopt; + it = std::prev(it); + uint64_t thunk_end_address = it->address + it->symbol.getSize(); + if (address >= thunk_end_address) return std::nullopt; + return it - thunks_->begin(); +} + +void BinaryAddressMapper::UpdateThunkTargets( + const BranchAggregation &branch_aggregation) { + if (!thunks_.has_value()) return; + for (auto [branch, weight] : branch_aggregation.branch_counters) { + std::optional thunk_index = + FindThunkInfoIndexUsingBinaryAddress(branch.from); + + if (!thunk_index.has_value()) continue; + + ThunkInfo &thunk_info = thunks_->at(*thunk_index); + thunk_info.target = branch.to; + } +} + // For each lbr record addr1->addr2, find function1/2 that contain addr1/addr2 // and add function1/2's index into the returned set. absl::btree_set BinaryAddressMapperBuilder::CalculateHotFunctions( @@ -638,6 +685,17 @@ absl::btree_set BinaryAddressMapperBuilder::SelectFunctions( return selected_functions; } +std::optional> BinaryAddressMapperBuilder::GetThunks() { + if (!thunk_map_.has_value()) return std::nullopt; + std::vector thunks; + for (const auto &thunk_entry : *thunk_map_) { + uint64_t thunk_address = thunk_entry.first; + llvm::object::ELFSymbolRef thunk_symbol = thunk_entry.second; + thunks.push_back({.address = thunk_address, .symbol = thunk_symbol}); + } + return thunks; +} + std::vector BinaryAddressMapper::ExtractIntraFunctionPaths( const BinaryAddressBranchPath &address_path) const { return IntraFunctionPathsExtractor(this).Extract(address_path); @@ -647,12 +705,15 @@ BinaryAddressMapperBuilder::BinaryAddressMapperBuilder( absl::flat_hash_map> symtab, std::vector bb_addr_map, PropellerStats &stats, - absl::Nonnull options) + absl::Nonnull options, + std::optional> + thunk_map) : bb_addr_map_(std::move(bb_addr_map)), symtab_(std::move(symtab)), symbol_info_map_(GetSymbolInfoMap(symtab_, bb_addr_map_)), stats_(&stats), - options_(options) { + options_(options), + thunk_map_(std::move(thunk_map)) { stats_->bbaddrmap_stats.bbaddrmap_function_does_not_have_symtab_entry += bb_addr_map_.size() - symbol_info_map_.size(); } @@ -661,11 +722,13 @@ BinaryAddressMapper::BinaryAddressMapper( absl::btree_set selected_functions, std::vector bb_addr_map, std::vector bb_handles, - absl::flat_hash_map symbol_info_map) + absl::flat_hash_map symbol_info_map, + std::optional> thunks) : selected_functions_(std::move(selected_functions)), bb_handles_(std::move(bb_handles)), bb_addr_map_(std::move(bb_addr_map)), - symbol_info_map_(std::move(symbol_info_map)) {} + symbol_info_map_(std::move(symbol_info_map)), + thunks_(std::move(thunks)) {} absl::StatusOr> BuildBinaryAddressMapper( const PropellerOptions &options, const BinaryContent &binary_content, @@ -676,7 +739,8 @@ absl::StatusOr> BuildBinaryAddressMapper( ASSIGN_OR_RETURN(bb_addr_map, ReadBbAddrMap(binary_content)); return BinaryAddressMapperBuilder(ReadSymbolTable(binary_content), - std::move(bb_addr_map), stats, &options) + std::move(bb_addr_map), stats, &options, + ReadThunkSymbols(binary_content)) .Build(hot_addresses); } @@ -684,6 +748,7 @@ std::unique_ptr BinaryAddressMapperBuilder::Build( const absl::flat_hash_set *hot_addresses) && { std::optional last_function_address; std::vector bb_handles; + std::optional> thunks = GetThunks(); absl::btree_set selected_functions = SelectFunctions(hot_addresses); DropNonSelectedFunctions(selected_functions); for (int function_index : selected_functions) { @@ -696,9 +761,10 @@ std::unique_ptr BinaryAddressMapperBuilder::Build( bb_handles.push_back({function_index, bb_index}); last_function_address = function_bb_addr_map.getFunctionAddress(); } + return std::make_unique( std::move(selected_functions), std::move(bb_addr_map_), - std::move(bb_handles), std::move(symbol_info_map_)); + std::move(bb_handles), std::move(symbol_info_map_), std::move(thunks)); } } // namespace propeller diff --git a/propeller/binary_address_mapper.h b/propeller/binary_address_mapper.h index c3147a8e550..49333034ee4 100644 --- a/propeller/binary_address_mapper.h +++ b/propeller/binary_address_mapper.h @@ -18,10 +18,12 @@ #include "absl/time/time.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Object/ELFObjectFile.h" #include "llvm/Object/ELFTypes.h" #include "propeller/bb_handle.h" #include "propeller/binary_address_branch_path.h" #include "propeller/binary_content.h" +#include "propeller/branch_aggregation.h" #include "propeller/propeller_options.pb.h" #include "propeller/propeller_statistics.h" @@ -103,6 +105,12 @@ struct BbHandleBranchPath { } }; +struct ThunkInfo { + uint64_t address; + uint64_t target; + llvm::object::ELFSymbolRef symbol; +}; + // Finds basic block entries from binary addresses. class BinaryAddressMapper { public: @@ -120,7 +128,8 @@ class BinaryAddressMapper { absl::btree_set selected_functions, std::vector bb_addr_map, std::vector bb_handles, - absl::flat_hash_map symbol_info_map); + absl::flat_hash_map symbol_info_map, + std::optional> thunks); BinaryAddressMapper(const BinaryAddressMapper &) = delete; BinaryAddressMapper &operator=(const BinaryAddressMapper &) = delete; @@ -141,6 +150,10 @@ class BinaryAddressMapper { return selected_functions_; } + const std::optional> &thunks() const { + return thunks_; + } + // Returns the `bb_handles_` index associated with the binary address // `address` given a branch from/to this address based on `direction`. // It returns nullopt if the no `bb_handles_` index can be mapped. @@ -186,6 +199,20 @@ class BinaryAddressMapper { bool CanFallThrough(int function_index, int from_bb_index, int to_bb_index) const; + // Returns the index of the thunk that contains the given binary address. + // Returns nullopt if no thunk contains the address. + std::optional FindThunkInfoIndexUsingBinaryAddress( + uint64_t address) const; + + // Returns the thunk that contains the given binary address. Returns nullopt + // if no thunk contains the address. + std::optional GetThunkInfoUsingBinaryAddress( + uint64_t address) const; + + // Sets the targets of thunks in `binary_address_mapper_` to the targets of + // their corresponding branches in `branch_aggregation`. + void UpdateThunkTargets(const BranchAggregation &branch_aggregation); + // Returns the full function's BB address map associated with the given // `bb_handle`. const llvm::object::BBAddrMap &GetFunctionEntry(BbHandle bb_handle) const { @@ -268,6 +295,9 @@ class BinaryAddressMapper { // A map from function indices to their symbol info (function names and // section name). absl::flat_hash_map symbol_info_map_; + + // A vector of thunks in the binary, ordered in increasing order of address. + std::optional> thunks_; }; // Builds a `BinaryAddressMapper` for binary represented by `binary_content` and diff --git a/propeller/binary_content.cc b/propeller/binary_content.cc index 2d2a9520333..54c84007e21 100644 --- a/propeller/binary_content.cc +++ b/propeller/binary_content.cc @@ -8,6 +8,7 @@ #include #include +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -35,6 +36,7 @@ #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TargetParser/Triple.h" #include "propeller/addr2cu.h" #include "propeller/status_macros.h" @@ -292,6 +294,35 @@ ReadSymbolTable(const BinaryContent &binary_content) { return symtab; } +// Read thunks from the symbol table in sorted order. +absl::btree_map ReadAArch64ThunkSymbols( + const BinaryContent &binary_content) { + absl::btree_map thunk_map; + for (llvm::object::SymbolRef sr : binary_content.object_file->symbols()) { + llvm::object::ELFSymbolRef symbol(sr); + uint8_t stt = symbol.getELFType(); + if (stt != llvm::ELF::STT_FUNC) continue; + llvm::Expected address = sr.getAddress(); + if (!address || !*address) continue; + llvm::Expected func_name = symbol.getName(); + // TODO(tzussman): More explicit thunk name check. + if (!func_name || !func_name->starts_with("__AArch64")) continue; + const uint64_t func_size = symbol.getSize(); + if (func_size == 0) continue; + + thunk_map.insert({*address, sr}); + } + return thunk_map; +} + +std::optional> +ReadThunkSymbols(const BinaryContent &binary_content) { + if (binary_content.object_file->getArch() == llvm::Triple::aarch64) + return ReadAArch64ThunkSymbols(binary_content); + + return std::nullopt; +} + absl::StatusOr> ReadBbAddrMap( const BinaryContent &binary_content) { auto *elf_object = llvm::dyn_cast( diff --git a/propeller/binary_content.h b/propeller/binary_content.h index 86fb1f656e2..d7ad79797d0 100644 --- a/propeller/binary_content.h +++ b/propeller/binary_content.h @@ -7,6 +7,7 @@ #include #include +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -116,6 +117,14 @@ absl::StatusOr GetSymbolAddress( absl::flat_hash_map> ReadSymbolTable(const BinaryContent &binary_content); +// Returns an AArch64 binary's thunk symbols by reading from its symbol table. +absl::btree_map ReadAArch64ThunkSymbols( + const BinaryContent &binary_content); + +// Returns the binary's thunk symbols by reading from its symbol table. +std::optional> +ReadThunkSymbols(const BinaryContent &binary_content); + // Returns the binary's `BBAddrMap`s by calling LLVM-side decoding function // `ELFObjectFileBase::readBBAddrMap`. Returns error if the call fails or if the // result is empty. diff --git a/propeller/program_cfg_builder.cc b/propeller/program_cfg_builder.cc index 9b55ce00181..6079c00f8e2 100644 --- a/propeller/program_cfg_builder.cc +++ b/propeller/program_cfg_builder.cc @@ -218,6 +218,22 @@ absl::Status ProgramCfgBuilder::CreateEdges( std::optional to_bb_index = binary_address_mapper_->FindBbHandleIndexUsingBinaryAddress( branch.to, BranchDirection::kTo); + + bool is_thunk_call = false; + + // Check if the branch is a thunk call, and if so, find the target of the + // thunk and set `to_bb_index` accordingly. + if (!to_bb_index.has_value()) { + std::optional thunk_info = + binary_address_mapper_->GetThunkInfoUsingBinaryAddress(branch.to); + if (thunk_info.has_value()) { + to_bb_index = + binary_address_mapper_->FindBbHandleIndexUsingBinaryAddress( + thunk_info->target, BranchDirection::kTo); + is_thunk_call = true; + } + } + if (!to_bb_index.has_value()) continue; BbHandle to_bb_handle = binary_address_mapper_->bb_handles()[*to_bb_index]; @@ -235,6 +251,7 @@ absl::Status ProgramCfgBuilder::CreateEdges( if ((!from_bb_index.has_value() || binary_address_mapper_->GetBBEntry(from_bb_handle).hasReturn() || to_bb_handle.function_index != from_bb_handle.function_index) && + !is_thunk_call && // Not a thunk call binary_address_mapper_->GetFunctionEntry(to_bb_handle) .getFunctionAddress() != branch.to && // Not a call // Jump to the beginning of the basicblock @@ -253,7 +270,8 @@ absl::Status ProgramCfgBuilder::CreateEdges( } if (!from_bb_index.has_value()) continue; if (!binary_address_mapper_->GetBBEntry(from_bb_handle).hasReturn() && - binary_address_mapper_->GetAddress(to_bb_handle) != branch.to) { + binary_address_mapper_->GetAddress(to_bb_handle) != branch.to && + !is_thunk_call) { // Jump is not a return and its target is not the beginning of a function // or a basic block. weight_on_dubious_edges += weight; @@ -261,7 +279,8 @@ absl::Status ProgramCfgBuilder::CreateEdges( CFGEdgeKind edge_kind = CFGEdgeKind::kBranchOrFallthough; if (binary_address_mapper_->GetFunctionEntry(to_bb_handle) - .getFunctionAddress() == branch.to) { + .getFunctionAddress() == branch.to || + is_thunk_call) { edge_kind = CFGEdgeKind::kCall; } else if (branch.to != binary_address_mapper_->GetAddress(to_bb_handle) || binary_address_mapper_->GetBBEntry(from_bb_handle).hasReturn()) {