diff --git a/src/HotChocolate/Core/src/Execution.Projections/SelectionExpressionBuilder.cs b/src/HotChocolate/Core/src/Execution.Projections/SelectionExpressionBuilder.cs index 78a6d6838cb..ba2f65f3810 100644 --- a/src/HotChocolate/Core/src/Execution.Projections/SelectionExpressionBuilder.cs +++ b/src/HotChocolate/Core/src/Execution.Projections/SelectionExpressionBuilder.cs @@ -280,13 +280,9 @@ private void CollectTypes(Context context, Selection selection, TypeContainer pa return Expression.Convert(Expression.Constant(p.DefaultValue), p.ParameterType); } - if (!p.ParameterType.IsValueType && IsMarkedAsExplicitlyNonNullable(p)) - { - throw new InvalidOperationException( - $"Cannot construct '{context.ParentType.Name}': missing required argument '{p.Name}' " - + "(non-nullable reference type with no default value)."); - } - + // Partial projections can omit constructor arguments that are not selected. + // We fall back to default values for missing arguments to keep selector + // construction non-throwing for record-like types. return Expression.Default(p.ParameterType); }).ToArray(); @@ -637,7 +633,4 @@ private static bool ShouldReuseExistingInstance(Type type) => type.GetConstructor(Type.EmptyTypes) is not null && type.GetConstructors(BindingFlags.Instance | BindingFlags.NonPublic) .Any(t => t.GetParameters().Length > 0); - - private static bool IsMarkedAsExplicitlyNonNullable(ParameterInfo parameter) - => new NullabilityInfoContext().Create(parameter).WriteState is NullabilityState.NotNull; } diff --git a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/AsSelectorRecordProjectionTests.cs b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/AsSelectorRecordProjectionTests.cs new file mode 100644 index 00000000000..0ed1679c002 --- /dev/null +++ b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/AsSelectorRecordProjectionTests.cs @@ -0,0 +1,290 @@ +using System.Linq.Expressions; +using System.Text.Json; +using GreenDonut.Data; +using HotChocolate.Execution; +using HotChocolate.Types; +using HotChocolate.Types.Relay; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using Squadron; + +namespace HotChocolate.Data; + +[Collection(PostgresCacheCollectionFixture.DefinitionName)] +public sealed class AsSelectorRecordProjectionTests(PostgreSqlResource resource) +{ + [Fact] + public async Task AsSelector_Should_Project_Record_On_Standard_Field() + { + // arrange + var db = "db_" + Guid.NewGuid().ToString("N"); + var connectionString = resource.GetConnectionString(db); + + await using var services = CreateServer(connectionString); + await SeedAsync( + services, + new StoreRecord(1, "Zurich", "ZH", "CH"), + new StoreRecord(2, "Basel", "BS", "CH")); + + var executor = await services + .GetRequiredService() + .GetExecutorAsync(); + + // act + var result = await executor.ExecuteAsync( + """ + { + stores { + name + } + } + """); + + // assert + var operationResult = result.ExpectOperationResult(); + Assert.Empty(operationResult.Errors ?? []); + + using var document = JsonDocument.Parse(result.ToJson()); + var names = document.RootElement + .GetProperty("data") + .GetProperty("stores") + .EnumerateArray() + .Select(t => t.GetProperty("name").GetString()!) + .OrderBy(t => t) + .ToArray(); + + Assert.Equal(["Basel", "Zurich"], names); + + var capture = services.GetRequiredService(); + + AssertSelectorProjects(capture.StandardFieldSelector, nameof(StoreRecord.Name)); + Assert.NotNull(capture.StandardFieldSql); + } + + [Fact] + public async Task AsSelector_Should_Project_Record_On_Node_Field() + { + // arrange + var db = "db_" + Guid.NewGuid().ToString("N"); + var connectionString = resource.GetConnectionString(db); + + await using var services = CreateServer(connectionString); + await SeedAsync(services, new StoreRecord(1, "Zurich", "ZH", "CH")); + + var executor = await services + .GetRequiredService() + .GetExecutorAsync(); + + var serializer = services.GetRequiredService(); + var storeId = serializer.Format("Store", 1); + + var nodeQuery = $$""" + { + node(id: "{{storeId}}") { + id + ... on Store { + name + } + } + } + """; + + // act + var result = await executor.ExecuteAsync(nodeQuery); + + // assert + var operationResult = result.ExpectOperationResult(); + Assert.Empty(operationResult.Errors ?? []); + + using var document = JsonDocument.Parse(result.ToJson()); + var node = document.RootElement + .GetProperty("data") + .GetProperty("node"); + + Assert.Equal("Zurich", node.GetProperty("name").GetString()); + + var capture = services.GetRequiredService(); + + AssertSelectorProjects( + capture.NodeFieldSelector, + nameof(StoreRecord.Id), + nameof(StoreRecord.Name)); + + Assert.NotNull(capture.NodeFieldSql); + } + + private static ServiceProvider CreateServer(string connectionString) + => new ServiceCollection() + .AddDbContext(c => c.UseNpgsql(connectionString)) + .AddScoped() + .AddSingleton() + .AddGraphQLServer() + .AddQueryContext() + .AddGlobalObjectIdentification() + .AddQueryType( + descriptor => descriptor + .Name(OperationTypeNames.Query) + .Field("stores") + .ResolveWith( + t => t.GetStoresAsync(default!, default!, default)) + .Type>>>()) + .AddObjectType( + descriptor => descriptor + .Name("Store") + .ImplementsNode() + .IdField(t => t.Id) + .ResolveNodeWith( + typeof(RecordStoreNodeResolver).GetMethod( + nameof(RecordStoreNodeResolver.GetStoreByIdAsync))!)) + .ModifyRequestOptions(o => o.IncludeExceptionDetails = true) + .Services + .BuildServiceProvider(); + + private static async Task SeedAsync(IServiceProvider services, params StoreRecord[] stores) + { + await using var scope = services.CreateAsyncScope(); + var context = scope.ServiceProvider.GetRequiredService(); + + await context.Database.EnsureCreatedAsync(); + context.Stores.AddRange(stores); + await context.SaveChangesAsync(); + } + + private static void AssertSelectorProjects( + Expression>? selector, + params string[] projectedMembers) + { + Assert.NotNull(selector); + var body = UnwrapConvert(selector!.Body); + + Assert.IsType(body); + + var visitor = new RootMemberAccessVisitor(selector.Parameters[0]); + visitor.Visit(body); + + foreach (var member in projectedMembers) + { + Assert.Contains(member, visitor.Members); + } + } + + private static Expression UnwrapConvert(Expression expression) + { + while (expression is UnaryExpression + { + NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked, + Operand: { } operand + }) + { + expression = operand; + } + + return expression; + } + + private sealed class RootMemberAccessVisitor(ParameterExpression root) : ExpressionVisitor + { + public HashSet Members { get; } = []; + + protected override Expression VisitMember(MemberExpression node) + { + if (IsRootMember(node.Expression)) + { + Members.Add(node.Member.Name); + } + + return base.VisitMember(node); + } + + private bool IsRootMember(Expression? expression) + { + while (expression is UnaryExpression + { + NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked, + Operand: { } operand + }) + { + expression = operand; + } + + return expression == root; + } + } + + private sealed record StoreRecord( + int Id, + string Name, + string Region, + string CountryCode); + + private sealed class RecordStoreContext(DbContextOptions options) + : DbContext(options) + { + public DbSet Stores => Set(); + + protected override void OnModelCreating(ModelBuilder modelBuilder) + => modelBuilder.Entity().HasKey(t => t.Id); + } + + private sealed class RecordSelectorCapture + { + public Expression>? StandardFieldSelector { get; set; } + + public Expression>? NodeFieldSelector { get; set; } + + public string? StandardFieldSql { get; set; } + + public string? NodeFieldSql { get; set; } + } + + private sealed class RecordStoreService(RecordStoreContext context, RecordSelectorCapture capture) + { + public async Task> GetStoresAsync( + QueryContext query, + CancellationToken cancellationToken) + { + capture.StandardFieldSelector = query.Selector; + + var projectedQuery = context.Stores.AsNoTracking().With(query); + capture.StandardFieldSql = projectedQuery.ToQueryString(); + + return await projectedQuery.ToListAsync(cancellationToken); + } + + public async Task GetStoreByIdAsync( + int id, + QueryContext query, + CancellationToken cancellationToken) + { + capture.NodeFieldSelector = query.Selector; + + var projectedQuery = context.Stores + .AsNoTracking() + .Where(t => t.Id == id) + .With(query); + + capture.NodeFieldSql = projectedQuery.ToQueryString(); + + return await projectedQuery.SingleOrDefaultAsync(cancellationToken); + } + } + + private sealed class RecordStoreQueryResolver + { + public Task> GetStoresAsync( + QueryContext query, + [Service] RecordStoreService service, + CancellationToken cancellationToken) + => service.GetStoresAsync(query, cancellationToken); + } + + private sealed class RecordStoreNodeResolver + { + public static Task GetStoreByIdAsync( + int id, + QueryContext query, + [Service] RecordStoreService service, + CancellationToken cancellationToken) + => service.GetStoreByIdAsync(id, query, cancellationToken); + } +}