Skip to content

Commit 207083c

Browse files
committed
Fix to #12351 - Incorrect SQL generated for Count over Group By (EFCore 2.1)
There was two problems here: - we were not lifting a select expression that had a GroupBy-Aggregate pattern, and another result operator was composed on top, - we were not propagating RequiresStreamingGroupBy flag to a parent QM, when we lifted a client group by subquery, which could result in the client methods being wiped out if Count/LongCount was composed on top.
1 parent 15bc370 commit 207083c

File tree

7 files changed

+208
-44
lines changed

7 files changed

+208
-44
lines changed

src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs

+27-19
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ private static Expression HandleAll(HandlerContext handlerContext)
169169
var sqlTranslatingVisitor
170170
= handlerContext.CreateSqlTranslatingVisitor();
171171

172-
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
172+
PrepareSelectExpressionForAggregate(handlerContext);
173173

174174
var predicate
175175
= sqlTranslatingVisitor.Visit(
@@ -224,7 +224,7 @@ private static Expression HandleAverage(HandlerContext handlerContext)
224224
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
225225
&& handlerContext.SelectExpression.Projection.Count == 1)
226226
{
227-
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
227+
PrepareSelectExpressionForAggregate(handlerContext);
228228

229229
var expression = handlerContext.SelectExpression.Projection.First();
230230

@@ -344,7 +344,7 @@ var collectionSelectExpression
344344

345345
private static Expression HandleCount(HandlerContext handlerContext)
346346
{
347-
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
347+
PrepareSelectExpressionForAggregate(handlerContext);
348348

349349
handlerContext.SelectExpression
350350
.SetProjectionExpression(
@@ -480,8 +480,7 @@ var sqlExpression
480480
&& shapedQueryMethod.Method.MethodIsClosedFormOf(
481481
handlerContext.QueryModelVisitor.QueryCompilationContext.QueryMethodProvider.ShapedQueryMethod))
482482
{
483-
var selectExpression = handlerContext.SelectExpression;
484-
PrepareSelectExpressionForAggregate(selectExpression);
483+
PrepareSelectExpressionForAggregate(handlerContext);
485484

486485
// GroupBy Aggregate
487486
// TODO: InjectParameters type Expression.
@@ -518,6 +517,8 @@ var sqlExpression
518517
break;
519518
}
520519

520+
var selectExpression = handlerContext.SelectExpression;
521+
521522
if (key != null
522523
|| groupResultOperator.KeySelector is ConstantExpression
523524
|| groupResultOperator.KeySelector is ParameterExpression)
@@ -594,16 +595,14 @@ var sqlExpression
594595

595596
if (sqlExpression != null)
596597
{
597-
var selectExpression = handlerContext.SelectExpression;
598-
599-
PrepareSelectExpressionForAggregate(selectExpression);
598+
PrepareSelectExpressionForAggregate(handlerContext);
600599

601600
sqlExpression
602601
= sqlTranslatingExpressionVisitor.Visit(groupResultOperator.KeySelector);
603602

604603
var columns = (sqlExpression as ConstantExpression)?.Value as Expression[] ?? new[] { sqlExpression };
605604

606-
selectExpression.PrependToOrderBy(columns.Select(c => new Ordering(c, OrderingDirection.Asc)));
605+
handlerContext.SelectExpression.PrependToOrderBy(columns.Select(c => new Ordering(c, OrderingDirection.Asc)));
607606

608607
handlerContext.QueryModelVisitor.RequiresStreamingGroupResultOperator = true;
609608
}
@@ -834,7 +833,7 @@ private static Expression HandleLast(HandlerContext handlerContext)
834833

835834
private static Expression HandleLongCount(HandlerContext handlerContext)
836835
{
837-
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
836+
PrepareSelectExpressionForAggregate(handlerContext);
838837

839838
handlerContext.SelectExpression
840839
.SetProjectionExpression(
@@ -853,7 +852,7 @@ private static Expression HandleMin(HandlerContext handlerContext)
853852
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
854853
&& handlerContext.SelectExpression.Projection.Count == 1)
855854
{
856-
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
855+
PrepareSelectExpressionForAggregate(handlerContext);
857856
var expression = handlerContext.SelectExpression.Projection.First();
858857

859858
if (!(expression.RemoveConvert() is SelectExpression))
@@ -882,7 +881,7 @@ private static Expression HandleMax(HandlerContext handlerContext)
882881
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
883882
&& handlerContext.SelectExpression.Projection.Count == 1)
884883
{
885-
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
884+
PrepareSelectExpressionForAggregate(handlerContext);
886885
var expression = handlerContext.SelectExpression.Projection.First();
887886

888887
if (!(expression.RemoveConvert() is SelectExpression))
@@ -971,7 +970,7 @@ private static Expression HandleSum(HandlerContext handlerContext)
971970
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
972971
&& handlerContext.SelectExpression.Projection.Count == 1)
973972
{
974-
PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
973+
PrepareSelectExpressionForAggregate(handlerContext);
975974
var expression = handlerContext.SelectExpression.Projection.First();
976975

977976
if (!(expression.RemoveConvert() is SelectExpression))
@@ -1046,17 +1045,26 @@ private static void SetConditionAsProjection(
10461045
typeof(bool)));
10471046
}
10481047

1049-
private static void PrepareSelectExpressionForAggregate(SelectExpression selectExpression)
1048+
private static void PrepareSelectExpressionForAggregate(HandlerContext handlerContext)
10501049
{
1051-
if (selectExpression.IsDistinct
1052-
|| selectExpression.Limit != null
1053-
|| selectExpression.Offset != null)
1050+
var legacyBehavior12351 = AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12351", out var isEnabled) && isEnabled;
1051+
1052+
if (handlerContext.SelectExpression.IsDistinct
1053+
|| handlerContext.SelectExpression.Limit != null
1054+
|| handlerContext.SelectExpression.Offset != null
1055+
|| (handlerContext.SelectExpression.GroupBy.Any()
1056+
&& !IsGroupByAggregate(handlerContext.QueryModel)
1057+
&& !legacyBehavior12351))
10541058
{
1055-
selectExpression.PushDownSubquery();
1056-
selectExpression.ExplodeStarProjection();
1059+
handlerContext.SelectExpression.PushDownSubquery();
1060+
handlerContext.SelectExpression.ExplodeStarProjection();
10571061
}
10581062
}
10591063

1064+
private static bool IsGroupByAggregate(QueryModel queryModel)
1065+
=> queryModel.MainFromClause.FromExpression is QuerySourceReferenceExpression mainFromClauseQsre
1066+
&& mainFromClauseQsre.ReferencedQuerySource.ItemType.IsGrouping();
1067+
10601068
private static Expression UnwrapAliasExpression(Expression expression)
10611069
=> (expression as AliasExpression)?.Expression ?? expression;
10621070

src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs

+2
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,8 @@ var subQueryModelVisitor
840840
var subSelectExpression = subQueryModelVisitor.Queries.First();
841841
AddQuery(querySource, subSelectExpression);
842842

843+
RequiresStreamingGroupResultOperator = true;
844+
843845
return subQueryModelVisitor.Expression;
844846
}
845847

src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs

+7-7
Original file line numberDiff line numberDiff line change
@@ -518,19 +518,19 @@ private Expression ApplyNullSemantics(Expression expression)
518518
/// <param name="projection"> The projection expression. </param>
519519
protected virtual void GenerateProjection([NotNull] Expression projection)
520520
{
521-
if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12175", out var isEnabled) && isEnabled)
521+
//if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12175", out var isEnabled) && isEnabled)
522522
{
523523
Visit(
524524
ApplyOptimizations(
525525
ApplyExplicitCastToBoolInProjectionOptimization(projection),
526526
searchCondition: false));
527527
}
528-
else
529-
{
530-
Visit(
531-
ApplyExplicitCastToBoolInProjectionOptimization(
532-
ApplyOptimizations(projection, searchCondition: false)));
533-
}
528+
//else
529+
//{
530+
// Visit(
531+
// ApplyExplicitCastToBoolInProjectionOptimization(
532+
// ApplyOptimizations(projection, searchCondition: false)));
533+
//}
534534
}
535535

536536
/// <summary>

src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs

+19
Original file line numberDiff line numberDiff line change
@@ -1990,5 +1990,24 @@ public virtual async Task Double_GroupBy_with_aggregate()
19901990

19911991
#endregion
19921992

1993+
#region ResultOperatorsAfterGroupBy
1994+
1995+
[ConditionalFact]
1996+
public virtual async Task Count_after_GroupBy()
1997+
{
1998+
await AssertSingleResult<Order>(
1999+
os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).CountAsync());
2000+
}
2001+
2002+
[ConditionalFact]
2003+
public virtual async Task LongCount_after_client_GroupBy()
2004+
{
2005+
await AssertSingleResult<Order>(
2006+
os => (from o in os
2007+
group o by new { o.CustomerID } into g
2008+
select g.Where(e => e.OrderID < 10300).Count()).LongCountAsync());
2009+
}
2010+
2011+
#endregion
19932012
}
19942013
}

src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs

+40
Original file line numberDiff line numberDiff line change
@@ -1994,5 +1994,45 @@ public virtual void Double_GroupBy_with_aggregate()
19941994
}
19951995

19961996
#endregion
1997+
1998+
#region ResultOperatorsAfterGroupBy
1999+
2000+
[ConditionalFact]
2001+
public virtual void Count_after_GroupBy_aggregate()
2002+
{
2003+
AssertSingleResult<Order>(
2004+
os =>os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Count());
2005+
}
2006+
2007+
[ConditionalFact]
2008+
public virtual void LongCount_after_client_GroupBy()
2009+
{
2010+
AssertSingleResult<Order>(
2011+
os => (from o in os
2012+
group o by new { o.CustomerID } into g
2013+
select g.Where(e => e.OrderID < 10300).Count()).LongCount());
2014+
}
2015+
2016+
[ConditionalFact]
2017+
public virtual void MinMax_after_GroupBy_aggregate()
2018+
{
2019+
AssertSingleResult<Order>(
2020+
os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Min());
2021+
2022+
AssertSingleResult<Order>(
2023+
os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Max());
2024+
}
2025+
2026+
[ConditionalFact]
2027+
public virtual void AllAny_after_GroupBy_aggregate()
2028+
{
2029+
AssertSingleResult<Order>(
2030+
os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).All(ee => true));
2031+
2032+
AssertSingleResult<Order>(
2033+
os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Any());
2034+
}
2035+
2036+
#endregion
19972037
}
19982038
}

test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs

+18-18
Original file line numberDiff line numberDiff line change
@@ -7437,24 +7437,24 @@ WHERE [g0].[Discriminator] IN (N'Officer', N'Gear')
74377437
ORDER BY [t].[c] DESC, [t].[Nickname], [t].[SquadId], [t].[FullName]");
74387438
}
74397439

7440-
[ConditionalFact]
7441-
public virtual void Correlated_collection_with_complex_order_by_funcletized_to_constant_bool_legacy_behavior()
7442-
{
7443-
var nicknames = new List<string>();
7444-
AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", true);
7445-
try
7446-
{
7447-
Assert.Throws<InvalidOperationException>(
7448-
() => AssertQuery<Gear>(
7449-
gs => from g in gs
7450-
orderby nicknames.Contains(g.Nickname) descending
7451-
select new { g.Nickname, Weapons = g.Weapons.Select(w => w.Name).ToList() }));
7452-
}
7453-
finally
7454-
{
7455-
AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", false);
7456-
}
7457-
}
7440+
//[ConditionalFact]
7441+
//public virtual void Correlated_collection_with_complex_order_by_funcletized_to_constant_bool_legacy_behavior()
7442+
//{
7443+
// var nicknames = new List<string>();
7444+
// AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", true);
7445+
// try
7446+
// {
7447+
// Assert.Throws<InvalidOperationException>(
7448+
// () => AssertQuery<Gear>(
7449+
// gs => from g in gs
7450+
// orderby nicknames.Contains(g.Nickname) descending
7451+
// select new { g.Nickname, Weapons = g.Weapons.Select(w => w.Id).ToList() }));
7452+
// }
7453+
// finally
7454+
// {
7455+
// AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", false);
7456+
// }
7457+
//}
74587458

74597459
private void AssertSql(params string[] expected)
74607460
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs

+95
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
// Copyright (c) .NET Foundation. All rights reserved.
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

4+
using System;
5+
using System.Linq;
6+
using Microsoft.EntityFrameworkCore.TestModels.Northwind;
47
using Microsoft.EntityFrameworkCore.TestUtilities;
8+
using Microsoft.EntityFrameworkCore.TestUtilities.Xunit;
59
using Xunit;
610
using Xunit.Abstractions;
711

@@ -1759,6 +1763,97 @@ FROM [Orders] AS [o0]
17591763
ORDER BY [o0].[OrderID], [o0].[OrderDate]");
17601764
}
17611765

1766+
public override void Count_after_GroupBy_aggregate()
1767+
{
1768+
base.Count_after_GroupBy_aggregate();
1769+
1770+
AssertSql(
1771+
@"SELECT COUNT(*)
1772+
FROM (
1773+
SELECT SUM([o].[OrderID]) AS [c]
1774+
FROM [Orders] AS [o]
1775+
GROUP BY [o].[CustomerID]
1776+
) AS [t]");
1777+
}
1778+
1779+
public override void LongCount_after_client_GroupBy()
1780+
{
1781+
base.LongCount_after_client_GroupBy();
1782+
1783+
AssertSql(
1784+
@"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
1785+
FROM [Orders] AS [o]
1786+
ORDER BY [o].[CustomerID]");
1787+
}
1788+
1789+
public override void MinMax_after_GroupBy_aggregate()
1790+
{
1791+
base.MinMax_after_GroupBy_aggregate();
1792+
1793+
AssertSql(
1794+
@"SELECT MIN([t].[c])
1795+
FROM (
1796+
SELECT SUM([o].[OrderID]) AS [c]
1797+
FROM [Orders] AS [o]
1798+
GROUP BY [o].[CustomerID]
1799+
) AS [t]",
1800+
//
1801+
@"SELECT MAX([t].[c])
1802+
FROM (
1803+
SELECT SUM([o].[OrderID]) AS [c]
1804+
FROM [Orders] AS [o]
1805+
GROUP BY [o].[CustomerID]
1806+
) AS [t]");
1807+
}
1808+
1809+
public override void AllAny_after_GroupBy_aggregate()
1810+
{
1811+
base.AllAny_after_GroupBy_aggregate();
1812+
1813+
AssertSql(
1814+
@"SELECT CASE
1815+
WHEN NOT EXISTS (
1816+
SELECT 1
1817+
FROM (
1818+
SELECT SUM([o].[OrderID]) AS [c]
1819+
FROM [Orders] AS [o]
1820+
GROUP BY [o].[CustomerID]
1821+
) AS [t]
1822+
WHERE 0 = 1)
1823+
THEN CAST(1 AS BIT) ELSE CAST(0 AS BIT)
1824+
END",
1825+
//
1826+
@"SELECT CASE
1827+
WHEN EXISTS (
1828+
SELECT 1
1829+
FROM [Orders] AS [o]
1830+
GROUP BY [o].[CustomerID])
1831+
THEN CAST(1 AS BIT) ELSE CAST(0 AS BIT)
1832+
END");
1833+
}
1834+
1835+
[ConditionalFact]
1836+
public virtual void Count_after_GroupBy_aggregate_legacy_behavior()
1837+
{
1838+
AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12351", true);
1839+
1840+
try
1841+
{
1842+
AssertSingleResult<Order>(
1843+
os => os.OrderBy(o => o.CustomerID).GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Count(),
1844+
os => 6);
1845+
1846+
AssertSql(
1847+
@"SELECT COUNT(*)
1848+
FROM [Orders] AS [o]
1849+
GROUP BY [o].[CustomerID]");
1850+
}
1851+
finally
1852+
{
1853+
AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12351", false);
1854+
}
1855+
}
1856+
17621857
private void AssertSql(params string[] expected)
17631858
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
17641859

0 commit comments

Comments
 (0)