Skip to content

Commit

Permalink
Generate implements extension for any interface implemented by Object… (
Browse files Browse the repository at this point in the history
#1749)

…Type classes
  • Loading branch information
pekkah authored Feb 23, 2024
1 parent 4a804f3 commit e3bd658
Show file tree
Hide file tree
Showing 46 changed files with 780 additions and 36 deletions.
47 changes: 44 additions & 3 deletions samples/GraphQL.Samples.SG.Subscription/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
options.AddGeneratedTypes(types =>
{
// Add generated controllers
types
.AddQueryController()
.AddSubscriptionController();
types.AddGlobalTypes();
});
});

Expand Down Expand Up @@ -45,11 +43,54 @@ public static async IAsyncEnumerable<int> Random(int from, int to, int count, [E
await Task.Delay(500, cancellationToken);
}
}

/// <summary>
/// This is subscription field producing values implementing interface
/// </summary>
/// <returns></returns>
public static async IAsyncEnumerable<IValue> RandomValues(int count, [EnumeratorCancellation] CancellationToken cancellationToken)
{
var r = new Random();

for (var i = 0; i < count; i++)
{
cancellationToken.ThrowIfCancellationRequested();
var next = r.Next(0, 2);
if (next == 0)
yield return new IntValue { Value = next };
else
yield return new StringValue { Value = next.ToString("F") };

await Task.Delay(500, cancellationToken);
}
}
}

[ObjectType]
public static partial class Query
{
// this is required as the graphiql will error without a query field
public static string Hello() => "Hello World!";
}

[InterfaceType]
public partial interface IValue
{
public string Hello { get; }
}

[ObjectType]
public partial class IntValue : IValue
{
public required int Value { get; init; }

public string Hello => GetType().Name;
}

[ObjectType]
public partial class StringValue : IValue
{
public required string Value { get; init; }

public string Hello => GetType().Name;
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"Logging": {
"LogLevel": {
"Default": "Information",
"Microsoft.AspNetCore": "Warning"
"Microsoft.AspNetCore": "Warning",
"Tanka": "Debug"
}
}
}
11 changes: 11 additions & 0 deletions src/GraphQL.Server.SourceGenerators/BaseDefinition.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using System.Collections.Generic;

namespace Tanka.GraphQL.Server.SourceGenerators;

public record BaseDefinition(
bool IsClass,
string Identifier,
string Namespace,
string? GraphQLName,
IReadOnlyList<ObjectPropertyDefinition> Properties,
IReadOnlyList<ObjectMethodDefinition> Methods);
16 changes: 16 additions & 0 deletions src/GraphQL.Server.SourceGenerators/NamedTypeExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,20 @@ public static string GetName(SemanticModel semanticModel, TypeDeclarationSyntax

return graphQLName;
}

public static string GetName(INamedTypeSymbol namedType)
{
var graphQLNameAttribute = namedType.GetAttributes()
.FirstOrDefault(attribute => attribute.AttributeClass?.Name == "GraphQLNameAttribute");

if (graphQLNameAttribute != null)
{
var graphQLName = (string?)graphQLNameAttribute.ConstructorArguments[0].Value ?? string.Empty;

if (!string.IsNullOrWhiteSpace(graphQLName))
return graphQLName;
}

return namedType.Name;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@ public class ObjectControllerDefinition: TypeDefinition, IEquatable<ObjectContro
public List<ObjectPropertyDefinition> Properties { get; init; } = [];

public List<ObjectMethodDefinition> Methods { get; init; } = [];

public IEnumerable<ObjectMethodDefinition> Subscribers => Methods.Where(m => m.IsSubscription);

public IReadOnlyList<ObjectMethodDefinition> AllMethods =>
Methods.Concat(Implements.SelectMany(i => i.Methods)).ToList();

public IEnumerable<ObjectMethodDefinition> AllSubscribers => AllMethods.Where(m => m.IsSubscription);

public ParentClass? ParentClass { get; init; }

public bool IsStatic { get; init; }

public IReadOnlyList<string> Usings { get; init; } = [];

public IReadOnlyList<BaseDefinition> Implements { get; set; } = [];

public bool Equals(ObjectControllerDefinition? other)
{
if (ReferenceEquals(null, other)) return false;
Expand All @@ -27,7 +32,8 @@ public bool Equals(ObjectControllerDefinition? other)
&& Methods.SequenceEqual(other.Methods)
&& ParentClass?.Equals(other.ParentClass) == true
&& IsStatic == other.IsStatic
&& Usings.SequenceEqual(other.Usings);
&& Usings.SequenceEqual(other.Usings)
&& Implements.SequenceEqual(other.Implements);
}

public override bool Equals(object? obj)
Expand Down
1 change: 1 addition & 0 deletions src/GraphQL.Server.SourceGenerators/ObjectTypeEmitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public static void Emit(
Methods = definition.Methods,
Properties = definition.Properties,
Usings = definition.Usings,
Implements = definition.Implements,
NamedTypeExtension = NamedTypeExtension.Render(
"class",
definition.TargetType,
Expand Down
8 changes: 7 additions & 1 deletion src/GraphQL.Server.SourceGenerators/ObjectTypeParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,16 @@ public static ObjectControllerDefinition ParseObjectControllerDefinition(
Properties = properties,
Methods = methods,
ParentClass = TypeHelper.GetParentClasses(classDeclaration),
Usings = TypeHelper.GetUsings(classDeclaration)
Usings = TypeHelper.GetUsings(classDeclaration),
Implements = GetImplements(context.SemanticModel, classDeclaration),
};
}

private static IReadOnlyList<BaseDefinition> GetImplements(SemanticModel model, ClassDeclarationSyntax classDeclaration)
{
return SymbolHelper.GetImplements(model.GetDeclaredSymbol(classDeclaration) ?? throw new InvalidOperationException());
}

private static (List<ObjectPropertyDefinition> Properties, List<ObjectMethodDefinition> Methods) ParseMembers(
GeneratorAttributeSyntaxContext context,
ClassDeclarationSyntax classDeclaration)
Expand Down
178 changes: 178 additions & 0 deletions src/GraphQL.Server.SourceGenerators/SymbolHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
using System.Collections.Generic;
using System.Linq;

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

using Tanka.GraphQL.Server.SourceGenerators.Internal;

namespace Tanka.GraphQL.Server.SourceGenerators;

public static class SymbolHelper
{
public static (List<ObjectPropertyDefinition> Properties, List<ObjectMethodDefinition> Methods) ParseMembers(
INamedTypeSymbol classSymbol)
{
var properties = new List<ObjectPropertyDefinition>();
var methods = new List<ObjectMethodDefinition>();

foreach (ISymbol memberSymbol in classSymbol
.GetMembers()
.Where(m => m.DeclaredAccessibility == Accessibility.Public))
{
if (memberSymbol is IPropertySymbol property)
{
var propertyDefinition = new ObjectPropertyDefinition()
{
IsStatic = property.IsStatic,
Name = property.Name,
ReturnType = property.Type.ToString(),
ClosestMatchingGraphQLTypeName = GetClosestMatchingGraphQLTypeName(property.Type),
};
properties.Add(propertyDefinition);
}
else if (memberSymbol is IMethodSymbol { MethodKind: not (MethodKind.PropertyGet or MethodKind.PropertySet) } method)
{
var methodDefinition = new ObjectMethodDefinition()
{
IsStatic = method.IsStatic,
Name = method.Name,
ReturnType = UnwrapTaskType(method.ReturnType).ToString(),
ClosestMatchingGraphQLTypeName = GetClosestMatchingGraphQLTypeName(method.ReturnType),
Type = GetMethodType(method),
Parameters = method.Parameters
.Select(p => new ParameterDefinition()
{
Name = p.Name,
Type = p.Type.ToString(),
ClosestMatchingGraphQLTypeName = GetClosestMatchingGraphQLTypeName(UnwrapTaskType(p.Type)),
IsNullable = p.NullableAnnotation == NullableAnnotation.Annotated,
IsPrimitive = IsPrimitiveType(p.Type!),
FromArguments = HasAttribute(p, "FromArgumentsAttribute"),
FromServices = HasAttribute(p, "FromServicesAttribute")
}).ToEquatableArray()
};
methods.Add(methodDefinition);
}
}

return (properties, methods);
}

public static bool HasAttribute(ISymbol symbol, string attributeName)
{
return symbol.GetAttributes()
.Any(attr => attr.AttributeClass?.Name == attributeName
&& attr.AttributeClass.ContainingNamespace.ToDisplayString().StartsWith("Tanka"));
}


public static bool IsPrimitiveType(ITypeSymbol typeSymbol)
{
// Define a list of C# primitive types
SpecialType[] primitiveTypes = new SpecialType[]
{
SpecialType.System_Boolean,
SpecialType.System_Byte,
SpecialType.System_SByte,
SpecialType.System_Char,
SpecialType.System_Decimal,
SpecialType.System_Double,
SpecialType.System_Single, // float
SpecialType.System_Int32,
SpecialType.System_UInt32,
SpecialType.System_Int64,
SpecialType.System_UInt64,
SpecialType.System_Int16,
SpecialType.System_UInt16,
SpecialType.System_String
};

// Check if the type symbol's special type is a primitive type
return primitiveTypes.Contains(typeSymbol.SpecialType);
}


private static ITypeSymbol UnwrapTaskType(ITypeSymbol possibleTaskType)
{
if (possibleTaskType is INamedTypeSymbol namedType)
{
if (namedType.ConstructedFrom?.ToString() == "System.Threading.Tasks.Task`1" ||
namedType.ConstructedFrom?.ToString() == "System.Threading.Tasks.ValueTask`1" ||
namedType.ConstructedFrom?.ToString() == "System.Collections.Generic.IAsyncEnumerable`1")
{
return namedType.TypeArguments[0];
}
}

return possibleTaskType;
}

private static string GetClosestMatchingGraphQLTypeName(ITypeSymbol type)
{
return TypeHelper.GetGraphQLTypeName(UnwrapTaskType(type));
}

private static MethodType GetMethodType(IMethodSymbol method)
{
var returnType = method.ReturnType;

if (returnType.SpecialType == SpecialType.System_Void)
return MethodType.Void;

if (returnType is INamedTypeSymbol namedType)
{
switch (namedType.ConstructedFrom?.ToString())
{
case "System.Threading.Tasks.Task":
return namedType.TypeArguments.Length == 0 ? MethodType.Task : MethodType.TaskOfT;
case "System.Threading.Tasks.ValueTask":
return namedType.TypeArguments.Length == 0 ? MethodType.ValueTask : MethodType.ValueTaskOfT;
case "System.Collections.Generic.IAsyncEnumerable`1":
return MethodType.AsyncEnumerableOfT;
case "System.Collections.Generic.IEnumerable`1":
return MethodType.EnumerableT;
default:
return MethodType.T;
}
}

return MethodType.Unknown;
}

public static IReadOnlyList<BaseDefinition> GetImplements(INamedTypeSymbol namedTypeSymbol)
{
var baseDefinitions = new List<BaseDefinition>();

/*if (namedTypeSymbol.BaseType != null && namedTypeSymbol.BaseType is not { SpecialType: SpecialType.System_Object })
{
baseDefinitions.Add(
GetBaseDefinition(namedTypeSymbol.BaseType)
);
}*/

baseDefinitions.AddRange(
namedTypeSymbol
.Interfaces
.Where(i => HasAttribute(i, "InterfaceTypeAttribute"))
.Select(GetBaseDefinition)
);

return baseDefinitions;
}

public static BaseDefinition GetBaseDefinition(INamedTypeSymbol baseNamedTypeSymbol)
{
var (properties, methods) = SymbolHelper.ParseMembers(baseNamedTypeSymbol);

return new BaseDefinition(
baseNamedTypeSymbol.TypeKind == TypeKind.Class,
baseNamedTypeSymbol.Name,
baseNamedTypeSymbol.ContainingNamespace.ToDisplayString(),
NamedTypeExtension.GetName(baseNamedTypeSymbol),
properties,
methods
);
}

}
20 changes: 19 additions & 1 deletion src/GraphQL.Server.SourceGenerators/Templates/ObjectTemplate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,14 @@ public static class {{name}}ControllerExtensions
{{~ end ~}}
{{~ end ~}}
));
{{~ if implements.size > 0 ~}}
builder.Builder.Configure(options => options.Builder.Add(
"""
extend type {{name}} implements {{ for base in implements ~}}{{base.graph_qlname}}{{~ if for.last ~}}{{~ else }} & {{ end ~}}{{~ end }}
"""));
{{~ end ~}}
return builder;
}
}
Expand Down Expand Up @@ -211,13 +218,24 @@ public required IEnumerable<string> Usings

public required string? TypeName { get; set; }

public IReadOnlyList<ObjectPropertyDefinition> AllProperties =>
Properties.Concat(Implements.SelectMany(i => i.Properties)).ToList();

public required IEnumerable<ObjectPropertyDefinition> Properties { get; set; } = [];

public IReadOnlyList<ObjectMethodDefinition> AllMethods =>
Methods.Concat(Implements.SelectMany(i => i.Methods))
.OrderBy(m => m.Name)
.ToList();

public required IEnumerable<ObjectMethodDefinition> Methods { get; set; } = [];


public IEnumerable<ObjectMethodDefinition> Subscribers => Methods.Where(m => m.IsSubscription);

public required string NamedTypeExtension { get; set; }

public IReadOnlyList<BaseDefinition> Implements { get; set; } = [];

public string Render()
{
Expand Down
Loading

0 comments on commit e3bd658

Please sign in to comment.