Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use registered ComWrappers for object <-> COM interface #33485

Merged
merged 12 commits into from
Mar 25, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,35 @@ private struct ComInterfaceInstance
/// <param name="flags">Flags used to configure the generated interface.</param>
/// <returns>The generated COM interface that can be passed outside the .NET runtime.</returns>
public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfaceFlags flags)
{
IntPtr ptr;
if (!TryGetOrCreateComInterfaceForObjectInternal(this, instance, flags, out ptr))
throw new ArgumentException();

return ptr;
}

/// <summary>
/// Create a COM representation of the supplied object that can be passed to a non-managed environment.
/// </summary>
/// <param name="impl">The <see cref="ComWrappers" /> implementation to use when creating the COM representation.</param>
/// <param name="instance">The managed object to expose outside the .NET runtime.</param>
/// <param name="flags">Flags used to configure the generated interface.</param>
/// <param name="retValue">The generated COM interface that can be passed outside the .NET runtime or IntPtr.Zero if it could not be created.</param>
/// <returns>Returns <c>true</c> if a COM representation could be created, <c>false</c> otherwise</returns>
/// <remarks>
/// If <paramref name="impl" /> is <c>null</c>, the global instance (if registered) will be used.
/// </remarks>
private static bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers? impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue)
{
if (instance == null)
throw new ArgumentNullException(nameof(instance));

ComWrappers impl = this;
return GetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), ObjectHandleOnStack.Create(ref instance), flags);
return TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), ObjectHandleOnStack.Create(ref instance), flags, out retValue);
}

[DllImport(RuntimeHelpers.QCall)]
private static extern IntPtr GetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, ObjectHandleOnStack instance, CreateComInterfaceFlags flags);
private static extern bool TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, ObjectHandleOnStack instance, CreateComInterfaceFlags flags, out IntPtr retValue);

/// <summary>
/// Compute the desired Vtable for <paramref name="obj"/> respecting the values of <paramref name="flags"/>.
Expand All @@ -140,13 +159,23 @@ public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfa
/// All memory returned from this function must either be unmanaged memory, pinned managed memory, or have been
/// allocated with the <see cref="System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(Type, int)"/> API.
///
/// If the interface entries cannot be created and <code>null</code> is returned, the call to <see cref="ComWrappers.GetOrCreateComInterfaceForObject(object, CreateComInterfaceFlags)"/> will throw a <see cref="System.ArgumentNullException"/>.
/// If the interface entries cannot be created and a negative <paramref name="count" /> or <code>null</code> and a non-zero <paramref name="count" /> are returned,
/// the call to <see cref="ComWrappers.GetOrCreateComInterfaceForObject(object, CreateComInterfaceFlags)"/> will throw a <see cref="System.ArgumentException"/>.
/// </remarks>
protected unsafe abstract ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count);

// Call to execute the abstract instance function
internal static unsafe void* CallComputeVtables(ComWrappers? comWrappersImpl, object obj, CreateComInterfaceFlags flags, out int count)
=> (comWrappersImpl ?? s_globalInstance!).ComputeVtables(obj, flags, out count);
{
ComWrappers? impl = comWrappersImpl ?? s_globalInstance;
if (impl is null)
{
count = -1;
return null;
}

return impl.ComputeVtables(obj, flags, out count);
}

/// <summary>
/// Get the currently registered managed object or creates a new managed object and registers it.
Expand All @@ -156,7 +185,11 @@ public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfa
/// <returns>Returns a managed object associated with the supplied external COM object.</returns>
public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags)
{
return GetOrCreateObjectForComInstanceInternal(externalComObject, flags, null);
object? obj;
if (!TryGetOrCreateObjectForComInstanceInternal(this, externalComObject, flags, null, out obj))
throw new ArgumentNullException();

return obj!;
}

/// <summary>
Expand All @@ -172,7 +205,13 @@ public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateOb

// Call to execute the abstract instance function
internal static object? CallCreateObject(ComWrappers? comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags)
=> (comWrappersImpl ?? s_globalInstance!).CreateObject(externalComObject, flags);
{
ComWrappers? impl = comWrappersImpl ?? s_globalInstance;
if (impl == null)
return null;

return impl.CreateObject(externalComObject, flags);
}

/// <summary>
/// Get the currently registered managed object or uses the supplied managed object and registers it.
Expand All @@ -189,24 +228,37 @@ public object GetOrRegisterObjectForComInstance(IntPtr externalComObject, Create
if (wrapper == null)
throw new ArgumentNullException(nameof(externalComObject));

return GetOrCreateObjectForComInstanceInternal(externalComObject, flags, wrapper);
object? obj;
if (!TryGetOrCreateObjectForComInstanceInternal(this, externalComObject, flags, wrapper, out obj))
throw new ArgumentNullException();

return obj!;
}

private object GetOrCreateObjectForComInstanceInternal(IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe)
/// <summary>
/// Get the currently registered managed object or creates a new managed object and registers it.
/// </summary>
/// <param name="impl">The <see cref="ComWrappers" /> implementation to use when creating the managed object.</param>
/// <param name="externalComObject">Object to import for usage into the .NET runtime.</param>
/// <param name="flags">Flags used to describe the external object.</param>
/// <param name="wrapperMaybe">The <see cref="object"/> to be used as the wrapper for the external object.</param>
/// <param name="retValue">The managed object associated with the supplied external COM object or <c>null</c> if it could not be created.</param>
/// <returns>Returns <c>true</c> if a managed object could be retrieved/created, <c>false</c> otherwise</returns>
/// <remarks>
/// If <paramref name="impl" /> is <c>null</c>, the global instance (if registered) will be used.
/// </remarks>
private static bool TryGetOrCreateObjectForComInstanceInternal(ComWrappers? impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe, out object? retValue)
{
if (externalComObject == IntPtr.Zero)
throw new ArgumentNullException(nameof(externalComObject));

ComWrappers impl = this;
object? wrapperMaybeLocal = wrapperMaybe;
object? retValue = null;
GetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue));

return retValue!;
retValue = null;
return TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue));
}

[DllImport(RuntimeHelpers.QCall)]
private static extern void GetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue);
private static extern bool TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue);

/// <summary>
/// Called when a request is made for a collection of objects to be released outside of normal object or COM interface lifetime.
Expand Down Expand Up @@ -235,8 +287,14 @@ public void RegisterAsGlobalInstance()
{
throw new InvalidOperationException(SR.InvalidOperation_ResetGlobalComWrappersInstance);
}

SetGlobalInstanceRegistered();
}

[DllImport(RuntimeHelpers.QCall)]
elinor-fung marked this conversation as resolved.
Show resolved Hide resolved
[SuppressGCTransition]
private static extern void SetGlobalInstanceRegistered();

/// <summary>
/// Get the runtime provided IUnknown implementation.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ public static string GetTypeInfoName(ITypeInfo typeInfo)
/// </summary>
public static IntPtr /* IUnknown* */ GetIUnknownForObject(object o)
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
{
if (o is null)
{
throw new ArgumentNullException(nameof(o));
}

return GetIUnknownForObjectNative(o, false);
}

Expand All @@ -344,6 +349,11 @@ public static string GetTypeInfoName(ITypeInfo typeInfo)
/// </summary>
public static IntPtr /* IDispatch */ GetIDispatchForObject(object o)
{
if (o is null)
{
throw new ArgumentNullException(nameof(o));
}

return GetIDispatchForObjectNative(o, false);
}

Expand All @@ -356,6 +366,16 @@ public static string GetTypeInfoName(ITypeInfo typeInfo)
/// </summary>
public static IntPtr /* IUnknown* */ GetComInterfaceForObject(object o, Type T)
{
if (o is null)
{
throw new ArgumentNullException(nameof(o));
}

if (T is null)
{
throw new ArgumentNullException(nameof(T));
}

return GetComInterfaceForObjectNative(o, T, false, true);
}

Expand All @@ -368,15 +388,48 @@ public static string GetTypeInfoName(ITypeInfo typeInfo)
/// </summary>
public static IntPtr /* IUnknown* */ GetComInterfaceForObject(object o, Type T, CustomQueryInterfaceMode mode)
{
if (o is null)
{
throw new ArgumentNullException(nameof(o));
}

if (T is null)
{
throw new ArgumentNullException(nameof(T));
}

bool bEnableCustomizedQueryInterface = ((mode == CustomQueryInterfaceMode.Allow) ? true : false);
return GetComInterfaceForObjectNative(o, T, false, bEnableCustomizedQueryInterface);
}

[MethodImpl(MethodImplOptions.InternalCall)]
private static extern IntPtr /* IUnknown* */ GetComInterfaceForObjectNative(object o, Type t, bool onlyInContext, bool fEnalbeCustomizedQueryInterface);
private static extern IntPtr /* IUnknown* */ GetComInterfaceForObjectNative(object o, Type t, bool onlyInContext, bool fEnableCustomizedQueryInterface);

/// <summary>
/// Return the managed object representing the IUnknown*
/// </summary>
public static object GetObjectForIUnknown(IntPtr /* IUnknown* */ pUnk)
{
if (pUnk == IntPtr.Zero)
{
throw new ArgumentNullException(nameof(pUnk));
}

elinor-fung marked this conversation as resolved.
Show resolved Hide resolved
return GetObjectForIUnknownNative(pUnk);
}

[MethodImpl(MethodImplOptions.InternalCall)]
public static extern object GetObjectForIUnknown(IntPtr /* IUnknown* */ pUnk);
private static extern object GetObjectForIUnknownNative(IntPtr /* IUnknown* */ pUnk);

public static object GetUniqueObjectForIUnknown(IntPtr unknown)
{
if (unknown == IntPtr.Zero)
{
throw new ArgumentNullException(nameof(unknown));
}

elinor-fung marked this conversation as resolved.
Show resolved Hide resolved
return GetUniqueObjectForIUnknownNative(unknown);
}

/// <summary>
/// Return a unique Object given an IUnknown. This ensures that you receive a fresh
Expand All @@ -385,7 +438,7 @@ public static string GetTypeInfoName(ITypeInfo typeInfo)
/// ReleaseComObject on a RCW and not worry about other active uses ofsaid RCW.
/// </summary>
[MethodImpl(MethodImplOptions.InternalCall)]
public static extern object GetUniqueObjectForIUnknown(IntPtr unknown);
private static extern object GetUniqueObjectForIUnknownNative(IntPtr unknown);

/// <summary>
/// Return an Object for IUnknown, using the Type T.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ internal static class InterfaceMarshaler
internal static extern IntPtr ConvertToNative(object objSrc, IntPtr itfMT, IntPtr classMT, int flags);

[MethodImpl(MethodImplOptions.InternalCall)]
internal static extern object ConvertToManaged(IntPtr pUnk, IntPtr itfMT, IntPtr classMT, int flags);
internal static extern object ConvertToManaged(IntPtr ppUnk, IntPtr itfMT, IntPtr classMT, int flags);

[DllImport(RuntimeHelpers.QCall)]
internal static extern void ClearNative(IntPtr pUnk);
Expand Down
3 changes: 2 additions & 1 deletion src/coreclr/src/interop/comwrappers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ enum class CreateComInterfaceFlagsEx : int32_t
TrackerSupport = InteropLib::Com::CreateComInterfaceFlags_TrackerSupport,

// Highest bit is reserved for internal usage
IsComActivated = 1 << 30,
IsPegged = 1 << 31,

InternalMask = IsPegged,
InternalMask = IsPegged | IsComActivated,
};

DEFINE_ENUM_FLAG_OPERATORS(CreateComInterfaceFlagsEx);
Expand Down
6 changes: 6 additions & 0 deletions src/coreclr/src/interop/inc/interoplib.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ namespace InteropLib
// Reactivate the supplied wrapper.
HRESULT ReactivateWrapper(_In_ IUnknown* wrapper, _In_ InteropLib::OBJECTHANDLE handle) noexcept;

// Get the object for the supplied wrapper
HRESULT GetObjectForWrapper(_In_ IUnknown* wrapper, _Outptr_result_maybenull_ OBJECTHANDLE* object) noexcept;

HRESULT MarkComActivated(_In_ IUnknown* wrapper) noexcept;
HRESULT IsComActivated(_In_ IUnknown* wrapper) noexcept;

struct ExternalWrapperResult
{
// The returned context memory is guaranteed to be initialized to zero.
Expand Down
37 changes: 37 additions & 0 deletions src/coreclr/src/interop/interoplib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,43 @@ namespace InteropLib
return S_OK;
}

HRESULT GetObjectForWrapper(_In_ IUnknown* wrapper, _Outptr_result_maybenull_ OBJECTHANDLE* object) noexcept
{
if (object == nullptr)
return E_POINTER;

*object = nullptr;

HRESULT hr = IsActiveWrapper(wrapper);
if (hr != S_OK)
return hr;

ManagedObjectWrapper *mow = ManagedObjectWrapper::MapFromIUnknown(wrapper);
_ASSERTE(mow != nullptr);

*object = mow->Target;
return S_OK;
}

HRESULT MarkComActivated(_In_ IUnknown* wrapperMaybe) noexcept
{
ManagedObjectWrapper* wrapper = ManagedObjectWrapper::MapFromIUnknown(wrapperMaybe);
if (wrapper == nullptr)
return E_INVALIDARG;

wrapper->SetFlag(CreateComInterfaceFlagsEx::IsComActivated);
return S_OK;
}

HRESULT IsComActivated(_In_ IUnknown* wrapperMaybe) noexcept
{
ManagedObjectWrapper* wrapper = ManagedObjectWrapper::MapFromIUnknown(wrapperMaybe);
if (wrapper == nullptr)
return E_INVALIDARG;

return wrapper->IsSet(CreateComInterfaceFlagsEx::IsComActivated) ? S_OK : S_FALSE;
}

HRESULT CreateWrapperForExternal(
_In_ IUnknown* external,
_In_ enum CreateObjectFlags flags,
Expand Down
9 changes: 5 additions & 4 deletions src/coreclr/src/vm/ecalllist.h
Original file line number Diff line number Diff line change
Expand Up @@ -805,8 +805,8 @@ FCFuncStart(gInteropMarshalFuncs)
FCFuncElement("GetHRForException", MarshalNative::GetHRForException)
FCFuncElement("GetRawIUnknownForComObjectNoAddRef", MarshalNative::GetRawIUnknownForComObjectNoAddRef)
FCFuncElement("IsComObject", MarshalNative::IsComObject)
FCFuncElement("GetObjectForIUnknown", MarshalNative::GetObjectForIUnknown)
FCFuncElement("GetUniqueObjectForIUnknown", MarshalNative::GetUniqueObjectForIUnknown)
FCFuncElement("GetObjectForIUnknownNative", MarshalNative::GetObjectForIUnknownNative)
FCFuncElement("GetUniqueObjectForIUnknownNative", MarshalNative::GetUniqueObjectForIUnknownNative)
FCFuncElement("AddRef", MarshalNative::AddRef)
FCFuncElement("GetNativeVariantForObject", MarshalNative::GetNativeVariantForObject)
FCFuncElement("GetObjectForNativeVariant", MarshalNative::GetObjectForNativeVariant)
Expand Down Expand Up @@ -997,8 +997,9 @@ FCFuncEnd()
#ifdef FEATURE_COMWRAPPERS
FCFuncStart(gComWrappersFuncs)
QCFuncElement("GetIUnknownImplInternal", ComWrappersNative::GetIUnknownImpl)
QCFuncElement("GetOrCreateComInterfaceForObjectInternal", ComWrappersNative::GetOrCreateComInterfaceForObject)
QCFuncElement("GetOrCreateObjectForComInstanceInternal", ComWrappersNative::GetOrCreateObjectForComInstance)
QCFuncElement("TryGetOrCreateComInterfaceForObjectInternal", ComWrappersNative::TryGetOrCreateComInterfaceForObject)
QCFuncElement("TryGetOrCreateObjectForComInstanceInternal", ComWrappersNative::TryGetOrCreateObjectForComInstance)
QCFuncElement("SetGlobalInstanceRegistered", GlobalComWrappers::SetGlobalInstanceRegistered)
FCFuncEnd()
#endif // FEATURE_COMWRAPPERS

Expand Down
Loading