diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs
index f5d0e06b016a3..59f3d47ac3585 100644
--- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs
+++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs
@@ -118,16 +118,35 @@ private struct ComInterfaceInstance
/// Flags used to configure the generated interface.
/// The generated COM interface that can be passed outside the .NET runtime.
public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfaceFlags flags)
+ {
+ IntPtr ptr;
+ if (!TryGetOrCreateComInterfaceForObjectInternal(this, instance, flags, out ptr))
+ throw new ArgumentException();
+
+ return ptr;
+ }
+
+ ///
+ /// Create a COM representation of the supplied object that can be passed to a non-managed environment.
+ ///
+ /// The implementation to use when creating the COM representation.
+ /// The managed object to expose outside the .NET runtime.
+ /// Flags used to configure the generated interface.
+ /// The generated COM interface that can be passed outside the .NET runtime or IntPtr.Zero if it could not be created.
+ /// Returns true if a COM representation could be created, false otherwise
+ ///
+ /// If is null, the global instance (if registered) will be used.
+ ///
+ private static bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers? impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue)
{
if (instance == null)
throw new ArgumentNullException(nameof(instance));
- ComWrappers impl = this;
- return GetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), ObjectHandleOnStack.Create(ref instance), flags);
+ return TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), ObjectHandleOnStack.Create(ref instance), flags, out retValue);
}
[DllImport(RuntimeHelpers.QCall)]
- private static extern IntPtr GetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, ObjectHandleOnStack instance, CreateComInterfaceFlags flags);
+ private static extern bool TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, ObjectHandleOnStack instance, CreateComInterfaceFlags flags, out IntPtr retValue);
///
/// Compute the desired Vtable for respecting the values of .
@@ -140,13 +159,23 @@ public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfa
/// All memory returned from this function must either be unmanaged memory, pinned managed memory, or have been
/// allocated with the API.
///
- /// If the interface entries cannot be created and null
is returned, the call to will throw a .
+ /// If the interface entries cannot be created and a negative or null
and a non-zero are returned,
+ /// the call to will throw a .
///
protected unsafe abstract ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count);
// Call to execute the abstract instance function
internal static unsafe void* CallComputeVtables(ComWrappers? comWrappersImpl, object obj, CreateComInterfaceFlags flags, out int count)
- => (comWrappersImpl ?? s_globalInstance!).ComputeVtables(obj, flags, out count);
+ {
+ ComWrappers? impl = comWrappersImpl ?? s_globalInstance;
+ if (impl is null)
+ {
+ count = -1;
+ return null;
+ }
+
+ return impl.ComputeVtables(obj, flags, out count);
+ }
///
/// Get the currently registered managed object or creates a new managed object and registers it.
@@ -156,7 +185,11 @@ public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfa
/// Returns a managed object associated with the supplied external COM object.
public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags)
{
- return GetOrCreateObjectForComInstanceInternal(externalComObject, flags, null);
+ object? obj;
+ if (!TryGetOrCreateObjectForComInstanceInternal(this, externalComObject, flags, null, out obj))
+ throw new ArgumentNullException();
+
+ return obj!;
}
///
@@ -172,7 +205,13 @@ public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateOb
// Call to execute the abstract instance function
internal static object? CallCreateObject(ComWrappers? comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags)
- => (comWrappersImpl ?? s_globalInstance!).CreateObject(externalComObject, flags);
+ {
+ ComWrappers? impl = comWrappersImpl ?? s_globalInstance;
+ if (impl == null)
+ return null;
+
+ return impl.CreateObject(externalComObject, flags);
+ }
///
/// Get the currently registered managed object or uses the supplied managed object and registers it.
@@ -189,24 +228,37 @@ public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, Create
if (wrapper == null)
throw new ArgumentNullException(nameof(externalComObject));
- return GetOrCreateObjectForComInstanceInternal(externalComObject, flags, wrapper);
+ object? obj;
+ if (!TryGetOrCreateObjectForComInstanceInternal(this, externalComObject, flags, wrapper, out obj))
+ throw new ArgumentNullException();
+
+ return obj!;
}
- private object GetOrCreateObjectForComInstanceInternal(IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe)
+ ///
+ /// Get the currently registered managed object or creates a new managed object and registers it.
+ ///
+ /// The implementation to use when creating the managed object.
+ /// Object to import for usage into the .NET runtime.
+ /// Flags used to describe the external object.
+ /// The to be used as the wrapper for the external object.
+ /// The managed object associated with the supplied external COM object or null if it could not be created.
+ /// Returns true if a managed object could be retrieved/created, false otherwise
+ ///
+ /// If is null, the global instance (if registered) will be used.
+ ///
+ private static bool TryGetOrCreateObjectForComInstanceInternal(ComWrappers? impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe, out object? retValue)
{
if (externalComObject == IntPtr.Zero)
throw new ArgumentNullException(nameof(externalComObject));
- ComWrappers impl = this;
object? wrapperMaybeLocal = wrapperMaybe;
- object? retValue = null;
- GetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue));
-
- return retValue!;
+ retValue = null;
+ return TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue));
}
[DllImport(RuntimeHelpers.QCall)]
- private static extern void GetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue);
+ private static extern bool TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue);
///
/// Called when a request is made for a collection of objects to be released outside of normal object or COM interface lifetime.
@@ -235,8 +287,14 @@ public void RegisterAsGlobalInstance()
{
throw new InvalidOperationException(SR.InvalidOperation_ResetGlobalComWrappersInstance);
}
+
+ SetGlobalInstanceRegistered();
}
+ [DllImport(RuntimeHelpers.QCall)]
+ [SuppressGCTransition]
+ private static extern void SetGlobalInstanceRegistered();
+
///
/// Get the runtime provided IUnknown implementation.
///
diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs
index 07b7ce866c3d6..e9b8b3df601b5 100644
--- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs
+++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs
@@ -326,6 +326,11 @@ public static string GetTypeInfoName(ITypeInfo typeInfo)
///
public static IntPtr /* IUnknown* */ GetIUnknownForObject(object o)
{
+ if (o is null)
+ {
+ throw new ArgumentNullException(nameof(o));
+ }
+
return GetIUnknownForObjectNative(o, false);
}
@@ -344,6 +349,11 @@ public static string GetTypeInfoName(ITypeInfo typeInfo)
///
public static IntPtr /* IDispatch */ GetIDispatchForObject(object o)
{
+ if (o is null)
+ {
+ throw new ArgumentNullException(nameof(o));
+ }
+
return GetIDispatchForObjectNative(o, false);
}
@@ -356,6 +366,16 @@ public static string GetTypeInfoName(ITypeInfo typeInfo)
///
public static IntPtr /* IUnknown* */ GetComInterfaceForObject(object o, Type T)
{
+ if (o is null)
+ {
+ throw new ArgumentNullException(nameof(o));
+ }
+
+ if (T is null)
+ {
+ throw new ArgumentNullException(nameof(T));
+ }
+
return GetComInterfaceForObjectNative(o, T, false, true);
}
@@ -368,15 +388,48 @@ public static string GetTypeInfoName(ITypeInfo typeInfo)
///
public static IntPtr /* IUnknown* */ GetComInterfaceForObject(object o, Type T, CustomQueryInterfaceMode mode)
{
+ if (o is null)
+ {
+ throw new ArgumentNullException(nameof(o));
+ }
+
+ if (T is null)
+ {
+ throw new ArgumentNullException(nameof(T));
+ }
+
bool bEnableCustomizedQueryInterface = ((mode == CustomQueryInterfaceMode.Allow) ? true : false);
return GetComInterfaceForObjectNative(o, T, false, bEnableCustomizedQueryInterface);
}
[MethodImpl(MethodImplOptions.InternalCall)]
- private static extern IntPtr /* IUnknown* */ GetComInterfaceForObjectNative(object o, Type t, bool onlyInContext, bool fEnalbeCustomizedQueryInterface);
+ private static extern IntPtr /* IUnknown* */ GetComInterfaceForObjectNative(object o, Type t, bool onlyInContext, bool fEnableCustomizedQueryInterface);
+
+ ///
+ /// Return the managed object representing the IUnknown*
+ ///
+ public static object GetObjectForIUnknown(IntPtr /* IUnknown* */ pUnk)
+ {
+ if (pUnk == IntPtr.Zero)
+ {
+ throw new ArgumentNullException(nameof(pUnk));
+ }
+
+ return GetObjectForIUnknownNative(pUnk);
+ }
[MethodImpl(MethodImplOptions.InternalCall)]
- public static extern object GetObjectForIUnknown(IntPtr /* IUnknown* */ pUnk);
+ private static extern object GetObjectForIUnknownNative(IntPtr /* IUnknown* */ pUnk);
+
+ public static object GetUniqueObjectForIUnknown(IntPtr unknown)
+ {
+ if (unknown == IntPtr.Zero)
+ {
+ throw new ArgumentNullException(nameof(unknown));
+ }
+
+ return GetUniqueObjectForIUnknownNative(unknown);
+ }
///
/// Return a unique Object given an IUnknown. This ensures that you receive a fresh
@@ -385,7 +438,7 @@ public static string GetTypeInfoName(ITypeInfo typeInfo)
/// ReleaseComObject on a RCW and not worry about other active uses ofsaid RCW.
///
[MethodImpl(MethodImplOptions.InternalCall)]
- public static extern object GetUniqueObjectForIUnknown(IntPtr unknown);
+ private static extern object GetUniqueObjectForIUnknownNative(IntPtr unknown);
///
/// Return an Object for IUnknown, using the Type T.
diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs
index 8c94325722e13..691113f665be0 100644
--- a/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs
+++ b/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs
@@ -690,7 +690,7 @@ internal static class InterfaceMarshaler
internal static extern IntPtr ConvertToNative(object objSrc, IntPtr itfMT, IntPtr classMT, int flags);
[MethodImpl(MethodImplOptions.InternalCall)]
- internal static extern object ConvertToManaged(IntPtr pUnk, IntPtr itfMT, IntPtr classMT, int flags);
+ internal static extern object ConvertToManaged(IntPtr ppUnk, IntPtr itfMT, IntPtr classMT, int flags);
[DllImport(RuntimeHelpers.QCall)]
internal static extern void ClearNative(IntPtr pUnk);
diff --git a/src/coreclr/src/interop/comwrappers.hpp b/src/coreclr/src/interop/comwrappers.hpp
index d2337746e39b8..e9838646b0189 100644
--- a/src/coreclr/src/interop/comwrappers.hpp
+++ b/src/coreclr/src/interop/comwrappers.hpp
@@ -16,9 +16,10 @@ enum class CreateComInterfaceFlagsEx : int32_t
TrackerSupport = InteropLib::Com::CreateComInterfaceFlags_TrackerSupport,
// Highest bit is reserved for internal usage
+ IsComActivated = 1 << 30,
IsPegged = 1 << 31,
- InternalMask = IsPegged,
+ InternalMask = IsPegged | IsComActivated,
};
DEFINE_ENUM_FLAG_OPERATORS(CreateComInterfaceFlagsEx);
diff --git a/src/coreclr/src/interop/inc/interoplib.h b/src/coreclr/src/interop/inc/interoplib.h
index df59b8b90a584..fa0e0c3e25918 100644
--- a/src/coreclr/src/interop/inc/interoplib.h
+++ b/src/coreclr/src/interop/inc/interoplib.h
@@ -45,6 +45,12 @@ namespace InteropLib
// Reactivate the supplied wrapper.
HRESULT ReactivateWrapper(_In_ IUnknown* wrapper, _In_ InteropLib::OBJECTHANDLE handle) noexcept;
+ // Get the object for the supplied wrapper
+ HRESULT GetObjectForWrapper(_In_ IUnknown* wrapper, _Outptr_result_maybenull_ OBJECTHANDLE* object) noexcept;
+
+ HRESULT MarkComActivated(_In_ IUnknown* wrapper) noexcept;
+ HRESULT IsComActivated(_In_ IUnknown* wrapper) noexcept;
+
struct ExternalWrapperResult
{
// The returned context memory is guaranteed to be initialized to zero.
diff --git a/src/coreclr/src/interop/interoplib.cpp b/src/coreclr/src/interop/interoplib.cpp
index 8283024ad5d4a..fb7af5aa7ec65 100644
--- a/src/coreclr/src/interop/interoplib.cpp
+++ b/src/coreclr/src/interop/interoplib.cpp
@@ -89,6 +89,43 @@ namespace InteropLib
return S_OK;
}
+ HRESULT GetObjectForWrapper(_In_ IUnknown* wrapper, _Outptr_result_maybenull_ OBJECTHANDLE* object) noexcept
+ {
+ if (object == nullptr)
+ return E_POINTER;
+
+ *object = nullptr;
+
+ HRESULT hr = IsActiveWrapper(wrapper);
+ if (hr != S_OK)
+ return hr;
+
+ ManagedObjectWrapper *mow = ManagedObjectWrapper::MapFromIUnknown(wrapper);
+ _ASSERTE(mow != nullptr);
+
+ *object = mow->Target;
+ return S_OK;
+ }
+
+ HRESULT MarkComActivated(_In_ IUnknown* wrapperMaybe) noexcept
+ {
+ ManagedObjectWrapper* wrapper = ManagedObjectWrapper::MapFromIUnknown(wrapperMaybe);
+ if (wrapper == nullptr)
+ return E_INVALIDARG;
+
+ wrapper->SetFlag(CreateComInterfaceFlagsEx::IsComActivated);
+ return S_OK;
+ }
+
+ HRESULT IsComActivated(_In_ IUnknown* wrapperMaybe) noexcept
+ {
+ ManagedObjectWrapper* wrapper = ManagedObjectWrapper::MapFromIUnknown(wrapperMaybe);
+ if (wrapper == nullptr)
+ return E_INVALIDARG;
+
+ return wrapper->IsSet(CreateComInterfaceFlagsEx::IsComActivated) ? S_OK : S_FALSE;
+ }
+
HRESULT CreateWrapperForExternal(
_In_ IUnknown* external,
_In_ enum CreateObjectFlags flags,
diff --git a/src/coreclr/src/vm/ecalllist.h b/src/coreclr/src/vm/ecalllist.h
index 3cbfa0e5ec40e..691bad54bf537 100644
--- a/src/coreclr/src/vm/ecalllist.h
+++ b/src/coreclr/src/vm/ecalllist.h
@@ -805,8 +805,8 @@ FCFuncStart(gInteropMarshalFuncs)
FCFuncElement("GetHRForException", MarshalNative::GetHRForException)
FCFuncElement("GetRawIUnknownForComObjectNoAddRef", MarshalNative::GetRawIUnknownForComObjectNoAddRef)
FCFuncElement("IsComObject", MarshalNative::IsComObject)
- FCFuncElement("GetObjectForIUnknown", MarshalNative::GetObjectForIUnknown)
- FCFuncElement("GetUniqueObjectForIUnknown", MarshalNative::GetUniqueObjectForIUnknown)
+ FCFuncElement("GetObjectForIUnknownNative", MarshalNative::GetObjectForIUnknownNative)
+ FCFuncElement("GetUniqueObjectForIUnknownNative", MarshalNative::GetUniqueObjectForIUnknownNative)
FCFuncElement("AddRef", MarshalNative::AddRef)
FCFuncElement("GetNativeVariantForObject", MarshalNative::GetNativeVariantForObject)
FCFuncElement("GetObjectForNativeVariant", MarshalNative::GetObjectForNativeVariant)
@@ -997,8 +997,9 @@ FCFuncEnd()
#ifdef FEATURE_COMWRAPPERS
FCFuncStart(gComWrappersFuncs)
QCFuncElement("GetIUnknownImplInternal", ComWrappersNative::GetIUnknownImpl)
- QCFuncElement("GetOrCreateComInterfaceForObjectInternal", ComWrappersNative::GetOrCreateComInterfaceForObject)
- QCFuncElement("GetOrCreateObjectForComInstanceInternal", ComWrappersNative::GetOrCreateObjectForComInstance)
+ QCFuncElement("TryGetOrCreateComInterfaceForObjectInternal", ComWrappersNative::TryGetOrCreateComInterfaceForObject)
+ QCFuncElement("TryGetOrCreateObjectForComInstanceInternal", ComWrappersNative::TryGetOrCreateObjectForComInstance)
+ QCFuncElement("SetGlobalInstanceRegistered", GlobalComWrappers::SetGlobalInstanceRegistered)
FCFuncEnd()
#endif // FEATURE_COMWRAPPERS
diff --git a/src/coreclr/src/vm/interopconverter.cpp b/src/coreclr/src/vm/interopconverter.cpp
index 58ac452295bd9..44470d7d297b4 100644
--- a/src/coreclr/src/vm/interopconverter.cpp
+++ b/src/coreclr/src/vm/interopconverter.cpp
@@ -17,9 +17,66 @@
#include "runtimecallablewrapper.h"
#include "cominterfacemarshaler.h"
#include "binder.h"
+#include
#include "winrttypenameconverter.h"
#include "typestring.h"
+namespace
+{
+ bool TryGetComIPFromObjectRefUsingComWrappers(
+ _In_ OBJECTREF instance,
+ _Outptr_ IUnknown** wrapperRaw)
+ {
+#ifdef FEATURE_COMWRAPPERS
+ return GlobalComWrappers::TryGetOrCreateComInterfaceForObject(instance, (void**)wrapperRaw);
+#else
+ return false;
+#endif // FEATURE_COMWRAPPERS
+ }
+
+ bool TryGetObjectRefFromComIPUsingComWrappers(
+ _In_ IUnknown* pUnknown,
+ _In_ DWORD dwFlags,
+ _Out_ OBJECTREF *pObjOut)
+ {
+#ifdef FEATURE_COMWRAPPERS
+ return GlobalComWrappers::TryGetOrCreateObjectForComInstance(pUnknown, dwFlags, pObjOut);
+#else
+ return false;
+#endif // FEATURE_COMWRAPPERS
+ }
+
+ void EnsureObjectRefIsValidForSpecifiedClass(
+ _In_ OBJECTREF *obj,
+ _In_ DWORD dwFlags,
+ _In_ MethodTable *pMTClass)
+ {
+ _ASSERTE(*obj != NULL);
+ _ASSERTE(pMTClass != NULL);
+
+ if ((dwFlags & ObjFromComIP::CLASS_IS_HINT) != 0)
+ return;
+
+ // make sure we can cast to the specified class
+ FAULT_NOT_FATAL();
+
+ // Bad format exception thrown for backward compatibility
+ THROW_BAD_FORMAT_MAYBE(pMTClass->IsArray() == FALSE, BFA_UNEXPECTED_ARRAY_TYPE, pMTClass);
+
+ if (CanCastComObject(*obj, pMTClass))
+ return;
+
+ StackSString ssObjClsName;
+ StackSString ssDestClsName;
+
+ (*obj)->GetMethodTable()->_GetFullyQualifiedNameForClass(ssObjClsName);
+ pMTClass->_GetFullyQualifiedNameForClass(ssDestClsName);
+
+ COMPlusThrow(kInvalidCastException, IDS_EE_CANNOTCAST,
+ ssObjClsName.GetUnicode(), ssDestClsName.GetUnicode());
+ }
+}
+
//--------------------------------------------------------------------------------
// IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, MethodTable *pMT, ...)
// Convert ObjectRef to a COM IP, based on MethodTable* pMT.
@@ -45,6 +102,11 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, MethodTable *pMT, BOOL bSecuri
if (*poref == NULL)
RETURN NULL;
+ if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk))
+ {
+ pUnk.SuppressRelease();
+ RETURN pUnk;
+ }
SyncBlock* pBlock = (*poref)->GetSyncBlock();
@@ -111,6 +173,30 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, ComIpType ReqIpType, ComIpType
if (*poref == NULL)
RETURN NULL;
+ if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk))
+ {
+ hr = S_OK;
+
+ SafeComHolder pvObj;
+ if (ReqIpType & ComIpType_Dispatch)
+ {
+ hr = pUnk->QueryInterface(IID_IDispatch, &pvObj);
+ }
+ else if (ReqIpType & ComIpType_Inspectable)
+ {
+ SafeComHolder pvObj;
+ hr = pUnk->QueryInterface(IID_IInspectable, &pvObj);
+ }
+
+ if (FAILED(hr))
+ COMPlusThrowHR(hr);
+
+ if (pFetchedIpType != NULL)
+ *pFetchedIpType = ReqIpType;
+
+ RETURN pUnk;
+ }
+
MethodTable *pMT = (*poref)->GetMethodTable();
SyncBlock* pBlock = (*poref)->GetSyncBlock();
@@ -376,6 +462,16 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, REFIID iid, bool throwIfNoComI
if (*poref == NULL)
RETURN NULL;
+ if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk))
+ {
+ SafeComHolder pvObj;
+ hr = pUnk->QueryInterface(iid, &pvObj);
+ if (FAILED(hr))
+ COMPlusThrowHR(hr);
+
+ RETURN pUnk;
+ }
+
MethodTable *pMT = (*poref)->GetMethodTable();
SyncBlock* pBlock = (*poref)->GetSyncBlock();
@@ -434,9 +530,18 @@ void GetObjectRefFromComIP(OBJECTREF* pObjOut, IUnknown **ppUnk, MethodTable *pM
_ASSERTE(g_fComStarted && "COM has not been started up, make sure EnsureComStarted is called before any COM objects are used!");
IUnknown *pUnk = *ppUnk;
+ *pObjOut = NULL;
+
+ if (TryGetObjectRefFromComIPUsingComWrappers(pUnk, dwFlags, pObjOut))
+ {
+ if (pMTClass != NULL)
+ EnsureObjectRefIsValidForSpecifiedClass(pObjOut, dwFlags, pMTClass);
+
+ return;
+ }
+
Thread * pThread = GetThread();
- *pObjOut = NULL;
IUnknown* pOuter = pUnk;
SafeComHolder pAutoOuterUnk = NULL;
@@ -537,22 +642,7 @@ void GetObjectRefFromComIP(OBJECTREF* pObjOut, IUnknown **ppUnk, MethodTable *pM
// make sure we can cast to the specified class
if (pMTClass != NULL)
{
- FAULT_NOT_FATAL();
-
- // Bad format exception thrown for backward compatibility
- THROW_BAD_FORMAT_MAYBE(pMTClass->IsArray() == FALSE, BFA_UNEXPECTED_ARRAY_TYPE, pMTClass);
-
- if (!CanCastComObject(*pObjOut, pMTClass))
- {
- StackSString ssObjClsName;
- StackSString ssDestClsName;
-
- (*pObjOut)->GetMethodTable()->_GetFullyQualifiedNameForClass(ssObjClsName);
- pMTClass->_GetFullyQualifiedNameForClass(ssDestClsName);
-
- COMPlusThrow(kInvalidCastException, IDS_EE_CANNOTCAST,
- ssObjClsName.GetUnicode(), ssDestClsName.GetUnicode());
- }
+ EnsureObjectRefIsValidForSpecifiedClass(pObjOut, dwFlags, pMTClass);
}
else if (dwFlags & ObjFromComIP::REQUIRE_IINSPECTABLE)
{
diff --git a/src/coreclr/src/vm/interoplibinterface.cpp b/src/coreclr/src/vm/interoplibinterface.cpp
index 70e5f14f2c7df..c71b0d955fd53 100644
--- a/src/coreclr/src/vm/interoplibinterface.cpp
+++ b/src/coreclr/src/vm/interoplibinterface.cpp
@@ -397,9 +397,12 @@ namespace
}
};
- // Global instance
+ // Global instance of the external object cache
Volatile ExtObjCxtCache::g_Instance;
+ // Indicator for if a ComWrappers implementation is globally registered
+ bool g_IsGlobalComWrappersRegistered;
+
// Defined handle types for the specific object uses.
const HandleType InstanceHandleType{ HNDTYPE_STRONG };
@@ -478,24 +481,25 @@ namespace
CALL_MANAGED_METHOD_NORET(args);
}
- void* GetOrCreateComInterfaceForObjectInternal(
+ bool TryGetOrCreateComInterfaceForObjectInternal(
_In_opt_ OBJECTREF impl,
_In_ OBJECTREF instance,
- _In_ CreateComInterfaceFlags flags)
+ _In_ CreateComInterfaceFlags flags,
+ _Outptr_ void** wrapperRaw)
{
- CONTRACT(void*)
+ CONTRACT(bool)
{
THROWS;
MODE_COOPERATIVE;
PRECONDITION(instance != NULL);
- POSTCONDITION(CheckPointer(RETVAL));
+ PRECONDITION(wrapperRaw != NULL);
}
CONTRACT_END;
HRESULT hr;
SafeComHolder newWrapper;
- void* wrapperRaw = NULL;
+ void* wrapperRawMaybe = NULL;
struct
{
@@ -514,7 +518,7 @@ namespace
_ASSERTE(syncBlock->IsPrecious());
// Query the associated InteropSyncBlockInfo for an existing managed object wrapper.
- if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRaw))
+ if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe))
{
// Compute VTables for the new existing COM object using the supplied COM Wrappers implementation.
//
@@ -525,7 +529,8 @@ namespace
void* vtables = CallComputeVTables(&gc.implRef, &gc.instRef, flags, &vtableCount);
// Re-query the associated InteropSyncBlockInfo for an existing managed object wrapper.
- if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRaw))
+ if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe)
+ && ((vtables != nullptr && vtableCount > 0) || (vtableCount == 0)))
{
OBJECTHANDLE instHandle = GetAppDomain()->CreateTypedHandle(gc.instRef, InstanceHandleType);
@@ -551,7 +556,7 @@ namespace
// If the managed object wrapper couldn't be set, then
// it should be possible to get the current one.
- if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRaw))
+ if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe))
{
UNREACHABLE();
}
@@ -564,20 +569,18 @@ namespace
{
// A new managed object wrapper was created, remove the object from the holder.
// No AddRef() here since the wrapper should be created with a reference.
- wrapperRaw = newWrapper.Extract();
- STRESS_LOG1(LF_INTEROP, LL_INFO100, "Created MOW: 0x%p\n", wrapperRaw);
+ wrapperRawMaybe = newWrapper.Extract();
+ STRESS_LOG1(LF_INTEROP, LL_INFO100, "Created MOW: 0x%p\n", wrapperRawMaybe);
}
- else
+ else if (wrapperRawMaybe != NULL)
{
- _ASSERTE(wrapperRaw != NULL);
-
// It is possible the supplied wrapper is no longer valid. If so, reactivate the
// wrapper using the protected OBJECTREF.
- IUnknown* wrapper = static_cast(wrapperRaw);
+ IUnknown* wrapper = static_cast(wrapperRawMaybe);
hr = InteropLib::Com::IsActiveWrapper(wrapper);
if (hr == S_FALSE)
{
- STRESS_LOG1(LF_INTEROP, LL_INFO100, "Reactivating MOW: 0x%p\n", wrapperRaw);
+ STRESS_LOG1(LF_INTEROP, LL_INFO100, "Reactivating MOW: 0x%p\n", wrapperRawMaybe);
OBJECTHANDLE h = GetAppDomain()->CreateTypedHandle(gc.instRef, InstanceHandleType);
hr = InteropLib::Com::ReactivateWrapper(wrapper, static_cast(h));
}
@@ -588,21 +591,29 @@ namespace
GCPROTECT_END();
- RETURN wrapperRaw;
+ *wrapperRaw = wrapperRawMaybe;
+ RETURN (wrapperRawMaybe != NULL);
}
- OBJECTREF GetOrCreateObjectForComInstanceInternal(
+ // The unwrap parameter indicates whether or not COM instances that are actually CCWs should
+ // be unwrapped to the original managed object.
+ // For implicit usage of ComWrappers (i.e. automatically called by the runtime when there is a global instance),
+ // CCWs should be unwrapped to allow for round-tripping object -> COM instance -> object.
+ // For explicit usage of ComWrappers (i.e. directly called via a ComWrappers APIs), CCWs should not be unwrapped.
+ bool TryGetOrCreateObjectForComInstanceInternal(
_In_opt_ OBJECTREF impl,
_In_ IUnknown* identity,
_In_ CreateObjectFlags flags,
- _In_opt_ OBJECTREF wrapperMaybe)
+ _In_ bool unwrap,
+ _In_opt_ OBJECTREF wrapperMaybe,
+ _Out_ OBJECTREF* objRef)
{
- CONTRACT(OBJECTREF)
+ CONTRACT(bool)
{
THROWS;
MODE_COOPERATIVE;
PRECONDITION(identity != NULL);
- POSTCONDITION(RETVAL != NULL);
+ PRECONDITION(objRef != NULL);
}
CONTRACT_END;
@@ -613,7 +624,7 @@ namespace
{
OBJECTREF implRef;
OBJECTREF wrapperMaybeRef;
- OBJECTREF objRef;
+ OBJECTREF objRefMaybe;
} gc;
::ZeroMemory(&gc, sizeof(gc));
GCPROTECT_BEGIN(gc);
@@ -622,6 +633,7 @@ namespace
gc.wrapperMaybeRef = wrapperMaybe;
ExtObjCxtCache* cache = ExtObjCxtCache::GetInstance();
+ InteropLib::OBJECTHANDLE handle = NULL;
// Check if the user requested a unique instance.
bool uniqueInstance = !!(flags & CreateObjectFlags::CreateObjectFlags_UniqueInstance);
@@ -630,93 +642,123 @@ namespace
// Query the external object cache
ExtObjCxtCache::LockHolder lock(cache);
extObjCxt = cache->Find(identity);
+
+ // If is no object found in the cache, check if the object COM instance is actually the CCW
+ // representing a managed object.
+ if (extObjCxt == NULL && unwrap)
+ {
+ // If the COM instance is a CCW that is not COM-activated, use the object of that wrapper object.
+ InteropLib::OBJECTHANDLE handleLocal;
+ if (InteropLib::Com::GetObjectForWrapper(identity, &handleLocal) == S_OK
+ && InteropLib::Com::IsComActivated(identity) == S_FALSE)
+ {
+ handle = handleLocal;
+ }
+ }
}
if (extObjCxt != NULL)
{
- gc.objRef = extObjCxt->GetObjectRef();
+ gc.objRefMaybe = extObjCxt->GetObjectRef();
+ }
+ else if (handle != NULL)
+ {
+ // We have an object handle from the COM instance which is a CCW. Use that object.
+ // This allows for the round-trip from object -> COM instance -> object.
+ ::OBJECTHANDLE objectHandle = static_cast<::OBJECTHANDLE>(handle);
+ gc.objRefMaybe = ObjectFromHandle(objectHandle);
}
else
{
// Create context instance for the possibly new external object.
ExternalWrapperResultHolder resultHolder;
- hr = InteropLib::Com::CreateWrapperForExternal(
- identity,
- flags,
- sizeof(ExternalObjectContext),
- &resultHolder);
+
+ {
+ GCX_PREEMP();
+ hr = InteropLib::Com::CreateWrapperForExternal(
+ identity,
+ flags,
+ sizeof(ExternalObjectContext),
+ &resultHolder);
+ }
+
if (FAILED(hr))
COMPlusThrowHR(hr);
// The user could have supplied a wrapper so assign that now.
- gc.objRef = gc.wrapperMaybeRef;
+ gc.objRefMaybe = gc.wrapperMaybeRef;
// If the wrapper hasn't been set yet, call the implementation to create one.
- if (gc.objRef == NULL)
+ if (gc.objRefMaybe == NULL)
{
- gc.objRef = CallGetObject(&gc.implRef, identity, flags);
- if (gc.objRef == NULL)
- COMPlusThrow(kArgumentNullException);
+ gc.objRefMaybe = CallGetObject(&gc.implRef, identity, flags);
}
- // Construct the new context with the object details.
- DWORD flags = (resultHolder.Result.FromTrackerRuntime
- ? ExternalObjectContext::Flags_ReferenceTracker
- : ExternalObjectContext::Flags_None) |
- (uniqueInstance
- ? ExternalObjectContext::Flags_None
- : ExternalObjectContext::Flags_InCache);
- ExternalObjectContext::Construct(
- resultHolder.GetContext(),
- identity,
- GetCurrentCtxCookie(),
- gc.objRef->GetSyncBlockIndex(),
- flags);
-
- if (uniqueInstance)
+ // The object may be null if the specified ComWrapper implementation returns null
+ // or there is no registered global instance. It is the caller's responsibility
+ // to handle this case and error if necessary.
+ if (gc.objRefMaybe != NULL)
{
- extObjCxt = resultHolder.GetContext();
- }
- else
- {
- // Attempt to insert the new context into the cache.
- ExtObjCxtCache::LockHolder lock(cache);
- extObjCxt = cache->FindOrAdd(identity, resultHolder.GetContext());
- }
-
- // If the returned context matches the new context it means the
- // new context was inserted or a unique instance was requested.
- if (extObjCxt == resultHolder.GetContext())
- {
- // Update the object's SyncBlock with a handle to the context for runtime cleanup.
- SyncBlock* syncBlock = gc.objRef->GetSyncBlock();
- InteropSyncBlockInfo* interopInfo = syncBlock->GetInteropInfo();
- _ASSERTE(syncBlock->IsPrecious());
-
- // Since the caller has the option of providing a wrapper, it is
- // possible the supplied wrapper already has an associated external
- // object and an object can only be associated with one external object.
- if (!interopInfo->TrySetExternalComObjectContext((void**)extObjCxt))
+ // Construct the new context with the object details.
+ DWORD flags = (resultHolder.Result.FromTrackerRuntime
+ ? ExternalObjectContext::Flags_ReferenceTracker
+ : ExternalObjectContext::Flags_None) |
+ (uniqueInstance
+ ? ExternalObjectContext::Flags_None
+ : ExternalObjectContext::Flags_InCache);
+ ExternalObjectContext::Construct(
+ resultHolder.GetContext(),
+ identity,
+ GetCurrentCtxCookie(),
+ gc.objRefMaybe->GetSyncBlockIndex(),
+ flags);
+
+ if (uniqueInstance)
{
- // Failed to set the context; one must already exist.
- // Remove from the cache above as well.
+ extObjCxt = resultHolder.GetContext();
+ }
+ else
+ {
+ // Attempt to insert the new context into the cache.
ExtObjCxtCache::LockHolder lock(cache);
- cache->Remove(resultHolder.GetContext());
+ extObjCxt = cache->FindOrAdd(identity, resultHolder.GetContext());
+ }
+
+ // If the returned context matches the new context it means the
+ // new context was inserted or a unique instance was requested.
+ if (extObjCxt == resultHolder.GetContext())
+ {
+ // Update the object's SyncBlock with a handle to the context for runtime cleanup.
+ SyncBlock* syncBlock = gc.objRefMaybe->GetSyncBlock();
+ InteropSyncBlockInfo* interopInfo = syncBlock->GetInteropInfo();
+ _ASSERTE(syncBlock->IsPrecious());
+
+ // Since the caller has the option of providing a wrapper, it is
+ // possible the supplied wrapper already has an associated external
+ // object and an object can only be associated with one external object.
+ if (!interopInfo->TrySetExternalComObjectContext((void**)extObjCxt))
+ {
+ // Failed to set the context; one must already exist.
+ // Remove from the cache above as well.
+ ExtObjCxtCache::LockHolder lock(cache);
+ cache->Remove(resultHolder.GetContext());
+
+ COMPlusThrow(kNotSupportedException);
+ }
- COMPlusThrow(kNotSupportedException);
+ // Detach from the holder to avoid cleanup.
+ (void)resultHolder.DetachContext();
+ STRESS_LOG2(LF_INTEROP, LL_INFO100, "Created EOC (Unique Instance: %d): 0x%p\n", (int)uniqueInstance, extObjCxt);
}
- // Detach from the holder to avoid cleanup.
- (void)resultHolder.DetachContext();
- STRESS_LOG2(LF_INTEROP, LL_INFO100, "Created EOC (Unique Instance: %d): 0x%p\n", (int)uniqueInstance, extObjCxt);
+ _ASSERTE(extObjCxt->IsActive());
}
-
- _ASSERTE(extObjCxt->IsActive());
}
GCPROTECT_END();
- RETURN gc.objRef;
+ *objRef = gc.objRefMaybe;
+ RETURN (gc.objRefMaybe != NULL);
}
}
@@ -949,19 +991,29 @@ namespace InteropLibImports
gc.implRef = NULL; // Use the globally registered implementation.
gc.wrapperMaybeRef = NULL; // No supplied wrapper here.
+ bool unwrapIfManagedObjectWrapper = false; // Don't unwrap CCWs
// Get wrapper for external object
- gc.objRef = GetOrCreateObjectForComInstanceInternal(
+ bool success = TryGetOrCreateObjectForComInstanceInternal(
gc.implRef,
externalComObject,
externalObjectFlags,
- gc.wrapperMaybeRef);
+ unwrapIfManagedObjectWrapper,
+ gc.wrapperMaybeRef,
+ &gc.objRef);
+
+ if (!success)
+ COMPlusThrow(kArgumentNullException);
// Get wrapper for managed object
- *trackerTarget = GetOrCreateComInterfaceForObjectInternal(
+ success = TryGetOrCreateComInterfaceForObjectInternal(
gc.implRef,
gc.objRef,
- trackerTargetFlags);
+ trackerTargetFlags,
+ trackerTarget);
+
+ if (!success)
+ COMPlusThrow(kArgumentException);
STRESS_LOG2(LF_INTEROP, LL_INFO100, "Created Target for External: 0x%p => 0x%p\n", OBJECTREFToObject(gc.objRef), *trackerTarget);
GCPROTECT_END();
@@ -1055,14 +1107,15 @@ namespace InteropLibImports
#ifdef FEATURE_COMWRAPPERS
-void* QCALLTYPE ComWrappersNative::GetOrCreateComInterfaceForObject(
+BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateComInterfaceForObject(
_In_ QCall::ObjectHandleOnStack comWrappersImpl,
_In_ QCall::ObjectHandleOnStack instance,
- _In_ INT32 flags)
+ _In_ INT32 flags,
+ _Outptr_ void** wrapper)
{
QCALL_CONTRACT;
- void* wrapper = NULL;
+ bool success;
BEGIN_QCALL;
@@ -1070,19 +1123,19 @@ void* QCALLTYPE ComWrappersNative::GetOrCreateComInterfaceForObject(
// are being manipulated.
{
GCX_COOP();
- wrapper = GetOrCreateComInterfaceForObjectInternal(
+ success = TryGetOrCreateComInterfaceForObjectInternal(
ObjectToOBJECTREF(*comWrappersImpl.m_ppObject),
ObjectToOBJECTREF(*instance.m_ppObject),
- (CreateComInterfaceFlags)flags);
+ (CreateComInterfaceFlags)flags,
+ wrapper);
}
END_QCALL;
- _ASSERTE(wrapper != NULL);
- return wrapper;
+ return (success ? TRUE : FALSE);
}
-void QCALLTYPE ComWrappersNative::GetOrCreateObjectForComInstance(
+BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance(
_In_ QCall::ObjectHandleOnStack comWrappersImpl,
_In_ void* ext,
_In_ INT32 flags,
@@ -1093,6 +1146,8 @@ void QCALLTYPE ComWrappersNative::GetOrCreateObjectForComInstance(
_ASSERTE(ext != NULL);
+ bool success;
+
BEGIN_QCALL;
HRESULT hr;
@@ -1107,17 +1162,25 @@ void QCALLTYPE ComWrappersNative::GetOrCreateObjectForComInstance(
// are being manipulated.
{
GCX_COOP();
- OBJECTREF newObj = GetOrCreateObjectForComInstanceInternal(
+
+ bool unwrapIfManagedObjectWrapper = false; // Don't unwrap CCWs
+ OBJECTREF newObj;
+ success = TryGetOrCreateObjectForComInstanceInternal(
ObjectToOBJECTREF(*comWrappersImpl.m_ppObject),
identity,
(CreateObjectFlags)flags,
- ObjectToOBJECTREF(*wrapperMaybe.m_ppObject));
+ unwrapIfManagedObjectWrapper,
+ ObjectToOBJECTREF(*wrapperMaybe.m_ppObject),
+ &newObj);
// Set the return value
- retValue.Set(newObj);
+ if (success)
+ retValue.Set(newObj);
}
END_QCALL;
+
+ return (success ? TRUE : FALSE);
}
void QCALLTYPE ComWrappersNative::GetIUnknownImpl(
@@ -1198,6 +1261,85 @@ void ComWrappersNative::MarkExternalComObjectContextCollected(_In_ void* context
}
}
+void ComWrappersNative::MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe)
+{
+ CONTRACTL
+ {
+ NOTHROW;
+ MODE_ANY;
+ PRECONDITION(wrapperMaybe != NULL);
+ }
+ CONTRACTL_END;
+
+ // The IUnknown may or may not represent a wrapper, so E_INVALIDARG is okay here.
+ HRESULT hr = InteropLib::Com::MarkComActivated(wrapperMaybe);
+ _ASSERTE(SUCCEEDED(hr) || hr == E_INVALIDARG);
+}
+
+void QCALLTYPE GlobalComWrappers::SetGlobalInstanceRegistered()
+{
+ // QCALL contracts are not used here because the managed declaration
+ // uses the SuppressGCTransition attribute
+
+ _ASSERTE(!g_IsGlobalComWrappersRegistered);
+ g_IsGlobalComWrappersRegistered = true;
+}
+
+bool GlobalComWrappers::TryGetOrCreateComInterfaceForObject(
+ _In_ OBJECTREF instance,
+ _Outptr_ void** wrapperRaw)
+{
+ if (!g_IsGlobalComWrappersRegistered)
+ return false;
+
+ // Switch to Cooperative mode since object references
+ // are being manipulated.
+ {
+ GCX_COOP();
+
+ CreateComInterfaceFlags flags = CreateComInterfaceFlags::CreateComInterfaceFlags_TrackerSupport;
+
+ // Passing NULL as the ComWrappers implementation indicates using the globally registered instance
+ return TryGetOrCreateComInterfaceForObjectInternal(
+ NULL,
+ instance,
+ flags,
+ wrapperRaw);
+ }
+}
+
+bool GlobalComWrappers::TryGetOrCreateObjectForComInstance(
+ _In_ IUnknown* externalComObject,
+ _In_ INT32 objFromComIPFlags,
+ _Out_ OBJECTREF* objRef)
+{
+ if (!g_IsGlobalComWrappersRegistered)
+ return false;
+
+ // Switch to Cooperative mode since object references
+ // are being manipulated.
+ {
+ GCX_COOP();
+
+ int flags = CreateObjectFlags::CreateObjectFlags_TrackerObject;
+ if ((objFromComIPFlags & ObjFromComIP::UNIQUE_OBJECT) != 0)
+ flags |= CreateObjectFlags::CreateObjectFlags_UniqueInstance;
+
+ // For implicit usage of ComWrappers (i.e. automatically called by the runtime when there is a global instance),
+ // unwrap CCWs to allow for round-tripping object -> COM instance -> object.
+ bool unwrapIfManagedObjectWrapper = true;
+
+ // Passing NULL as the ComWrappers implementation indicates using the globally registered instance
+ return TryGetOrCreateObjectForComInstanceInternal(
+ NULL /*comWrappersImpl*/,
+ externalComObject,
+ (CreateObjectFlags)flags,
+ unwrapIfManagedObjectWrapper,
+ NULL /*wrapperMaybe*/,
+ objRef);
+ }
+}
+
#endif // FEATURE_COMWRAPPERS
void Interop::OnGCStarted(_In_ int nCondemnedGeneration)
diff --git a/src/coreclr/src/vm/interoplibinterface.h b/src/coreclr/src/vm/interoplibinterface.h
index 2ed54cfc90706..2985412db4b29 100644
--- a/src/coreclr/src/vm/interoplibinterface.h
+++ b/src/coreclr/src/vm/interoplibinterface.h
@@ -17,12 +17,13 @@ class ComWrappersNative
_Out_ void** fpAddRef,
_Out_ void** fpRelease);
- static void* QCALLTYPE GetOrCreateComInterfaceForObject(
+ static BOOL QCALLTYPE TryGetOrCreateComInterfaceForObject(
_In_ QCall::ObjectHandleOnStack comWrappersImpl,
_In_ QCall::ObjectHandleOnStack instance,
- _In_ INT32 flags);
+ _In_ INT32 flags,
+ _Outptr_ void** wrapperRaw);
- static void QCALLTYPE GetOrCreateObjectForComInstance(
+ static BOOL QCALLTYPE TryGetOrCreateObjectForComInstance(
_In_ QCall::ObjectHandleOnStack comWrappersImpl,
_In_ void* externalComObject,
_In_ INT32 flags,
@@ -33,6 +34,27 @@ class ComWrappersNative
static void DestroyManagedObjectComWrapper(_In_ void* wrapper);
static void DestroyExternalComObjectContext(_In_ void* context);
static void MarkExternalComObjectContextCollected(_In_ void* context);
+
+public: // COM activation
+ static void MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe);
+};
+
+class GlobalComWrappers
+{
+public:
+ // Native QCall for the ComWrappers managed type to indicate a global instance is registered
+ // This should be set if the private static member representing the global instance on ComWrappers is non-null.
+ static void QCALLTYPE SetGlobalInstanceRegistered();
+
+public: // Functions operating on a registered global instance
+ static bool TryGetOrCreateComInterfaceForObject(
+ _In_ OBJECTREF instance,
+ _Outptr_ void** wrapperRaw);
+
+ static bool TryGetOrCreateObjectForComInstance(
+ _In_ IUnknown* externalComObject,
+ _In_ INT32 objFromComIPFlags,
+ _Out_ OBJECTREF* objRef);
};
#endif // FEATURE_COMWRAPPERS
diff --git a/src/coreclr/src/vm/marshalnative.cpp b/src/coreclr/src/vm/marshalnative.cpp
index adc20310648cb..d0d3b27c68980 100644
--- a/src/coreclr/src/vm/marshalnative.cpp
+++ b/src/coreclr/src/vm/marshalnative.cpp
@@ -372,7 +372,7 @@ FCIMPL1(UINT32, MarshalNative::OffsetOfHelper, ReflectFieldObject *pFieldUNSAFE)
FieldDesc *pField = refField->GetField();
TypeHandle th = TypeHandle(pField->GetApproxEnclosingMethodTable());
-
+
if (th.IsBlittable())
{
return pField->GetOffset();
@@ -388,7 +388,7 @@ FCIMPL1(UINT32, MarshalNative::OffsetOfHelper, ReflectFieldObject *pFieldUNSAFE)
{
// It isn't marshalable so throw an ArgumentException.
StackSString strTypeName;
- TypeString::AppendType(strTypeName, th);
+ TypeString::AppendType(strTypeName, th);
COMPlusThrow(kArgumentException, IDS_CANNOT_MARSHAL, strTypeName.GetUnicode(), NULL, NULL);
}
EEClassNativeLayoutInfo const* pNativeLayoutInfo = th.GetMethodTable()->GetNativeLayoutInfo();
@@ -401,7 +401,7 @@ FCIMPL1(UINT32, MarshalNative::OffsetOfHelper, ReflectFieldObject *pFieldUNSAFE)
#endif
while (numReferenceFields--)
{
- if (pNFD->GetFieldDesc() == pField)
+ if (pNFD->GetFieldDesc() == pField)
{
externalOffset = pNFD->GetExternalOffset();
INDEBUG(foundField = true);
@@ -802,11 +802,7 @@ FCIMPL2(IUnknown*, MarshalNative::GetIUnknownForObjectNative, Object* orefUNSAFE
OBJECTREF oref = (OBJECTREF) orefUNSAFE;
HELPER_METHOD_FRAME_BEGIN_RET_1(oref);
- HRESULT hr = S_OK;
-
- if(!oref)
- COMPlusThrowArgumentNull(W("o"));
-
+ _ASSERTE(oref != NULL);
// Ensure COM is started up.
EnsureComStarted();
@@ -868,11 +864,7 @@ FCIMPL2(IDispatch*, MarshalNative::GetIDispatchForObjectNative, Object* orefUNSA
OBJECTREF oref = (OBJECTREF) orefUNSAFE;
HELPER_METHOD_FRAME_BEGIN_RET_1(oref);
- HRESULT hr = S_OK;
-
- if(!oref)
- COMPlusThrowArgumentNull(W("o"));
-
+ _ASSERTE(oref != NULL);
// Ensure COM is started up.
EnsureComStarted();
@@ -899,13 +891,8 @@ FCIMPL4(IUnknown*, MarshalNative::GetComInterfaceForObjectNative, Object* orefUN
REFLECTCLASSBASEREF refClass = (REFLECTCLASSBASEREF) refClassUNSAFE;
HELPER_METHOD_FRAME_BEGIN_RET_2(oref, refClass);
- HRESULT hr = S_OK;
-
- if(!oref)
- COMPlusThrowArgumentNull(W("o"));
- if(!refClass)
- COMPlusThrowArgumentNull(W("t"));
-
+ _ASSERTE(oref != NULL);
+ _ASSERTE(refClass != NULL);
// Ensure COM is started up.
EnsureComStarted();
@@ -946,23 +933,18 @@ FCIMPLEND
//====================================================================
// return an Object for IUnknown
//====================================================================
-FCIMPL1(Object*, MarshalNative::GetObjectForIUnknown, IUnknown* pUnk)
+FCIMPL1(Object*, MarshalNative::GetObjectForIUnknownNative, IUnknown* pUnk)
{
CONTRACTL
{
FCALL_CHECK;
- PRECONDITION(CheckPointer(pUnk, NULL_OK));
+ PRECONDITION(CheckPointer(pUnk));
}
CONTRACTL_END;
OBJECTREF oref = NULL;
HELPER_METHOD_FRAME_BEGIN_RET_1(oref);
- HRESULT hr = S_OK;
-
- if(!pUnk)
- COMPlusThrowArgumentNull(W("pUnk"));
-
// Ensure COM is started up.
EnsureComStarted();
@@ -974,12 +956,12 @@ FCIMPL1(Object*, MarshalNative::GetObjectForIUnknown, IUnknown* pUnk)
FCIMPLEND
-FCIMPL1(Object*, MarshalNative::GetUniqueObjectForIUnknown, IUnknown* pUnk)
+FCIMPL1(Object*, MarshalNative::GetUniqueObjectForIUnknownNative, IUnknown* pUnk)
{
CONTRACTL
{
FCALL_CHECK;
- PRECONDITION(CheckPointer(pUnk, NULL_OK));
+ PRECONDITION(CheckPointer(pUnk));
}
CONTRACTL_END;
@@ -988,9 +970,6 @@ FCIMPL1(Object*, MarshalNative::GetUniqueObjectForIUnknown, IUnknown* pUnk)
HRESULT hr = S_OK;
- if(!pUnk)
- COMPlusThrowArgumentNull(W("pUnk"));
-
// Ensure COM is started up.
EnsureComStarted();
diff --git a/src/coreclr/src/vm/marshalnative.h b/src/coreclr/src/vm/marshalnative.h
index 7ff167c586fc3..3325a0d05544e 100644
--- a/src/coreclr/src/vm/marshalnative.h
+++ b/src/coreclr/src/vm/marshalnative.h
@@ -106,12 +106,12 @@ class MarshalNative
//====================================================================
// return an Object for IUnknown
//====================================================================
- static FCDECL1(Object*, GetObjectForIUnknown, IUnknown* pUnk);
+ static FCDECL1(Object*, GetObjectForIUnknownNative, IUnknown* pUnk);
//====================================================================
// return a unique cacheless Object for IUnknown
//====================================================================
- static FCDECL1(Object*, GetUniqueObjectForIUnknown, IUnknown* pUnk);
+ static FCDECL1(Object*, GetUniqueObjectForIUnknownNative, IUnknown* pUnk);
//====================================================================
// return a unique cacheless Object for IUnknown
diff --git a/src/coreclr/src/vm/runtimecallablewrapper.cpp b/src/coreclr/src/vm/runtimecallablewrapper.cpp
index f61ed2b61b417..61475bf9ba0d5 100644
--- a/src/coreclr/src/vm/runtimecallablewrapper.cpp
+++ b/src/coreclr/src/vm/runtimecallablewrapper.cpp
@@ -52,6 +52,7 @@ SLIST_HEADER RCW::s_RCWStandbyList;
#endif // FEATURE_COMINTEROP_APARTMENT_SUPPORT
#ifdef FEATURE_COMINTEROP_UNMANAGED_ACTIVATION
+#include "interoplibinterface.h"
#ifndef CROSSGEN_COMPILE
@@ -247,6 +248,8 @@ IUnknown *ComClassFactory::CreateInstanceFromClassFactory(IClassFactory *pClassF
if (ccw != NULL)
ccw->MarkComActivated();
+ ComWrappersNative::MarkWrapperAsComActivated(pUnk);
+
pUnk.SuppressRelease();
RETURN pUnk;
}
diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/ComWrappersTests.csproj b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/ComWrappersTests.csproj
similarity index 82%
rename from src/coreclr/tests/src/Interop/COM/ComWrappers/ComWrappersTests.csproj
rename to src/coreclr/tests/src/Interop/COM/ComWrappers/API/ComWrappersTests.csproj
index e82960ae1e101..83acfa1f6fd5e 100644
--- a/src/coreclr/tests/src/Interop/COM/ComWrappers/ComWrappersTests.csproj
+++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/ComWrappersTests.csproj
@@ -9,8 +9,9 @@
+
-
+
diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs
similarity index 68%
rename from src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs
rename to src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs
index e85f87a4b8269..2c30bb6c2a61d 100644
--- a/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs
+++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs
@@ -7,154 +7,16 @@ namespace ComWrappersTests
using System;
using System.Collections;
using System.Collections.Generic;
- using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
+ using ComWrappersTests.Common;
using TestLibrary;
class Program
{
- //
- // Managed object with native wrapper definition.
- //
- [Guid("447BB9ED-DA48-4ABC-8963-5BB5C3E0AA09")]
- interface ITest
- {
- void SetValue(int i);
- }
-
- class Test : ITest
- {
- public static int InstanceCount = 0;
-
- private int value = -1;
- public Test() { InstanceCount++; }
- ~Test() { InstanceCount--; }
-
- public void SetValue(int i) => this.value = i;
- public int GetValue() => this.value;
- }
-
- public struct IUnknownVtbl
- {
- public IntPtr QueryInterface;
- public IntPtr AddRef;
- public IntPtr Release;
- }
-
- public struct ITestVtbl
- {
- public IUnknownVtbl IUnknownImpl;
- public IntPtr SetValue;
-
- public delegate int _SetValue(IntPtr thisPtr, int i);
- public static _SetValue pSetValue = new _SetValue(SetValueInternal);
-
- public static int SetValueInternal(IntPtr dispatchPtr, int i)
- {
- unsafe
- {
- try
- {
- ComWrappers.ComInterfaceDispatch.GetInstance((ComWrappers.ComInterfaceDispatch*)dispatchPtr).SetValue(i);
- }
- catch (Exception e)
- {
- return e.HResult;
- }
- }
- return 0; // S_OK;
- }
- }
-
- //
- // Native interface defintion with managed wrapper for tracker object
- //
- struct MockReferenceTrackerRuntime
- {
- [DllImport(nameof(MockReferenceTrackerRuntime))]
- extern public static IntPtr CreateTrackerObject();
-
- [DllImport(nameof(MockReferenceTrackerRuntime))]
- extern public static void ReleaseAllTrackerObjects();
-
- [DllImport(nameof(MockReferenceTrackerRuntime))]
- extern public static int Trigger_NotifyEndOfReferenceTrackingOnThread();
- }
-
- [Guid("42951130-245C-485E-B60B-4ED4254256F8")]
- public interface ITrackerObject
- {
- int AddObjectRef(IntPtr obj);
- void DropObjectRef(int id);
- };
-
- public struct VtblPtr
- {
- public IntPtr Vtbl;
- }
-
- public class ITrackerObjectWrapper : ITrackerObject
- {
- private struct ITrackerObjectWrapperVtbl
- {
- public IntPtr QueryInterface;
- public _AddRef AddRef;
- public _Release Release;
- public _AddObjectRef AddObjectRef;
- public _DropObjectRef DropObjectRef;
- }
-
- private delegate int _AddRef(IntPtr This);
- private delegate int _Release(IntPtr This);
- private delegate int _AddObjectRef(IntPtr This, IntPtr obj, out int id);
- private delegate int _DropObjectRef(IntPtr This, int id);
-
- private readonly IntPtr instance;
- private readonly ITrackerObjectWrapperVtbl vtable;
-
- public ITrackerObjectWrapper(IntPtr instance)
- {
- var inst = Marshal.PtrToStructure(instance);
- this.vtable = Marshal.PtrToStructure(inst.Vtbl);
- this.instance = instance;
- }
-
- ~ITrackerObjectWrapper()
- {
- if (this.instance != IntPtr.Zero)
- {
- this.vtable.Release(this.instance);
- }
- }
-
- public int AddObjectRef(IntPtr obj)
- {
- int id;
- int hr = this.vtable.AddObjectRef(this.instance, obj, out id);
- if (hr != 0)
- {
- throw new COMException($"{nameof(AddObjectRef)}", hr);
- }
-
- return id;
- }
-
- public void DropObjectRef(int id)
- {
- int hr = this.vtable.DropObjectRef(this.instance, id);
- if (hr != 0)
- {
- throw new COMException($"{nameof(DropObjectRef)}", hr);
- }
- }
- }
-
class TestComWrappers : ComWrappers
{
- public static readonly TestComWrappers Global = new TestComWrappers();
-
protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
{
Assert.IsTrue(obj is Test);
@@ -188,11 +50,11 @@ class TestComWrappers : ComWrappers
protected override object? CreateObject(IntPtr externalComObject, CreateObjectFlags flag)
{
var iid = typeof(ITrackerObject).GUID;
- IntPtr iTestComObject;
- int hr = Marshal.QueryInterface(externalComObject, ref iid, out iTestComObject);
- Assert.AreEqual(hr, 0);
+ IntPtr iTrackerComObject;
+ int hr = Marshal.QueryInterface(externalComObject, ref iid, out iTrackerComObject);
+ Assert.AreEqual(0, hr);
- return new ITrackerObjectWrapper(iTestComObject);
+ return new ITrackerObjectWrapper(iTrackerComObject);
}
public const int ReleaseObjectsCallAck = unchecked((int)-1);
@@ -458,37 +320,6 @@ static void ValidateRuntimeTrackerScenario()
GC.Collect();
}
- static void ValidateGlobalInstanceScenarios()
- {
- Console.WriteLine($"Running {nameof(ValidateGlobalInstanceScenarios)}...");
- Console.WriteLine($"Validate RegisterAsGlobalInstance()...");
-
- var wrappers1 = TestComWrappers.Global;
- wrappers1.RegisterAsGlobalInstance();
-
- Assert.Throws(
- () =>
- {
- wrappers1.RegisterAsGlobalInstance();
- }, "Should not be able to re-register for global ComWrappers");
-
- var wrappers2 = new TestComWrappers();
- Assert.Throws(
- () =>
- {
- wrappers2.RegisterAsGlobalInstance();
- }, "Should not be able to reset for global ComWrappers");
-
- Console.WriteLine($"Validate NotifyEndOfReferenceTrackingOnThread()...");
-
- int hr;
- var cw = TestComWrappers.Global;
-
- // Trigger the thread lifetime end API and verify the callback occurs.
- hr = MockReferenceTrackerRuntime.Trigger_NotifyEndOfReferenceTrackingOnThread();
- Assert.AreEqual(TestComWrappers.ReleaseObjectsCallAck, hr);
- }
-
static int Main(string[] doNotUse)
{
try
@@ -499,10 +330,6 @@ static int Main(string[] doNotUse)
ValidateIUnknownImpls();
ValidateBadComWrapperImpl();
ValidateRuntimeTrackerScenario();
-
- // Perform all global impacting test scenarios last to
- // avoid polluting non-global tests.
- ValidateGlobalInstanceScenarios();
}
catch (Exception e)
{
diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs
new file mode 100644
index 0000000000000..06c052b0a237a
--- /dev/null
+++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs
@@ -0,0 +1,146 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace ComWrappersTests.Common
+{
+ using System;
+ using System.Runtime.InteropServices;
+
+ //
+ // Managed object with native wrapper definition.
+ //
+ [Guid("447BB9ED-DA48-4ABC-8963-5BB5C3E0AA09")]
+ interface ITest
+ {
+ void SetValue(int i);
+ }
+
+ class Test : ITest
+ {
+ public static int InstanceCount = 0;
+
+ private int value = -1;
+ public Test() { InstanceCount++; }
+ ~Test() { InstanceCount--; }
+
+ public void SetValue(int i) => this.value = i;
+ public int GetValue() => this.value;
+ }
+
+ public struct IUnknownVtbl
+ {
+ public IntPtr QueryInterface;
+ public IntPtr AddRef;
+ public IntPtr Release;
+ }
+
+ public struct ITestVtbl
+ {
+ public IUnknownVtbl IUnknownImpl;
+ public IntPtr SetValue;
+
+ public delegate int _SetValue(IntPtr thisPtr, int i);
+ public static _SetValue pSetValue = new _SetValue(SetValueInternal);
+
+ public static int SetValueInternal(IntPtr dispatchPtr, int i)
+ {
+ unsafe
+ {
+ try
+ {
+ ComWrappers.ComInterfaceDispatch.GetInstance((ComWrappers.ComInterfaceDispatch*)dispatchPtr).SetValue(i);
+ }
+ catch (Exception e)
+ {
+ return e.HResult;
+ }
+ }
+ return 0; // S_OK;
+ }
+ }
+
+ //
+ // Native interface defintion with managed wrapper for tracker object
+ //
+ struct MockReferenceTrackerRuntime
+ {
+ [DllImport(nameof(MockReferenceTrackerRuntime))]
+ extern public static IntPtr CreateTrackerObject();
+
+ [DllImport(nameof(MockReferenceTrackerRuntime))]
+ extern public static void ReleaseAllTrackerObjects();
+
+ [DllImport(nameof(MockReferenceTrackerRuntime))]
+ extern public static int Trigger_NotifyEndOfReferenceTrackingOnThread();
+ }
+
+ [Guid("42951130-245C-485E-B60B-4ED4254256F8")]
+ public interface ITrackerObject
+ {
+ int AddObjectRef(IntPtr obj);
+ void DropObjectRef(int id);
+ };
+
+ public struct VtblPtr
+ {
+ public IntPtr Vtbl;
+ }
+
+ public class ITrackerObjectWrapper : ITrackerObject
+ {
+ private struct ITrackerObjectWrapperVtbl
+ {
+ public IntPtr QueryInterface;
+ public _AddRef AddRef;
+ public _Release Release;
+ public _AddObjectRef AddObjectRef;
+ public _DropObjectRef DropObjectRef;
+ }
+
+ private delegate int _AddRef(IntPtr This);
+ private delegate int _Release(IntPtr This);
+ private delegate int _AddObjectRef(IntPtr This, IntPtr obj, out int id);
+ private delegate int _DropObjectRef(IntPtr This, int id);
+
+ private readonly IntPtr instance;
+ private readonly ITrackerObjectWrapperVtbl vtable;
+
+ public ITrackerObjectWrapper(IntPtr instance)
+ {
+ var inst = Marshal.PtrToStructure(instance);
+ this.vtable = Marshal.PtrToStructure(inst.Vtbl);
+ this.instance = instance;
+ }
+
+ ~ITrackerObjectWrapper()
+ {
+ if (this.instance != IntPtr.Zero)
+ {
+ this.vtable.Release(this.instance);
+ }
+ }
+
+ public int AddObjectRef(IntPtr obj)
+ {
+ int id;
+ int hr = this.vtable.AddObjectRef(this.instance, obj, out id);
+ if (hr != 0)
+ {
+ throw new COMException($"{nameof(AddObjectRef)}", hr);
+ }
+
+ return id;
+ }
+
+ public void DropObjectRef(int id)
+ {
+ int hr = this.vtable.DropObjectRef(this.instance, id);
+ if (hr != 0)
+ {
+ throw new COMException($"{nameof(DropObjectRef)}", hr);
+ }
+ }
+ }
+}
+
diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/App.manifest b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/App.manifest
new file mode 100644
index 0000000000000..bb7ec83fae8b8
--- /dev/null
+++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/App.manifest
@@ -0,0 +1,26 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/CoreShim.X.manifest b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/CoreShim.X.manifest
new file mode 100644
index 0000000000000..abb39fbb21c7d
--- /dev/null
+++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/CoreShim.X.manifest
@@ -0,0 +1,16 @@
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstanceTests.csproj b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstanceTests.csproj
new file mode 100644
index 0000000000000..1e5e14f330e2e
--- /dev/null
+++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstanceTests.csproj
@@ -0,0 +1,37 @@
+
+
+ Exe
+ App.manifest
+ true
+ true
+
+ BuildOnly
+
+ true
+ true
+
+ true
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+ false
+ Content
+ Always
+
+
+
+
+ PreserveNewest
+
+
+
diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs
new file mode 100644
index 0000000000000..3bd14141570cd
--- /dev/null
+++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs
@@ -0,0 +1,451 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace ComWrappersTests.GlobalInstance
+{
+ using System;
+ using System.Collections;
+ using System.Runtime.CompilerServices;
+ using System.Runtime.InteropServices;
+
+ using ComWrappersTests.Common;
+ using TestLibrary;
+
+ class Program
+ {
+ struct MarshalInterface
+ {
+ [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint=nameof(MockReferenceTrackerRuntime.CreateTrackerObject))]
+ [return: MarshalAs(UnmanagedType.IUnknown)]
+ extern public static object CreateTrackerObjectAsIUnknown();
+
+ [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint=nameof(MockReferenceTrackerRuntime.CreateTrackerObject))]
+ [return: MarshalAs(UnmanagedType.Interface)]
+ extern public static FakeWrapper CreateTrackerObjectAsInterface();
+
+ [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint = nameof(MockReferenceTrackerRuntime.CreateTrackerObject))]
+ [return: MarshalAs(UnmanagedType.Interface)]
+ extern public static Test CreateTrackerObjectWrongType();
+
+ [DllImport(nameof(MockReferenceTrackerRuntime))]
+ extern public static int UpdateTestObjectAsIUnknown(
+ [MarshalAs(UnmanagedType.IUnknown)] object testObj,
+ int i,
+ [MarshalAs(UnmanagedType.IUnknown)] out object ret);
+
+ [DllImport(nameof(MockReferenceTrackerRuntime))]
+ extern public static int UpdateTestObjectAsIDispatch(
+ [MarshalAs(UnmanagedType.IDispatch)] object testObj,
+ int i,
+ [MarshalAs(UnmanagedType.IDispatch)] out object ret);
+
+ [DllImport(nameof(MockReferenceTrackerRuntime))]
+ extern public static int UpdateTestObjectAsIInspectable(
+ [MarshalAs(UnmanagedType.IInspectable)] object testObj,
+ int i,
+ [MarshalAs(UnmanagedType.IInspectable)] out object ret);
+
+ [DllImport(nameof(MockReferenceTrackerRuntime))]
+ extern public static int UpdateTestObjectAsInterface(
+ [MarshalAs(UnmanagedType.Interface)] Test testObj,
+ int i,
+ [Out, MarshalAs(UnmanagedType.Interface)] out Test ret);
+ }
+
+ private const string ManagedServerTypeName = "ConsumeNETServerTesting";
+
+ private const string IID_IDISPATCH = "00020400-0000-0000-C000-000000000046";
+ private const string IID_IINSPECTABLE = "AF86E2E0-B12D-4c6a-9C5A-D7AA65101E90";
+ class TestEx : Test
+ {
+ public readonly Guid[] Interfaces;
+ public TestEx(params string[] iids)
+ {
+ Interfaces = new Guid[iids.Length];
+ for (int i = 0; i < iids.Length; i++)
+ Interfaces[i] = Guid.Parse(iids[i]);
+ }
+ }
+
+ class FakeWrapper
+ {
+ private delegate int _AddRef(IntPtr This);
+ private delegate int _Release(IntPtr This);
+ private struct IUnknownWrapperVtbl
+ {
+ public IntPtr QueryInterface;
+ public _AddRef AddRef;
+ public _Release Release;
+ }
+
+ private readonly IntPtr wrappedInstance;
+
+ private readonly IUnknownWrapperVtbl vtable;
+
+ public FakeWrapper(IntPtr instance)
+ {
+ this.wrappedInstance = instance;
+ var inst = Marshal.PtrToStructure(instance);
+ this.vtable = Marshal.PtrToStructure(inst.Vtbl);
+ }
+
+ ~FakeWrapper()
+ {
+ if (this.wrappedInstance != IntPtr.Zero)
+ {
+ this.vtable.Release(this.wrappedInstance);
+ }
+ }
+ }
+
+ class GlobalComWrappers : ComWrappers
+ {
+ public static GlobalComWrappers Instance = new GlobalComWrappers();
+
+ public bool ReturnInvalid { get; set; }
+
+ public object LastComputeVtablesObject { get; private set; }
+
+ protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
+ {
+ LastComputeVtablesObject = obj;
+
+ if (ReturnInvalid)
+ {
+ count = -1;
+ return null;
+ }
+
+ if (obj is Test)
+ {
+ return ComputeVtablesForTestObject((Test)obj, out count);
+ }
+ else if (string.Equals(ManagedServerTypeName, obj.GetType().Name))
+ {
+ IntPtr fpQueryInteface = default;
+ IntPtr fpAddRef = default;
+ IntPtr fpRelease = default;
+ ComWrappers.GetIUnknownImpl(out fpQueryInteface, out fpAddRef, out fpRelease);
+
+ var vtbl = new IUnknownVtbl()
+ {
+ QueryInterface = fpQueryInteface,
+ AddRef = fpAddRef,
+ Release = fpRelease
+ };
+ var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IUnknownVtbl), sizeof(IUnknownVtbl));
+ Marshal.StructureToPtr(vtbl, vtblRaw, false);
+
+ // Including interfaces to allow QI, but not actually returning a valid vtable, since it is not needed for the tests here.
+ var entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IUnknownVtbl), sizeof(ComInterfaceEntry));
+ entryRaw[0].IID = typeof(Server.Contract.IConsumeNETServer).GUID;
+ entryRaw[0].Vtable = vtblRaw;
+
+ count = 1;
+ return entryRaw;
+ }
+
+ count = -1;
+ return null;
+ }
+
+ protected override object? CreateObject(IntPtr externalComObject, CreateObjectFlags flag)
+ {
+ if (ReturnInvalid)
+ return null;
+
+ Guid[] iids = {
+ typeof(ITrackerObject).GUID,
+ typeof(ITest).GUID,
+ typeof(Server.Contract.IDispatchTesting).GUID,
+ typeof(Server.Contract.IConsumeNETServer).GUID
+ };
+
+ for (var i = 0; i < iids.Length; i++)
+ {
+ var iid = iids[i];
+ IntPtr comObject;
+ int hr = Marshal.QueryInterface(externalComObject, ref iid, out comObject);
+ if (hr == 0)
+ return new FakeWrapper(comObject);
+ }
+
+ return null;
+ }
+
+ public const int ReleaseObjectsCallAck = unchecked((int)-1);
+
+ protected override void ReleaseObjects(IEnumerable objects)
+ {
+ throw new Exception() { HResult = ReleaseObjectsCallAck };
+ }
+
+ private unsafe ComInterfaceEntry* ComputeVtablesForTestObject(Test obj, out int count)
+ {
+ IntPtr fpQueryInteface = default;
+ IntPtr fpAddRef = default;
+ IntPtr fpRelease = default;
+ ComWrappers.GetIUnknownImpl(out fpQueryInteface, out fpAddRef, out fpRelease);
+
+ var iUnknownVtbl = new IUnknownVtbl()
+ {
+ QueryInterface = fpQueryInteface,
+ AddRef = fpAddRef,
+ Release = fpRelease
+ };
+
+ var vtbl = new ITestVtbl()
+ {
+ IUnknownImpl = iUnknownVtbl,
+ SetValue = Marshal.GetFunctionPointerForDelegate(ITestVtbl.pSetValue)
+ };
+ var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ITestVtbl));
+ Marshal.StructureToPtr(vtbl, vtblRaw, false);
+
+ int countLocal = obj is TestEx ? ((TestEx)obj).Interfaces.Length + 1 : 1;
+ var entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ComInterfaceEntry) * countLocal);
+ entryRaw[0].IID = typeof(ITest).GUID;
+ entryRaw[0].Vtable = vtblRaw;
+
+ if (obj is TestEx)
+ {
+ var iUnknownVtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IUnknownVtbl), sizeof(IUnknownVtbl));
+ Marshal.StructureToPtr(iUnknownVtbl, iUnknownVtblRaw, false);
+
+ var testEx = (TestEx)obj;
+ for (int i = 1; i < testEx.Interfaces.Length + 1; i++)
+ {
+ // Including interfaces to allow QI, but not actually returning a valid vtable, since it is not needed for the tests here.
+ entryRaw[i].IID = testEx.Interfaces[i - 1];
+ entryRaw[i].Vtable = iUnknownVtblRaw;
+ }
+ }
+
+ count = countLocal;
+ return entryRaw;
+ }
+ }
+
+ private static void ValidateRegisterAsGlobalInstance()
+ {
+ Console.WriteLine($"Running {nameof(ValidateRegisterAsGlobalInstance)}...");
+
+ var wrappers1 = GlobalComWrappers.Instance;
+ wrappers1.RegisterAsGlobalInstance();
+ Assert.Throws(
+ () =>
+ {
+ wrappers1.RegisterAsGlobalInstance();
+ }, "Should not be able to re-register for global ComWrappers");
+
+ var wrappers2 = new GlobalComWrappers();
+ Assert.Throws(
+ () =>
+ {
+ wrappers2.RegisterAsGlobalInstance();
+ }, "Should not be able to reset for global ComWrappers");
+ }
+
+ private static void ValidateMarshalAPIs(bool validateUseRegistered)
+ {
+ string scenario = validateUseRegistered ? "use registered wrapper" : "fall back to runtime";
+ Console.WriteLine($"Running {nameof(ValidateMarshalAPIs)}: {scenario}...");
+
+ GlobalComWrappers registeredWrapper = GlobalComWrappers.Instance;
+ registeredWrapper.ReturnInvalid = !validateUseRegistered;
+
+ Console.WriteLine($" -- Validate Marshal.GetIUnknownForObject...");
+
+ var testObj = new Test();
+ IntPtr comWrapper1 = Marshal.GetIUnknownForObject(testObj);
+ Assert.AreNotEqual(IntPtr.Zero, comWrapper1);
+ Assert.AreEqual(testObj, registeredWrapper.LastComputeVtablesObject, "Registered ComWrappers instance should have been called");
+
+ IntPtr comWrapper2 = Marshal.GetIUnknownForObject(testObj);
+ Assert.AreEqual(comWrapper1, comWrapper2);
+
+ Marshal.Release(comWrapper1);
+ Marshal.Release(comWrapper2);
+
+ Console.WriteLine($" -- Validate Marshal.GetIDispatchForObject...");
+
+ Assert.Throws(() => Marshal.GetIDispatchForObject(testObj));
+
+ if (validateUseRegistered)
+ {
+ var dispatchObj = new TestEx(IID_IDISPATCH);
+ IntPtr dispatchWrapper = Marshal.GetIDispatchForObject(dispatchObj);
+ Assert.AreNotEqual(IntPtr.Zero, dispatchWrapper);
+ Assert.AreEqual(dispatchObj, registeredWrapper.LastComputeVtablesObject, "Registered ComWrappers instance should have been called");
+ }
+
+ Console.WriteLine($" -- Validate Marshal.GetObjectForIUnknown...");
+
+ IntPtr trackerObjRaw = MockReferenceTrackerRuntime.CreateTrackerObject();
+ object objWrapper1 = Marshal.GetObjectForIUnknown(trackerObjRaw);
+ Assert.AreEqual(validateUseRegistered, objWrapper1 is FakeWrapper, $"GetObjectForIUnknown should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance");
+ object objWrapper2 = Marshal.GetObjectForIUnknown(trackerObjRaw);
+ Assert.AreEqual(objWrapper1, objWrapper2);
+
+ Console.WriteLine($" -- Validate Marshal.GetUniqueObjectForIUnknown...");
+
+ object objWrapper3 = Marshal.GetUniqueObjectForIUnknown(trackerObjRaw);
+ Assert.AreEqual(validateUseRegistered, objWrapper3 is FakeWrapper, $"GetObjectForIUnknown should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance");
+
+ Assert.AreNotEqual(objWrapper1, objWrapper3);
+
+ Marshal.Release(trackerObjRaw);
+ }
+
+ private static void ValidatePInvokes(bool validateUseRegistered)
+ {
+ string scenario = validateUseRegistered ? "use registered wrapper" : "fall back to runtime";
+ Console.WriteLine($"Running {nameof(ValidatePInvokes)}: {scenario}...");
+
+ GlobalComWrappers.Instance.ReturnInvalid = !validateUseRegistered;
+
+ Console.WriteLine($" -- Validate MarshalAs IUnknown...");
+ ValidateInterfaceMarshaler