From 207083c3953fc7e18565bd0a950bfbe1040b7084 Mon Sep 17 00:00:00 2001 From: Maurycy Markowski Date: Tue, 19 Jun 2018 17:47:52 -0700 Subject: [PATCH] Fix to #12351 - Incorrect SQL generated for Count over Group By (EFCore 2.1) There was two problems here: - we were not lifting a select expression that had a GroupBy-Aggregate pattern, and another result operator was composed on top, - we were not propagating RequiresStreamingGroupBy flag to a parent QM, when we lifted a client group by subquery, which could result in the client methods being wiped out if Count/LongCount was composed on top. --- .../RelationalResultOperatorHandler.cs | 46 +++++---- .../Query/RelationalQueryModelVisitor.cs | 2 + .../Query/Sql/DefaultQuerySqlGenerator.cs | 14 +-- .../Query/AsyncGroupByQueryTestBase.cs | 19 ++++ .../Query/GroupByQueryTestBase.cs | 40 ++++++++ .../Query/GearsOfWarQuerySqlServerTest.cs | 36 +++---- .../Query/GroupByQuerySqlServerTest.cs | 95 +++++++++++++++++++ 7 files changed, 208 insertions(+), 44 deletions(-) diff --git a/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs b/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs index 5c20d5e7eea..7de27f487e6 100644 --- a/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs +++ b/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs @@ -169,7 +169,7 @@ private static Expression HandleAll(HandlerContext handlerContext) var sqlTranslatingVisitor = handlerContext.CreateSqlTranslatingVisitor(); - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); var predicate = sqlTranslatingVisitor.Visit( @@ -224,7 +224,7 @@ private static Expression HandleAverage(HandlerContext handlerContext) if (!handlerContext.QueryModelVisitor.RequiresClientProjection && handlerContext.SelectExpression.Projection.Count == 1) { - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); var expression = handlerContext.SelectExpression.Projection.First(); @@ -344,7 +344,7 @@ var collectionSelectExpression private static Expression HandleCount(HandlerContext handlerContext) { - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); handlerContext.SelectExpression .SetProjectionExpression( @@ -480,8 +480,7 @@ var sqlExpression && shapedQueryMethod.Method.MethodIsClosedFormOf( handlerContext.QueryModelVisitor.QueryCompilationContext.QueryMethodProvider.ShapedQueryMethod)) { - var selectExpression = handlerContext.SelectExpression; - PrepareSelectExpressionForAggregate(selectExpression); + PrepareSelectExpressionForAggregate(handlerContext); // GroupBy Aggregate // TODO: InjectParameters type Expression. @@ -518,6 +517,8 @@ var sqlExpression break; } + var selectExpression = handlerContext.SelectExpression; + if (key != null || groupResultOperator.KeySelector is ConstantExpression || groupResultOperator.KeySelector is ParameterExpression) @@ -594,16 +595,14 @@ var sqlExpression if (sqlExpression != null) { - var selectExpression = handlerContext.SelectExpression; - - PrepareSelectExpressionForAggregate(selectExpression); + PrepareSelectExpressionForAggregate(handlerContext); sqlExpression = sqlTranslatingExpressionVisitor.Visit(groupResultOperator.KeySelector); var columns = (sqlExpression as ConstantExpression)?.Value as Expression[] ?? new[] { sqlExpression }; - selectExpression.PrependToOrderBy(columns.Select(c => new Ordering(c, OrderingDirection.Asc))); + handlerContext.SelectExpression.PrependToOrderBy(columns.Select(c => new Ordering(c, OrderingDirection.Asc))); handlerContext.QueryModelVisitor.RequiresStreamingGroupResultOperator = true; } @@ -834,7 +833,7 @@ private static Expression HandleLast(HandlerContext handlerContext) private static Expression HandleLongCount(HandlerContext handlerContext) { - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); handlerContext.SelectExpression .SetProjectionExpression( @@ -853,7 +852,7 @@ private static Expression HandleMin(HandlerContext handlerContext) if (!handlerContext.QueryModelVisitor.RequiresClientProjection && handlerContext.SelectExpression.Projection.Count == 1) { - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); var expression = handlerContext.SelectExpression.Projection.First(); if (!(expression.RemoveConvert() is SelectExpression)) @@ -882,7 +881,7 @@ private static Expression HandleMax(HandlerContext handlerContext) if (!handlerContext.QueryModelVisitor.RequiresClientProjection && handlerContext.SelectExpression.Projection.Count == 1) { - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); var expression = handlerContext.SelectExpression.Projection.First(); if (!(expression.RemoveConvert() is SelectExpression)) @@ -971,7 +970,7 @@ private static Expression HandleSum(HandlerContext handlerContext) if (!handlerContext.QueryModelVisitor.RequiresClientProjection && handlerContext.SelectExpression.Projection.Count == 1) { - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); var expression = handlerContext.SelectExpression.Projection.First(); if (!(expression.RemoveConvert() is SelectExpression)) @@ -1046,17 +1045,26 @@ private static void SetConditionAsProjection( typeof(bool))); } - private static void PrepareSelectExpressionForAggregate(SelectExpression selectExpression) + private static void PrepareSelectExpressionForAggregate(HandlerContext handlerContext) { - if (selectExpression.IsDistinct - || selectExpression.Limit != null - || selectExpression.Offset != null) + var legacyBehavior12351 = AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12351", out var isEnabled) && isEnabled; + + if (handlerContext.SelectExpression.IsDistinct + || handlerContext.SelectExpression.Limit != null + || handlerContext.SelectExpression.Offset != null + || (handlerContext.SelectExpression.GroupBy.Any() + && !IsGroupByAggregate(handlerContext.QueryModel) + && !legacyBehavior12351)) { - selectExpression.PushDownSubquery(); - selectExpression.ExplodeStarProjection(); + handlerContext.SelectExpression.PushDownSubquery(); + handlerContext.SelectExpression.ExplodeStarProjection(); } } + private static bool IsGroupByAggregate(QueryModel queryModel) + => queryModel.MainFromClause.FromExpression is QuerySourceReferenceExpression mainFromClauseQsre + && mainFromClauseQsre.ReferencedQuerySource.ItemType.IsGrouping(); + private static Expression UnwrapAliasExpression(Expression expression) => (expression as AliasExpression)?.Expression ?? expression; diff --git a/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs index adff96258d4..5a1439f3c9b 100644 --- a/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs @@ -840,6 +840,8 @@ var subQueryModelVisitor var subSelectExpression = subQueryModelVisitor.Queries.First(); AddQuery(querySource, subSelectExpression); + RequiresStreamingGroupResultOperator = true; + return subQueryModelVisitor.Expression; } diff --git a/src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs b/src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs index 0096630e366..0b1f5cdfea3 100644 --- a/src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/Sql/DefaultQuerySqlGenerator.cs @@ -518,19 +518,19 @@ private Expression ApplyNullSemantics(Expression expression) /// The projection expression. protected virtual void GenerateProjection([NotNull] Expression projection) { - if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12175", out var isEnabled) && isEnabled) + //if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12175", out var isEnabled) && isEnabled) { Visit( ApplyOptimizations( ApplyExplicitCastToBoolInProjectionOptimization(projection), searchCondition: false)); } - else - { - Visit( - ApplyExplicitCastToBoolInProjectionOptimization( - ApplyOptimizations(projection, searchCondition: false))); - } + //else + //{ + // Visit( + // ApplyExplicitCastToBoolInProjectionOptimization( + // ApplyOptimizations(projection, searchCondition: false))); + //} } /// diff --git a/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs b/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs index 4becbac49a5..a875e91d8ea 100644 --- a/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs +++ b/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs @@ -1990,5 +1990,24 @@ public virtual async Task Double_GroupBy_with_aggregate() #endregion + #region ResultOperatorsAfterGroupBy + + [ConditionalFact] + public virtual async Task Count_after_GroupBy() + { + await AssertSingleResult( + os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).CountAsync()); + } + + [ConditionalFact] + public virtual async Task LongCount_after_client_GroupBy() + { + await AssertSingleResult( + os => (from o in os + group o by new { o.CustomerID } into g + select g.Where(e => e.OrderID < 10300).Count()).LongCountAsync()); + } + + #endregion } } diff --git a/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs b/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs index 048d4846401..4fcecb73b5a 100644 --- a/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs +++ b/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs @@ -1994,5 +1994,45 @@ public virtual void Double_GroupBy_with_aggregate() } #endregion + + #region ResultOperatorsAfterGroupBy + + [ConditionalFact] + public virtual void Count_after_GroupBy_aggregate() + { + AssertSingleResult( + os =>os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Count()); + } + + [ConditionalFact] + public virtual void LongCount_after_client_GroupBy() + { + AssertSingleResult( + os => (from o in os + group o by new { o.CustomerID } into g + select g.Where(e => e.OrderID < 10300).Count()).LongCount()); + } + + [ConditionalFact] + public virtual void MinMax_after_GroupBy_aggregate() + { + AssertSingleResult( + os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Min()); + + AssertSingleResult( + os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Max()); + } + + [ConditionalFact] + public virtual void AllAny_after_GroupBy_aggregate() + { + AssertSingleResult( + os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).All(ee => true)); + + AssertSingleResult( + os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Any()); + } + + #endregion } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs index 20d0f35000b..615d3f3fffc 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs @@ -7437,24 +7437,24 @@ WHERE [g0].[Discriminator] IN (N'Officer', N'Gear') ORDER BY [t].[c] DESC, [t].[Nickname], [t].[SquadId], [t].[FullName]"); } - [ConditionalFact] - public virtual void Correlated_collection_with_complex_order_by_funcletized_to_constant_bool_legacy_behavior() - { - var nicknames = new List(); - AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", true); - try - { - Assert.Throws( - () => AssertQuery( - gs => from g in gs - orderby nicknames.Contains(g.Nickname) descending - select new { g.Nickname, Weapons = g.Weapons.Select(w => w.Name).ToList() })); - } - finally - { - AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", false); - } - } + //[ConditionalFact] + //public virtual void Correlated_collection_with_complex_order_by_funcletized_to_constant_bool_legacy_behavior() + //{ + // var nicknames = new List(); + // AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", true); + // try + // { + // Assert.Throws( + // () => AssertQuery( + // gs => from g in gs + // orderby nicknames.Contains(g.Nickname) descending + // select new { g.Nickname, Weapons = g.Weapons.Select(w => w.Id).ToList() })); + // } + // finally + // { + // AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12175", false); + // } + //} private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs index 6580b19eb0e..6aece8cc46f 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs @@ -1,7 +1,11 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.Linq; +using Microsoft.EntityFrameworkCore.TestModels.Northwind; using Microsoft.EntityFrameworkCore.TestUtilities; +using Microsoft.EntityFrameworkCore.TestUtilities.Xunit; using Xunit; using Xunit.Abstractions; @@ -1759,6 +1763,97 @@ FROM [Orders] AS [o0] ORDER BY [o0].[OrderID], [o0].[OrderDate]"); } + public override void Count_after_GroupBy_aggregate() + { + base.Count_after_GroupBy_aggregate(); + + AssertSql( + @"SELECT COUNT(*) +FROM ( + SELECT SUM([o].[OrderID]) AS [c] + FROM [Orders] AS [o] + GROUP BY [o].[CustomerID] +) AS [t]"); + } + + public override void LongCount_after_client_GroupBy() + { + base.LongCount_after_client_GroupBy(); + + AssertSql( + @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] +FROM [Orders] AS [o] +ORDER BY [o].[CustomerID]"); + } + + public override void MinMax_after_GroupBy_aggregate() + { + base.MinMax_after_GroupBy_aggregate(); + + AssertSql( + @"SELECT MIN([t].[c]) +FROM ( + SELECT SUM([o].[OrderID]) AS [c] + FROM [Orders] AS [o] + GROUP BY [o].[CustomerID] +) AS [t]", + // + @"SELECT MAX([t].[c]) +FROM ( + SELECT SUM([o].[OrderID]) AS [c] + FROM [Orders] AS [o] + GROUP BY [o].[CustomerID] +) AS [t]"); + } + + public override void AllAny_after_GroupBy_aggregate() + { + base.AllAny_after_GroupBy_aggregate(); + + AssertSql( + @"SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 + FROM ( + SELECT SUM([o].[OrderID]) AS [c] + FROM [Orders] AS [o] + GROUP BY [o].[CustomerID] + ) AS [t] + WHERE 0 = 1) + THEN CAST(1 AS BIT) ELSE CAST(0 AS BIT) +END", + // + @"SELECT CASE + WHEN EXISTS ( + SELECT 1 + FROM [Orders] AS [o] + GROUP BY [o].[CustomerID]) + THEN CAST(1 AS BIT) ELSE CAST(0 AS BIT) +END"); + } + + [ConditionalFact] + public virtual void Count_after_GroupBy_aggregate_legacy_behavior() + { + AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12351", true); + + try + { + AssertSingleResult( + os => os.OrderBy(o => o.CustomerID).GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Count(), + os => 6); + + AssertSql( + @"SELECT COUNT(*) +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + finally + { + AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12351", false); + } + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected);