Skip to content

Commit

Permalink
Fix to #12314 - SumAsync throw Exception when used over float?
Browse files Browse the repository at this point in the history
Problem was that in order to return 0 from empty Sum of nullable values, we convert them to non-nullable (to produce 0) and then convert back to nullable type. This works without issue for sync path, but in async simple casting like that doesn't work.

Fix is to call method that adds Task.ContinueWith() call that casts the result to the correct type in the async scenario.
  • Loading branch information
maumar committed Jun 29, 2018
1 parent 77512c4 commit 15bc370
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace System.Threading.Tasks
{
internal static class TaskExtensions
internal static class RelationalTaskExtensions
{
public static Task<TDerived> Cast<T, TDerived>(this Task<T> task)
where TDerived : T
Expand Down Expand Up @@ -31,5 +31,32 @@ public static Task<TDerived> Cast<T, TDerived>(this Task<T> task)

return taskCompletionSource.Task;
}

public static Task<T?> CastToNullable<T>(this Task<T> task)
where T : struct
{
var taskCompletionSource = new TaskCompletionSource<T?>();

task.ContinueWith(
t =>
{
if (t.IsFaulted)
{
// ReSharper disable once PossibleNullReferenceException
taskCompletionSource.TrySetException(t.Exception.InnerExceptions);
}
else if (t.IsCanceled)
{
taskCompletionSource.TrySetCanceled();
}
else
{
taskCompletionSource.TrySetResult((T?)t.Result);
}
},
TaskContinuationOptions.ExecuteSynchronously);

return taskCompletionSource.Task;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,9 @@ var sqlTranslatingExpressionVisitor
return handlerContext.EvalOnClient();
}

private static readonly MethodInfo _castToNullableMethodInfo
= typeof(RelationalTaskExtensions).GetRuntimeMethods().Where(m => m.Name == nameof(RelationalTaskExtensions.CastToNullable)).Single();

private static Expression HandleSum(HandlerContext handlerContext)
{
if (!handlerContext.QueryModelVisitor.RequiresClientProjection
Expand Down Expand Up @@ -990,10 +993,21 @@ var clientExpression
.MakeGenericMethod(sumExpression.Type.UnwrapNullableType())
.Invoke(null, new object[] { handlerContext, /*throwOnNullResult:*/ false });

return
sumExpression.Type.IsNullableType()
if (handlerContext.QueryModelVisitor.QueryCompilationContext.IsAsyncQuery
&& !(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue12314", out var isEnabled) && isEnabled))
{
return sumExpression.Type.IsNullableType()
? Expression.Call(
_castToNullableMethodInfo.MakeGenericMethod(sumExpression.Type.UnwrapNullableType()),
clientExpression)
: clientExpression;
}
else
{
return sumExpression.Type.IsNullableType()
? Expression.Convert(clientExpression, sumExpression.Type)
: clientExpression;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1691,5 +1691,15 @@ await AssertQuery<City>(
assertOrder: true,
elementAsserter: CollectionAsserter<Gear>(e => e.Nickname, (e, a) => Assert.Equal(e.Nickname, a.Nickname)));
}

[ConditionalFact]
public virtual async Task Sum_with_no_data_nullable_double()
{
using (var ctx = CreateContext())
{
var result = await ctx.Missions.Where(m => m.CodeName == "Operation Foobar").Select(m => m.Rating).SumAsync();
Assert.Equal(0, result);
}
}
}
}
210 changes: 120 additions & 90 deletions src/EFCore.Specification.Tests/Query/AsyncQueryTestBase.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -3808,5 +3808,11 @@ public virtual async Task Cast_to_same_Type_CountAsync_works()
{
await AssertSingleResult<Customer>(cs => cs.Cast<Customer>().CountAsync());
}

[ConditionalFact]
public virtual async Task Sum_with_no_data_nullable()
{
await AssertSingleResult<Order>(os => os.Where(o => o.OrderID < 0).Select(o => (int?)o.OrderID).SumAsync());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ protected GearsOfWarQueryFixtureBase()
{
Assert.Equal(e.Id, a.Id);
Assert.Equal(e.CodeName, a.CodeName);
Assert.Equal(e.Rating, a.Rating);
Assert.Equal(e.Timeline, a.Timeline);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ public static IReadOnlyList<Squad> CreateSquads()
public static IReadOnlyList<Mission> CreateMissions()
=> new List<Mission>
{
new Mission { Id = 1, CodeName = "Lightmass Offensive", Timeline = new DateTimeOffset(2, 1, 2, 10, 0, 0, new TimeSpan(1, 30, 0)) },
new Mission { Id = 2, CodeName = "Hollow Storm", Timeline = new DateTimeOffset(2, 3, 1, 8, 0, 0, new TimeSpan(-5, 0, 0)) },
new Mission { Id = 3, CodeName = "Halvo Bay defense", Timeline = new DateTimeOffset(10, 5, 3, 12, 0, 0, new TimeSpan()) }
new Mission { Id = 1, CodeName = "Lightmass Offensive", Rating = 2.1, Timeline = new DateTimeOffset(2, 1, 2, 10, 0, 0, new TimeSpan(1, 30, 0)) },
new Mission { Id = 2, CodeName = "Hollow Storm", Rating = 4.2, Timeline = new DateTimeOffset(2, 3, 1, 8, 0, 0, new TimeSpan(-5, 0, 0)) },
new Mission { Id = 3, CodeName = "Halvo Bay defense", Rating = null, Timeline = new DateTimeOffset(10, 5, 3, 12, 0, 0, new TimeSpan()) }
};

public static IReadOnlyList<SquadMission> CreateSquadMissions()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public class Mission
public int Id { get; set; }

public string CodeName { get; set; }
public double? Rating { get; set; }
public DateTimeOffset Timeline { get; set; }

public virtual ICollection<SquadMission> ParticipatingSquads { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,5 +214,21 @@ await AssertQuery<Order>(
os => os.Select(o => new { o.Customer.City, Count = o.OrderDetails.Count() }),
elementSorter: e => e.City + " " + e.Count);
}

[Fact]
public async Task Sum_with_no_data_nullable_legacy_behavior()
{
AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12314", true);

try
{
await Assert.ThrowsAsync<InvalidOperationException>(
async () => await AssertSingleResult<Order>(os => os.Where(o => o.OrderID < 0).Select(o => (int?)o.OrderID).SumAsync()));
}
finally
{
AppContext.SetSwitch("Microsoft.EntityFrameworkCore.Issue12314", false);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3037,7 +3037,7 @@ public override void Where_datetimeoffset_now()
base.Where_datetimeoffset_now();

AssertSql(
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
FROM [Missions] AS [m]
WHERE [m].[Timeline] <> SYSDATETIMEOFFSET()");
}
Expand All @@ -3047,7 +3047,7 @@ public override void Where_datetimeoffset_utcnow()
base.Where_datetimeoffset_utcnow();

AssertSql(
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
FROM [Missions] AS [m]
WHERE [m].[Timeline] <> CAST(SYSUTCDATETIME() AS datetimeoffset)");
}
Expand All @@ -3060,7 +3060,7 @@ public override void Where_datetimeoffset_date_component()
AssertSql(
@"@__Date_0='0001-01-01T00:00:00'
SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
FROM [Missions] AS [m]
WHERE CONVERT(date, [m].[Timeline]) > @__Date_0");
}
Expand All @@ -3070,7 +3070,7 @@ public override void Where_datetimeoffset_year_component()
base.Where_datetimeoffset_year_component();

AssertSql(
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
FROM [Missions] AS [m]
WHERE DATEPART(year, [m].[Timeline]) = 2");
}
Expand All @@ -3080,7 +3080,7 @@ public override void Where_datetimeoffset_month_component()
base.Where_datetimeoffset_month_component();

AssertSql(
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
FROM [Missions] AS [m]
WHERE DATEPART(month, [m].[Timeline]) = 1");
}
Expand All @@ -3090,7 +3090,7 @@ public override void Where_datetimeoffset_dayofyear_component()
base.Where_datetimeoffset_dayofyear_component();

AssertSql(
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
FROM [Missions] AS [m]
WHERE DATEPART(dayofyear, [m].[Timeline]) = 2");
}
Expand All @@ -3100,7 +3100,7 @@ public override void Where_datetimeoffset_day_component()
base.Where_datetimeoffset_day_component();

AssertSql(
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
FROM [Missions] AS [m]
WHERE DATEPART(day, [m].[Timeline]) = 2");
}
Expand All @@ -3110,7 +3110,7 @@ public override void Where_datetimeoffset_hour_component()
base.Where_datetimeoffset_hour_component();

AssertSql(
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
FROM [Missions] AS [m]
WHERE DATEPART(hour, [m].[Timeline]) = 10");
}
Expand All @@ -3120,7 +3120,7 @@ public override void Where_datetimeoffset_minute_component()
base.Where_datetimeoffset_minute_component();

AssertSql(
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
FROM [Missions] AS [m]
WHERE DATEPART(minute, [m].[Timeline]) = 0");
}
Expand All @@ -3130,7 +3130,7 @@ public override void Where_datetimeoffset_second_component()
base.Where_datetimeoffset_second_component();

AssertSql(
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
FROM [Missions] AS [m]
WHERE DATEPART(second, [m].[Timeline]) = 0");
}
Expand All @@ -3140,7 +3140,7 @@ public override void Where_datetimeoffset_millisecond_component()
base.Where_datetimeoffset_millisecond_component();

AssertSql(
@"SELECT [m].[Id], [m].[CodeName], [m].[Timeline]
@"SELECT [m].[Id], [m].[CodeName], [m].[Rating], [m].[Timeline]
FROM [Missions] AS [m]
WHERE DATEPART(millisecond, [m].[Timeline]) = 0");
}
Expand Down

0 comments on commit 15bc370

Please sign in to comment.