Skip to content

Commit

Permalink
Query: Extract join predicate from Having for group by aggregate quer…
Browse files Browse the repository at this point in the history
…ies (#26011)

Resolves #24474
  • Loading branch information
smitpatel authored Sep 14, 2021
1 parent 2482ecf commit a6b7d59
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 46 deletions.
98 changes: 54 additions & 44 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2132,57 +2132,67 @@ static void GetPartitions(SelectExpression selectExpression, SqlExpression sqlEx

static SqlExpression? TryExtractJoinKey(SelectExpression outer, SelectExpression inner, bool allowNonEquality)
{
if (inner.Limit == null
&& inner.Offset == null
&& inner.Predicate != null)
{
var outerColumnExpressions = new List<SqlExpression>();
var joinPredicate = TryExtractJoinKey(
outer,
inner,
inner.Predicate,
outerColumnExpressions,
allowNonEquality,
out var predicate);
if (inner.Limit != null
|| inner.Offset != null)
{
return null;
}

if (joinPredicate != null)
{
joinPredicate = RemoveRedundantNullChecks(joinPredicate, outerColumnExpressions);
}
// TODO: verify the case for GroupBy. See issue#24474
// We extract join predicate from Predicate part but GroupBy would have last Having. Changing predicate can change groupings
var predicate = inner.GroupBy.Count > 0 ? inner.Having : inner.Predicate;
if (predicate == null)
{
return null;
}

// we can't convert apply to join in case of distinct and groupby, if the projection doesn't already contain the join keys
// since we can't add the missing keys to the projection - only convert to join if all the keys are already there
if (joinPredicate != null
&& (inner.IsDistinct
|| inner.GroupBy.Count > 0))
{
var innerKeyColumns = new List<ColumnExpression>();
InnerKeyColumns(inner.Tables, joinPredicate, innerKeyColumns);
var outerColumnExpressions = new List<SqlExpression>();
var joinPredicate = TryExtractJoinKey(
outer,
inner,
predicate,
outerColumnExpressions,
allowNonEquality,
out var updatedPredicate);

if (joinPredicate != null)
{
joinPredicate = RemoveRedundantNullChecks(joinPredicate, outerColumnExpressions);
}

// if projection has already been applied we can use it directly
// otherwise we extract future projection columns from projection mapping
// and based on that we determine whether we can convert from APPLY to JOIN
var projectionColumns = inner.Projection.Count > 0
? inner.Projection.Select(p => p.Expression)
: ExtractColumnsFromProjectionMapping(inner._projectionMapping);
// we can't convert apply to join in case of distinct and groupby, if the projection doesn't already contain the join keys
// since we can't add the missing keys to the projection - only convert to join if all the keys are already there
if (joinPredicate != null
&& (inner.IsDistinct
|| inner.GroupBy.Count > 0))
{
var innerKeyColumns = new List<ColumnExpression>();
PopulateInnerKeyColumns(inner.Tables, joinPredicate, innerKeyColumns);

// if projection has already been applied we can use it directly
// otherwise we extract future projection columns from projection mapping
// and based on that we determine whether we can convert from APPLY to JOIN
var projectionColumns = inner.Projection.Count > 0
? inner.Projection.Select(p => p.Expression)
: ExtractColumnsFromProjectionMapping(inner._projectionMapping);

foreach (var innerColumn in innerKeyColumns)
foreach (var innerColumn in innerKeyColumns)
{
if (!projectionColumns.Contains(innerColumn))
{
if (!projectionColumns.Contains(innerColumn))
{
return null;
}
return null;
}
}
}

inner.Predicate = predicate;

return joinPredicate;
if (inner.GroupBy.Count > 0)
{
inner.Having = updatedPredicate;
}
else
{
inner.Predicate = updatedPredicate;
}

return null;
return joinPredicate;

static SqlExpression? TryExtractJoinKey(
SelectExpression outer,
Expand Down Expand Up @@ -2310,12 +2320,12 @@ static bool IsContainedColumn(SelectExpression selectExpression, SqlExpression s
}
}

static void InnerKeyColumns(IEnumerable<TableExpressionBase> tables, SqlExpression joinPredicate, List<ColumnExpression> resultColumns)
static void PopulateInnerKeyColumns(IEnumerable<TableExpressionBase> tables, SqlExpression joinPredicate, List<ColumnExpression> resultColumns)
{
if (joinPredicate is SqlBinaryExpression sqlBinaryExpression)
{
InnerKeyColumns(tables, sqlBinaryExpression.Left, resultColumns);
InnerKeyColumns(tables, sqlBinaryExpression.Right, resultColumns);
PopulateInnerKeyColumns(tables, sqlBinaryExpression.Left, resultColumns);
PopulateInnerKeyColumns(tables, sqlBinaryExpression.Right, resultColumns);
}
else if (joinPredicate is ColumnExpression columnExpression
&& tables.Contains(columnExpression.Table))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,37 @@ join o in ss.Set<Order>() on a.LastOrderID equals o.OrderID
entryCount: 126);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_Aggregate_Join_converted_from_SelectMany(bool async)
{
return AssertQuery(
async,
ss => from c in ss.Set<Customer>()
from o in ss.Set<Order>().GroupBy(o => o.CustomerID)
.Where(g => g.Count() > 5)
.Select(g => new { CustomerID = g.Key, LastOrderID = g.Max(o => o.OrderID) })
.Where(c1 => c.CustomerID == c1.CustomerID)
select c,
entryCount: 63);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_Aggregate_LeftJoin_converted_from_SelectMany(bool async)
{
return AssertQuery(
async,
ss => from c in ss.Set<Customer>()
from o in ss.Set<Order>().GroupBy(o => o.CustomerID)
.Where(g => g.Count() > 5)
.Select(g => new { CustomerID = g.Key, LastOrderID = g.Max(o => o.OrderID) })
.Where(c1 => c.CustomerID == c1.CustomerID)
.DefaultIfEmpty()
select c,
entryCount: 91);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Join_GroupBy_Aggregate_multijoins(bool async)
Expand Down Expand Up @@ -2486,8 +2517,8 @@ public virtual Task GroupBy_aggregate_without_selectMany_selecting_first(bool as
async,
ss => from id in
(from o in ss.Set<Order>()
group o by o.CustomerID into g
select g.Min(x => x.OrderID))
group o by o.CustomerID into g
select g.Min(x => x.OrderID))
from o in ss.Set<Order>()
where o.OrderID == id
select o,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,36 @@ HAVING COUNT(*) > 5
INNER JOIN [Orders] AS [o0] ON [t].[LastOrderID] = [o0].[OrderID]");
}

public override async Task GroupBy_Aggregate_Join_converted_from_SelectMany(bool async)
{
await base.GroupBy_Aggregate_Join_converted_from_SelectMany(async);

AssertSql(
@"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
INNER JOIN (
SELECT [o].[CustomerID]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]
HAVING COUNT(*) > 5
) AS [t] ON [c].[CustomerID] = [t].[CustomerID]");
}

public override async Task GroupBy_Aggregate_LeftJoin_converted_from_SelectMany(bool async)
{
await base.GroupBy_Aggregate_LeftJoin_converted_from_SelectMany(async);

AssertSql(
@"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
LEFT JOIN (
SELECT [o].[CustomerID], MAX([o].[OrderID]) AS [LastOrderID]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]
HAVING COUNT(*) > 5
) AS [t] ON [c].[CustomerID] = [t].[CustomerID]");
}

public override async Task Join_GroupBy_Aggregate_multijoins(bool async)
{
await base.Join_GroupBy_Aggregate_multijoins(async);
Expand Down

0 comments on commit a6b7d59

Please sign in to comment.