Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -239,28 +239,29 @@ private MarshallingInfo GetMarshallingInfo(

// If we aren't overriding the marshalling at usage time,
// then fall back to the information on the element type itself.
foreach (AttributeData typeAttribute in type.GetAttributes())
{
if (GetMarshallingInfoForAttribute(typeAttribute, type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) is MarshallingInfo marshallingInfo)
{
return marshallingInfo;
}
}
if (GetMarshallingInfoForAttributes(type.GetAttributes().AsSpan(), type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) is MarshallingInfo info)
return info;

// If the type doesn't have custom attributes that dictate marshalling,
// then consider the type itself.
return GetMarshallingInfoForType(type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) ?? NoMarshallingInfo.Instance;
}

private MarshallingInfo? GetMarshallingInfoForAttribute(AttributeData attribute, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback)
{
return GetMarshallingInfoForAttributes(new AttributeData[] { attribute }.AsSpan(), type, indirectionDepth, useSiteAttributes, marshallingInfoCallback);
}

private MarshallingInfo? GetMarshallingInfoForAttributes(ReadOnlySpan<AttributeData> attrs, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback)
{
foreach (var parser in _marshallingAttributeParsers)
{
// Automatically ignore invalid attributes.
// The compiler will already error on them.
if (attribute.AttributeConstructor is not null && parser.CanParseAttributeType(attribute.AttributeClass))
foreach (var attr in attrs)
{
return parser.ParseAttribute(attribute, type, indirectionDepth, useSiteAttributes, marshallingInfoCallback);
if (attr.AttributeConstructor is not null && parser.CanParseAttributeType(attr.AttributeClass))
{
return parser.ParseAttribute(attr, type, indirectionDepth, useSiteAttributes, marshallingInfoCallback);
}
}
}
return null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using SharedTypes.ComInterfaces;
using Xunit;

namespace ComInterfaceGenerator.Tests
{
public unsafe partial class NativeMarshallingAttributeTests
{
[LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_unique_marshalling")]
internal static partial IUniqueMarshalling NewUniqueMarshalling();

[Fact]
public void GetSameComInterfaceTwiceReturnsUniqueInstances()
{
// When using NativeMarshalling with UniqueComInterfaceMarshaller,
// getting the same COM interface twice should return different managed instances
var obj1 = NewUniqueMarshalling();
var obj2 = NewUniqueMarshalling();

Assert.NotSame(obj1, obj2);

// Verify they work independently
obj1.SetValue(42);
obj2.SetValue(100);

Assert.Equal(42, obj1.GetValue());
Assert.Equal(100, obj2.GetValue());
}

[Fact]
public void MethodReturningComInterfaceReturnsUniqueInstance()
{
// When a COM interface method returns the same interface type,
// it should return a new managed instance, not the cached one
var obj = NewUniqueMarshalling();
obj.SetValue(42);

var returnedObj = obj.GetThis();

// Should be a different managed object
Assert.NotSame(obj, returnedObj);

// But should refer to the same underlying COM object
Assert.Equal(42, returnedObj.GetValue());

// Modifying through one should affect the other
returnedObj.SetValue(100);
Assert.Equal(100, obj.GetValue());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -751,5 +751,21 @@ public Bidirectional(IComInterfaceAttributeProvider attributeProvider)

public IComInterfaceAttributeProvider AttributeProvider { get; }
}

public string ComInterfaceWithNativeMarshalling => $$"""
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

[assembly:DisableRuntimeMarshalling]

{{GeneratedComInterface()}}
[NativeMarshalling(typeof(UniqueComInterfaceMarshaller<IFoo>))]
partial interface IFoo
{
void DoWorkTogether(IFoo foo);
}
""";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ public static IEnumerable<object[]> ComInterfaceSnippetsToCompile()
yield return new object[] { ID(), codeSnippets.ForwarderWithPreserveSigAndRefKind("ref readonly") };
yield return new object[] { ID(), codeSnippets.ForwarderWithPreserveSigAndRefKind("in") };
yield return new object[] { ID(), codeSnippets.ForwarderWithPreserveSigAndRefKind("out") };
yield return new object[] { ID(), codeSnippets.ComInterfaceWithNativeMarshalling };
}

public static IEnumerable<object[]> ManagedToUnmanagedComInterfaceSnippetsToCompile()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

namespace SharedTypes.ComInterfaces
{
[GeneratedComInterface]
[Guid(IID)]
[NativeMarshalling(typeof(UniqueComInterfaceMarshaller<IUniqueMarshalling>))]
internal partial interface IUniqueMarshalling
{
int GetValue();
void SetValue(int x);
IUniqueMarshalling GetThis();

public const string IID = "E11D5F3E-DD57-4E7E-A78C-F5F8B8E0A1F4";
}

[GeneratedComClass]
internal partial class UniqueMarshalling : IUniqueMarshalling
{
int _data = 0;
public int GetValue() => _data;
public void SetValue(int x) => _data = x;
public IUniqueMarshalling GetThis() => this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1526,5 +1526,25 @@ public void Free() { }
}
}
""";

public static string ComInterfaceWithNativeMarshallingInLibraryImport => """
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

[GeneratedComInterface]
[Guid("0E7204B5-4B61-4E06-B872-82BA652F2ECA")]
[NativeMarshalling(typeof(UniqueComInterfaceMarshaller<IFoo>))]
partial interface IFoo
{
void DoWork();
}

static partial class PInvokes
{
[LibraryImport("lib")]
[return: MarshalAs(UnmanagedType.I1)]
public static partial bool TryGetFoo(out IFoo foo);
}
""";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ public static IEnumerable<object[]> CodeSnippetsToCompile()

// Type-level interop generator trigger attributes
yield return new[] { ID(), CodeSnippets.GeneratedComInterface };
yield return new[] { ID(), CodeSnippets.ComInterfaceWithNativeMarshallingInLibraryImport };

// Parameter modifiers
yield return new[] { ID(), CodeSnippets.SingleParameterWithModifier("int", "scoped ref") };
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using SharedTypes.ComInterfaces;

namespace NativeExports.ComInterfaceGenerator
{
public static unsafe class UniqueMarshalling
{
// Call from another assembly to get a ptr to make an RCW
[UnmanagedCallersOnly(EntryPoint = "new_unique_marshalling")]
public static void* CreateComObject()
{
StrategyBasedComWrappers wrappers = new();
var myObject = new SharedTypes.ComInterfaces.UniqueMarshalling();
nint ptr = wrappers.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);

return (void*)ptr;
}
}
}
Loading