Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,9 @@ private void CollectTypes(Context context, Selection selection, TypeContainer pa
return Expression.Convert(Expression.Constant(p.DefaultValue), p.ParameterType);
}

if (!p.ParameterType.IsValueType && IsMarkedAsExplicitlyNonNullable(p))
{
throw new InvalidOperationException(
$"Cannot construct '{context.ParentType.Name}': missing required argument '{p.Name}' "
+ "(non-nullable reference type with no default value).");
}

// Partial projections can omit constructor arguments that are not selected.
// We fall back to default values for missing arguments to keep selector
// construction non-throwing for record-like types.
return Expression.Default(p.ParameterType);
}).ToArray();

Expand Down Expand Up @@ -637,7 +633,4 @@ private static bool ShouldReuseExistingInstance(Type type)
=> type.GetConstructor(Type.EmptyTypes) is not null
&& type.GetConstructors(BindingFlags.Instance | BindingFlags.NonPublic)
.Any(t => t.GetParameters().Length > 0);

private static bool IsMarkedAsExplicitlyNonNullable(ParameterInfo parameter)
=> new NullabilityInfoContext().Create(parameter).WriteState is NullabilityState.NotNull;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
using System.Linq.Expressions;
using System.Text.Json;
using GreenDonut.Data;
using HotChocolate.Execution;
using HotChocolate.Types;
using HotChocolate.Types.Relay;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Squadron;

namespace HotChocolate.Data;

[Collection(PostgresCacheCollectionFixture.DefinitionName)]
public sealed class AsSelectorRecordProjectionTests(PostgreSqlResource resource)
{
[Fact]
public async Task AsSelector_Should_Project_Record_On_Standard_Field()
{
// arrange
var db = "db_" + Guid.NewGuid().ToString("N");
var connectionString = resource.GetConnectionString(db);

await using var services = CreateServer(connectionString);
await SeedAsync(
services,
new StoreRecord(1, "Zurich", "ZH", "CH"),
new StoreRecord(2, "Basel", "BS", "CH"));

var executor = await services
.GetRequiredService<IRequestExecutorProvider>()
.GetExecutorAsync();

// act
var result = await executor.ExecuteAsync(
"""
{
stores {
name
}
}
""");

// assert
var operationResult = result.ExpectOperationResult();
Assert.Empty(operationResult.Errors ?? []);

using var document = JsonDocument.Parse(result.ToJson());
var names = document.RootElement
.GetProperty("data")
.GetProperty("stores")
.EnumerateArray()
.Select(t => t.GetProperty("name").GetString()!)
.OrderBy(t => t)
.ToArray();

Assert.Equal(["Basel", "Zurich"], names);

var capture = services.GetRequiredService<RecordSelectorCapture>();

AssertSelectorProjects(capture.StandardFieldSelector, nameof(StoreRecord.Name));
Assert.NotNull(capture.StandardFieldSql);
}

[Fact]
public async Task AsSelector_Should_Project_Record_On_Node_Field()
{
// arrange
var db = "db_" + Guid.NewGuid().ToString("N");
var connectionString = resource.GetConnectionString(db);

await using var services = CreateServer(connectionString);
await SeedAsync(services, new StoreRecord(1, "Zurich", "ZH", "CH"));

var executor = await services
.GetRequiredService<IRequestExecutorProvider>()
.GetExecutorAsync();

var serializer = services.GetRequiredService<INodeIdSerializer>();
var storeId = serializer.Format("Store", 1);

var nodeQuery = $$"""
{
node(id: "{{storeId}}") {
id
... on Store {
name
}
}
}
""";

// act
var result = await executor.ExecuteAsync(nodeQuery);

// assert
var operationResult = result.ExpectOperationResult();
Assert.Empty(operationResult.Errors ?? []);

using var document = JsonDocument.Parse(result.ToJson());
var node = document.RootElement
.GetProperty("data")
.GetProperty("node");

Assert.Equal("Zurich", node.GetProperty("name").GetString());

var capture = services.GetRequiredService<RecordSelectorCapture>();

AssertSelectorProjects(
capture.NodeFieldSelector,
nameof(StoreRecord.Id),
nameof(StoreRecord.Name));

Assert.NotNull(capture.NodeFieldSql);
}

private static ServiceProvider CreateServer(string connectionString)
=> new ServiceCollection()
.AddDbContext<RecordStoreContext>(c => c.UseNpgsql(connectionString))
.AddScoped<RecordStoreService>()
.AddSingleton<RecordSelectorCapture>()
.AddGraphQLServer()
.AddQueryContext()
.AddGlobalObjectIdentification()
.AddQueryType(
descriptor => descriptor
.Name(OperationTypeNames.Query)
.Field("stores")
.ResolveWith<RecordStoreQueryResolver>(
t => t.GetStoresAsync(default!, default!, default))
.Type<ListType<NonNullType<ObjectType<StoreRecord>>>>())
.AddObjectType<StoreRecord>(
descriptor => descriptor
.Name("Store")
.ImplementsNode()
.IdField(t => t.Id)
.ResolveNodeWith(
typeof(RecordStoreNodeResolver).GetMethod(
nameof(RecordStoreNodeResolver.GetStoreByIdAsync))!))
.ModifyRequestOptions(o => o.IncludeExceptionDetails = true)
.Services
.BuildServiceProvider();

private static async Task SeedAsync(IServiceProvider services, params StoreRecord[] stores)
{
await using var scope = services.CreateAsyncScope();
var context = scope.ServiceProvider.GetRequiredService<RecordStoreContext>();

await context.Database.EnsureCreatedAsync();
context.Stores.AddRange(stores);
await context.SaveChangesAsync();
}

private static void AssertSelectorProjects(
Expression<Func<StoreRecord, StoreRecord>>? selector,
params string[] projectedMembers)
{
Assert.NotNull(selector);
var body = UnwrapConvert(selector!.Body);

Assert.IsType<NewExpression>(body);

var visitor = new RootMemberAccessVisitor(selector.Parameters[0]);
visitor.Visit(body);

foreach (var member in projectedMembers)
{
Assert.Contains(member, visitor.Members);
}
}

private static Expression UnwrapConvert(Expression expression)
{
while (expression is UnaryExpression
{
NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked,
Operand: { } operand
})
{
expression = operand;
}

return expression;
}

private sealed class RootMemberAccessVisitor(ParameterExpression root) : ExpressionVisitor
{
public HashSet<string> Members { get; } = [];

protected override Expression VisitMember(MemberExpression node)
{
if (IsRootMember(node.Expression))
{
Members.Add(node.Member.Name);
}

return base.VisitMember(node);
}

private bool IsRootMember(Expression? expression)
{
while (expression is UnaryExpression
{
NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked,
Operand: { } operand
})
{
expression = operand;
}

return expression == root;
}
}

private sealed record StoreRecord(
int Id,
string Name,
string Region,
string CountryCode);

private sealed class RecordStoreContext(DbContextOptions<RecordStoreContext> options)
: DbContext(options)
{
public DbSet<StoreRecord> Stores => Set<StoreRecord>();

protected override void OnModelCreating(ModelBuilder modelBuilder)
=> modelBuilder.Entity<StoreRecord>().HasKey(t => t.Id);
}

private sealed class RecordSelectorCapture
{
public Expression<Func<StoreRecord, StoreRecord>>? StandardFieldSelector { get; set; }

public Expression<Func<StoreRecord, StoreRecord>>? NodeFieldSelector { get; set; }

public string? StandardFieldSql { get; set; }

public string? NodeFieldSql { get; set; }
}

private sealed class RecordStoreService(RecordStoreContext context, RecordSelectorCapture capture)
{
public async Task<IReadOnlyList<StoreRecord>> GetStoresAsync(
QueryContext<StoreRecord> query,
CancellationToken cancellationToken)
{
capture.StandardFieldSelector = query.Selector;

var projectedQuery = context.Stores.AsNoTracking().With(query);
capture.StandardFieldSql = projectedQuery.ToQueryString();

return await projectedQuery.ToListAsync(cancellationToken);
}

public async Task<StoreRecord?> GetStoreByIdAsync(
int id,
QueryContext<StoreRecord> query,
CancellationToken cancellationToken)
{
capture.NodeFieldSelector = query.Selector;

var projectedQuery = context.Stores
.AsNoTracking()
.Where(t => t.Id == id)
.With(query);

capture.NodeFieldSql = projectedQuery.ToQueryString();

return await projectedQuery.SingleOrDefaultAsync(cancellationToken);
}
}

private sealed class RecordStoreQueryResolver
{
public Task<IReadOnlyList<StoreRecord>> GetStoresAsync(
QueryContext<StoreRecord> query,
[Service] RecordStoreService service,
CancellationToken cancellationToken)
=> service.GetStoresAsync(query, cancellationToken);
}

private sealed class RecordStoreNodeResolver
{
public static Task<StoreRecord?> GetStoreByIdAsync(
int id,
QueryContext<StoreRecord> query,
[Service] RecordStoreService service,
CancellationToken cancellationToken)
=> service.GetStoreByIdAsync(id, query, cancellationToken);
}
}
Loading