Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1741,11 +1741,9 @@ private void WriteAssignTypeRef(
{
case SchemaTypeReferenceKind.ExtendedTypeReference:
Writer.WriteIndentedLine(
"{0} = typeInspector.GetTypeRef(typeof({1}), {2}.{3}){4}",
"{0} = {1}{2}",
propertyName,
typeReference.TypeString,
WellKnownTypes.TypeContext,
context,
CreateTypeDefinitionReferenceExpression(typeReference, context),
lineEnd);
break;

Expand All @@ -1768,10 +1766,8 @@ private void WriteAssignTypeRef(
using (Writer.IncreaseIndent())
{
Writer.WriteIndentedLine(
"typeInspector.GetTypeRef(typeof({0}), {1}.{2}),",
typeReference.TypeString,
WellKnownTypes.TypeContext,
context);
"{0},",
CreateTypeDefinitionReferenceExpression(typeReference, context));
Writer.WriteIndentedLine(
"{0}){1}",
typeReference.TypeStructure,
Expand All @@ -1784,6 +1780,28 @@ private void WriteAssignTypeRef(
}
}

private static string CreateTypeDefinitionReferenceExpression(
SchemaTypeReference typeReference,
string context)
{
if (typeReference.Nullability is { } nullability)
{
return string.Format(
"global::{0}.Create(typeInspector.GetType(typeof({1}), {2}), {3}.{4})",
WellKnownTypes.TypeReference,
typeReference.TypeString,
nullability,
WellKnownTypes.TypeContext,
context);
}

return string.Format(
"typeInspector.GetTypeRef(typeof({0}), {1}.{2})",
typeReference.TypeString,
WellKnownTypes.TypeContext,
context);
}

private static string GetResolverArgumentAssignments(int parameterCount)
{
if (parameterCount == 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,24 @@ public static SchemaTypeReference CreateTypeReference(
}

// Next, we create a key that describes the type and ensures we are only executing the type factory once.
var (typeStructure, typeDefinition, isSimpleType) = CreateTypeKey(unwrapped);
var (typeStructure, typeDefinition, nullability, isSimpleType) = CreateTypeKey(unwrapped);

if (isSimpleType)
{
return new SchemaTypeReference(
SchemaTypeReferenceKind.ExtendedTypeReference,
typeDefinition);
typeDefinition,
nullability: nullability);
}

return new SchemaTypeReference(
SchemaTypeReferenceKind.FactoryTypeReference,
typeDefinition,
typeStructure);
typeStructure,
nullability);
}

private static (string TypeStructure, string TypeDefinition, bool IsSimpleType) CreateTypeKey(
private static (string TypeStructure, string TypeDefinition, string? Nullability, bool IsSimpleType) CreateTypeKey(
ITypeSymbol unwrappedType)
{
bool isNullable;
Expand Down Expand Up @@ -127,7 +129,7 @@ private static (string TypeStructure, string TypeDefinition, bool IsSimpleType)

if (underlyingType is INamedTypeSymbol namedType && TryGetListElementType(namedType, out var listElementType))
{
var (typeStructure, typeDefinition, _) = CreateTypeKey(listElementType);
var (typeStructure, typeDefinition, elementNullability, _) = CreateTypeKey(listElementType);

if (isNullable)
{
Expand All @@ -145,12 +147,12 @@ private static (string TypeStructure, string TypeDefinition, bool IsSimpleType)
typeStructure);
}

return (typeStructure, typeDefinition, false);
return (typeStructure, typeDefinition, elementNullability, false);
}

if (IsArrayType(unwrappedType, out var arrayElementType))
{
var (typeStructure, typeDefinition, _) = CreateTypeKey(arrayElementType);
var (typeStructure, typeDefinition, elementNullability, _) = CreateTypeKey(arrayElementType);

if (isNullable)
{
Expand All @@ -168,19 +170,22 @@ private static (string TypeStructure, string TypeDefinition, bool IsSimpleType)
typeStructure);
}

return (typeStructure, typeDefinition, false);
return (typeStructure, typeDefinition, elementNullability, false);
}

var typeName = GetFullyQualifiedTypeName(underlyingType);
var compliantTypeName = MakeGraphQLCompliant(typeName);
var nullability = ShouldPreserveNullability(underlyingType)
? CreateNullabilityLiteral(underlyingType, isNullable)
: null;

if (isNullable)
{
var typeStructure = string.Format(
"new global::{0}(\"{1}\")",
WellKnownTypes.NamedTypeNode,
compliantTypeName);
return (typeStructure, typeName, IsSimpleType: unwrappedType.IsReferenceType);
return (typeStructure, typeName, nullability, IsSimpleType: unwrappedType.IsReferenceType);
}
else
{
Expand All @@ -189,10 +194,69 @@ private static (string TypeStructure, string TypeDefinition, bool IsSimpleType)
WellKnownTypes.NonNullTypeNode,
WellKnownTypes.NamedTypeNode,
compliantTypeName);
return (typeStructure, typeName, IsSimpleType: false);
return (typeStructure, typeName, nullability, IsSimpleType: false);
}
}

private static bool ShouldPreserveNullability(ITypeSymbol typeSymbol)
=> typeSymbol is INamedTypeSymbol { IsGenericType: true };

private static string CreateNullabilityLiteral(
ITypeSymbol typeSymbol,
bool isNullable)
{
var flags = new List<string>();
CollectNullability(typeSymbol, isNullable, flags);

return flags.Count == 0
? "[]"
: $"[{string.Join(", ", flags)}]";
}

private static void CollectNullability(
ITypeSymbol typeSymbol,
bool isNullable,
List<string> flags)
{
flags.Add(isNullable ? "true" : "false");

if (typeSymbol is not INamedTypeSymbol namedType || !namedType.IsGenericType)
{
return;
}

// Nullable<T> is represented by the wrapped value type and a nullable root flag.
if (namedType.OriginalDefinition.SpecialType is SpecialType.System_Nullable_T)
{
if (namedType.TypeArguments.Length == 1
&& namedType.TypeArguments[0] is INamedTypeSymbol innerNamed
&& innerNamed.IsGenericType)
{
foreach (var argument in innerNamed.TypeArguments)
{
CollectNullability(argument, IsNullable(argument), flags);
}
}
return;
}

foreach (var argument in namedType.TypeArguments)
{
CollectNullability(argument, IsNullable(argument), flags);
}
}

private static bool IsNullable(ITypeSymbol typeSymbol)
{
if (typeSymbol is INamedTypeSymbol { OriginalDefinition.SpecialType: SpecialType.System_Nullable_T })
{
return true;
}

return typeSymbol.IsReferenceType
&& typeSymbol.NullableAnnotation == NullableAnnotation.Annotated;
}

private static ITypeSymbol? UnwrapListElementType(ITypeSymbol typeSymbol)
{
if (typeSymbol is IArrayTypeSymbol arrayType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@ public readonly struct SchemaTypeReference
public SchemaTypeReference(
SchemaTypeReferenceKind kind,
string typeString,
string? typeStructure = null)
string? typeStructure = null,
string? nullability = null)
{
Kind = kind;
TypeString = typeString;
TypeStructure = typeStructure;
Nullability = nullability;
}

public SchemaTypeReferenceKind Kind { get; }

public string TypeString { get; }

public string? TypeStructure { get; }

public string? Nullability { get; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ public static IObjectFieldDescriptor UseOffsetPaging(
if (currentTypeRef is FactoryTypeReference factoryTypeRef
&& factoryTypeRef.TypeStructure.IsListType())
{
typeRef = factoryTypeRef.TypeDefinition;
// Preserve list element nullability from the generated type structure.
typeRef = factoryTypeRef.GetElementType();
}

if (typeRef is null
Expand Down Expand Up @@ -261,7 +262,8 @@ public static IInterfaceFieldDescriptor UseOffsetPaging(
if (currentTypeRef is FactoryTypeReference factoryTypeRef
&& factoryTypeRef.TypeStructure.IsListType())
{
typeRef = factoryTypeRef.TypeDefinition;
// Preserve list element nullability from the generated type structure.
typeRef = factoryTypeRef.GetElementType();
}

if (typeRef is null
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using HotChocolate.Types.Descriptors;
using HotChocolate.Types;

namespace HotChocolate.Configuration;

Expand All @@ -9,15 +10,16 @@ internal sealed class SourceGeneratorTypeReferenceHandler(
{
private readonly ExtendedTypeReferenceHandler _innerHandler = new(context.TypeInspector);

private readonly HashSet<string> _handled = [];
private readonly HashSet<(string Key, int TypeHash, TypeContext Context)> _handled = [];

public TypeReferenceKind Kind => TypeReferenceKind.Factory;

public void Handle(ITypeRegistrar typeRegistrar, TypeReference typeReference)
{
var typeRef = (FactoryTypeReference)typeReference;
var marker = (typeRef.Key, typeRef.TypeDefinition.GetHashCode(), typeRef.TypeDefinition.Context);

if (_handled.Add(typeRef.Key))
if (_handled.Add(marker))
{
typeRegistry.Register(typeRef, typeRef.TypeDefinition);
_innerHandler.Handle(typeRegistrar, typeRef.TypeDefinition);
Expand Down
2 changes: 1 addition & 1 deletion src/HotChocolate/Core/src/Types/Types/Scalars/AnyType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ protected override void OnCoerceOutputValue(JsonElement runtimeValue, ResultElem
case JsonValueKind.String:
{
var value = JsonMarshal.GetRawUtf8Value(runtimeValue);
resultValue.SetStringValue(value[1..^1]);
resultValue.SetStringValue(value[1..^1], isEncoded: true);
break;
}

Expand Down
Loading
Loading