Skip to content

Commit

Permalink
Process AArch64 range extension thunks in Propeller CFG construction
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699227260
  • Loading branch information
Propeller Team authored and copybara-github committed Nov 27, 2024
1 parent 2ad32b5 commit bed8f68
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 10 deletions.
3 changes: 3 additions & 0 deletions propeller/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ cc_library(
deps = [
":addr2cu",
":status_macros",
"@abseil-cpp//absl/container:btree",
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/log",
"@abseil-cpp//absl/log:check",
Expand All @@ -280,6 +281,7 @@ cc_library(
"@llvm-project//llvm:DebugInfo",
"@llvm-project//llvm:Object",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TargetParser",
],
)

Expand Down Expand Up @@ -474,6 +476,7 @@ cc_library(
":binary_address_branch",
":binary_address_branch_path",
":binary_content",
":branch_aggregation",
":propeller_options_cc_proto",
":propeller_statistics",
":status_macros",
Expand Down
80 changes: 73 additions & 7 deletions propeller/binary_address_mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "absl/algorithm/container.h"
#include "absl/base/attributes.h"
#include "absl/base/nullability.h"
#include "absl/container/btree_map.h"
#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
Expand All @@ -32,6 +33,7 @@
#include "propeller/binary_address_branch.h"
#include "propeller/binary_address_branch_path.h"
#include "propeller/binary_content.h"
#include "propeller/branch_aggregation.h"
#include "propeller/propeller_options.pb.h"
#include "propeller/propeller_statistics.h"
#include "propeller/status_macros.h"
Expand Down Expand Up @@ -84,7 +86,9 @@ class BinaryAddressMapperBuilder {
symtab,
std::vector<llvm::object::BBAddrMap> bb_addr_map, PropellerStats &stats,
absl::Nonnull<const PropellerOptions *> options
ABSL_ATTRIBUTE_LIFETIME_BOUND);
ABSL_ATTRIBUTE_LIFETIME_BOUND,
std::optional<absl::btree_map<uint64_t, llvm::object::ELFSymbolRef>>
thunk_map);

BinaryAddressMapperBuilder(const BinaryAddressMapperBuilder &) = delete;
BinaryAddressMapperBuilder &operator=(const BinaryAddressMapper &) = delete;
Expand Down Expand Up @@ -131,6 +135,9 @@ class BinaryAddressMapperBuilder {
int FilterDuplicateNameFunctions(
absl::btree_set<int> &selected_functions) const;

// Create a sorted vector of thunks in the binary from `thunk_map_`.
std::optional<std::vector<ThunkInfo>> GetThunks();

// BB address map of functions.
std::vector<llvm::object::BBAddrMap> bb_addr_map_;
// Non-zero sized function symbols from elf symbol table, indexed by
Expand All @@ -144,6 +151,10 @@ class BinaryAddressMapperBuilder {

PropellerStats *stats_;
const PropellerOptions *options_;

// Map of thunks by address.
std::optional<absl::btree_map<uint64_t, llvm::object::ELFSymbolRef>>
thunk_map_;
};

// Helper class for extracting intra-function paths from binary-address paths.
Expand Down Expand Up @@ -504,6 +515,42 @@ bool BinaryAddressMapper::CanFallThrough(int from, int to) const {
return true;
}

std::optional<ThunkInfo> BinaryAddressMapper::GetThunkInfoUsingBinaryAddress(
uint64_t address) const {
std::optional<int> index = FindThunkInfoIndexUsingBinaryAddress(address);
if (!index.has_value()) return std::nullopt;
return thunks_->at(*index);
}

// Find thunk by binary address
std::optional<int> BinaryAddressMapper::FindThunkInfoIndexUsingBinaryAddress(
uint64_t address) const {
if (!thunks_.has_value()) return std::nullopt;
auto it = absl::c_upper_bound(*thunks_, address,
[](uint64_t addr, const ThunkInfo &thunk) {
return addr < thunk.address;
});
if (it == thunks_->begin()) return std::nullopt;
it = std::prev(it);
uint64_t thunk_end_address = it->address + it->symbol.getSize();
if (address >= thunk_end_address) return std::nullopt;
return it - thunks_->begin();
}

void BinaryAddressMapper::UpdateThunkTargets(
const BranchAggregation &branch_aggregation) {
if (!thunks_.has_value()) return;
for (auto [branch, weight] : branch_aggregation.branch_counters) {
std::optional<int> thunk_index =
FindThunkInfoIndexUsingBinaryAddress(branch.from);

if (!thunk_index.has_value()) continue;

ThunkInfo &thunk_info = thunks_->at(*thunk_index);
thunk_info.target = branch.to;
}
}

// For each lbr record addr1->addr2, find function1/2 that contain addr1/addr2
// and add function1/2's index into the returned set.
absl::btree_set<int> BinaryAddressMapperBuilder::CalculateHotFunctions(
Expand Down Expand Up @@ -638,6 +685,17 @@ absl::btree_set<int> BinaryAddressMapperBuilder::SelectFunctions(
return selected_functions;
}

std::optional<std::vector<ThunkInfo>> BinaryAddressMapperBuilder::GetThunks() {
if (!thunk_map_.has_value()) return std::nullopt;
std::vector<ThunkInfo> thunks;
for (const auto &thunk_entry : *thunk_map_) {
uint64_t thunk_address = thunk_entry.first;
llvm::object::ELFSymbolRef thunk_symbol = thunk_entry.second;
thunks.push_back({.address = thunk_address, .symbol = thunk_symbol});
}
return thunks;
}

std::vector<BbHandleBranchPath> BinaryAddressMapper::ExtractIntraFunctionPaths(
const BinaryAddressBranchPath &address_path) const {
return IntraFunctionPathsExtractor(this).Extract(address_path);
Expand All @@ -647,12 +705,15 @@ BinaryAddressMapperBuilder::BinaryAddressMapperBuilder(
absl::flat_hash_map<uint64_t, llvm::SmallVector<llvm::object::ELFSymbolRef>>
symtab,
std::vector<llvm::object::BBAddrMap> bb_addr_map, PropellerStats &stats,
absl::Nonnull<const PropellerOptions *> options)
absl::Nonnull<const PropellerOptions *> options,
std::optional<absl::btree_map<uint64_t, llvm::object::ELFSymbolRef>>
thunk_map)
: bb_addr_map_(std::move(bb_addr_map)),
symtab_(std::move(symtab)),
symbol_info_map_(GetSymbolInfoMap(symtab_, bb_addr_map_)),
stats_(&stats),
options_(options) {
options_(options),
thunk_map_(std::move(thunk_map)) {
stats_->bbaddrmap_stats.bbaddrmap_function_does_not_have_symtab_entry +=
bb_addr_map_.size() - symbol_info_map_.size();
}
Expand All @@ -661,11 +722,13 @@ BinaryAddressMapper::BinaryAddressMapper(
absl::btree_set<int> selected_functions,
std::vector<llvm::object::BBAddrMap> bb_addr_map,
std::vector<BbHandle> bb_handles,
absl::flat_hash_map<int, FunctionSymbolInfo> symbol_info_map)
absl::flat_hash_map<int, FunctionSymbolInfo> symbol_info_map,
std::optional<std::vector<ThunkInfo>> thunks)
: selected_functions_(std::move(selected_functions)),
bb_handles_(std::move(bb_handles)),
bb_addr_map_(std::move(bb_addr_map)),
symbol_info_map_(std::move(symbol_info_map)) {}
symbol_info_map_(std::move(symbol_info_map)),
thunks_(std::move(thunks)) {}

absl::StatusOr<std::unique_ptr<BinaryAddressMapper>> BuildBinaryAddressMapper(
const PropellerOptions &options, const BinaryContent &binary_content,
Expand All @@ -676,14 +739,16 @@ absl::StatusOr<std::unique_ptr<BinaryAddressMapper>> BuildBinaryAddressMapper(
ASSIGN_OR_RETURN(bb_addr_map, ReadBbAddrMap(binary_content));

return BinaryAddressMapperBuilder(ReadSymbolTable(binary_content),
std::move(bb_addr_map), stats, &options)
std::move(bb_addr_map), stats, &options,
ReadThunkSymbols(binary_content))
.Build(hot_addresses);
}

std::unique_ptr<BinaryAddressMapper> BinaryAddressMapperBuilder::Build(
const absl::flat_hash_set<uint64_t> *hot_addresses) && {
std::optional<uint64_t> last_function_address;
std::vector<BbHandle> bb_handles;
std::optional<std::vector<ThunkInfo>> thunks = GetThunks();
absl::btree_set<int> selected_functions = SelectFunctions(hot_addresses);
DropNonSelectedFunctions(selected_functions);
for (int function_index : selected_functions) {
Expand All @@ -696,9 +761,10 @@ std::unique_ptr<BinaryAddressMapper> BinaryAddressMapperBuilder::Build(
bb_handles.push_back({function_index, bb_index});
last_function_address = function_bb_addr_map.getFunctionAddress();
}

return std::make_unique<BinaryAddressMapper>(
std::move(selected_functions), std::move(bb_addr_map_),
std::move(bb_handles), std::move(symbol_info_map_));
std::move(bb_handles), std::move(symbol_info_map_), std::move(thunks));
}

} // namespace propeller
32 changes: 31 additions & 1 deletion propeller/binary_address_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
#include "absl/time/time.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Object/ELFObjectFile.h"
#include "llvm/Object/ELFTypes.h"
#include "propeller/bb_handle.h"
#include "propeller/binary_address_branch_path.h"
#include "propeller/binary_content.h"
#include "propeller/branch_aggregation.h"
#include "propeller/propeller_options.pb.h"
#include "propeller/propeller_statistics.h"

Expand Down Expand Up @@ -103,6 +105,12 @@ struct BbHandleBranchPath {
}
};

struct ThunkInfo {
uint64_t address;
uint64_t target;
llvm::object::ELFSymbolRef symbol;
};

// Finds basic block entries from binary addresses.
class BinaryAddressMapper {
public:
Expand All @@ -120,7 +128,8 @@ class BinaryAddressMapper {
absl::btree_set<int> selected_functions,
std::vector<llvm::object::BBAddrMap> bb_addr_map,
std::vector<BbHandle> bb_handles,
absl::flat_hash_map<int, FunctionSymbolInfo> symbol_info_map);
absl::flat_hash_map<int, FunctionSymbolInfo> symbol_info_map,
std::optional<std::vector<ThunkInfo>> thunks);

BinaryAddressMapper(const BinaryAddressMapper &) = delete;
BinaryAddressMapper &operator=(const BinaryAddressMapper &) = delete;
Expand All @@ -141,6 +150,10 @@ class BinaryAddressMapper {
return selected_functions_;
}

const std::optional<std::vector<ThunkInfo>> &thunks() const {
return thunks_;
}

// Returns the `bb_handles_` index associated with the binary address
// `address` given a branch from/to this address based on `direction`.
// It returns nullopt if the no `bb_handles_` index can be mapped.
Expand Down Expand Up @@ -186,6 +199,20 @@ class BinaryAddressMapper {
bool CanFallThrough(int function_index, int from_bb_index,
int to_bb_index) const;

// Returns the index of the thunk that contains the given binary address.
// Returns nullopt if no thunk contains the address.
std::optional<int> FindThunkInfoIndexUsingBinaryAddress(
uint64_t address) const;

// Returns the thunk that contains the given binary address. Returns nullopt
// if no thunk contains the address.
std::optional<ThunkInfo> GetThunkInfoUsingBinaryAddress(
uint64_t address) const;

// Sets the targets of thunks in `binary_address_mapper_` to the targets of
// their corresponding branches in `branch_aggregation`.
void UpdateThunkTargets(const BranchAggregation &branch_aggregation);

// Returns the full function's BB address map associated with the given
// `bb_handle`.
const llvm::object::BBAddrMap &GetFunctionEntry(BbHandle bb_handle) const {
Expand Down Expand Up @@ -268,6 +295,9 @@ class BinaryAddressMapper {
// A map from function indices to their symbol info (function names and
// section name).
absl::flat_hash_map<int, FunctionSymbolInfo> symbol_info_map_;

// A vector of thunks in the binary, ordered in increasing order of address.
std::optional<std::vector<ThunkInfo>> thunks_;
};

// Builds a `BinaryAddressMapper` for binary represented by `binary_content` and
Expand Down
31 changes: 31 additions & 0 deletions propeller/binary_content.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <utility>
#include <vector>

#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
Expand Down Expand Up @@ -35,6 +36,7 @@
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TargetParser/Triple.h"
#include "propeller/addr2cu.h"
#include "propeller/status_macros.h"

Expand Down Expand Up @@ -292,6 +294,35 @@ ReadSymbolTable(const BinaryContent &binary_content) {
return symtab;
}

// Read thunks from the symbol table in sorted order.
absl::btree_map<uint64_t, llvm::object::ELFSymbolRef> ReadAArch64ThunkSymbols(
const BinaryContent &binary_content) {
absl::btree_map<uint64_t, llvm::object::ELFSymbolRef> thunk_map;
for (llvm::object::SymbolRef sr : binary_content.object_file->symbols()) {
llvm::object::ELFSymbolRef symbol(sr);
uint8_t stt = symbol.getELFType();
if (stt != llvm::ELF::STT_FUNC) continue;
llvm::Expected<uint64_t> address = sr.getAddress();
if (!address || !*address) continue;
llvm::Expected<llvm::StringRef> func_name = symbol.getName();
// TODO(tzussman): More explicit thunk name check.
if (!func_name || !func_name->starts_with("__AArch64")) continue;
const uint64_t func_size = symbol.getSize();
if (func_size == 0) continue;

thunk_map.insert({*address, sr});
}
return thunk_map;
}

std::optional<absl::btree_map<uint64_t, llvm::object::ELFSymbolRef>>
ReadThunkSymbols(const BinaryContent &binary_content) {
if (binary_content.object_file->getArch() == llvm::Triple::aarch64)
return ReadAArch64ThunkSymbols(binary_content);

return std::nullopt;
}

absl::StatusOr<std::vector<llvm::object::BBAddrMap>> ReadBbAddrMap(
const BinaryContent &binary_content) {
auto *elf_object = llvm::dyn_cast<llvm::object::ELFObjectFileBase>(
Expand Down
9 changes: 9 additions & 0 deletions propeller/binary_content.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <string>
#include <vector>

#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -116,6 +117,14 @@ absl::StatusOr<int64_t> GetSymbolAddress(
absl::flat_hash_map<uint64_t, llvm::SmallVector<llvm::object::ELFSymbolRef>>
ReadSymbolTable(const BinaryContent &binary_content);

// Returns an AArch64 binary's thunk symbols by reading from its symbol table.
absl::btree_map<uint64_t, llvm::object::ELFSymbolRef> ReadAArch64ThunkSymbols(
const BinaryContent &binary_content);

// Returns the binary's thunk symbols by reading from its symbol table.
std::optional<absl::btree_map<uint64_t, llvm::object::ELFSymbolRef>>
ReadThunkSymbols(const BinaryContent &binary_content);

// Returns the binary's `BBAddrMap`s by calling LLVM-side decoding function
// `ELFObjectFileBase::readBBAddrMap`. Returns error if the call fails or if the
// result is empty.
Expand Down
Loading

0 comments on commit bed8f68

Please sign in to comment.