From eb07c40924dc3e1f1ba069f47b131e847500f92e Mon Sep 17 00:00:00 2001 From: Maurycy Markowski Date: Tue, 19 Jun 2018 17:47:52 -0700 Subject: [PATCH] Fix to #12351 - Incorrect SQL generated for Count over Group By (EFCore 2.1) There was two problems here: - we were not lifting a select expression that had a GroupBy-Aggregate pattern, and another result operator was composed on top, - we were not propagating RequiresStreamingGroupBy flag to a parent QM, when we lifted a client group by subquery, which could result in the client methods being wiped out if Count/LongCount was composed on top. --- .../RelationalResultOperatorHandler.cs | 43 +++++++++++-------- .../Query/RelationalQueryModelVisitor.cs | 2 + .../Query/AsyncGroupByQueryTestBase.cs | 19 ++++++++ .../Query/GroupByQueryTestBase.cs | 20 +++++++++ .../Query/GroupByQuerySqlServerTest.cs | 23 ++++++++++ 5 files changed, 88 insertions(+), 19 deletions(-) diff --git a/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs b/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs index f916b25176c..cc1e69737f5 100644 --- a/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs +++ b/src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs @@ -169,7 +169,7 @@ private static Expression HandleAll(HandlerContext handlerContext) var sqlTranslatingVisitor = handlerContext.CreateSqlTranslatingVisitor(); - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); var predicate = sqlTranslatingVisitor.Visit( @@ -224,7 +224,7 @@ private static Expression HandleAverage(HandlerContext handlerContext) if (!handlerContext.QueryModelVisitor.RequiresClientProjection && handlerContext.SelectExpression.Projection.Count == 1) { - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); var expression = handlerContext.SelectExpression.Projection.First(); @@ -344,7 +344,7 @@ var collectionSelectExpression private static Expression HandleCount(HandlerContext handlerContext) { - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); handlerContext.SelectExpression .SetProjectionExpression( @@ -480,8 +480,7 @@ var sqlExpression && shapedQueryMethod.Method.MethodIsClosedFormOf( handlerContext.QueryModelVisitor.QueryCompilationContext.QueryMethodProvider.ShapedQueryMethod)) { - var selectExpression = handlerContext.SelectExpression; - PrepareSelectExpressionForAggregate(selectExpression); + PrepareSelectExpressionForAggregate(handlerContext); // GroupBy Aggregate // TODO: InjectParameters type Expression. @@ -518,6 +517,8 @@ var sqlExpression break; } + var selectExpression = handlerContext.SelectExpression; + if (key != null || groupResultOperator.KeySelector is ConstantExpression || groupResultOperator.KeySelector is ParameterExpression) @@ -594,16 +595,14 @@ var sqlExpression if (sqlExpression != null) { - var selectExpression = handlerContext.SelectExpression; - - PrepareSelectExpressionForAggregate(selectExpression); + PrepareSelectExpressionForAggregate(handlerContext); sqlExpression = sqlTranslatingExpressionVisitor.Visit(groupResultOperator.KeySelector); var columns = (sqlExpression as ConstantExpression)?.Value as Expression[] ?? new[] { sqlExpression }; - selectExpression.PrependToOrderBy(columns.Select(c => new Ordering(c, OrderingDirection.Asc))); + handlerContext.SelectExpression.PrependToOrderBy(columns.Select(c => new Ordering(c, OrderingDirection.Asc))); handlerContext.QueryModelVisitor.RequiresStreamingGroupResultOperator = true; } @@ -834,7 +833,7 @@ private static Expression HandleLast(HandlerContext handlerContext) private static Expression HandleLongCount(HandlerContext handlerContext) { - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); handlerContext.SelectExpression .SetProjectionExpression( @@ -853,7 +852,7 @@ private static Expression HandleMin(HandlerContext handlerContext) if (!handlerContext.QueryModelVisitor.RequiresClientProjection && handlerContext.SelectExpression.Projection.Count == 1) { - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); var expression = handlerContext.SelectExpression.Projection.First(); if (!(expression.RemoveConvert() is SelectExpression)) @@ -882,7 +881,7 @@ private static Expression HandleMax(HandlerContext handlerContext) if (!handlerContext.QueryModelVisitor.RequiresClientProjection && handlerContext.SelectExpression.Projection.Count == 1) { - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); var expression = handlerContext.SelectExpression.Projection.First(); if (!(expression.RemoveConvert() is SelectExpression)) @@ -968,7 +967,7 @@ private static Expression HandleSum(HandlerContext handlerContext) if (!handlerContext.QueryModelVisitor.RequiresClientProjection && handlerContext.SelectExpression.Projection.Count == 1) { - PrepareSelectExpressionForAggregate(handlerContext.SelectExpression); + PrepareSelectExpressionForAggregate(handlerContext); var expression = handlerContext.SelectExpression.Projection.First(); if (!(expression.RemoveConvert() is SelectExpression)) @@ -1032,17 +1031,23 @@ private static void SetConditionAsProjection( typeof(bool))); } - private static void PrepareSelectExpressionForAggregate(SelectExpression selectExpression) + private static void PrepareSelectExpressionForAggregate(HandlerContext handlerContext) { - if (selectExpression.IsDistinct - || selectExpression.Limit != null - || selectExpression.Offset != null) + if (handlerContext.SelectExpression.IsDistinct + || handlerContext.SelectExpression.Limit != null + || handlerContext.SelectExpression.Offset != null + || (handlerContext.SelectExpression.GroupBy.Any() + && !IsGroupByAggregate(handlerContext.QueryModel))) { - selectExpression.PushDownSubquery(); - selectExpression.ExplodeStarProjection(); + handlerContext.SelectExpression.PushDownSubquery(); + handlerContext.SelectExpression.ExplodeStarProjection(); } } + private static bool IsGroupByAggregate(QueryModel queryModel) + => queryModel.MainFromClause.FromExpression is QuerySourceReferenceExpression mainFromClauseQsre + && mainFromClauseQsre.ReferencedQuerySource.ItemType.IsGrouping(); + private static Expression UnwrapAliasExpression(Expression expression) => (expression as AliasExpression)?.Expression ?? expression; diff --git a/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs index adff96258d4..5a1439f3c9b 100644 --- a/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryModelVisitor.cs @@ -840,6 +840,8 @@ var subQueryModelVisitor var subSelectExpression = subQueryModelVisitor.Queries.First(); AddQuery(querySource, subSelectExpression); + RequiresStreamingGroupResultOperator = true; + return subQueryModelVisitor.Expression; } diff --git a/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs b/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs index 4becbac49a5..a875e91d8ea 100644 --- a/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs +++ b/src/EFCore.Specification.Tests/Query/AsyncGroupByQueryTestBase.cs @@ -1990,5 +1990,24 @@ public virtual async Task Double_GroupBy_with_aggregate() #endregion + #region ResultOperatorsAfterGroupBy + + [ConditionalFact] + public virtual async Task Count_after_GroupBy() + { + await AssertSingleResult( + os => os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).CountAsync()); + } + + [ConditionalFact] + public virtual async Task LongCount_after_client_GroupBy() + { + await AssertSingleResult( + os => (from o in os + group o by new { o.CustomerID } into g + select g.Where(e => e.OrderID < 10300).Count()).LongCountAsync()); + } + + #endregion } } diff --git a/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs b/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs index 048d4846401..35f2b0d0b8d 100644 --- a/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs +++ b/src/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs @@ -1994,5 +1994,25 @@ public virtual void Double_GroupBy_with_aggregate() } #endregion + + #region ResultOperatorsAfterGroupBy + + [ConditionalFact] + public virtual void Count_after_GroupBy() + { + AssertSingleResult( + os =>os.GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).Count()); + } + + [ConditionalFact] + public virtual void LongCount_after_client_GroupBy() + { + AssertSingleResult( + os => (from o in os + group o by new { o.CustomerID } into g + select g.Where(e => e.OrderID < 10300).Count()).LongCount()); + } + + #endregion } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs index 6580b19eb0e..f557865069c 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs @@ -1759,6 +1759,29 @@ FROM [Orders] AS [o0] ORDER BY [o0].[OrderID], [o0].[OrderDate]"); } + public override void Count_after_GroupBy() + { + base.Count_after_GroupBy(); + + AssertSql( + @"SELECT COUNT(*) +FROM ( + SELECT SUM([o].[OrderID]) AS [c] + FROM [Orders] AS [o] + GROUP BY [o].[CustomerID] +) AS [t]"); + } + + public override void LongCount_after_client_GroupBy() + { + base.LongCount_after_client_GroupBy(); + + AssertSql( + @"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] +FROM [Orders] AS [o] +ORDER BY [o].[CustomerID]"); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected);