diff --git a/src/HotChocolate/Data/src/Data/Sorting/Expressions/Handlers/QueryableAscendingSortOperationHandler.cs b/src/HotChocolate/Data/src/Data/Sorting/Expressions/Handlers/QueryableAscendingSortOperationHandler.cs index 4efecd24f23..627a78ad34e 100644 --- a/src/HotChocolate/Data/src/Data/Sorting/Expressions/Handlers/QueryableAscendingSortOperationHandler.cs +++ b/src/HotChocolate/Data/src/Data/Sorting/Expressions/Handlers/QueryableAscendingSortOperationHandler.cs @@ -13,9 +13,7 @@ protected override QueryableSortOperation HandleOperation( QueryableFieldSelector fieldSelector, ISortField field, SortEnumValue? sortEnumValue) - { - return AscendingSortOperation.From(fieldSelector); - } + => AscendingSortOperation.From(fieldSelector); public static QueryableAscendingSortOperationHandler Create(SortProviderContext context) => new(); @@ -28,6 +26,30 @@ private AscendingSortOperation(QueryableFieldSelector fieldSelector) public override Expression CompileOrderBy(Expression expression) { + // We try to push the sort through any .Select() projection so the database can sort + // before projecting. If that works, we apply the sort on the source and re-attach the projection. + if (QueryableSortExpressionOptimizer.TryRewriteSelectorToSource( + expression, + ParameterExpression, + Selector, + out var rewrittenSource, + out var rewrittenSelector, + out var projection)) + { + var sortedSource = Expression.Call( + rewrittenSource.GetEnumerableKind(), + nameof(Queryable.OrderBy), + [rewrittenSelector.Parameters[0].Type, rewrittenSelector.ReturnType], + rewrittenSource, + rewrittenSelector); + + return QueryableSortExpressionOptimizer.ReapplyProjection( + sortedSource, + projection); + } + + // If the optimization is not possible, we fall back to a plain OrderBy on + // the expression as-is. return Expression.Call( expression.GetEnumerableKind(), nameof(Queryable.OrderBy), @@ -38,6 +60,30 @@ public override Expression CompileOrderBy(Expression expression) public override Expression CompileThenBy(Expression expression) { + // We try to push the sort through any .Select() projection so the database can sort + // before projecting. If that works, we apply the sort on the source and re-attach the projection. + if (QueryableSortExpressionOptimizer.TryRewriteSelectorToSource( + expression, + ParameterExpression, + Selector, + out var rewrittenSource, + out var rewrittenSelector, + out var projection)) + { + var sortedSource = Expression.Call( + rewrittenSource.GetEnumerableKind(), + nameof(Queryable.ThenBy), + [rewrittenSelector.Parameters[0].Type, rewrittenSelector.ReturnType], + rewrittenSource, + rewrittenSelector); + + return QueryableSortExpressionOptimizer.ReapplyProjection( + sortedSource, + projection); + } + + // If the optimization is not possible, we fall back to a plain ThenBy on + // the expression as-is. return Expression.Call( expression.GetEnumerableKind(), nameof(Queryable.ThenBy), @@ -46,7 +92,7 @@ public override Expression CompileThenBy(Expression expression) Expression.Lambda(Selector, ParameterExpression)); } - public static AscendingSortOperation From(QueryableFieldSelector selector) => - new AscendingSortOperation(selector); + public static AscendingSortOperation From(QueryableFieldSelector selector) + => new AscendingSortOperation(selector); } } diff --git a/src/HotChocolate/Data/src/Data/Sorting/Expressions/Handlers/QueryableDescendingSortOperationHandler.cs b/src/HotChocolate/Data/src/Data/Sorting/Expressions/Handlers/QueryableDescendingSortOperationHandler.cs index ae7cf4fd464..1b52a4c1ef0 100644 --- a/src/HotChocolate/Data/src/Data/Sorting/Expressions/Handlers/QueryableDescendingSortOperationHandler.cs +++ b/src/HotChocolate/Data/src/Data/Sorting/Expressions/Handlers/QueryableDescendingSortOperationHandler.cs @@ -13,9 +13,7 @@ protected override QueryableSortOperation HandleOperation( QueryableFieldSelector fieldSelector, ISortField field, SortEnumValue? sortEnumValue) - { - return DescendingSortOperation.From(fieldSelector); - } + => DescendingSortOperation.From(fieldSelector); public static QueryableDescendingSortOperationHandler Create(SortProviderContext context) => new(); @@ -28,6 +26,30 @@ private DescendingSortOperation(QueryableFieldSelector fieldSelector) public override Expression CompileOrderBy(Expression expression) { + // We try to push the sort through any .Select() projection so the database can sort + // before projecting. If that works, we apply the sort on the source and re-attach the projection. + if (QueryableSortExpressionOptimizer.TryRewriteSelectorToSource( + expression, + ParameterExpression, + Selector, + out var rewrittenSource, + out var rewrittenSelector, + out var projection)) + { + var sortedSource = Expression.Call( + rewrittenSource.GetEnumerableKind(), + nameof(Queryable.OrderByDescending), + [rewrittenSelector.Parameters[0].Type, rewrittenSelector.ReturnType], + rewrittenSource, + rewrittenSelector); + + return QueryableSortExpressionOptimizer.ReapplyProjection( + sortedSource, + projection); + } + + // If the optimization is not possible, we fall back to a plain OrderByDescending on + // the expression as-is. return Expression.Call( expression.GetEnumerableKind(), nameof(Queryable.OrderByDescending), @@ -38,6 +60,30 @@ public override Expression CompileOrderBy(Expression expression) public override Expression CompileThenBy(Expression expression) { + // We try to push the sort through any .Select() projection so the database can sort + // before projecting. If that works, we apply the sort on the source and re-attach the projection. + if (QueryableSortExpressionOptimizer.TryRewriteSelectorToSource( + expression, + ParameterExpression, + Selector, + out var rewrittenSource, + out var rewrittenSelector, + out var projection)) + { + var sortedSource = Expression.Call( + rewrittenSource.GetEnumerableKind(), + nameof(Queryable.ThenByDescending), + [rewrittenSelector.Parameters[0].Type, rewrittenSelector.ReturnType], + rewrittenSource, + rewrittenSelector); + + return QueryableSortExpressionOptimizer.ReapplyProjection( + sortedSource, + projection); + } + + // If the optimization is not possible, we fall back to a plain ThenByDescending on + // the expression as-is. return Expression.Call( expression.GetEnumerableKind(), nameof(Queryable.ThenByDescending), @@ -46,7 +92,7 @@ public override Expression CompileThenBy(Expression expression) Expression.Lambda(Selector, ParameterExpression)); } - public static DescendingSortOperation From(QueryableFieldSelector selector) => - new DescendingSortOperation(selector); + public static DescendingSortOperation From(QueryableFieldSelector selector) + => new DescendingSortOperation(selector); } } diff --git a/src/HotChocolate/Data/src/Data/Sorting/Expressions/QueryableSortExpressionOptimizer.cs b/src/HotChocolate/Data/src/Data/Sorting/Expressions/QueryableSortExpressionOptimizer.cs new file mode 100644 index 00000000000..e8f7746b6ac --- /dev/null +++ b/src/HotChocolate/Data/src/Data/Sorting/Expressions/QueryableSortExpressionOptimizer.cs @@ -0,0 +1,260 @@ +using System.Diagnostics.CodeAnalysis; +using System.Linq.Expressions; +using System.Reflection; + +namespace HotChocolate.Data.Sorting.Expressions; + +internal static class QueryableSortExpressionOptimizer +{ + private static readonly PropertyInfo s_dateTimeOffsetDateTime = + typeof(DateTimeOffset).GetProperty(nameof(DateTimeOffset.DateTime))!; + + /// + /// Tries to push a sort selector through a .Select(...) projection so that sorting + /// happens on the original source instead of the projected type. This allows the database + /// to apply the sort before the projection, which produces more efficient SQL. + /// + /// The query expression, expected to be a .Select(...) call. + /// The parameter the sort selector operates on (the projected type). + /// The sort selector expression to rewrite. + /// The original source before the .Select(...), if successful. + /// The sort selector rewritten to operate on the original source type, if successful. + /// The original .Select(...) lambda so it can be re-applied after sorting, if successful. + /// true if the rewrite succeeded; otherwise false. + public static bool TryRewriteSelectorToSource( + Expression source, + ParameterExpression selectorParameter, + Expression selector, + [NotNullWhen(true)] out Expression? rewrittenSource, + [NotNullWhen(true)] out LambdaExpression? rewrittenSelector, + [NotNullWhen(true)] out LambdaExpression? projection) + { + rewrittenSource = null; + rewrittenSelector = null; + projection = null; + + // We only proceed if the source is a .Select() call whose projection produces the same type + // that the sort selector operates on. Anything else can't be optimized. + if (source is not MethodCallExpression selectCall + || !IsSelectMethod(selectCall.Method) + || selectCall.Arguments.Count != 2 + || TryExtractLambda(selectCall.Arguments[1]) is not { Parameters.Count: 1 } selectLambda + || selectLambda.ReturnType != selectorParameter.Type) + { + return false; + } + + // Next, we try to trace the sort expression back through the projection to find the equivalent + // expression on the original source. If we can't, there is nothing we can do here. + if (!TryRewriteProjectedExpression( + selector, + selectorParameter, + selectLambda.Body, + out var sourceSelector)) + { + return false; + } + + // Finally, we strip any .DateTime access off DateTimeOffset fields so the sort translates + // cleanly to SQL, then package everything up and return success. + sourceSelector = DateTimeOffsetDateTimeExpressionVisitor.Rewrite(sourceSelector); + rewrittenSource = selectCall.Arguments[0]; + rewrittenSelector = Expression.Lambda(sourceSelector, selectLambda.Parameters[0]); + projection = selectLambda; + return true; + } + + /// + /// Re-applies the original .Select(...) projection on top of + /// after the sort has been pushed down to the underlying source query. + /// + /// The sorted source expression to project over. + /// The original Select lambda captured from . + /// A new expression equivalent to source.Select(projection). + public static Expression ReapplyProjection( + Expression source, + LambdaExpression projection) + => Expression.Call( + source.GetEnumerableKind(), + nameof(Queryable.Select), + [projection.Parameters[0].Type, projection.ReturnType], + source, + projection); + + private static bool IsSelectMethod(MethodInfo method) + => method.Name.Equals(nameof(Queryable.Select), StringComparison.Ordinal) + && (method.DeclaringType == typeof(Queryable) + || method.DeclaringType == typeof(Enumerable)); + + private static LambdaExpression? TryExtractLambda(Expression expression) + => expression switch + { + UnaryExpression { NodeType: ExpressionType.Quote, Operand: LambdaExpression lambda } => lambda, + LambdaExpression lambda => lambda, + _ => null + }; + + private static bool TryRewriteProjectedExpression( + Expression expression, + ParameterExpression selectorParameter, + Expression projection, + [NotNullWhen(true)] out Expression? rewritten) + { + if (expression == selectorParameter) + { + rewritten = projection; + return true; + } + + // We handle member access (e.g. dto.Name) by recursively rewriting the parent expression + // and then looking up which source expression was assigned to that member in the projection. + if (expression is MemberExpression memberExpression) + { + if (memberExpression.Expression is null + || !TryRewriteProjectedExpression( + memberExpression.Expression, + selectorParameter, + projection, + out var parentExpression) + || !TryBindMember(parentExpression, memberExpression.Member, out rewritten)) + { + rewritten = null; + return false; + } + + return true; + } + + // Next, we handle type casts by rewriting the inner operand and then rebuilding the cast + // around the rewritten expression. + if (expression is UnaryExpression unaryExpression + && (unaryExpression.NodeType == ExpressionType.Convert + || unaryExpression.NodeType == ExpressionType.ConvertChecked)) + { + if (!TryRewriteProjectedExpression( + unaryExpression.Operand, + selectorParameter, + projection, + out var operand)) + { + rewritten = null; + return false; + } + + rewritten = Expression.MakeUnary( + unaryExpression.NodeType, + operand, + unaryExpression.Type, + unaryExpression.Method); + return true; + } + + if (!ParameterExpressionVisitor.Contains(expression, selectorParameter)) + { + rewritten = expression; + return true; + } + + rewritten = null; + return false; + } + + private static bool TryBindMember( + Expression source, + MemberInfo member, + [NotNullWhen(true)] out Expression? rewritten) + { + // We first check for object initializer expressions (new Foo { Name = ... }) and look + // for a binding that matches the member name. + if (source is MemberInitExpression memberInit) + { + foreach (var binding in memberInit.Bindings) + { + if (binding is MemberAssignment assignment + && binding.Member.Name.Equals(member.Name, StringComparison.Ordinal)) + { + rewritten = assignment.Expression; + return true; + } + } + + rewritten = null; + return false; + } + + // Next, we check for constructor expressions (new Foo(...)) and match the member by name + // against the constructor parameters. + if (source is NewExpression { Members: not null } newExpression) + { + for (var i = 0; i < newExpression.Members!.Count; i++) + { + if (newExpression.Members[i].Name.Equals(member.Name, StringComparison.Ordinal)) + { + rewritten = newExpression.Arguments[i]; + return true; + } + } + + rewritten = null; + return false; + } + + // Finally, if the source type directly exposes the member, we just access it directly. + if (member.DeclaringType?.IsAssignableFrom(source.Type) ?? false) + { + rewritten = Expression.MakeMemberAccess(source, member); + return true; + } + + rewritten = null; + return false; + } + + private sealed class DateTimeOffsetDateTimeExpressionVisitor : ExpressionVisitor + { + protected override Expression VisitMember(MemberExpression node) + { + if (node.Member == s_dateTimeOffsetDateTime + && node.Expression is not null + && node.Expression.Type == typeof(DateTimeOffset)) + { + return Visit(node.Expression); + } + + return base.VisitMember(node); + } + + public static Expression Rewrite(Expression expression) + => new DateTimeOffsetDateTimeExpressionVisitor().Visit(expression); + } + + private sealed class ParameterExpressionVisitor(ParameterExpression parameter) : ExpressionVisitor + { + private readonly ParameterExpression _parameter = parameter; + + public bool ContainsParameter { get; private set; } + + public override Expression? Visit(Expression? node) + { + if (ContainsParameter || node is null) + { + return node; + } + + return base.Visit(node); + } + + protected override Expression VisitParameter(ParameterExpression node) + { + ContainsParameter = node == _parameter; + return node; + } + + public static bool Contains(Expression expression, ParameterExpression parameter) + { + var visitor = new ParameterExpressionVisitor(parameter); + visitor.Visit(expression); + return visitor.ContainsParameter; + } + } +} diff --git a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/DateTimeOffsetSortingTests.cs b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/DateTimeOffsetSortingTests.cs new file mode 100644 index 00000000000..ec082a363f9 --- /dev/null +++ b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/DateTimeOffsetSortingTests.cs @@ -0,0 +1,109 @@ +using System.Text.Json; +using HotChocolate.Execution; +using HotChocolate.Types; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using Squadron; + +namespace HotChocolate.Data; + +[Collection(PostgresCacheCollectionFixture.DefinitionName)] +public sealed class DateTimeOffsetSortingTests(PostgreSqlResource resource) +{ + [Fact] + public async Task Sort_Projected_DateTime_From_DateTimeOffset() + { + // arrange + var db = "db_" + Guid.NewGuid().ToString("N"); + var connectionString = resource.GetConnectionString(db); + + await using var services = new ServiceCollection() + .AddDbContext(c => c.UseNpgsql(connectionString)) + .AddGraphQLServer() + .AddSorting() + .ModifyRequestOptions(o => o.IncludeExceptionDetails = true) + .AddQueryType( + d => d + .Name("Query") + .Field("events") + .Type>>() + .Resolve(ctx => ctx.Service().Events + .Select(x => new ProjectedEvent + { + Timestamp = x.Timestamp.DateTime + }) + .AsExecutable()) + .UseSorting()) + .Services + .BuildServiceProvider(); + + await using var scope = services.CreateAsyncScope(); + await using var context = scope.ServiceProvider.GetRequiredService(); + + await context.Database.EnsureCreatedAsync(); + + context.Events.AddRange( + new EventEntity + { + Timestamp = new DateTimeOffset(2025, 11, 14, 8, 0, 0, TimeSpan.Zero) + }, + new EventEntity + { + Timestamp = new DateTimeOffset(2025, 11, 14, 13, 30, 0, TimeSpan.Zero) + }, + new EventEntity + { + Timestamp = new DateTimeOffset(2025, 11, 15, 17, 0, 0, TimeSpan.Zero) + }); + + await context.SaveChangesAsync(); + + var executor = await services + .GetRequiredService() + .GetExecutorAsync(); + + // act + var result = await executor.ExecuteAsync( + """ + { + events(order: [{ timestamp: DESC }]) { + timestamp + } + } + """); + + // assert + using var document = JsonDocument.Parse(result.ToJson()); + Assert.False( + document.RootElement.TryGetProperty("errors", out _), + result.ToJson()); + + var values = document.RootElement + .GetProperty("data") + .GetProperty("events") + .EnumerateArray() + .Select(t => DateTimeOffset.Parse(t.GetProperty("timestamp").GetString()!)) + .ToArray(); + + var sorted = values.OrderByDescending(t => t).ToArray(); + Assert.Equal(sorted, values); + } + + private sealed class EventContext(DbContextOptions options) + : DbContext(options) + { + public DbSet Events { get; set; } = null!; + } + + private sealed class EventEntity + { + public int Id { get; set; } + + public DateTimeOffset Timestamp { get; set; } + } + + private sealed record ProjectedEvent + { + public DateTime Timestamp { get; init; } + } +} diff --git a/src/HotChocolate/Data/test/Data.Sorting.SqlLite.Tests/QueryableSortVisitorDateTimeTests.cs b/src/HotChocolate/Data/test/Data.Sorting.SqlLite.Tests/QueryableSortVisitorDateTimeTests.cs new file mode 100644 index 00000000000..a1d74ea8c6b --- /dev/null +++ b/src/HotChocolate/Data/test/Data.Sorting.SqlLite.Tests/QueryableSortVisitorDateTimeTests.cs @@ -0,0 +1,110 @@ +using System.Text.Json; +using HotChocolate.Execution; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; + +namespace HotChocolate.Data.Sorting; + +public sealed class QueryableSortVisitorDateTimeTests +{ + [Fact] + public async Task Sort_Projected_DateTime_From_DateTimeOffset() + { + // arrange + var databaseName = $"{Guid.NewGuid():N}.db"; + var executor = await new ServiceCollection() + .AddDbContext(b => b.UseSqlite($"Data Source={databaseName}")) + .AddGraphQL() + .AddSorting() + .ModifyRequestOptions(o => o.IncludeExceptionDetails = true) + .AddQueryType() + .BuildRequestExecutorAsync(); + + // act + var result = await executor.ExecuteAsync( + """ + { + events(order: [{ timestamp: DESC }]) { + timestamp + } + } + """); + + // assert + var json = result.ToJson(); + using var document = JsonDocument.Parse(json); + + Assert.True(document.RootElement.TryGetProperty("errors", out _), json); + Assert.Contains("DateTimeOffset", json, StringComparison.Ordinal); + Assert.DoesNotContain("Timestamp.DateTime", json, StringComparison.Ordinal); + } + + public sealed class Query + { + [UseSorting] + public IExecutable GetEvents(EventContext context) + { + SeedData(context); + + return context.Events + .Select(x => new ProjectedEvent + { + Timestamp = x.Timestamp.DateTime + }) + .AsExecutable(); + } + + private static void SeedData(EventContext context) + { + context.Database.EnsureCreated(); + + if (context.Events.Any()) + { + return; + } + + context.Events.AddRange( + new EventEntity + { + Timestamp = new DateTimeOffset(2025, 11, 14, 9, 0, 0, TimeSpan.FromHours(1)) + }, + new EventEntity + { + Timestamp = new DateTimeOffset(2025, 11, 14, 14, 30, 0, TimeSpan.FromHours(1)) + }, + new EventEntity + { + Timestamp = new DateTimeOffset(2025, 11, 15, 18, 0, 0, TimeSpan.FromHours(1)) + }); + + context.SaveChanges(); + } + } + + public sealed class EventContext(DbContextOptions options) + : DbContext(options) + { + public DbSet Events { get; set; } = null!; + } + + public sealed class EventEntity + { + public int Id { get; set; } + + public DateTimeOffset Timestamp { get; set; } + } + + public sealed record ProjectedEvent + { + public DateTime Timestamp { get; init; } + } + + public sealed class EventSortType : SortInputType + { + protected override void Configure(ISortInputTypeDescriptor descriptor) + { + descriptor.BindFieldsExplicitly(); + descriptor.Field(f => f.Timestamp); + } + } +}