Skip to content

Commit 15bc370

Browse files
committed
Fix to #12314 - SumAsync throw Exception when used over float?
Problem was that in order to return 0 from empty Sum of nullable values, we convert them to non-nullable (to produce 0) and then convert back to nullable type. This works without issue for sync path, but in async simple casting like that doesn't work. Fix is to call method that adds Task.ContinueWith() call that casts the result to the correct type in the async scenario.
1 parent 77512c4 commit 15bc370

File tree

10 files changed

+212
-107
lines changed

10 files changed

+212
-107
lines changed

src/EFCore.Relational/Extensions/TaskExtensions.cs renamed to src/EFCore.Relational/Extensions/RelationalTaskExtensions.cs

+28-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
namespace System.Threading.Tasks
55
{
6-
internal static class TaskExtensions
6+
internal static class RelationalTaskExtensions
77
{
88
public static Task<TDerived> Cast<T, TDerived>(this Task<T> task)
99
where TDerived : T
@@ -31,5 +31,32 @@ public static Task<TDerived> Cast<T, TDerived>(this Task<T> task)
3131

3232
return taskCompletionSource.Task;
3333
}
34+
35+
public static Task<T?> CastToNullable<T>(this Task<T> task)
36+
where T : struct
37+
{
38+
var taskCompletionSource = new TaskCompletionSource<T?>();
39+
40+
task.ContinueWith(
41+
t =>
42+
{
43+
if (t.IsFaulted)
44+
{
45+
// ReSharper disable once PossibleNullReferenceException
46+
taskCompletionSource.TrySetException(t.Exception.InnerExceptions);
47+
}
48+
else if (t.IsCanceled)
49+
{
50+
taskCompletionSource.TrySetCanceled();
51+
}
52+
else
53+
{
54+
taskCompletionSource.TrySetResult((T?)t.Result);
55+
}
56+
},
57+
TaskContinuationOptions.ExecuteSynchronously);
58+
59+
return taskCompletionSource.Task;
60+
}
3461
}
3562
}

src/EFCore.Relational/Query/Internal/RelationalResultOperatorHandler.cs

+16-2
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,9 @@ var sqlTranslatingExpressionVisitor
963963
return handlerContext.EvalOnClient();
964964
}
965965

966+
private static readonly MethodInfo _castToNullableMethodInfo
967+
= typeof(RelationalTaskExtensions).GetRuntimeMethods().Where(m => m.Name == nameof(RelationalTaskExtensions.CastToNullable)).Single();
968+
966969
private static Expression HandleSum(HandlerContext handlerContext)
967970
{
968971
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
@@ -990,10 +993,21 @@ var clientExpression
990993
.MakeGenericMethod(sumExpression.Type.UnwrapNullableType())
991994
.Invoke(null, new object[] { handlerContext, /*throwOnNullResult:*/ false });
992995

993-
return
994-
sumExpression.Type.IsNullableType()
996+
if (handlerContext.QueryModelVisitor.QueryCompilationContext.IsAsyncQuery
997+
&& !(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12314", out var isEnabled) && isEnabled))
998+
{
999+
return sumExpression.Type.IsNullableType()
1000+
? Expression.Call(
1001+
_castToNullableMethodInfo.MakeGenericMethod(sumExpression.Type.UnwrapNullableType()),
1002+
clientExpression)
1003+
: clientExpression;
1004+
}
1005+
else
1006+
{
1007+
return sumExpression.Type.IsNullableType()
9951008
? Expression.Convert(clientExpression, sumExpression.Type)
9961009
: clientExpression;
1010+
}
9971011
}
9981012
}
9991013

src/EFCore.Specification.Tests/Query/AsyncGearsOfWarQueryTestBase.cs

+10
Original file line numberDiff line numberDiff line change
@@ -1691,5 +1691,15 @@ await AssertQuery<City>(
16911691
assertOrder: true,
16921692
elementAsserter: CollectionAsserter<Gear>(e => e.Nickname, (e, a) => Assert.Equal(e.Nickname, a.Nickname)));
16931693
}
1694+
1695+
[ConditionalFact]
1696+
public virtual async Task Sum_with_no_data_nullable_double()
1697+
{
1698+
using (var ctx = CreateContext())
1699+
{
1700+
var result = await ctx.Missions.Where(m => m.CodeName == "Operation Foobar").Select(m => m.Rating).SumAsync();
1701+
Assert.Equal(0, result);
1702+
}
1703+
}
16941704
}
16951705
}

src/EFCore.Specification.Tests/Query/AsyncQueryTestBase.cs

+120-90
Large diffs are not rendered by default.

src/EFCore.Specification.Tests/Query/AsyncSimpleQueryTestBase.cs

+6
Original file line numberDiff line numberDiff line change
@@ -3808,5 +3808,11 @@ public virtual async Task Cast_to_same_Type_CountAsync_works()
38083808
{
38093809
await AssertSingleResult<Customer>(cs => cs.Cast<Customer>().CountAsync());
38103810
}
3811+
3812+
[ConditionalFact]
3813+
public virtual async Task Sum_with_no_data_nullable()
3814+
{
3815+
await AssertSingleResult<Order>(os => os.Where(o => o.OrderID < 0).Select(o => (int?)o.OrderID).SumAsync());
3816+
}
38113817
}
38123818
}

src/EFCore.Specification.Tests/Query/GearsOfWarQueryFixtureBase.cs

+1
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ protected GearsOfWarQueryFixtureBase()
150150
{
151151
Assert.Equal(e.Id, a.Id);
152152
Assert.Equal(e.CodeName, a.CodeName);
153+
Assert.Equal(e.Rating, a.Rating);
153154
Assert.Equal(e.Timeline, a.Timeline);
154155
}
155156
}

src/EFCore.Specification.Tests/TestModels/GearsOfWarModel/GearsOfWarData.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ public static IReadOnlyList<Squad> CreateSquads()
104104
public static IReadOnlyList<Mission> CreateMissions()
105105
=> new List<Mission>
106106
{
107-
new Mission { Id = 1, CodeName = "Lightmass Offensive", Timeline = new DateTimeOffset(2, 1, 2, 10, 0, 0, new TimeSpan(1, 30, 0)) },
108-
new Mission { Id = 2, CodeName = "Hollow Storm", Timeline = new DateTimeOffset(2, 3, 1, 8, 0, 0, new TimeSpan(-5, 0, 0)) },
109-
new Mission { Id = 3, CodeName = "Halvo Bay defense", Timeline = new DateTimeOffset(10, 5, 3, 12, 0, 0, new TimeSpan()) }
107+
new Mission { Id = 1, CodeName = "Lightmass Offensive", Rating = 2.1, Timeline = new DateTimeOffset(2, 1, 2, 10, 0, 0, new TimeSpan(1, 30, 0)) },
108+
new Mission { Id = 2, CodeName = "Hollow Storm", Rating = 4.2, Timeline = new DateTimeOffset(2, 3, 1, 8, 0, 0, new TimeSpan(-5, 0, 0)) },
109+
new Mission { Id = 3, CodeName = "Halvo Bay defense", Rating = null, Timeline = new DateTimeOffset(10, 5, 3, 12, 0, 0, new TimeSpan()) }
110110
};
111111

112112
public static IReadOnlyList<SquadMission> CreateSquadMissions()

src/EFCore.Specification.Tests/TestModels/GearsOfWarModel/Mission.cs

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ public class Mission
1111
public int Id { get; set; }
1212

1313
public string CodeName { get; set; }
14+
public double? Rating { get; set; }
1415
public DateTimeOffset Timeline { get; set; }
1516

1617
public virtual ICollection<SquadMission> ParticipatingSquads { get; set; }

test/EFCore.SqlServer.FunctionalTests/Query/AsyncSimpleQuerySqlServerTest.cs

+16
Original file line numberDiff line numberDiff line change
@@ -214,5 +214,21 @@ await AssertQuery<Order>(
214214
os => os.Select(o => new { o.Customer.City, Count = o.OrderDetails.Count() }),
215215
elementSorter: e => e.City + " " + e.Count);
216216
}
217+
218+
[Fact]
219+
public async Task Sum_with_no_data_nullable_legacy_behavior()
220+
{
221+
AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12314", true);
222+
223+
try
224+
{
225+
await Assert.ThrowsAsync<InvalidOperationException>(
226+
async () => await AssertSingleResult<Order>(os => os.Where(o => o.OrderID < 0).Select(o => (int?)o.OrderID).SumAsync()));
227+
}
228+
finally
229+
{
230+
AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12314", false);
231+
}
232+
}
217233
}
218234
}

test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs

+11-11
Original file line numberDiff line numberDiff line change
@@ -3037,7 +3037,7 @@ public override void Where_datetimeoffset_now()
30373037
base.Where_datetimeoffset_now();
30383038

30393039
AssertSql(
3040-
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
3040+
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
30413041
FROM [Missions] AS [m]
30423042
WHERE [m].[Timeline] <> SYSDATETIMEOFFSET()");
30433043
}
@@ -3047,7 +3047,7 @@ public override void Where_datetimeoffset_utcnow()
30473047
base.Where_datetimeoffset_utcnow();
30483048

30493049
AssertSql(
3050-
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
3050+
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
30513051
FROM [Missions] AS [m]
30523052
WHERE [m].[Timeline] <> CAST(SYSUTCDATETIME() AS datetimeoffset)");
30533053
}
@@ -3060,7 +3060,7 @@ public override void Where_datetimeoffset_date_component()
30603060
AssertSql(
30613061
@"@__Date_0='0001-01-01T00:00:00'
30623062
3063-
SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
3063+
SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
30643064
FROM [Missions] AS [m]
30653065
WHERE CONVERT(date, [m].[Timeline]) > @__Date_0");
30663066
}
@@ -3070,7 +3070,7 @@ public override void Where_datetimeoffset_year_component()
30703070
base.Where_datetimeoffset_year_component();
30713071

30723072
AssertSql(
3073-
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
3073+
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
30743074
FROM [Missions] AS [m]
30753075
WHERE DATEPART(year, [m].[Timeline]) = 2");
30763076
}
@@ -3080,7 +3080,7 @@ public override void Where_datetimeoffset_month_component()
30803080
base.Where_datetimeoffset_month_component();
30813081

30823082
AssertSql(
3083-
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
3083+
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
30843084
FROM [Missions] AS [m]
30853085
WHERE DATEPART(month, [m].[Timeline]) = 1");
30863086
}
@@ -3090,7 +3090,7 @@ public override void Where_datetimeoffset_dayofyear_component()
30903090
base.Where_datetimeoffset_dayofyear_component();
30913091

30923092
AssertSql(
3093-
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
3093+
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
30943094
FROM [Missions] AS [m]
30953095
WHERE DATEPART(dayofyear, [m].[Timeline]) = 2");
30963096
}
@@ -3100,7 +3100,7 @@ public override void Where_datetimeoffset_day_component()
31003100
base.Where_datetimeoffset_day_component();
31013101

31023102
AssertSql(
3103-
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
3103+
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
31043104
FROM [Missions] AS [m]
31053105
WHERE DATEPART(day, [m].[Timeline]) = 2");
31063106
}
@@ -3110,7 +3110,7 @@ public override void Where_datetimeoffset_hour_component()
31103110
base.Where_datetimeoffset_hour_component();
31113111

31123112
AssertSql(
3113-
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
3113+
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
31143114
FROM [Missions] AS [m]
31153115
WHERE DATEPART(hour, [m].[Timeline]) = 10");
31163116
}
@@ -3120,7 +3120,7 @@ public override void Where_datetimeoffset_minute_component()
31203120
base.Where_datetimeoffset_minute_component();
31213121

31223122
AssertSql(
3123-
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
3123+
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
31243124
FROM [Missions] AS [m]
31253125
WHERE DATEPART(minute, [m].[Timeline]) = 0");
31263126
}
@@ -3130,7 +3130,7 @@ public override void Where_datetimeoffset_second_component()
31303130
base.Where_datetimeoffset_second_component();
31313131

31323132
AssertSql(
3133-
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
3133+
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
31343134
FROM [Missions] AS [m]
31353135
WHERE DATEPART(second, [m].[Timeline]) = 0");
31363136
}
@@ -3140,7 +3140,7 @@ public override void Where_datetimeoffset_millisecond_component()
31403140
base.Where_datetimeoffset_millisecond_component();
31413141

31423142
AssertSql(
3143-
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
3143+
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
31443144
FROM [Missions] AS [m]
31453145
WHERE DATEPART(millisecond, [m].[Timeline]) = 0");
31463146
}

0 commit comments

Comments
 (0)