Skip to content

Commit

Permalink
Auto merge of rust-lang#102655 - joboet:windows_tls_opt, r=ChrisDenton
Browse files Browse the repository at this point in the history
Optimize TLS on Windows

This implements the suggestion in the current TLS code to embed the linked list of destructors in the `StaticKey` structure to save allocations. Additionally, locking is avoided when no destructor needs to be run. By using one Windows-provided `Once` per key instead of a global lock, locking is more finely-grained (this unblocks rust-lang#100579).
  • Loading branch information
bors committed Oct 13, 2022
2 parents 3cf5fc5 + d457801 commit fa0ca78
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 117 deletions.
5 changes: 0 additions & 5 deletions library/std/src/sys/sgx/thread_local_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,3 @@ pub unsafe fn get(key: Key) -> *mut u8 {
pub unsafe fn destroy(key: Key) {
Tls::destroy(AbiKey::from_usize(key))
}

#[inline]
pub fn requires_synchronized_create() -> bool {
false
}
5 changes: 0 additions & 5 deletions library/std/src/sys/solid/thread_local_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,3 @@ pub unsafe fn get(_key: Key) -> *mut u8 {
pub unsafe fn destroy(_key: Key) {
panic!("should not be used on the solid target");
}

#[inline]
pub fn requires_synchronized_create() -> bool {
panic!("should not be used on the solid target");
}
5 changes: 0 additions & 5 deletions library/std/src/sys/unix/thread_local_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,3 @@ pub unsafe fn destroy(key: Key) {
let r = libc::pthread_key_delete(key);
debug_assert_eq!(r, 0);
}

#[inline]
pub fn requires_synchronized_create() -> bool {
false
}
5 changes: 0 additions & 5 deletions library/std/src/sys/unsupported/thread_local_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,3 @@ pub unsafe fn get(_key: Key) -> *mut u8 {
pub unsafe fn destroy(_key: Key) {
panic!("should not be used on this target");
}

#[inline]
pub fn requires_synchronized_create() -> bool {
panic!("should not be used on this target");
}
17 changes: 17 additions & 0 deletions library/std/src/sys/windows/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub type BCRYPT_ALG_HANDLE = LPVOID;
pub type PCONDITION_VARIABLE = *mut CONDITION_VARIABLE;
pub type PLARGE_INTEGER = *mut c_longlong;
pub type PSRWLOCK = *mut SRWLOCK;
pub type LPINIT_ONCE = *mut INIT_ONCE;

pub type SOCKET = crate::os::windows::raw::SOCKET;
pub type socklen_t = c_int;
Expand Down Expand Up @@ -194,6 +195,9 @@ pub const DUPLICATE_SAME_ACCESS: DWORD = 0x00000002;

pub const CONDITION_VARIABLE_INIT: CONDITION_VARIABLE = CONDITION_VARIABLE { ptr: ptr::null_mut() };
pub const SRWLOCK_INIT: SRWLOCK = SRWLOCK { ptr: ptr::null_mut() };
pub const INIT_ONCE_STATIC_INIT: INIT_ONCE = INIT_ONCE { ptr: ptr::null_mut() };

pub const INIT_ONCE_INIT_FAILED: DWORD = 0x00000004;

pub const DETACHED_PROCESS: DWORD = 0x00000008;
pub const CREATE_NEW_PROCESS_GROUP: DWORD = 0x00000200;
Expand Down Expand Up @@ -565,6 +569,10 @@ pub struct CONDITION_VARIABLE {
pub struct SRWLOCK {
pub ptr: LPVOID,
}
#[repr(C)]
pub struct INIT_ONCE {
pub ptr: LPVOID,
}

#[repr(C)]
pub struct REPARSE_MOUNTPOINT_DATA_BUFFER {
Expand Down Expand Up @@ -955,6 +963,7 @@ extern "system" {
pub fn TlsAlloc() -> DWORD;
pub fn TlsGetValue(dwTlsIndex: DWORD) -> LPVOID;
pub fn TlsSetValue(dwTlsIndex: DWORD, lpTlsvalue: LPVOID) -> BOOL;
pub fn TlsFree(dwTlsIndex: DWORD) -> BOOL;
pub fn GetLastError() -> DWORD;
pub fn QueryPerformanceFrequency(lpFrequency: *mut LARGE_INTEGER) -> BOOL;
pub fn QueryPerformanceCounter(lpPerformanceCount: *mut LARGE_INTEGER) -> BOOL;
Expand Down Expand Up @@ -1114,6 +1123,14 @@ extern "system" {
pub fn TryAcquireSRWLockExclusive(SRWLock: PSRWLOCK) -> BOOLEAN;
pub fn TryAcquireSRWLockShared(SRWLock: PSRWLOCK) -> BOOLEAN;

pub fn InitOnceBeginInitialize(
lpInitOnce: LPINIT_ONCE,
dwFlags: DWORD,
fPending: LPBOOL,
lpContext: *mut LPVOID,
) -> BOOL;
pub fn InitOnceComplete(lpInitOnce: LPINIT_ONCE, dwFlags: DWORD, lpContext: LPVOID) -> BOOL;

pub fn CompareStringOrdinal(
lpString1: LPCWSTR,
cchCount1: c_int,
Expand Down
196 changes: 123 additions & 73 deletions library/std/src/sys/windows/thread_local_key.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use crate::mem::ManuallyDrop;
use crate::cell::UnsafeCell;
use crate::ptr;
use crate::sync::atomic::AtomicPtr;
use crate::sync::atomic::Ordering::SeqCst;
use crate::sync::atomic::{
AtomicPtr, AtomicU32,
Ordering::{AcqRel, Acquire, Relaxed, Release},
};
use crate::sys::c;

pub type Key = c::DWORD;
pub type Dtor = unsafe extern "C" fn(*mut u8);
#[cfg(test)]
mod tests;

type Key = c::DWORD;
type Dtor = unsafe extern "C" fn(*mut u8);

// Turns out, like pretty much everything, Windows is pretty close the
// functionality that Unix provides, but slightly different! In the case of
Expand All @@ -22,60 +27,109 @@ pub type Dtor = unsafe extern "C" fn(*mut u8);
// To accomplish this feat, we perform a number of threads, all contained
// within this module:
//
// * All TLS destructors are tracked by *us*, not the windows runtime. This
// * All TLS destructors are tracked by *us*, not the Windows runtime. This
// means that we have a global list of destructors for each TLS key that
// we know about.
// * When a thread exits, we run over the entire list and run dtors for all
// non-null keys. This attempts to match Unix semantics in this regard.
//
// This ends up having the overhead of using a global list, having some
// locks here and there, and in general just adding some more code bloat. We
// attempt to optimize runtime by forgetting keys that don't have
// destructors, but this only gets us so far.
//
// For more details and nitty-gritty, see the code sections below!
//
// [1]: https://www.codeproject.com/Articles/8113/Thread-Local-Storage-The-C-Way
// [2]: https://github.com/ChromiumWebApps/chromium/blob/master/base
// /threading/thread_local_storage_win.cc#L42
// [2]: https://github.com/ChromiumWebApps/chromium/blob/master/base/threading/thread_local_storage_win.cc#L42

// -------------------------------------------------------------------------
// Native bindings
//
// This section is just raw bindings to the native functions that Windows
// provides, There's a few extra calls to deal with destructors.
pub struct StaticKey {
/// The key value shifted up by one. Since TLS_OUT_OF_INDEXES == DWORD::MAX
/// is not a valid key value, this allows us to use zero as sentinel value
/// without risking overflow.
key: AtomicU32,
dtor: Option<Dtor>,
next: AtomicPtr<StaticKey>,
/// Currently, destructors cannot be unregistered, so we cannot use racy
/// initialization for keys. Instead, we need synchronize initialization.
/// Use the Windows-provided `Once` since it does not require TLS.
once: UnsafeCell<c::INIT_ONCE>,
}

#[inline]
pub unsafe fn create(dtor: Option<Dtor>) -> Key {
let key = c::TlsAlloc();
assert!(key != c::TLS_OUT_OF_INDEXES);
if let Some(f) = dtor {
register_dtor(key, f);
impl StaticKey {
#[inline]
pub const fn new(dtor: Option<Dtor>) -> StaticKey {
StaticKey {
key: AtomicU32::new(0),
dtor,
next: AtomicPtr::new(ptr::null_mut()),
once: UnsafeCell::new(c::INIT_ONCE_STATIC_INIT),
}
}
key
}

#[inline]
pub unsafe fn set(key: Key, value: *mut u8) {
let r = c::TlsSetValue(key, value as c::LPVOID);
debug_assert!(r != 0);
}
#[inline]
pub unsafe fn set(&'static self, val: *mut u8) {
let r = c::TlsSetValue(self.key(), val.cast());
debug_assert_eq!(r, c::TRUE);
}

#[inline]
pub unsafe fn get(key: Key) -> *mut u8 {
c::TlsGetValue(key) as *mut u8
}
#[inline]
pub unsafe fn get(&'static self) -> *mut u8 {
c::TlsGetValue(self.key()).cast()
}

#[inline]
pub unsafe fn destroy(_key: Key) {
rtabort!("can't destroy tls keys on windows")
}
#[inline]
unsafe fn key(&'static self) -> Key {
match self.key.load(Acquire) {
0 => self.init(),
key => key - 1,
}
}

#[cold]
unsafe fn init(&'static self) -> Key {
if self.dtor.is_some() {
let mut pending = c::FALSE;
let r = c::InitOnceBeginInitialize(self.once.get(), 0, &mut pending, ptr::null_mut());
assert_eq!(r, c::TRUE);

#[inline]
pub fn requires_synchronized_create() -> bool {
true
if pending == c::FALSE {
// Some other thread initialized the key, load it.
self.key.load(Relaxed) - 1
} else {
let key = c::TlsAlloc();
if key == c::TLS_OUT_OF_INDEXES {
// Wakeup the waiting threads before panicking to avoid deadlock.
c::InitOnceComplete(self.once.get(), c::INIT_ONCE_INIT_FAILED, ptr::null_mut());
panic!("out of TLS indexes");
}

self.key.store(key + 1, Release);
register_dtor(self);

let r = c::InitOnceComplete(self.once.get(), 0, ptr::null_mut());
debug_assert_eq!(r, c::TRUE);

key
}
} else {
// If there is no destructor to clean up, we can use racy initialization.

let key = c::TlsAlloc();
assert_ne!(key, c::TLS_OUT_OF_INDEXES, "out of TLS indexes");

match self.key.compare_exchange(0, key + 1, AcqRel, Acquire) {
Ok(_) => key,
Err(new) => {
// Some other thread completed initialization first, so destroy
// our key and use theirs.
let r = c::TlsFree(key);
debug_assert_eq!(r, c::TRUE);
new - 1
}
}
}
}
}

unsafe impl Send for StaticKey {}
unsafe impl Sync for StaticKey {}

// -------------------------------------------------------------------------
// Dtor registration
//
Expand All @@ -96,29 +150,21 @@ pub fn requires_synchronized_create() -> bool {
// Typically processes have a statically known set of TLS keys which is pretty
// small, and we'd want to keep this memory alive for the whole process anyway
// really.
//
// Perhaps one day we can fold the `Box` here into a static allocation,
// expanding the `StaticKey` structure to contain not only a slot for the TLS
// key but also a slot for the destructor queue on windows. An optimization for
// another day!

static DTORS: AtomicPtr<Node> = AtomicPtr::new(ptr::null_mut());

struct Node {
dtor: Dtor,
key: Key,
next: *mut Node,
}

unsafe fn register_dtor(key: Key, dtor: Dtor) {
let mut node = ManuallyDrop::new(Box::new(Node { key, dtor, next: ptr::null_mut() }));
static DTORS: AtomicPtr<StaticKey> = AtomicPtr::new(ptr::null_mut());

let mut head = DTORS.load(SeqCst);
/// Should only be called once per key, otherwise loops or breaks may occur in
/// the linked list.
unsafe fn register_dtor(key: &'static StaticKey) {
let this = <*const StaticKey>::cast_mut(key);
// Use acquire ordering to pass along the changes done by the previously
// registered keys when we store the new head with release ordering.
let mut head = DTORS.load(Acquire);
loop {
node.next = head;
match DTORS.compare_exchange(head, &mut **node, SeqCst, SeqCst) {
Ok(_) => return, // nothing to drop, we successfully added the node to the list
Err(cur) => head = cur,
key.next.store(head, Relaxed);
match DTORS.compare_exchange_weak(head, this, Release, Acquire) {
Ok(_) => break,
Err(new) => head = new,
}
}
}
Expand Down Expand Up @@ -214,25 +260,29 @@ unsafe extern "system" fn on_tls_callback(h: c::LPVOID, dwReason: c::DWORD, pv:
unsafe fn reference_tls_used() {}
}

#[allow(dead_code)] // actually called above
#[allow(dead_code)] // actually called below
unsafe fn run_dtors() {
let mut any_run = true;
for _ in 0..5 {
if !any_run {
break;
}
any_run = false;
let mut cur = DTORS.load(SeqCst);
let mut any_run = false;

// Use acquire ordering to observe key initialization.
let mut cur = DTORS.load(Acquire);
while !cur.is_null() {
let ptr = c::TlsGetValue((*cur).key);
let key = (*cur).key.load(Relaxed) - 1;
let dtor = (*cur).dtor.unwrap();

let ptr = c::TlsGetValue(key);
if !ptr.is_null() {
c::TlsSetValue((*cur).key, ptr::null_mut());
((*cur).dtor)(ptr as *mut _);
c::TlsSetValue(key, ptr::null_mut());
dtor(ptr as *mut _);
any_run = true;
}

cur = (*cur).next;
cur = (*cur).next.load(Relaxed);
}

if !any_run {
break;
}
}
}
53 changes: 53 additions & 0 deletions library/std/src/sys/windows/thread_local_key/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use super::StaticKey;
use crate::ptr;

#[test]
fn smoke() {
static K1: StaticKey = StaticKey::new(None);
static K2: StaticKey = StaticKey::new(None);

unsafe {
assert!(K1.get().is_null());
assert!(K2.get().is_null());
K1.set(ptr::invalid_mut(1));
K2.set(ptr::invalid_mut(2));
assert_eq!(K1.get() as usize, 1);
assert_eq!(K2.get() as usize, 2);
}
}

#[test]
fn destructors() {
use crate::mem::ManuallyDrop;
use crate::sync::Arc;
use crate::thread;

unsafe extern "C" fn destruct(ptr: *mut u8) {
drop(Arc::from_raw(ptr as *const ()));
}

static KEY: StaticKey = StaticKey::new(Some(destruct));

let shared1 = Arc::new(());
let shared2 = Arc::clone(&shared1);

unsafe {
assert!(KEY.get().is_null());
KEY.set(Arc::into_raw(shared1) as *mut u8);
}

thread::spawn(move || unsafe {
assert!(KEY.get().is_null());
KEY.set(Arc::into_raw(shared2) as *mut u8);
})
.join()
.unwrap();

// Leak the Arc, let the TLS destructor clean it up.
let shared1 = unsafe { ManuallyDrop::new(Arc::from_raw(KEY.get() as *const ())) };
assert_eq!(
Arc::strong_count(&shared1),
1,
"destructor should have dropped the other reference on thread exit"
);
}
Loading

0 comments on commit fa0ca78

Please sign in to comment.