Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,23 @@ private static ObjectType CreateKeyValuePairObjectType(
.Name("key")
.Extend()
.OnBeforeCreate(
(_, field) => field.SetMoreSpecificType(keyType, TypeContext.Output));
(_, field) =>
{
field.SetMoreSpecificType(keyType, TypeContext.Output);
field.SourceType = runtimeType;
field.ResolverType = runtimeType;
});

descriptor.Field(valueProperty)
.Name("value")
.Extend()
.OnBeforeCreate(
(_, field) => field.SetMoreSpecificType(valueType, TypeContext.Output));
(_, field) =>
{
field.SetMoreSpecificType(valueType, TypeContext.Output);
field.SourceType = runtimeType;
field.ResolverType = runtimeType;
});

descriptor.Extend()
.OnBeforeCreate(
Expand Down Expand Up @@ -259,12 +269,12 @@ private static string CreateKeyValuePairTypeName(IExtendedType type, TypeKind ki
var keyName = keyType.Type.Name;
var valueName = valueType.Type.Name;

if (keyType.IsNullable)
if (keyType.IsNullable && keyType.Type.IsValueType)
{
keyName = $"Nullable{keyName}";
}

if (valueType.IsNullable)
if (valueType.IsNullable && valueType.Type.IsValueType)
{
valueName = $"Nullable{valueName}";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class SourceGeneratorOffsetPagingReproTests
[Fact]
public async Task QueryType_SourceGenerator_Path_Works_Like_AddQueryType_Path()
{
var assembly = CompileReproAssembly();
var assembly = CompileOffsetPagingReproAssembly();

var sourceGeneratorException = await BuildSchemaWithSourceGeneratorRegistrationAsync(assembly);
var addQueryTypeException = await BuildSchemaWithAddQueryTypeRegistrationAsync(assembly);
Expand All @@ -32,17 +32,36 @@ public async Task QueryType_SourceGenerator_Path_Works_Like_AddQueryType_Path()
Assert.Null(addQueryTypeException);
}

[Fact]
public async Task Module_QueryType_Dictionary_Result_SourceGenerator_Path_Works_Like_AddQueryType_Path()
{
var assembly = CompileModuleDictionaryReproAssembly();

var sourceGenerated = await ExecuteWithSourceGeneratorRegistrationAsync(
assembly,
registrationMethodName: "AddDemo",
query: "{ foo { key value } }");

var addQueryType = await ExecuteWithAddQueryTypeRegistrationAsync(
assembly,
runtimeQueryTypeName: "Repro.RuntimeQuery",
query: "{ foo { key value } }");

Assert.Contains("foo: [KeyValuePairOfStringAndString!]!", sourceGenerated.Schema);
Assert.Equal(addQueryType.Result, sourceGenerated.Result);
Assert.DoesNotContain("\"errors\"", sourceGenerated.Result, StringComparison.Ordinal);
Assert.Contains("\"key\": \"foo\"", sourceGenerated.Result, StringComparison.Ordinal);
Assert.Contains("\"value\": \"bar\"", sourceGenerated.Result, StringComparison.Ordinal);
}

private static async Task<Exception?> BuildSchemaWithSourceGeneratorRegistrationAsync(Assembly assembly)
{
var services = new ServiceCollection();
var builder = services.AddGraphQLServer(disableDefaultSecurity: true);

var addTypesMethod = assembly
.GetTypes()
.Where(t => t is { IsAbstract: true, IsSealed: true }
&& t.Namespace == "Microsoft.Extensions.DependencyInjection")
.SelectMany(t => t.GetMethods(BindingFlags.Public | BindingFlags.Static))
.Single(m =>
var addTypesMethod = FindRegistrationMethod(
assembly,
m =>
{
var p = m.GetParameters();
return m.Name.StartsWith("Add", StringComparison.Ordinal)
Expand Down Expand Up @@ -71,7 +90,61 @@ public async Task QueryType_SourceGenerator_Path_Works_Like_AddQueryType_Path()
async () => await builder.BuildSchemaAsync());
}

private static Assembly CompileReproAssembly()
private static async Task<ExecutionResult> ExecuteWithSourceGeneratorRegistrationAsync(
Assembly assembly,
string registrationMethodName,
string query)
{
var builder = new ServiceCollection().AddGraphQLServer(disableDefaultSecurity: true);

var addModuleMethod = FindRegistrationMethod(
assembly,
m =>
{
var p = m.GetParameters();
return m.Name.Equals(registrationMethodName, StringComparison.Ordinal)
&& m.ReturnType == typeof(IRequestExecutorBuilder)
&& p.Length == 1
&& p[0].ParameterType == typeof(IRequestExecutorBuilder);
});

addModuleMethod.Invoke(null, [builder]);

var executor = await builder.BuildRequestExecutorAsync();
var result = await executor.ExecuteAsync(query);
return new ExecutionResult(executor.Schema.ToString(), result.ToJson());
}

private static async Task<ExecutionResult> ExecuteWithAddQueryTypeRegistrationAsync(
Assembly assembly,
string runtimeQueryTypeName,
string query)
{
var runtimeQueryType = assembly.GetType(runtimeQueryTypeName)
?? throw new InvalidOperationException("Could not locate runtime query type.");

var builder = new ServiceCollection()
.AddGraphQLServer(disableDefaultSecurity: true)
.AddQueryType(runtimeQueryType);

var executor = await builder.BuildRequestExecutorAsync();
var result = await executor.ExecuteAsync(query);
return new ExecutionResult(executor.Schema.ToString(), result.ToJson());
}

private static MethodInfo FindRegistrationMethod(
Assembly assembly,
Func<MethodInfo, bool> predicate)
{
return assembly
.GetTypes()
.Where(t => t is { IsAbstract: true, IsSealed: true }
&& t.Namespace == "Microsoft.Extensions.DependencyInjection")
.SelectMany(t => t.GetMethods(BindingFlags.Public | BindingFlags.Static))
.Single(predicate);
}

private static Assembly CompileOffsetPagingReproAssembly()
{
const string source = """
using System.Collections.Generic;
Expand Down Expand Up @@ -102,6 +175,45 @@ public async Task<Dictionary<string, string>> UglyLegacyResolver()
}
""";

return CompileReproAssembly(source, "SourceGeneratorOffsetPagingRepro");
}

private static Assembly CompileModuleDictionaryReproAssembly()
{
const string source = """
using System.Collections.Generic;
using HotChocolate;
using HotChocolate.Types;

[assembly: Module("Demo")]

namespace Repro;

[QueryType]
public static partial class SourceGeneratedQuery
{
public static Dictionary<string, string> Foo()
=> new()
{
["foo"] = "bar"
};
}

public class RuntimeQuery
{
public Dictionary<string, string> Foo()
=> new()
{
["foo"] = "bar"
};
}
""";

return CompileReproAssembly(source, "SourceGeneratorDictionaryModuleRepro");
}

private static Assembly CompileReproAssembly(string source, string assemblyName)
{
var parseOptions = CSharpParseOptions.Default;
var syntaxTree = CSharpSyntaxTree.ParseText(source, parseOptions);

Expand Down Expand Up @@ -143,7 +255,7 @@ public async Task<Dictionary<string, string>> UglyLegacyResolver()
];

var compilation = CSharpCompilation.Create(
assemblyName: "SourceGeneratorOffsetPagingRepro",
assemblyName: assemblyName,
syntaxTrees: [syntaxTree],
references: references,
options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
Expand Down Expand Up @@ -179,7 +291,9 @@ public async Task<Dictionary<string, string>> UglyLegacyResolver()

stream.Position = 0;

var context = new AssemblyLoadContext("SourceGeneratorOffsetPagingRepro", isCollectible: true);
var context = new AssemblyLoadContext(assemblyName, isCollectible: true);
return context.LoadFromStream(stream);
}

private sealed record ExecutionResult(string Schema, string Result);
}
Loading