Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public void Handle(ITypeRegistrar typeRegistrar, TypeReference typeReference)
{
TryMapToExistingRegistration(
typeRegistrar,
typeRef,
typeInfo,
typeReference.Context,
typeReference.Scope);
Expand All @@ -52,10 +53,20 @@ public void Handle(ITypeRegistrar typeRegistrar, TypeReference typeReference)

private static void TryMapToExistingRegistration(
ITypeRegistrar typeRegistrar,
ExtendedTypeReference typeRef,
ITypeInfo typeInfo,
TypeContext context,
string? scope)
{
// If there is an explicit runtime binding for the full type, keep the original
// type reference unresolved so discovery can apply that binding.
if (RuntimeTypeBindingHelper.RequiresExactBinding(typeRef.Type)
&& typeRegistrar.HasRuntimeTypeBinding(typeRef))
{
typeRegistrar.MarkUnresolved(typeRef);
return;
}

ExtendedTypeReference? normalizedTypeRef = null;
var resolved = false;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ void Register(

bool IsResolved(TypeReference typeReference);

bool HasRuntimeTypeBinding(ExtendedTypeReference typeReference);

TypeSystemObject CreateInstance(Type namedSchemaType);

IReadOnlyCollection<TypeReference> Unresolved { get; }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System.Collections;
using HotChocolate.Internal;

namespace HotChocolate.Configuration;

internal static class RuntimeTypeBindingHelper
{
public static bool RequiresExactBinding(IExtendedType runtimeType)
{
ArgumentNullException.ThrowIfNull(runtimeType);

return IsByteArray(runtimeType) || IsDictionary(runtimeType.Source);
}

private static bool IsByteArray(IExtendedType runtimeType)
=> runtimeType.IsArray
&& runtimeType.ElementType is { Source: { } elementType }
&& elementType == typeof(byte);

private static bool IsDictionary(Type type)
{
if (typeof(IDictionary).IsAssignableFrom(type))
{
return true;
}

if (type.IsGenericType)
{
var typeDefinition = type.GetGenericTypeDefinition();

if (typeDefinition == typeof(IDictionary<,>)
|| typeDefinition == typeof(IReadOnlyDictionary<,>))
{
return true;
}
}

foreach (var implementedType in type.GetInterfaces())
{
if (implementedType.IsGenericType)
{
var typeDefinition = implementedType.GetGenericTypeDefinition();

if (typeDefinition == typeof(IDictionary<,>)
|| typeDefinition == typeof(IReadOnlyDictionary<,>))
{
return true;
}
}
}

return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,17 @@ public bool TryGetType(TypeReference typeRef, [NotNullWhen(true)] out IType? typ
switch (typeRef)
{
case ExtendedTypeReference r:
var typeFactory = _typeInspector.CreateTypeFactory(r.Type);
type = typeFactory.CreateType(typeDefinition);
if (_typeRegistry.IsExplicitBinding(r)
&& RuntimeTypeBindingHelper.RequiresExactBinding(r.Type))
{
type = CreateExplicitBoundType(typeDefinition, r.Type);
}
else
{
var typeFactory = _typeInspector.CreateTypeFactory(r.Type);
type = typeFactory.CreateType(typeDefinition);
}

_typeCache[typeId] = type;
return true;

Expand Down Expand Up @@ -155,6 +164,18 @@ private static IType CreateType(
return namedType;
}

private static IType CreateExplicitBoundType(ITypeDefinition typeDefinition, IExtendedType runtimeType)
{
IType type = typeDefinition;

if (!runtimeType.IsNullable && typeDefinition.Kind is not TypeKind.NonNull)
{
type = new NonNullType(typeDefinition);
}

return type;
}

private TypeId CreateId(TypeReference typeRef, TypeReference namedTypeRef)
{
switch (typeRef)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,11 @@ private RegisteredType InitializeType(
.Build());
}
}

public bool HasRuntimeTypeBinding(ExtendedTypeReference typeReference)
{
ArgumentNullException.ThrowIfNull(typeReference);

return _typeRegistry.TryGetTypeRef(typeReference, out _);
}
}
19 changes: 18 additions & 1 deletion src/HotChocolate/Core/src/Types/Configuration/TypeRegistry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ internal sealed class TypeRegistry
private readonly Dictionary<TypeReference, RegisteredType> _typeRegister = [];
private readonly Dictionary<ExtendedTypeReference, TypeReference> _runtimeTypeRefs =
new(new ExtendedTypeRefEqualityComparer());
private readonly HashSet<ExtendedTypeReference> _explicitRuntimeTypeRefs =
new(new ExtendedTypeRefEqualityComparer());
private readonly Dictionary<string, TypeReference> _nameRefs = new(StringComparer.Ordinal);
private readonly Dictionary<FactoryTypeReference, TypeReference> _lookups = new(new TypeRefEqualityComparer());
private readonly List<RegisteredType> _types = [];
Expand Down Expand Up @@ -76,6 +78,13 @@ public bool TryGetTypeRef(
return _runtimeTypeRefs.TryGetValue(runtimeTypeRef, out typeRef);
}

public bool IsExplicitBinding(ExtendedTypeReference runtimeTypeRef)
{
ArgumentNullException.ThrowIfNull(runtimeTypeRef);

return _explicitRuntimeTypeRefs.Contains(runtimeTypeRef);
}

public bool TryGetTypeRef(
string typeName,
[NotNullWhen(true)] out TypeReference? typeRef)
Expand All @@ -93,12 +102,20 @@ public bool TryGetTypeRef(

public IEnumerable<TypeReference> GetTypeRefs() => _runtimeTypeRefs.Values;

public void TryRegister(ExtendedTypeReference runtimeTypeRef, TypeReference typeRef)
public void TryRegister(
ExtendedTypeReference runtimeTypeRef,
TypeReference typeRef,
bool explicitBinding = false)
{
ArgumentNullException.ThrowIfNull(runtimeTypeRef);
ArgumentNullException.ThrowIfNull(typeRef);

_runtimeTypeRefs.TryAdd(runtimeTypeRef, typeRef);

if (explicitBinding)
{
_explicitRuntimeTypeRefs.Add(runtimeTypeRef);
}
}

public void Register(RegisteredType registeredType)
Expand Down
7 changes: 5 additions & 2 deletions src/HotChocolate/Core/src/Types/SchemaBuilder.Setup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,12 @@ private static TypeInitializer CreateTypeInitializer(
{
foreach (var binding in bindings.Values)
{
var runtimeTypeRef = binding.GetRuntimeTypeReference(context.TypeInspector);

typeRegistry.TryRegister(
binding.GetRuntimeTypeReference(context.TypeInspector),
binding.GetSchemaTypeReference(context.TypeInspector));
runtimeTypeRef,
binding.GetSchemaTypeReference(context.TypeInspector),
explicitBinding: RuntimeTypeBindingHelper.RequiresExactBinding(runtimeTypeRef.Type));
}
}

Expand Down
40 changes: 40 additions & 0 deletions src/HotChocolate/Core/test/Types.Tests/SchemaBuilderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,36 @@ public void BindClrType_IntToString_IntFieldIsStringField()
schema.ToString().MatchSnapshot();
}

[Fact]
public void BindClrType_DictionaryToAnyType_DictionaryFieldIsScalarField()
{
// arrange
// act
var schema = SchemaBuilder.New()
.AddQueryType<QueryWithDictionaryArgument>()
.BindRuntimeType<IDictionary<string, object>, AnyType>()
.Create();

// assert
var queryType = schema.Types.GetType<ObjectType>("QueryWithDictionaryArgument");
Assert.Equal("Any", queryType.Fields["foo"].Arguments["foo"].Type.Print());
}

[Fact]
public void BindClrType_ByteArrayToBase64Type_ByteArrayFieldIsScalarField()
{
// arrange
// act
var schema = SchemaBuilder.New()
.AddQueryType<QueryWithByteArrayField>()
.BindRuntimeType<byte[], Base64StringType>()
.Create();

// assert
var queryType = schema.Types.GetType<ObjectType>("QueryWithByteArrayField");
Assert.Equal("Base64String!", queryType.Fields["foo"].Type.Print());
}

[Fact]
public void BindClrType_BuilderIsNull_ArgumentNullException()
{
Expand Down Expand Up @@ -2094,6 +2124,16 @@ public class QueryWithIntField
public int Foo { get; set; }
}

public class QueryWithDictionaryArgument
{
public bool Foo(IDictionary<string, object>? foo) => true;
}

public class QueryWithByteArrayField
{
public required byte[] Foo { get; set; }
}

public abstract class AbstractQuery
{
public required string Foo { get; set; }
Expand Down
Loading