diff --git a/src/coreclr/vm/interoputil.cpp b/src/coreclr/vm/interoputil.cpp index a8f58ead1b3603..049e8854c13afa 100644 --- a/src/coreclr/vm/interoputil.cpp +++ b/src/coreclr/vm/interoputil.cpp @@ -3090,19 +3090,49 @@ void IUInvokeDispMethod( } } - // // Retrieve the IDispatch interface that will be invoked on. // if (pInvokedMT->IsInterface()) { - // The invoked type is a dispatch or dual interface so we will make the - // invocation on it. - pUnk = ComObject::GetComIPFromRCWThrowing(pTarget, pInvokedMT); - hr = SafeQueryInterface(pUnk, IID_IDispatch, (IUnknown**)&pDisp); + // COMPAT: We must invoke any methods on this specific IDispatch implementation, + // as the canonical implementation (returned by QueryInterface) may not + // expose all members. + // This can occur when pTarget is a CCW for a .NET object and that object's class + // has a custom default interface (specified by the System.Runtime.InteropServices.ComDefaultInterfaceAttribute attribute). + // In that case, the default IDispatch pointer will only resolve members defined on that interface, + // not all members defined on the class. + // + // We will still do the QI here because we want to protect the user from the following scenario: + // - The user has a COM object that implements one IUnknown interface, which we'll call ICallback. + // - The user defines a managed interface (represented by pInvokedMT)to represent ICallback + // - The user incorrectly marks ICallback as "implements IDispatch and not dual". + // - The user tries to call a method on ICallback. + // + // In this case, pInvokedMT will represent the managed ICallback definition. + // The underlying COM object will not implement IDispatch. + // + // We cannot verify that the vtable of the COM object returned by pInvokedMT will have the IDispatch methods, + // but we must use that IDispatch implementation (see above). + // To catch the simple case above (where there is no IDispatch implementation at all), + // we do a QI for IDispatch and throw away the result just to make sure we don't try to invoke on an object that doesn't implement IDispatch. + // + // If the underlying COM object implements IDispatch but the COM interface represented by pInvokedMT is not dispatch or dual, + // we will not correctly detect that the user did something wrong and will crash. + // This is a known issue with no solution. + // Our check here is best effort to catch the simple case where a user may make a mistake. + SafeComHolder pInvokedMTUnknown = ComObject::GetComIPFromRCWThrowing(pTarget, pInvokedMT); + + // QI for IDispatch to catch the simple error case (COM object has no IDispatch but pInvokedMT is specified as a dispatch or dual interface) + SafeComHolder pCanonicalDisp; + hr = SafeQueryInterface(pInvokedMTUnknown, IID_IDispatch, &pCanonicalDisp); if (FAILED(hr)) COMPlusThrow(kTargetException, W("TargetInvocation_TargetDoesNotImplementIDispatch")); + + _ASSERTE(IsDispatchBasedItf(pInvokedMT->GetComInterfaceType())); + // Extract the IDispatch pointer that is associated with pInvokedMT specifically. + pDisp = (IDispatch*)pInvokedMTUnknown.Extract(); } else { diff --git a/src/tests/Interop/COM/NETClients/ConsumeNETServer/ConsumeNETServer.csproj b/src/tests/Interop/COM/NETClients/ConsumeNETServer/ConsumeNETServer.csproj index 8902714d19b95e..402f04ce472bb0 100644 --- a/src/tests/Interop/COM/NETClients/ConsumeNETServer/ConsumeNETServer.csproj +++ b/src/tests/Interop/COM/NETClients/ConsumeNETServer/ConsumeNETServer.csproj @@ -14,6 +14,7 @@ + diff --git a/src/tests/Interop/COM/NETClients/ConsumeNETServer/Program.cs b/src/tests/Interop/COM/NETClients/ConsumeNETServer/Program.cs index 2e51444ff2115a..8427e94e58536c 100644 --- a/src/tests/Interop/COM/NETClients/ConsumeNETServer/Program.cs +++ b/src/tests/Interop/COM/NETClients/ConsumeNETServer/Program.cs @@ -1,17 +1,16 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +using TestLibrary; using Xunit; +using Server.Contract; namespace NetClient { - using System; - using System.Collections.Generic; - using System.Runtime.CompilerServices; - using System.Runtime.InteropServices; - - using TestLibrary; - using Xunit; - using Server.Contract; using CoClass = Server.Contract.Servers; @@ -105,7 +104,7 @@ static void Validate_Server_CCW_RCW() } [Fact] - public static int TestEntryPoint() + public static int ConsumeNETServerTests() { // RegFree COM is not supported on Windows Nano if (Utilities.IsWindowsNanoServer) @@ -141,5 +140,44 @@ public static int TestEntryPoint() return 100; } + + [Fact] + public static void Validate_IDispatch_Custom_DefaultInterface() + { + using (ComActivationHelpers.RegisterTypeForActivation()) + { + var myObjectType = Type.GetTypeFromCLSID(typeof(MyObject).GUID, throwOnError: true)!; + object obj = Activator.CreateInstance(myObjectType)!; + var iSecond = (ISecond)obj; + iSecond.Invoke2(); // Verify that we can invoke a method on a non-default IDispatch non-dual interface. + } + } } } + +[ComVisible(true)] +[ProgId("MyCompany.MyObject")] +[Guid("CF824E95-642E-4F8D-8CD1-F67BC588B107")] +[ClassInterface(ClassInterfaceType.None)] +[ComDefaultInterface(typeof(IFirst))] +public sealed class MyObject : IFirst, ISecond +{ + public void Invoke1() { } + public void Invoke2() { } +} + +[ComVisible(true)] +[Guid("59DAED5B-726B-4DDD-8F07-9236382038F7")] +[InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] +public interface IFirst +{ + void Invoke1(); +} + +[ComVisible(true)] +[Guid("88C3FF65-0155-48A9-917D-E96DFAC8409B")] +[InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] +public interface ISecond +{ + void Invoke2(); +} diff --git a/src/tests/Interop/common/ComActivationHelpers.cs b/src/tests/Interop/common/ComActivationHelpers.cs new file mode 100644 index 00000000000000..1cf9bd22c1f060 --- /dev/null +++ b/src/tests/Interop/common/ComActivationHelpers.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; + +public static class ComActivationHelpers +{ + public static IDisposable RegisterTypeForActivation() + where T: new() + { + Factory factory = new(() => new T()); + CoRegisterClassObject( + typeof(T).GUID, + factory, + 1, // CLSCTX_INPROC_SERVER + 1, // REGCLS_MULTIPLEUSE + out uint cookie); + return new RegistrationToken(cookie); + } + + private class RegistrationToken(uint token) : IDisposable + { + public void Dispose() + { + CoRevokeClassObject(token); + } + } + + [DllImport("ole32")] + private static extern int CoRegisterClassObject( + in Guid rclsid, + [MarshalAs(UnmanagedType.Interface)] IClassFactory pUnk, + uint dwClsContext, + uint flags, + out uint lpdwRegister); + + [DllImport("ole32")] + private static extern int CoRevokeClassObject(uint dwRegister); + + [ComImport] + [ComVisible(false)] + [Guid("00000001-0000-0000-C000-000000000046")] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] + private interface IClassFactory + { + void CreateInstance( + [MarshalAs(UnmanagedType.Interface)] object? pUnkOuter, + in Guid riid, + out nint ppvObject); + + void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock); + } + + private sealed class Factory(Func factory) : IClassFactory + { + public void CreateInstance(object? pUnkOuter, in Guid riid, out nint ppvObject) + { + if (pUnkOuter != null) + { + const int CLASS_E_NOAGGREGATION = unchecked((int)0x80040110); + Marshal.ThrowExceptionForHR(CLASS_E_NOAGGREGATION); + } + + nint ccw = Marshal.GetIUnknownForObject(factory()); + try + { + int hr = Marshal.QueryInterface(ccw, in riid, out ppvObject); + Marshal.ThrowExceptionForHR(hr); + } + finally + { + Marshal.Release(ccw); + } + } + + public void LockServer(bool fLock) + { + // No-op + } + } +}