diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs index f5d0e06b016a3..59f3d47ac3585 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs @@ -118,16 +118,35 @@ private struct ComInterfaceInstance /// Flags used to configure the generated interface. /// The generated COM interface that can be passed outside the .NET runtime. public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfaceFlags flags) + { + IntPtr ptr; + if (!TryGetOrCreateComInterfaceForObjectInternal(this, instance, flags, out ptr)) + throw new ArgumentException(); + + return ptr; + } + + /// + /// Create a COM representation of the supplied object that can be passed to a non-managed environment. + /// + /// The implementation to use when creating the COM representation. + /// The managed object to expose outside the .NET runtime. + /// Flags used to configure the generated interface. + /// The generated COM interface that can be passed outside the .NET runtime or IntPtr.Zero if it could not be created. + /// Returns true if a COM representation could be created, false otherwise + /// + /// If is null, the global instance (if registered) will be used. + /// + private static bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers? impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue) { if (instance == null) throw new ArgumentNullException(nameof(instance)); - ComWrappers impl = this; - return GetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), ObjectHandleOnStack.Create(ref instance), flags); + return TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), ObjectHandleOnStack.Create(ref instance), flags, out retValue); } [DllImport(RuntimeHelpers.QCall)] - private static extern IntPtr GetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, ObjectHandleOnStack instance, CreateComInterfaceFlags flags); + private static extern bool TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, ObjectHandleOnStack instance, CreateComInterfaceFlags flags, out IntPtr retValue); /// /// Compute the desired Vtable for respecting the values of . @@ -140,13 +159,23 @@ public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfa /// All memory returned from this function must either be unmanaged memory, pinned managed memory, or have been /// allocated with the API. /// - /// If the interface entries cannot be created and null is returned, the call to will throw a . + /// If the interface entries cannot be created and a negative or null and a non-zero are returned, + /// the call to will throw a . /// protected unsafe abstract ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count); // Call to execute the abstract instance function internal static unsafe void* CallComputeVtables(ComWrappers? comWrappersImpl, object obj, CreateComInterfaceFlags flags, out int count) - => (comWrappersImpl ?? s_globalInstance!).ComputeVtables(obj, flags, out count); + { + ComWrappers? impl = comWrappersImpl ?? s_globalInstance; + if (impl is null) + { + count = -1; + return null; + } + + return impl.ComputeVtables(obj, flags, out count); + } /// /// Get the currently registered managed object or creates a new managed object and registers it. @@ -156,7 +185,11 @@ public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfa /// Returns a managed object associated with the supplied external COM object. public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags) { - return GetOrCreateObjectForComInstanceInternal(externalComObject, flags, null); + object? obj; + if (!TryGetOrCreateObjectForComInstanceInternal(this, externalComObject, flags, null, out obj)) + throw new ArgumentNullException(); + + return obj!; } /// @@ -172,7 +205,13 @@ public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateOb // Call to execute the abstract instance function internal static object? CallCreateObject(ComWrappers? comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags) - => (comWrappersImpl ?? s_globalInstance!).CreateObject(externalComObject, flags); + { + ComWrappers? impl = comWrappersImpl ?? s_globalInstance; + if (impl == null) + return null; + + return impl.CreateObject(externalComObject, flags); + } /// /// Get the currently registered managed object or uses the supplied managed object and registers it. @@ -189,24 +228,37 @@ public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, Create if (wrapper == null) throw new ArgumentNullException(nameof(externalComObject)); - return GetOrCreateObjectForComInstanceInternal(externalComObject, flags, wrapper); + object? obj; + if (!TryGetOrCreateObjectForComInstanceInternal(this, externalComObject, flags, wrapper, out obj)) + throw new ArgumentNullException(); + + return obj!; } - private object GetOrCreateObjectForComInstanceInternal(IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe) + /// + /// Get the currently registered managed object or creates a new managed object and registers it. + /// + /// The implementation to use when creating the managed object. + /// Object to import for usage into the .NET runtime. + /// Flags used to describe the external object. + /// The to be used as the wrapper for the external object. + /// The managed object associated with the supplied external COM object or null if it could not be created. + /// Returns true if a managed object could be retrieved/created, false otherwise + /// + /// If is null, the global instance (if registered) will be used. + /// + private static bool TryGetOrCreateObjectForComInstanceInternal(ComWrappers? impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe, out object? retValue) { if (externalComObject == IntPtr.Zero) throw new ArgumentNullException(nameof(externalComObject)); - ComWrappers impl = this; object? wrapperMaybeLocal = wrapperMaybe; - object? retValue = null; - GetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue)); - - return retValue!; + retValue = null; + return TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue)); } [DllImport(RuntimeHelpers.QCall)] - private static extern void GetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue); + private static extern bool TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue); /// /// Called when a request is made for a collection of objects to be released outside of normal object or COM interface lifetime. @@ -235,8 +287,14 @@ public void RegisterAsGlobalInstance() { throw new InvalidOperationException(SR.InvalidOperation_ResetGlobalComWrappersInstance); } + + SetGlobalInstanceRegistered(); } + [DllImport(RuntimeHelpers.QCall)] + [SuppressGCTransition] + private static extern void SetGlobalInstanceRegistered(); + /// /// Get the runtime provided IUnknown implementation. /// diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs index 07b7ce866c3d6..e9b8b3df601b5 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.CoreCLR.cs @@ -326,6 +326,11 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) /// public static IntPtr /* IUnknown* */ GetIUnknownForObject(object o) { + if (o is null) + { + throw new ArgumentNullException(nameof(o)); + } + return GetIUnknownForObjectNative(o, false); } @@ -344,6 +349,11 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) /// public static IntPtr /* IDispatch */ GetIDispatchForObject(object o) { + if (o is null) + { + throw new ArgumentNullException(nameof(o)); + } + return GetIDispatchForObjectNative(o, false); } @@ -356,6 +366,16 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) /// public static IntPtr /* IUnknown* */ GetComInterfaceForObject(object o, Type T) { + if (o is null) + { + throw new ArgumentNullException(nameof(o)); + } + + if (T is null) + { + throw new ArgumentNullException(nameof(T)); + } + return GetComInterfaceForObjectNative(o, T, false, true); } @@ -368,15 +388,48 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) /// public static IntPtr /* IUnknown* */ GetComInterfaceForObject(object o, Type T, CustomQueryInterfaceMode mode) { + if (o is null) + { + throw new ArgumentNullException(nameof(o)); + } + + if (T is null) + { + throw new ArgumentNullException(nameof(T)); + } + bool bEnableCustomizedQueryInterface = ((mode == CustomQueryInterfaceMode.Allow) ? true : false); return GetComInterfaceForObjectNative(o, T, false, bEnableCustomizedQueryInterface); } [MethodImpl(MethodImplOptions.InternalCall)] - private static extern IntPtr /* IUnknown* */ GetComInterfaceForObjectNative(object o, Type t, bool onlyInContext, bool fEnalbeCustomizedQueryInterface); + private static extern IntPtr /* IUnknown* */ GetComInterfaceForObjectNative(object o, Type t, bool onlyInContext, bool fEnableCustomizedQueryInterface); + + /// + /// Return the managed object representing the IUnknown* + /// + public static object GetObjectForIUnknown(IntPtr /* IUnknown* */ pUnk) + { + if (pUnk == IntPtr.Zero) + { + throw new ArgumentNullException(nameof(pUnk)); + } + + return GetObjectForIUnknownNative(pUnk); + } [MethodImpl(MethodImplOptions.InternalCall)] - public static extern object GetObjectForIUnknown(IntPtr /* IUnknown* */ pUnk); + private static extern object GetObjectForIUnknownNative(IntPtr /* IUnknown* */ pUnk); + + public static object GetUniqueObjectForIUnknown(IntPtr unknown) + { + if (unknown == IntPtr.Zero) + { + throw new ArgumentNullException(nameof(unknown)); + } + + return GetUniqueObjectForIUnknownNative(unknown); + } /// /// Return a unique Object given an IUnknown. This ensures that you receive a fresh @@ -385,7 +438,7 @@ public static string GetTypeInfoName(ITypeInfo typeInfo) /// ReleaseComObject on a RCW and not worry about other active uses ofsaid RCW. /// [MethodImpl(MethodImplOptions.InternalCall)] - public static extern object GetUniqueObjectForIUnknown(IntPtr unknown); + private static extern object GetUniqueObjectForIUnknownNative(IntPtr unknown); /// /// Return an Object for IUnknown, using the Type T. diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs index 8c94325722e13..691113f665be0 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/StubHelpers.cs @@ -690,7 +690,7 @@ internal static class InterfaceMarshaler internal static extern IntPtr ConvertToNative(object objSrc, IntPtr itfMT, IntPtr classMT, int flags); [MethodImpl(MethodImplOptions.InternalCall)] - internal static extern object ConvertToManaged(IntPtr pUnk, IntPtr itfMT, IntPtr classMT, int flags); + internal static extern object ConvertToManaged(IntPtr ppUnk, IntPtr itfMT, IntPtr classMT, int flags); [DllImport(RuntimeHelpers.QCall)] internal static extern void ClearNative(IntPtr pUnk); diff --git a/src/coreclr/src/interop/comwrappers.hpp b/src/coreclr/src/interop/comwrappers.hpp index d2337746e39b8..e9838646b0189 100644 --- a/src/coreclr/src/interop/comwrappers.hpp +++ b/src/coreclr/src/interop/comwrappers.hpp @@ -16,9 +16,10 @@ enum class CreateComInterfaceFlagsEx : int32_t TrackerSupport = InteropLib::Com::CreateComInterfaceFlags_TrackerSupport, // Highest bit is reserved for internal usage + IsComActivated = 1 << 30, IsPegged = 1 << 31, - InternalMask = IsPegged, + InternalMask = IsPegged | IsComActivated, }; DEFINE_ENUM_FLAG_OPERATORS(CreateComInterfaceFlagsEx); diff --git a/src/coreclr/src/interop/inc/interoplib.h b/src/coreclr/src/interop/inc/interoplib.h index df59b8b90a584..fa0e0c3e25918 100644 --- a/src/coreclr/src/interop/inc/interoplib.h +++ b/src/coreclr/src/interop/inc/interoplib.h @@ -45,6 +45,12 @@ namespace InteropLib // Reactivate the supplied wrapper. HRESULT ReactivateWrapper(_In_ IUnknown* wrapper, _In_ InteropLib::OBJECTHANDLE handle) noexcept; + // Get the object for the supplied wrapper + HRESULT GetObjectForWrapper(_In_ IUnknown* wrapper, _Outptr_result_maybenull_ OBJECTHANDLE* object) noexcept; + + HRESULT MarkComActivated(_In_ IUnknown* wrapper) noexcept; + HRESULT IsComActivated(_In_ IUnknown* wrapper) noexcept; + struct ExternalWrapperResult { // The returned context memory is guaranteed to be initialized to zero. diff --git a/src/coreclr/src/interop/interoplib.cpp b/src/coreclr/src/interop/interoplib.cpp index 8283024ad5d4a..fb7af5aa7ec65 100644 --- a/src/coreclr/src/interop/interoplib.cpp +++ b/src/coreclr/src/interop/interoplib.cpp @@ -89,6 +89,43 @@ namespace InteropLib return S_OK; } + HRESULT GetObjectForWrapper(_In_ IUnknown* wrapper, _Outptr_result_maybenull_ OBJECTHANDLE* object) noexcept + { + if (object == nullptr) + return E_POINTER; + + *object = nullptr; + + HRESULT hr = IsActiveWrapper(wrapper); + if (hr != S_OK) + return hr; + + ManagedObjectWrapper *mow = ManagedObjectWrapper::MapFromIUnknown(wrapper); + _ASSERTE(mow != nullptr); + + *object = mow->Target; + return S_OK; + } + + HRESULT MarkComActivated(_In_ IUnknown* wrapperMaybe) noexcept + { + ManagedObjectWrapper* wrapper = ManagedObjectWrapper::MapFromIUnknown(wrapperMaybe); + if (wrapper == nullptr) + return E_INVALIDARG; + + wrapper->SetFlag(CreateComInterfaceFlagsEx::IsComActivated); + return S_OK; + } + + HRESULT IsComActivated(_In_ IUnknown* wrapperMaybe) noexcept + { + ManagedObjectWrapper* wrapper = ManagedObjectWrapper::MapFromIUnknown(wrapperMaybe); + if (wrapper == nullptr) + return E_INVALIDARG; + + return wrapper->IsSet(CreateComInterfaceFlagsEx::IsComActivated) ? S_OK : S_FALSE; + } + HRESULT CreateWrapperForExternal( _In_ IUnknown* external, _In_ enum CreateObjectFlags flags, diff --git a/src/coreclr/src/vm/ecalllist.h b/src/coreclr/src/vm/ecalllist.h index 3cbfa0e5ec40e..691bad54bf537 100644 --- a/src/coreclr/src/vm/ecalllist.h +++ b/src/coreclr/src/vm/ecalllist.h @@ -805,8 +805,8 @@ FCFuncStart(gInteropMarshalFuncs) FCFuncElement("GetHRForException", MarshalNative::GetHRForException) FCFuncElement("GetRawIUnknownForComObjectNoAddRef", MarshalNative::GetRawIUnknownForComObjectNoAddRef) FCFuncElement("IsComObject", MarshalNative::IsComObject) - FCFuncElement("GetObjectForIUnknown", MarshalNative::GetObjectForIUnknown) - FCFuncElement("GetUniqueObjectForIUnknown", MarshalNative::GetUniqueObjectForIUnknown) + FCFuncElement("GetObjectForIUnknownNative", MarshalNative::GetObjectForIUnknownNative) + FCFuncElement("GetUniqueObjectForIUnknownNative", MarshalNative::GetUniqueObjectForIUnknownNative) FCFuncElement("AddRef", MarshalNative::AddRef) FCFuncElement("GetNativeVariantForObject", MarshalNative::GetNativeVariantForObject) FCFuncElement("GetObjectForNativeVariant", MarshalNative::GetObjectForNativeVariant) @@ -997,8 +997,9 @@ FCFuncEnd() #ifdef FEATURE_COMWRAPPERS FCFuncStart(gComWrappersFuncs) QCFuncElement("GetIUnknownImplInternal", ComWrappersNative::GetIUnknownImpl) - QCFuncElement("GetOrCreateComInterfaceForObjectInternal", ComWrappersNative::GetOrCreateComInterfaceForObject) - QCFuncElement("GetOrCreateObjectForComInstanceInternal", ComWrappersNative::GetOrCreateObjectForComInstance) + QCFuncElement("TryGetOrCreateComInterfaceForObjectInternal", ComWrappersNative::TryGetOrCreateComInterfaceForObject) + QCFuncElement("TryGetOrCreateObjectForComInstanceInternal", ComWrappersNative::TryGetOrCreateObjectForComInstance) + QCFuncElement("SetGlobalInstanceRegistered", GlobalComWrappers::SetGlobalInstanceRegistered) FCFuncEnd() #endif // FEATURE_COMWRAPPERS diff --git a/src/coreclr/src/vm/interopconverter.cpp b/src/coreclr/src/vm/interopconverter.cpp index 58ac452295bd9..44470d7d297b4 100644 --- a/src/coreclr/src/vm/interopconverter.cpp +++ b/src/coreclr/src/vm/interopconverter.cpp @@ -17,9 +17,66 @@ #include "runtimecallablewrapper.h" #include "cominterfacemarshaler.h" #include "binder.h" +#include #include "winrttypenameconverter.h" #include "typestring.h" +namespace +{ + bool TryGetComIPFromObjectRefUsingComWrappers( + _In_ OBJECTREF instance, + _Outptr_ IUnknown** wrapperRaw) + { +#ifdef FEATURE_COMWRAPPERS + return GlobalComWrappers::TryGetOrCreateComInterfaceForObject(instance, (void**)wrapperRaw); +#else + return false; +#endif // FEATURE_COMWRAPPERS + } + + bool TryGetObjectRefFromComIPUsingComWrappers( + _In_ IUnknown* pUnknown, + _In_ DWORD dwFlags, + _Out_ OBJECTREF *pObjOut) + { +#ifdef FEATURE_COMWRAPPERS + return GlobalComWrappers::TryGetOrCreateObjectForComInstance(pUnknown, dwFlags, pObjOut); +#else + return false; +#endif // FEATURE_COMWRAPPERS + } + + void EnsureObjectRefIsValidForSpecifiedClass( + _In_ OBJECTREF *obj, + _In_ DWORD dwFlags, + _In_ MethodTable *pMTClass) + { + _ASSERTE(*obj != NULL); + _ASSERTE(pMTClass != NULL); + + if ((dwFlags & ObjFromComIP::CLASS_IS_HINT) != 0) + return; + + // make sure we can cast to the specified class + FAULT_NOT_FATAL(); + + // Bad format exception thrown for backward compatibility + THROW_BAD_FORMAT_MAYBE(pMTClass->IsArray() == FALSE, BFA_UNEXPECTED_ARRAY_TYPE, pMTClass); + + if (CanCastComObject(*obj, pMTClass)) + return; + + StackSString ssObjClsName; + StackSString ssDestClsName; + + (*obj)->GetMethodTable()->_GetFullyQualifiedNameForClass(ssObjClsName); + pMTClass->_GetFullyQualifiedNameForClass(ssDestClsName); + + COMPlusThrow(kInvalidCastException, IDS_EE_CANNOTCAST, + ssObjClsName.GetUnicode(), ssDestClsName.GetUnicode()); + } +} + //-------------------------------------------------------------------------------- // IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, MethodTable *pMT, ...) // Convert ObjectRef to a COM IP, based on MethodTable* pMT. @@ -45,6 +102,11 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, MethodTable *pMT, BOOL bSecuri if (*poref == NULL) RETURN NULL; + if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk)) + { + pUnk.SuppressRelease(); + RETURN pUnk; + } SyncBlock* pBlock = (*poref)->GetSyncBlock(); @@ -111,6 +173,30 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, ComIpType ReqIpType, ComIpType if (*poref == NULL) RETURN NULL; + if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk)) + { + hr = S_OK; + + SafeComHolder pvObj; + if (ReqIpType & ComIpType_Dispatch) + { + hr = pUnk->QueryInterface(IID_IDispatch, &pvObj); + } + else if (ReqIpType & ComIpType_Inspectable) + { + SafeComHolder pvObj; + hr = pUnk->QueryInterface(IID_IInspectable, &pvObj); + } + + if (FAILED(hr)) + COMPlusThrowHR(hr); + + if (pFetchedIpType != NULL) + *pFetchedIpType = ReqIpType; + + RETURN pUnk; + } + MethodTable *pMT = (*poref)->GetMethodTable(); SyncBlock* pBlock = (*poref)->GetSyncBlock(); @@ -376,6 +462,16 @@ IUnknown *GetComIPFromObjectRef(OBJECTREF *poref, REFIID iid, bool throwIfNoComI if (*poref == NULL) RETURN NULL; + if (TryGetComIPFromObjectRefUsingComWrappers(*poref, &pUnk)) + { + SafeComHolder pvObj; + hr = pUnk->QueryInterface(iid, &pvObj); + if (FAILED(hr)) + COMPlusThrowHR(hr); + + RETURN pUnk; + } + MethodTable *pMT = (*poref)->GetMethodTable(); SyncBlock* pBlock = (*poref)->GetSyncBlock(); @@ -434,9 +530,18 @@ void GetObjectRefFromComIP(OBJECTREF* pObjOut, IUnknown **ppUnk, MethodTable *pM _ASSERTE(g_fComStarted && "COM has not been started up, make sure EnsureComStarted is called before any COM objects are used!"); IUnknown *pUnk = *ppUnk; + *pObjOut = NULL; + + if (TryGetObjectRefFromComIPUsingComWrappers(pUnk, dwFlags, pObjOut)) + { + if (pMTClass != NULL) + EnsureObjectRefIsValidForSpecifiedClass(pObjOut, dwFlags, pMTClass); + + return; + } + Thread * pThread = GetThread(); - *pObjOut = NULL; IUnknown* pOuter = pUnk; SafeComHolder pAutoOuterUnk = NULL; @@ -537,22 +642,7 @@ void GetObjectRefFromComIP(OBJECTREF* pObjOut, IUnknown **ppUnk, MethodTable *pM // make sure we can cast to the specified class if (pMTClass != NULL) { - FAULT_NOT_FATAL(); - - // Bad format exception thrown for backward compatibility - THROW_BAD_FORMAT_MAYBE(pMTClass->IsArray() == FALSE, BFA_UNEXPECTED_ARRAY_TYPE, pMTClass); - - if (!CanCastComObject(*pObjOut, pMTClass)) - { - StackSString ssObjClsName; - StackSString ssDestClsName; - - (*pObjOut)->GetMethodTable()->_GetFullyQualifiedNameForClass(ssObjClsName); - pMTClass->_GetFullyQualifiedNameForClass(ssDestClsName); - - COMPlusThrow(kInvalidCastException, IDS_EE_CANNOTCAST, - ssObjClsName.GetUnicode(), ssDestClsName.GetUnicode()); - } + EnsureObjectRefIsValidForSpecifiedClass(pObjOut, dwFlags, pMTClass); } else if (dwFlags & ObjFromComIP::REQUIRE_IINSPECTABLE) { diff --git a/src/coreclr/src/vm/interoplibinterface.cpp b/src/coreclr/src/vm/interoplibinterface.cpp index 70e5f14f2c7df..c71b0d955fd53 100644 --- a/src/coreclr/src/vm/interoplibinterface.cpp +++ b/src/coreclr/src/vm/interoplibinterface.cpp @@ -397,9 +397,12 @@ namespace } }; - // Global instance + // Global instance of the external object cache Volatile ExtObjCxtCache::g_Instance; + // Indicator for if a ComWrappers implementation is globally registered + bool g_IsGlobalComWrappersRegistered; + // Defined handle types for the specific object uses. const HandleType InstanceHandleType{ HNDTYPE_STRONG }; @@ -478,24 +481,25 @@ namespace CALL_MANAGED_METHOD_NORET(args); } - void* GetOrCreateComInterfaceForObjectInternal( + bool TryGetOrCreateComInterfaceForObjectInternal( _In_opt_ OBJECTREF impl, _In_ OBJECTREF instance, - _In_ CreateComInterfaceFlags flags) + _In_ CreateComInterfaceFlags flags, + _Outptr_ void** wrapperRaw) { - CONTRACT(void*) + CONTRACT(bool) { THROWS; MODE_COOPERATIVE; PRECONDITION(instance != NULL); - POSTCONDITION(CheckPointer(RETVAL)); + PRECONDITION(wrapperRaw != NULL); } CONTRACT_END; HRESULT hr; SafeComHolder newWrapper; - void* wrapperRaw = NULL; + void* wrapperRawMaybe = NULL; struct { @@ -514,7 +518,7 @@ namespace _ASSERTE(syncBlock->IsPrecious()); // Query the associated InteropSyncBlockInfo for an existing managed object wrapper. - if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRaw)) + if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe)) { // Compute VTables for the new existing COM object using the supplied COM Wrappers implementation. // @@ -525,7 +529,8 @@ namespace void* vtables = CallComputeVTables(&gc.implRef, &gc.instRef, flags, &vtableCount); // Re-query the associated InteropSyncBlockInfo for an existing managed object wrapper. - if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRaw)) + if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe) + && ((vtables != nullptr && vtableCount > 0) || (vtableCount == 0))) { OBJECTHANDLE instHandle = GetAppDomain()->CreateTypedHandle(gc.instRef, InstanceHandleType); @@ -551,7 +556,7 @@ namespace // If the managed object wrapper couldn't be set, then // it should be possible to get the current one. - if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRaw)) + if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe)) { UNREACHABLE(); } @@ -564,20 +569,18 @@ namespace { // A new managed object wrapper was created, remove the object from the holder. // No AddRef() here since the wrapper should be created with a reference. - wrapperRaw = newWrapper.Extract(); - STRESS_LOG1(LF_INTEROP, LL_INFO100, "Created MOW: 0x%p\n", wrapperRaw); + wrapperRawMaybe = newWrapper.Extract(); + STRESS_LOG1(LF_INTEROP, LL_INFO100, "Created MOW: 0x%p\n", wrapperRawMaybe); } - else + else if (wrapperRawMaybe != NULL) { - _ASSERTE(wrapperRaw != NULL); - // It is possible the supplied wrapper is no longer valid. If so, reactivate the // wrapper using the protected OBJECTREF. - IUnknown* wrapper = static_cast(wrapperRaw); + IUnknown* wrapper = static_cast(wrapperRawMaybe); hr = InteropLib::Com::IsActiveWrapper(wrapper); if (hr == S_FALSE) { - STRESS_LOG1(LF_INTEROP, LL_INFO100, "Reactivating MOW: 0x%p\n", wrapperRaw); + STRESS_LOG1(LF_INTEROP, LL_INFO100, "Reactivating MOW: 0x%p\n", wrapperRawMaybe); OBJECTHANDLE h = GetAppDomain()->CreateTypedHandle(gc.instRef, InstanceHandleType); hr = InteropLib::Com::ReactivateWrapper(wrapper, static_cast(h)); } @@ -588,21 +591,29 @@ namespace GCPROTECT_END(); - RETURN wrapperRaw; + *wrapperRaw = wrapperRawMaybe; + RETURN (wrapperRawMaybe != NULL); } - OBJECTREF GetOrCreateObjectForComInstanceInternal( + // The unwrap parameter indicates whether or not COM instances that are actually CCWs should + // be unwrapped to the original managed object. + // For implicit usage of ComWrappers (i.e. automatically called by the runtime when there is a global instance), + // CCWs should be unwrapped to allow for round-tripping object -> COM instance -> object. + // For explicit usage of ComWrappers (i.e. directly called via a ComWrappers APIs), CCWs should not be unwrapped. + bool TryGetOrCreateObjectForComInstanceInternal( _In_opt_ OBJECTREF impl, _In_ IUnknown* identity, _In_ CreateObjectFlags flags, - _In_opt_ OBJECTREF wrapperMaybe) + _In_ bool unwrap, + _In_opt_ OBJECTREF wrapperMaybe, + _Out_ OBJECTREF* objRef) { - CONTRACT(OBJECTREF) + CONTRACT(bool) { THROWS; MODE_COOPERATIVE; PRECONDITION(identity != NULL); - POSTCONDITION(RETVAL != NULL); + PRECONDITION(objRef != NULL); } CONTRACT_END; @@ -613,7 +624,7 @@ namespace { OBJECTREF implRef; OBJECTREF wrapperMaybeRef; - OBJECTREF objRef; + OBJECTREF objRefMaybe; } gc; ::ZeroMemory(&gc, sizeof(gc)); GCPROTECT_BEGIN(gc); @@ -622,6 +633,7 @@ namespace gc.wrapperMaybeRef = wrapperMaybe; ExtObjCxtCache* cache = ExtObjCxtCache::GetInstance(); + InteropLib::OBJECTHANDLE handle = NULL; // Check if the user requested a unique instance. bool uniqueInstance = !!(flags & CreateObjectFlags::CreateObjectFlags_UniqueInstance); @@ -630,93 +642,123 @@ namespace // Query the external object cache ExtObjCxtCache::LockHolder lock(cache); extObjCxt = cache->Find(identity); + + // If is no object found in the cache, check if the object COM instance is actually the CCW + // representing a managed object. + if (extObjCxt == NULL && unwrap) + { + // If the COM instance is a CCW that is not COM-activated, use the object of that wrapper object. + InteropLib::OBJECTHANDLE handleLocal; + if (InteropLib::Com::GetObjectForWrapper(identity, &handleLocal) == S_OK + && InteropLib::Com::IsComActivated(identity) == S_FALSE) + { + handle = handleLocal; + } + } } if (extObjCxt != NULL) { - gc.objRef = extObjCxt->GetObjectRef(); + gc.objRefMaybe = extObjCxt->GetObjectRef(); + } + else if (handle != NULL) + { + // We have an object handle from the COM instance which is a CCW. Use that object. + // This allows for the round-trip from object -> COM instance -> object. + ::OBJECTHANDLE objectHandle = static_cast<::OBJECTHANDLE>(handle); + gc.objRefMaybe = ObjectFromHandle(objectHandle); } else { // Create context instance for the possibly new external object. ExternalWrapperResultHolder resultHolder; - hr = InteropLib::Com::CreateWrapperForExternal( - identity, - flags, - sizeof(ExternalObjectContext), - &resultHolder); + + { + GCX_PREEMP(); + hr = InteropLib::Com::CreateWrapperForExternal( + identity, + flags, + sizeof(ExternalObjectContext), + &resultHolder); + } + if (FAILED(hr)) COMPlusThrowHR(hr); // The user could have supplied a wrapper so assign that now. - gc.objRef = gc.wrapperMaybeRef; + gc.objRefMaybe = gc.wrapperMaybeRef; // If the wrapper hasn't been set yet, call the implementation to create one. - if (gc.objRef == NULL) + if (gc.objRefMaybe == NULL) { - gc.objRef = CallGetObject(&gc.implRef, identity, flags); - if (gc.objRef == NULL) - COMPlusThrow(kArgumentNullException); + gc.objRefMaybe = CallGetObject(&gc.implRef, identity, flags); } - // Construct the new context with the object details. - DWORD flags = (resultHolder.Result.FromTrackerRuntime - ? ExternalObjectContext::Flags_ReferenceTracker - : ExternalObjectContext::Flags_None) | - (uniqueInstance - ? ExternalObjectContext::Flags_None - : ExternalObjectContext::Flags_InCache); - ExternalObjectContext::Construct( - resultHolder.GetContext(), - identity, - GetCurrentCtxCookie(), - gc.objRef->GetSyncBlockIndex(), - flags); - - if (uniqueInstance) + // The object may be null if the specified ComWrapper implementation returns null + // or there is no registered global instance. It is the caller's responsibility + // to handle this case and error if necessary. + if (gc.objRefMaybe != NULL) { - extObjCxt = resultHolder.GetContext(); - } - else - { - // Attempt to insert the new context into the cache. - ExtObjCxtCache::LockHolder lock(cache); - extObjCxt = cache->FindOrAdd(identity, resultHolder.GetContext()); - } - - // If the returned context matches the new context it means the - // new context was inserted or a unique instance was requested. - if (extObjCxt == resultHolder.GetContext()) - { - // Update the object's SyncBlock with a handle to the context for runtime cleanup. - SyncBlock* syncBlock = gc.objRef->GetSyncBlock(); - InteropSyncBlockInfo* interopInfo = syncBlock->GetInteropInfo(); - _ASSERTE(syncBlock->IsPrecious()); - - // Since the caller has the option of providing a wrapper, it is - // possible the supplied wrapper already has an associated external - // object and an object can only be associated with one external object. - if (!interopInfo->TrySetExternalComObjectContext((void**)extObjCxt)) + // Construct the new context with the object details. + DWORD flags = (resultHolder.Result.FromTrackerRuntime + ? ExternalObjectContext::Flags_ReferenceTracker + : ExternalObjectContext::Flags_None) | + (uniqueInstance + ? ExternalObjectContext::Flags_None + : ExternalObjectContext::Flags_InCache); + ExternalObjectContext::Construct( + resultHolder.GetContext(), + identity, + GetCurrentCtxCookie(), + gc.objRefMaybe->GetSyncBlockIndex(), + flags); + + if (uniqueInstance) { - // Failed to set the context; one must already exist. - // Remove from the cache above as well. + extObjCxt = resultHolder.GetContext(); + } + else + { + // Attempt to insert the new context into the cache. ExtObjCxtCache::LockHolder lock(cache); - cache->Remove(resultHolder.GetContext()); + extObjCxt = cache->FindOrAdd(identity, resultHolder.GetContext()); + } + + // If the returned context matches the new context it means the + // new context was inserted or a unique instance was requested. + if (extObjCxt == resultHolder.GetContext()) + { + // Update the object's SyncBlock with a handle to the context for runtime cleanup. + SyncBlock* syncBlock = gc.objRefMaybe->GetSyncBlock(); + InteropSyncBlockInfo* interopInfo = syncBlock->GetInteropInfo(); + _ASSERTE(syncBlock->IsPrecious()); + + // Since the caller has the option of providing a wrapper, it is + // possible the supplied wrapper already has an associated external + // object and an object can only be associated with one external object. + if (!interopInfo->TrySetExternalComObjectContext((void**)extObjCxt)) + { + // Failed to set the context; one must already exist. + // Remove from the cache above as well. + ExtObjCxtCache::LockHolder lock(cache); + cache->Remove(resultHolder.GetContext()); + + COMPlusThrow(kNotSupportedException); + } - COMPlusThrow(kNotSupportedException); + // Detach from the holder to avoid cleanup. + (void)resultHolder.DetachContext(); + STRESS_LOG2(LF_INTEROP, LL_INFO100, "Created EOC (Unique Instance: %d): 0x%p\n", (int)uniqueInstance, extObjCxt); } - // Detach from the holder to avoid cleanup. - (void)resultHolder.DetachContext(); - STRESS_LOG2(LF_INTEROP, LL_INFO100, "Created EOC (Unique Instance: %d): 0x%p\n", (int)uniqueInstance, extObjCxt); + _ASSERTE(extObjCxt->IsActive()); } - - _ASSERTE(extObjCxt->IsActive()); } GCPROTECT_END(); - RETURN gc.objRef; + *objRef = gc.objRefMaybe; + RETURN (gc.objRefMaybe != NULL); } } @@ -949,19 +991,29 @@ namespace InteropLibImports gc.implRef = NULL; // Use the globally registered implementation. gc.wrapperMaybeRef = NULL; // No supplied wrapper here. + bool unwrapIfManagedObjectWrapper = false; // Don't unwrap CCWs // Get wrapper for external object - gc.objRef = GetOrCreateObjectForComInstanceInternal( + bool success = TryGetOrCreateObjectForComInstanceInternal( gc.implRef, externalComObject, externalObjectFlags, - gc.wrapperMaybeRef); + unwrapIfManagedObjectWrapper, + gc.wrapperMaybeRef, + &gc.objRef); + + if (!success) + COMPlusThrow(kArgumentNullException); // Get wrapper for managed object - *trackerTarget = GetOrCreateComInterfaceForObjectInternal( + success = TryGetOrCreateComInterfaceForObjectInternal( gc.implRef, gc.objRef, - trackerTargetFlags); + trackerTargetFlags, + trackerTarget); + + if (!success) + COMPlusThrow(kArgumentException); STRESS_LOG2(LF_INTEROP, LL_INFO100, "Created Target for External: 0x%p => 0x%p\n", OBJECTREFToObject(gc.objRef), *trackerTarget); GCPROTECT_END(); @@ -1055,14 +1107,15 @@ namespace InteropLibImports #ifdef FEATURE_COMWRAPPERS -void* QCALLTYPE ComWrappersNative::GetOrCreateComInterfaceForObject( +BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateComInterfaceForObject( _In_ QCall::ObjectHandleOnStack comWrappersImpl, _In_ QCall::ObjectHandleOnStack instance, - _In_ INT32 flags) + _In_ INT32 flags, + _Outptr_ void** wrapper) { QCALL_CONTRACT; - void* wrapper = NULL; + bool success; BEGIN_QCALL; @@ -1070,19 +1123,19 @@ void* QCALLTYPE ComWrappersNative::GetOrCreateComInterfaceForObject( // are being manipulated. { GCX_COOP(); - wrapper = GetOrCreateComInterfaceForObjectInternal( + success = TryGetOrCreateComInterfaceForObjectInternal( ObjectToOBJECTREF(*comWrappersImpl.m_ppObject), ObjectToOBJECTREF(*instance.m_ppObject), - (CreateComInterfaceFlags)flags); + (CreateComInterfaceFlags)flags, + wrapper); } END_QCALL; - _ASSERTE(wrapper != NULL); - return wrapper; + return (success ? TRUE : FALSE); } -void QCALLTYPE ComWrappersNative::GetOrCreateObjectForComInstance( +BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance( _In_ QCall::ObjectHandleOnStack comWrappersImpl, _In_ void* ext, _In_ INT32 flags, @@ -1093,6 +1146,8 @@ void QCALLTYPE ComWrappersNative::GetOrCreateObjectForComInstance( _ASSERTE(ext != NULL); + bool success; + BEGIN_QCALL; HRESULT hr; @@ -1107,17 +1162,25 @@ void QCALLTYPE ComWrappersNative::GetOrCreateObjectForComInstance( // are being manipulated. { GCX_COOP(); - OBJECTREF newObj = GetOrCreateObjectForComInstanceInternal( + + bool unwrapIfManagedObjectWrapper = false; // Don't unwrap CCWs + OBJECTREF newObj; + success = TryGetOrCreateObjectForComInstanceInternal( ObjectToOBJECTREF(*comWrappersImpl.m_ppObject), identity, (CreateObjectFlags)flags, - ObjectToOBJECTREF(*wrapperMaybe.m_ppObject)); + unwrapIfManagedObjectWrapper, + ObjectToOBJECTREF(*wrapperMaybe.m_ppObject), + &newObj); // Set the return value - retValue.Set(newObj); + if (success) + retValue.Set(newObj); } END_QCALL; + + return (success ? TRUE : FALSE); } void QCALLTYPE ComWrappersNative::GetIUnknownImpl( @@ -1198,6 +1261,85 @@ void ComWrappersNative::MarkExternalComObjectContextCollected(_In_ void* context } } +void ComWrappersNative::MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe) +{ + CONTRACTL + { + NOTHROW; + MODE_ANY; + PRECONDITION(wrapperMaybe != NULL); + } + CONTRACTL_END; + + // The IUnknown may or may not represent a wrapper, so E_INVALIDARG is okay here. + HRESULT hr = InteropLib::Com::MarkComActivated(wrapperMaybe); + _ASSERTE(SUCCEEDED(hr) || hr == E_INVALIDARG); +} + +void QCALLTYPE GlobalComWrappers::SetGlobalInstanceRegistered() +{ + // QCALL contracts are not used here because the managed declaration + // uses the SuppressGCTransition attribute + + _ASSERTE(!g_IsGlobalComWrappersRegistered); + g_IsGlobalComWrappersRegistered = true; +} + +bool GlobalComWrappers::TryGetOrCreateComInterfaceForObject( + _In_ OBJECTREF instance, + _Outptr_ void** wrapperRaw) +{ + if (!g_IsGlobalComWrappersRegistered) + return false; + + // Switch to Cooperative mode since object references + // are being manipulated. + { + GCX_COOP(); + + CreateComInterfaceFlags flags = CreateComInterfaceFlags::CreateComInterfaceFlags_TrackerSupport; + + // Passing NULL as the ComWrappers implementation indicates using the globally registered instance + return TryGetOrCreateComInterfaceForObjectInternal( + NULL, + instance, + flags, + wrapperRaw); + } +} + +bool GlobalComWrappers::TryGetOrCreateObjectForComInstance( + _In_ IUnknown* externalComObject, + _In_ INT32 objFromComIPFlags, + _Out_ OBJECTREF* objRef) +{ + if (!g_IsGlobalComWrappersRegistered) + return false; + + // Switch to Cooperative mode since object references + // are being manipulated. + { + GCX_COOP(); + + int flags = CreateObjectFlags::CreateObjectFlags_TrackerObject; + if ((objFromComIPFlags & ObjFromComIP::UNIQUE_OBJECT) != 0) + flags |= CreateObjectFlags::CreateObjectFlags_UniqueInstance; + + // For implicit usage of ComWrappers (i.e. automatically called by the runtime when there is a global instance), + // unwrap CCWs to allow for round-tripping object -> COM instance -> object. + bool unwrapIfManagedObjectWrapper = true; + + // Passing NULL as the ComWrappers implementation indicates using the globally registered instance + return TryGetOrCreateObjectForComInstanceInternal( + NULL /*comWrappersImpl*/, + externalComObject, + (CreateObjectFlags)flags, + unwrapIfManagedObjectWrapper, + NULL /*wrapperMaybe*/, + objRef); + } +} + #endif // FEATURE_COMWRAPPERS void Interop::OnGCStarted(_In_ int nCondemnedGeneration) diff --git a/src/coreclr/src/vm/interoplibinterface.h b/src/coreclr/src/vm/interoplibinterface.h index 2ed54cfc90706..2985412db4b29 100644 --- a/src/coreclr/src/vm/interoplibinterface.h +++ b/src/coreclr/src/vm/interoplibinterface.h @@ -17,12 +17,13 @@ class ComWrappersNative _Out_ void** fpAddRef, _Out_ void** fpRelease); - static void* QCALLTYPE GetOrCreateComInterfaceForObject( + static BOOL QCALLTYPE TryGetOrCreateComInterfaceForObject( _In_ QCall::ObjectHandleOnStack comWrappersImpl, _In_ QCall::ObjectHandleOnStack instance, - _In_ INT32 flags); + _In_ INT32 flags, + _Outptr_ void** wrapperRaw); - static void QCALLTYPE GetOrCreateObjectForComInstance( + static BOOL QCALLTYPE TryGetOrCreateObjectForComInstance( _In_ QCall::ObjectHandleOnStack comWrappersImpl, _In_ void* externalComObject, _In_ INT32 flags, @@ -33,6 +34,27 @@ class ComWrappersNative static void DestroyManagedObjectComWrapper(_In_ void* wrapper); static void DestroyExternalComObjectContext(_In_ void* context); static void MarkExternalComObjectContextCollected(_In_ void* context); + +public: // COM activation + static void MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe); +}; + +class GlobalComWrappers +{ +public: + // Native QCall for the ComWrappers managed type to indicate a global instance is registered + // This should be set if the private static member representing the global instance on ComWrappers is non-null. + static void QCALLTYPE SetGlobalInstanceRegistered(); + +public: // Functions operating on a registered global instance + static bool TryGetOrCreateComInterfaceForObject( + _In_ OBJECTREF instance, + _Outptr_ void** wrapperRaw); + + static bool TryGetOrCreateObjectForComInstance( + _In_ IUnknown* externalComObject, + _In_ INT32 objFromComIPFlags, + _Out_ OBJECTREF* objRef); }; #endif // FEATURE_COMWRAPPERS diff --git a/src/coreclr/src/vm/marshalnative.cpp b/src/coreclr/src/vm/marshalnative.cpp index adc20310648cb..d0d3b27c68980 100644 --- a/src/coreclr/src/vm/marshalnative.cpp +++ b/src/coreclr/src/vm/marshalnative.cpp @@ -372,7 +372,7 @@ FCIMPL1(UINT32, MarshalNative::OffsetOfHelper, ReflectFieldObject *pFieldUNSAFE) FieldDesc *pField = refField->GetField(); TypeHandle th = TypeHandle(pField->GetApproxEnclosingMethodTable()); - + if (th.IsBlittable()) { return pField->GetOffset(); @@ -388,7 +388,7 @@ FCIMPL1(UINT32, MarshalNative::OffsetOfHelper, ReflectFieldObject *pFieldUNSAFE) { // It isn't marshalable so throw an ArgumentException. StackSString strTypeName; - TypeString::AppendType(strTypeName, th); + TypeString::AppendType(strTypeName, th); COMPlusThrow(kArgumentException, IDS_CANNOT_MARSHAL, strTypeName.GetUnicode(), NULL, NULL); } EEClassNativeLayoutInfo const* pNativeLayoutInfo = th.GetMethodTable()->GetNativeLayoutInfo(); @@ -401,7 +401,7 @@ FCIMPL1(UINT32, MarshalNative::OffsetOfHelper, ReflectFieldObject *pFieldUNSAFE) #endif while (numReferenceFields--) { - if (pNFD->GetFieldDesc() == pField) + if (pNFD->GetFieldDesc() == pField) { externalOffset = pNFD->GetExternalOffset(); INDEBUG(foundField = true); @@ -802,11 +802,7 @@ FCIMPL2(IUnknown*, MarshalNative::GetIUnknownForObjectNative, Object* orefUNSAFE OBJECTREF oref = (OBJECTREF) orefUNSAFE; HELPER_METHOD_FRAME_BEGIN_RET_1(oref); - HRESULT hr = S_OK; - - if(!oref) - COMPlusThrowArgumentNull(W("o")); - + _ASSERTE(oref != NULL); // Ensure COM is started up. EnsureComStarted(); @@ -868,11 +864,7 @@ FCIMPL2(IDispatch*, MarshalNative::GetIDispatchForObjectNative, Object* orefUNSA OBJECTREF oref = (OBJECTREF) orefUNSAFE; HELPER_METHOD_FRAME_BEGIN_RET_1(oref); - HRESULT hr = S_OK; - - if(!oref) - COMPlusThrowArgumentNull(W("o")); - + _ASSERTE(oref != NULL); // Ensure COM is started up. EnsureComStarted(); @@ -899,13 +891,8 @@ FCIMPL4(IUnknown*, MarshalNative::GetComInterfaceForObjectNative, Object* orefUN REFLECTCLASSBASEREF refClass = (REFLECTCLASSBASEREF) refClassUNSAFE; HELPER_METHOD_FRAME_BEGIN_RET_2(oref, refClass); - HRESULT hr = S_OK; - - if(!oref) - COMPlusThrowArgumentNull(W("o")); - if(!refClass) - COMPlusThrowArgumentNull(W("t")); - + _ASSERTE(oref != NULL); + _ASSERTE(refClass != NULL); // Ensure COM is started up. EnsureComStarted(); @@ -946,23 +933,18 @@ FCIMPLEND //==================================================================== // return an Object for IUnknown //==================================================================== -FCIMPL1(Object*, MarshalNative::GetObjectForIUnknown, IUnknown* pUnk) +FCIMPL1(Object*, MarshalNative::GetObjectForIUnknownNative, IUnknown* pUnk) { CONTRACTL { FCALL_CHECK; - PRECONDITION(CheckPointer(pUnk, NULL_OK)); + PRECONDITION(CheckPointer(pUnk)); } CONTRACTL_END; OBJECTREF oref = NULL; HELPER_METHOD_FRAME_BEGIN_RET_1(oref); - HRESULT hr = S_OK; - - if(!pUnk) - COMPlusThrowArgumentNull(W("pUnk")); - // Ensure COM is started up. EnsureComStarted(); @@ -974,12 +956,12 @@ FCIMPL1(Object*, MarshalNative::GetObjectForIUnknown, IUnknown* pUnk) FCIMPLEND -FCIMPL1(Object*, MarshalNative::GetUniqueObjectForIUnknown, IUnknown* pUnk) +FCIMPL1(Object*, MarshalNative::GetUniqueObjectForIUnknownNative, IUnknown* pUnk) { CONTRACTL { FCALL_CHECK; - PRECONDITION(CheckPointer(pUnk, NULL_OK)); + PRECONDITION(CheckPointer(pUnk)); } CONTRACTL_END; @@ -988,9 +970,6 @@ FCIMPL1(Object*, MarshalNative::GetUniqueObjectForIUnknown, IUnknown* pUnk) HRESULT hr = S_OK; - if(!pUnk) - COMPlusThrowArgumentNull(W("pUnk")); - // Ensure COM is started up. EnsureComStarted(); diff --git a/src/coreclr/src/vm/marshalnative.h b/src/coreclr/src/vm/marshalnative.h index 7ff167c586fc3..3325a0d05544e 100644 --- a/src/coreclr/src/vm/marshalnative.h +++ b/src/coreclr/src/vm/marshalnative.h @@ -106,12 +106,12 @@ class MarshalNative //==================================================================== // return an Object for IUnknown //==================================================================== - static FCDECL1(Object*, GetObjectForIUnknown, IUnknown* pUnk); + static FCDECL1(Object*, GetObjectForIUnknownNative, IUnknown* pUnk); //==================================================================== // return a unique cacheless Object for IUnknown //==================================================================== - static FCDECL1(Object*, GetUniqueObjectForIUnknown, IUnknown* pUnk); + static FCDECL1(Object*, GetUniqueObjectForIUnknownNative, IUnknown* pUnk); //==================================================================== // return a unique cacheless Object for IUnknown diff --git a/src/coreclr/src/vm/runtimecallablewrapper.cpp b/src/coreclr/src/vm/runtimecallablewrapper.cpp index f61ed2b61b417..61475bf9ba0d5 100644 --- a/src/coreclr/src/vm/runtimecallablewrapper.cpp +++ b/src/coreclr/src/vm/runtimecallablewrapper.cpp @@ -52,6 +52,7 @@ SLIST_HEADER RCW::s_RCWStandbyList; #endif // FEATURE_COMINTEROP_APARTMENT_SUPPORT #ifdef FEATURE_COMINTEROP_UNMANAGED_ACTIVATION +#include "interoplibinterface.h" #ifndef CROSSGEN_COMPILE @@ -247,6 +248,8 @@ IUnknown *ComClassFactory::CreateInstanceFromClassFactory(IClassFactory *pClassF if (ccw != NULL) ccw->MarkComActivated(); + ComWrappersNative::MarkWrapperAsComActivated(pUnk); + pUnk.SuppressRelease(); RETURN pUnk; } diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/ComWrappersTests.csproj b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/ComWrappersTests.csproj similarity index 82% rename from src/coreclr/tests/src/Interop/COM/ComWrappers/ComWrappersTests.csproj rename to src/coreclr/tests/src/Interop/COM/ComWrappers/API/ComWrappersTests.csproj index e82960ae1e101..83acfa1f6fd5e 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/ComWrappersTests.csproj +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/ComWrappersTests.csproj @@ -9,8 +9,9 @@ + - + diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs similarity index 68% rename from src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs rename to src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs index e85f87a4b8269..2c30bb6c2a61d 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/Program.cs +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs @@ -7,154 +7,16 @@ namespace ComWrappersTests using System; using System.Collections; using System.Collections.Generic; - using System.IO; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; + using ComWrappersTests.Common; using TestLibrary; class Program { - // - // Managed object with native wrapper definition. - // - [Guid("447BB9ED-DA48-4ABC-8963-5BB5C3E0AA09")] - interface ITest - { - void SetValue(int i); - } - - class Test : ITest - { - public static int InstanceCount = 0; - - private int value = -1; - public Test() { InstanceCount++; } - ~Test() { InstanceCount--; } - - public void SetValue(int i) => this.value = i; - public int GetValue() => this.value; - } - - public struct IUnknownVtbl - { - public IntPtr QueryInterface; - public IntPtr AddRef; - public IntPtr Release; - } - - public struct ITestVtbl - { - public IUnknownVtbl IUnknownImpl; - public IntPtr SetValue; - - public delegate int _SetValue(IntPtr thisPtr, int i); - public static _SetValue pSetValue = new _SetValue(SetValueInternal); - - public static int SetValueInternal(IntPtr dispatchPtr, int i) - { - unsafe - { - try - { - ComWrappers.ComInterfaceDispatch.GetInstance((ComWrappers.ComInterfaceDispatch*)dispatchPtr).SetValue(i); - } - catch (Exception e) - { - return e.HResult; - } - } - return 0; // S_OK; - } - } - - // - // Native interface defintion with managed wrapper for tracker object - // - struct MockReferenceTrackerRuntime - { - [DllImport(nameof(MockReferenceTrackerRuntime))] - extern public static IntPtr CreateTrackerObject(); - - [DllImport(nameof(MockReferenceTrackerRuntime))] - extern public static void ReleaseAllTrackerObjects(); - - [DllImport(nameof(MockReferenceTrackerRuntime))] - extern public static int Trigger_NotifyEndOfReferenceTrackingOnThread(); - } - - [Guid("42951130-245C-485E-B60B-4ED4254256F8")] - public interface ITrackerObject - { - int AddObjectRef(IntPtr obj); - void DropObjectRef(int id); - }; - - public struct VtblPtr - { - public IntPtr Vtbl; - } - - public class ITrackerObjectWrapper : ITrackerObject - { - private struct ITrackerObjectWrapperVtbl - { - public IntPtr QueryInterface; - public _AddRef AddRef; - public _Release Release; - public _AddObjectRef AddObjectRef; - public _DropObjectRef DropObjectRef; - } - - private delegate int _AddRef(IntPtr This); - private delegate int _Release(IntPtr This); - private delegate int _AddObjectRef(IntPtr This, IntPtr obj, out int id); - private delegate int _DropObjectRef(IntPtr This, int id); - - private readonly IntPtr instance; - private readonly ITrackerObjectWrapperVtbl vtable; - - public ITrackerObjectWrapper(IntPtr instance) - { - var inst = Marshal.PtrToStructure(instance); - this.vtable = Marshal.PtrToStructure(inst.Vtbl); - this.instance = instance; - } - - ~ITrackerObjectWrapper() - { - if (this.instance != IntPtr.Zero) - { - this.vtable.Release(this.instance); - } - } - - public int AddObjectRef(IntPtr obj) - { - int id; - int hr = this.vtable.AddObjectRef(this.instance, obj, out id); - if (hr != 0) - { - throw new COMException($"{nameof(AddObjectRef)}", hr); - } - - return id; - } - - public void DropObjectRef(int id) - { - int hr = this.vtable.DropObjectRef(this.instance, id); - if (hr != 0) - { - throw new COMException($"{nameof(DropObjectRef)}", hr); - } - } - } - class TestComWrappers : ComWrappers { - public static readonly TestComWrappers Global = new TestComWrappers(); - protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) { Assert.IsTrue(obj is Test); @@ -188,11 +50,11 @@ class TestComWrappers : ComWrappers protected override object? CreateObject(IntPtr externalComObject, CreateObjectFlags flag) { var iid = typeof(ITrackerObject).GUID; - IntPtr iTestComObject; - int hr = Marshal.QueryInterface(externalComObject, ref iid, out iTestComObject); - Assert.AreEqual(hr, 0); + IntPtr iTrackerComObject; + int hr = Marshal.QueryInterface(externalComObject, ref iid, out iTrackerComObject); + Assert.AreEqual(0, hr); - return new ITrackerObjectWrapper(iTestComObject); + return new ITrackerObjectWrapper(iTrackerComObject); } public const int ReleaseObjectsCallAck = unchecked((int)-1); @@ -458,37 +320,6 @@ static void ValidateRuntimeTrackerScenario() GC.Collect(); } - static void ValidateGlobalInstanceScenarios() - { - Console.WriteLine($"Running {nameof(ValidateGlobalInstanceScenarios)}..."); - Console.WriteLine($"Validate RegisterAsGlobalInstance()..."); - - var wrappers1 = TestComWrappers.Global; - wrappers1.RegisterAsGlobalInstance(); - - Assert.Throws( - () => - { - wrappers1.RegisterAsGlobalInstance(); - }, "Should not be able to re-register for global ComWrappers"); - - var wrappers2 = new TestComWrappers(); - Assert.Throws( - () => - { - wrappers2.RegisterAsGlobalInstance(); - }, "Should not be able to reset for global ComWrappers"); - - Console.WriteLine($"Validate NotifyEndOfReferenceTrackingOnThread()..."); - - int hr; - var cw = TestComWrappers.Global; - - // Trigger the thread lifetime end API and verify the callback occurs. - hr = MockReferenceTrackerRuntime.Trigger_NotifyEndOfReferenceTrackingOnThread(); - Assert.AreEqual(TestComWrappers.ReleaseObjectsCallAck, hr); - } - static int Main(string[] doNotUse) { try @@ -499,10 +330,6 @@ static int Main(string[] doNotUse) ValidateIUnknownImpls(); ValidateBadComWrapperImpl(); ValidateRuntimeTrackerScenario(); - - // Perform all global impacting test scenarios last to - // avoid polluting non-global tests. - ValidateGlobalInstanceScenarios(); } catch (Exception e) { diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs new file mode 100644 index 0000000000000..06c052b0a237a --- /dev/null +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/Common.cs @@ -0,0 +1,146 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace ComWrappersTests.Common +{ + using System; + using System.Runtime.InteropServices; + + // + // Managed object with native wrapper definition. + // + [Guid("447BB9ED-DA48-4ABC-8963-5BB5C3E0AA09")] + interface ITest + { + void SetValue(int i); + } + + class Test : ITest + { + public static int InstanceCount = 0; + + private int value = -1; + public Test() { InstanceCount++; } + ~Test() { InstanceCount--; } + + public void SetValue(int i) => this.value = i; + public int GetValue() => this.value; + } + + public struct IUnknownVtbl + { + public IntPtr QueryInterface; + public IntPtr AddRef; + public IntPtr Release; + } + + public struct ITestVtbl + { + public IUnknownVtbl IUnknownImpl; + public IntPtr SetValue; + + public delegate int _SetValue(IntPtr thisPtr, int i); + public static _SetValue pSetValue = new _SetValue(SetValueInternal); + + public static int SetValueInternal(IntPtr dispatchPtr, int i) + { + unsafe + { + try + { + ComWrappers.ComInterfaceDispatch.GetInstance((ComWrappers.ComInterfaceDispatch*)dispatchPtr).SetValue(i); + } + catch (Exception e) + { + return e.HResult; + } + } + return 0; // S_OK; + } + } + + // + // Native interface defintion with managed wrapper for tracker object + // + struct MockReferenceTrackerRuntime + { + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static IntPtr CreateTrackerObject(); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static void ReleaseAllTrackerObjects(); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static int Trigger_NotifyEndOfReferenceTrackingOnThread(); + } + + [Guid("42951130-245C-485E-B60B-4ED4254256F8")] + public interface ITrackerObject + { + int AddObjectRef(IntPtr obj); + void DropObjectRef(int id); + }; + + public struct VtblPtr + { + public IntPtr Vtbl; + } + + public class ITrackerObjectWrapper : ITrackerObject + { + private struct ITrackerObjectWrapperVtbl + { + public IntPtr QueryInterface; + public _AddRef AddRef; + public _Release Release; + public _AddObjectRef AddObjectRef; + public _DropObjectRef DropObjectRef; + } + + private delegate int _AddRef(IntPtr This); + private delegate int _Release(IntPtr This); + private delegate int _AddObjectRef(IntPtr This, IntPtr obj, out int id); + private delegate int _DropObjectRef(IntPtr This, int id); + + private readonly IntPtr instance; + private readonly ITrackerObjectWrapperVtbl vtable; + + public ITrackerObjectWrapper(IntPtr instance) + { + var inst = Marshal.PtrToStructure(instance); + this.vtable = Marshal.PtrToStructure(inst.Vtbl); + this.instance = instance; + } + + ~ITrackerObjectWrapper() + { + if (this.instance != IntPtr.Zero) + { + this.vtable.Release(this.instance); + } + } + + public int AddObjectRef(IntPtr obj) + { + int id; + int hr = this.vtable.AddObjectRef(this.instance, obj, out id); + if (hr != 0) + { + throw new COMException($"{nameof(AddObjectRef)}", hr); + } + + return id; + } + + public void DropObjectRef(int id) + { + int hr = this.vtable.DropObjectRef(this.instance, id); + if (hr != 0) + { + throw new COMException($"{nameof(DropObjectRef)}", hr); + } + } + } +} + diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/App.manifest b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/App.manifest new file mode 100644 index 0000000000000..bb7ec83fae8b8 --- /dev/null +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/App.manifest @@ -0,0 +1,26 @@ + + + + + + + + + + + + + + + + + diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/CoreShim.X.manifest b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/CoreShim.X.manifest new file mode 100644 index 0000000000000..abb39fbb21c7d --- /dev/null +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/CoreShim.X.manifest @@ -0,0 +1,16 @@ + + + + + + + + + + + diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstanceTests.csproj b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstanceTests.csproj new file mode 100644 index 0000000000000..1e5e14f330e2e --- /dev/null +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/GlobalInstanceTests.csproj @@ -0,0 +1,37 @@ + + + Exe + App.manifest + true + true + + BuildOnly + + true + true + + true + true + + + + + + + + + + + + + false + Content + Always + + + + + PreserveNewest + + + diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs new file mode 100644 index 0000000000000..3bd14141570cd --- /dev/null +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/GlobalInstance/Program.cs @@ -0,0 +1,451 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace ComWrappersTests.GlobalInstance +{ + using System; + using System.Collections; + using System.Runtime.CompilerServices; + using System.Runtime.InteropServices; + + using ComWrappersTests.Common; + using TestLibrary; + + class Program + { + struct MarshalInterface + { + [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint=nameof(MockReferenceTrackerRuntime.CreateTrackerObject))] + [return: MarshalAs(UnmanagedType.IUnknown)] + extern public static object CreateTrackerObjectAsIUnknown(); + + [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint=nameof(MockReferenceTrackerRuntime.CreateTrackerObject))] + [return: MarshalAs(UnmanagedType.Interface)] + extern public static FakeWrapper CreateTrackerObjectAsInterface(); + + [DllImport(nameof(MockReferenceTrackerRuntime), EntryPoint = nameof(MockReferenceTrackerRuntime.CreateTrackerObject))] + [return: MarshalAs(UnmanagedType.Interface)] + extern public static Test CreateTrackerObjectWrongType(); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static int UpdateTestObjectAsIUnknown( + [MarshalAs(UnmanagedType.IUnknown)] object testObj, + int i, + [MarshalAs(UnmanagedType.IUnknown)] out object ret); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static int UpdateTestObjectAsIDispatch( + [MarshalAs(UnmanagedType.IDispatch)] object testObj, + int i, + [MarshalAs(UnmanagedType.IDispatch)] out object ret); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static int UpdateTestObjectAsIInspectable( + [MarshalAs(UnmanagedType.IInspectable)] object testObj, + int i, + [MarshalAs(UnmanagedType.IInspectable)] out object ret); + + [DllImport(nameof(MockReferenceTrackerRuntime))] + extern public static int UpdateTestObjectAsInterface( + [MarshalAs(UnmanagedType.Interface)] Test testObj, + int i, + [Out, MarshalAs(UnmanagedType.Interface)] out Test ret); + } + + private const string ManagedServerTypeName = "ConsumeNETServerTesting"; + + private const string IID_IDISPATCH = "00020400-0000-0000-C000-000000000046"; + private const string IID_IINSPECTABLE = "AF86E2E0-B12D-4c6a-9C5A-D7AA65101E90"; + class TestEx : Test + { + public readonly Guid[] Interfaces; + public TestEx(params string[] iids) + { + Interfaces = new Guid[iids.Length]; + for (int i = 0; i < iids.Length; i++) + Interfaces[i] = Guid.Parse(iids[i]); + } + } + + class FakeWrapper + { + private delegate int _AddRef(IntPtr This); + private delegate int _Release(IntPtr This); + private struct IUnknownWrapperVtbl + { + public IntPtr QueryInterface; + public _AddRef AddRef; + public _Release Release; + } + + private readonly IntPtr wrappedInstance; + + private readonly IUnknownWrapperVtbl vtable; + + public FakeWrapper(IntPtr instance) + { + this.wrappedInstance = instance; + var inst = Marshal.PtrToStructure(instance); + this.vtable = Marshal.PtrToStructure(inst.Vtbl); + } + + ~FakeWrapper() + { + if (this.wrappedInstance != IntPtr.Zero) + { + this.vtable.Release(this.wrappedInstance); + } + } + } + + class GlobalComWrappers : ComWrappers + { + public static GlobalComWrappers Instance = new GlobalComWrappers(); + + public bool ReturnInvalid { get; set; } + + public object LastComputeVtablesObject { get; private set; } + + protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) + { + LastComputeVtablesObject = obj; + + if (ReturnInvalid) + { + count = -1; + return null; + } + + if (obj is Test) + { + return ComputeVtablesForTestObject((Test)obj, out count); + } + else if (string.Equals(ManagedServerTypeName, obj.GetType().Name)) + { + IntPtr fpQueryInteface = default; + IntPtr fpAddRef = default; + IntPtr fpRelease = default; + ComWrappers.GetIUnknownImpl(out fpQueryInteface, out fpAddRef, out fpRelease); + + var vtbl = new IUnknownVtbl() + { + QueryInterface = fpQueryInteface, + AddRef = fpAddRef, + Release = fpRelease + }; + var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IUnknownVtbl), sizeof(IUnknownVtbl)); + Marshal.StructureToPtr(vtbl, vtblRaw, false); + + // Including interfaces to allow QI, but not actually returning a valid vtable, since it is not needed for the tests here. + var entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IUnknownVtbl), sizeof(ComInterfaceEntry)); + entryRaw[0].IID = typeof(Server.Contract.IConsumeNETServer).GUID; + entryRaw[0].Vtable = vtblRaw; + + count = 1; + return entryRaw; + } + + count = -1; + return null; + } + + protected override object? CreateObject(IntPtr externalComObject, CreateObjectFlags flag) + { + if (ReturnInvalid) + return null; + + Guid[] iids = { + typeof(ITrackerObject).GUID, + typeof(ITest).GUID, + typeof(Server.Contract.IDispatchTesting).GUID, + typeof(Server.Contract.IConsumeNETServer).GUID + }; + + for (var i = 0; i < iids.Length; i++) + { + var iid = iids[i]; + IntPtr comObject; + int hr = Marshal.QueryInterface(externalComObject, ref iid, out comObject); + if (hr == 0) + return new FakeWrapper(comObject); + } + + return null; + } + + public const int ReleaseObjectsCallAck = unchecked((int)-1); + + protected override void ReleaseObjects(IEnumerable objects) + { + throw new Exception() { HResult = ReleaseObjectsCallAck }; + } + + private unsafe ComInterfaceEntry* ComputeVtablesForTestObject(Test obj, out int count) + { + IntPtr fpQueryInteface = default; + IntPtr fpAddRef = default; + IntPtr fpRelease = default; + ComWrappers.GetIUnknownImpl(out fpQueryInteface, out fpAddRef, out fpRelease); + + var iUnknownVtbl = new IUnknownVtbl() + { + QueryInterface = fpQueryInteface, + AddRef = fpAddRef, + Release = fpRelease + }; + + var vtbl = new ITestVtbl() + { + IUnknownImpl = iUnknownVtbl, + SetValue = Marshal.GetFunctionPointerForDelegate(ITestVtbl.pSetValue) + }; + var vtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ITestVtbl)); + Marshal.StructureToPtr(vtbl, vtblRaw, false); + + int countLocal = obj is TestEx ? ((TestEx)obj).Interfaces.Length + 1 : 1; + var entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ITestVtbl), sizeof(ComInterfaceEntry) * countLocal); + entryRaw[0].IID = typeof(ITest).GUID; + entryRaw[0].Vtable = vtblRaw; + + if (obj is TestEx) + { + var iUnknownVtblRaw = RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(IUnknownVtbl), sizeof(IUnknownVtbl)); + Marshal.StructureToPtr(iUnknownVtbl, iUnknownVtblRaw, false); + + var testEx = (TestEx)obj; + for (int i = 1; i < testEx.Interfaces.Length + 1; i++) + { + // Including interfaces to allow QI, but not actually returning a valid vtable, since it is not needed for the tests here. + entryRaw[i].IID = testEx.Interfaces[i - 1]; + entryRaw[i].Vtable = iUnknownVtblRaw; + } + } + + count = countLocal; + return entryRaw; + } + } + + private static void ValidateRegisterAsGlobalInstance() + { + Console.WriteLine($"Running {nameof(ValidateRegisterAsGlobalInstance)}..."); + + var wrappers1 = GlobalComWrappers.Instance; + wrappers1.RegisterAsGlobalInstance(); + Assert.Throws( + () => + { + wrappers1.RegisterAsGlobalInstance(); + }, "Should not be able to re-register for global ComWrappers"); + + var wrappers2 = new GlobalComWrappers(); + Assert.Throws( + () => + { + wrappers2.RegisterAsGlobalInstance(); + }, "Should not be able to reset for global ComWrappers"); + } + + private static void ValidateMarshalAPIs(bool validateUseRegistered) + { + string scenario = validateUseRegistered ? "use registered wrapper" : "fall back to runtime"; + Console.WriteLine($"Running {nameof(ValidateMarshalAPIs)}: {scenario}..."); + + GlobalComWrappers registeredWrapper = GlobalComWrappers.Instance; + registeredWrapper.ReturnInvalid = !validateUseRegistered; + + Console.WriteLine($" -- Validate Marshal.GetIUnknownForObject..."); + + var testObj = new Test(); + IntPtr comWrapper1 = Marshal.GetIUnknownForObject(testObj); + Assert.AreNotEqual(IntPtr.Zero, comWrapper1); + Assert.AreEqual(testObj, registeredWrapper.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); + + IntPtr comWrapper2 = Marshal.GetIUnknownForObject(testObj); + Assert.AreEqual(comWrapper1, comWrapper2); + + Marshal.Release(comWrapper1); + Marshal.Release(comWrapper2); + + Console.WriteLine($" -- Validate Marshal.GetIDispatchForObject..."); + + Assert.Throws(() => Marshal.GetIDispatchForObject(testObj)); + + if (validateUseRegistered) + { + var dispatchObj = new TestEx(IID_IDISPATCH); + IntPtr dispatchWrapper = Marshal.GetIDispatchForObject(dispatchObj); + Assert.AreNotEqual(IntPtr.Zero, dispatchWrapper); + Assert.AreEqual(dispatchObj, registeredWrapper.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); + } + + Console.WriteLine($" -- Validate Marshal.GetObjectForIUnknown..."); + + IntPtr trackerObjRaw = MockReferenceTrackerRuntime.CreateTrackerObject(); + object objWrapper1 = Marshal.GetObjectForIUnknown(trackerObjRaw); + Assert.AreEqual(validateUseRegistered, objWrapper1 is FakeWrapper, $"GetObjectForIUnknown should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); + object objWrapper2 = Marshal.GetObjectForIUnknown(trackerObjRaw); + Assert.AreEqual(objWrapper1, objWrapper2); + + Console.WriteLine($" -- Validate Marshal.GetUniqueObjectForIUnknown..."); + + object objWrapper3 = Marshal.GetUniqueObjectForIUnknown(trackerObjRaw); + Assert.AreEqual(validateUseRegistered, objWrapper3 is FakeWrapper, $"GetObjectForIUnknown should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); + + Assert.AreNotEqual(objWrapper1, objWrapper3); + + Marshal.Release(trackerObjRaw); + } + + private static void ValidatePInvokes(bool validateUseRegistered) + { + string scenario = validateUseRegistered ? "use registered wrapper" : "fall back to runtime"; + Console.WriteLine($"Running {nameof(ValidatePInvokes)}: {scenario}..."); + + GlobalComWrappers.Instance.ReturnInvalid = !validateUseRegistered; + + Console.WriteLine($" -- Validate MarshalAs IUnknown..."); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIUnknown, validateUseRegistered); + object obj = MarshalInterface.CreateTrackerObjectAsIUnknown(); + Assert.AreEqual(validateUseRegistered, obj is FakeWrapper, $"Should{(validateUseRegistered ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); + + if (validateUseRegistered) + { + Console.WriteLine($" -- Validate MarshalAs IDispatch..."); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIDispatch, validateUseRegistered, new TestEx(IID_IDISPATCH)); + + Console.WriteLine($" -- Validate MarshalAs IInspectable..."); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsIInspectable, validateUseRegistered, new TestEx(IID_IINSPECTABLE)); + } + + Console.WriteLine($" -- Validate MarshalAs Interface..."); + ValidateInterfaceMarshaler(MarshalInterface.UpdateTestObjectAsInterface, validateUseRegistered); + + if (validateUseRegistered) + { + Assert.Throws(() => MarshalInterface.CreateTrackerObjectWrongType()); + + FakeWrapper wrapper = MarshalInterface.CreateTrackerObjectAsInterface(); + Assert.IsNotNull(obj, $"Should have returned {nameof(FakeWrapper)} instance"); + } + } + + private delegate int UpdateTestObject(T testObj, int i, out T ret) where T : class; + private static void ValidateInterfaceMarshaler(UpdateTestObject func, bool validateUseRegistered, Test testObj = null) where T : class + { + const int E_NOINTERFACE = unchecked((int)0x80004002); + int value = 10; + + if (testObj == null) + testObj = new Test(); + + T retObj; + int hr = func(testObj as T, value, out retObj); + Assert.AreEqual(testObj, GlobalComWrappers.Instance.LastComputeVtablesObject, "Registered ComWrappers instance should have been called"); + if (validateUseRegistered) + { + Assert.IsTrue(retObj is Test); + Assert.AreEqual(value, testObj.GetValue()); + Assert.AreEqual(testObj, retObj); + } + else + { + Assert.AreEqual(E_NOINTERFACE, hr); + } + } + + private static void ValidateComActivation(bool validateUseRegistered) + { + string scenario = validateUseRegistered ? "use registered wrapper" : "fall back to runtime"; + Console.WriteLine($"Running {nameof(ValidateComActivation)}: {scenario}..."); + GlobalComWrappers.Instance.ReturnInvalid = !validateUseRegistered; + + Console.WriteLine($" -- Validate native server..."); + ValidateNativeServerActivation(); + + Console.WriteLine($" -- Validate managed server..."); + ValidateManagedServerActivation(); + } + + private static void ValidateNativeServerActivation() + { + bool returnValid = !GlobalComWrappers.Instance.ReturnInvalid; + + Type t= Type.GetTypeFromCLSID(Guid.Parse(Server.Contract.Guids.DispatchTesting)); + var server = Activator.CreateInstance(t); + Assert.AreEqual(returnValid, server is FakeWrapper, $"Should{(returnValid ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); + + IntPtr ptr = Marshal.GetIUnknownForObject(server); + var obj = Marshal.GetObjectForIUnknown(ptr); + Assert.AreEqual(server, obj); + } + + private static void ValidateManagedServerActivation() + { + bool returnValid = !GlobalComWrappers.Instance.ReturnInvalid; + + // Initialize CoreShim and hostpolicymock + HostPolicyMock.Initialize(Environment.CurrentDirectory, null); + Environment.SetEnvironmentVariable("CORESHIM_COMACT_ASSEMBLYNAME", "NETServer"); + Environment.SetEnvironmentVariable("CORESHIM_COMACT_TYPENAME", ManagedServerTypeName); + + using (HostPolicyMock.Mock_corehost_resolve_component_dependencies(0, string.Empty, string.Empty, string.Empty)) + { + Type t = Type.GetTypeFromCLSID(Guid.Parse(Server.Contract.Guids.ConsumeNETServerTesting)); + var server = Activator.CreateInstance(t); + Assert.AreEqual(returnValid, server is FakeWrapper, $"Should{(returnValid ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); + object serverUnwrapped = GlobalComWrappers.Instance.LastComputeVtablesObject; + Assert.AreEqual(ManagedServerTypeName, serverUnwrapped.GetType().Name); + + IntPtr ptr = Marshal.GetIUnknownForObject(server); + var obj = Marshal.GetObjectForIUnknown(ptr); + Assert.AreEqual(server, obj); + Assert.AreEqual(returnValid, obj is FakeWrapper, $"Should{(returnValid ? string.Empty : "not")} have returned {nameof(FakeWrapper)} instance"); + serverUnwrapped.GetType().GetMethod("NotEqualByRCW").Invoke(serverUnwrapped, new object[] { obj }); + } + } + + private static void ValidateNotifyEndOfReferenceTrackingOnThread() + { + Console.WriteLine($"Running {nameof(ValidateNotifyEndOfReferenceTrackingOnThread)}..."); + + // Make global instance return invalid object so that the Exception thrown by + // GlobalComWrappers.ReleaseObjects is marshalled using the built-in system. + GlobalComWrappers.Instance.ReturnInvalid = true; + + // Trigger the thread lifetime end API and verify the callback occurs. + int hr = MockReferenceTrackerRuntime.Trigger_NotifyEndOfReferenceTrackingOnThread(); + Assert.AreEqual(GlobalComWrappers.ReleaseObjectsCallAck, hr); + } + + static int Main(string[] doNotUse) + { + try + { + // The first test registereds a global ComWrappers instance + // Subsequents tests assume the global instance has already been registered. + ValidateRegisterAsGlobalInstance(); + + ValidateMarshalAPIs(validateUseRegistered: true); + ValidateMarshalAPIs(validateUseRegistered: false); + + ValidatePInvokes(validateUseRegistered: true); + ValidatePInvokes(validateUseRegistered: false); + + ValidateComActivation(validateUseRegistered: true); + ValidateComActivation(validateUseRegistered: false); + + ValidateNotifyEndOfReferenceTrackingOnThread(); + } + catch (Exception e) + { + Console.WriteLine($"Test Failure: {e}"); + return 101; + } + + return 100; + } + } +} + diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp b/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp index f0a0f30b277e7..867efa01866e4 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/MockReferenceTrackerRuntime/ReferenceTrackerRuntime.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace API { @@ -337,3 +338,41 @@ extern "C" DLL_EXPORT int STDMETHODCALLTYPE Trigger_NotifyEndOfReferenceTracking { return TrackerRuntimeManager.NotifyEndOfReferenceTrackingOnThread(); } + +extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObjectAsIUnknown(IUnknown *obj, int i, IUnknown **out) +{ + if (obj == nullptr) + return E_POINTER; + + HRESULT hr; + ComSmartPtr testObj; + RETURN_IF_FAILED(obj->QueryInterface(&testObj)) + RETURN_IF_FAILED(testObj->SetValue(i)); + + *out = testObj.Detach(); + return S_OK; +} + +extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObjectAsIDispatch(IDispatch *obj, int i, IDispatch **out) +{ + if (obj == nullptr) + return E_POINTER; + + return UpdateTestObjectAsIUnknown(obj, i, (IUnknown**)out); +} + +extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObjectAsIInspectable(IInspectable * obj, int i, IInspectable **out) +{ + if (obj == nullptr) + return E_POINTER; + + return UpdateTestObjectAsIUnknown(obj, i, (IUnknown **)out); +} + +extern "C" DLL_EXPORT int STDMETHODCALLTYPE UpdateTestObjectAsInterface(ITest *obj, int i, ITest **out) +{ + if (obj == nullptr) + return E_POINTER; + + return UpdateTestObjectAsIUnknown(obj, i, (IUnknown**)out); +} diff --git a/src/coreclr/tests/src/Interop/COM/NETServer/ConsumeNETServerTesting.cs b/src/coreclr/tests/src/Interop/COM/NETServer/ConsumeNETServerTesting.cs index d8dba73982d9c..ec9f92ff3e69a 100644 --- a/src/coreclr/tests/src/Interop/COM/NETServer/ConsumeNETServerTesting.cs +++ b/src/coreclr/tests/src/Interop/COM/NETServer/ConsumeNETServerTesting.cs @@ -15,6 +15,9 @@ public class ConsumeNETServerTesting : Server.Contract.IConsumeNETServer public ConsumeNETServerTesting() { _ccw = Marshal.GetIUnknownForObject(this); + + // At this point, the CCW has not been marked as COM-activated, + // so the returned RCW will be unwrapped. _rcwUnwrapped = Marshal.GetObjectForIUnknown(_ccw); } diff --git a/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetComInterfaceForObjectTests.cs b/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetComInterfaceForObjectTests.cs index f72fa3af23a00..26e1876686853 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetComInterfaceForObjectTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetComInterfaceForObjectTests.cs @@ -141,8 +141,8 @@ public void GetComInterfaceForObject_NullObject_ThrowsArgumentNullException() [PlatformSpecific(TestPlatforms.Windows)] public void GetComInterfaceForObject_NullType_ThrowsArgumentNullException() { - AssertExtensions.Throws("t", () => Marshal.GetComInterfaceForObject(new object(), null)); - AssertExtensions.Throws("t", () => Marshal.GetComInterfaceForObject(new object(), null, CustomQueryInterfaceMode.Allow)); + AssertExtensions.Throws("T", () => Marshal.GetComInterfaceForObject(new object(), null)); + AssertExtensions.Throws("T", () => Marshal.GetComInterfaceForObject(new object(), null, CustomQueryInterfaceMode.Allow)); } public static IEnumerable GetComInterfaceForObject_InvalidType_TestData() diff --git a/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetUniqueObjectForIUnknownTests.cs b/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetUniqueObjectForIUnknownTests.cs index 1da3b1529af9b..88d36ee5b5b35 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetUniqueObjectForIUnknownTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/Marshal/GetUniqueObjectForIUnknownTests.cs @@ -63,7 +63,7 @@ public void GetUniqueObjectForIUnknown_Unix_ThrowsPlatformNotSupportedException( [PlatformSpecific(TestPlatforms.Windows)] public void GetUniqueObjectForIUnknown_NullPointer_ThrowsArgumentNullException() { - AssertExtensions.Throws("pUnk", () => Marshal.GetUniqueObjectForIUnknown(IntPtr.Zero)); + AssertExtensions.Throws("unknown", () => Marshal.GetUniqueObjectForIUnknown(IntPtr.Zero)); } private static void NonGenericMethod(int i) { }