Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 97 additions & 104 deletions cranelift/jit/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::ffi::CString;
use std::io::Write;
use std::ptr;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicPtr, Ordering};
use target_lexicon::PointerWidth;
#[cfg(windows)]
use winapi;
Expand Down Expand Up @@ -129,6 +130,15 @@ impl JITBuilder {
}
}

/// A pending update to the GOT.
struct GotUpdate {
/// The entry that is to be updated.
entry: NonNull<AtomicPtr<u8>>,

/// The new value of the entry.
ptr: *const u8,
}

/// A `JITModule` implements `Module` and emits code and data into memory where it can be
/// directly called and accessed.
///
Expand All @@ -140,15 +150,18 @@ pub struct JITModule {
libcall_names: Box<dyn Fn(ir::LibCall) -> String>,
memory: MemoryHandle,
declarations: ModuleDeclarations,
function_got_entries: SecondaryMap<FuncId, Option<NonNull<*const u8>>>,
function_got_entries: SecondaryMap<FuncId, Option<NonNull<AtomicPtr<u8>>>>,
function_plt_entries: SecondaryMap<FuncId, Option<NonNull<[u8; 16]>>>,
data_object_got_entries: SecondaryMap<DataId, Option<NonNull<*const u8>>>,
libcall_got_entries: HashMap<ir::LibCall, NonNull<*const u8>>,
data_object_got_entries: SecondaryMap<DataId, Option<NonNull<AtomicPtr<u8>>>>,
libcall_got_entries: HashMap<ir::LibCall, NonNull<AtomicPtr<u8>>>,
libcall_plt_entries: HashMap<ir::LibCall, NonNull<[u8; 16]>>,
compiled_functions: SecondaryMap<FuncId, Option<CompiledBlob>>,
compiled_data_objects: SecondaryMap<DataId, Option<CompiledBlob>>,
functions_to_finalize: Vec<FuncId>,
data_objects_to_finalize: Vec<DataId>,

/// Updates to the GOT awaiting relocations to be made and region protections to be set
pending_got_updates: Vec<GotUpdate>,
}

/// A handle to allow freeing memory allocated by the `Module`.
Expand Down Expand Up @@ -180,54 +193,53 @@ impl JITModule {
.or_else(|| lookup_with_dlsym(name))
}

fn new_func_plt_entry(&mut self, id: FuncId, val: *const u8) {
fn new_got_entry(&mut self, val: *const u8) -> NonNull<AtomicPtr<u8>> {
let got_entry = self
.memory
.writable
.allocate(
std::mem::size_of::<*const u8>(),
std::mem::align_of::<*const u8>().try_into().unwrap(),
std::mem::size_of::<AtomicPtr<u8>>(),
std::mem::align_of::<AtomicPtr<u8>>().try_into().unwrap(),
)
.unwrap()
.cast::<*const u8>();
self.function_got_entries[id] = Some(NonNull::new(got_entry).unwrap());
.cast::<AtomicPtr<u8>>();
unsafe {
std::ptr::write(got_entry, val);
std::ptr::write(got_entry, AtomicPtr::new(val as *mut _));
}
NonNull::new(got_entry).unwrap()
}

fn new_plt_entry(&mut self, got_entry: NonNull<AtomicPtr<u8>>) -> NonNull<[u8; 16]> {
let plt_entry = self
.memory
.code
.allocate(std::mem::size_of::<[u8; 16]>(), EXECUTABLE_DATA_ALIGNMENT)
.unwrap()
.cast::<[u8; 16]>();
unsafe {
Self::write_plt_entry_bytes(plt_entry, got_entry);
}
NonNull::new(plt_entry).unwrap()
}

fn new_func_plt_entry(&mut self, id: FuncId, val: *const u8) {
let got_entry = self.new_got_entry(val);
self.function_got_entries[id] = Some(got_entry);
let plt_entry = self.new_plt_entry(got_entry);
self.record_function_for_perf(
plt_entry as *mut _,
plt_entry.as_ptr().cast(),
std::mem::size_of::<[u8; 16]>(),
&format!("{}@plt", self.declarations.get_function_decl(id).name),
);
self.function_plt_entries[id] = Some(NonNull::new(plt_entry).unwrap());
unsafe {
Self::write_plt_entry_bytes(plt_entry, got_entry);
}
self.function_plt_entries[id] = Some(plt_entry);
}

fn new_data_got_entry(&mut self, id: DataId, val: *const u8) {
let got_entry = self
.memory
.writable
.allocate(
std::mem::size_of::<*const u8>(),
std::mem::align_of::<*const u8>().try_into().unwrap(),
)
.unwrap()
.cast::<*const u8>();
self.data_object_got_entries[id] = Some(NonNull::new(got_entry).unwrap());
unsafe {
std::ptr::write(got_entry, val);
}
let got_entry = self.new_got_entry(val);
self.data_object_got_entries[id] = Some(got_entry);
}

unsafe fn write_plt_entry_bytes(plt_ptr: *mut [u8; 16], got_ptr: *mut *const u8) {
unsafe fn write_plt_entry_bytes(plt_ptr: *mut [u8; 16], got_ptr: NonNull<AtomicPtr<u8>>) {
assert!(
cfg!(target_arch = "x86_64"),
"PLT is currently only supported on x86_64"
Expand All @@ -236,7 +248,7 @@ impl JITModule {
let mut plt_val = [
0xff, 0x25, 0, 0, 0, 0, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b,
];
let what = got_ptr as isize - 4;
let what = got_ptr.as_ptr() as isize - 4;
let at = plt_ptr as isize + 2;
plt_val[2..6].copy_from_slice(&i32::to_ne_bytes(i32::try_from(what - at).unwrap()));
std::ptr::write(plt_ptr, plt_val);
Expand Down Expand Up @@ -289,32 +301,25 @@ impl JITModule {
///
/// Panics if there's no entry in the table for the given function.
pub fn read_got_entry(&self, func_id: FuncId) -> *const u8 {
unsafe { *self.function_got_entries[func_id].unwrap().as_ptr() }
let got_entry = self.function_got_entries[func_id].unwrap();
unsafe { got_entry.as_ref() }.load(Ordering::SeqCst)
}

fn get_got_address(&self, name: &ir::ExternalName) -> *const u8 {
fn get_got_address(&self, name: &ir::ExternalName) -> NonNull<AtomicPtr<u8>> {
match *name {
ir::ExternalName::User { .. } => {
if ModuleDeclarations::is_function(name) {
let func_id = FuncId::from_name(name);
self.function_got_entries[func_id]
.unwrap()
.as_ptr()
.cast::<u8>()
self.function_got_entries[func_id].unwrap()
} else {
let data_id = DataId::from_name(name);
self.data_object_got_entries[data_id]
.unwrap()
.as_ptr()
.cast::<u8>()
self.data_object_got_entries[data_id].unwrap()
}
}
ir::ExternalName::LibCall(ref libcall) => self
ir::ExternalName::LibCall(ref libcall) => *self
.libcall_got_entries
.get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall))
.as_ptr()
.cast::<u8>(),
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)),
_ => panic!("invalid ExternalName {}", name),
}
}
Expand Down Expand Up @@ -406,7 +411,7 @@ impl JITModule {
.expect("function must be compiled before it can be finalized");
func.perform_relocations(
|name| self.get_address(name),
|name| self.get_got_address(name),
|name| self.get_got_address(name).as_ptr().cast(),
|name| self.get_plt_address(name),
);
}
Expand All @@ -419,14 +424,18 @@ impl JITModule {
.expect("data object must be compiled before it can be finalized");
data.perform_relocations(
|name| self.get_address(name),
|name| self.get_got_address(name),
|name| self.get_got_address(name).as_ptr().cast(),
|name| self.get_plt_address(name),
);
}

// Now that we're done patching, prepare the memory for execution!
self.memory.readonly.set_readonly();
self.memory.code.set_readable_and_executable();

for update in self.pending_got_updates.drain(..) {
unsafe { update.entry.as_ref() }.store(update.ptr as *mut _, Ordering::SeqCst);
}
}

/// Create a new `JITModule`.
Expand All @@ -438,33 +447,38 @@ impl JITModule {
);
}

let mut memory = MemoryHandle {
code: Memory::new(),
readonly: Memory::new(),
writable: Memory::new(),
let mut module = Self {
isa: builder.isa,
hotswap_enabled: builder.hotswap_enabled,
symbols: builder.symbols,
libcall_names: builder.libcall_names,
memory: MemoryHandle {
code: Memory::new(),
readonly: Memory::new(),
writable: Memory::new(),
},
declarations: ModuleDeclarations::default(),
function_got_entries: SecondaryMap::new(),
function_plt_entries: SecondaryMap::new(),
data_object_got_entries: SecondaryMap::new(),
libcall_got_entries: HashMap::new(),
libcall_plt_entries: HashMap::new(),
compiled_functions: SecondaryMap::new(),
compiled_data_objects: SecondaryMap::new(),
functions_to_finalize: Vec::new(),
data_objects_to_finalize: Vec::new(),
pending_got_updates: Vec::new(),
};

let mut libcall_got_entries = HashMap::new();
let mut libcall_plt_entries = HashMap::new();

// Pre-create a GOT and PLT entry for each libcall.
let all_libcalls = if builder.isa.flags().is_pic() {
let all_libcalls = if module.isa.flags().is_pic() {
ir::LibCall::all_libcalls()
} else {
&[] // Not PIC, so no GOT and PLT entries necessary
};
for &libcall in all_libcalls {
let got_entry = memory
.writable
.allocate(
std::mem::size_of::<*const u8>(),
std::mem::align_of::<*const u8>().try_into().unwrap(),
)
.unwrap()
.cast::<*const u8>();
libcall_got_entries.insert(libcall, NonNull::new(got_entry).unwrap());
let sym = (builder.libcall_names)(libcall);
let addr = if let Some(addr) = builder
let sym = (module.libcall_names)(libcall);
let addr = if let Some(addr) = module
.symbols
.get(&sym)
.copied()
Expand All @@ -474,37 +488,13 @@ impl JITModule {
} else {
continue;
};
unsafe {
std::ptr::write(got_entry, addr);
}
let plt_entry = memory
.code
.allocate(std::mem::size_of::<[u8; 16]>(), EXECUTABLE_DATA_ALIGNMENT)
.unwrap()
.cast::<[u8; 16]>();
libcall_plt_entries.insert(libcall, NonNull::new(plt_entry).unwrap());
unsafe {
Self::write_plt_entry_bytes(plt_entry, got_entry);
}
let got_entry = module.new_got_entry(addr);
module.libcall_got_entries.insert(libcall, got_entry);
let plt_entry = module.new_plt_entry(got_entry);
module.libcall_plt_entries.insert(libcall, plt_entry);
}

Self {
isa: builder.isa,
hotswap_enabled: builder.hotswap_enabled,
symbols: builder.symbols,
libcall_names: builder.libcall_names,
memory,
declarations: ModuleDeclarations::default(),
function_got_entries: SecondaryMap::new(),
function_plt_entries: SecondaryMap::new(),
data_object_got_entries: SecondaryMap::new(),
libcall_got_entries,
libcall_plt_entries,
compiled_functions: SecondaryMap::new(),
compiled_data_objects: SecondaryMap::new(),
functions_to_finalize: Vec::new(),
data_objects_to_finalize: Vec::new(),
}
module
}

/// Allow a single future `define_function` on a previously defined function. This allows for
Expand Down Expand Up @@ -682,9 +672,10 @@ impl Module for JITModule {
});

if self.isa.flags().is_pic() {
unsafe {
std::ptr::write(self.function_got_entries[id].unwrap().as_ptr(), ptr);
}
self.pending_got_updates.push(GotUpdate {
entry: self.function_got_entries[id].unwrap(),
ptr,
})
}

if self.hotswap_enabled {
Expand All @@ -704,7 +695,7 @@ impl Module for JITModule {
.cast::<u8>(),
_ => panic!("invalid ExternalName {}", name),
},
|name| self.get_got_address(name),
|name| self.get_got_address(name).as_ptr().cast(),
|name| self.get_plt_address(name),
);
} else {
Expand Down Expand Up @@ -754,9 +745,10 @@ impl Module for JITModule {
});

if self.isa.flags().is_pic() {
unsafe {
std::ptr::write(self.function_got_entries[id].unwrap().as_ptr(), ptr);
}
self.pending_got_updates.push(GotUpdate {
entry: self.function_got_entries[id].unwrap(),
ptr,
})
}

if self.hotswap_enabled {
Expand All @@ -765,7 +757,7 @@ impl Module for JITModule {
.unwrap()
.perform_relocations(
|name| unreachable!("non GOT or PLT relocation in function {} to {}", id, name),
|name| self.get_got_address(name),
|name| self.get_got_address(name).as_ptr().cast(),
|name| self.get_plt_address(name),
);
} else {
Expand Down Expand Up @@ -836,9 +828,10 @@ impl Module for JITModule {
self.compiled_data_objects[id] = Some(CompiledBlob { ptr, size, relocs });
self.data_objects_to_finalize.push(id);
if self.isa.flags().is_pic() {
unsafe {
std::ptr::write(self.data_object_got_entries[id].unwrap().as_ptr(), ptr);
}
self.pending_got_updates.push(GotUpdate {
entry: self.data_object_got_entries[id].unwrap(),
ptr,
})
}

Ok(())
Expand Down