Skip to content
Merged
54 changes: 54 additions & 0 deletions src/coreclr/nativeaot/Runtime/thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,23 @@ bool Thread::CheckPendingRedirect(PCODE eip)

#endif // TARGET_X86

void Thread::SetInterrupted(bool isInterrupted)
{
if (isInterrupted)
{
SetState(TSF_Interrupted);
}
else
{
ClearState(TSF_Interrupted);
}
}

bool Thread::CheckInterrupted()
{
return IsStateSet(TSF_Interrupted);
}

#endif // !DACCESS_COMPILE

void Thread::ValidateExInfoStack()
Expand Down Expand Up @@ -1323,6 +1340,43 @@ FCIMPL0(size_t, RhGetDefaultStackSize)
}
FCIMPLEND

#ifdef TARGET_WINDOWS
// Native APC callback for Thread.Interrupt
// This callback sets the interrupt flag on the current thread
static VOID CALLBACK InterruptApcCallback(ULONG_PTR /* parameter */)
{
Thread* pCurrentThread = ThreadStore::RawGetCurrentThread();
if (!pCurrentThread->IsInitialized())
{
// If the thread was interrupted before it was started
// the thread won't have been initialized.
// Attach the thread here if it's the first time we're seeing it.
ThreadStore::AttachCurrentThread();
}

pCurrentThread->SetInterrupted(true);
}

// Function to get the address of the interrupt APC callback
FCIMPL0(void*, RhGetInterruptApcCallback)
{
return (void*)InterruptApcCallback;
}
FCIMPLEND

FCIMPL0(FC_BOOL_RET, RhCheckAndClearPendingInterrupt)
{
Thread* pCurrentThread = ThreadStore::RawGetCurrentThread();
if (pCurrentThread->CheckInterrupted())
{
pCurrentThread->SetInterrupted(false);
FC_RETURN_BOOL(true);
}
FC_RETURN_BOOL(false);
}
FCIMPLEND
#endif // TARGET_WINDOWS

// Standard calling convention variant and actual implementation for RhpReversePInvokeAttachOrTrapThread
EXTERN_C NOINLINE void FASTCALL RhpReversePInvokeAttachOrTrapThread2(ReversePInvokeFrame* pFrame)
{
Expand Down
6 changes: 5 additions & 1 deletion src/coreclr/nativeaot/Runtime/thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct ee_alloc_context

struct RuntimeThreadLocals
{
ee_alloc_context m_eeAllocContext;
ee_alloc_context m_eeAllocContext;
uint32_t volatile m_ThreadStateFlags; // see Thread::ThreadStateFlags enum
PInvokeTransitionFrame* m_pTransitionFrame;
PInvokeTransitionFrame* m_pDeferredTransitionFrame; // see Thread::EnablePreemptiveMode
Expand Down Expand Up @@ -214,6 +214,7 @@ class Thread : private RuntimeThreadLocals
//
// On Unix this is an optimization to not queue up more signals when one is
// still being processed.
TSF_Interrupted = 0x00000200, // Set to indicate Thread.Interrupt() has been called on this thread
};
private:

Expand Down Expand Up @@ -390,6 +391,9 @@ class Thread : private RuntimeThreadLocals
void SetPendingRedirect(PCODE eip);
bool CheckPendingRedirect(PCODE eip);
#endif

void SetInterrupted(bool isInterrupted);
bool CheckInterrupted();
};

#ifndef __GCENV_BASE_INCLUDED__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,14 @@ internal static IntPtr RhGetModuleSection(TypeManagerHandle module, ReadyToRunSe
[RuntimeImport(RuntimeLibrary, "RhGetDefaultStackSize")]
internal static extern unsafe IntPtr RhGetDefaultStackSize();

[MethodImplAttribute(MethodImplOptions.InternalCall)]
[RuntimeImport(RuntimeLibrary, "RhGetInterruptApcCallback")]
internal static extern unsafe delegate* unmanaged<nuint, void> RhGetInterruptApcCallback();

[MethodImplAttribute(MethodImplOptions.InternalCall)]
[RuntimeImport(RuntimeLibrary, "RhCheckAndClearPendingInterrupt")]
internal static extern bool RhCheckAndClearPendingInterrupt();

[MethodImplAttribute(MethodImplOptions.InternalCall)]
[RuntimeImport("*", "RhGetCurrentThunkContext")]
internal static extern IntPtr GetCurrentInteropThunkContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,59 @@ public sealed partial class Thread

partial void PlatformSpecificInitialize();

internal static void SleepInternal(int millisecondsTimeout)
{
Debug.Assert(millisecondsTimeout >= Timeout.Infinite);

CheckForPendingInterrupt();

Thread currentThread = CurrentThread;
if (millisecondsTimeout == Timeout.Infinite)
{
// Infinite wait - use alertable wait
currentThread.SetWaitSleepJoinState();
uint result;
while (true)
{
result = Interop.Kernel32.SleepEx(Timeout.UnsignedInfinite, true);
if (result != Interop.Kernel32.WAIT_IO_COMPLETION)
{
break;
}
CheckForPendingInterrupt();
}

currentThread.ClearWaitSleepJoinState();
}
else
{
// Timed wait - use alertable wait
currentThread.SetWaitSleepJoinState();
long startTime = Environment.TickCount64;
while (true)
{
uint result = Interop.Kernel32.SleepEx((uint)millisecondsTimeout, true);
if (result != Interop.Kernel32.WAIT_IO_COMPLETION)
{
break;
}
// Check if this was our interrupt APC
CheckForPendingInterrupt();
// Handle APC completion by adjusting timeout and retrying
long currentTime = Environment.TickCount64;
long elapsed = currentTime - startTime;
if (elapsed >= millisecondsTimeout)
{
break;
}
millisecondsTimeout -= (int)elapsed;
startTime = currentTime;
}

currentThread.ClearWaitSleepJoinState();
}
}

// Platform-specific initialization of foreign threads, i.e. threads not created by Thread.Start
private void PlatformSpecificInitializeExistingThread()
{
Expand Down Expand Up @@ -154,18 +207,57 @@ private bool JoinInternal(int millisecondsTimeout)

try
{
int result;

if (millisecondsTimeout == 0)
{
result = (int)Interop.Kernel32.WaitForSingleObject(waitHandle.DangerousGetHandle(), 0);
int result = (int)Interop.Kernel32.WaitForSingleObject(waitHandle.DangerousGetHandle(), 0);
return result == (int)Interop.Kernel32.WAIT_OBJECT_0;
}
else
{
result = WaitHandle.WaitOneCore(waitHandle.DangerousGetHandle(), millisecondsTimeout, useTrivialWaits: false);
Thread currentThread = CurrentThread;
currentThread.SetWaitSleepJoinState();
uint result;
if (millisecondsTimeout == Timeout.Infinite)
{
// Infinite wait
while (true)
{
result = Interop.Kernel32.WaitForSingleObjectEx(waitHandle.DangerousGetHandle(), Timeout.UnsignedInfinite, Interop.BOOL.TRUE);
if (result != Interop.Kernel32.WAIT_IO_COMPLETION)
{
break;
}
// Check if this was our interrupt APC
CheckForPendingInterrupt();
}
}
else
{
long startTime = Environment.TickCount64;
while (true)
{
result = Interop.Kernel32.WaitForSingleObjectEx(waitHandle.DangerousGetHandle(), (uint)millisecondsTimeout, Interop.BOOL.TRUE);
if (result != Interop.Kernel32.WAIT_IO_COMPLETION)
{
break;
}
// Check if this was our interrupt APC
CheckForPendingInterrupt();
// Handle APC completion by adjusting timeout and retrying
long currentTime = Environment.TickCount64;
long elapsed = currentTime - startTime;
if (elapsed >= millisecondsTimeout)
{
result = Interop.Kernel32.WAIT_TIMEOUT;
break;
}
millisecondsTimeout -= (int)elapsed;
startTime = currentTime;
}
}
currentThread.ClearWaitSleepJoinState();
return result == (int)Interop.Kernel32.WAIT_OBJECT_0;
}

return result == (int)Interop.Kernel32.WAIT_OBJECT_0;
}
finally
{
Expand Down Expand Up @@ -212,6 +304,13 @@ private unsafe bool CreateThread(GCHandle<Thread> thisThreadHandle)
// CoreCLR ignores OS errors while setting the priority, so do we
SetPriorityLive(_priority);

// If the thread was interrupted before it was started, queue the interruption now
if (GetThreadStateBit(Interrupted))
{
ClearThreadStateBit(Interrupted);
Interrupt();
}

Interop.Kernel32.ResumeThread(_osHandle);
return true;
}
Expand Down Expand Up @@ -393,7 +492,39 @@ internal static Thread EnsureThreadPoolThreadInitialized()
return InitializeExistingThreadPoolThread();
}

public void Interrupt() { throw new PlatformNotSupportedException(); }
public void Interrupt()
{
using (_lock.EnterScope())
{
// Thread.Interrupt for dead thread should do nothing
if (IsDead())
{
return;
}

// Thread.Interrupt for thread that has not been started yet should queue a pending interrupt
// for when we actually create the thread.
if (_osHandle?.IsInvalid ?? true)
{
SetThreadStateBit(Interrupted);
return;
}

unsafe
{
Interop.Kernel32.QueueUserAPC(RuntimeImports.RhGetInterruptApcCallback(), _osHandle, 0);
}
}
}

internal static void CheckForPendingInterrupt()
{
if (RuntimeImports.RhCheckAndClearPendingInterrupt())
{
CurrentThread.ClearWaitSleepJoinState();
throw new ThreadInterruptedException();
}
}

internal static bool ReentrantWaitsEnabled =>
GetCurrentApartmentType() == ApartmentType.STA;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ public sealed partial class Thread
{
// Extra bits used in _threadState
private const ThreadState ThreadPoolThread = (ThreadState)0x1000;
#if TARGET_WINDOWS
private const ThreadState Interrupted = (ThreadState)0x2000;
#endif

// Bits of _threadState that are returned by the ThreadState property
private const ThreadState PublicThreadStateMask = (ThreadState)0x1FF;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,12 @@ internal enum ThreadPriority : int
[LibraryImport(Libraries.Kernel32, SetLastError = true)]
[return: MarshalAs(UnmanagedType.Bool)]
internal static partial bool GetThreadIOPendingFlag(nint hThread, out BOOL lpIOIsPending);

[LibraryImport(Libraries.Kernel32, SetLastError = true)]
[return: MarshalAs(UnmanagedType.Bool)]
internal static unsafe partial bool QueueUserAPC(delegate* unmanaged<nuint, void> pfnAPC, SafeHandle hThread, nuint dwData);

[LibraryImport(Libraries.Kernel32)]
internal static partial uint SleepEx(uint dwMilliseconds, [MarshalAs(UnmanagedType.Bool)] bool bAlertable);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public sealed partial class Thread
{
internal static void UninterruptibleSleep0() => Interop.Kernel32.Sleep(0);

#if !CORECLR
#if MONO
private static void SleepInternal(int millisecondsTimeout)
{
Debug.Assert(millisecondsTimeout >= -1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ private static unsafe int WaitForMultipleObjectsIgnoringSyncContext(IntPtr* pHan
}

int result;

Thread.CheckForPendingInterrupt();

while (true)
{
#if NATIVEAOT
Expand All @@ -75,8 +78,10 @@ private static unsafe int WaitForMultipleObjectsIgnoringSyncContext(IntPtr* pHan
if (result != Interop.Kernel32.WAIT_IO_COMPLETION)
break;

Thread.CheckForPendingInterrupt();

// Handle APC completion by adjusting timeout and retrying
if (millisecondsTimeout != -1)
if (millisecondsTimeout != Timeout.Infinite)
{
long currentTime = Environment.TickCount64;
long elapsed = currentTime - startTime;
Expand All @@ -89,6 +94,7 @@ private static unsafe int WaitForMultipleObjectsIgnoringSyncContext(IntPtr* pHan
startTime = currentTime;
}
}

currentThread.ClearWaitSleepJoinState();

if (result == Interop.Kernel32.WAIT_FAILED)
Expand Down Expand Up @@ -134,12 +140,16 @@ private static int SignalAndWaitCore(IntPtr handleToSignal, IntPtr handleToWaitO
startTime = Environment.TickCount64;
}

Thread.CheckForPendingInterrupt();

// Signal the object and wait for the first time
int ret = (int)Interop.Kernel32.SignalObjectAndWait(handleToSignal, handleToWaitOn, (uint)millisecondsTimeout, Interop.BOOL.TRUE);

// Handle APC completion by retrying with WaitForSingleObjectEx (without signaling again)
while (ret == Interop.Kernel32.WAIT_IO_COMPLETION)
{
Thread.CheckForPendingInterrupt();

if (millisecondsTimeout != -1)
{
long currentTime = Environment.TickCount64;
Expand Down
4 changes: 1 addition & 3 deletions src/libraries/System.Threading.Thread/tests/ThreadTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ public static void AbortSuspendTest()
verify();

e.Set();
waitForThread();
waitForThread();
}

private static void VerifyLocalDataSlot(LocalDataStoreSlot slot)
Expand Down Expand Up @@ -916,7 +916,6 @@ public static void LocalDataSlotTest()

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
[ActiveIssue("https://github.com/dotnet/runtime/issues/49521", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)]
[ActiveIssue("https://github.com/dotnet/runtime/issues/69919", typeof(PlatformDetection), nameof(PlatformDetection.IsNativeAot))]
public static void InterruptTest()
{
// Interrupting a thread that is not blocked does not do anything, but once the thread starts blocking, it gets
Expand Down Expand Up @@ -966,7 +965,6 @@ public static void InterruptTest()
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
[ActiveIssue("https://github.com/dotnet/runtime/issues/69919", typeof(PlatformDetection), nameof(PlatformDetection.IsNativeAot))]
[ActiveIssue("https://github.com/dotnet/runtime/issues/49521", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)]
public static void InterruptInFinallyBlockTest_SkipOnDesktopFramework()
{
Expand Down
Loading
Loading