diff --git a/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs b/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs
index 5c20d5e7eea..7de27f487e6 100644
--- a/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs
+++ b/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs
@@ -169,7 +169,7 @@ private static Expression HandleAll(HandlerContext handlerContext)
var sqlTranslatingVisitor
= handlerContext.CreateSqlTranslatingVisitor();
- PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
+ PrepareSelectExpressionForAggregate(handlerContext);
var predicate
= sqlTranslatingVisitor.Visit(
@@ -224,7 +224,7 @@ private static Expression HandleAverage(HandlerContext handlerContext)
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
&& handlerContext.SelectExpression.Projection.Count == 1)
{
- PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
+ PrepareSelectExpressionForAggregate(handlerContext);
var expression = handlerContext.SelectExpression.Projection.First();
@@ -344,7 +344,7 @@ var collectionSelectExpression
private static Expression HandleCount(HandlerContext handlerContext)
{
- PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
+ PrepareSelectExpressionForAggregate(handlerContext);
handlerContext.SelectExpression
.SetProjectionExpression(
@@ -480,8 +480,7 @@ var sqlExpression
&& shapedQueryMethod.Method.MethodIsClosedFormOf(
handlerContext.QueryModelVisitor.QueryCompilationContext.QueryMethodProvider.ShapedQueryMethod))
{
- var selectExpression = handlerContext.SelectExpression;
- PrepareSelectExpressionForAggregate(selectExpression);
+ PrepareSelectExpressionForAggregate(handlerContext);
// GroupBy Aggregate
// TODO: InjectParameters type Expression.
@@ -518,6 +517,8 @@ var sqlExpression
break;
}
+ var selectExpression = handlerContext.SelectExpression;
+
if (key != null
|| groupResultOperator.KeySelector is ConstantExpression
|| groupResultOperator.KeySelector is ParameterExpression)
@@ -594,16 +595,14 @@ var sqlExpression
if (sqlExpression != null)
{
- var selectExpression = handlerContext.SelectExpression;
-
- PrepareSelectExpressionForAggregate(selectExpression);
+ PrepareSelectExpressionForAggregate(handlerContext);
sqlExpression
= sqlTranslatingExpressionVisitor.Visit(groupResultOperator.KeySelector);
var columns = (sqlExpression as ConstantExpression)?.Value as Expression[] ?? new[] { sqlExpression };
- selectExpression.PrependToOrderBy(columns.Select(c => new Ordering(c, OrderingDirection.Asc)));
+ handlerContext.SelectExpression.PrependToOrderBy(columns.Select(c => new Ordering(c, OrderingDirection.Asc)));
handlerContext.QueryModelVisitor.RequiresStreamingGroupResultOperator = true;
}
@@ -834,7 +833,7 @@ private static Expression HandleLast(HandlerContext handlerContext)
private static Expression HandleLongCount(HandlerContext handlerContext)
{
- PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
+ PrepareSelectExpressionForAggregate(handlerContext);
handlerContext.SelectExpression
.SetProjectionExpression(
@@ -853,7 +852,7 @@ private static Expression HandleMin(HandlerContext handlerContext)
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
&& handlerContext.SelectExpression.Projection.Count == 1)
{
- PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
+ PrepareSelectExpressionForAggregate(handlerContext);
var expression = handlerContext.SelectExpression.Projection.First();
if (!(expression.RemoveConvert() is SelectExpression))
@@ -882,7 +881,7 @@ private static Expression HandleMax(HandlerContext handlerContext)
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
&& handlerContext.SelectExpression.Projection.Count == 1)
{
- PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
+ PrepareSelectExpressionForAggregate(handlerContext);
var expression = handlerContext.SelectExpression.Projection.First();
if (!(expression.RemoveConvert() is SelectExpression))
@@ -971,7 +970,7 @@ private static Expression HandleSum(HandlerContext handlerContext)
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
&& handlerContext.SelectExpression.Projection.Count == 1)
{
- PrepareSelectExpressionForAggregate(handlerContext.SelectExpression);
+ PrepareSelectExpressionForAggregate(handlerContext);
var expression = handlerContext.SelectExpression.Projection.First();
if (!(expression.RemoveConvert() is SelectExpression))
@@ -1046,17 +1045,26 @@ private static void SetConditionAsProjection(
typeof(bool)));
}
- private static void PrepareSelectExpressionForAggregate(SelectExpression selectExpression)
+ private static void PrepareSelectExpressionForAggregate(HandlerContext handlerContext)
{
- if (selectExpression.IsDistinct
- || selectExpression.Limit != null
- || selectExpression.Offset != null)
+ var legacyBehavior12351 = AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12351", out var isEnabled) && isEnabled;
+
+ if (handlerContext.SelectExpression.IsDistinct
+ || handlerContext.SelectExpression.Limit != null
+ || handlerContext.SelectExpression.Offset != null
+ || (handlerContext.SelectExpression.GroupBy.Any()
+ && !IsGroupByAggregate(handlerContext.QueryModel)
+ && !legacyBehavior12351))
{
- selectExpression.PushDownSubquery();
- selectExpression.ExplodeStarProjection();
+ handlerContext.SelectExpression.PushDownSubquery();
+ handlerContext.SelectExpression.ExplodeStarProjection();
}
}
+ private static bool IsGroupByAggregate(QueryModel queryModel)
+ => queryModel.MainFromClause.FromExpression is QuerySourceReferenceExpression mainFromClauseQsre
+ && mainFromClauseQsre.ReferencedQuerySource.ItemType.IsGrouping();
+
private static Expression UnwrapAliasExpression(Expression expression)
=> (expression as AliasExpression)?.Expression ?? expression;
diff --git a/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs
index adff96258d4..5a1439f3c9b 100644
--- a/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs
+++ b/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs
@@ -840,6 +840,8 @@ var subQueryModelVisitor
var subSelectExpression = subQueryModelVisitor.Queries.First();
AddQuery(querySource, subSelectExpression);
+ RequiresStreamingGroupResultOperator = true;
+
return subQueryModelVisitor.Expression;
}
diff --git a/src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs b/src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs
index 0096630e366..0b1f5cdfea3 100644
--- a/src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs
+++ b/src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs
@@ -518,19 +518,19 @@ private Expression ApplyNullSemantics(Expression expression)
/// The projection expression.
protected virtual void GenerateProjection([NotNull] Expression projection)
{
- if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12175", out var isEnabled) && isEnabled)
+ //if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12175", out var isEnabled) && isEnabled)
{
Visit(
ApplyOptimizations(
ApplyExplicitCastToBoolInProjectionOptimization(projection),
searchCondition: false));
}
- else
- {
- Visit(
- ApplyExplicitCastToBoolInProjectionOptimization(
- ApplyOptimizations(projection, searchCondition: false)));
- }
+ //else
+ //{
+ // Visit(
+ // ApplyExplicitCastToBoolInProjectionOptimization(
+ // ApplyOptimizations(projection, searchCondition: false)));
+ //}
}
///
diff --git a/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs b/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs
index 4becbac49a5..a875e91d8ea 100644
--- a/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs
+++ b/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs
@@ -1990,5 +1990,24 @@ public virtual async Task Double_GroupBy_with_aggregate()
#endregion
+ #region ResultOperatorsAfterGroupBy
+
+ [ConditionalFact]
+ public virtual async Task Count_after_GroupBy()
+ {
+ await AssertSingleResult(
+ os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).CountAsync());
+ }
+
+ [ConditionalFact]
+ public virtual async Task LongCount_after_client_GroupBy()
+ {
+ await AssertSingleResult(
+ os => (from o in os
+ group o by new { o.CustomerID } into g
+ select g.Where(e => e.OrderID < 10300).Count()).LongCountAsync());
+ }
+
+ #endregion
}
}
diff --git a/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs b/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs
index 048d4846401..4fcecb73b5a 100644
--- a/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs
+++ b/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs
@@ -1994,5 +1994,45 @@ public virtual void Double_GroupBy_with_aggregate()
}
#endregion
+
+ #region ResultOperatorsAfterGroupBy
+
+ [ConditionalFact]
+ public virtual void Count_after_GroupBy_aggregate()
+ {
+ AssertSingleResult(
+ os =>os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Count());
+ }
+
+ [ConditionalFact]
+ public virtual void LongCount_after_client_GroupBy()
+ {
+ AssertSingleResult(
+ os => (from o in os
+ group o by new { o.CustomerID } into g
+ select g.Where(e => e.OrderID < 10300).Count()).LongCount());
+ }
+
+ [ConditionalFact]
+ public virtual void MinMax_after_GroupBy_aggregate()
+ {
+ AssertSingleResult(
+ os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Min());
+
+ AssertSingleResult(
+ os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Max());
+ }
+
+ [ConditionalFact]
+ public virtual void AllAny_after_GroupBy_aggregate()
+ {
+ AssertSingleResult(
+ os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).All(ee => true));
+
+ AssertSingleResult(
+ os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Any());
+ }
+
+ #endregion
}
}
diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs
index 20d0f35000b..615d3f3fffc 100644
--- a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs
+++ b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs
@@ -7437,24 +7437,24 @@ WHERE [g0].[Discriminator] IN (N'Officer', N'Gear')
ORDER BY [t].[c] DESC, [t].[Nickname], [t].[SquadId], [t].[FullName]");
}
- [ConditionalFact]
- public virtual void Correlated_collection_with_complex_order_by_funcletized_to_constant_bool_legacy_behavior()
- {
- var nicknames = new List();
- AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", true);
- try
- {
- Assert.Throws(
- () => AssertQuery(
- gs => from g in gs
- orderby nicknames.Contains(g.Nickname) descending
- select new { g.Nickname, Weapons = g.Weapons.Select(w => w.Name).ToList() }));
- }
- finally
- {
- AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", false);
- }
- }
+ //[ConditionalFact]
+ //public virtual void Correlated_collection_with_complex_order_by_funcletized_to_constant_bool_legacy_behavior()
+ //{
+ // var nicknames = new List();
+ // AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", true);
+ // try
+ // {
+ // Assert.Throws(
+ // () => AssertQuery(
+ // gs => from g in gs
+ // orderby nicknames.Contains(g.Nickname) descending
+ // select new { g.Nickname, Weapons = g.Weapons.Select(w => w.Id).ToList() }));
+ // }
+ // finally
+ // {
+ // AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", false);
+ // }
+ //}
private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs
index 6580b19eb0e..6aece8cc46f 100644
--- a/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs
+++ b/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs
@@ -1,7 +1,11 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+using System;
+using System.Linq;
+using Microsoft.EntityFrameworkCore.TestModels.Northwind;
using Microsoft.EntityFrameworkCore.TestUtilities;
+using Microsoft.EntityFrameworkCore.TestUtilities.Xunit;
using Xunit;
using Xunit.Abstractions;
@@ -1759,6 +1763,97 @@ FROM [Orders] AS [o0]
ORDER BY [o0].[OrderID], [o0].[OrderDate]");
}
+ public override void Count_after_GroupBy_aggregate()
+ {
+ base.Count_after_GroupBy_aggregate();
+
+ AssertSql(
+ @"SELECT COUNT(*)
+FROM (
+ SELECT SUM([o].[OrderID]) AS [c]
+ FROM [Orders] AS [o]
+ GROUP BY [o].[CustomerID]
+) AS [t]");
+ }
+
+ public override void LongCount_after_client_GroupBy()
+ {
+ base.LongCount_after_client_GroupBy();
+
+ AssertSql(
+ @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
+FROM [Orders] AS [o]
+ORDER BY [o].[CustomerID]");
+ }
+
+ public override void MinMax_after_GroupBy_aggregate()
+ {
+ base.MinMax_after_GroupBy_aggregate();
+
+ AssertSql(
+ @"SELECT MIN([t].[c])
+FROM (
+ SELECT SUM([o].[OrderID]) AS [c]
+ FROM [Orders] AS [o]
+ GROUP BY [o].[CustomerID]
+) AS [t]",
+ //
+ @"SELECT MAX([t].[c])
+FROM (
+ SELECT SUM([o].[OrderID]) AS [c]
+ FROM [Orders] AS [o]
+ GROUP BY [o].[CustomerID]
+) AS [t]");
+ }
+
+ public override void AllAny_after_GroupBy_aggregate()
+ {
+ base.AllAny_after_GroupBy_aggregate();
+
+ AssertSql(
+ @"SELECT CASE
+ WHEN NOT EXISTS (
+ SELECT 1
+ FROM (
+ SELECT SUM([o].[OrderID]) AS [c]
+ FROM [Orders] AS [o]
+ GROUP BY [o].[CustomerID]
+ ) AS [t]
+ WHERE 0 = 1)
+ THEN CAST(1 AS BIT) ELSE CAST(0 AS BIT)
+END",
+ //
+ @"SELECT CASE
+ WHEN EXISTS (
+ SELECT 1
+ FROM [Orders] AS [o]
+ GROUP BY [o].[CustomerID])
+ THEN CAST(1 AS BIT) ELSE CAST(0 AS BIT)
+END");
+ }
+
+ [ConditionalFact]
+ public virtual void Count_after_GroupBy_aggregate_legacy_behavior()
+ {
+ AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12351", true);
+
+ try
+ {
+ AssertSingleResult(
+ os => os.OrderBy(o => o.CustomerID).GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Count(),
+ os => 6);
+
+ AssertSql(
+ @"SELECT COUNT(*)
+FROM [Orders] AS [o]
+GROUP BY [o].[CustomerID]");
+ }
+ finally
+ {
+ AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12351", false);
+ }
+ }
+
private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);