Skip to content

Commit

Permalink
Query: Translate ToList over grouping element
Browse files Browse the repository at this point in the history
Part 2 & 3 of #26046
  • Loading branch information
smitpatel committed Sep 15, 2021
1 parent 24998f0 commit 72f83a6
Show file tree
Hide file tree
Showing 13 changed files with 355 additions and 144 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public class QueryExpressionReplacingExpressionVisitor : ExpressionVisitor
{
private readonly Expression _oldQuery;
private readonly Expression _newQuery;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public QueryExpressionReplacingExpressionVisitor(Expression oldQuery, Expression newQuery)
{
_oldQuery = oldQuery;
_newQuery = newQuery;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
[return: NotNullIfNotNull("expression")]
public override Expression? Visit(Expression? expression)
{
return expression is ProjectionBindingExpression projectionBindingExpression
&& ReferenceEquals(projectionBindingExpression.QueryExpression, _oldQuery)
? projectionBindingExpression.ProjectionMember != null
? new ProjectionBindingExpression(
_newQuery, projectionBindingExpression.ProjectionMember!, projectionBindingExpression.Type)
: new ProjectionBindingExpression(
_newQuery, projectionBindingExpression.Index!.Value, projectionBindingExpression.Type)
: base.Visit(expression);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand Down Expand Up @@ -464,29 +463,7 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
source = TranslateSelect(source, elementSelector);
}

if (translatedKey is NewExpression newExpression
&& newExpression.Arguments.Count == 0)
{
selectExpression.ApplyGrouping(_sqlExpressionFactory.ApplyDefaultTypeMapping(_sqlExpressionFactory.Constant(1)));
}
else
{
translatedKey = selectExpression.ApplyGrouping(translatedKey);
}
var clonedSelectExpression = selectExpression.Clone();
// If the grouping key is empty then there may not be any group by terms.
var correlationPredicate = selectExpression.GroupBy.Zip(clonedSelectExpression.GroupBy)
.Select(e => _sqlExpressionFactory.Equal(e.First, e.Second))
.Aggregate((l, r) => _sqlExpressionFactory.AndAlso(l, r));
clonedSelectExpression.ClearGroupBy();
clonedSelectExpression.ApplyPredicate(correlationPredicate);

var groupByShaper = new GroupByShaperExpression(
translatedKey,
new ShapedQueryExpression(
clonedSelectExpression,
new QueryExpressionReplacingExpressionVisitor(selectExpression, clonedSelectExpression).Visit(source.ShaperExpression)));

var groupByShaper = selectExpression.ApplyGrouping(translatedKey, source.ShaperExpression, _sqlExpressionFactory);
if (resultSelector == null)
{
return source.UpdateShaperExpression(groupByShaper);
Expand Down Expand Up @@ -1697,30 +1674,5 @@ static void PopulatePredicateTerms(SqlExpression predicate, List<SqlExpression>
}
}
}

private sealed class QueryExpressionReplacingExpressionVisitor : ExpressionVisitor
{
private readonly Expression _oldQuery;
private readonly Expression _newQuery;

public QueryExpressionReplacingExpressionVisitor(Expression oldQuery, Expression newQuery)
{
_oldQuery = oldQuery;
_newQuery = newQuery;
}

[return: NotNullIfNotNull("expression")]
public override Expression? Visit(Expression? expression)
{
return expression is ProjectionBindingExpression projectionBindingExpression
&& ReferenceEquals(projectionBindingExpression.QueryExpression, _oldQuery)
? projectionBindingExpression.ProjectionMember != null
? new ProjectionBindingExpression(
_newQuery, projectionBindingExpression.ProjectionMember!, projectionBindingExpression.Type)
: new ProjectionBindingExpression(
_newQuery, projectionBindingExpression.Index!.Value, projectionBindingExpression.Type)
: base.Visit(expression);
}
}
}
}
90 changes: 78 additions & 12 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions
Expand Down Expand Up @@ -1126,15 +1127,15 @@ public void ApplyPredicate(SqlExpression sqlExpression)
/// Applies grouping from given key selector.
/// </summary>
/// <param name="keySelector"> An key selector expression for the GROUP BY. </param>
public Expression ApplyGrouping(Expression keySelector)
public void ApplyGrouping(Expression keySelector)
{
Check.NotNull(keySelector, nameof(keySelector));

ClearOrdering();

var groupByTerms = new List<SqlExpression>();
var groupByAliases = new List<string?>();
AppendGroupBy(keySelector, groupByTerms, groupByAliases, "Key");
PopulateGroupByTerms(keySelector, groupByTerms, groupByAliases, "Key");

if (groupByTerms.Any(e => e is SqlConstantExpression || e is SqlParameterExpression || e is ScalarSubqueryExpression))
{
Expand Down Expand Up @@ -1163,19 +1164,84 @@ public Expression ApplyGrouping(Expression keySelector)
_identifier.AddRange(_groupBy.Select(e => ((ColumnExpression)e, e.TypeMapping!.KeyComparer)));
}
}

return keySelector;
}

/// <summary>
/// Clears existing group by terms.
/// Applies grouping from given key selector and generate <see cref="GroupByShaperExpression"/> to shape results.
/// </summary>
public void ClearGroupBy()
/// <param name="keySelector"> An key selector expression for the GROUP BY. </param>
/// <param name="shaperExpression"> The shaper expression for current query. </param>
/// <param name="sqlExpressionFactory"> The sql expression factory to use. </param>
/// <returns> A <see cref="GroupByShaperExpression"/> which represents the result of the grouping operation. </returns>
public GroupByShaperExpression ApplyGrouping(Expression keySelector, Expression shaperExpression, ISqlExpressionFactory sqlExpressionFactory)
{
_groupBy.Clear();
Check.NotNull(keySelector, nameof(keySelector));

ClearOrdering();

var keySelectorToAdd = keySelector;
var emptyKey = keySelector is NewExpression newExpression
&& newExpression.Arguments.Count == 0;
if (emptyKey)
{
keySelectorToAdd = sqlExpressionFactory.ApplyDefaultTypeMapping(sqlExpressionFactory.Constant(1));
}

var groupByTerms = new List<SqlExpression>();
var groupByAliases = new List<string?>();
PopulateGroupByTerms(keySelectorToAdd, groupByTerms, groupByAliases, "Key");

if (groupByTerms.Any(e => e is SqlConstantExpression || e is SqlParameterExpression || e is ScalarSubqueryExpression))
{
// EmptyKey will always hit this path.
var sqlRemappingVisitor = PushdownIntoSubqueryInternal();
var newGroupByTerms = new List<SqlExpression>(groupByTerms.Count);
var subquery = (SelectExpression)_tables[0];
var subqueryTableReference = _tableReferences[0];
for (var i = 0; i < groupByTerms.Count; i++)
{
var item = groupByTerms[i];
var newItem = subquery._projection.Any(e => e.Expression.Equals(item))
? sqlRemappingVisitor.Remap(item)
: subquery.GenerateOuterColumn(subqueryTableReference, item, groupByAliases[i] ?? "Key");
newGroupByTerms.Add(newItem);
}
if (!emptyKey)
{
// If non-empty key then we need to regenerate the key selector
keySelector = new ReplacingExpressionVisitor(groupByTerms, newGroupByTerms).Visit(keySelector);
}
groupByTerms = newGroupByTerms;
}

_groupBy.AddRange(groupByTerms);

// We generate the cloned expression before changing identifier for this SelectExpression
// because we are going to erase grouping for cloned expression.
var clonedSelectExpression = Clone();
var correlationPredicate = groupByTerms.Zip(clonedSelectExpression._groupBy)
.Select(e => sqlExpressionFactory.Equal(e.First, e.Second))
.Aggregate((l, r) => sqlExpressionFactory.AndAlso(l, r));
clonedSelectExpression._groupBy.Clear();
clonedSelectExpression.ApplyPredicate(correlationPredicate);

if (!_identifier.All(e => _groupBy.Contains(e.Column)))
{
_identifier.Clear();
if (_groupBy.All(e => e is ColumnExpression))
{
_identifier.AddRange(_groupBy.Select(e => ((ColumnExpression)e, e.TypeMapping!.KeyComparer)));
}
}

return new GroupByShaperExpression(
keySelector,
new ShapedQueryExpression(
clonedSelectExpression,
new QueryExpressionReplacingExpressionVisitor(this, clonedSelectExpression).Visit(shaperExpression)));
}

private void AppendGroupBy(Expression keySelector, List<SqlExpression> groupByTerms, List<string?> groupByAliases, string? name)
private void PopulateGroupByTerms(Expression keySelector, List<SqlExpression> groupByTerms, List<string?> groupByAliases, string? name)
{
Check.NotNull(keySelector, nameof(keySelector));

Expand All @@ -1189,23 +1255,23 @@ private void AppendGroupBy(Expression keySelector, List<SqlExpression> groupByTe
case NewExpression newExpression:
for (var i = 0; i < newExpression.Arguments.Count; i++)
{
AppendGroupBy(newExpression.Arguments[i], groupByTerms, groupByAliases, newExpression.Members?[i].Name);
PopulateGroupByTerms(newExpression.Arguments[i], groupByTerms, groupByAliases, newExpression.Members?[i].Name);
}
break;

case MemberInitExpression memberInitExpression:
AppendGroupBy(memberInitExpression.NewExpression, groupByTerms, groupByAliases, null);
PopulateGroupByTerms(memberInitExpression.NewExpression, groupByTerms, groupByAliases, null);
foreach (var argument in memberInitExpression.Bindings)
{
var memberAssignment = (MemberAssignment)argument;
AppendGroupBy(memberAssignment.Expression, groupByTerms, groupByAliases, memberAssignment.Member.Name);
PopulateGroupByTerms(memberAssignment.Expression, groupByTerms, groupByAliases, memberAssignment.Member.Name);
}
break;

case UnaryExpression unaryExpression
when unaryExpression.NodeType == ExpressionType.Convert
|| unaryExpression.NodeType == ExpressionType.ConvertChecked:
AppendGroupBy(unaryExpression.Operand, groupByTerms, groupByAliases, name);
PopulateGroupByTerms(unaryExpression.Operand, groupByTerms, groupByAliases, name);
break;

default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,9 @@ public GroupingElementReplacingExpressionVisitor(
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable
&& (methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable
|| methodCallExpression.Method.GetGenericMethodDefinition() == EnumerableMethods.ToList
|| methodCallExpression.Method.GetGenericMethodDefinition() == EnumerableMethods.ToArray)
&& methodCallExpression.Arguments[0] == _parameterExpression)
{
var currentTree = _cloningExpressionVisitor.Clone(_navigationExpansionExpression.CurrentTree);
Expand Down
Loading

0 comments on commit 72f83a6

Please sign in to comment.