Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Add support for SQL GROUP BY #16381

Merged
merged 2 commits into from
Jul 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/EFCore.Relational/Query/Pipeline/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,20 @@ protected virtual void GenerateSelect(SelectExpression selectExpression)
Visit(selectExpression.Predicate);
}

if (selectExpression.GroupBy.Count > 0)
{
_relationalCommandBuilder.AppendLine().Append("GROUP BY ");

GenerateList(selectExpression.GroupBy, e => Visit(e));
}

if (selectExpression.HavingExpression != null)
{
_relationalCommandBuilder.AppendLine().Append("HAVING ");

Visit(selectExpression.HavingExpression);
}

GenerateOrderings(selectExpression);
GenerateLimitOffset(selectExpression);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.NavigationExpansion;
using Microsoft.EntityFrameworkCore.Query.Pipeline;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,34 +151,12 @@ protected override ShapedQueryExpression TranslateAverage(ShapedQueryExpression
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();

if (selector != null)
{
source = TranslateSelect(source, selector);
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
var newSelector = selector == null
|| selector.Body == selector.Parameters[0]
? selectExpression.GetMappedProjection(new ProjectionMember())
: ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), source.ShaperExpression, selector.Body);

var inputType = projection.Type.UnwrapNullableType();
if (inputType == typeof(int)
|| inputType == typeof(long))
{
projection = _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(projection, typeof(double)));
}

if (inputType == typeof(float))
{
projection = _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function(
"AVG", new[] { projection }, typeof(double), null),
projection.Type,
projection.TypeMapping);
}
else
{
projection = _sqlExpressionFactory.Function(
"AVG", new[] { projection }, projection.Type, projection.TypeMapping);
}
var projection = _sqlTranslator.TranslateAverage(newSelector);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}
Expand Down Expand Up @@ -237,8 +215,7 @@ protected override ShapedQueryExpression TranslateCount(ShapedQueryExpression so
source = TranslateWhere(source, predicate);
}

var translation = _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function("COUNT", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(int)));
var translation = _sqlTranslator.TranslateCount();

var projectionMapping = new Dictionary<ProjectionMember, Expression>
{
Expand Down Expand Up @@ -289,7 +266,105 @@ protected override ShapedQueryExpression TranslateFirstOrDefault(ShapedQueryExpr
return source;
}

protected override ShapedQueryExpression TranslateGroupBy(ShapedQueryExpression source, LambdaExpression keySelector, LambdaExpression elementSelector, LambdaExpression resultSelector) => throw new NotImplementedException();
protected override ShapedQueryExpression TranslateGroupBy(
ShapedQueryExpression source,
LambdaExpression keySelector,
LambdaExpression elementSelector,
LambdaExpression resultSelector)
{
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();

var remappedKeySelector = RemapLambdaBody(source.ShaperExpression, keySelector);

var translatedKey = TranslateGroupingKey(remappedKeySelector)
?? (remappedKeySelector is ConstantExpression ? remappedKeySelector : null);
if (translatedKey != null)
{
if (elementSelector != null)
{
source = TranslateSelect(source, elementSelector);
}

var sqlKeySelector = translatedKey is ConstantExpression
? _sqlExpressionFactory.ApplyDefaultTypeMapping(_sqlExpressionFactory.Constant(1))
: translatedKey;

var appliedKeySelector = selectExpression.ApplyGrouping(sqlKeySelector);
translatedKey = translatedKey is ConstantExpression ? translatedKey : appliedKeySelector;

source.ShaperExpression = new GroupByShaperExpression(translatedKey, source.ShaperExpression);

if (resultSelector == null)
{
return source;
}

var keyAccessExpression = Expression.MakeMemberAccess(
source.ShaperExpression,
source.ShaperExpression.Type.GetTypeInfo().GetMember(nameof(IGrouping<int, int>.Key))[0]);

var newResultSelectorBody = ReplacingExpressionVisitor.Replace(
resultSelector.Parameters[0], keyAccessExpression,
resultSelector.Parameters[1], source.ShaperExpression,
resultSelector.Body);

source.ShaperExpression = _projectionBindingExpressionVisitor.Translate(selectExpression, newResultSelectorBody);

return source;
}

throw new InvalidOperationException();
}

private Expression TranslateGroupingKey(Expression expression)
{
if (expression is NewExpression newExpression)
{
if (newExpression.Arguments.Count == 0)
{
return newExpression;
}

var newArguments = new Expression[newExpression.Arguments.Count];
for (var i = 0; i < newArguments.Length; i++)
{
newArguments[i] = TranslateGroupingKey(newExpression.Arguments[i]);
if (newArguments[i] == null)
{
return null;
}
}

return newExpression.Update(newArguments);
}

if (expression is MemberInitExpression memberInitExpression)
{
var updatedNewExpression = (NewExpression)TranslateGroupingKey(memberInitExpression.NewExpression);
if (updatedNewExpression == null)
{
return null;
}

var newBindings = new MemberAssignment[memberInitExpression.Bindings.Count];
for (var i = 0; i < newBindings.Length; i++)
{
var memberAssignment = (MemberAssignment)memberInitExpression.Bindings[i];
var visitedExpression = TranslateGroupingKey(memberAssignment.Expression);
if (visitedExpression == null)
{
return null;
}

newBindings[i] = memberAssignment.Update(visitedExpression);
}

return memberInitExpression.Update(updatedNewExpression, newBindings);
}

return _sqlTranslator.Translate(expression);
}

protected override ShapedQueryExpression TranslateGroupJoin(ShapedQueryExpression outer, ShapedQueryExpression inner, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, LambdaExpression resultSelector)
{
Expand Down Expand Up @@ -480,8 +555,7 @@ protected override ShapedQueryExpression TranslateLongCount(ShapedQueryExpressio
source = TranslateWhere(source, predicate);
}

var translation = _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function("COUNT", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(long)));
var translation = _sqlTranslator.TranslateLongCount();
var projectionMapping = new Dictionary<ProjectionMember, Expression>
{
{ new ProjectionMember(), translation }
Expand All @@ -499,14 +573,12 @@ protected override ShapedQueryExpression TranslateMax(ShapedQueryExpression sour
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();

if (selector != null)
{
source = TranslateSelect(source, selector);
}

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());
var newSelector = selector == null
|| selector.Body == selector.Parameters[0]
? selectExpression.GetMappedProjection(new ProjectionMember())
: ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), source.ShaperExpression, selector.Body);

projection = _sqlExpressionFactory.Function("MAX", new[] { projection }, resultType, projection.TypeMapping);
var projection = _sqlTranslator.TranslateMax(newSelector);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}
Expand All @@ -516,14 +588,12 @@ protected override ShapedQueryExpression TranslateMin(ShapedQueryExpression sour
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();

if (selector != null)
{
source = TranslateSelect(source, selector);
}
var newSelector = selector == null
|| selector.Body == selector.Parameters[0]
? selectExpression.GetMappedProjection(new ProjectionMember())
: ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), source.ShaperExpression, selector.Body);

var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());

projection = _sqlExpressionFactory.Function("MIN", new[] { projection }, resultType, projection.TypeMapping);
var projection = _sqlTranslator.TranslateMin(newSelector);

return AggregateResultShaper(source, projection, throwOnNullResult: true, resultType);
}
Expand Down Expand Up @@ -617,17 +687,16 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
}

var newSelectorBody = ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), source.ShaperExpression, selector.Body);

source.ShaperExpression = _projectionBindingExpressionVisitor
.Translate(selectExpression, newSelectorBody);
source.ShaperExpression = _projectionBindingExpressionVisitor.Translate(selectExpression, newSelectorBody);

return source;
}

private static readonly MethodInfo _defaultIfEmptyWithoutArgMethodInfo = typeof(Enumerable).GetTypeInfo()
.GetDeclaredMethods(nameof(Enumerable.DefaultIfEmpty)).Single(mi => mi.GetParameters().Length == 1);

protected override ShapedQueryExpression TranslateSelectMany(ShapedQueryExpression source, LambdaExpression collectionSelector, LambdaExpression resultSelector)
protected override ShapedQueryExpression TranslateSelectMany(
ShapedQueryExpression source, LambdaExpression collectionSelector, LambdaExpression resultSelector)
{
var collectionSelectorBody = collectionSelector.Body;
//var defaultIfEmpty = false;
Expand Down Expand Up @@ -737,27 +806,12 @@ protected override ShapedQueryExpression TranslateSum(ShapedQueryExpression sour
{
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();
var newSelector = selector == null
|| selector.Body == selector.Parameters[0]
? selectExpression.GetMappedProjection(new ProjectionMember())
: ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), source.ShaperExpression, selector.Body);

if (selector != null)
{
source = TranslateSelect(source, selector);
}

var serverOutputType = resultType.UnwrapNullableType();
var projection = (SqlExpression)selectExpression.GetMappedProjection(new ProjectionMember());

if (serverOutputType == typeof(float))
{
projection = _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function("SUM", new[] { projection }, typeof(double)),
serverOutputType,
projection.TypeMapping);
}
else
{
projection = _sqlExpressionFactory.Function(
"SUM", new[] { projection }, serverOutputType, projection.TypeMapping);
}
var projection = _sqlTranslator.TranslateSum(newSelector);

return AggregateResultShaper(source, projection, throwOnNullResult: false, resultType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ protected override Expression VisitExtension(Expression extensionExpression)
var collectionId = _collectionId++;
var selectExpression = (SelectExpression)collectionShaperExpression.Projection.QueryExpression;
// Do pushdown beforehand so it updates all pending collections first
if (selectExpression.IsDistinct || selectExpression.Limit != null || selectExpression.Offset != null)
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null
|| selectExpression.GroupBy.Count > 1)
{
selectExpression.PushdownIntoSubquery();
}
Expand Down
Loading