Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -371,21 +371,21 @@ public static ResolverParameterKind GetParameterKind(
return ResolverParameterKind.HttpResponse;
}

if (parameter.IsGlobalState(out key))
if (parameter.IsGlobalState(compilation, out key))
{
return parameter.IsSetState()
? ResolverParameterKind.SetGlobalState
: ResolverParameterKind.GetGlobalState;
}

if (parameter.IsScopedState(out key))
if (parameter.IsScopedState(compilation, out key))
{
return parameter.IsSetState()
? ResolverParameterKind.SetScopedState
: ResolverParameterKind.GetScopedState;
}

if (parameter.IsLocalState(out key))
if (parameter.IsLocalState(compilation, out key))
{
return parameter.IsSetState()
? ResolverParameterKind.SetLocalState
Expand Down
219 changes: 152 additions & 67 deletions src/HotChocolate/Core/src/Types.Analyzers/Helpers/SymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -789,48 +789,45 @@ public static bool IsSelection(this IParameterSymbol parameter)

public static bool IsGlobalState(
this IParameterSymbol parameter,
Compilation compilation,
[NotNullWhen(true)] out string? key)
{
key = null;
=> parameter.TryGetStateKey("HotChocolate.GlobalStateAttribute", compilation, out key);

public static bool IsScopedState(
this IParameterSymbol parameter,
Compilation compilation,
[NotNullWhen(true)] out string? key)
=> parameter.TryGetStateKey("HotChocolate.ScopedStateAttribute", compilation, out key);

public static bool IsLocalState(
this IParameterSymbol parameter,
Compilation compilation,
[NotNullWhen(true)] out string? key)
=> parameter.TryGetStateKey("HotChocolate.LocalStateAttribute", compilation, out key);

public static bool IsEventMessage(
this IParameterSymbol parameter)
{
foreach (var attributeData in parameter.GetAttributes())
{
if (IsOrInheritsFrom(attributeData.AttributeClass, "HotChocolate.GlobalStateAttribute"))
if (attributeData.AttributeClass?.ToDisplayString() == WellKnownAttributes.EventMessageAttribute)
{
if (attributeData.ConstructorArguments.Length == 1
&& attributeData.ConstructorArguments[0].Kind == TypedConstantKind.Primitive
&& attributeData.ConstructorArguments[0].Value is string keyValue)
{
key = keyValue;
return true;
}

foreach (var namedArg in attributeData.NamedArguments)
{
if (namedArg is { Key: "Key", Value.Value: string namedKeyValue })
{
key = namedKeyValue;
return true;
}
}

key = parameter.Name;
return true;
}
}

return false;
}

public static bool IsScopedState(
public static bool IsService(
this IParameterSymbol parameter,
[NotNullWhen(true)] out string? key)
out string? key)
{
key = null;

foreach (var attributeData in parameter.GetAttributes())
{
if (IsOrInheritsFrom(attributeData.AttributeClass, "HotChocolate.ScopedStateAttribute"))
if (attributeData.AttributeClass?.ToDisplayString() == WellKnownAttributes.ServiceAttribute)
{
if (attributeData.ConstructorArguments.Length == 1
&& attributeData.ConstructorArguments[0].Kind == TypedConstantKind.Primitive
Expand All @@ -849,98 +846,186 @@ public static bool IsScopedState(
}
}

key = parameter.Name;
key = null;
return true;
}
}

return false;
}

public static bool IsLocalState(
private static bool TryGetStateKey(
this IParameterSymbol parameter,
string stateAttributeType,
Compilation compilation,
[NotNullWhen(true)] out string? key)
{
key = null;

foreach (var attributeData in parameter.GetAttributes())
{
if (IsOrInheritsFrom(attributeData.AttributeClass, "HotChocolate.LocalStateAttribute"))
if (!IsOrInheritsFrom(attributeData.AttributeClass, stateAttributeType))
{
if (attributeData.ConstructorArguments.Length == 1
&& attributeData.ConstructorArguments[0].Kind == TypedConstantKind.Primitive
&& attributeData.ConstructorArguments[0].Value is string keyValue)
{
key = keyValue;
return true;
}

foreach (var namedArg in attributeData.NamedArguments)
{
if (namedArg is { Key: "Key", Value.Value: string namedKeyValue })
{
key = namedKeyValue;
return true;
}
}
continue;
}

key = parameter.Name;
if (TryGetStateKeyFromAttributeUsage(attributeData, out key)
|| TryGetStateKeyFromAttributeDeclaration(attributeData, compilation, out key))
{
return true;
}

key = parameter.Name;
return true;
}

return false;
}

public static bool IsEventMessage(
this IParameterSymbol parameter)
private static bool TryGetStateKeyFromAttributeUsage(
AttributeData attributeData,
[NotNullWhen(true)] out string? key)
{
foreach (var attributeData in parameter.GetAttributes())
key = null;

if (attributeData.ConstructorArguments.Length == 1
&& attributeData.ConstructorArguments[0].Kind == TypedConstantKind.Primitive
&& attributeData.ConstructorArguments[0].Value is string keyValue)
{
if (attributeData.AttributeClass?.ToDisplayString() == WellKnownAttributes.EventMessageAttribute)
key = keyValue;
return true;
}

foreach (var namedArg in attributeData.NamedArguments)
{
if (namedArg is { Key: "Key", Value.Value: string namedKeyValue })
{
key = namedKeyValue;
return true;
}
}

return false;
}

public static bool IsService(
this IParameterSymbol parameter,
out string? key)
private static bool TryGetStateKeyFromAttributeDeclaration(
AttributeData attributeData,
Compilation compilation,
[NotNullWhen(true)] out string? key)
{
key = null;

foreach (var attributeData in parameter.GetAttributes())
var constructor = attributeData.AttributeConstructor;
if (constructor is not null
&& TryGetStateKeyFromConstructorDeclaration(constructor, compilation, out key))
{
if (attributeData.AttributeClass?.ToDisplayString() == WellKnownAttributes.ServiceAttribute)
return true;
}

var attributeType = attributeData.AttributeClass;
if (attributeType is null)
{
return false;
}

foreach (var syntaxReference in attributeType.DeclaringSyntaxReferences)
{
if (syntaxReference.GetSyntax() is TypeDeclarationSyntax declaration
&& TryGetStateKeyFromPrimaryConstructorBaseCall(declaration, compilation, out key))
{
if (attributeData.ConstructorArguments.Length == 1
&& attributeData.ConstructorArguments[0].Kind == TypedConstantKind.Primitive
&& attributeData.ConstructorArguments[0].Value is string keyValue)
{
key = keyValue;
return true;
}
return true;
}
}

foreach (var namedArg in attributeData.NamedArguments)
return false;
}

private static bool TryGetStateKeyFromConstructorDeclaration(
IMethodSymbol constructor,
Compilation compilation,
[NotNullWhen(true)] out string? key)
{
key = null;

foreach (var syntaxReference in constructor.DeclaringSyntaxReferences)
{
if (syntaxReference.GetSyntax() is ConstructorDeclarationSyntax declaration
&& declaration.Initializer is
{
if (namedArg is { Key: "Key", Value.Value: string namedKeyValue })
{
key = namedKeyValue;
return true;
}
RawKind: (int)SyntaxKind.BaseConstructorInitializer,
ArgumentList: { } argumentList
}
&& TryGetConstantStringFromArgumentList(
argumentList,
declaration.SyntaxTree,
compilation,
out key))
{
return true;
}
}

key = null;
return false;
}

private static bool TryGetStateKeyFromPrimaryConstructorBaseCall(
TypeDeclarationSyntax declaration,
Compilation compilation,
[NotNullWhen(true)] out string? key)
{
key = null;

if (declaration.BaseList is null)
{
return false;
}

foreach (var baseType in declaration.BaseList.Types)
{
var argumentList = baseType.ChildNodes().OfType<ArgumentListSyntax>().FirstOrDefault();
if (argumentList is null)
{
continue;
}

if (TryGetConstantStringFromArgumentList(
argumentList,
declaration.SyntaxTree,
compilation,
out key))
{
return true;
}
}

return false;
}

private static bool TryGetConstantStringFromArgumentList(
ArgumentListSyntax argumentList,
SyntaxTree syntaxTree,
Compilation compilation,
[NotNullWhen(true)] out string? key)
{
key = null;

if (argumentList.Arguments.Count != 1)
{
return false;
}

var model = compilation.GetSemanticModel(syntaxTree);
var constantValue = model.GetConstantValue(argumentList.Arguments[0].Expression);

if (constantValue is { HasValue: true, Value: string keyValue })
{
key = keyValue;
return true;
}

return false;
}

public static bool IsArgument(
this IParameterSymbol parameter,
out string? key)
Expand Down
66 changes: 66 additions & 0 deletions src/HotChocolate/Core/test/Types.Analyzers.Tests/ResolverTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,72 @@ internal class Test;
""").MatchMarkdownAsync();
}

[Fact]
public async Task GenerateSource_ResolverWithLocalStateDerivedAttribute_PrimaryConstructorBaseKey_MatchesSnapshot()
{
await TestHelper.GetGeneratedSourceSnapshot(
"""
using System;
using HotChocolate;
using HotChocolate.Types;

namespace TestNamespace;

[ObjectType<Test>]
internal static partial class TestType
{
public static string GetTest([ScopeState] string scope)
{
return scope;
}
}

[AttributeUsage(AttributeTargets.Parameter)]
public sealed class ScopeStateAttribute()
: LocalStateAttribute(LookupKey)
{
public const string LookupKey = "ScopeState";
}

internal class Test;
""").MatchMarkdownAsync();
}

[Fact]
public async Task GenerateSource_ResolverWithScopedStateDerivedAttribute_BaseConstructorKey_MatchesSnapshot()
{
await TestHelper.GetGeneratedSourceSnapshot(
"""
using System;
using HotChocolate;
using HotChocolate.Types;

namespace TestNamespace;

[ObjectType<Test>]
internal static partial class TestType
{
public static string GetTest([ScopeState] string scope)
{
return scope;
}
}

[AttributeUsage(AttributeTargets.Parameter)]
public sealed class ScopeStateAttribute : ScopedStateAttribute
{
public const string LookupKey = "ScopeState";

public ScopeStateAttribute()
: base(LookupKey)
{
}
}

internal class Test;
""").MatchMarkdownAsync();
}

[Fact]
public async Task GenerateSource_ResolverWithLocalStateSetStateArgument_MatchesSnapshot()
{
Expand Down
Loading
Loading