diff --git a/src/HotChocolate/Core/src/Execution.Projections/SelectionExpressionBuilder.cs b/src/HotChocolate/Core/src/Execution.Projections/SelectionExpressionBuilder.cs index 7a238eeb062..2a215e7bb56 100644 --- a/src/HotChocolate/Core/src/Execution.Projections/SelectionExpressionBuilder.cs +++ b/src/HotChocolate/Core/src/Execution.Projections/SelectionExpressionBuilder.cs @@ -212,6 +212,12 @@ private void CollectTypes(Context context, Selection selection, TypeContainer pa var assignmentList = assignments.ToImmutable(); + // Wee keep EF constructor-injected entities intact by reusing the existing instance. + if (ShouldReuseExistingInstance(context.ParentType)) + { + return context.Parent; + } + // Preferred path for mutable types. var parameterlessConstructor = context.ParentType.GetConstructor(Type.EmptyTypes); if (parameterlessConstructor is not null) @@ -444,6 +450,11 @@ private readonly record struct Context( } } + private static bool ShouldReuseExistingInstance(Type type) + => type.GetConstructor(Type.EmptyTypes) is not null + && type.GetConstructors(BindingFlags.Instance | BindingFlags.NonPublic) + .Any(t => t.GetParameters().Length > 0); + private static bool IsMarkedAsExplicitlyNonNullable(ParameterInfo parameter) => new NullabilityInfoContext().Create(parameter).WriteState is NullabilityState.NotNull; } diff --git a/src/HotChocolate/Data/test/Data.EntityFramework.Tests/IntegrationTests.cs b/src/HotChocolate/Data/test/Data.EntityFramework.Tests/IntegrationTests.cs index f30af7a1b16..9d0df21c3f5 100644 --- a/src/HotChocolate/Data/test/Data.EntityFramework.Tests/IntegrationTests.cs +++ b/src/HotChocolate/Data/test/Data.EntityFramework.Tests/IntegrationTests.cs @@ -1,4 +1,5 @@ using HotChocolate.Execution; +using HotChocolate.Execution.Processing; using HotChocolate.Types; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; @@ -534,12 +535,90 @@ static Dictionary ReadCounts(JsonElement value) } } + [Fact] + public async Task AsSelector_Should_Preserve_Entity_Constructor_DbContext_Injection() + { + var databaseName = $"db-{Guid.NewGuid():N}"; + + await using (var seedContext = new ConstructorInjectionDbContext( + new DbContextOptionsBuilder() + .UseInMemoryDatabase(databaseName) + .Options)) + { + await seedContext.Database.EnsureCreatedAsync(); + + var blog1 = new ConstructorInjectionBlog { Name = "Blog1" }; + var blog2 = new ConstructorInjectionBlog { Name = "Blog2" }; + + await seedContext.Blogs.AddRangeAsync(blog1, blog2); + await seedContext.SaveChangesAsync(); + + await seedContext.Posts.AddRangeAsync( + new ConstructorInjectionPost { BlogId = blog1.Id }, + new ConstructorInjectionPost { BlogId = blog1.Id }, + new ConstructorInjectionPost { BlogId = blog2.Id }); + await seedContext.SaveChangesAsync(); + } + + var executor = await new ServiceCollection() + .AddDbContext( + b => b.UseInMemoryDatabase(databaseName)) + .AddGraphQL() + .AddProjections() + .AddQueryType() + .BuildRequestExecutorAsync(); + + var result = await executor.ExecuteAsync( + """ + { + blogsAsSelector { + name + postCount + } + blogsNoProjection { + name + postCount + } + } + """); + + var operationResult = result.ExpectOperationResult(); + Assert.True(operationResult.Errors is null || operationResult.Errors.Count == 0); + Assert.True(operationResult.Data.HasValue); + + using var document = JsonDocument.Parse(result.ToJson()); + var data = document.RootElement.GetProperty("data"); + var projectedCounts = ReadCounts(data.GetProperty("blogsAsSelector")); + var unprojectedCounts = ReadCounts(data.GetProperty("blogsNoProjection")); + + Assert.Equal(unprojectedCounts, projectedCounts); + Assert.Equal(3, projectedCounts["Blog1"]); + Assert.Equal(3, projectedCounts["Blog2"]); + + static Dictionary ReadCounts(JsonElement value) + { + var result = new Dictionary(); + + foreach (var item in value.EnumerateArray()) + { + result.Add(item.GetProperty("name").GetString()!, item.GetProperty("postCount").GetInt32()); + } + + return result; + } + } + public class ConstructorInjectionQuery { [UseProjection] public IQueryable GetBlogs(ConstructorInjectionDbContext context) => context.Blogs; + public IQueryable GetBlogsAsSelector( + ConstructorInjectionDbContext context, + ISelection selection) + => context.Blogs.Select(selection.AsSelector()); + public IQueryable GetBlogsNoProjection( ConstructorInjectionDbContext context) => context.Blogs;