Skip to content

Commit

Permalink
Fix to #12351 - Incorrect SQL generated for Count over Group By (EFCo…
Browse files Browse the repository at this point in the history
…re 2.1)

There was two problems here:
- we were not lifting a select expression that had a GroupBy-Aggregate pattern, and another result operator was composed on top,
- we were not propagating RequiresStreamingGroupBy flag to a parent QM, when we lifted a client group by subquery, which could result in the client methods being wiped out if Count/LongCount was composed on top.
  • Loading branch information
maumar committed Jul 1, 2018
1 parent 15bc370 commit 207083c
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ private static Expression HandleAll(HandlerContext handlerContext)
var sqlTranslatingVisitor
= handlerContext.CreateSqlTranslatingVisitor();

PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);

var predicate
= sqlTranslatingVisitor.Visit(
Expand Down Expand Up @@ -224,7 +224,7 @@ private static Expression HandleAverage(HandlerContext handlerContext)
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
&& handlerContext.SelectExpression.Projection.Count == 1)
{
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);

var expression = handlerContext.SelectExpression.Projection.First();

Expand Down Expand Up @@ -344,7 +344,7 @@ var collectionSelectExpression

private static Expression HandleCount(HandlerContext handlerContext)
{
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);

handlerContext.SelectExpression
.SetProjectionExpression(
Expand Down Expand Up @@ -480,8 +480,7 @@ var sqlExpression
&& shapedQueryMethod.Method.MethodIsClosedFormOf(
handlerContext.QueryModelVisitor.QueryCompilationContext.QueryMethodProvider.ShapedQueryMethod))
{
var selectExpression = handlerContext.SelectExpression;
PrepareSelectExpressionForAggregate(selectExpression);
PrepareSelectExpressionForAggregate(handlerContext);

// GroupBy Aggregate
// TODO: InjectParameters type Expression.
Expand Down Expand Up @@ -518,6 +517,8 @@ var sqlExpression
break;
}

var selectExpression = handlerContext.SelectExpression;

if (key != null
|| groupResultOperator.KeySelector is ConstantExpression
|| groupResultOperator.KeySelector is ParameterExpression)
Expand Down Expand Up @@ -594,16 +595,14 @@ var sqlExpression

if (sqlExpression != null)
{
var selectExpression = handlerContext.SelectExpression;

PrepareSelectExpressionForAggregate(selectExpression);
PrepareSelectExpressionForAggregate(handlerContext);

sqlExpression
= sqlTranslatingExpressionVisitor.Visit(groupResultOperator.KeySelector);

var columns = (sqlExpression as ConstantExpression)?.Value as Expression[] ?? new[] { sqlExpression };

selectExpression.PrependToOrderBy(columns.Select(c => new Ordering(c, OrderingDirection.Asc)));
handlerContext.SelectExpression.PrependToOrderBy(columns.Select(c => new Ordering(c, OrderingDirection.Asc)));

handlerContext.QueryModelVisitor.RequiresStreamingGroupResultOperator = true;
}
Expand Down Expand Up @@ -834,7 +833,7 @@ private static Expression HandleLast(HandlerContext handlerContext)

private static Expression HandleLongCount(HandlerContext handlerContext)
{
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);

handlerContext.SelectExpression
.SetProjectionExpression(
Expand All @@ -853,7 +852,7 @@ private static Expression HandleMin(HandlerContext handlerContext)
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
&& handlerContext.SelectExpression.Projection.Count == 1)
{
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);
var expression = handlerContext.SelectExpression.Projection.First();

if (!(expression.RemoveConvert() is SelectExpression))
Expand Down Expand Up @@ -882,7 +881,7 @@ private static Expression HandleMax(HandlerContext handlerContext)
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
&& handlerContext.SelectExpression.Projection.Count == 1)
{
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);
var expression = handlerContext.SelectExpression.Projection.First();

if (!(expression.RemoveConvert() is SelectExpression))
Expand Down Expand Up @@ -971,7 +970,7 @@ private static Expression HandleSum(HandlerContext handlerContext)
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
&& handlerContext.SelectExpression.Projection.Count == 1)
{
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
PrepareSelectExpressionForAggregate(handlerContext);
var expression = handlerContext.SelectExpression.Projection.First();

if (!(expression.RemoveConvert() is SelectExpression))
Expand Down Expand Up @@ -1046,17 +1045,26 @@ private static void SetConditionAsProjection(
typeof(bool)));
}

private static void PrepareSelectExpressionForAggregate(SelectExpression selectExpression)
private static void PrepareSelectExpressionForAggregate(HandlerContext handlerContext)
{
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null)
var legacyBehavior12351 = AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12351", out var isEnabled) && isEnabled;

if (handlerContext.SelectExpression.IsDistinct
|| handlerContext.SelectExpression.Limit != null
|| handlerContext.SelectExpression.Offset != null
|| (handlerContext.SelectExpression.GroupBy.Any()
&& !IsGroupByAggregate(handlerContext.QueryModel)
&& !legacyBehavior12351))
{
selectExpression.PushDownSubquery();
selectExpression.ExplodeStarProjection();
handlerContext.SelectExpression.PushDownSubquery();
handlerContext.SelectExpression.ExplodeStarProjection();
}
}

private static bool IsGroupByAggregate(QueryModel queryModel)
=> queryModel.MainFromClause.FromExpression is QuerySourceReferenceExpression mainFromClauseQsre
&& mainFromClauseQsre.ReferencedQuerySource.ItemType.IsGrouping();

private static Expression UnwrapAliasExpression(Expression expression)
=> (expression as AliasExpression)?.Expression ?? expression;

Expand Down
2 changes: 2 additions & 0 deletions src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,8 @@ var subQueryModelVisitor
var subSelectExpression = subQueryModelVisitor.Queries.First();
AddQuery(querySource, subSelectExpression);

RequiresStreamingGroupResultOperator = true;

return subQueryModelVisitor.Expression;
}

Expand Down
14 changes: 7 additions & 7 deletions src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -518,19 +518,19 @@ private Expression ApplyNullSemantics(Expression expression)
/// <param name="projection"> The projection expression. </param>
protected virtual void GenerateProjection([NotNull] Expression projection)
{
if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12175", out var isEnabled) && isEnabled)
//if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12175", out var isEnabled) && isEnabled)
{
Visit(
ApplyOptimizations(
ApplyExplicitCastToBoolInProjectionOptimization(projection),
searchCondition: false));
}
else
{
Visit(
ApplyExplicitCastToBoolInProjectionOptimization(
ApplyOptimizations(projection, searchCondition: false)));
}
//else
//{
// Visit(
// ApplyExplicitCastToBoolInProjectionOptimization(
// ApplyOptimizations(projection, searchCondition: false)));
//}
}

/// <summary>
Expand Down
19 changes: 19 additions & 0 deletions src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1990,5 +1990,24 @@ public virtual async Task Double_GroupBy_with_aggregate()

#endregion

#region ResultOperatorsAfterGroupBy

[ConditionalFact]
public virtual async Task Count_after_GroupBy()
{
await AssertSingleResult<Order>(
os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).CountAsync());
}

[ConditionalFact]
public virtual async Task LongCount_after_client_GroupBy()
{
await AssertSingleResult<Order>(
os => (from o in os
group o by new { o.CustomerID } into g
select g.Where(e => e.OrderID < 10300).Count()).LongCountAsync());
}

#endregion
}
}
40 changes: 40 additions & 0 deletions src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1994,5 +1994,45 @@ public virtual void Double_GroupBy_with_aggregate()
}

#endregion

#region ResultOperatorsAfterGroupBy

[ConditionalFact]
public virtual void Count_after_GroupBy_aggregate()
{
AssertSingleResult<Order>(
os =>os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Count());
}

[ConditionalFact]
public virtual void LongCount_after_client_GroupBy()
{
AssertSingleResult<Order>(
os => (from o in os
group o by new { o.CustomerID } into g
select g.Where(e => e.OrderID < 10300).Count()).LongCount());
}

[ConditionalFact]
public virtual void MinMax_after_GroupBy_aggregate()
{
AssertSingleResult<Order>(
os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Min());

AssertSingleResult<Order>(
os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Max());
}

[ConditionalFact]
public virtual void AllAny_after_GroupBy_aggregate()
{
AssertSingleResult<Order>(
os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).All(ee => true));

AssertSingleResult<Order>(
os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Any());
}

#endregion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7437,24 +7437,24 @@ WHERE [g0].[Discriminator] IN (N'Officer', N'Gear')
ORDER BY [t].[c] DESC, [t].[Nickname], [t].[SquadId], [t].[FullName]");
}

[ConditionalFact]
public virtual void Correlated_collection_with_complex_order_by_funcletized_to_constant_bool_legacy_behavior()
{
var nicknames = new List<string>();
AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", true);
try
{
Assert.Throws<InvalidOperationException>(
() => AssertQuery<Gear>(
gs => from g in gs
orderby nicknames.Contains(g.Nickname) descending
select new { g.Nickname, Weapons = g.Weapons.Select(w => w.Name).ToList() }));
}
finally
{
AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", false);
}
}
//[ConditionalFact]
//public virtual void Correlated_collection_with_complex_order_by_funcletized_to_constant_bool_legacy_behavior()
//{
// var nicknames = new List<string>();
// AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", true);
// try
// {
// Assert.Throws<InvalidOperationException>(
// () => AssertQuery<Gear>(
// gs => from g in gs
// orderby nicknames.Contains(g.Nickname) descending
// select new { g.Nickname, Weapons = g.Weapons.Select(w => w.Id).ToList() }));
// }
// finally
// {
// AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", false);
// }
//}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq;
using Microsoft.EntityFrameworkCore.TestModels.Northwind;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Microsoft.EntityFrameworkCore.TestUtilities.Xunit;
using Xunit;
using Xunit.Abstractions;

Expand Down Expand Up @@ -1759,6 +1763,97 @@ FROM [Orders] AS [o0]
ORDER BY [o0].[OrderID], [o0].[OrderDate]");
}

public override void Count_after_GroupBy_aggregate()
{
base.Count_after_GroupBy_aggregate();

AssertSql(
@"SELECT COUNT(*)
FROM (
SELECT SUM([o].[OrderID]) AS [c]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]
) AS [t]");
}

public override void LongCount_after_client_GroupBy()
{
base.LongCount_after_client_GroupBy();

AssertSql(
@"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
ORDER BY [o].[CustomerID]");
}

public override void MinMax_after_GroupBy_aggregate()
{
base.MinMax_after_GroupBy_aggregate();

AssertSql(
@"SELECT MIN([t].[c])
FROM (
SELECT SUM([o].[OrderID]) AS [c]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]
) AS [t]",
//
@"SELECT MAX([t].[c])
FROM (
SELECT SUM([o].[OrderID]) AS [c]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]
) AS [t]");
}

public override void AllAny_after_GroupBy_aggregate()
{
base.AllAny_after_GroupBy_aggregate();

AssertSql(
@"SELECT CASE
WHEN NOT EXISTS (
SELECT 1
FROM (
SELECT SUM([o].[OrderID]) AS [c]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]
) AS [t]
WHERE 0 = 1)
THEN CAST(1 AS BIT) ELSE CAST(0 AS BIT)
END",
//
@"SELECT CASE
WHEN EXISTS (
SELECT 1
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID])
THEN CAST(1 AS BIT) ELSE CAST(0 AS BIT)
END");
}

[ConditionalFact]
public virtual void Count_after_GroupBy_aggregate_legacy_behavior()
{
AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12351", true);

try
{
AssertSingleResult<Order>(
os => os.OrderBy(o => o.CustomerID).GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Count(),
os => 6);

AssertSql(
@"SELECT COUNT(*)
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]");
}
finally
{
AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12351", false);
}
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down

0 comments on commit 207083c

Please sign in to comment.