Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ internal static SyntaxToken Token(SyntaxKind kind)

internal static ImplicitArrayCreationExpressionSyntax ImplicitArrayCreationExpression(InitializerExpressionSyntax initializerExpression) => SyntaxFactory.ImplicitArrayCreationExpression(Token(SyntaxKind.NewKeyword), Token(SyntaxKind.OpenBracketToken), default, Token(SyntaxKind.CloseBracketToken), initializerExpression);

internal static CollectionExpressionSyntax CollectionExpression(SeparatedSyntaxList<CollectionElementSyntax> elements = default) => SyntaxFactory.CollectionExpression(elements);

internal static ExpressionElementSyntax ExpressionElement(ExpressionSyntax expression) => SyntaxFactory.ExpressionElement(expression);

internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? declaration, ExpressionSyntax condition, SeparatedSyntaxList<ExpressionSyntax> incrementors, StatementSyntax statement)
{
SyntaxToken semicolonToken = SyntaxFactory.Token(TriviaList(), SyntaxKind.SemicolonToken, TriviaList(Space));
Expand Down Expand Up @@ -321,6 +325,8 @@ internal static SyntaxList<TNode> List<TNode>(IEnumerable<TNode> nodes)

internal static TypeConstraintSyntax TypeConstraint(TypeSyntax type) => SyntaxFactory.TypeConstraint(type);

internal static ClassOrStructConstraintSyntax ClassOrStructConstraint(SyntaxKind kind) => SyntaxFactory.ClassOrStructConstraint(kind);

internal static TypeParameterConstraintClauseSyntax TypeParameterConstraintClause(IdentifierNameSyntax name, SeparatedSyntaxList<TypeParameterConstraintSyntax> constraints) => SyntaxFactory.TypeParameterConstraintClause(TokenWithSpace(SyntaxKind.WhereKeyword), name, TokenWithSpaces(SyntaxKind.ColonToken), constraints);

internal static FieldDeclarationSyntax FieldDeclaration(VariableDeclarationSyntax declaration) => SyntaxFactory.FieldDeclaration(default, default, declaration, Semicolon);
Expand Down
121 changes: 115 additions & 6 deletions src/Microsoft.Windows.CsWin32/Generator.Com.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ private static bool GenerateCcwFor(MetadataReader reader, StringHandle typeName,
return true;
}

private static StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionStatement(InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrExpression, HRThrowOnFailureMethodName),
ArgumentList()));

/// <summary>
/// Generates a type to represent a COM interface.
/// </summary>
Expand Down Expand Up @@ -327,10 +331,6 @@ FunctionPointerParameterSyntax ToFunctionPointerParameter(ParameterSyntax p)
if (methodDefinition.Generator.TryGetPropertyAccessorInfo(methodDefinition, originalIfaceName, context, out IdentifierNameSyntax? propertyName, out SyntaxKind? accessorKind, out TypeSyntax? propertyType) &&
declaredProperties.Contains(propertyName.Identifier.ValueText))
{
StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionStatement(InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, hrExpression, HRThrowOnFailureMethodName),
ArgumentList()));

BlockSyntax? body;
switch (accessorKind)
{
Expand Down Expand Up @@ -1307,15 +1307,124 @@ private bool TryDeclareCOMGuidInterfaceIfNecessary()
/// Creates an empty class that when instantiated, creates a cocreatable Windows object
/// that may implement a number of interfaces at runtime, discoverable only by documentation.
/// </summary>
private ClassDeclarationSyntax DeclareCocreatableClass(TypeDefinition typeDef)
private ClassDeclarationSyntax DeclareCocreatableClass(TypeDefinition typeDef, Context context)
{
bool canUseComImport = context.AllowMarshaling && !this.useSourceGenerators;

IdentifierNameSyntax name = IdentifierName(this.Reader.GetString(typeDef.Name));
Guid guid = this.FindGuidFromAttribute(typeDef) ?? throw new ArgumentException("Type does not have a GuidAttribute.");
SyntaxTokenList classModifiers = TokenList(TokenWithSpace(this.Visibility));
classModifiers = classModifiers.Add(TokenWithSpace(SyntaxKind.PartialKeyword));
ClassDeclarationSyntax result = ClassDeclaration(name.Identifier)
.WithModifiers(classModifiers)
.AddAttributeLists(AttributeList().AddAttributes(GUID(guid), ComImportAttributeSyntax));
.AddAttributeLists(AttributeList().AddAttributes(GUID(guid)).AddAttributes(canUseComImport ? [ComImportAttributeSyntax] : []));

if (!canUseComImport && !this.Options.ComInterop.UseIntPtrForComOutPointers)
{
string obsoleteMessage = context.AllowMarshaling
? $"COM source generators do not support direct instantiation of co-creatable classes. Use {name.Identifier}.CreateInstance<T> instead."
: $"Marshaling is disabled, so direct instantiation of co-creatable classes is not supported. Use {name.Identifier}.CreateInstance<T> instead.";

// Generate a private readonly field for the Guid
// private static readonly Guid CLSID_Foo = new Guid(...);
SyntaxToken clsidFieldName = Identifier($"CLSID_{name.Identifier}");
FieldDeclarationSyntax clsidField = FieldDeclaration(
VariableDeclaration(IdentifierName(nameof(Guid)))
.AddVariables(VariableDeclarator(clsidFieldName).WithInitializer(EqualsValueClause(GuidValue(guid)))))
.AddModifiers(TokenWithSpace(SyntaxKind.PrivateKeyword), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.ReadOnlyKeyword));
result = result.AddMembers(clsidField);

// If using source generators or marshalling is disabled, generate a constructor with obsolete attribute like this:
// [Obsolete("COM source generators do not support direct instantiation of co-creatable classes. Use CreateInstance<T> method instead.")]
// public Foo() { throw new NotSupportedException("COM source generators do not support direct instantiation of co-creatable classes. Use CreateInstance<T> method instead."); }
AttributeSyntax obsoleteAttribute =
Attribute(IdentifierName(nameof(ObsoleteAttribute)))
.AddArgumentListArguments(
AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(obsoleteMessage))));
ConstructorDeclarationSyntax constructor = ConstructorDeclaration(name.Identifier)
.AddModifiers(TokenWithSpace(SyntaxKind.PublicKeyword))
.AddAttributeLists(AttributeList().AddAttributes(obsoleteAttribute))
.WithBody(
Block(
ThrowStatement(
ObjectCreationExpression(IdentifierName(nameof(NotSupportedException)))
.WithArgumentList(
ArgumentList().AddArguments(
Argument(
LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(obsoleteMessage))))))));
result = result.AddMembers(constructor);

this.MainGenerator.TryGenerateExternMethod("CoCreateInstance", out IReadOnlyCollection<string> preciseApi);
this.MainGenerator.TryGenerateConstant("CLSCTX", out preciseApi);

if (context.AllowMarshaling)
{
// Then add the CreateInstance<T> method:
// public static T CreateInstance<T>() where T : class
// {
// PInvoke.CoCreateInstance<T>(CLSID_Foo, null, CLSCTX.CLSCTX_SERVER, out T ret).ThrowOnFailure();
// return ret;
// }
TypeParameterSyntax typeParameter = TypeParameter(Identifier("T"));
GenericNameSyntax genericName = GenericName("CreateInstance").AddTypeArgumentListArguments(IdentifierName("T"));
MethodDeclarationSyntax createInstanceMethod = MethodDeclaration(IdentifierName("T"), genericName.Identifier)
.AddModifiers(TokenWithSpace(SyntaxKind.PublicKeyword), TokenWithSpace(SyntaxKind.StaticKeyword))
.AddTypeParameterListParameters(typeParameter)
.AddConstraintClauses(
TypeParameterConstraintClause(IdentifierName("T"), SingletonSeparatedList<TypeParameterConstraintSyntax>(ClassOrStructConstraint(SyntaxKind.ClassConstraint))))
.WithBody(
Block(
ThrowOnHRFailure(
InvocationExpression(QualifiedName(ParseName($"{this.Win32NamespacePrefix}.{this.options.ClassName}"), GenericName("CoCreateInstance").AddTypeArgumentListArguments(IdentifierName("T"))))
.WithArgumentList(
ArgumentList().AddArguments(
Argument(IdentifierName(clsidFieldName)),
Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)),
Argument(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
QualifiedName(ParseName($"{this.Win32NamespacePrefix}.System.Com"), IdentifierName("CLSCTX")),
IdentifierName("CLSCTX_SERVER"))),
Argument(DeclarationExpression(IdentifierName("T").WithTrailingTrivia(Space), SingleVariableDesignation(Identifier("ret")))).WithRefKindKeyword(Token(SyntaxKind.OutKeyword))))),
ReturnStatement(IdentifierName("ret"))));
result = result.AddMembers(createInstanceMethod);
}
else
{
// Then add a CreateInstance<T> method that looks like this:
// public static HRESULT CreateInstance<T>(out T* instance) where T : unmanaged
// {
// return PInvoke.CoCreateInstance<T>(CLSID_Foo, null, CLSCTX.CLSCTX_SERVER, out instance);
// }
TypeParameterSyntax typeParameter = TypeParameter(Identifier("T"));
GenericNameSyntax genericName = GenericName("CreateInstance").AddTypeArgumentListArguments(IdentifierName("T"));
MethodDeclarationSyntax createInstanceMethod = MethodDeclaration(IdentifierName($"{this.Win32NamespacePrefix}.Foundation.HRESULT"), genericName.Identifier)
.AddModifiers(TokenWithSpace(SyntaxKind.PublicKeyword), TokenWithSpace(SyntaxKind.StaticKeyword), TokenWithSpace(SyntaxKind.UnsafeKeyword))
.AddTypeParameterListParameters(typeParameter)
.AddConstraintClauses(
TypeParameterConstraintClause(IdentifierName("T"), SingletonSeparatedList<TypeParameterConstraintSyntax>(TypeConstraint(IdentifierName("unmanaged")))))
.WithParameterList(
ParameterList().AddParameters(
Parameter(Identifier("instance"))
.WithType(PointerType(IdentifierName("T")))
.WithModifiers(TokenList(Token(SyntaxKind.OutKeyword)))))
.WithBody(
Block(
ReturnStatement(
InvocationExpression(QualifiedName(ParseName($"{this.Win32NamespacePrefix}.{this.options.ClassName}"), GenericName("CoCreateInstance").AddTypeArgumentListArguments(IdentifierName("T"))))
.WithArgumentList(
ArgumentList().AddArguments(
Argument(IdentifierName(clsidFieldName)),
Argument(LiteralExpression(SyntaxKind.NullLiteralExpression)),
Argument(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
QualifiedName(ParseName($"{this.Win32NamespacePrefix}.System.Com"), IdentifierName("CLSCTX")),
IdentifierName("CLSCTX_SERVER"))),
Argument(IdentifierName("instance")).WithRefKindKeyword(Token(SyntaxKind.OutKeyword)))))));
result = result.AddMembers(createInstanceMethod);
}
}

result = this.AddApiDocumentation(name.Identifier.ValueText, result);
return result;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,7 @@ private IReadOnlyList<ISymbol> FindTypeSymbolsIfAlreadyAvailable(string fullyQua
}
else if (this.IsEmptyStructWithGuid(typeDef))
{
typeDeclaration = this.DeclareCocreatableClass(typeDef);
typeDeclaration = this.DeclareCocreatableClass(typeDef, context);
}
else
{
Expand Down
28 changes: 28 additions & 0 deletions src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,34 @@ internal static ObjectCreationExpressionSyntax GuidValue(CustomAttribute guidAtt
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(k), k))));
}

internal static ObjectCreationExpressionSyntax GuidValue(Guid guid)
{
byte[] bytes = guid.ToByteArray();
uint a = BitConverter.ToUInt32(bytes, 0);
ushort b = BitConverter.ToUInt16(bytes, 4);
ushort c = BitConverter.ToUInt16(bytes, 6);
byte d = bytes[8];
byte e = bytes[9];
byte f = bytes[10];
byte g = bytes[11];
byte h = bytes[12];
byte i = bytes[13];
byte j = bytes[14];
byte k = bytes[15];
return ObjectCreationExpression(GuidTypeSyntax).AddArgumentListArguments(
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(a), a))),
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(b), b))),
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(c), c))),
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(d), d))),
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(e), e))),
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(f), f))),
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(g), g))),
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(h), h))),
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(i), i))),
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(j), j))),
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ToHex(k), k))));
}

internal static ExpressionSyntax IntPtrExpr(IntPtr value) => ObjectCreationExpression(IntPtrTypeSyntax).AddArgumentListArguments(
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(value.ToInt64()))));

Expand Down
15 changes: 15 additions & 0 deletions test/CsWin32Generator.Tests/CsWin32GeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,21 @@ public async Task TestPlatformCaseSensitivity(string platform)
await this.InvokeGeneratorAndCompile($"{nameof(this.TestPlatformCaseSensitivity)}_{platform}");
}

[Fact]
public async Task TestGenerateCoCreateableClass()
{
this.nativeMethods.Add("ShellLink");
await this.InvokeGeneratorAndCompileFromFact();

var shellLinkType = Assert.Single(this.FindGeneratedType("ShellLink"));

// Check that it does not have the ComImport attribute.
Assert.DoesNotContain(shellLinkType.AttributeLists, al => al.Attributes.Any(attr => attr.Name.ToString().Contains("ComImport")));

// Check that it contains a CreateInstance method
Assert.Contains(shellLinkType.DescendantNodes().OfType<MethodDeclarationSyntax>(), method => method.Identifier.Text == "CreateInstance");
}

[Theory]
[InlineData("IMFMediaKeySession", "get_KeySystem", "winmdroot.Foundation.BSTR* keySystem")]
[InlineData("AddPrinterW", "AddPrinter", "winmdroot.Foundation.PWSTR pName, uint Level, Span<byte> pPrinter")]
Expand Down
10 changes: 10 additions & 0 deletions test/GenerationSandbox.BuildTask.Tests/COMTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma warning disable IDE0005
#pragma warning disable SA1201, SA1512, SA1005, SA1507, SA1515, SA1403, SA1402, SA1411, SA1300, SA1313, SA1134, SA1307, SA1308

using System.ComponentModel;
using System.Net.NetworkInformation;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
Expand All @@ -17,6 +18,7 @@
using Windows.Win32.Graphics.Direct3D11;
using Windows.Win32.System.Com;
using Windows.Win32.System.WinRT.Composition;
using Windows.Win32.UI.Shell;

[Trait("WindowsOnly", "true")]
public partial class COMTests
Expand Down Expand Up @@ -70,4 +72,12 @@ public async Task CanInteropWithICompositorInterop()
Assert.Skip("Skipping due to UnauthorizedAccessException.");
}
}

[Fact]
public void CocreatableClassesWithImplicitInterfaces()
{
var shellLinkW = ShellLink.CreateInstance<IShellLinkW>();
var persistFile = (IPersistFile)shellLinkW;
Assert.NotNull(persistFile);
}
}
5 changes: 4 additions & 1 deletion test/GenerationSandbox.BuildTask.Tests/NativeMethods.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,7 @@ WINTRUST_DATA
WINTRUST_FILE_INFO
WinVerifyTrust
WM_HOTKEY
WNDCLASSW
WNDCLASSW
ShellLink
CoCreateInstance
IShellLinkW
18 changes: 17 additions & 1 deletion test/GenerationSandbox.Unmarshalled.Tests/COMTests.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

#pragma warning disable IDE0005
#pragma warning disable IDE0005,SA1202

using System.Runtime.InteropServices;
using Windows.Win32;
using Windows.Win32.System.Com;
using Windows.Win32.UI.Shell;

public class COMTests
{
Expand All @@ -19,5 +21,19 @@ public void COMStaticGuid()
private static Guid GetGuid<T>()
where T : IComIID
=> T.Guid;

[Trait("WindowsOnly", "true")]
[Fact]
public unsafe void CocreatableClassesWithImplicitInterfaces()
{
Assert.SkipUnless(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "Test calls Windows-specific APIs");

ShellLink.CreateInstance(out IShellLinkW* shellLinkWPtr).ThrowOnFailure();
shellLinkWPtr->QueryInterface(typeof(IPersistFile).GUID, out void* ppv).ThrowOnFailure();
IPersistFile* persistFilePtr = (IPersistFile*)ppv;
Assert.NotNull(persistFilePtr);
persistFilePtr->Release();
shellLinkWPtr->Release();
}
#endif
}
2 changes: 2 additions & 0 deletions test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
IEventSubscription
IPersistFile
IStream
ShellLink
IShellLinkW
19 changes: 19 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/COMTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -487,4 +487,23 @@ public void IUnknown_Derived_QueryInterfaceGenericHelper()
this.FindGeneratedMethod("QueryInterface"),
m => m.Parent is StructDeclarationSyntax { Identifier.Text: "ITypeLib" } && m.TypeParameterList?.Parameters.Count == 1);
}

[Theory, PairwiseData]
public void TestGenerateCoCreateableClass(bool useIntPtrForComOutPtr)
{
this.generator = this.CreateGenerator(new GeneratorOptions { AllowMarshaling = false, ComInterop = new GeneratorOptions.ComInteropOptions { UseIntPtrForComOutPointers = useIntPtrForComOutPtr } });

this.GenerateApi("ShellLink");

var shellLinkType = Assert.Single(this.FindGeneratedType("ShellLink"));

// Check that it does not have the ComImport attribute.
Assert.DoesNotContain(shellLinkType.AttributeLists, al => al.Attributes.Any(attr => attr.Name.ToString().Contains("ComImport")));

if (!useIntPtrForComOutPtr)
{
// Check that it contains a CreateInstance method
Assert.Contains(shellLinkType.DescendantNodes().OfType<MethodDeclarationSyntax>(), method => method.Identifier.Text == "CreateInstance");
}
}
}
Loading