diff --git a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs index af899efb..54cae8bd 100644 --- a/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs +++ b/src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs @@ -76,6 +76,10 @@ internal static SyntaxToken Token(SyntaxKind kind) internal static ImplicitArrayCreationExpressionSyntax ImplicitArrayCreationExpression(InitializerExpressionSyntax initializerExpression) => SyntaxFactory.ImplicitArrayCreationExpression(Token(SyntaxKind.NewKeyword), Token(SyntaxKind.OpenBracketToken), default, Token(SyntaxKind.CloseBracketToken), initializerExpression); + internal static CollectionExpressionSyntax CollectionExpression(SeparatedSyntaxList elements = default) => SyntaxFactory.CollectionExpression(elements); + + internal static ExpressionElementSyntax ExpressionElement(ExpressionSyntax expression) => SyntaxFactory.ExpressionElement(expression); + internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? declaration, ExpressionSyntax condition, SeparatedSyntaxList incrementors, StatementSyntax statement) { SyntaxToken semicolonToken = SyntaxFactory.Token(TriviaList(), SyntaxKind.SemicolonToken, TriviaList(Space)); @@ -321,6 +325,8 @@ internal static SyntaxList List(IEnumerable nodes) internal static TypeConstraintSyntax TypeConstraint(TypeSyntax type) => SyntaxFactory.TypeConstraint(type); + internal static ClassOrStructConstraintSyntax ClassOrStructConstraint(SyntaxKind kind) => SyntaxFactory.ClassOrStructConstraint(kind); + internal static TypeParameterConstraintClauseSyntax TypeParameterConstraintClause(IdentifierNameSyntax name, SeparatedSyntaxList constraints) => SyntaxFactory.TypeParameterConstraintClause(TokenWithSpace(SyntaxKind.WhereKeyword), name, TokenWithSpaces(SyntaxKind.ColonToken), constraints); internal static FieldDeclarationSyntax FieldDeclaration(VariableDeclarationSyntax declaration) => SyntaxFactory.FieldDeclaration(default, default, declaration, Semicolon); diff --git a/src/Microsoft.Windows.CsWin32/Generator.Com.cs b/src/Microsoft.Windows.CsWin32/Generator.Com.cs index 2c210e67..f87a7f05 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.Com.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.Com.cs @@ -68,6 +68,10 @@ private static bool GenerateCcwFor(MetadataReader reader, StringHandle typeName, return true; } + private static StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionStatement(InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrExpression, HRThrowOnFailureMethodName), + ArgumentList())); + /// /// Generates a type to represent a COM interface. /// @@ -327,10 +331,6 @@ FunctionPointerParameterSyntax ToFunctionPointerParameter(ParameterSyntax p) if (methodDefinition.Generator.TryGetPropertyAccessorInfo(methodDefinition, originalIfaceName, context, out IdentifierNameSyntax? propertyName, out SyntaxKind? accessorKind, out TypeSyntax? propertyType) && declaredProperties.Contains(propertyName.Identifier.ValueText)) { - StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionStatement(InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrExpression, HRThrowOnFailureMethodName), - ArgumentList())); - BlockSyntax? body; switch (accessorKind) { @@ -1307,15 +1307,124 @@ private bool TryDeclareCOMGuidInterfaceIfNecessary() /// Creates an empty class that when instantiated, creates a cocreatable Windows object /// that may implement a number of interfaces at runtime, discoverable only by documentation. /// - private ClassDeclarationSyntax DeclareCocreatableClass(TypeDefinition typeDef) + private ClassDeclarationSyntax DeclareCocreatableClass(TypeDefinition typeDef, Context context) { + bool canUseComImport = context.AllowMarshaling && !this.useSourceGenerators; + IdentifierNameSyntax name = IdentifierName(this.Reader.GetString(typeDef.Name)); Guid guid = this.FindGuidFromAttribute(typeDef) ?? throw new ArgumentException("Type does not have a GuidAttribute."); SyntaxTokenList classModifiers = TokenList(TokenWithSpace(this.Visibility)); classModifiers = classModifiers.Add(TokenWithSpace(SyntaxKind.PartialKeyword)); ClassDeclarationSyntax result = ClassDeclaration(name.Identifier) .WithModifiers(classModifiers) - .AddAttributeLists(AttributeList().AddAttributes(GUID(guid), ComImportAttributeSyntax)); + .AddAttributeLists(AttributeList().AddAttributes(GUID(guid)).AddAttributes(canUseComImport ? [ComImportAttributeSyntax] : [])); + + if (!canUseComImport && !this.Options.ComInterop.UseIntPtrForComOutPointers) + { + string obsoleteMessage = context.AllowMarshaling + ? $"COM source generators do not support direct instantiation of co-creatable classes. Use {name.Identifier}.CreateInstance instead." + : $"Marshaling is disabled, so direct instantiation of co-creatable classes is not supported. Use {name.Identifier}.CreateInstance instead."; + + // Generate a private readonly field for the Guid + // private static readonly Guid CLSID_Foo = new Guid(...); + SyntaxToken clsidFieldName = Identifier($"CLSID_{name.Identifier}"); + FieldDeclarationSyntax clsidField = FieldDeclaration( + VariableDeclaration(IdentifierName(nameof(Guid))) + .AddVariables(VariableDeclarator(clsidFieldName).WithInitializer(EqualsValueClause(GuidValue(guid))))) + .AddModifiers(TokenWithSpace(SyntaxKind.PrivateKeyword), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.ReadOnlyKeyword)); + result = result.AddMembers(clsidField); + + // If using source generators or marshalling is disabled, generate a constructor with obsolete attribute like this: + // [Obsolete("COM source generators do not support direct instantiation of co-creatable classes. Use CreateInstance method instead.")] + // public Foo() { throw new NotSupportedException("COM source generators do not support direct instantiation of co-creatable classes. Use CreateInstance method instead."); } + AttributeSyntax obsoleteAttribute = + Attribute(IdentifierName(nameof(ObsoleteAttribute))) + .AddArgumentListArguments( + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(obsoleteMessage)))); + ConstructorDeclarationSyntax constructor = ConstructorDeclaration(name.Identifier) + .AddModifiers(TokenWithSpace(SyntaxKind.PublicKeyword)) + .AddAttributeLists(AttributeList().AddAttributes(obsoleteAttribute)) + .WithBody( + Block( + ThrowStatement( + ObjectCreationExpression(IdentifierName(nameof(NotSupportedException))) + .WithArgumentList( + ArgumentList().AddArguments( + Argument( + LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(obsoleteMessage)))))))); + result = result.AddMembers(constructor); + + this.MainGenerator.TryGenerateExternMethod("CoCreateInstance", out IReadOnlyCollection preciseApi); + this.MainGenerator.TryGenerateConstant("CLSCTX", out preciseApi); + + if (context.AllowMarshaling) + { + // Then add the CreateInstance method: + // public static T CreateInstance() where T : class + // { + // PInvoke.CoCreateInstance(CLSID_Foo, null, CLSCTX.CLSCTX_SERVER, out T ret).ThrowOnFailure(); + // return ret; + // } + TypeParameterSyntax typeParameter = TypeParameter(Identifier("T")); + GenericNameSyntax genericName = GenericName("CreateInstance").AddTypeArgumentListArguments(IdentifierName("T")); + MethodDeclarationSyntax createInstanceMethod = MethodDeclaration(IdentifierName("T"), genericName.Identifier) + .AddModifiers(TokenWithSpace(SyntaxKind.PublicKeyword), TokenWithSpace(SyntaxKind.StaticKeyword)) + .AddTypeParameterListParameters(typeParameter) + .AddConstraintClauses( + TypeParameterConstraintClause(IdentifierName("T"), SingletonSeparatedList(ClassOrStructConstraint(SyntaxKind.ClassConstraint)))) + .WithBody( + Block( + ThrowOnHRFailure( + InvocationExpression(QualifiedName(ParseName($"{this.Win32NamespacePrefix}.{this.options.ClassName}"), GenericName("CoCreateInstance").AddTypeArgumentListArguments(IdentifierName("T")))) + .WithArgumentList( + ArgumentList().AddArguments( + Argument(IdentifierName(clsidFieldName)), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), + Argument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + QualifiedName(ParseName($"{this.Win32NamespacePrefix}.System.Com"), IdentifierName("CLSCTX")), + IdentifierName("CLSCTX_SERVER"))), + Argument(DeclarationExpression(IdentifierName("T").WithTrailingTrivia(Space), SingleVariableDesignation(Identifier("ret")))).WithRefKindKeyword(Token(SyntaxKind.OutKeyword))))), + ReturnStatement(IdentifierName("ret")))); + result = result.AddMembers(createInstanceMethod); + } + else + { + // Then add a CreateInstance method that looks like this: + // public static HRESULT CreateInstance(out T* instance) where T : unmanaged + // { + // return PInvoke.CoCreateInstance(CLSID_Foo, null, CLSCTX.CLSCTX_SERVER, out instance); + // } + TypeParameterSyntax typeParameter = TypeParameter(Identifier("T")); + GenericNameSyntax genericName = GenericName("CreateInstance").AddTypeArgumentListArguments(IdentifierName("T")); + MethodDeclarationSyntax createInstanceMethod = MethodDeclaration(IdentifierName($"{this.Win32NamespacePrefix}.Foundation.HRESULT"), genericName.Identifier) + .AddModifiers(TokenWithSpace(SyntaxKind.PublicKeyword), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.UnsafeKeyword)) + .AddTypeParameterListParameters(typeParameter) + .AddConstraintClauses( + TypeParameterConstraintClause(IdentifierName("T"), SingletonSeparatedList(TypeConstraint(IdentifierName("unmanaged"))))) + .WithParameterList( + ParameterList().AddParameters( + Parameter(Identifier("instance")) + .WithType(PointerType(IdentifierName("T"))) + .WithModifiers(TokenList(Token(SyntaxKind.OutKeyword))))) + .WithBody( + Block( + ReturnStatement( + InvocationExpression(QualifiedName(ParseName($"{this.Win32NamespacePrefix}.{this.options.ClassName}"), GenericName("CoCreateInstance").AddTypeArgumentListArguments(IdentifierName("T")))) + .WithArgumentList( + ArgumentList().AddArguments( + Argument(IdentifierName(clsidFieldName)), + Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)), + Argument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + QualifiedName(ParseName($"{this.Win32NamespacePrefix}.System.Com"), IdentifierName("CLSCTX")), + IdentifierName("CLSCTX_SERVER"))), + Argument(IdentifierName("instance")).WithRefKindKeyword(Token(SyntaxKind.OutKeyword))))))); + result = result.AddMembers(createInstanceMethod); + } + } result = this.AddApiDocumentation(name.Identifier.ValueText, result); return result; diff --git a/src/Microsoft.Windows.CsWin32/Generator.cs b/src/Microsoft.Windows.CsWin32/Generator.cs index d12900cd..79e062e8 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.cs @@ -1441,7 +1441,7 @@ private IReadOnlyList FindTypeSymbolsIfAlreadyAvailable(string fullyQua } else if (this.IsEmptyStructWithGuid(typeDef)) { - typeDeclaration = this.DeclareCocreatableClass(typeDef); + typeDeclaration = this.DeclareCocreatableClass(typeDef, context); } else { diff --git a/src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs b/src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs index 59553bc0..1d085800 100644 --- a/src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs +++ b/src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs @@ -425,6 +425,34 @@ internal static ObjectCreationExpressionSyntax GuidValue(CustomAttribute guidAtt Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(k), k)))); } + internal static ObjectCreationExpressionSyntax GuidValue(Guid guid) + { + byte[] bytes = guid.ToByteArray(); + uint a = BitConverter.ToUInt32(bytes, 0); + ushort b = BitConverter.ToUInt16(bytes, 4); + ushort c = BitConverter.ToUInt16(bytes, 6); + byte d = bytes[8]; + byte e = bytes[9]; + byte f = bytes[10]; + byte g = bytes[11]; + byte h = bytes[12]; + byte i = bytes[13]; + byte j = bytes[14]; + byte k = bytes[15]; + return ObjectCreationExpression(GuidTypeSyntax).AddArgumentListArguments( + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(a), a))), + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(b), b))), + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(c), c))), + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(d), d))), + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(e), e))), + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(f), f))), + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(g), g))), + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(h), h))), + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(i), i))), + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(j), j))), + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(k), k)))); + } + internal static ExpressionSyntax IntPtrExpr(IntPtr value) => ObjectCreationExpression(IntPtrTypeSyntax).AddArgumentListArguments( Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(value.ToInt64())))); diff --git a/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs b/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs index c0ee5c7c..30aea6f8 100644 --- a/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs +++ b/test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs @@ -100,6 +100,21 @@ public async Task TestPlatformCaseSensitivity(string platform) await this.InvokeGeneratorAndCompile($"{nameof(this.TestPlatformCaseSensitivity)}_{platform}"); } + [Fact] + public async Task TestGenerateCoCreateableClass() + { + this.nativeMethods.Add("ShellLink"); + await this.InvokeGeneratorAndCompileFromFact(); + + var shellLinkType = Assert.Single(this.FindGeneratedType("ShellLink")); + + // Check that it does not have the ComImport attribute. + Assert.DoesNotContain(shellLinkType.AttributeLists, al => al.Attributes.Any(attr => attr.Name.ToString().Contains("ComImport"))); + + // Check that it contains a CreateInstance method + Assert.Contains(shellLinkType.DescendantNodes().OfType(), method => method.Identifier.Text == "CreateInstance"); + } + [Theory] [InlineData("IMFMediaKeySession", "get_KeySystem", "winmdroot.Foundation.BSTR* keySystem")] [InlineData("AddPrinterW", "AddPrinter", "winmdroot.Foundation.PWSTR pName, uint Level, Span pPrinter")] diff --git a/test/GenerationSandbox.BuildTask.Tests/COMTests.cs b/test/GenerationSandbox.BuildTask.Tests/COMTests.cs index 908cb7b7..13318860 100644 --- a/test/GenerationSandbox.BuildTask.Tests/COMTests.cs +++ b/test/GenerationSandbox.BuildTask.Tests/COMTests.cs @@ -4,6 +4,7 @@ #pragma warning disable IDE0005 #pragma warning disable SA1201, SA1512, SA1005, SA1507, SA1515, SA1403, SA1402, SA1411, SA1300, SA1313, SA1134, SA1307, SA1308 +using System.ComponentModel; using System.Net.NetworkInformation; using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; @@ -17,6 +18,7 @@ using Windows.Win32.Graphics.Direct3D11; using Windows.Win32.System.Com; using Windows.Win32.System.WinRT.Composition; +using Windows.Win32.UI.Shell; [Trait("WindowsOnly", "true")] public partial class COMTests @@ -70,4 +72,12 @@ public async Task CanInteropWithICompositorInterop() Assert.Skip("Skipping due to UnauthorizedAccessException."); } } + + [Fact] + public void CocreatableClassesWithImplicitInterfaces() + { + var shellLinkW = ShellLink.CreateInstance(); + var persistFile = (IPersistFile)shellLinkW; + Assert.NotNull(persistFile); + } } diff --git a/test/GenerationSandbox.BuildTask.Tests/NativeMethods.txt b/test/GenerationSandbox.BuildTask.Tests/NativeMethods.txt index a4fddae7..f977d933 100644 --- a/test/GenerationSandbox.BuildTask.Tests/NativeMethods.txt +++ b/test/GenerationSandbox.BuildTask.Tests/NativeMethods.txt @@ -34,4 +34,7 @@ WINTRUST_DATA WINTRUST_FILE_INFO WinVerifyTrust WM_HOTKEY -WNDCLASSW \ No newline at end of file +WNDCLASSW +ShellLink +CoCreateInstance +IShellLinkW \ No newline at end of file diff --git a/test/GenerationSandbox.Unmarshalled.Tests/COMTests.cs b/test/GenerationSandbox.Unmarshalled.Tests/COMTests.cs index 388e88c6..b686a2dd 100644 --- a/test/GenerationSandbox.Unmarshalled.Tests/COMTests.cs +++ b/test/GenerationSandbox.Unmarshalled.Tests/COMTests.cs @@ -1,10 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -#pragma warning disable IDE0005 +#pragma warning disable IDE0005,SA1202 +using System.Runtime.InteropServices; using Windows.Win32; using Windows.Win32.System.Com; +using Windows.Win32.UI.Shell; public class COMTests { @@ -19,5 +21,19 @@ public void COMStaticGuid() private static Guid GetGuid() where T : IComIID => T.Guid; + + [Trait("WindowsOnly", "true")] + [Fact] + public unsafe void CocreatableClassesWithImplicitInterfaces() + { + Assert.SkipUnless(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "Test calls Windows-specific APIs"); + + ShellLink.CreateInstance(out IShellLinkW* shellLinkWPtr).ThrowOnFailure(); + shellLinkWPtr->QueryInterface(typeof(IPersistFile).GUID, out void* ppv).ThrowOnFailure(); + IPersistFile* persistFilePtr = (IPersistFile*)ppv; + Assert.NotNull(persistFilePtr); + persistFilePtr->Release(); + shellLinkWPtr->Release(); + } #endif } diff --git a/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt b/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt index c59c92ce..9a5ec71c 100644 --- a/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt +++ b/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt @@ -1,3 +1,5 @@ IEventSubscription IPersistFile IStream +ShellLink +IShellLinkW \ No newline at end of file diff --git a/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs b/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs index ca8602e2..3218191b 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs @@ -487,4 +487,23 @@ public void IUnknown_Derived_QueryInterfaceGenericHelper() this.FindGeneratedMethod("QueryInterface"), m => m.Parent is StructDeclarationSyntax { Identifier.Text: "ITypeLib" } && m.TypeParameterList?.Parameters.Count == 1); } + + [Theory, PairwiseData] + public void TestGenerateCoCreateableClass(bool useIntPtrForComOutPtr) + { + this.generator = this.CreateGenerator(new GeneratorOptions { AllowMarshaling = false, ComInterop = new GeneratorOptions.ComInteropOptions { UseIntPtrForComOutPointers = useIntPtrForComOutPtr } }); + + this.GenerateApi("ShellLink"); + + var shellLinkType = Assert.Single(this.FindGeneratedType("ShellLink")); + + // Check that it does not have the ComImport attribute. + Assert.DoesNotContain(shellLinkType.AttributeLists, al => al.Attributes.Any(attr => attr.Name.ToString().Contains("ComImport"))); + + if (!useIntPtrForComOutPtr) + { + // Check that it contains a CreateInstance method + Assert.Contains(shellLinkType.DescendantNodes().OfType(), method => method.Identifier.Text == "CreateInstance"); + } + } }