Skip to content

Commit

Permalink
std: optimize TLS on Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
joboet committed Oct 8, 2022
1 parent cba4a38 commit d457801
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 117 deletions.
5 changes: 0 additions & 5 deletions library/std/src/sys/sgx/thread_local_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,3 @@ pub unsafe fn get(key: Key) -> *mut u8 {
pub unsafe fn destroy(key: Key) {
Tls::destroy(AbiKey::from_usize(key))
}

#[inline]
pub fn requires_synchronized_create() -> bool {
false
}
5 changes: 0 additions & 5 deletions library/std/src/sys/solid/thread_local_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,3 @@ pub unsafe fn get(_key: Key) -> *mut u8 {
pub unsafe fn destroy(_key: Key) {
panic!("should not be used on the solid target");
}

#[inline]
pub fn requires_synchronized_create() -> bool {
panic!("should not be used on the solid target");
}
5 changes: 0 additions & 5 deletions library/std/src/sys/unix/thread_local_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,3 @@ pub unsafe fn destroy(key: Key) {
let r = libc::pthread_key_delete(key);
debug_assert_eq!(r, 0);
}

#[inline]
pub fn requires_synchronized_create() -> bool {
false
}
5 changes: 0 additions & 5 deletions library/std/src/sys/unsupported/thread_local_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,3 @@ pub unsafe fn get(_key: Key) -> *mut u8 {
pub unsafe fn destroy(_key: Key) {
panic!("should not be used on this target");
}

#[inline]
pub fn requires_synchronized_create() -> bool {
panic!("should not be used on this target");
}
17 changes: 17 additions & 0 deletions library/std/src/sys/windows/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub type BCRYPT_ALG_HANDLE = LPVOID;
pub type PCONDITION_VARIABLE = *mut CONDITION_VARIABLE;
pub type PLARGE_INTEGER = *mut c_longlong;
pub type PSRWLOCK = *mut SRWLOCK;
pub type LPINIT_ONCE = *mut INIT_ONCE;

pub type SOCKET = crate::os::windows::raw::SOCKET;
pub type socklen_t = c_int;
Expand Down Expand Up @@ -194,6 +195,9 @@ pub const DUPLICATE_SAME_ACCESS: DWORD = 0x00000002;

pub const CONDITION_VARIABLE_INIT: CONDITION_VARIABLE = CONDITION_VARIABLE { ptr: ptr::null_mut() };
pub const SRWLOCK_INIT: SRWLOCK = SRWLOCK { ptr: ptr::null_mut() };
pub const INIT_ONCE_STATIC_INIT: INIT_ONCE = INIT_ONCE { ptr: ptr::null_mut() };

pub const INIT_ONCE_INIT_FAILED: DWORD = 0x00000004;

pub const DETACHED_PROCESS: DWORD = 0x00000008;
pub const CREATE_NEW_PROCESS_GROUP: DWORD = 0x00000200;
Expand Down Expand Up @@ -565,6 +569,10 @@ pub struct CONDITION_VARIABLE {
pub struct SRWLOCK {
pub ptr: LPVOID,
}
#[repr(C)]
pub struct INIT_ONCE {
pub ptr: LPVOID,
}

#[repr(C)]
pub struct REPARSE_MOUNTPOINT_DATA_BUFFER {
Expand Down Expand Up @@ -959,6 +967,7 @@ extern "system" {
pub fn TlsAlloc() -> DWORD;
pub fn TlsGetValue(dwTlsIndex: DWORD) -> LPVOID;
pub fn TlsSetValue(dwTlsIndex: DWORD, lpTlsvalue: LPVOID) -> BOOL;
pub fn TlsFree(dwTlsIndex: DWORD) -> BOOL;
pub fn GetLastError() -> DWORD;
pub fn QueryPerformanceFrequency(lpFrequency: *mut LARGE_INTEGER) -> BOOL;
pub fn QueryPerformanceCounter(lpPerformanceCount: *mut LARGE_INTEGER) -> BOOL;
Expand Down Expand Up @@ -1118,6 +1127,14 @@ extern "system" {
pub fn TryAcquireSRWLockExclusive(SRWLock: PSRWLOCK) -> BOOLEAN;
pub fn TryAcquireSRWLockShared(SRWLock: PSRWLOCK) -> BOOLEAN;

pub fn InitOnceBeginInitialize(
lpInitOnce: LPINIT_ONCE,
dwFlags: DWORD,
fPending: LPBOOL,
lpContext: *mut LPVOID,
) -> BOOL;
pub fn InitOnceComplete(lpInitOnce: LPINIT_ONCE, dwFlags: DWORD, lpContext: LPVOID) -> BOOL;

pub fn CompareStringOrdinal(
lpString1: LPCWSTR,
cchCount1: c_int,
Expand Down
196 changes: 123 additions & 73 deletions library/std/src/sys/windows/thread_local_key.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use crate::mem::ManuallyDrop;
use crate::cell::UnsafeCell;
use crate::ptr;
use crate::sync::atomic::AtomicPtr;
use crate::sync::atomic::Ordering::SeqCst;
use crate::sync::atomic::{
AtomicPtr, AtomicU32,
Ordering::{AcqRel, Acquire, Relaxed, Release},
};
use crate::sys::c;

pub type Key = c::DWORD;
pub type Dtor = unsafe extern "C" fn(*mut u8);
#[cfg(test)]
mod tests;

type Key = c::DWORD;
type Dtor = unsafe extern "C" fn(*mut u8);

// Turns out, like pretty much everything, Windows is pretty close the
// functionality that Unix provides, but slightly different! In the case of
Expand All @@ -22,60 +27,109 @@ pub type Dtor = unsafe extern "C" fn(*mut u8);
// To accomplish this feat, we perform a number of threads, all contained
// within this module:
//
// * All TLS destructors are tracked by *us*, not the windows runtime. This
// * All TLS destructors are tracked by *us*, not the Windows runtime. This
// means that we have a global list of destructors for each TLS key that
// we know about.
// * When a thread exits, we run over the entire list and run dtors for all
// non-null keys. This attempts to match Unix semantics in this regard.
//
// This ends up having the overhead of using a global list, having some
// locks here and there, and in general just adding some more code bloat. We
// attempt to optimize runtime by forgetting keys that don't have
// destructors, but this only gets us so far.
//
// For more details and nitty-gritty, see the code sections below!
//
// [1]: https://www.codeproject.com/Articles/8113/Thread-Local-Storage-The-C-Way
// [2]: https://github.com/ChromiumWebApps/chromium/blob/master/base
// /threading/thread_local_storage_win.cc#L42
// [2]: https://github.com/ChromiumWebApps/chromium/blob/master/base/threading/thread_local_storage_win.cc#L42

// -------------------------------------------------------------------------
// Native bindings
//
// This section is just raw bindings to the native functions that Windows
// provides, There's a few extra calls to deal with destructors.
pub struct StaticKey {
/// The key value shifted up by one. Since TLS_OUT_OF_INDEXES == DWORD::MAX
/// is not a valid key value, this allows us to use zero as sentinel value
/// without risking overflow.
key: AtomicU32,
dtor: Option<Dtor>,
next: AtomicPtr<StaticKey>,
/// Currently, destructors cannot be unregistered, so we cannot use racy
/// initialization for keys. Instead, we need synchronize initialization.
/// Use the Windows-provided `Once` since it does not require TLS.
once: UnsafeCell<c::INIT_ONCE>,
}

#[inline]
pub unsafe fn create(dtor: Option<Dtor>) -> Key {
let key = c::TlsAlloc();
assert!(key != c::TLS_OUT_OF_INDEXES);
if let Some(f) = dtor {
register_dtor(key, f);
impl StaticKey {
#[inline]
pub const fn new(dtor: Option<Dtor>) -> StaticKey {
StaticKey {
key: AtomicU32::new(0),
dtor,
next: AtomicPtr::new(ptr::null_mut()),
once: UnsafeCell::new(c::INIT_ONCE_STATIC_INIT),
}
}
key
}

#[inline]
pub unsafe fn set(key: Key, value: *mut u8) {
let r = c::TlsSetValue(key, value as c::LPVOID);
debug_assert!(r != 0);
}
#[inline]
pub unsafe fn set(&'static self, val: *mut u8) {
let r = c::TlsSetValue(self.key(), val.cast());
debug_assert_eq!(r, c::TRUE);
}

#[inline]
pub unsafe fn get(key: Key) -> *mut u8 {
c::TlsGetValue(key) as *mut u8
}
#[inline]
pub unsafe fn get(&'static self) -> *mut u8 {
c::TlsGetValue(self.key()).cast()
}

#[inline]
pub unsafe fn destroy(_key: Key) {
rtabort!("can't destroy tls keys on windows")
}
#[inline]
unsafe fn key(&'static self) -> Key {
match self.key.load(Acquire) {
0 => self.init(),
key => key - 1,
}
}

#[cold]
unsafe fn init(&'static self) -> Key {
if self.dtor.is_some() {
let mut pending = c::FALSE;
let r = c::InitOnceBeginInitialize(self.once.get(), 0, &mut pending, ptr::null_mut());
assert_eq!(r, c::TRUE);

#[inline]
pub fn requires_synchronized_create() -> bool {
true
if pending == c::FALSE {
// Some other thread initialized the key, load it.
self.key.load(Relaxed) - 1
} else {
let key = c::TlsAlloc();
if key == c::TLS_OUT_OF_INDEXES {
// Wakeup the waiting threads before panicking to avoid deadlock.
c::InitOnceComplete(self.once.get(), c::INIT_ONCE_INIT_FAILED, ptr::null_mut());
panic!("out of TLS indexes");
}

self.key.store(key + 1, Release);
register_dtor(self);

let r = c::InitOnceComplete(self.once.get(), 0, ptr::null_mut());
debug_assert_eq!(r, c::TRUE);

key
}
} else {
// If there is no destructor to clean up, we can use racy initialization.

let key = c::TlsAlloc();
assert_ne!(key, c::TLS_OUT_OF_INDEXES, "out of TLS indexes");

match self.key.compare_exchange(0, key + 1, AcqRel, Acquire) {
Ok(_) => key,
Err(new) => {
// Some other thread completed initialization first, so destroy
// our key and use theirs.
let r = c::TlsFree(key);
debug_assert_eq!(r, c::TRUE);
new - 1
}
}
}
}
}

unsafe impl Send for StaticKey {}
unsafe impl Sync for StaticKey {}

// -------------------------------------------------------------------------
// Dtor registration
//
Expand All @@ -96,29 +150,21 @@ pub fn requires_synchronized_create() -> bool {
// Typically processes have a statically known set of TLS keys which is pretty
// small, and we'd want to keep this memory alive for the whole process anyway
// really.
//
// Perhaps one day we can fold the `Box` here into a static allocation,
// expanding the `StaticKey` structure to contain not only a slot for the TLS
// key but also a slot for the destructor queue on windows. An optimization for
// another day!

static DTORS: AtomicPtr<Node> = AtomicPtr::new(ptr::null_mut());

struct Node {
dtor: Dtor,
key: Key,
next: *mut Node,
}

unsafe fn register_dtor(key: Key, dtor: Dtor) {
let mut node = ManuallyDrop::new(Box::new(Node { key, dtor, next: ptr::null_mut() }));
static DTORS: AtomicPtr<StaticKey> = AtomicPtr::new(ptr::null_mut());

let mut head = DTORS.load(SeqCst);
/// Should only be called once per key, otherwise loops or breaks may occur in
/// the linked list.
unsafe fn register_dtor(key: &'static StaticKey) {
let this = <*const StaticKey>::cast_mut(key);
// Use acquire ordering to pass along the changes done by the previously
// registered keys when we store the new head with release ordering.
let mut head = DTORS.load(Acquire);
loop {
node.next = head;
match DTORS.compare_exchange(head, &mut **node, SeqCst, SeqCst) {
Ok(_) => return, // nothing to drop, we successfully added the node to the list
Err(cur) => head = cur,
key.next.store(head, Relaxed);
match DTORS.compare_exchange_weak(head, this, Release, Acquire) {
Ok(_) => break,
Err(new) => head = new,
}
}
}
Expand Down Expand Up @@ -214,25 +260,29 @@ unsafe extern "system" fn on_tls_callback(h: c::LPVOID, dwReason: c::DWORD, pv:
unsafe fn reference_tls_used() {}
}

#[allow(dead_code)] // actually called above
#[allow(dead_code)] // actually called below
unsafe fn run_dtors() {
let mut any_run = true;
for _ in 0..5 {
if !any_run {
break;
}
any_run = false;
let mut cur = DTORS.load(SeqCst);
let mut any_run = false;

// Use acquire ordering to observe key initialization.
let mut cur = DTORS.load(Acquire);
while !cur.is_null() {
let ptr = c::TlsGetValue((*cur).key);
let key = (*cur).key.load(Relaxed) - 1;
let dtor = (*cur).dtor.unwrap();

let ptr = c::TlsGetValue(key);
if !ptr.is_null() {
c::TlsSetValue((*cur).key, ptr::null_mut());
((*cur).dtor)(ptr as *mut _);
c::TlsSetValue(key, ptr::null_mut());
dtor(ptr as *mut _);
any_run = true;
}

cur = (*cur).next;
cur = (*cur).next.load(Relaxed);
}

if !any_run {
break;
}
}
}
53 changes: 53 additions & 0 deletions library/std/src/sys/windows/thread_local_key/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use super::StaticKey;
use crate::ptr;

#[test]
fn smoke() {
static K1: StaticKey = StaticKey::new(None);
static K2: StaticKey = StaticKey::new(None);

unsafe {
assert!(K1.get().is_null());
assert!(K2.get().is_null());
K1.set(ptr::invalid_mut(1));
K2.set(ptr::invalid_mut(2));
assert_eq!(K1.get() as usize, 1);
assert_eq!(K2.get() as usize, 2);
}
}

#[test]
fn destructors() {
use crate::mem::ManuallyDrop;
use crate::sync::Arc;
use crate::thread;

unsafe extern "C" fn destruct(ptr: *mut u8) {
drop(Arc::from_raw(ptr as *const ()));
}

static KEY: StaticKey = StaticKey::new(Some(destruct));

let shared1 = Arc::new(());
let shared2 = Arc::clone(&shared1);

unsafe {
assert!(KEY.get().is_null());
KEY.set(Arc::into_raw(shared1) as *mut u8);
}

thread::spawn(move || unsafe {
assert!(KEY.get().is_null());
KEY.set(Arc::into_raw(shared2) as *mut u8);
})
.join()
.unwrap();

// Leak the Arc, let the TLS destructor clean it up.
let shared1 = unsafe { ManuallyDrop::new(Arc::from_raw(KEY.get() as *const ())) };
assert_eq!(
Arc::strong_count(&shared1),
1,
"destructor should have dropped the other reference on thread exit"
);
}
Loading

0 comments on commit d457801

Please sign in to comment.