Skip to content

Commit

Permalink
Fix to #12351 - Incorrect SQL generated for Count over Group By (EFCo…
Browse files Browse the repository at this point in the history
…re 2.1)

There was two problems here:
- we were not lifting a select expression that had a GroupBy-Aggregate pattern, and another result operator was composed on top,
- we were not propagating RequiresStreamingGroupBy flag to a parent QM, when we lifted a client group by subquery, which could result in the client methods being wiped out if Count/LongCount was composed on top.
  • Loading branch information
maumar committed Jun 20, 2018
1 parent 0dc07a6 commit eb07c40
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ private static Expression HandleAll(HandlerContext handlerContext)
var sqlTranslatingVisitor
= handlerContext.CreateSqlTranslatingVisitor();

PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);

var predicate
= sqlTranslatingVisitor.Visit(
Expand Down Expand Up @@ -224,7 +224,7 @@ private static Expression HandleAverage(HandlerContext handlerContext)
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
&& handlerContext.SelectExpression.Projection.Count == 1)
{
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);

var expression = handlerContext.SelectExpression.Projection.First();

Expand Down Expand Up @@ -344,7 +344,7 @@ var collectionSelectExpression

private static Expression HandleCount(HandlerContext handlerContext)
{
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);

handlerContext.SelectExpression
.SetProjectionExpression(
Expand Down Expand Up @@ -480,8 +480,7 @@ var sqlExpression
&& shapedQueryMethod.Method.MethodIsClosedFormOf(
handlerContext.QueryModelVisitor.QueryCompilationContext.QueryMethodProvider.ShapedQueryMethod))
{
var selectExpression = handlerContext.SelectExpression;
PrepareSelectExpressionForAggregate(selectExpression);
PrepareSelectExpressionForAggregate(handlerContext);

// GroupBy Aggregate
// TODO: InjectParameters type Expression.
Expand Down Expand Up @@ -518,6 +517,8 @@ var sqlExpression
break;
}

var selectExpression = handlerContext.SelectExpression;

if (key != null
|| groupResultOperator.KeySelector is ConstantExpression
|| groupResultOperator.KeySelector is ParameterExpression)
Expand Down Expand Up @@ -594,16 +595,14 @@ var sqlExpression

if (sqlExpression != null)
{
var selectExpression = handlerContext.SelectExpression;

PrepareSelectExpressionForAggregate(selectExpression);
PrepareSelectExpressionForAggregate(handlerContext);

sqlExpression
= sqlTranslatingExpressionVisitor.Visit(groupResultOperator.KeySelector);

var columns = (sqlExpression as ConstantExpression)?.Value as Expression[] ?? new[] { sqlExpression };

selectExpression.PrependToOrderBy(columns.Select(c => new Ordering(c, OrderingDirection.Asc)));
handlerContext.SelectExpression.PrependToOrderBy(columns.Select(c => new Ordering(c, OrderingDirection.Asc)));

handlerContext.QueryModelVisitor.RequiresStreamingGroupResultOperator = true;
}
Expand Down Expand Up @@ -834,7 +833,7 @@ private static Expression HandleLast(HandlerContext handlerContext)

private static Expression HandleLongCount(HandlerContext handlerContext)
{
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);

handlerContext.SelectExpression
.SetProjectionExpression(
Expand All @@ -853,7 +852,7 @@ private static Expression HandleMin(HandlerContext handlerContext)
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
&& handlerContext.SelectExpression.Projection.Count == 1)
{
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);
var expression = handlerContext.SelectExpression.Projection.First();

if (!(expression.RemoveConvert() is SelectExpression))
Expand Down Expand Up @@ -882,7 +881,7 @@ private static Expression HandleMax(HandlerContext handlerContext)
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
&& handlerContext.SelectExpression.Projection.Count == 1)
{
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);
var expression = handlerContext.SelectExpression.Projection.First();

if (!(expression.RemoveConvert() is SelectExpression))
Expand Down Expand Up @@ -968,7 +967,7 @@ private static Expression HandleSum(HandlerContext handlerContext)
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
&& handlerContext.SelectExpression.Projection.Count == 1)
{
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);
var expression = handlerContext.SelectExpression.Projection.First();

if (!(expression.RemoveConvert() is SelectExpression))
Expand Down Expand Up @@ -1032,17 +1031,23 @@ private static void SetConditionAsProjection(
typeof(bool)));
}

private static void PrepareSelectExpressionForAggregate(SelectExpression selectExpression)
private static void PrepareSelectExpressionForAggregate(HandlerContext handlerContext)
{
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
if (handlerContext.SelectExpression.IsDistinct
|| handlerContext.SelectExpression.Limit != null
|| handlerContext.SelectExpression.Offset != null
|| (handlerContext.SelectExpression.GroupBy.Any()
&& !IsGroupByAggregate(handlerContext.QueryModel)))
{
selectExpression.PushDownSubquery();
selectExpression.ExplodeStarProjection();
handlerContext.SelectExpression.PushDownSubquery();
handlerContext.SelectExpression.ExplodeStarProjection();
}
}

private static bool IsGroupByAggregate(QueryModel queryModel)
=> queryModel.MainFromClause.FromExpression is QuerySourceReferenceExpression mainFromClauseQsre
&& mainFromClauseQsre.ReferencedQuerySource.ItemType.IsGrouping();

private static Expression UnwrapAliasExpression(Expression expression)
=> (expression as AliasExpression)?.Expression ?? expression;

Expand Down
2 changes: 2 additions & 0 deletions src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,8 @@ var subQueryModelVisitor
var subSelectExpression = subQueryModelVisitor.Queries.First();
AddQuery(querySource, subSelectExpression);

RequiresStreamingGroupResultOperator = true;

return subQueryModelVisitor.Expression;
}

Expand Down
19 changes: 19 additions & 0 deletions src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1990,5 +1990,24 @@ public virtual async Task Double_GroupBy_with_aggregate()

#endregion

#region ResultOperatorsAfterGroupBy

[ConditionalFact]
public virtual async Task Count_after_GroupBy()
{
await AssertSingleResult<Order>(
os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).CountAsync());
}

[ConditionalFact]
public virtual async Task LongCount_after_client_GroupBy()
{
await AssertSingleResult<Order>(
os => (from o in os
group o by new { o.CustomerID } into g
select g.Where(e => e.OrderID < 10300).Count()).LongCountAsync());
}

#endregion
}
}
20 changes: 20 additions & 0 deletions src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1994,5 +1994,25 @@ public virtual void Double_GroupBy_with_aggregate()
}

#endregion

#region ResultOperatorsAfterGroupBy

[ConditionalFact]
public virtual void Count_after_GroupBy()
{
AssertSingleResult<Order>(
os =>os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Count());
}

[ConditionalFact]
public virtual void LongCount_after_client_GroupBy()
{
AssertSingleResult<Order>(
os => (from o in os
group o by new { o.CustomerID } into g
select g.Where(e => e.OrderID < 10300).Count()).LongCount());
}

#endregion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,29 @@ FROM [Orders] AS [o0]
ORDER BY [o0].[OrderID], [o0].[OrderDate]");
}

public override void Count_after_GroupBy()
{
base.Count_after_GroupBy();

AssertSql(
@"SELECT COUNT(*)
FROM (
SELECT SUM([o].[OrderID]) AS [c]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]
) AS [t]");
}

public override void LongCount_after_client_GroupBy()
{
base.LongCount_after_client_GroupBy();

AssertSql(
@"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
ORDER BY [o].[CustomerID]");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down

0 comments on commit eb07c40

Please sign in to comment.