Skip to content

Commit

Permalink
Change IMarshallingGenerator resolution to use a sequential design in…
Browse files Browse the repository at this point in the history
…stead of recursive (#96632)

* Make generator factories non-nested (instead using a sequential search model similar to how we parse attribute info) and rename to a better name based on the return type.

* Use empty instead of default

* Include all core resolvers in element marshalling resolution, not just CustomMarshaller marshallers.

* Run the attributed marshalling model (which handles UnmanagedBlittable marshalling info) before char marshalling to ensure that char marshalling works as expected.

* Rename all references to GeneratorFactory to GeneratorResolver
  • Loading branch information
jkoritzinsky authored Jan 10, 2024
1 parent cc5f1df commit 80aa8a0
Show file tree
Hide file tree
Showing 39 changed files with 529 additions and 479 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public JSExportCodeGenerator(
JSExportData attributeData,
JSSignatureContext signatureContext,
GeneratorDiagnosticsBag diagnosticsBag,
IMarshallingGeneratorFactory generatorFactory)
IMarshallingGeneratorResolver generatorResolver)
{
_signatureContext = signatureContext;
NativeToManagedStubCodeContext innerContext = new NativeToManagedStubCodeContext(ReturnIdentifier, ReturnIdentifier)
Expand All @@ -33,7 +33,7 @@ public JSExportCodeGenerator(
};
_context = new JSExportCodeContext(attributeData, innerContext);

_marshallers = BoundGenerators.Create(argTypes, generatorFactory, _context, new EmptyJSGenerator(), out var bindingFailures);
_marshallers = BoundGenerators.Create(argTypes, generatorResolver, _context, new EmptyJSGenerator(), out var bindingFailures);

diagnosticsBag.ReportGeneratorDiagnostics(bindingFailures);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ private static (MemberDeclarationSyntax, StatementSyntax, AttributeListSyntax, I
incrementalContext.JSExportData,
incrementalContext.SignatureContext,
diagnostics,
new JSGeneratorFactory());
new JSGeneratorResolver());

var wrapperName = "__Wrapper_" + incrementalContext.StubMethodSyntaxTemplate.Identifier + "_" + incrementalContext.SignatureContext.TypesHash;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

namespace Microsoft.Interop.JavaScript
{
internal sealed class JSGeneratorFactory : IMarshallingGeneratorFactory
internal sealed class JSGeneratorResolver : IMarshallingGeneratorResolver
{
public ResolvedGenerator Create(TypePositionInfo info, StubCodeContext context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ public JSImportCodeGenerator(
JSImportData attributeData,
JSSignatureContext signatureContext,
GeneratorDiagnosticsBag diagnosticsBag,
IMarshallingGeneratorFactory generatorFactory)
IMarshallingGeneratorResolver generatorResolver)
{
_signatureContext = signatureContext;
ManagedToNativeStubCodeContext innerContext = new ManagedToNativeStubCodeContext(ReturnIdentifier, ReturnIdentifier)
{
CodeEmitOptions = new(SkipInit: true)
};
_context = new JSImportCodeContext(attributeData, innerContext);
_marshallers = BoundGenerators.Create(argTypes, generatorFactory, _context, new EmptyJSGenerator(), out var bindingFailures);
_marshallers = BoundGenerators.Create(argTypes, generatorResolver, _context, new EmptyJSGenerator(), out var bindingFailures);

diagnosticsBag.ReportGeneratorDiagnostics(bindingFailures);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ private static (MemberDeclarationSyntax, ImmutableArray<DiagnosticInfo>) Generat
incrementalContext.JSImportData,
incrementalContext.SignatureContext,
diagnostics,
new JSGeneratorFactory());
new JSGeneratorResolver());

BlockSyntax code = stubGenerator.GenerateJSImportBody();

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ public override void Initialize(AnalysisContext context)
new CodeEmitOptions(SkipInit: true),
typeof(ConvertComImportToGeneratedComInterfaceAnalyzer).Assembly);
var managedToUnmanagedFactory = ComInterfaceGeneratorHelpers.GetGeneratorFactory(env.EnvironmentFlags, MarshalDirection.ManagedToUnmanaged);
var unmanagedToManagedFactory = ComInterfaceGeneratorHelpers.GetGeneratorFactory(env.EnvironmentFlags, MarshalDirection.UnmanagedToManaged);
var managedToUnmanagedFactory = ComInterfaceGeneratorHelpers.GetGeneratorResolver(env.EnvironmentFlags, MarshalDirection.ManagedToUnmanaged);
var unmanagedToManagedFactory = ComInterfaceGeneratorHelpers.GetGeneratorResolver(env.EnvironmentFlags, MarshalDirection.UnmanagedToManaged);
mayRequireAdditionalWork = diagnostics.Diagnostics.Any();
bool anyExplicitlyUnsupportedInfo = false;
Expand All @@ -92,7 +92,7 @@ public override void Initialize(AnalysisContext context)
var forwarder = new Forwarder();
// We don't actually need the bound generators. We just need them to be attempted to be bound to determine if the generator will be able to bind them.
BoundGenerators generators = BoundGenerators.Create(targetSignatureContext.ElementTypeInformation, new CallbackGeneratorFactory((info, context) =>
BoundGenerators generators = BoundGenerators.Create(targetSignatureContext.ElementTypeInformation, new CallbackGeneratorResolver((info, context) =>
{
if (s_unsupportedTypeNames.Contains(info.ManagedType.FullTypeName))
{
Expand Down Expand Up @@ -186,11 +186,11 @@ private static bool HasUnsupportedMarshalAsInfo(TypePositionInfo info)
|| unmanagedType == UnmanagedType.SafeArray;
}

private sealed class CallbackGeneratorFactory : IMarshallingGeneratorFactory
private sealed class CallbackGeneratorResolver : IMarshallingGeneratorResolver
{
private readonly Func<TypePositionInfo, StubCodeContext, ResolvedGenerator> _func;

public CallbackGeneratorFactory(Func<TypePositionInfo, StubCodeContext, ResolvedGenerator> func)
public CallbackGeneratorResolver(Func<TypePositionInfo, StubCodeContext, ResolvedGenerator> func)
{
_func = func;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterf
.Where(context => context.UnmanagedToManagedStub.Diagnostics.All(diag => diag.Descriptor.DefaultSeverity != DiagnosticSeverity.Error))
.Select(context => context.GenerationContext),
vtableLocalName,
ComInterfaceGeneratorHelpers.GetGeneratorFactory);
ComInterfaceGeneratorHelpers.GetGeneratorResolver);

return ImplementationInterfaceTemplate
.AddMembers(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

<ItemGroup>
<Compile Include="$(LibrariesProjectRoot)System.Runtime.InteropServices\src\System\Runtime\InteropServices\Marshalling\ComInterfaceOptions.cs" Link="Production\ComInterfaceOptions.cs" />
<Compile Include="..\Common\DefaultMarshallingInfoParser.cs" Link="Common\DefaultMarshallingInfoParser.cs" />
<Compile Include="..\..\tests\Common\ExceptionMarshalling.cs" Link="Common\ExceptionMarshalling.cs" />
<Compile Include="$(CommonPath)\Roslyn\DiagnosticDescriptorHelper.cs" Link="Common\Roslyn\DiagnosticDescriptorHelper.cs" />
<Compile Include="$(CommonPath)\Roslyn\GetBestTypeByMetadataName.cs" Link="Common\Roslyn\GetBestTypeByMetadataName.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,75 +2,34 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;

namespace Microsoft.Interop
{
internal static class ComInterfaceGeneratorHelpers
{
private static readonly IMarshallingGeneratorFactory s_managedToUnmanagedDisabledMarshallingGeneratorFactory = CreateGeneratorFactory(EnvironmentFlags.DisableRuntimeMarshalling, MarshalDirection.ManagedToUnmanaged);
private static readonly IMarshallingGeneratorFactory s_unmanagedToManagedDisabledMarshallingGeneratorFactory = CreateGeneratorFactory(EnvironmentFlags.DisableRuntimeMarshalling, MarshalDirection.UnmanagedToManaged);
private static readonly IMarshallingGeneratorFactory s_managedToUnmanagedEnabledMarshallingGeneratorFactory = CreateGeneratorFactory(EnvironmentFlags.None, MarshalDirection.ManagedToUnmanaged);
private static readonly IMarshallingGeneratorFactory s_unmanagedToManagedEnabledMarshallingGeneratorFactory = CreateGeneratorFactory(EnvironmentFlags.None, MarshalDirection.UnmanagedToManaged);

private static IMarshallingGeneratorFactory CreateGeneratorFactory(EnvironmentFlags env, MarshalDirection direction)
{
IMarshallingGeneratorFactory generatorFactory;

// If we're in a "supported" scenario, then emit a diagnostic as our final fallback.
generatorFactory = new UnsupportedMarshallingFactory();

generatorFactory = new NoMarshallingInfoErrorMarshallingFactory(generatorFactory, TypeNames.GeneratedComInterfaceAttribute_ShortName);

// Since the char type can go into the P/Invoke signature here, we can only use it when
// runtime marshalling is disabled.
generatorFactory = new CharMarshallingGeneratorFactory(generatorFactory, useBlittableMarshallerForUtf16: env.HasFlag(EnvironmentFlags.DisableRuntimeMarshalling), TypeNames.GeneratedComInterfaceAttribute_ShortName);

InteropGenerationOptions interopGenerationOptions = new(UseMarshalType: true);
generatorFactory = new MarshalAsMarshallingGeneratorFactory(interopGenerationOptions, generatorFactory);

generatorFactory = new StructAsHResultMarshallerFactory(generatorFactory);

IMarshallingGeneratorFactory elementFactory = new AttributedMarshallingModelGeneratorFactory(
// Since the char type in an array will not be part of the P/Invoke signature, we can
// use the regular blittable marshaller in all cases.
new CharMarshallingGeneratorFactory(generatorFactory, useBlittableMarshallerForUtf16: true, TypeNames.GeneratedComInterfaceAttribute_ShortName),
new AttributedMarshallingModelOptions(env.HasFlag(EnvironmentFlags.DisableRuntimeMarshalling), MarshalMode.ElementIn, MarshalMode.ElementRef, MarshalMode.ElementOut));
// We don't need to include the later generator factories for collection elements
// as the later generator factories only apply to parameters.
generatorFactory = new AttributedMarshallingModelGeneratorFactory(
generatorFactory,
elementFactory,
new AttributedMarshallingModelOptions(
env.HasFlag(EnvironmentFlags.DisableRuntimeMarshalling),
direction == MarshalDirection.ManagedToUnmanaged
? MarshalMode.ManagedToUnmanagedIn
: MarshalMode.UnmanagedToManagedOut,
direction == MarshalDirection.ManagedToUnmanaged
? MarshalMode.ManagedToUnmanagedRef
: MarshalMode.UnmanagedToManagedRef,
direction == MarshalDirection.ManagedToUnmanaged
? MarshalMode.ManagedToUnmanagedOut
: MarshalMode.UnmanagedToManagedIn));

generatorFactory = new ManagedHResultExceptionMarshallerFactory(generatorFactory, direction);

generatorFactory = new ComInterfaceDispatchMarshallerFactory(generatorFactory);

generatorFactory = new ByValueContentsMarshalKindValidator(generatorFactory);
generatorFactory = new BreakingChangeDetector(generatorFactory);

return generatorFactory;
}

public static IMarshallingGeneratorFactory GetGeneratorFactory(EnvironmentFlags env, MarshalDirection direction)
private static readonly IMarshallingGeneratorResolver s_managedToUnmanagedDisabledMarshallingGeneratorResolver = CreateGeneratorResolver(EnvironmentFlags.DisableRuntimeMarshalling, MarshalDirection.ManagedToUnmanaged);
private static readonly IMarshallingGeneratorResolver s_unmanagedToManagedDisabledMarshallingGeneratorResolver = CreateGeneratorResolver(EnvironmentFlags.DisableRuntimeMarshalling, MarshalDirection.UnmanagedToManaged);
private static readonly IMarshallingGeneratorResolver s_managedToUnmanagedEnabledMarshallingGeneratorResolver = CreateGeneratorResolver(EnvironmentFlags.None, MarshalDirection.ManagedToUnmanaged);
private static readonly IMarshallingGeneratorResolver s_unmanagedToManagedEnabledMarshallingGeneratorResolver = CreateGeneratorResolver(EnvironmentFlags.None, MarshalDirection.UnmanagedToManaged);

private static IMarshallingGeneratorResolver CreateGeneratorResolver(EnvironmentFlags env, MarshalDirection direction)
=> DefaultMarshallingGeneratorResolver.Create(env, direction, TypeNames.GeneratedComInterfaceAttribute_ShortName,
[
new StructAsHResultMarshallerFactory(),
new ManagedHResultExceptionGeneratorResolver(direction),
new ComInterfaceDispatchMarshallingResolver(),
]);

public static IMarshallingGeneratorResolver GetGeneratorResolver(EnvironmentFlags env, MarshalDirection direction)
=> (env.HasFlag(EnvironmentFlags.DisableRuntimeMarshalling), direction) switch
{
(true, MarshalDirection.ManagedToUnmanaged) => s_managedToUnmanagedDisabledMarshallingGeneratorFactory,
(true, MarshalDirection.UnmanagedToManaged) => s_unmanagedToManagedDisabledMarshallingGeneratorFactory,
(false, MarshalDirection.ManagedToUnmanaged) => s_managedToUnmanagedEnabledMarshallingGeneratorFactory,
(false, MarshalDirection.UnmanagedToManaged) => s_unmanagedToManagedEnabledMarshallingGeneratorFactory,
(true, MarshalDirection.ManagedToUnmanaged) => s_managedToUnmanagedDisabledMarshallingGeneratorResolver,
(true, MarshalDirection.UnmanagedToManaged) => s_unmanagedToManagedDisabledMarshallingGeneratorResolver,
(false, MarshalDirection.ManagedToUnmanaged) => s_managedToUnmanagedEnabledMarshallingGeneratorResolver,
(false, MarshalDirection.UnmanagedToManaged) => s_unmanagedToManagedEnabledMarshallingGeneratorResolver,
_ => throw new UnreachableException(),
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ private GeneratedMethodContextBase CreateManagedToUnmanagedStub()
{
return new SkippedStubContext(OriginalDeclaringInterface.Info.Type);
}
var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext, ComInterfaceGeneratorHelpers.GetGeneratorFactory);
var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext, ComInterfaceGeneratorHelpers.GetGeneratorResolver);
return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
}

Expand All @@ -93,7 +93,7 @@ private GeneratedMethodContextBase CreateUnmanagedToManagedStub()
{
return new SkippedStubContext(GenerationContext.OriginalDefiningType);
}
var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext, ComInterfaceGeneratorHelpers.GetGeneratorFactory);
var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext, ComInterfaceGeneratorHelpers.GetGeneratorResolver);
return new GeneratedStubCodeContext(GenerationContext.OriginalDefiningType, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public ManagedToNativeVTableMethodGenerator(
bool setLastError,
bool implicitThis,
GeneratorDiagnosticsBag diagnosticsBag,
IMarshallingGeneratorFactory generatorFactory)
IMarshallingGeneratorResolver generatorResolver)
{
_setLastError = setLastError;
if (implicitThis)
Expand All @@ -73,7 +73,7 @@ public ManagedToNativeVTableMethodGenerator(
}

_context = new ManagedToNativeStubCodeContext(ReturnIdentifier, ReturnIdentifier);
_marshallers = BoundGenerators.Create(argTypes, generatorFactory, _context, new Forwarder(), out var bindingFailures);
_marshallers = BoundGenerators.Create(argTypes, generatorResolver, _context, new Forwarder(), out var bindingFailures);

diagnosticsBag.ReportGeneratorDiagnostics(bindingFailures);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,10 @@ internal sealed record ComInterfaceDispatchMarshallingInfo : MarshallingInfo
public static readonly ComInterfaceDispatchMarshallingInfo Instance = new();
}

internal sealed class ComInterfaceDispatchMarshallerFactory : IMarshallingGeneratorFactory
internal sealed class ComInterfaceDispatchMarshallingResolver : IMarshallingGeneratorResolver
{
private readonly IMarshallingGeneratorFactory _inner;
public ComInterfaceDispatchMarshallerFactory(IMarshallingGeneratorFactory inner)
{
_inner = inner;
}

public ResolvedGenerator Create(TypePositionInfo info, StubCodeContext context)
=> info.MarshallingAttributeInfo is ComInterfaceDispatchMarshallingInfo ? ResolvedGenerator.Resolved(new Marshaller()) : _inner.Create(info, context);
=> info.MarshallingAttributeInfo is ComInterfaceDispatchMarshallingInfo ? ResolvedGenerator.Resolved(new Marshaller()) : ResolvedGenerator.UnresolvedGenerator;

private sealed class Marshaller : IMarshallingGenerator
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,16 @@ namespace Microsoft.Interop
{
internal sealed record ManagedHResultExceptionMarshallingInfo : MarshallingInfo;

internal sealed class ManagedHResultExceptionMarshallerFactory : IMarshallingGeneratorFactory
internal sealed class ManagedHResultExceptionGeneratorResolver : IMarshallingGeneratorResolver
{
private readonly IMarshallingGeneratorFactory _inner;
private readonly MarshalDirection _direction;

public ManagedHResultExceptionMarshallerFactory(IMarshallingGeneratorFactory inner, MarshalDirection direction)
public ManagedHResultExceptionGeneratorResolver(MarshalDirection direction)
{
if (direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.UnmanagedToManaged))
{
throw new ArgumentOutOfRangeException(nameof(direction));
}
_inner = inner;
_direction = direction;
}

Expand All @@ -41,7 +39,7 @@ public ResolvedGenerator Create(TypePositionInfo info, StubCodeContext context)
}
else
{
return _inner.Create(info, context);
return ResolvedGenerator.UnresolvedGenerator;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,10 @@ namespace Microsoft.Interop
{
internal sealed record ObjectUnwrapperInfo(TypeSyntax UnwrapperType) : MarshallingInfo;

internal sealed class ObjectUnwrapperMarshallerFactory : IMarshallingGeneratorFactory
internal sealed class ObjectUnwrapperResolver : IMarshallingGeneratorResolver
{
private readonly IMarshallingGeneratorFactory _inner;
public ObjectUnwrapperMarshallerFactory(IMarshallingGeneratorFactory inner)
{
_inner = inner;
}

public ResolvedGenerator Create(TypePositionInfo info, StubCodeContext context)
=> info.MarshallingAttributeInfo is ObjectUnwrapperInfo ? ResolvedGenerator.Resolved(new Marshaller()) : _inner.Create(info, context);
=> info.MarshallingAttributeInfo is ObjectUnwrapperInfo ? ResolvedGenerator.Resolved(new Marshaller()) : ResolvedGenerator.UnresolvedGenerator;

private sealed class Marshaller : IMarshallingGenerator
{
Expand Down
Loading

0 comments on commit 80aa8a0

Please sign in to comment.