Skip to content

Commit

Permalink
Support GroupBy over complex type
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Apr 9, 2024
1 parent 45448ef commit 1e73ce1
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -735,20 +735,23 @@ protected override ShapedQueryExpression TranslateExcept(ShapedQueryExpression s

var remappedKeySelector = RemapLambdaBody(source, keySelector);
var translatedKey = TranslateGroupingKey(remappedKeySelector);
if (translatedKey == null)
switch (translatedKey)
{
// This could be group by entity type
if (remappedKeySelector is not StructuralTypeShaperExpression
// Special handling for GroupBy over entity type: get the entity projection expression out.
// For GroupBy over a complex type, we already get the projection expression out.
case StructuralTypeShaperExpression { StructuralType: IEntityType } shaper:
if (shaper.ValueBufferExpression is not ProjectionBindingExpression pbe)
{
ValueBufferExpression: ProjectionBindingExpression pbe
} shaper)
{
// ValueBufferExpression can be JsonQuery, ProjectionBindingExpression, EntityProjection
// We only allow ProjectionBindingExpression which represents a regular entity
return null;
}
// ValueBufferExpression can be JsonQuery, ProjectionBindingExpression, EntityProjection
// We only allow ProjectionBindingExpression which represents a regular entity
return null;
}

translatedKey = shaper.Update(((SelectExpression)pbe.QueryExpression).GetProjection(pbe));
translatedKey = shaper.Update(((SelectExpression)pbe.QueryExpression).GetProjection(pbe));
break;

case null:
return null;
}

if (elementSelector != null)
Expand Down Expand Up @@ -823,7 +826,7 @@ protected override ShapedQueryExpression TranslateExcept(ShapedQueryExpression s
return memberInitExpression.Update(updatedNewExpression, newBindings);

default:
var translation = TranslateExpression(expression);
var translation = TranslateProjection(expression);
if (translation == null)
{
return null;
Expand Down Expand Up @@ -1325,6 +1328,21 @@ protected override ShapedQueryExpression TranslateUnion(ShapedQueryExpression so
return translation;
}

private Expression? TranslateProjection(Expression expression, bool applyDefaultTypeMapping = true)
{
var translation = _sqlTranslator.TranslateProjection(expression, applyDefaultTypeMapping);

if (translation is null)
{
if (_sqlTranslator.TranslationErrorDetails != null)
{
AddTranslationErrorDetails(_sqlTranslator.TranslationErrorDetails);
}
}

return translation;
}

/// <summary>
/// Translates the given lambda expression for the <see cref="ShapedQueryExpression" /> source into equivalent SQL representation.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,11 @@ private static void PopulateGroupByTerms(
projection.DiscriminatorExpression, groupByTerms, groupByAliases, name: DiscriminatorColumnAlias);
}

foreach (var complexProperty in projection.StructuralType.GetComplexProperties())
{
PopulateGroupByTerms(projection.BindComplexProperty(complexProperty), groupByTerms, groupByAliases, name: null);
}

break;

default:
Expand Down
33 changes: 33 additions & 0 deletions test/EFCore.Specification.Tests/Query/ComplexTypeQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -812,13 +812,46 @@ from c2 in ss.Set<Customer>()
AssertEqual(e.Complex?.Two, a.Complex?.Two);
});

#region GroupBy

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_over_property_in_nested_complex_type(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().GroupBy(x => x.ShippingAddress.Country.Code).Select(g => new { Code = g.Key, Count = g.Count() }),
elementSorter: g => g.Code);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_over_complex_type(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().GroupBy(x => x.ShippingAddress).Select(g => new { Address = g.Key, Count = g.Count() }),
elementSorter: g => g.Address.ZipCode,
elementAsserter: (e, a) =>
{
AssertEqual(e.Address, a.Address);
Assert.Equal(e.Count, a.Count);
});

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_over_nested_complex_type(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().GroupBy(x => x.ShippingAddress.Country).Select(g => new { Country = g.Key, Count = g.Count() }),
elementSorter: g => g.Country.Code);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Entity_with_complex_type_with_group_by_and_first(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().GroupBy(x => x.Id).Select(x => x.First()));

#endregion GroupBy

protected DbContext CreateContext()
=> Fixture.CreateContext();
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private static IReadOnlyList<Customer> CreateCustomers()
AddressLine1 = "804 S. Lakeshore Road",
ZipCode = 38654,
Country = new Country { FullName = "United States", Code = "US" },
Tags = new List<string> { "foo", "bar" }
Tags = ["foo", "bar"]
};

var customer1 = new Customer
Expand All @@ -71,19 +71,14 @@ private static IReadOnlyList<Customer> CreateCustomers()
AddressLine1 = "72 Hickory Rd.",
ZipCode = 07728,
Country = new Country { FullName = "Germany", Code = "DE" },
Tags = new List<string> { "baz" }
Tags = ["baz"]
},
BillingAddress = new Address
{
AddressLine1 = "79 Main St.",
ZipCode = 29293,
Country = new Country { FullName = "Germany", Code = "DE" },
Tags = new List<string>
{
"a1",
"a2",
"a3"
}
Tags = ["a1", "a2", "a3"]
}
};

Expand All @@ -92,7 +87,7 @@ private static IReadOnlyList<Customer> CreateCustomers()
AddressLine1 = "79 Main St.",
ZipCode = 29293,
Country = new Country { FullName = "Germany", Code = "DE" },
Tags = new List<string> { "foo", "moo" }
Tags = ["foo", "moo"]
};

var customer3 = new Customer
Expand All @@ -103,12 +98,7 @@ private static IReadOnlyList<Customer> CreateCustomers()
BillingAddress = address3
};

return new List<Customer>
{
customer1,
customer2,
customer3
};
return [customer1, customer2, customer3];
}

private static IReadOnlyList<CustomerGroup> CreateCustomerGroups(IReadOnlyList<Customer> customers)
Expand All @@ -134,12 +124,7 @@ private static IReadOnlyList<CustomerGroup> CreateCustomerGroups(IReadOnlyList<C
OptionalCustomer = null
};

return new List<CustomerGroup>
{
group1,
group2,
group3
};
return [group1, group2, group3];
}

private static IReadOnlyList<ValuedCustomer> CreateValuedCustomers()
Expand Down Expand Up @@ -192,12 +177,7 @@ private static IReadOnlyList<ValuedCustomer> CreateValuedCustomers()
BillingAddress = address3
};

return new List<ValuedCustomer>
{
customer1,
customer2,
customer3
};
return [customer1, customer2, customer3];
}

private static IReadOnlyList<ValuedCustomerGroup> CreateValuedCustomerGroups(IReadOnlyList<ValuedCustomer> customers)
Expand All @@ -223,12 +203,7 @@ private static IReadOnlyList<ValuedCustomerGroup> CreateValuedCustomerGroups(IRe
OptionalCustomer = null
};

return new List<ValuedCustomerGroup>
{
group1,
group2,
group3
};
return [group1, group2, group3];
}

public static Task SeedAsync(PoolableDbContext context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1131,12 +1131,50 @@ public override async Task Same_complex_type_projected_twice_with_pushdown_as_pa
AssertSql("");
}

#region GroupBy

public override async Task GroupBy_over_property_in_nested_complex_type(bool async)
{
await base.GroupBy_over_property_in_nested_complex_type(async);

AssertSql(
"""
SELECT [c].[ShippingAddress_Country_Code] AS [Code], COUNT(*) AS [Count]
FROM [Customer] AS [c]
GROUP BY [c].[ShippingAddress_Country_Code]
""");
}

public override async Task GroupBy_over_complex_type(bool async)
{
await base.GroupBy_over_complex_type(async);

AssertSql(
"""
SELECT [c].[ShippingAddress_AddressLine1], [c].[ShippingAddress_AddressLine2], [c].[ShippingAddress_Tags], [c].[ShippingAddress_ZipCode], [c].[ShippingAddress_Country_Code], [c].[ShippingAddress_Country_FullName], COUNT(*) AS [Count]
FROM [Customer] AS [c]
GROUP BY [c].[ShippingAddress_AddressLine1], [c].[ShippingAddress_AddressLine2], [c].[ShippingAddress_Tags], [c].[ShippingAddress_ZipCode], [c].[ShippingAddress_Country_Code], [c].[ShippingAddress_Country_FullName]
""");
}

public override async Task GroupBy_over_nested_complex_type(bool async)
{
await base.GroupBy_over_nested_complex_type(async);

AssertSql(
"""
SELECT [c].[ShippingAddress_Country_Code], [c].[ShippingAddress_Country_FullName], COUNT(*) AS [Count]
FROM [Customer] AS [c]
GROUP BY [c].[ShippingAddress_Country_Code], [c].[ShippingAddress_Country_FullName]
""");
}

public override async Task Entity_with_complex_type_with_group_by_and_first(bool async)
{
await base.Entity_with_complex_type_with_group_by_and_first(async);

AssertSql(
"""
"""
SELECT [c3].[Id], [c3].[Name], [c3].[BillingAddress_AddressLine1], [c3].[BillingAddress_AddressLine2], [c3].[BillingAddress_Tags], [c3].[BillingAddress_ZipCode], [c3].[BillingAddress_Country_Code], [c3].[BillingAddress_Country_FullName], [c3].[ShippingAddress_AddressLine1], [c3].[ShippingAddress_AddressLine2], [c3].[ShippingAddress_Tags], [c3].[ShippingAddress_ZipCode], [c3].[ShippingAddress_Country_Code], [c3].[ShippingAddress_Country_FullName]
FROM (
SELECT [c].[Id]
Expand All @@ -1154,6 +1192,8 @@ FROM [Customer] AS [c0]
""");
}

#endregion GroupBy

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1014,12 +1014,50 @@ public override async Task Same_complex_type_projected_twice_with_pushdown_as_pa
(await Assert.ThrowsAsync<InvalidOperationException>(
() => base.Same_complex_type_projected_twice_with_pushdown_as_part_of_another_projection(async))).Message);

#region GroupBy

public override async Task GroupBy_over_property_in_nested_complex_type(bool async)
{
await base.GroupBy_over_property_in_nested_complex_type(async);

AssertSql(
"""
SELECT "c"."ShippingAddress_Country_Code" AS "Code", COUNT(*) AS "Count"
FROM "Customer" AS "c"
GROUP BY "c"."ShippingAddress_Country_Code"
""");
}

public override async Task GroupBy_over_complex_type(bool async)
{
await base.GroupBy_over_complex_type(async);

AssertSql(
"""
SELECT "c"."ShippingAddress_AddressLine1", "c"."ShippingAddress_AddressLine2", "c"."ShippingAddress_Tags", "c"."ShippingAddress_ZipCode", "c"."ShippingAddress_Country_Code", "c"."ShippingAddress_Country_FullName", COUNT(*) AS "Count"
FROM "Customer" AS "c"
GROUP BY "c"."ShippingAddress_AddressLine1", "c"."ShippingAddress_AddressLine2", "c"."ShippingAddress_Tags", "c"."ShippingAddress_ZipCode", "c"."ShippingAddress_Country_Code", "c"."ShippingAddress_Country_FullName"
""");
}

public override async Task GroupBy_over_nested_complex_type(bool async)
{
await base.GroupBy_over_nested_complex_type(async);

AssertSql(
"""
SELECT "c"."ShippingAddress_Country_Code", "c"."ShippingAddress_Country_FullName", COUNT(*) AS "Count"
FROM "Customer" AS "c"
GROUP BY "c"."ShippingAddress_Country_Code", "c"."ShippingAddress_Country_FullName"
""");
}

public override async Task Entity_with_complex_type_with_group_by_and_first(bool async)
{
await base.Entity_with_complex_type_with_group_by_and_first(async);

AssertSql(
"""
"""
SELECT "c3"."Id", "c3"."Name", "c3"."BillingAddress_AddressLine1", "c3"."BillingAddress_AddressLine2", "c3"."BillingAddress_Tags", "c3"."BillingAddress_ZipCode", "c3"."BillingAddress_Country_Code", "c3"."BillingAddress_Country_FullName", "c3"."ShippingAddress_AddressLine1", "c3"."ShippingAddress_AddressLine2", "c3"."ShippingAddress_Tags", "c3"."ShippingAddress_ZipCode", "c3"."ShippingAddress_Country_Code", "c3"."ShippingAddress_Country_FullName"
FROM (
SELECT "c"."Id"
Expand All @@ -1037,6 +1075,8 @@ LEFT JOIN (
""");
}

#endregion GroupBy

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down

0 comments on commit 1e73ce1

Please sign in to comment.