Skip to content

Commit

Permalink
Implement ComWrappers.RegisterForTrackerSupport to be able create CCW
Browse files Browse the repository at this point in the history
This moves a bit towards dotnet#306 and dotnet#1453
  • Loading branch information
kant2002 committed Sep 9, 2021
1 parent 46831cd commit 1324c71
Showing 1 changed file with 191 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@ namespace System.Runtime.InteropServices
/// </summary>
public abstract partial class ComWrappers
{
private const int TrackerRefShift = 32;
private const ulong TrackerRefCounter = 1UL << TrackerRefShift;
private const ulong DestroySentinel = 0x0000000080000000UL;
private const ulong TrackerRefCountMask = 0xffffffff00000000UL;
private const ulong ComRefCountMask = 0x000000007fffffffUL;

internal static IntPtr DefaultIUnknownVftblPtr { get; } = CreateDefaultIUnknownVftbl();
internal static IntPtr DefaultIReferenceTrackerTargetVftblPtr { get; } = CreateDefaultIReferenceTrackerTargetVftbl();

internal static Guid IID_IUnknown = new Guid(0x00000000, 0x0000, 0x0000, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46);
internal static Guid IID_IReferenceTrackerTarget = new Guid(0x64bd43f8, 0xbfee, 0x4ec4, 0xb7, 0xeb, 0x29, 0x35, 0x15, 0x8d, 0xae, 0x21);

private readonly ConditionalWeakTable<object, ManagedObjectWrapperHolder> _ccwTable = new ConditionalWeakTable<object, ManagedObjectWrapperHolder>();
private readonly Lock _lock = new Lock();
Expand Down Expand Up @@ -55,26 +63,100 @@ internal unsafe struct InternalComInterfaceDispatch
internal ManagedObjectWrapper* _thisPtr;
}

internal enum CreateComInterfaceFlagsEx
{
None = 0,

/// <summary>
/// The caller will provide an IUnknown Vtable.
/// </summary>
/// <remarks>
/// This is useful in scenarios when the caller has no need to rely on an IUnknown instance
/// that is used when running managed code is not possible (i.e. during a GC). In traditional
/// COM scenarios this is common, but scenarios involving <see href="https://docs.microsoft.com/windows/win32/api/windows.ui.xaml.hosting.referencetracker/nn-windows-ui-xaml-hosting-referencetracker-ireferencetrackertarget">Reference Tracker hosting</see>
/// calling of the IUnknown API during a GC is possible.
/// </remarks>
CallerDefinedIUnknown = 1,

/// <summary>
/// Flag used to indicate the COM interface should implement <see href="https://docs.microsoft.com/windows/win32/api/windows.ui.xaml.hosting.referencetracker/nn-windows-ui-xaml-hosting-referencetracker-ireferencetrackertarget">IReferenceTrackerTarget</see>.
/// When this flag is passed, the resulting COM interface will have an internal implementation of IUnknown
/// and as such none should be supplied by the caller.
/// </summary>
TrackerSupport = 2,

LacksICustomQueryInterface = 1 << 29,
IsComActivated = 1 << 30,
IsPegged = 1 << 31,

InternalMask = IsPegged | IsComActivated | LacksICustomQueryInterface,
}

internal unsafe struct ManagedObjectWrapper
{
public IntPtr Target; // This is GC Handle
public uint RefCount;
public ulong RefCount;

public int UserDefinedCount;
public ComInterfaceEntry* UserDefined;
internal InternalComInterfaceDispatch* Dispatches;

internal CreateComInterfaceFlags Flags;
internal CreateComInterfaceFlagsEx Flags;

public uint AddRef()
{
return Interlocked.Increment(ref RefCount);
return GetComCount(Interlocked.Increment(ref RefCount));
}

public uint Release()
{
Debug.Assert(RefCount != 0);
return Interlocked.Decrement(ref RefCount);
Debug.Assert(GetComCount(RefCount) != 0);
return GetComCount(Interlocked.Decrement(ref RefCount));
}

public uint AddRefFromReferenceTracker()
{
ulong prev;
ulong curr;
do
{
prev = RefCount;
curr = prev + TrackerRefCounter;
} while (Interlocked.CompareExchange(ref RefCount, curr, prev) != prev);

return GetTrackerCount(curr);
}

public uint ReleaseFromReferenceTracker()
{
Debug.Assert(GetTrackerCount(RefCount) != 0);
ulong prev;
ulong curr;
do
{
prev = RefCount;
curr = prev - TrackerRefCounter;
}
while (Interlocked.CompareExchange(ref RefCount, curr, prev) != prev);

// If we observe the destroy sentinel, then this release
// must destroy the wrapper.
if (RefCount == DestroySentinel)
Destroy();

return GetTrackerCount(RefCount);
}

public uint Peg()
{
SetFlag(CreateComInterfaceFlagsEx.IsPegged);
return HResults.S_OK;
}

public uint Unpeg()
{
ResetFlag(CreateComInterfaceFlagsEx.IsPegged);
return HResults.S_OK;
}

public unsafe int QueryInterface(in Guid riid, out IntPtr ppvObject)
Expand Down Expand Up @@ -114,12 +196,30 @@ public unsafe void Destroy()

private unsafe IntPtr AsRuntimeDefined(in Guid riid)
{
if ((Flags & CreateComInterfaceFlags.CallerDefinedIUnknown) == CreateComInterfaceFlags.None)
if ((Flags & CreateComInterfaceFlagsEx.CallerDefinedIUnknown) == CreateComInterfaceFlagsEx.None)
{
if (riid == IID_IUnknown)
{
return (IntPtr)(Dispatches + UserDefinedCount);
}

if ((Flags & CreateComInterfaceFlagsEx.TrackerSupport) == CreateComInterfaceFlagsEx.TrackerSupport)
{
if (riid == IID_IReferenceTrackerTarget)
{
return (IntPtr)(Dispatches + UserDefinedCount + 1);
}
}
}
else
{
if ((Flags & CreateComInterfaceFlagsEx.TrackerSupport) == CreateComInterfaceFlagsEx.TrackerSupport)
{
if (riid == IID_IReferenceTrackerTarget)
{
return (IntPtr)(Dispatches + UserDefinedCount);
}
}
}

return IntPtr.Zero;
Expand All @@ -137,6 +237,33 @@ private unsafe IntPtr AsUserDefined(in Guid riid)

return IntPtr.Zero;
}

private void SetFlag(CreateComInterfaceFlagsEx flag)
{
int setMask = (int)flag;
Interlocked.Or(ref Unsafe.As<CreateComInterfaceFlagsEx, int>(ref Flags), setMask);
}

private void ResetFlag(CreateComInterfaceFlagsEx flag)
{
int resetMask = ~(int)flag;
Interlocked.And(ref Unsafe.As<CreateComInterfaceFlagsEx, int>(ref Flags), resetMask);
}

private static uint GetTrackerCount(ulong c)
{
return (uint)((c & TrackerRefCountMask) >> TrackerRefShift);
}

private static uint GetComCount(ulong c)
{
return (uint)(c & ComRefCountMask);
}

private static bool IsMarkedToDestroy(ulong c)
{
return (c & DestroySentinel) != 0;
}
}

internal unsafe class ManagedObjectWrapperHolder
Expand Down Expand Up @@ -184,12 +311,10 @@ public NativeObjectWrapper(IntPtr externalComObject, ComWrappers comWrappers, ob
}
}

#if false
/// <summary>
/// Globally registered instance of the ComWrappers class for reference tracker support.
/// </summary>
private static ComWrappers? s_globalInstanceForTrackerSupport;
#endif

/// <summary>
/// Globally registered instance of the ComWrappers class for marshalling.
Expand Down Expand Up @@ -240,6 +365,11 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
runtimeDefinedVtable[runtimeDefinedCount++] = DefaultIUnknownVftblPtr;
}

if ((flags & CreateComInterfaceFlags.TrackerSupport) == CreateComInterfaceFlags.TrackerSupport)
{
runtimeDefinedVtable[runtimeDefinedCount++] = DefaultIReferenceTrackerTargetVftblPtr;
}

// Compute size for ManagedObjectWrapper instance.
int totalDefinedCount = runtimeDefinedCount + userDefinedCount;

Expand All @@ -262,7 +392,7 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
mow->RefCount = 1;
mow->UserDefinedCount = userDefinedCount;
mow->UserDefined = userDefined;
mow->Flags = flags;
mow->Flags = (CreateComInterfaceFlagsEx)flags;
mow->Dispatches = pDispatches;
return mow;
}
Expand Down Expand Up @@ -364,6 +494,9 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal(
if (flags.HasFlag(CreateObjectFlags.Aggregation))
throw new NotImplementedException();

if (flags.HasFlag(CreateObjectFlags.TrackerObject))
throw new NotImplementedException();

if (flags.HasFlag(CreateObjectFlags.Unwrap))
{
var comInterfaceDispatch = TryGetComInterfaceDispatch(externalComObject);
Expand Down Expand Up @@ -440,17 +573,13 @@ private void RemoveRCWFromCache(IntPtr comPointer)
/// </remarks>
public static void RegisterForTrackerSupport(ComWrappers instance)
{
#if false
if (instance == null)
throw new ArgumentNullException(nameof(instance));

if (null != Interlocked.CompareExchange(ref s_globalInstanceForTrackerSupport, instance, null))
{
throw new InvalidOperationException(SR.InvalidOperation_ResetGlobalComWrappersInstance);
}
#else
throw new NotImplementedException();
#endif
}

/// <summary>
Expand Down Expand Up @@ -554,11 +683,60 @@ internal static unsafe uint IUnknown_Release(IntPtr pThis)
return refcount;
}

[UnmanagedCallersOnly]
internal static unsafe int IReferenceTrackerTarget_QueryInterface(IntPtr pThis, Guid* guid, IntPtr* ppObject)
{
ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis);
return wrapper->QueryInterface(in *guid, out *ppObject);
}

[UnmanagedCallersOnly]
internal static unsafe uint IReferenceTrackerTarget_AddRefFromReferenceTracker(IntPtr pThis)
{
ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis);
return wrapper->AddRefFromReferenceTracker();
}

[UnmanagedCallersOnly]
internal static unsafe uint IReferenceTrackerTarget_ReleaseFromReferenceTracker(IntPtr pThis)
{
ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis);
return wrapper->ReleaseFromReferenceTracker();
}

[UnmanagedCallersOnly]
internal static unsafe uint IReferenceTrackerTarget_Peg(IntPtr pThis)
{
ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis);
return wrapper->Peg();
}

[UnmanagedCallersOnly]
internal static unsafe uint IReferenceTrackerTarget_Unpeg(IntPtr pThis)
{
ManagedObjectWrapper* wrapper = ComInterfaceDispatch.ToManagedObjectWrapper((ComInterfaceDispatch*)pThis);
return wrapper->Unpeg();
}

private static unsafe IntPtr CreateDefaultIUnknownVftbl()
{
IntPtr* vftbl = (IntPtr*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComWrappers), 3 * sizeof(IntPtr));
GetIUnknownImpl(out vftbl[0], out vftbl[1], out vftbl[2]);
return (IntPtr)vftbl;
}

private static unsafe IntPtr CreateDefaultIReferenceTrackerTargetVftbl()
{
IntPtr* vftbl = (IntPtr*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComWrappers), 7 * sizeof(IntPtr));
GetIUnknownImpl(out vftbl[0], out vftbl[1], out vftbl[2]);
vftbl[0] = (IntPtr)(delegate* unmanaged<IntPtr, Guid*, IntPtr*, int>)&ComWrappers.IReferenceTrackerTarget_QueryInterface;
vftbl[1] = (IntPtr)(delegate* unmanaged<IntPtr, uint>)&ComWrappers.IUnknown_AddRef;
vftbl[2] = (IntPtr)(delegate* unmanaged<IntPtr, uint>)&ComWrappers.IUnknown_Release;
vftbl[3] = (IntPtr)(delegate* unmanaged<IntPtr, uint>)&ComWrappers.IReferenceTrackerTarget_AddRefFromReferenceTracker;
vftbl[4] = (IntPtr)(delegate* unmanaged<IntPtr, uint>)&ComWrappers.IReferenceTrackerTarget_ReleaseFromReferenceTracker;
vftbl[5] = (IntPtr)(delegate* unmanaged<IntPtr, uint>)&ComWrappers.IReferenceTrackerTarget_Peg;
vftbl[6] = (IntPtr)(delegate* unmanaged<IntPtr, uint>)&ComWrappers.IReferenceTrackerTarget_Unpeg;
return (IntPtr)vftbl;
}
}
}

0 comments on commit 1324c71

Please sign in to comment.