diff --git a/src/HotChocolate/Core/src/Execution.Projections/SelectionExpressionBuilder.cs b/src/HotChocolate/Core/src/Execution.Projections/SelectionExpressionBuilder.cs index 7a238eeb062..2a215e7bb56 100644 --- a/src/HotChocolate/Core/src/Execution.Projections/SelectionExpressionBuilder.cs +++ b/src/HotChocolate/Core/src/Execution.Projections/SelectionExpressionBuilder.cs @@ -212,6 +212,12 @@ private void CollectTypes(Context context, Selection selection, TypeContainer pa var assignmentList = assignments.ToImmutable(); + // Wee keep EF constructor-injected entities intact by reusing the existing instance. + if (ShouldReuseExistingInstance(context.ParentType)) + { + return context.Parent; + } + // Preferred path for mutable types. var parameterlessConstructor = context.ParentType.GetConstructor(Type.EmptyTypes); if (parameterlessConstructor is not null) @@ -444,6 +450,11 @@ private readonly record struct Context( } } + private static bool ShouldReuseExistingInstance(Type type) + => type.GetConstructor(Type.EmptyTypes) is not null + && type.GetConstructors(BindingFlags.Instance | BindingFlags.NonPublic) + .Any(t => t.GetParameters().Length > 0); + private static bool IsMarkedAsExplicitlyNonNullable(ParameterInfo parameter) => new NullabilityInfoContext().Create(parameter).WriteState is NullabilityState.NotNull; } diff --git a/src/HotChocolate/Data/src/Data/Projections/Expressions/QueryableProjectionScopeExtensions.cs b/src/HotChocolate/Data/src/Data/Projections/Expressions/QueryableProjectionScopeExtensions.cs index 3a7eb114a08..3e15cc9107a 100644 --- a/src/HotChocolate/Data/src/Data/Projections/Expressions/QueryableProjectionScopeExtensions.cs +++ b/src/HotChocolate/Data/src/Data/Projections/Expressions/QueryableProjectionScopeExtensions.cs @@ -1,5 +1,6 @@ using System.Diagnostics.CodeAnalysis; using System.Linq.Expressions; +using System.Reflection; namespace HotChocolate.Data.Projections.Expressions; @@ -29,6 +30,14 @@ public static Expression> Project( public static Expression CreateMemberInit(this QueryableProjectionScope scope) { + // When the type exposes non-public parameterized constructors (e.g., EF + // constructor service injection), we must preserve the instance that EF + // materialized so that injected services are not lost. + if (ShouldReuseExistingInstance(scope.RuntimeType)) + { + return scope.Instance.Peek(); + } + if (scope.HasAbstractTypes()) { Expression lastValue = Expression.Default(scope.RuntimeType); @@ -53,6 +62,11 @@ public static Expression CreateMemberInit(this QueryableProjectionScope scope) } } + private static bool ShouldReuseExistingInstance(Type type) + => type.GetConstructor(Type.EmptyTypes) is not null + && type.GetConstructors(BindingFlags.Instance | BindingFlags.NonPublic) + .Any(t => t.GetParameters().Length > 0); + public static Expression CreateMemberInitLambda(this QueryableProjectionScope scope) => Expression.Lambda(scope.CreateMemberInit(), scope.Parameter); diff --git a/src/HotChocolate/Data/test/Data.EntityFramework.Tests/IntegrationTests.cs b/src/HotChocolate/Data/test/Data.EntityFramework.Tests/IntegrationTests.cs index 5ec47bd8e07..9d0df21c3f5 100644 --- a/src/HotChocolate/Data/test/Data.EntityFramework.Tests/IntegrationTests.cs +++ b/src/HotChocolate/Data/test/Data.EntityFramework.Tests/IntegrationTests.cs @@ -1,7 +1,9 @@ using HotChocolate.Execution; +using HotChocolate.Execution.Processing; using HotChocolate.Types; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; +using System.Text.Json; namespace HotChocolate.Data; @@ -459,4 +461,202 @@ public async Task ExecuteAsync_Should_ReturnNull_When_FirstOrDefaultZero_AsyncEn // assert result.MatchSnapshot(); } + + [Fact] + public async Task UseProjection_Should_Preserve_Entity_Constructor_DbContext_Injection() + { + var databaseName = $"db-{Guid.NewGuid():N}"; + + await using (var seedContext = new ConstructorInjectionDbContext( + new DbContextOptionsBuilder() + .UseInMemoryDatabase(databaseName) + .Options)) + { + await seedContext.Database.EnsureCreatedAsync(); + + var blog1 = new ConstructorInjectionBlog { Name = "Blog1" }; + var blog2 = new ConstructorInjectionBlog { Name = "Blog2" }; + + await seedContext.Blogs.AddRangeAsync(blog1, blog2); + await seedContext.SaveChangesAsync(); + + await seedContext.Posts.AddRangeAsync( + new ConstructorInjectionPost { BlogId = blog1.Id }, + new ConstructorInjectionPost { BlogId = blog1.Id }, + new ConstructorInjectionPost { BlogId = blog2.Id }); + await seedContext.SaveChangesAsync(); + } + + var executor = await new ServiceCollection() + .AddDbContext( + b => b.UseInMemoryDatabase(databaseName)) + .AddGraphQL() + .AddProjections() + .AddQueryType() + .BuildRequestExecutorAsync(); + + var result = await executor.ExecuteAsync( + """ + { + blogs { + name + postCount + } + blogsNoProjection { + name + postCount + } + } + """); + + var operationResult = result.ExpectOperationResult(); + Assert.True(operationResult.Errors is null || operationResult.Errors.Count == 0); + Assert.True(operationResult.Data.HasValue); + + using var document = JsonDocument.Parse(result.ToJson()); + var data = document.RootElement.GetProperty("data"); + var projectedCounts = ReadCounts(data.GetProperty("blogs")); + var unprojectedCounts = ReadCounts(data.GetProperty("blogsNoProjection")); + + Assert.Equal(unprojectedCounts, projectedCounts); + Assert.Equal(3, projectedCounts["Blog1"]); + Assert.Equal(3, projectedCounts["Blog2"]); + + static Dictionary ReadCounts(JsonElement value) + { + var result = new Dictionary(); + + foreach (var item in value.EnumerateArray()) + { + result.Add(item.GetProperty("name").GetString()!, item.GetProperty("postCount").GetInt32()); + } + + return result; + } + } + + [Fact] + public async Task AsSelector_Should_Preserve_Entity_Constructor_DbContext_Injection() + { + var databaseName = $"db-{Guid.NewGuid():N}"; + + await using (var seedContext = new ConstructorInjectionDbContext( + new DbContextOptionsBuilder() + .UseInMemoryDatabase(databaseName) + .Options)) + { + await seedContext.Database.EnsureCreatedAsync(); + + var blog1 = new ConstructorInjectionBlog { Name = "Blog1" }; + var blog2 = new ConstructorInjectionBlog { Name = "Blog2" }; + + await seedContext.Blogs.AddRangeAsync(blog1, blog2); + await seedContext.SaveChangesAsync(); + + await seedContext.Posts.AddRangeAsync( + new ConstructorInjectionPost { BlogId = blog1.Id }, + new ConstructorInjectionPost { BlogId = blog1.Id }, + new ConstructorInjectionPost { BlogId = blog2.Id }); + await seedContext.SaveChangesAsync(); + } + + var executor = await new ServiceCollection() + .AddDbContext( + b => b.UseInMemoryDatabase(databaseName)) + .AddGraphQL() + .AddProjections() + .AddQueryType() + .BuildRequestExecutorAsync(); + + var result = await executor.ExecuteAsync( + """ + { + blogsAsSelector { + name + postCount + } + blogsNoProjection { + name + postCount + } + } + """); + + var operationResult = result.ExpectOperationResult(); + Assert.True(operationResult.Errors is null || operationResult.Errors.Count == 0); + Assert.True(operationResult.Data.HasValue); + + using var document = JsonDocument.Parse(result.ToJson()); + var data = document.RootElement.GetProperty("data"); + var projectedCounts = ReadCounts(data.GetProperty("blogsAsSelector")); + var unprojectedCounts = ReadCounts(data.GetProperty("blogsNoProjection")); + + Assert.Equal(unprojectedCounts, projectedCounts); + Assert.Equal(3, projectedCounts["Blog1"]); + Assert.Equal(3, projectedCounts["Blog2"]); + + static Dictionary ReadCounts(JsonElement value) + { + var result = new Dictionary(); + + foreach (var item in value.EnumerateArray()) + { + result.Add(item.GetProperty("name").GetString()!, item.GetProperty("postCount").GetInt32()); + } + + return result; + } + } + + public class ConstructorInjectionQuery + { + [UseProjection] + public IQueryable GetBlogs(ConstructorInjectionDbContext context) + => context.Blogs; + + public IQueryable GetBlogsAsSelector( + ConstructorInjectionDbContext context, + ISelection selection) + => context.Blogs.Select(selection.AsSelector()); + + public IQueryable GetBlogsNoProjection( + ConstructorInjectionDbContext context) + => context.Blogs; + } + + public class ConstructorInjectionDbContext( + DbContextOptions options) + : DbContext(options) + { + public DbSet Blogs => Set(); + + public DbSet Posts => Set(); + } + + public class ConstructorInjectionBlog + { + public ConstructorInjectionBlog() + { + } + + private ConstructorInjectionBlog(ConstructorInjectionDbContext context) + { + Context = context; + } + + private ConstructorInjectionDbContext? Context { get; } + + public int Id { get; set; } + + public string Name { get; set; } = default!; + + public int PostCount => Context?.Posts.Count() ?? 0; + } + + public class ConstructorInjectionPost + { + public int Id { get; set; } + + public int BlogId { get; set; } + } } diff --git a/src/HotChocolate/Data/test/Data.EntityFramework.Tests/Issue5449Tests.cs b/src/HotChocolate/Data/test/Data.EntityFramework.Tests/Issue5449Tests.cs new file mode 100644 index 00000000000..0c3685b415e --- /dev/null +++ b/src/HotChocolate/Data/test/Data.EntityFramework.Tests/Issue5449Tests.cs @@ -0,0 +1,107 @@ +using System.ComponentModel.DataAnnotations.Schema; +using System.Text.Json; +using HotChocolate.Execution; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; + +namespace HotChocolate.Data; + +public class Issue5449Tests +{ + [Fact] + public async Task UseProjection_Should_Keep_Entity_Service_Injection() + { + // arrange + IServiceProvider services = + new ServiceCollection() + .AddDbContextPool( + b => b.UseInMemoryDatabase($"Issue5449-{Guid.NewGuid():N}")) + .AddGraphQL() + .AddProjections() + .AddQueryType() + .Services + .BuildServiceProvider(); + + var executor = await services + .GetRequiredService() + .GetExecutorAsync(); + + await using (var scope = services.CreateAsyncScope()) + { + await using var context = scope.ServiceProvider.GetRequiredService(); + await context.Blogs.AddAsync(new Issue5449Blog("hc")); + await context.SaveChangesAsync(); + } + + await using (var scope = services.CreateAsyncScope()) + { + await using var context = scope.ServiceProvider.GetRequiredService(); + var directMaterialization = await context.Blogs.AsNoTracking().SingleAsync(); + Assert.True(directMaterialization.ContextInjected); + } + + // act + var result = await executor.ExecuteAsync( + """ + { + blogs { + contextInjected + } + } + """); + + // assert + var operationResult = result.ExpectOperationResult(); + Assert.Empty(operationResult.Errors ?? []); + + using var document = JsonDocument.Parse(result.ToJson()); + var contextInjected = document + .RootElement + .GetProperty("data") + .GetProperty("blogs")[0] + .GetProperty("contextInjected") + .GetBoolean(); + + Assert.True(contextInjected); + } + + public sealed class Issue5449Query + { + [UseProjection] + public IQueryable GetBlogs(Issue5449Context context) + => context.Blogs; + } + + public sealed class Issue5449Context(DbContextOptions options) : DbContext(options) + { + public DbSet Blogs => Set(); + } + + public sealed class Issue5449Blog + { + private readonly Issue5449Context? _context; + + public Issue5449Blog() + { + Name = string.Empty; + } + + public Issue5449Blog(string name) + { + Name = name; + } + + private Issue5449Blog(Issue5449Context context) + { + _context = context; + Name = string.Empty; + } + + public int Id { get; set; } + + public string Name { get; set; } + + [NotMapped] + public bool ContextInjected => _context is not null; + } +}