diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs index 23457d50f85b0a..ed09b33a9d7283 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs @@ -239,13 +239,8 @@ private MarshallingInfo GetMarshallingInfo( // If we aren't overriding the marshalling at usage time, // then fall back to the information on the element type itself. - foreach (AttributeData typeAttribute in type.GetAttributes()) - { - if (GetMarshallingInfoForAttribute(typeAttribute, type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) is MarshallingInfo marshallingInfo) - { - return marshallingInfo; - } - } + if (GetMarshallingInfoForAttributes(type.GetAttributes().AsSpan(), type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) is MarshallingInfo info) + return info; // If the type doesn't have custom attributes that dictate marshalling, // then consider the type itself. @@ -253,14 +248,20 @@ private MarshallingInfo GetMarshallingInfo( } private MarshallingInfo? GetMarshallingInfoForAttribute(AttributeData attribute, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) + { + return GetMarshallingInfoForAttributes(new AttributeData[] { attribute }.AsSpan(), type, indirectionDepth, useSiteAttributes, marshallingInfoCallback); + } + + private MarshallingInfo? GetMarshallingInfoForAttributes(ReadOnlySpan attrs, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback) { foreach (var parser in _marshallingAttributeParsers) { - // Automatically ignore invalid attributes. - // The compiler will already error on them. - if (attribute.AttributeConstructor is not null && parser.CanParseAttributeType(attribute.AttributeClass)) + foreach (var attr in attrs) { - return parser.ParseAttribute(attribute, type, indirectionDepth, useSiteAttributes, marshallingInfoCallback); + if (attr.AttributeConstructor is not null && parser.CanParseAttributeType(attr.AttributeClass)) + { + return parser.ParseAttribute(attr, type, indirectionDepth, useSiteAttributes, marshallingInfoCallback); + } } } return null; diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/NativeMarshallingAttributeTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/NativeMarshallingAttributeTests.cs new file mode 100644 index 00000000000000..89f54764983be2 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/NativeMarshallingAttributeTests.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +using SharedTypes.ComInterfaces; +using Xunit; + +namespace ComInterfaceGenerator.Tests +{ + public unsafe partial class NativeMarshallingAttributeTests + { + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_unique_marshalling")] + internal static partial IUniqueMarshalling NewUniqueMarshalling(); + + [Fact] + public void MethodReturningComInterfaceReturnsUniqueInstance() + { + // When a COM interface method returns the same interface type, + // it should return a new managed instance, not the cached one + var obj = NewUniqueMarshalling(); + obj.SetValue(42); + + var returnedObj = obj.GetThis(); + + // Should be a different managed object + Assert.NotSame(obj, returnedObj); + + // But should refer to the same underlying COM object + Assert.Equal(42, returnedObj.GetValue()); + + // Modifying through one should affect the other + returnedObj.SetValue(100); + Assert.Equal(100, obj.GetValue()); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs index 8b3f5652bbab6b..514b400c3f1c91 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs @@ -751,5 +751,21 @@ public Bidirectional(IComInterfaceAttributeProvider attributeProvider) public IComInterfaceAttributeProvider AttributeProvider { get; } } + + public string ComInterfaceWithNativeMarshalling => $$""" + using System; + using System.Runtime.CompilerServices; + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [assembly:DisableRuntimeMarshalling] + + {{GeneratedComInterface()}} + [NativeMarshalling(typeof(UniqueComInterfaceMarshaller))] + partial interface IFoo + { + void DoWorkTogether(IFoo foo); + } + """; } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs index 247bd7a84f56e0..ba4928fec1d367 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs @@ -349,6 +349,7 @@ public static IEnumerable ComInterfaceSnippetsToCompile() yield return new object[] { ID(), codeSnippets.ForwarderWithPreserveSigAndRefKind("ref readonly") }; yield return new object[] { ID(), codeSnippets.ForwarderWithPreserveSigAndRefKind("in") }; yield return new object[] { ID(), codeSnippets.ForwarderWithPreserveSigAndRefKind("out") }; + yield return new object[] { ID(), codeSnippets.ComInterfaceWithNativeMarshalling }; } public static IEnumerable ManagedToUnmanagedComInterfaceSnippetsToCompile() diff --git a/src/libraries/System.Runtime.InteropServices/tests/Common/ComInterfaces/IUniqueMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/Common/ComInterfaces/IUniqueMarshalling.cs new file mode 100644 index 00000000000000..628f90fa58c4ef --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/Common/ComInterfaces/IUniqueMarshalling.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; + +namespace SharedTypes.ComInterfaces +{ + [GeneratedComInterface] + [Guid(IID)] + [NativeMarshalling(typeof(UniqueComInterfaceMarshaller))] + internal partial interface IUniqueMarshalling + { + int GetValue(); + void SetValue(int x); + IUniqueMarshalling GetThis(); + + public const string IID = "E11D5F3E-DD57-4E7E-A78C-F5F8B8E0A1F4"; + } + + [GeneratedComClass] + internal partial class UniqueMarshalling : IUniqueMarshalling + { + int _data = 0; + public int GetValue() => _data; + public void SetValue(int x) => _data = x; + public IUniqueMarshalling GetThis() => this; + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/LibraryImportGenerator.Tests.csproj b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/LibraryImportGenerator.Tests.csproj index 8e99f4edd9dc0c..0c5778c8390cc0 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/LibraryImportGenerator.Tests.csproj +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/LibraryImportGenerator.Tests.csproj @@ -14,6 +14,8 @@ + diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/NativeMarshallingAttributeTests.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/NativeMarshallingAttributeTests.cs new file mode 100644 index 00000000000000..c58a67e30e97fb --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/NativeMarshallingAttributeTests.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +using SharedTypes.ComInterfaces; +using Xunit; + +namespace LibraryImportGenerator.IntegrationTests +{ + partial class NativeExportsNE + { + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "new_unique_marshalling")] + internal static partial IUniqueMarshalling GetUniqueMarshalling(); + } + + public class NativeMarshallingAttributeTests + { + [Fact] + public void GetSameComInterfaceTwiceReturnsUniqueInstances() + { + // When using NativeMarshalling with UniqueComInterfaceMarshaller, + // calling GetUniqueMarshalling() twice returns different managed instances for the same COM object + var obj1 = NativeExportsNE.GetUniqueMarshalling(); + var obj2 = NativeExportsNE.GetUniqueMarshalling(); + + Assert.NotSame(obj1, obj2); + + // Both refer to the same underlying COM object (same cached pointer) + obj1.SetValue(42); + Assert.Equal(42, obj2.GetValue()); + + // Modifying through one should affect the other + obj2.SetValue(100); + Assert.Equal(100, obj1.GetValue()); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs index 3228bfc44ed64f..c0896cc09e8f5d 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs @@ -1526,5 +1526,25 @@ public void Free() { } } } """; + + public static string ComInterfaceWithNativeMarshallingInLibraryImport => """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + [Guid("0E7204B5-4B61-4E06-B872-82BA652F2ECA")] + [NativeMarshalling(typeof(UniqueComInterfaceMarshaller))] + partial interface IFoo + { + void DoWork(); + } + + static partial class PInvokes + { + [LibraryImport("lib")] + [return: MarshalAs(UnmanagedType.I1)] + public static partial bool TryGetFoo(out IFoo foo); + } + """; } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs index b7d3f37f87a2c8..afa6af2bb4c9a6 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs @@ -255,6 +255,7 @@ public static IEnumerable CodeSnippetsToCompile() // Type-level interop generator trigger attributes yield return new[] { ID(), CodeSnippets.GeneratedComInterface }; + yield return new[] { ID(), CodeSnippets.ComInterfaceWithNativeMarshallingInLibraryImport }; // Parameter modifiers yield return new[] { ID(), CodeSnippets.SingleParameterWithModifier("int", "scoped ref") }; diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/UniqueMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/UniqueMarshalling.cs new file mode 100644 index 00000000000000..f7dc257e01261f --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/UniqueMarshalling.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +using SharedTypes.ComInterfaces; + +namespace NativeExports.ComInterfaceGenerator +{ + public static unsafe class UniqueMarshalling + { + private static void* s_cachedPtr = null; + + // Call from another assembly to get a ptr to make an RCW + [UnmanagedCallersOnly(EntryPoint = "new_unique_marshalling")] + public static void* CreateComObject() + { + if (s_cachedPtr == null) + { + StrategyBasedComWrappers wrappers = new(); + var myObject = new SharedTypes.ComInterfaces.UniqueMarshalling(); + nint ptr = wrappers.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None); + s_cachedPtr = (void*)ptr; + } + + return s_cachedPtr; + } + } +}