Skip to content

Commit

Permalink
Translate Math.Min and Math.Max on SqlServer
Browse files Browse the repository at this point in the history
  • Loading branch information
yesmey committed Apr 10, 2022
1 parent 73db32d commit 909f90d
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 6 deletions.
60 changes: 60 additions & 0 deletions src/EFCore.SqlServer/Query/Internal/SqlServerMathTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,38 @@ public class SqlServerMathTranslator : IMethodCallTranslator
typeof(MathF).GetRuntimeMethod(nameof(MathF.Round), new[] { typeof(float), typeof(int) })!
};

private static readonly IEnumerable<MethodInfo> MaxMethodInfos = new[]
{
typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(double), typeof(double) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(float), typeof(float) })!,
typeof(Math).GetRuntimeMethod(nameof(MathF.Max), new[] { typeof(float), typeof(float) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(byte), typeof(byte) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(sbyte), typeof(sbyte) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(int), typeof(int) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(uint), typeof(uint) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(short), typeof(short) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(ushort), typeof(ushort) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(long), typeof(long) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(ulong), typeof(ulong) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(decimal), typeof(decimal) })!
};

private static readonly IEnumerable<MethodInfo> MinMethodInfos = new[]
{
typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(double), typeof(double) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(float), typeof(float) })!,
typeof(Math).GetRuntimeMethod(nameof(MathF.Max), new[] { typeof(float), typeof(float) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(byte), typeof(byte) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(sbyte), typeof(sbyte) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(int), typeof(int) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(uint), typeof(uint) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(short), typeof(short) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(ushort), typeof(ushort) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(long), typeof(long) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(ulong), typeof(ulong) })!,
typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(decimal), typeof(decimal) })!
};

private readonly ISqlExpressionFactory _sqlExpressionFactory;

/// <summary>
Expand Down Expand Up @@ -170,6 +202,34 @@ public SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory)
return _sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping);
}

if (MaxMethodInfos.Contains(method))
{
var left = arguments[0];
var right = arguments[1];

var typeMapping = ExpressionExtensions.InferTypeMapping(left, right);
left = _sqlExpressionFactory.ApplyTypeMapping(left, typeMapping);
right = _sqlExpressionFactory.ApplyTypeMapping(right, typeMapping);

return _sqlExpressionFactory.Case(
new[] { new CaseWhenClause(_sqlExpressionFactory.GreaterThanOrEqual(left, right), left) },
right);
}

if (MinMethodInfos.Contains(method))
{
var left = arguments[0];
var right = arguments[1];

var typeMapping = ExpressionExtensions.InferTypeMapping(left, right);
left = _sqlExpressionFactory.ApplyTypeMapping(left, typeMapping);
right = _sqlExpressionFactory.ApplyTypeMapping(right, typeMapping);

return _sqlExpressionFactory.Case(
new[] { new CaseWhenClause(_sqlExpressionFactory.LessThanOrEqual(left, right), left) },
right);
}

return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,14 @@ public override async Task Where_math_min(bool async)
AssertSql();
}

public override async Task Where_math_min2(bool async)
{
// Cosmos client evaluation. Issue #17246.
await AssertTranslationFailed(() => base.Where_math_min2(async));

AssertSql();
}

public override async Task Where_math_max(bool async)
{
// Cosmos client evaluation. Issue #17246.
Expand All @@ -510,6 +518,14 @@ public override async Task Where_math_max(bool async)
AssertSql();
}

public override async Task Where_math_max2(bool async)
{
// Cosmos client evaluation. Issue #17246.
await AssertTranslationFailed(() => base.Where_math_max2(async));

AssertSql();
}

public override async Task Where_mathf_abs1(bool async)
{
// Cosmos client evaluation. Issue #17246.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,14 @@ public virtual Task Where_math_max(bool async)
ss => ss.Set<OrderDetail>().Where(od => od.OrderID == 11077).Where(od => Math.Max(od.OrderID, od.ProductID) == od.OrderID),
entryCount: 25);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_math_max2(bool async)
=> AssertQuery(
async,
ss => ss.Set<OrderDetail>().Where(od => od.OrderID == 11077).Where(od => Math.Max(od.Discount, 0) == od.Discount),
entryCount: 25);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_math_min(bool async)
Expand All @@ -960,6 +968,15 @@ public virtual Task Where_math_min(bool async)
.Where(od => Math.Min(od.OrderID, od.ProductID) == od.ProductID),
entryCount: 25);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_math_min2(bool async)
=> AssertQuery(
async,
ss => ss.Set<OrderDetail>().Where(od => od.OrderID == 11077)
.Where(od => Math.Min(od.Discount, 0) == od.Discount),
entryCount: 12);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_mathf_abs1(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -946,18 +946,50 @@ FROM [Order Details] AS [o]

public override async Task Where_math_min(bool async)
{
// Translate Math.Min.
await AssertTranslationFailed(() => base.Where_math_min(async));
await base.Where_math_min(async);

AssertSql();
AssertSql(@"SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice]
FROM [Order Details] AS [o]
WHERE [o].[OrderID] = 11077 AND CASE
WHEN [o].[OrderID] <= [o].[ProductID] THEN [o].[OrderID]
ELSE [o].[ProductID]
END = [o].[ProductID]");
}

public override async Task Where_math_min2(bool async)
{
await base.Where_math_min2(async);

AssertSql(@"SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice]
FROM [Order Details] AS [o]
WHERE [o].[OrderID] = 11077 AND CASE
WHEN [o].[Discount] <= CAST(0 AS real) THEN [o].[Discount]
ELSE CAST(0 AS real)
END = [o].[Discount]");
}

public override async Task Where_math_max(bool async)
{
// Translate Math.Max.
await AssertTranslationFailed(() => base.Where_math_max(async));
await base.Where_math_max(async);

AssertSql();
AssertSql(@"SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice]
FROM [Order Details] AS [o]
WHERE [o].[OrderID] = 11077 AND CASE
WHEN [o].[OrderID] >= [o].[ProductID] THEN [o].[OrderID]
ELSE [o].[ProductID]
END = [o].[OrderID]");
}

public override async Task Where_math_max2(bool async)
{
await base.Where_math_max2(async);

AssertSql(@"SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice]
FROM [Order Details] AS [o]
WHERE [o].[OrderID] = 11077 AND CASE
WHEN [o].[Discount] >= CAST(0 AS real) THEN [o].[Discount]
ELSE CAST(0 AS real)
END = [o].[Discount]");
}

public override async Task Where_mathf_abs1(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,16 @@ public override async Task Where_math_min(bool async)
WHERE ""o"".""OrderID"" = 11077 AND min(""o"".""OrderID"", ""o"".""ProductID"") = ""o"".""ProductID""");
}

public override async Task Where_math_min2(bool async)
{
await base.Where_math_min2(async);

AssertSql(
@"SELECT ""o"".""OrderID"", ""o"".""ProductID"", ""o"".""Discount"", ""o"".""Quantity"", ""o"".""UnitPrice""
FROM ""Order Details"" AS ""o""
WHERE ""o"".""OrderID"" = 11077 AND min(""o"".""Discount"", 0) = ""o"".""Discount""");
}

public override async Task Where_math_max(bool async)
{
await base.Where_math_max(async);
Expand All @@ -486,6 +496,16 @@ public override async Task Where_math_max(bool async)
WHERE ""o"".""OrderID"" = 11077 AND max(""o"".""OrderID"", ""o"".""ProductID"") = ""o"".""OrderID""");
}

public override async Task Where_math_max2(bool async)
{
await base.Where_math_max2(async);

AssertSql(
@"SELECT ""o"".""OrderID"", ""o"".""ProductID"", ""o"".""Discount"", ""o"".""Quantity"", ""o"".""UnitPrice""
FROM ""Order Details"" AS ""o""
WHERE ""o"".""OrderID"" = 11077 AND max(""o"".""Discount"", 0) = ""o"".""Discount""");
}

public override async Task Where_string_to_lower(bool async)
{
await base.Where_string_to_lower(async);
Expand Down

0 comments on commit 909f90d

Please sign in to comment.