diff --git a/src/EFCore.Cosmos/Query/Internal/Translators/CosmosDateTimeMethodTranslator.cs b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosDateTimeMethodTranslator.cs index c2e4504dcb3..00cf9837c3f 100644 --- a/src/EFCore.Cosmos/Query/Internal/Translators/CosmosDateTimeMethodTranslator.cs +++ b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosDateTimeMethodTranslator.cs @@ -13,32 +13,6 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; /// public class CosmosDateTimeMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly Dictionary MethodInfoDatePartMapping = new() - { - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddYears), [typeof(int)])!, "yyyy" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMonths), [typeof(int)])!, "mm" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddDays), [typeof(double)])!, "dd" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddHours), [typeof(double)])!, "hh" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMinutes), [typeof(double)])!, "mi" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddSeconds), [typeof(double)])!, "ss" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMilliseconds), [typeof(double)])!, "ms" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMicroseconds), [typeof(double)])!, "mcs" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddYears), [typeof(int)])!, "yyyy" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMonths), [typeof(int)])!, "mm" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddDays), [typeof(double)])!, "dd" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddHours), [typeof(double)])!, "hh" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMinutes), [typeof(double)])!, "mi" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddSeconds), [typeof(double)])!, "ss" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMilliseconds), [typeof(double)])!, "ms" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMicroseconds), [typeof(double)])!, "mcs" } - }; - - private static readonly Dictionary MethodInfoDateDiffMapping = new() - { - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.ToUnixTimeSeconds), Type.EmptyTypes)!, "second" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.ToUnixTimeMilliseconds), Type.EmptyTypes)!, "millisecond" } - }; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -56,16 +30,30 @@ public class CosmosDateTimeMethodTranslator(ISqlExpressionFactory sqlExpressionF return null; } - if (MethodInfoDatePartMapping.TryGetValue(method, out var datePart) - && instance != null) + if (instance is null || arguments is not [var arg]) { - return sqlExpressionFactory.Function( - "DateTimeAdd", - arguments: [sqlExpressionFactory.Constant(datePart), arguments[0], instance], - instance.Type, - instance.TypeMapping); + return null; } - return null; + var datePart = method.Name switch + { + nameof(DateTime.AddYears) => "yyyy", + nameof(DateTime.AddMonths) => "mm", + nameof(DateTime.AddDays) => "dd", + nameof(DateTime.AddHours) => "hh", + nameof(DateTime.AddMinutes) => "mi", + nameof(DateTime.AddSeconds) => "ss", + nameof(DateTime.AddMilliseconds) => "ms", + nameof(DateTime.AddMicroseconds) => "mcs", + _ => (string?)null + }; + + return datePart is not null + ? sqlExpressionFactory.Function( + "DateTimeAdd", + arguments: [sqlExpressionFactory.Constant(datePart), arg, instance], + instance.Type, + instance.TypeMapping) + : null; } } diff --git a/src/EFCore.Cosmos/Query/Internal/Translators/CosmosMathTranslator.cs b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosMathTranslator.cs index 3dab9722b31..28c024d54b4 100644 --- a/src/EFCore.Cosmos/Query/Internal/Translators/CosmosMathTranslator.cs +++ b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosMathTranslator.cs @@ -13,66 +13,6 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; /// public class CosmosMathTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly Dictionary SupportedMethodTranslations = new() - { - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(decimal)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(double)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(float)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(int)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(long)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(sbyte)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(short)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), [typeof(decimal)])!, "CEILING" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), [typeof(double)])!, "CEILING" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Floor), [typeof(decimal)])!, "FLOOR" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Floor), [typeof(double)])!, "FLOOR" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Pow), [typeof(double), typeof(double)])!, "POWER" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Exp), [typeof(double)])!, "EXP" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Log10), [typeof(double)])!, "LOG10" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Log), [typeof(double)])!, "LOG" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Log), [typeof(double), typeof(double)])!, "LOG" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sqrt), [typeof(double)])!, "SQRT" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Acos), [typeof(double)])!, "ACOS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Asin), [typeof(double)])!, "ASIN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Atan), [typeof(double)])!, "ATAN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Atan2), [typeof(double), typeof(double)])!, "ATN2" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Cos), [typeof(double)])!, "COS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sin), [typeof(double)])!, "SIN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Tan), [typeof(double)])!, "TAN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(decimal)])!, "SIGN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(double)])!, "SIGN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(float)])!, "SIGN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(int)])!, "SIGN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(long)])!, "SIGN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(sbyte)])!, "SIGN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(short)])!, "SIGN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Truncate), [typeof(decimal)])!, "TRUNC" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Truncate), [typeof(double)])!, "TRUNC" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Round), [typeof(decimal)])!, "ROUND" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Round), [typeof(double)])!, "ROUND" }, - { typeof(double).GetRuntimeMethod(nameof(double.DegreesToRadians), [typeof(double)])!, "RADIANS" }, - { typeof(double).GetRuntimeMethod(nameof(double.RadiansToDegrees), [typeof(double)])!, "DEGREES" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Ceiling), [typeof(float)])!, "CEILING" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Floor), [typeof(float)])!, "FLOOR" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Pow), [typeof(float), typeof(float)])!, "POWER" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Exp), [typeof(float)])!, "EXP" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Log10), [typeof(float)])!, "LOG10" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), [typeof(float)])!, "LOG" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), [typeof(float), typeof(float)])!, "LOG" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Sqrt), [typeof(float)])!, "SQRT" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Acos), [typeof(float)])!, "ACOS" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Asin), [typeof(float)])!, "ASIN" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan), [typeof(float)])!, "ATAN" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan2), [typeof(float), typeof(float)])!, "ATN2" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Cos), [typeof(float)])!, "COS" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Sin), [typeof(float)])!, "SIN" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Tan), [typeof(float)])!, "TAN" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Truncate), [typeof(float)])!, "TRUNC" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Round), [typeof(float)])!, "ROUND" }, - { typeof(float).GetRuntimeMethod(nameof(float.DegreesToRadians), [typeof(float)])!, "RADIANS" }, - { typeof(float).GetRuntimeMethod(nameof(float.RadiansToDegrees), [typeof(float)])!, "DEGREES" }, - }; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -85,21 +25,85 @@ public class CosmosMathTranslator(ISqlExpressionFactory sqlExpressionFactory) : IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (SupportedMethodTranslations.TryGetValue(method, out var sqlFunctionName)) + if (method.DeclaringType != typeof(Math) + && method.DeclaringType != typeof(MathF) + && method.DeclaringType != typeof(double) + && method.DeclaringType != typeof(float)) + { + return null; + } + + var sqlFunctionName = method.Name switch { - var typeMapping = arguments.Count == 1 - ? ExpressionExtensions.InferTypeMapping(arguments[0]) - : ExpressionExtensions.InferTypeMapping(arguments[0], arguments[1]); + nameof(Math.Abs) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(decimal) || t == typeof(double) || t == typeof(float) + || t == typeof(int) || t == typeof(long) || t == typeof(sbyte) || t == typeof(short)) + => "ABS", - var newArguments = arguments.Select(e => sqlExpressionFactory.ApplyTypeMapping(e, typeMapping!)); + nameof(Math.Ceiling) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(decimal) || t == typeof(double) || t == typeof(float)) + => "CEILING", + nameof(Math.Floor) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(decimal) || t == typeof(double) || t == typeof(float)) + => "FLOOR", + nameof(Math.Round) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(decimal) || t == typeof(double) || t == typeof(float)) + => "ROUND", + nameof(Math.Truncate) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(decimal) || t == typeof(double) || t == typeof(float)) + => "TRUNC", + nameof(Math.Sign) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(decimal) || t == typeof(double) || t == typeof(float) + || t == typeof(int) || t == typeof(long) || t == typeof(sbyte) || t == typeof(short)) + => "SIGN", - return sqlExpressionFactory.Function( - sqlFunctionName, - newArguments, - method.ReturnType, - typeMapping); + nameof(Math.Pow) when arguments is [_, _] + => "POWER", + nameof(Math.Exp) when arguments is [_] + => "EXP", + nameof(Math.Log10) when arguments is [_] + => "LOG10", + nameof(Math.Log) when arguments is [_] or [_, _] + => "LOG", + nameof(Math.Sqrt) when arguments is [_] + => "SQRT", + nameof(Math.Acos) when arguments is [_] + => "ACOS", + nameof(Math.Asin) when arguments is [_] + => "ASIN", + nameof(Math.Atan) when arguments is [_] + => "ATAN", + nameof(Math.Atan2) when arguments is [_, _] + => "ATN2", + nameof(Math.Cos) when arguments is [_] + => "COS", + nameof(Math.Sin) when arguments is [_] + => "SIN", + nameof(Math.Tan) when arguments is [_] + => "TAN", + nameof(double.DegreesToRadians) when arguments is [_] + => "RADIANS", + nameof(double.RadiansToDegrees) when arguments is [_] + => "DEGREES", + + _ => null + }; + + if (sqlFunctionName is null) + { + return null; } - return null; + var typeMapping = arguments.Count == 1 + ? ExpressionExtensions.InferTypeMapping(arguments[0]) + : ExpressionExtensions.InferTypeMapping(arguments[0], arguments[1]); + + var newArguments = arguments.Select(e => sqlExpressionFactory.ApplyTypeMapping(e, typeMapping!)); + + return sqlExpressionFactory.Function( + sqlFunctionName, + newArguments, + method.ReturnType, + typeMapping); } } diff --git a/src/EFCore.Cosmos/Query/Internal/Translators/CosmosRandomTranslator.cs b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosRandomTranslator.cs index eaa785c1534..1d4bab65d10 100644 --- a/src/EFCore.Cosmos/Query/Internal/Translators/CosmosRandomTranslator.cs +++ b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosRandomTranslator.cs @@ -13,9 +13,6 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; /// public class CosmosRandomTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo MethodInfo = typeof(DbFunctionsExtensions).GetRuntimeMethod( - nameof(DbFunctionsExtensions.Random), [typeof(DbFunctions)])!; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -27,10 +24,11 @@ public class CosmosRandomTranslator(ISqlExpressionFactory sqlExpressionFactory) MethodInfo method, IReadOnlyList arguments, IDiagnosticsLogger logger) - => MethodInfo.Equals(method) - ? sqlExpressionFactory.Function( - "RAND", - [], - method.ReturnType) - : null; + => method.DeclaringType == typeof(DbFunctionsExtensions) + && method.Name == nameof(DbFunctionsExtensions.Random) + ? sqlExpressionFactory.Function( + "RAND", + [], + method.ReturnType) + : null; } diff --git a/src/EFCore.Cosmos/Query/Internal/Translators/CosmosRegexTranslator.cs b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosRegexTranslator.cs index 7b11562abb3..a7741897b87 100644 --- a/src/EFCore.Cosmos/Query/Internal/Translators/CosmosRegexTranslator.cs +++ b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosRegexTranslator.cs @@ -15,12 +15,6 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; public class CosmosRegexTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo IsMatch = - typeof(Regex).GetRuntimeMethod(nameof(Regex.IsMatch), [typeof(string), typeof(string)])!; - - private static readonly MethodInfo IsMatchWithRegexOptions = - typeof(Regex).GetRuntimeMethod(nameof(Regex.IsMatch), [typeof(string), typeof(string), typeof(RegexOptions)])!; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -33,7 +27,13 @@ public class CosmosRegexTranslator(ISqlExpressionFactory sqlExpressionFactory) IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (method != IsMatch && method != IsMatchWithRegexOptions) + if (method.DeclaringType != typeof(Regex) + || method.Name != nameof(Regex.IsMatch)) + { + return null; + } + + if (arguments is not ([_, _] or [_, _, _])) { return null; } @@ -44,7 +44,7 @@ public class CosmosRegexTranslator(ISqlExpressionFactory sqlExpressionFactory) sqlExpressionFactory.ApplyTypeMapping(input, typeMapping), sqlExpressionFactory.ApplyTypeMapping(pattern, typeMapping)); - if (method == IsMatch || arguments[2] is SqlConstantExpression { Value: RegexOptions.None }) + if (arguments.Count == 2 || arguments[2] is SqlConstantExpression { Value: RegexOptions.None }) { return sqlExpressionFactory.Function("RegexMatch", [input, pattern], typeof(bool)); } diff --git a/src/EFCore.Cosmos/Query/Internal/Translators/CosmosStringMethodTranslator.cs b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosStringMethodTranslator.cs index 2cbd2b84231..2d1a29e764d 100644 --- a/src/EFCore.Cosmos/Query/Internal/Translators/CosmosStringMethodTranslator.cs +++ b/src/EFCore.Cosmos/Query/Internal/Translators/CosmosStringMethodTranslator.cs @@ -13,107 +13,6 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; /// public class CosmosStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo IndexOfMethodInfoString - = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(string)])!; - - private static readonly MethodInfo IndexOfMethodInfoChar - = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(char)])!; - - private static readonly MethodInfo IndexOfMethodInfoWithStartingPositionString - = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(string), typeof(int)])!; - - private static readonly MethodInfo IndexOfMethodInfoWithStartingPositionChar - = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(char), typeof(int)])!; - - private static readonly MethodInfo ReplaceMethodInfoString - = typeof(string).GetRuntimeMethod(nameof(string.Replace), [typeof(string), typeof(string)])!; - - private static readonly MethodInfo ReplaceMethodInfoChar - = typeof(string).GetRuntimeMethod(nameof(string.Replace), [typeof(char), typeof(char)])!; - - private static readonly MethodInfo ContainsMethodInfoString - = typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(string)])!; - - private static readonly MethodInfo ContainsMethodInfoChar - = typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(char)])!; - - private static readonly MethodInfo ContainsWithStringComparisonMethodInfoString - = typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(string), typeof(StringComparison)])!; - - private static readonly MethodInfo ContainsWithStringComparisonMethodInfoChar - = typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(char), typeof(StringComparison)])!; - - private static readonly MethodInfo StartsWithMethodInfoString - = typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(string)])!; - - private static readonly MethodInfo StartsWithMethodInfoChar - = typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(char)])!; - - private static readonly MethodInfo StartsWithWithStringComparisonMethodInfoString - = typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(string), typeof(StringComparison)])!; - - private static readonly MethodInfo EndsWithMethodInfoString - = typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(string)])!; - - private static readonly MethodInfo EndsWithMethodInfoChar - = typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(char)])!; - - private static readonly MethodInfo EndsWithWithStringComparisonMethodInfoString - = typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(string), typeof(StringComparison)])!; - - private static readonly MethodInfo ToLowerMethodInfo - = typeof(string).GetRuntimeMethod(nameof(string.ToLower), [])!; - - private static readonly MethodInfo ToUpperMethodInfo - = typeof(string).GetRuntimeMethod(nameof(string.ToUpper), [])!; - - private static readonly MethodInfo TrimStartMethodInfoWithoutArgs - = typeof(string).GetRuntimeMethod(nameof(string.TrimStart), [])!; - - private static readonly MethodInfo TrimEndMethodInfoWithoutArgs - = typeof(string).GetRuntimeMethod(nameof(string.TrimEnd), [])!; - - private static readonly MethodInfo TrimMethodInfoWithoutArgs - = typeof(string).GetRuntimeMethod(nameof(string.Trim), [])!; - - private static readonly MethodInfo TrimStartMethodInfoWithCharArrayArg - = typeof(string).GetRuntimeMethod(nameof(string.TrimStart), [typeof(char[])])!; - - private static readonly MethodInfo TrimEndMethodInfoWithCharArrayArg - = typeof(string).GetRuntimeMethod(nameof(string.TrimEnd), [typeof(char[])])!; - - private static readonly MethodInfo TrimMethodInfoWithCharArrayArg - = typeof(string).GetRuntimeMethod(nameof(string.Trim), [typeof(char[])])!; - - private static readonly MethodInfo SubstringMethodInfoWithOneArg - = typeof(string).GetRuntimeMethod(nameof(string.Substring), [typeof(int)])!; - - private static readonly MethodInfo SubstringMethodInfoWithTwoArgs - = typeof(string).GetRuntimeMethod(nameof(string.Substring), [typeof(int), typeof(int)])!; - - private static readonly MethodInfo FirstOrDefaultMethodInfoWithoutArgs - = typeof(Enumerable).GetRuntimeMethods().Single(m => m.Name == nameof(Enumerable.FirstOrDefault) - && m.GetParameters().Length == 1).MakeGenericMethod(typeof(char)); - - private static readonly MethodInfo LastOrDefaultMethodInfoWithoutArgs - = typeof(Enumerable).GetRuntimeMethods().Single(m => m.Name == nameof(Enumerable.LastOrDefault) - && m.GetParameters().Length == 1).MakeGenericMethod(typeof(char)); - - private static readonly MethodInfo StringConcatWithTwoArguments = - typeof(string).GetRuntimeMethod(nameof(string.Concat), [typeof(string), typeof(string)])!; - - private static readonly MethodInfo StringConcatWithThreeArguments = - typeof(string).GetRuntimeMethod(nameof(string.Concat), [typeof(string), typeof(string), typeof(string)])!; - - private static readonly MethodInfo StringConcatWithFourArguments = - typeof(string).GetRuntimeMethod(nameof(string.Concat), [typeof(string), typeof(string), typeof(string), typeof(string)])!; - - private static readonly MethodInfo StringComparisonWithComparisonTypeArgumentInstance - = typeof(string).GetRuntimeMethod(nameof(string.Equals), [typeof(string), typeof(StringComparison)])!; - - private static readonly MethodInfo StringComparisonWithComparisonTypeArgumentStatic - = typeof(string).GetRuntimeMethod(nameof(string.Equals), [typeof(string), typeof(string), typeof(StringComparison)])!; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -126,210 +25,135 @@ private static readonly MethodInfo StringComparisonWithComparisonTypeArgumentSta IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (instance != null) + if (method.DeclaringType == typeof(string)) { - if (IndexOfMethodInfoString.Equals(method) || IndexOfMethodInfoChar.Equals(method)) - { - return TranslateSystemFunction("INDEX_OF", typeof(int), instance, arguments[0]); - } - - if (IndexOfMethodInfoWithStartingPositionString.Equals(method) || IndexOfMethodInfoWithStartingPositionChar.Equals(method)) - { - return TranslateSystemFunction("INDEX_OF", typeof(int), instance, arguments[0], arguments[1]); - } - - if (ReplaceMethodInfoString.Equals(method) || ReplaceMethodInfoChar.Equals(method)) - { - return TranslateSystemFunction("REPLACE", method.ReturnType, instance, arguments[0], arguments[1]); - } - - if (ContainsMethodInfoString.Equals(method) || ContainsMethodInfoChar.Equals(method)) + if (instance is not null) { - return TranslateSystemFunction("CONTAINS", typeof(bool), instance, arguments[0]); - } - - if (ContainsWithStringComparisonMethodInfoString.Equals(method) || ContainsWithStringComparisonMethodInfoChar.Equals(method)) - { - if (arguments[1] is SqlConstantExpression { Value: StringComparison comparisonType }) - { - return comparisonType switch - { - StringComparison.Ordinal - => TranslateSystemFunction( - "CONTAINS", typeof(bool), instance, arguments[0], sqlExpressionFactory.Constant(false)), - StringComparison.OrdinalIgnoreCase - => TranslateSystemFunction( - "CONTAINS", typeof(bool), instance, arguments[0], sqlExpressionFactory.Constant(true)), - - _ => null // TODO: Explicit translation error for unsupported StringComparison argument (depends on #26410) - }; - } - - // TODO: Explicit translation error for non-constant StringComparison argument (depends on #26410) - return null; - } - - if (StartsWithMethodInfoString.Equals(method) || StartsWithMethodInfoChar.Equals(method)) - { - return TranslateSystemFunction("STARTSWITH", typeof(bool), instance, arguments[0]); - } - - if (StartsWithWithStringComparisonMethodInfoString.Equals(method)) - { - if (arguments[1] is SqlConstantExpression { Value: StringComparison comparisonType }) + return method.Name switch { - return comparisonType switch + nameof(string.IndexOf) when arguments is [var arg] + => TranslateSystemFunction("INDEX_OF", typeof(int), instance, arg), + nameof(string.IndexOf) when arguments is [var arg, var startIndex] + => TranslateSystemFunction("INDEX_OF", typeof(int), instance, arg, startIndex), + nameof(string.Replace) when arguments is [var oldValue, var newValue] + => TranslateSystemFunction("REPLACE", method.ReturnType, instance, oldValue, newValue), + nameof(string.Contains) when arguments is [var arg] + => TranslateSystemFunction("CONTAINS", typeof(bool), instance, arg), + nameof(string.Contains) when arguments is [var arg, SqlConstantExpression { Value: StringComparison comparisonType }] + => comparisonType switch + { + StringComparison.Ordinal + => TranslateSystemFunction("CONTAINS", typeof(bool), instance, arg, sqlExpressionFactory.Constant(false)), + StringComparison.OrdinalIgnoreCase + => TranslateSystemFunction("CONTAINS", typeof(bool), instance, arg, sqlExpressionFactory.Constant(true)), + _ => null + }, + nameof(string.StartsWith) when arguments is [var arg] && arg.Type is { } t && (t == typeof(string) || t == typeof(char)) + => TranslateSystemFunction("STARTSWITH", typeof(bool), instance, arg), + nameof(string.StartsWith) when arguments is [var arg, SqlConstantExpression { Value: StringComparison comparisonType }] + && arg.Type == typeof(string) + => comparisonType switch + { + StringComparison.Ordinal + => TranslateSystemFunction( + "STARTSWITH", typeof(bool), instance, arg, sqlExpressionFactory.Constant(false)), + StringComparison.OrdinalIgnoreCase + => TranslateSystemFunction( + "STARTSWITH", typeof(bool), instance, arg, sqlExpressionFactory.Constant(true)), + _ => null + }, + nameof(string.EndsWith) when arguments is [var arg] && arg.Type is { } t && (t == typeof(string) || t == typeof(char)) + => TranslateSystemFunction("ENDSWITH", typeof(bool), instance, arg), + nameof(string.EndsWith) when arguments is [var arg, SqlConstantExpression { Value: StringComparison comparisonType }] + && arg.Type == typeof(string) + => comparisonType switch + { + StringComparison.Ordinal + => TranslateSystemFunction( + "ENDSWITH", typeof(bool), instance, arg, sqlExpressionFactory.Constant(false)), + StringComparison.OrdinalIgnoreCase + => TranslateSystemFunction( + "ENDSWITH", typeof(bool), instance, arg, sqlExpressionFactory.Constant(true)), + _ => null + }, + nameof(string.ToLower) when arguments is [] + => TranslateSystemFunction("LOWER", method.ReturnType, instance), + nameof(string.ToUpper) when arguments is [] + => TranslateSystemFunction("UPPER", method.ReturnType, instance), + nameof(string.TrimStart) when arguments is [] + => TranslateSystemFunction("LTRIM", method.ReturnType, instance), + nameof(string.TrimStart) when arguments is [SqlConstantExpression { Value: char[] { Length: 0 } }] + // Cosmos DB LTRIM does not take arguments + => TranslateSystemFunction("LTRIM", method.ReturnType, instance), + nameof(string.TrimEnd) when arguments is [] + => TranslateSystemFunction("RTRIM", method.ReturnType, instance), + nameof(string.TrimEnd) when arguments is [SqlConstantExpression { Value: char[] { Length: 0 } }] + // Cosmos DB RTRIM does not take arguments + => TranslateSystemFunction("RTRIM", method.ReturnType, instance), + nameof(string.Trim) when arguments is [] + => TranslateSystemFunction("TRIM", method.ReturnType, instance), + nameof(string.Trim) when arguments is [SqlConstantExpression { Value: char[] { Length: 0 } }] + // Cosmos DB TRIM does not take arguments + => TranslateSystemFunction("TRIM", method.ReturnType, instance), + nameof(string.Substring) when arguments is [var startIndex] + => TranslateSystemFunction( + "SUBSTRING", + method.ReturnType, + instance, + startIndex, + TranslateSystemFunction("LENGTH", typeof(int), instance)), + nameof(string.Substring) when arguments is [SqlConstantExpression { Value: 0 }, var length] + => TranslateSystemFunction("LEFT", method.ReturnType, instance, length), + nameof(string.Substring) when arguments is [var startIndex, var length] + => TranslateSystemFunction("SUBSTRING", method.ReturnType, instance, startIndex, length), + nameof(string.Equals) when arguments is [var other, SqlConstantExpression + { + Value: StringComparison comparisonTypeValue + and (StringComparison.OrdinalIgnoreCase or StringComparison.Ordinal) + }] + => comparisonTypeValue == StringComparison.OrdinalIgnoreCase + ? TranslateSystemFunction( + "STRINGEQUALS", typeof(bool), instance, other, sqlExpressionFactory.Constant(true)) + : TranslateSystemFunction("STRINGEQUALS", typeof(bool), instance, other), + _ => null + }; + } + + // Static string methods + return method.Name switch + { + nameof(string.Concat) when arguments is [var a, var b] + => sqlExpressionFactory.Add(a, b), + nameof(string.Concat) when arguments is [var a, var b, var c] + => sqlExpressionFactory.Add(a, sqlExpressionFactory.Add(b, c)), + nameof(string.Concat) when arguments is [var a, var b, var c, var d] + => sqlExpressionFactory.Add(a, sqlExpressionFactory.Add(b, sqlExpressionFactory.Add(c, d))), + nameof(string.Equals) when arguments is [var left, var right, SqlConstantExpression { - StringComparison.Ordinal - => TranslateSystemFunction( - "STARTSWITH", typeof(bool), instance, arguments[0], sqlExpressionFactory.Constant(false)), - StringComparison.OrdinalIgnoreCase - => TranslateSystemFunction( - "STARTSWITH", typeof(bool), instance, arguments[0], sqlExpressionFactory.Constant(true)), - - _ => null // TODO: Explicit translation error for unsupported StringComparison argument (depends on #26410) - }; - } - - // TODO: Explicit translation error for non-constant StringComparison argument (depends on #26410) - return null; - } - - if (EndsWithMethodInfoString.Equals(method) || EndsWithMethodInfoChar.Equals(method)) - { - return TranslateSystemFunction("ENDSWITH", typeof(bool), instance, arguments[0]); - } - - if (EndsWithWithStringComparisonMethodInfoString.Equals(method)) - { - if (arguments[1] is SqlConstantExpression { Value: StringComparison comparisonType }) - { - return comparisonType switch - { - StringComparison.Ordinal - => TranslateSystemFunction( - "ENDSWITH", typeof(bool), instance, arguments[0], sqlExpressionFactory.Constant(false)), - StringComparison.OrdinalIgnoreCase - => TranslateSystemFunction( - "ENDSWITH", typeof(bool), instance, arguments[0], sqlExpressionFactory.Constant(true)), - - _ => null // TODO: Explicit translation error for unsupported StringComparison argument (depends on #26410) - }; - } - - // TODO: Explicit translation error for non-constant StringComparison argument (depends on #26410) - return null; - } - - if (ToLowerMethodInfo.Equals(method)) - { - return TranslateSystemFunction("LOWER", method.ReturnType, instance); - } - - if (ToUpperMethodInfo.Equals(method)) - { - return TranslateSystemFunction("UPPER", method.ReturnType, instance); - } - - if (TrimStartMethodInfoWithoutArgs.Equals(method) - || (TrimStartMethodInfoWithCharArrayArg.Equals(method) - // Cosmos DB LTRIM does not take arguments - && ((arguments[0] as SqlConstantExpression)?.Value as Array)?.Length == 0)) - { - return TranslateSystemFunction("LTRIM", method.ReturnType, instance); - } - - if (TrimEndMethodInfoWithoutArgs.Equals(method) - || (TrimEndMethodInfoWithCharArrayArg.Equals(method) - // Cosmos DB RTRIM does not take arguments - && ((arguments[0] as SqlConstantExpression)?.Value as Array)?.Length == 0)) - { - return TranslateSystemFunction("RTRIM", method.ReturnType, instance); - } - - if (TrimMethodInfoWithoutArgs.Equals(method) - || (TrimMethodInfoWithCharArrayArg.Equals(method) - // Cosmos DB TRIM does not take arguments - && ((arguments[0] as SqlConstantExpression)?.Value as Array)?.Length == 0)) - { - return TranslateSystemFunction("TRIM", method.ReturnType, instance); - } - - if (SubstringMethodInfoWithOneArg.Equals(method)) - { - return TranslateSystemFunction( - "SUBSTRING", - method.ReturnType, - instance, - arguments[0], - TranslateSystemFunction("LENGTH", typeof(int), instance)); - } - - if (SubstringMethodInfoWithTwoArgs.Equals(method)) - { - return arguments[0] is SqlConstantExpression { Value: 0 } - ? TranslateSystemFunction("LEFT", method.ReturnType, instance, arguments[1]) - : TranslateSystemFunction("SUBSTRING", method.ReturnType, instance, arguments[0], arguments[1]); - } - } - - if (FirstOrDefaultMethodInfoWithoutArgs.Equals(method)) - { - return TranslateSystemFunction("LEFT", typeof(char), arguments[0], sqlExpressionFactory.Constant(1)); - } - - if (LastOrDefaultMethodInfoWithoutArgs.Equals(method)) - { - return TranslateSystemFunction("RIGHT", typeof(char), arguments[0], sqlExpressionFactory.Constant(1)); - } - - if (StringConcatWithTwoArguments.Equals(method)) - { - return sqlExpressionFactory.Add( - arguments[0], - arguments[1]); - } - - if (StringConcatWithThreeArguments.Equals(method)) - { - return sqlExpressionFactory.Add( - arguments[0], - sqlExpressionFactory.Add( - arguments[1], - arguments[2])); - } - - if (StringConcatWithFourArguments.Equals(method)) - { - return sqlExpressionFactory.Add( - arguments[0], - sqlExpressionFactory.Add( - arguments[1], - sqlExpressionFactory.Add( - arguments[2], - arguments[3]))); + Value: StringComparison comparisonTypeValue + and (StringComparison.OrdinalIgnoreCase or StringComparison.Ordinal) + }] + => comparisonTypeValue == StringComparison.OrdinalIgnoreCase + ? TranslateSystemFunction( + "STRINGEQUALS", typeof(bool), left, right, sqlExpressionFactory.Constant(true)) + : TranslateSystemFunction("STRINGEQUALS", typeof(bool), left, right), + _ => null + }; } - if (StringComparisonWithComparisonTypeArgumentInstance.Equals(method) - || StringComparisonWithComparisonTypeArgumentStatic.Equals(method)) + if (method.DeclaringType == typeof(Enumerable) + && method.IsGenericMethod + && method.GetGenericArguments()[0] == typeof(char) + && arguments is [var source]) { - var comparisonTypeArgument = arguments[^1]; - - if (comparisonTypeArgument is SqlConstantExpression - { - Value: StringComparison comparisonTypeArgumentValue and (StringComparison.OrdinalIgnoreCase or StringComparison.Ordinal) - }) - { - return StringComparisonWithComparisonTypeArgumentInstance.Equals(method) - ? comparisonTypeArgumentValue == StringComparison.OrdinalIgnoreCase - ? TranslateSystemFunction( - "STRINGEQUALS", typeof(bool), instance!, arguments[0], sqlExpressionFactory.Constant(true)) - : TranslateSystemFunction("STRINGEQUALS", typeof(bool), instance!, arguments[0]) - : comparisonTypeArgumentValue == StringComparison.OrdinalIgnoreCase - ? TranslateSystemFunction( - "STRINGEQUALS", typeof(bool), arguments[0], arguments[1], sqlExpressionFactory.Constant(true)) - : TranslateSystemFunction("STRINGEQUALS", typeof(bool), arguments[0], arguments[1]); - } + return method.Name switch + { + nameof(Enumerable.FirstOrDefault) + => TranslateSystemFunction("LEFT", typeof(char), source, sqlExpressionFactory.Constant(1)), + nameof(Enumerable.LastOrDefault) + => TranslateSystemFunction("RIGHT", typeof(char), source, sqlExpressionFactory.Constant(1)), + _ => null + }; } return null; diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerByteArrayMethodTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerByteArrayMethodTranslator.cs index 86c07a0aa4d..21ad745fb61 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerByteArrayMethodTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerByteArrayMethodTranslator.cs @@ -12,19 +12,8 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerByteArrayMethodTranslator : IMethodCallTranslator +public class SqlServerByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -37,40 +26,39 @@ public SqlServerByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFac IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (method.IsGenericMethod) + if (method.IsGenericMethod + && method.DeclaringType == typeof(Enumerable)) { - var methodDefinition = method.GetGenericMethodDefinition(); - if (methodDefinition.Equals(EnumerableMethods.Contains) - && arguments[0].Type == typeof(byte[])) + switch (method.Name) { - var source = arguments[0]; - var sourceTypeMapping = source.TypeMapping; + case nameof(Enumerable.Contains) when arguments is [var source, var item] && source.Type == typeof(byte[]): + { + var sourceTypeMapping = source.TypeMapping; - var value = arguments[1] is SqlConstantExpression constantValue - ? _sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, sourceTypeMapping) - : _sqlExpressionFactory.Convert(arguments[1], typeof(byte[]), sourceTypeMapping); + var value = item is SqlConstantExpression constantValue + ? sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, sourceTypeMapping) + : sqlExpressionFactory.Convert(item, typeof(byte[]), sourceTypeMapping); - return _sqlExpressionFactory.GreaterThan( - _sqlExpressionFactory.Function( - "CHARINDEX", - [value, source], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[2], - typeof(int)), - _sqlExpressionFactory.Constant(0)); - } + return sqlExpressionFactory.GreaterThan( + sqlExpressionFactory.Function( + "CHARINDEX", + [value, source], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[2], + typeof(int)), + sqlExpressionFactory.Constant(0)); + } - if (methodDefinition.Equals(EnumerableMethods.FirstWithoutPredicate) - && arguments[0].Type == typeof(byte[])) - { - return _sqlExpressionFactory.Convert( - _sqlExpressionFactory.Function( - "SUBSTRING", - [arguments[0], _sqlExpressionFactory.Constant(1), _sqlExpressionFactory.Constant(1)], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[3], - typeof(byte[])), - method.ReturnType); + // First without a predicate + case nameof(Enumerable.First) when arguments is [var source] && source.Type == typeof(byte[]): + return sqlExpressionFactory.Convert( + sqlExpressionFactory.Function( + "SUBSTRING", + [source, sqlExpressionFactory.Constant(1), sqlExpressionFactory.Constant(1)], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[3], + typeof(byte[])), + method.ReturnType); } } diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerConvertTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerConvertTranslator.cs index 185f5d397df..bf3d3e98596 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerConvertTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerConvertTranslator.cs @@ -12,62 +12,8 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerConvertTranslator : IMethodCallTranslator +public class SqlServerConvertTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly Dictionary TypeMapping = new() - { - [nameof(Convert.ToBoolean)] = "bit", - [nameof(Convert.ToByte)] = "tinyint", - [nameof(Convert.ToDecimal)] = "decimal(18, 2)", - [nameof(Convert.ToDouble)] = "float", - [nameof(Convert.ToInt16)] = "smallint", - [nameof(Convert.ToInt32)] = "int", - [nameof(Convert.ToInt64)] = "bigint", - [nameof(Convert.ToString)] = "nvarchar(max)" - }; - - private static readonly List SupportedTypes = - [ - typeof(bool), - typeof(byte), - typeof(DateTime), - typeof(decimal), - typeof(double), - typeof(float), - typeof(int), - typeof(long), - typeof(short), - typeof(string), - typeof(object) - ]; - - private static readonly MethodInfo[] SupportedMethods; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerConvertTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - - static SqlServerConvertTranslator() - { - var convertInfo = typeof(Convert).GetTypeInfo(); - SupportedMethods = TypeMapping.Keys - .SelectMany(name => convertInfo.GetDeclaredMethods(name) - .Where(method => - { - var parameters = method.GetParameters(); - return parameters.Length == 1 - && SupportedTypes.Contains(parameters[0].ParameterType); - })) - .ToArray(); - } - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -79,12 +25,50 @@ static SqlServerConvertTranslator() MethodInfo method, IReadOnlyList arguments, IDiagnosticsLogger logger) - => SupportedMethods.Contains(method) - ? _sqlExpressionFactory.Function( - "CONVERT", - [_sqlExpressionFactory.Fragment(TypeMapping[method.Name]), arguments[0]], - nullable: true, - argumentsPropagateNullability: Statics.FalseTrue, - method.ReturnType) - : null; + { + if (method.DeclaringType != typeof(Convert)) + { + return null; + } + + var sqlType = method.Name switch + { + nameof(Convert.ToBoolean) => "bit", + nameof(Convert.ToByte) => "tinyint", + nameof(Convert.ToDecimal) => "decimal(18, 2)", + nameof(Convert.ToDouble) => "float", + nameof(Convert.ToInt16) => "smallint", + nameof(Convert.ToInt32) => "int", + nameof(Convert.ToInt64) => "bigint", + nameof(Convert.ToString) => "nvarchar(max)", + _ => null + }; + + if (sqlType is null + || method.GetParameters() is not [{ ParameterType: var paramType }] + || !IsSupportedType(paramType)) + { + return null; + } + + return sqlExpressionFactory.Function( + "CONVERT", + [sqlExpressionFactory.Fragment(sqlType), arguments[0]], + nullable: true, + argumentsPropagateNullability: Statics.FalseTrue, + method.ReturnType); + } + + private static bool IsSupportedType(Type type) + => type == typeof(bool) + || type == typeof(byte) + || type == typeof(DateTime) + || type == typeof(decimal) + || type == typeof(double) + || type == typeof(float) + || type == typeof(int) + || type == typeof(long) + || type == typeof(short) + || type == typeof(string) + || type == typeof(object); } diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDataLengthFunctionTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDataLengthFunctionTranslator.cs index 921a38627a7..6454f7e9d5d 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDataLengthFunctionTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDataLengthFunctionTranslator.cs @@ -12,49 +12,8 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerDataLengthFunctionTranslator : IMethodCallTranslator +public class SqlServerDataLengthFunctionTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly List LongReturningTypes = - [ - "nvarchar(max)", - "varchar(max)", - "varbinary(max)" - ]; - - private static readonly HashSet MethodInfoDataLengthMapping - = - [ - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DataLength), [typeof(DbFunctions), typeof(string)])!, - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DataLength), [typeof(DbFunctions), typeof(bool?)])!, - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DataLength), [typeof(DbFunctions), typeof(double?)])!, - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DataLength), [typeof(DbFunctions), typeof(decimal?)])!, - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DataLength), [typeof(DbFunctions), typeof(DateTime?)])!, - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DataLength), [typeof(DbFunctions), typeof(TimeSpan?)])!, - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DataLength), [typeof(DbFunctions), typeof(DateTimeOffset?)])!, - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DataLength), [typeof(DbFunctions), typeof(byte[])])!, - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DataLength), [typeof(DbFunctions), typeof(Guid?)])! - ]; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerDataLengthFunctionTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -67,34 +26,35 @@ public SqlServerDataLengthFunctionTranslator(ISqlExpressionFactory sqlExpression IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (MethodInfoDataLengthMapping.Contains(method)) + if (method.DeclaringType != typeof(SqlServerDbFunctionsExtensions) + || method.Name != nameof(SqlServerDbFunctionsExtensions.DataLength)) { - var argument = arguments[1]; - if (argument.TypeMapping == null) - { - argument = _sqlExpressionFactory.ApplyDefaultTypeMapping(argument); - } - - if (LongReturningTypes.Contains(argument.TypeMapping!.StoreType)) - { - var result = _sqlExpressionFactory.Function( - "DATALENGTH", - arguments.Skip(1), - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[1], - typeof(long)); + return null; + } - return _sqlExpressionFactory.Convert(result, method.ReturnType.UnwrapNullableType()); - } + var argument = arguments[1]; + if (argument.TypeMapping == null) + { + argument = sqlExpressionFactory.ApplyDefaultTypeMapping(argument); + } - return _sqlExpressionFactory.Function( + if (argument.TypeMapping!.StoreType is "nvarchar(max)" or "varchar(max)" or "varbinary(max)") + { + var result = sqlExpressionFactory.Function( "DATALENGTH", arguments.Skip(1), nullable: true, argumentsPropagateNullability: Statics.TrueArrays[1], - method.ReturnType.UnwrapNullableType()); + typeof(long)); + + return sqlExpressionFactory.Convert(result, method.ReturnType.UnwrapNullableType()); } - return null; + return sqlExpressionFactory.Function( + "DATALENGTH", + arguments.Skip(1), + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[1], + method.ReturnType.UnwrapNullableType()); } } diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDateDiffFunctionsTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDateDiffFunctionsTranslator.cs index 6be965994f8..646adce541f 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDateDiffFunctionsTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDateDiffFunctionsTranslator.cs @@ -13,528 +13,8 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerDateDiffFunctionsTranslator : IMethodCallTranslator +public class SqlServerDateDiffFunctionsTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private readonly Dictionary _methodInfoDateDiffMapping - = new() - { - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffYear), - [typeof(DbFunctions), typeof(DateTime), typeof(DateTime)])!, - "year" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffYear), - [typeof(DbFunctions), typeof(DateTime?), typeof(DateTime?)])!, - "year" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffYear), - [typeof(DbFunctions), typeof(DateTimeOffset), typeof(DateTimeOffset)])!, - "year" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffYear), - [typeof(DbFunctions), typeof(DateTimeOffset?), typeof(DateTimeOffset?)])!, - "year" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffYear), - [typeof(DbFunctions), typeof(DateOnly), typeof(DateOnly)])!, - "year" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffYear), - [typeof(DbFunctions), typeof(DateOnly?), typeof(DateOnly?)])!, - "year" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMonth), - [typeof(DbFunctions), typeof(DateTime), typeof(DateTime)])!, - "month" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMonth), - [typeof(DbFunctions), typeof(DateTime?), typeof(DateTime?)])!, - "month" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMonth), - [typeof(DbFunctions), typeof(DateTimeOffset), typeof(DateTimeOffset)])!, - "month" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMonth), - [typeof(DbFunctions), typeof(DateTimeOffset?), typeof(DateTimeOffset?)])!, - "month" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMonth), - [typeof(DbFunctions), typeof(DateOnly), typeof(DateOnly)])!, - "month" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMonth), - [typeof(DbFunctions), typeof(DateOnly?), typeof(DateOnly?)])!, - "month" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffDay), [typeof(DbFunctions), typeof(DateTime), typeof(DateTime)])!, - "day" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffDay), - [typeof(DbFunctions), typeof(DateTime?), typeof(DateTime?)])!, - "day" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffDay), - [typeof(DbFunctions), typeof(DateTimeOffset), typeof(DateTimeOffset)])!, - "day" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffDay), - [typeof(DbFunctions), typeof(DateTimeOffset?), typeof(DateTimeOffset?)])!, - "day" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffDay), - [typeof(DbFunctions), typeof(DateOnly), typeof(DateOnly)])!, - "day" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffDay), - [typeof(DbFunctions), typeof(DateOnly?), typeof(DateOnly?)])!, - "day" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffHour), - [typeof(DbFunctions), typeof(DateTime), typeof(DateTime)])!, - "hour" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffHour), - [typeof(DbFunctions), typeof(DateTime?), typeof(DateTime?)])!, - "hour" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffHour), - [typeof(DbFunctions), typeof(DateTimeOffset), typeof(DateTimeOffset)])!, - "hour" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffHour), - [typeof(DbFunctions), typeof(DateTimeOffset?), typeof(DateTimeOffset?)])!, - "hour" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffHour), - [typeof(DbFunctions), typeof(TimeSpan), typeof(TimeSpan)])!, - "hour" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffHour), - [typeof(DbFunctions), typeof(TimeSpan?), typeof(TimeSpan?)])!, - "hour" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffHour), - [typeof(DbFunctions), typeof(TimeOnly), typeof(TimeOnly)])!, - "hour" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffHour), - [typeof(DbFunctions), typeof(TimeOnly?), typeof(TimeOnly?)])!, - "hour" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffHour), - [typeof(DbFunctions), typeof(DateOnly), typeof(DateOnly)])!, - "hour" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffHour), - [typeof(DbFunctions), typeof(DateOnly?), typeof(DateOnly?)])!, - "hour" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMinute), - [typeof(DbFunctions), typeof(DateTime), typeof(DateTime)])!, - "minute" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMinute), - [typeof(DbFunctions), typeof(DateTime?), typeof(DateTime?)])!, - "minute" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMinute), - [typeof(DbFunctions), typeof(DateTimeOffset), typeof(DateTimeOffset)])!, - "minute" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMinute), - [typeof(DbFunctions), typeof(DateTimeOffset?), typeof(DateTimeOffset?)])!, - "minute" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMinute), - [typeof(DbFunctions), typeof(TimeSpan), typeof(TimeSpan)])!, - "minute" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMinute), - [typeof(DbFunctions), typeof(TimeSpan?), typeof(TimeSpan?)])!, - "minute" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMinute), - [typeof(DbFunctions), typeof(TimeOnly), typeof(TimeOnly)])!, - "minute" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMinute), - [typeof(DbFunctions), typeof(TimeOnly?), typeof(TimeOnly?)])!, - "minute" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMinute), - [typeof(DbFunctions), typeof(DateOnly), typeof(DateOnly)])!, - "minute" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMinute), - [typeof(DbFunctions), typeof(DateOnly?), typeof(DateOnly?)])!, - "minute" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffSecond), - [typeof(DbFunctions), typeof(DateTime), typeof(DateTime)])!, - "second" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffSecond), - [typeof(DbFunctions), typeof(DateTime?), typeof(DateTime?)])!, - "second" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffSecond), - [typeof(DbFunctions), typeof(DateTimeOffset), typeof(DateTimeOffset)])!, - "second" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffSecond), - [typeof(DbFunctions), typeof(DateTimeOffset?), typeof(DateTimeOffset?)])!, - "second" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffSecond), - [typeof(DbFunctions), typeof(TimeSpan), typeof(TimeSpan)])!, - "second" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffSecond), - [typeof(DbFunctions), typeof(TimeSpan?), typeof(TimeSpan?)])!, - "second" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffSecond), - [typeof(DbFunctions), typeof(TimeOnly), typeof(TimeOnly)])!, - "second" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffSecond), - [typeof(DbFunctions), typeof(TimeOnly?), typeof(TimeOnly?)])!, - "second" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffSecond), - [typeof(DbFunctions), typeof(DateOnly), typeof(DateOnly)])!, - "second" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffSecond), - [typeof(DbFunctions), typeof(DateOnly?), typeof(DateOnly?)])!, - "second" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMillisecond), - [typeof(DbFunctions), typeof(DateTime), typeof(DateTime)])!, - "millisecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMillisecond), - [typeof(DbFunctions), typeof(DateTime?), typeof(DateTime?)])!, - "millisecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMillisecond), - [typeof(DbFunctions), typeof(DateTimeOffset), typeof(DateTimeOffset)])!, - "millisecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMillisecond), - [typeof(DbFunctions), typeof(DateTimeOffset?), typeof(DateTimeOffset?)])!, - "millisecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMillisecond), - [typeof(DbFunctions), typeof(TimeSpan), typeof(TimeSpan)])!, - "millisecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMillisecond), - [typeof(DbFunctions), typeof(TimeSpan?), typeof(TimeSpan?)])!, - "millisecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMillisecond), - [typeof(DbFunctions), typeof(TimeOnly), typeof(TimeOnly)])!, - "millisecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMillisecond), - [typeof(DbFunctions), typeof(TimeOnly?), typeof(TimeOnly?)])!, - "millisecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMillisecond), - [typeof(DbFunctions), typeof(DateOnly), typeof(DateOnly)])!, - "millisecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMillisecond), - [typeof(DbFunctions), typeof(DateOnly?), typeof(DateOnly?)])!, - "millisecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMicrosecond), - [typeof(DbFunctions), typeof(DateTime), typeof(DateTime)])!, - "microsecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMicrosecond), - [typeof(DbFunctions), typeof(DateTime?), typeof(DateTime?)])!, - "microsecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMicrosecond), - [typeof(DbFunctions), typeof(DateTimeOffset), typeof(DateTimeOffset)])!, - "microsecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMicrosecond), - [typeof(DbFunctions), typeof(DateTimeOffset?), typeof(DateTimeOffset?)])!, - "microsecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMicrosecond), - [typeof(DbFunctions), typeof(TimeSpan), typeof(TimeSpan)])!, - "microsecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMicrosecond), - [typeof(DbFunctions), typeof(TimeSpan?), typeof(TimeSpan?)])!, - "microsecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMicrosecond), - [typeof(DbFunctions), typeof(TimeOnly), typeof(TimeOnly)])!, - "microsecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMicrosecond), - [typeof(DbFunctions), typeof(TimeOnly?), typeof(TimeOnly?)])!, - "microsecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMicrosecond), - [typeof(DbFunctions), typeof(DateOnly), typeof(DateOnly)])!, - "microsecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffMicrosecond), - [typeof(DbFunctions), typeof(DateOnly?), typeof(DateOnly?)])!, - "microsecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffNanosecond), - [typeof(DbFunctions), typeof(DateTime), typeof(DateTime)])!, - "nanosecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffNanosecond), - [typeof(DbFunctions), typeof(DateTime?), typeof(DateTime?)])!, - "nanosecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffNanosecond), - [typeof(DbFunctions), typeof(DateTimeOffset), typeof(DateTimeOffset)])!, - "nanosecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffNanosecond), - [typeof(DbFunctions), typeof(DateTimeOffset?), typeof(DateTimeOffset?)])!, - "nanosecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffNanosecond), - [typeof(DbFunctions), typeof(TimeSpan), typeof(TimeSpan)])!, - "nanosecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffNanosecond), - [typeof(DbFunctions), typeof(TimeSpan?), typeof(TimeSpan?)])!, - "nanosecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffNanosecond), - [typeof(DbFunctions), typeof(TimeOnly), typeof(TimeOnly)])!, - "nanosecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffNanosecond), - [typeof(DbFunctions), typeof(TimeOnly?), typeof(TimeOnly?)])!, - "nanosecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffNanosecond), - [typeof(DbFunctions), typeof(DateOnly), typeof(DateOnly)])!, - "nanosecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffNanosecond), - [typeof(DbFunctions), typeof(DateOnly?), typeof(DateOnly?)])!, - "nanosecond" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffWeek), - [typeof(DbFunctions), typeof(DateTime), typeof(DateTime)])!, - "week" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffWeek), - [typeof(DbFunctions), typeof(DateTime?), typeof(DateTime?)])!, - "week" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffWeek), - [typeof(DbFunctions), typeof(DateTimeOffset), typeof(DateTimeOffset)])!, - "week" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffWeek), - [typeof(DbFunctions), typeof(DateTimeOffset?), typeof(DateTimeOffset?)])!, - "week" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffWeek), - [typeof(DbFunctions), typeof(DateOnly), typeof(DateOnly)])!, - "week" - }, - { - typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateDiffWeek), - [typeof(DbFunctions), typeof(DateOnly?), typeof(DateOnly?)])!, - "week" - } - }; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerDateDiffFunctionsTranslator( - ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -547,23 +27,43 @@ public SqlServerDateDiffFunctionsTranslator( IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (_methodInfoDateDiffMapping.TryGetValue(method, out var datePart)) + if (method.DeclaringType != typeof(SqlServerDbFunctionsExtensions)) { - var startDate = arguments[1]; - var endDate = arguments[2]; - var typeMapping = ExpressionExtensions.InferTypeMapping(startDate, endDate); + return null; + } - startDate = _sqlExpressionFactory.ApplyTypeMapping(startDate, typeMapping); - endDate = _sqlExpressionFactory.ApplyTypeMapping(endDate, typeMapping); + var datePart = method.Name switch + { + nameof(SqlServerDbFunctionsExtensions.DateDiffYear) => "year", + nameof(SqlServerDbFunctionsExtensions.DateDiffMonth) => "month", + nameof(SqlServerDbFunctionsExtensions.DateDiffDay) => "day", + nameof(SqlServerDbFunctionsExtensions.DateDiffHour) => "hour", + nameof(SqlServerDbFunctionsExtensions.DateDiffMinute) => "minute", + nameof(SqlServerDbFunctionsExtensions.DateDiffSecond) => "second", + nameof(SqlServerDbFunctionsExtensions.DateDiffMillisecond) => "millisecond", + nameof(SqlServerDbFunctionsExtensions.DateDiffMicrosecond) => "microsecond", + nameof(SqlServerDbFunctionsExtensions.DateDiffNanosecond) => "nanosecond", + nameof(SqlServerDbFunctionsExtensions.DateDiffWeek) => "week", + _ => null + }; - return _sqlExpressionFactory.Function( - "DATEDIFF", - [_sqlExpressionFactory.Fragment(datePart), startDate, endDate], - nullable: true, - argumentsPropagateNullability: [false, true, true], - typeof(int)); + if (datePart is null) + { + return null; } - return null; + var startDate = arguments[1]; + var endDate = arguments[2]; + var typeMapping = ExpressionExtensions.InferTypeMapping(startDate, endDate); + + startDate = sqlExpressionFactory.ApplyTypeMapping(startDate, typeMapping); + endDate = sqlExpressionFactory.ApplyTypeMapping(endDate, typeMapping); + + return sqlExpressionFactory.Function( + "DATEDIFF", + [sqlExpressionFactory.Fragment(datePart), startDate, endDate], + nullable: true, + argumentsPropagateNullability: [false, true, true], + typeof(int)); } } diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDateOnlyMethodTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDateOnlyMethodTranslator.cs index cdbba2a8c99..4180bcb80c5 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDateOnlyMethodTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDateOnlyMethodTranslator.cs @@ -12,29 +12,8 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerDateOnlyMethodTranslator : IMethodCallTranslator +public class SqlServerDateOnlyMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private readonly Dictionary _methodInfoDatePartMapping = new() - { - { typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddYears), [typeof(int)])!, "year" }, - { typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddMonths), [typeof(int)])!, "month" }, - { typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddDays), [typeof(int)])!, "day" } - }; - - private static readonly MethodInfo ToDateTimeMethodInfo - = typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.ToDateTime), [typeof(TimeOnly)])!; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerDateOnlyMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -47,12 +26,15 @@ public SqlServerDateOnlyMethodTranslator(ISqlExpressionFactory sqlExpressionFact IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (instance != null) + if (method.DeclaringType != typeof(DateOnly)) { - if (method == ToDateTimeMethodInfo) - { - var timeOnly = arguments[0]; + return null; + } + if (instance is not null) + { + if (method.Name == nameof(DateOnly.ToDateTime) && arguments is [var timeOnly]) + { // We need to refrain from doing the translation when either the DateOnly or the TimeOnly // are a complex SQL expression (anything other than a column/constant/parameter), to avoid evaluating them multiple // potentially expensive arbitrary expressions multiple times. @@ -62,7 +44,7 @@ public SqlServerDateOnlyMethodTranslator(ISqlExpressionFactory sqlExpressionFact return null; } - return _sqlExpressionFactory.Function( + return sqlExpressionFactory.Function( "DATETIME2FROMPARTS", [ MapDatePartExpression("year", instance), @@ -72,20 +54,28 @@ public SqlServerDateOnlyMethodTranslator(ISqlExpressionFactory sqlExpressionFact MapDatePartExpression("minute", timeOnly), MapDatePartExpression("second", timeOnly), MapDatePartExpression("fraction", timeOnly), - _sqlExpressionFactory.Constant(7, typeof(int)), + sqlExpressionFactory.Constant(7, typeof(int)), ], nullable: true, argumentsPropagateNullability: [true, true, true, true, true, true, true, false], typeof(DateTime)); } - if (_methodInfoDatePartMapping.TryGetValue(method, out var datePart)) + var datePart = method.Name switch + { + nameof(DateOnly.AddYears) => "year", + nameof(DateOnly.AddMonths) => "month", + nameof(DateOnly.AddDays) => "day", + _ => (string?)null + }; + + if (datePart is not null) { - instance = _sqlExpressionFactory.ApplyDefaultTypeMapping(instance); + instance = sqlExpressionFactory.ApplyDefaultTypeMapping(instance); - return _sqlExpressionFactory.Function( + return sqlExpressionFactory.Function( "DATEADD", - [_sqlExpressionFactory.Fragment(datePart), _sqlExpressionFactory.Convert(arguments[0], typeof(int)), instance], + [sqlExpressionFactory.Fragment(datePart), sqlExpressionFactory.Convert(arguments[0], typeof(int)), instance], nullable: true, argumentsPropagateNullability: [false, true, true], instance.Type, @@ -93,11 +83,9 @@ public SqlServerDateOnlyMethodTranslator(ISqlExpressionFactory sqlExpressionFact } } - if (method.DeclaringType == typeof(DateOnly) - && method.Name == nameof(DateOnly.FromDateTime) - && arguments.Count == 1) + if (method.Name == nameof(DateOnly.FromDateTime) && arguments is [_]) { - return _sqlExpressionFactory.Convert(arguments[0], typeof(DateOnly)); + return sqlExpressionFactory.Convert(arguments[0], typeof(DateOnly)); } return null; @@ -120,26 +108,26 @@ private SqlExpression MapDatePartExpression(string datepart, SqlExpression argum _ => throw new UnreachableException() }; - return _sqlExpressionFactory.Constant(constant, typeof(int)); + return sqlExpressionFactory.Constant(constant, typeof(int)); } if (datepart == "fraction") { - return _sqlExpressionFactory.Divide( - _sqlExpressionFactory.Function( + return sqlExpressionFactory.Divide( + sqlExpressionFactory.Function( "DATEPART", - [_sqlExpressionFactory.Fragment("nanosecond"), argument], + [sqlExpressionFactory.Fragment("nanosecond"), argument], nullable: true, argumentsPropagateNullability: [true, true], typeof(int) ), - _sqlExpressionFactory.Constant(100, typeof(int)) + sqlExpressionFactory.Constant(100, typeof(int)) ); } - return _sqlExpressionFactory.Function( + return sqlExpressionFactory.Function( "DATEPART", - [_sqlExpressionFactory.Fragment(datepart), argument], + [sqlExpressionFactory.Fragment(datepart), argument], nullable: true, argumentsPropagateNullability: [true, true], typeof(int)); diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDateTimeMethodTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDateTimeMethodTranslator.cs index b38e0261479..3c446344ae1 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDateTimeMethodTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerDateTimeMethodTranslator.cs @@ -17,38 +17,6 @@ public class SqlServerDateTimeMethodTranslator( IRelationalTypeMappingSource typeMappingSource) : IMethodCallTranslator { - private static readonly Dictionary MethodInfoDatePartMapping = new() - { - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddYears), [typeof(int)])!, "year" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMonths), [typeof(int)])!, "month" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddDays), [typeof(double)])!, "day" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddHours), [typeof(double)])!, "hour" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMinutes), [typeof(double)])!, "minute" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddSeconds), [typeof(double)])!, "second" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMilliseconds), [typeof(double)])!, "millisecond" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddYears), [typeof(int)])!, "year" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMonths), [typeof(int)])!, "month" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddDays), [typeof(double)])!, "day" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddHours), [typeof(double)])!, "hour" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMinutes), [typeof(double)])!, "minute" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddSeconds), [typeof(double)])!, "second" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMilliseconds), [typeof(double)])!, "millisecond" } - }; - - private static readonly Dictionary MethodInfoDateDiffMapping = new() - { - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.ToUnixTimeSeconds), Type.EmptyTypes)!, "second" }, - { typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.ToUnixTimeMilliseconds), Type.EmptyTypes)!, "millisecond" } - }; - - private static readonly MethodInfo AtTimeZoneDateTimeOffsetMethodInfo = typeof(SqlServerDbFunctionsExtensions) - .GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.AtTimeZone), [typeof(DbFunctions), typeof(DateTimeOffset), typeof(string)])!; - - private static readonly MethodInfo AtTimeZoneDateTimeMethodInfo = typeof(SqlServerDbFunctionsExtensions) - .GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.AtTimeZone), [typeof(DbFunctions), typeof(DateTime), typeof(string)])!; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -61,35 +29,83 @@ public class SqlServerDateTimeMethodTranslator( IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (MethodInfoDatePartMapping.TryGetValue(method, out var datePart) - && instance != null) + var declaringType = method.DeclaringType; + + if (declaringType == typeof(DateTime) || declaringType == typeof(DateTimeOffset)) { - // Some Add methods accept a double, and SQL Server DateAdd does not accept number argument outside of int range - if (arguments[0] is SqlConstantExpression { Value: double and (<= int.MinValue or >= int.MaxValue) }) + var datePart = method.Name switch + { + nameof(DateTime.AddYears) => "year", + nameof(DateTime.AddMonths) => "month", + nameof(DateTime.AddDays) => "day", + nameof(DateTime.AddHours) => "hour", + nameof(DateTime.AddMinutes) => "minute", + nameof(DateTime.AddSeconds) => "second", + nameof(DateTime.AddMilliseconds) => "millisecond", + _ => (string?)null + }; + + if (datePart is not null && instance is not null) { - return null; + // Some Add methods accept a double, and SQL Server DateAdd does not accept number argument outside of int range + if (arguments[0] is SqlConstantExpression { Value: double and (<= int.MinValue or >= int.MaxValue) }) + { + return null; + } + + // DATEADD defaults to interpreting its last argument as datetime, not datetime2. + // Our default mapping for DateTime is datetime2, so we force constants to be datetime instead here. + if (instance is SqlConstantExpression instanceConstant) + { + instance = instanceConstant.ApplyTypeMapping(typeMappingSource.FindMapping(typeof(DateTime), "datetime")); + } + + return sqlExpressionFactory.Function( + "DATEADD", + arguments: + [ + sqlExpressionFactory.Fragment(datePart), + sqlExpressionFactory.Convert(arguments[0], typeof(int)), + instance + ], + nullable: true, + argumentsPropagateNullability: [false, true, true], + instance.Type, + instance.TypeMapping); } - // DATEADD defaults to interpreting its last argument as datetime, not datetime2. - // Our default mapping for DateTime is datetime2, so we force constants to be datetime instead here. - if (instance is SqlConstantExpression instanceConstant) + if (declaringType == typeof(DateTimeOffset)) { - instance = instanceConstant.ApplyTypeMapping(typeMappingSource.FindMapping(typeof(DateTime), "datetime")); + var timePart = method.Name switch + { + nameof(DateTimeOffset.ToUnixTimeSeconds) => "second", + nameof(DateTimeOffset.ToUnixTimeMilliseconds) => "millisecond", + _ => null + }; + + if (timePart is not null) + { + return sqlExpressionFactory.Function( + "DATEDIFF_BIG", + arguments: + [ + sqlExpressionFactory.Fragment(timePart), + sqlExpressionFactory.Constant(DateTimeOffset.UnixEpoch, instance!.TypeMapping), + instance + ], + nullable: true, + argumentsPropagateNullability: [false, true, true], + typeof(long)); + } } - return sqlExpressionFactory.Function( - "DATEADD", - arguments: [sqlExpressionFactory.Fragment(datePart), sqlExpressionFactory.Convert(arguments[0], typeof(int)), instance], - nullable: true, - argumentsPropagateNullability: [false, true, true], - instance.Type, - instance.TypeMapping); + return null; } - if (method == AtTimeZoneDateTimeOffsetMethodInfo || method == AtTimeZoneDateTimeMethodInfo) + if (declaringType == typeof(SqlServerDbFunctionsExtensions) + && method.Name == nameof(SqlServerDbFunctionsExtensions.AtTimeZone) + && arguments is [_, var operand, var timeZone]) { - var (operand, timeZone) = (arguments[1], arguments[2]); - RelationalTypeMapping? resultTypeMapping = null; // The AT TIME ZONE construct bubbles up the precision of its operand, so when invoked over datetime2(2) it returns a @@ -126,21 +142,6 @@ public class SqlServerDateTimeMethodTranslator( resultTypeMapping); } - if (MethodInfoDateDiffMapping.TryGetValue(method, out var timePart)) - { - return sqlExpressionFactory.Function( - "DATEDIFF_BIG", - arguments: - [ - sqlExpressionFactory.Fragment(timePart), - sqlExpressionFactory.Constant(DateTimeOffset.UnixEpoch, instance!.TypeMapping), - instance - ], - nullable: true, - argumentsPropagateNullability: [false, true, true], - typeof(long)); - } - return null; } } diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerFromPartsFunctionTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerFromPartsFunctionTranslator.cs index 61e3ee436af..81539356973 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerFromPartsFunctionTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerFromPartsFunctionTranslator.cs @@ -12,87 +12,13 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerFromPartsFunctionTranslator : IMethodCallTranslator +public class SqlServerFromPartsFunctionTranslator( + ISqlExpressionFactory sqlExpressionFactory, + IRelationalTypeMappingSource typeMappingSource) + : IMethodCallTranslator { - private static readonly MethodInfo DateFromPartsMethodInfo = typeof(SqlServerDbFunctionsExtensions) - .GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateFromParts), - [typeof(DbFunctions), typeof(int), typeof(int), typeof(int)])!; - - private static readonly MethodInfo DateTimeFromPartsMethodInfo = typeof(SqlServerDbFunctionsExtensions) - .GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateTimeFromParts), - [typeof(DbFunctions), typeof(int), typeof(int), typeof(int), typeof(int), typeof(int), typeof(int), typeof(int)])!; - - private static readonly MethodInfo DateTime2FromPartsMethodInfo = typeof(SqlServerDbFunctionsExtensions) - .GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateTime2FromParts), - [ - typeof(DbFunctions), - typeof(int), - typeof(int), - typeof(int), - typeof(int), - typeof(int), - typeof(int), - typeof(int), - typeof(int) - ])!; - - private static readonly MethodInfo DateTimeOffsetFromPartsMethodInfo = typeof(SqlServerDbFunctionsExtensions) - .GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.DateTimeOffsetFromParts), - [ - typeof(DbFunctions), - typeof(int), - typeof(int), - typeof(int), - typeof(int), - typeof(int), - typeof(int), - typeof(int), - typeof(int), - typeof(int), - typeof(int) - ])!; - - private static readonly MethodInfo SmallDateTimeFromPartsMethodInfo = typeof(SqlServerDbFunctionsExtensions) - .GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.SmallDateTimeFromParts), - [typeof(DbFunctions), typeof(int), typeof(int), typeof(int), typeof(int), typeof(int)])!; - - private static readonly MethodInfo TimeFromPartsMethodInfo = typeof(SqlServerDbFunctionsExtensions) - .GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.TimeFromParts), - [typeof(DbFunctions), typeof(int), typeof(int), typeof(int), typeof(int), typeof(int)])!; - - private static readonly IDictionary MethodFunctionMapping - = new Dictionary - { - { DateFromPartsMethodInfo, ("DATEFROMPARTS", "date") }, - { DateTimeFromPartsMethodInfo, ("DATETIMEFROMPARTS", "datetime") }, - { DateTime2FromPartsMethodInfo, ("DATETIME2FROMPARTS", "datetime2") }, - { DateTimeOffsetFromPartsMethodInfo, ("DATETIMEOFFSETFROMPARTS", "datetimeoffset") }, - { SmallDateTimeFromPartsMethodInfo, ("SMALLDATETIMEFROMPARTS", "smalldatetime") }, - { TimeFromPartsMethodInfo, ("TIMEFROMPARTS", "time") } - }; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - private readonly IRelationalTypeMappingSource _typeMappingSource; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerFromPartsFunctionTranslator( - ISqlExpressionFactory sqlExpressionFactory, - IRelationalTypeMappingSource typeMappingSource) - { - _sqlExpressionFactory = sqlExpressionFactory; - _typeMappingSource = typeMappingSource; - } + private readonly ISqlExpressionFactory _sqlExpressionFactory = sqlExpressionFactory; + private readonly IRelationalTypeMappingSource _typeMappingSource = typeMappingSource; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -106,17 +32,33 @@ public SqlServerFromPartsFunctionTranslator( IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (MethodFunctionMapping.TryGetValue(method, out var value)) + if (method.DeclaringType != typeof(SqlServerDbFunctionsExtensions)) + { + return null; + } + + var (functionName, returnType) = method.Name switch + { + nameof(SqlServerDbFunctionsExtensions.DateFromParts) => ("DATEFROMPARTS", "date"), + nameof(SqlServerDbFunctionsExtensions.DateTimeFromParts) => ("DATETIMEFROMPARTS", "datetime"), + nameof(SqlServerDbFunctionsExtensions.DateTime2FromParts) => ("DATETIME2FROMPARTS", "datetime2"), + nameof(SqlServerDbFunctionsExtensions.DateTimeOffsetFromParts) => ("DATETIMEOFFSETFROMPARTS", "datetimeoffset"), + nameof(SqlServerDbFunctionsExtensions.SmallDateTimeFromParts) => ("SMALLDATETIMEFROMPARTS", "smalldatetime"), + nameof(SqlServerDbFunctionsExtensions.TimeFromParts) => ("TIMEFROMPARTS", "time"), + _ => (null, null) + }; + + if (functionName is null) { - return _sqlExpressionFactory.Function( - value.FunctionName, - arguments.Skip(1), - nullable: true, - argumentsPropagateNullability: arguments.Skip(1).Select(_ => true), - method.ReturnType, - _typeMappingSource.FindMapping(method.ReturnType, value.ReturnType)); + return null; } - return null; + return _sqlExpressionFactory.Function( + functionName, + arguments.Skip(1), + nullable: true, + argumentsPropagateNullability: arguments.Skip(1).Select(_ => true), + method.ReturnType, + _typeMappingSource.FindMapping(method.ReturnType, returnType)); } } diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerFullTextSearchFunctionsTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerFullTextSearchFunctionsTranslator.cs index f8bf8db4c7b..1789b57fa9c 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerFullTextSearchFunctionsTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerFullTextSearchFunctionsTranslator.cs @@ -13,49 +13,8 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerFullTextSearchFunctionsTranslator : IMethodCallTranslator +public class SqlServerFullTextSearchFunctionsTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private const string FreeTextFunctionName = "FREETEXT"; - private const string ContainsFunctionName = "CONTAINS"; - - private static readonly MethodInfo FreeTextMethodInfo - = typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.FreeText), [typeof(DbFunctions), typeof(object), typeof(string)])!; - - private static readonly MethodInfo FreeTextMethodInfoWithLanguage - = typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.FreeText), - [typeof(DbFunctions), typeof(object), typeof(string), typeof(int)])!; - - private static readonly MethodInfo ContainsMethodInfo - = typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.Contains), [typeof(DbFunctions), typeof(object), typeof(string)])!; - - private static readonly MethodInfo ContainsMethodInfoWithLanguage - = typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.Contains), - [typeof(DbFunctions), typeof(object), typeof(string), typeof(int)])!; - - private static readonly IDictionary FunctionMapping - = new Dictionary - { - { FreeTextMethodInfo, FreeTextFunctionName }, - { FreeTextMethodInfoWithLanguage, FreeTextFunctionName }, - { ContainsMethodInfo, ContainsFunctionName }, - { ContainsMethodInfoWithLanguage, ContainsFunctionName } - }; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerFullTextSearchFunctionsTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -68,33 +27,46 @@ public SqlServerFullTextSearchFunctionsTranslator(ISqlExpressionFactory sqlExpre IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (FunctionMapping.TryGetValue(method, out var functionName)) + if (method.DeclaringType != typeof(SqlServerDbFunctionsExtensions)) { - var propertyReference = arguments[1]; - if (propertyReference is not ColumnExpression) - { - throw new InvalidOperationException(SqlServerStrings.InvalidColumnNameForFreeText); - } + return null; + } - var freeText = _sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[2]); + // Note: the table-valued FREETEXTTABLE and CONTAINSTABLE functions are handled in SqlServerQueryableMethodTranslatingExpressionVisitor + var functionName = method.Name switch + { + nameof(SqlServerDbFunctionsExtensions.FreeText) => "FREETEXT", + nameof(SqlServerDbFunctionsExtensions.Contains) => "CONTAINS", + _ => null + }; - var functionArguments = new List { propertyReference, freeText }; + if (functionName is null) + { + return null; + } + + var propertyReference = arguments[1]; + if (propertyReference is not ColumnExpression) + { + throw new InvalidOperationException(SqlServerStrings.InvalidColumnNameForFreeText); + } + + var freeText = sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[2]); - if (arguments.Count == 4) - { - functionArguments.Add( - _sqlExpressionFactory.Fragment($"LANGUAGE {((SqlConstantExpression)arguments[3]).Value}")); - } + var functionArguments = new List { propertyReference, freeText }; - return _sqlExpressionFactory.Function( - functionName, - functionArguments, - nullable: true, - // TODO: don't propagate for now - argumentsPropagateNullability: functionArguments.Select(_ => false).ToList(), - typeof(bool)); + if (arguments.Count == 4) + { + functionArguments.Add( + sqlExpressionFactory.Fragment($"LANGUAGE {((SqlConstantExpression)arguments[3]).Value}")); } - return null; + return sqlExpressionFactory.Function( + functionName, + functionArguments, + nullable: true, + // TODO: don't propagate for now + argumentsPropagateNullability: functionArguments.Select(_ => false).ToList(), + typeof(bool)); } } diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerIsDateFunctionTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerIsDateFunctionTranslator.cs index 8e7d545daf6..a3842779cef 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerIsDateFunctionTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerIsDateFunctionTranslator.cs @@ -12,22 +12,8 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerIsDateFunctionTranslator : IMethodCallTranslator +public class SqlServerIsDateFunctionTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - private static readonly MethodInfo MethodInfo = typeof(SqlServerDbFunctionsExtensions) - .GetRuntimeMethod(nameof(SqlServerDbFunctionsExtensions.IsDate), [typeof(DbFunctions), typeof(string)])!; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerIsDateFunctionTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -39,14 +25,15 @@ public SqlServerIsDateFunctionTranslator(ISqlExpressionFactory sqlExpressionFact MethodInfo method, IReadOnlyList arguments, IDiagnosticsLogger logger) - => MethodInfo.Equals(method) - ? _sqlExpressionFactory.Convert( - _sqlExpressionFactory.Function( - "ISDATE", - [arguments[1]], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[1], - MethodInfo.ReturnType), - MethodInfo.ReturnType) - : null; + => method.DeclaringType == typeof(SqlServerDbFunctionsExtensions) + && method.Name == nameof(SqlServerDbFunctionsExtensions.IsDate) + ? sqlExpressionFactory.Convert( + sqlExpressionFactory.Function( + "ISDATE", + [arguments[1]], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[1], + method.ReturnType), + method.ReturnType) + : null; } diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerIsNumericFunctionTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerIsNumericFunctionTranslator.cs index f9a24c9f7e3..fbd18e5da85 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerIsNumericFunctionTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerIsNumericFunctionTranslator.cs @@ -12,22 +12,8 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerIsNumericFunctionTranslator : IMethodCallTranslator +public class SqlServerIsNumericFunctionTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - private static readonly MethodInfo MethodInfo = typeof(SqlServerDbFunctionsExtensions) - .GetRuntimeMethod(nameof(SqlServerDbFunctionsExtensions.IsNumeric), [typeof(DbFunctions), typeof(string)])!; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerIsNumericFunctionTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -39,14 +25,15 @@ public SqlServerIsNumericFunctionTranslator(ISqlExpressionFactory sqlExpressionF MethodInfo method, IReadOnlyList arguments, IDiagnosticsLogger logger) - => MethodInfo.Equals(method) - ? _sqlExpressionFactory.Equal( - _sqlExpressionFactory.Function( - "ISNUMERIC", - [arguments[1]], - nullable: false, - argumentsPropagateNullability: Statics.FalseArrays[1], - typeof(int)), - _sqlExpressionFactory.Constant(1)) - : null; + => method.DeclaringType == typeof(SqlServerDbFunctionsExtensions) + && method.Name == nameof(SqlServerDbFunctionsExtensions.IsNumeric) + ? sqlExpressionFactory.Equal( + sqlExpressionFactory.Function( + "ISNUMERIC", + [arguments[1]], + nullable: false, + argumentsPropagateNullability: Statics.FalseArrays[1], + typeof(int)), + sqlExpressionFactory.Constant(1)) + : null; } diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs index de4242b2a29..06f9e9b9201 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs @@ -13,92 +13,10 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerMathTranslator : IMethodCallTranslator +public class SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly Dictionary SupportedMethodTranslations = new() - { - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(decimal)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(double)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(float)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(int)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(long)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(sbyte)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Abs), [typeof(short)])!, "ABS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), [typeof(decimal)])!, "CEILING" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), [typeof(double)])!, "CEILING" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Floor), [typeof(decimal)])!, "FLOOR" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Floor), [typeof(double)])!, "FLOOR" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Pow), [typeof(double), typeof(double)])!, "POWER" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Exp), [typeof(double)])!, "EXP" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Log10), [typeof(double)])!, "LOG10" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Log), [typeof(double)])!, "LOG" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Log), [typeof(double), typeof(double)])!, "LOG" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sqrt), [typeof(double)])!, "SQRT" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Acos), [typeof(double)])!, "ACOS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Asin), [typeof(double)])!, "ASIN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Atan), [typeof(double)])!, "ATAN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Atan2), [typeof(double), typeof(double)])!, "ATN2" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Cos), [typeof(double)])!, "COS" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sin), [typeof(double)])!, "SIN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Tan), [typeof(double)])!, "TAN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(decimal)])!, "SIGN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(double)])!, "SIGN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(float)])!, "SIGN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(int)])!, "SIGN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(long)])!, "SIGN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(sbyte)])!, "SIGN" }, - { typeof(Math).GetRuntimeMethod(nameof(Math.Sign), [typeof(short)])!, "SIGN" }, - { typeof(double).GetRuntimeMethod(nameof(double.DegreesToRadians), [typeof(double)])!, "RADIANS" }, - { typeof(double).GetRuntimeMethod(nameof(double.RadiansToDegrees), [typeof(double)])!, "DEGREES" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Ceiling), [typeof(float)])!, "CEILING" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Floor), [typeof(float)])!, "FLOOR" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Pow), [typeof(float), typeof(float)])!, "POWER" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Exp), [typeof(float)])!, "EXP" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Log10), [typeof(float)])!, "LOG10" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), [typeof(float)])!, "LOG" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), [typeof(float), typeof(float)])!, "LOG" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Sqrt), [typeof(float)])!, "SQRT" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Acos), [typeof(float)])!, "ACOS" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Asin), [typeof(float)])!, "ASIN" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan), [typeof(float)])!, "ATAN" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan2), [typeof(float), typeof(float)])!, "ATN2" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Cos), [typeof(float)])!, "COS" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Sin), [typeof(float)])!, "SIN" }, - { typeof(MathF).GetRuntimeMethod(nameof(MathF.Tan), [typeof(float)])!, "TAN" }, - { typeof(float).GetRuntimeMethod(nameof(float.DegreesToRadians), [typeof(float)])!, "RADIANS" }, - { typeof(float).GetRuntimeMethod(nameof(float.RadiansToDegrees), [typeof(float)])!, "DEGREES" } - }; - // Note: Math.Max/Min are handled in RelationalSqlTranslatingExpressionVisitor - private static readonly IEnumerable TruncateMethodInfos = - [ - typeof(Math).GetRuntimeMethod(nameof(Math.Truncate), [typeof(decimal)])!, - typeof(Math).GetRuntimeMethod(nameof(Math.Truncate), [typeof(double)])!, - typeof(MathF).GetRuntimeMethod(nameof(MathF.Truncate), [typeof(float)])! - ]; - - private static readonly IEnumerable RoundMethodInfos = - [ - typeof(Math).GetRuntimeMethod(nameof(Math.Round), [typeof(decimal)])!, - typeof(Math).GetRuntimeMethod(nameof(Math.Round), [typeof(double)])!, - typeof(Math).GetRuntimeMethod(nameof(Math.Round), [typeof(decimal), typeof(int)])!, - typeof(Math).GetRuntimeMethod(nameof(Math.Round), [typeof(double), typeof(int)])!, - typeof(MathF).GetRuntimeMethod(nameof(MathF.Round), [typeof(float)])!, - typeof(MathF).GetRuntimeMethod(nameof(MathF.Round), [typeof(float), typeof(int)])! - ]; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -111,33 +29,122 @@ public SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory) IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (SupportedMethodTranslations.TryGetValue(method, out var sqlFunctionName)) - { - var typeMapping = arguments.Count == 1 - ? ExpressionExtensions.InferTypeMapping(arguments[0]) - : ExpressionExtensions.InferTypeMapping(arguments[0], arguments[1]); + var declaringType = method.DeclaringType; - var newArguments = new SqlExpression[arguments.Count]; - newArguments[0] = _sqlExpressionFactory.ApplyTypeMapping(arguments[0], typeMapping); + if (declaringType != typeof(Math) + && declaringType != typeof(MathF) + && declaringType != typeof(double) + && declaringType != typeof(float)) + { + return null; + } - if (arguments.Count == 2) - { - newArguments[1] = _sqlExpressionFactory.ApplyTypeMapping(arguments[1], typeMapping); - } + return method.Name switch + { + nameof(Math.Abs) when arguments is [var arg] + && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float) + || arg.Type == typeof(int) || arg.Type == typeof(long) || arg.Type == typeof(sbyte) || arg.Type == typeof(short)) + => TranslateFunction("ABS", arg), + nameof(Math.Ceiling) when arguments is [var arg] + && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("CEILING", arg), + nameof(Math.Floor) when arguments is [var arg] + && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("FLOOR", arg), + nameof(Math.Exp) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("EXP", arg), + nameof(Math.Log10) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("LOG10", arg), + nameof(Math.Log) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("LOG", arg), + nameof(Math.Log) when arguments is [var arg1, var arg2] + && (arg1.Type == typeof(double) || arg1.Type == typeof(float)) + => TranslateBinaryFunction("LOG", arg1, arg2), + nameof(Math.Sqrt) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("SQRT", arg), + nameof(Math.Acos) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("ACOS", arg), + nameof(Math.Asin) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("ASIN", arg), + nameof(Math.Atan) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("ATAN", arg), + nameof(Math.Atan2) when arguments is [var arg1, var arg2] + && (arg1.Type == typeof(double) || arg1.Type == typeof(float)) + => TranslateBinaryFunction("ATN2", arg1, arg2), + nameof(Math.Cos) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("COS", arg), + nameof(Math.Sin) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("SIN", arg), + nameof(Math.Tan) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("TAN", arg), + nameof(Math.Pow) when arguments is [var arg1, var arg2] + && (arg1.Type == typeof(double) || arg1.Type == typeof(float)) + => TranslateBinaryFunction("POWER", arg1, arg2), + nameof(Math.Sign) when arguments is [var arg] + && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float) + || arg.Type == typeof(int) || arg.Type == typeof(long) || arg.Type == typeof(sbyte) || arg.Type == typeof(short)) + => TranslateFunction("SIGN", arg, nullTypeMapping: true), + nameof(double.DegreesToRadians) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("RADIANS", arg), + nameof(double.RadiansToDegrees) when arguments is [var arg] + && (arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateFunction("DEGREES", arg), + + nameof(Math.Truncate) when arguments is [var arg] + && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateTruncate(arg), + nameof(Math.Round) when arguments is [var arg] + && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateRound(arg, digits: null), + nameof(Math.Round) when arguments is [var arg, var digits] + && digits.Type == typeof(int) + && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float)) + => TranslateRound(arg, digits), + + _ => null + }; + + SqlExpression TranslateFunction(string sqlFunctionName, SqlExpression arg, bool nullTypeMapping = false) + { + var typeMapping = ExpressionExtensions.InferTypeMapping(arg); + return sqlExpressionFactory.Function( + sqlFunctionName, + [sqlExpressionFactory.ApplyTypeMapping(arg, typeMapping)], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[1], + method.ReturnType, + nullTypeMapping ? null : typeMapping); + } - return _sqlExpressionFactory.Function( + SqlExpression TranslateBinaryFunction(string sqlFunctionName, SqlExpression arg1, SqlExpression arg2) + { + var typeMapping = ExpressionExtensions.InferTypeMapping(arg1, arg2); + return sqlExpressionFactory.Function( sqlFunctionName, - newArguments, + [ + sqlExpressionFactory.ApplyTypeMapping(arg1, typeMapping), + sqlExpressionFactory.ApplyTypeMapping(arg2, typeMapping) + ], nullable: true, - argumentsPropagateNullability: newArguments.Select(_ => true).ToArray(), + argumentsPropagateNullability: Statics.TrueArrays[2], method.ReturnType, - sqlFunctionName == "SIGN" ? null : typeMapping); + typeMapping); } - if (TruncateMethodInfos.Contains(method)) + SqlExpression TranslateTruncate(SqlExpression argument) { - var argument = arguments[0]; - // C# has Round over decimal/double/float only so our argument will be one of those types (compiler puts convert node) + // C# has Truncate over decimal/double/float only so our argument will be one of those types (compiler puts convert node) // In database result will be same type except for float which returns double which we need to cast back to float. var resultType = argument.Type; if (resultType == typeof(float)) @@ -145,25 +152,25 @@ public SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory) resultType = typeof(double); } - var result = _sqlExpressionFactory.Function( + var result = sqlExpressionFactory.Function( "ROUND", - [argument, _sqlExpressionFactory.Constant(0), _sqlExpressionFactory.Constant(1)], + [argument, sqlExpressionFactory.Constant(0), sqlExpressionFactory.Constant(1)], nullable: true, argumentsPropagateNullability: [true, false, false], resultType); if (argument.Type == typeof(float)) { - result = _sqlExpressionFactory.Convert(result, typeof(float)); + result = sqlExpressionFactory.Convert(result, typeof(float)); } - return _sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping); + return sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping); } - if (RoundMethodInfos.Contains(method)) + SqlExpression TranslateRound(SqlExpression argument, SqlExpression? digits) { - var argument = arguments[0]; - var digits = arguments.Count == 2 ? arguments[1] : _sqlExpressionFactory.Constant(0); + digits ??= sqlExpressionFactory.Constant(0); + // C# has Round over decimal/double/float only so our argument will be one of those types (compiler puts convert node) // In database result will be same type except for float which returns double which we need to cast back to float. var resultType = argument.Type; @@ -172,7 +179,7 @@ public SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory) resultType = typeof(double); } - var result = _sqlExpressionFactory.Function( + var result = sqlExpressionFactory.Function( "ROUND", [argument, digits], nullable: true, @@ -181,12 +188,10 @@ public SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory) if (argument.Type == typeof(float)) { - result = _sqlExpressionFactory.Convert(result, typeof(float)); + result = sqlExpressionFactory.Convert(result, typeof(float)); } - return _sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping); + return sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping); } - - return null; } } diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerNewGuidTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerNewGuidTranslator.cs index 95b412800bc..59e6923854c 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerNewGuidTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerNewGuidTranslator.cs @@ -12,20 +12,8 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerNewGuidTranslator : IMethodCallTranslator +public class SqlServerNewGuidTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo MethodInfo = typeof(Guid).GetRuntimeMethod(nameof(Guid.NewGuid), Type.EmptyTypes)!; - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerNewGuidTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -37,8 +25,8 @@ public SqlServerNewGuidTranslator(ISqlExpressionFactory sqlExpressionFactory) MethodInfo method, IReadOnlyList arguments, IDiagnosticsLogger logger) - => MethodInfo.Equals(method) - ? _sqlExpressionFactory.Function( + => method.DeclaringType == typeof(Guid) && method.Name == nameof(Guid.NewGuid) + ? sqlExpressionFactory.Function( "NEWID", [], nullable: false, diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerStringAggregateMethodTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerStringAggregateMethodTranslator.cs index 8e7b01eae87..90d26c70d72 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerStringAggregateMethodTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerStringAggregateMethodTranslator.cs @@ -12,31 +12,10 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerStringAggregateMethodTranslator : IAggregateMethodCallTranslator +public class SqlServerStringAggregateMethodTranslator( + ISqlExpressionFactory sqlExpressionFactory, + IRelationalTypeMappingSource typeMappingSource) : IAggregateMethodCallTranslator { - private static readonly MethodInfo StringConcatMethod - = typeof(string).GetRuntimeMethod(nameof(string.Concat), [typeof(IEnumerable)])!; - - private static readonly MethodInfo StringJoinMethod - = typeof(string).GetRuntimeMethod(nameof(string.Join), [typeof(string), typeof(IEnumerable)])!; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - private readonly IRelationalTypeMappingSource _typeMappingSource; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerStringAggregateMethodTranslator( - ISqlExpressionFactory sqlExpressionFactory, - IRelationalTypeMappingSource typeMappingSource) - { - _sqlExpressionFactory = sqlExpressionFactory; - _typeMappingSource = typeMappingSource; - } - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -52,11 +31,24 @@ public SqlServerStringAggregateMethodTranslator( // Docs: https://docs.microsoft.com/sql/t-sql/functions/string-agg-transact-sql if (source.Selector is not SqlExpression sqlExpression - || (method != StringJoinMethod && method != StringConcatMethod)) + || method.DeclaringType != typeof(string)) { return null; } + SqlExpression separator; + switch (method.Name) + { + case nameof(string.Concat) when arguments is []: + separator = sqlExpressionFactory.Constant(string.Empty, typeof(string)); + break; + case nameof(string.Join) when arguments is [var sep]: + separator = sep; + break; + default: + return null; + } + // STRING_AGG enlarges the return type size (e.g. for input VARCHAR(5), it returns VARCHAR(8000)). // See https://docs.microsoft.com/sql/t-sql/functions/string-agg-transact-sql#return-types var resultTypeMapping = sqlExpression.TypeMapping; @@ -64,7 +56,7 @@ public SqlServerStringAggregateMethodTranslator( { if (resultTypeMapping is { IsUnicode: true, Size: < 4000 }) { - resultTypeMapping = _typeMappingSource.FindMapping( + resultTypeMapping = typeMappingSource.FindMapping( typeof(string), resultTypeMapping.StoreTypeNameBase, unicode: true, @@ -72,7 +64,7 @@ public SqlServerStringAggregateMethodTranslator( } else if (resultTypeMapping is { IsUnicode: false, Size: < 8000 }) { - resultTypeMapping = _typeMappingSource.FindMapping( + resultTypeMapping = typeMappingSource.FindMapping( typeof(string), resultTypeMapping.StoreTypeNameBase, unicode: false, @@ -81,28 +73,26 @@ public SqlServerStringAggregateMethodTranslator( } // STRING_AGG filters out nulls, but string.Join treats them as empty strings. - sqlExpression = _sqlExpressionFactory.Coalesce( + sqlExpression = sqlExpressionFactory.Coalesce( sqlExpression, - _sqlExpressionFactory.Constant(string.Empty, typeof(string))); + sqlExpressionFactory.Constant(string.Empty, typeof(string))); // STRING_AGG returns null when there are no rows (or non-null values), but string.Join returns an empty string. return - _sqlExpressionFactory.Coalesce( + sqlExpressionFactory.Coalesce( SqlServerExpression.AggregateFunctionWithOrdering( - _sqlExpressionFactory, + sqlExpressionFactory, "STRING_AGG", [ sqlExpression, - _sqlExpressionFactory.ApplyTypeMapping( - method == StringJoinMethod ? arguments[0] : _sqlExpressionFactory.Constant(string.Empty, typeof(string)), - sqlExpression.TypeMapping) + sqlExpressionFactory.ApplyTypeMapping(separator, sqlExpression.TypeMapping) ], source, enumerableArgumentIndex: 0, nullable: true, argumentsPropagateNullability: Statics.FalseArrays[2], typeof(string)), - _sqlExpressionFactory.Constant(string.Empty, typeof(string)), + sqlExpressionFactory.Constant(string.Empty, typeof(string)), resultTypeMapping); } } diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerStringMethodTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerStringMethodTranslator.cs index 9fbd664bcc4..fbc82a9f5bd 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerStringMethodTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerStringMethodTranslator.cs @@ -15,100 +15,11 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerStringMethodTranslator : IMethodCallTranslator +public class SqlServerStringMethodTranslator( + ISqlExpressionFactory sqlExpressionFactory, + ISqlServerSingletonOptions sqlServerSingletonOptions) + : IMethodCallTranslator { - private static readonly MethodInfo IndexOfMethodInfoString - = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(string)])!; - - private static readonly MethodInfo IndexOfMethodInfoChar - = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(char)])!; - - private static readonly MethodInfo IndexOfMethodInfoWithStartingPositionString - = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(string), typeof(int)])!; - - private static readonly MethodInfo IndexOfMethodInfoWithStartingPositionChar - = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(char), typeof(int)])!; - - private static readonly MethodInfo ReplaceMethodInfoString - = typeof(string).GetRuntimeMethod(nameof(string.Replace), [typeof(string), typeof(string)])!; - - private static readonly MethodInfo ReplaceMethodInfoChar - = typeof(string).GetRuntimeMethod(nameof(string.Replace), [typeof(char), typeof(char)])!; - - private static readonly MethodInfo ToLowerMethodInfo - = typeof(string).GetRuntimeMethod(nameof(string.ToLower), Type.EmptyTypes)!; - - private static readonly MethodInfo ToUpperMethodInfo - = typeof(string).GetRuntimeMethod(nameof(string.ToUpper), Type.EmptyTypes)!; - - private static readonly MethodInfo SubstringMethodInfoWithOneArg - = typeof(string).GetRuntimeMethod(nameof(string.Substring), [typeof(int)])!; - - private static readonly MethodInfo SubstringMethodInfoWithTwoArgs - = typeof(string).GetRuntimeMethod(nameof(string.Substring), [typeof(int), typeof(int)])!; - - private static readonly MethodInfo IsNullOrEmptyMethodInfo - = typeof(string).GetRuntimeMethod(nameof(string.IsNullOrEmpty), [typeof(string)])!; - - private static readonly MethodInfo IsNullOrWhiteSpaceMethodInfo - = typeof(string).GetRuntimeMethod(nameof(string.IsNullOrWhiteSpace), [typeof(string)])!; - - // Method defined in netcoreapp2.0 only - private static readonly MethodInfo TrimStartMethodInfoWithoutArgs - = typeof(string).GetRuntimeMethod(nameof(string.TrimStart), Type.EmptyTypes)!; - - private static readonly MethodInfo TrimEndMethodInfoWithoutArgs - = typeof(string).GetRuntimeMethod(nameof(string.TrimEnd), Type.EmptyTypes)!; - - private static readonly MethodInfo TrimMethodInfoWithoutArgs - = typeof(string).GetRuntimeMethod(nameof(string.Trim), Type.EmptyTypes)!; - - // Method defined in netstandard2.0 - private static readonly MethodInfo TrimStartMethodInfoWithCharArrayArg - = typeof(string).GetRuntimeMethod(nameof(string.TrimStart), [typeof(char[])])!; - - private static readonly MethodInfo TrimEndMethodInfoWithCharArrayArg - = typeof(string).GetRuntimeMethod(nameof(string.TrimEnd), [typeof(char[])])!; - - private static readonly MethodInfo TrimMethodInfoWithCharArrayArg - = typeof(string).GetRuntimeMethod(nameof(string.Trim), [typeof(char[])])!; - - private static readonly MethodInfo TrimStartMethodInfoWithCharArg - = typeof(string).GetRuntimeMethod(nameof(string.TrimStart), [typeof(char)])!; - - private static readonly MethodInfo TrimEndMethodInfoWithCharArg - = typeof(string).GetRuntimeMethod(nameof(string.TrimEnd), [typeof(char)])!; - - private static readonly MethodInfo FirstOrDefaultMethodInfoWithoutArgs - = typeof(Enumerable).GetRuntimeMethods().Single(m => m.Name == nameof(Enumerable.FirstOrDefault) - && m.GetParameters().Length == 1).MakeGenericMethod(typeof(char)); - - private static readonly MethodInfo LastOrDefaultMethodInfoWithoutArgs - = typeof(Enumerable).GetRuntimeMethods().Single(m => m.Name == nameof(Enumerable.LastOrDefault) - && m.GetParameters().Length == 1).MakeGenericMethod(typeof(char)); - - private static readonly MethodInfo PatIndexMethodInfo - = typeof(SqlServerDbFunctionsExtensions).GetRuntimeMethod( - nameof(SqlServerDbFunctionsExtensions.PatIndex), - [typeof(DbFunctions), typeof(string), typeof(string)])!; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - private readonly ISqlServerSingletonOptions _sqlServerSingletonOptions; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory, ISqlServerSingletonOptions sqlServerSingletonOptions) - { - _sqlExpressionFactory = sqlExpressionFactory; - - _sqlServerSingletonOptions = sqlServerSingletonOptions; - } - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -121,207 +32,188 @@ public SqlServerStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactor IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (instance != null) + if (method.DeclaringType == typeof(string)) { - if (IndexOfMethodInfoString.Equals(method) || IndexOfMethodInfoChar.Equals(method)) - { - return TranslateIndexOf(instance, method, arguments[0], null); - } - - if (IndexOfMethodInfoWithStartingPositionString.Equals(method) || IndexOfMethodInfoWithStartingPositionChar.Equals(method)) - { - return TranslateIndexOf(instance, method, arguments[0], arguments[1]); - } - - if (ReplaceMethodInfoString.Equals(method) || ReplaceMethodInfoChar.Equals(method)) + if (instance is not null) { - var firstArgument = arguments[0]; - var secondArgument = arguments[1]; - var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, firstArgument, secondArgument); - - instance = _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping); - firstArgument = _sqlExpressionFactory.ApplyTypeMapping( - firstArgument, firstArgument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping); - secondArgument = _sqlExpressionFactory.ApplyTypeMapping( - secondArgument, secondArgument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping); - - return _sqlExpressionFactory.Function( - "REPLACE", - [instance, firstArgument, secondArgument], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[3], - method.ReturnType, - stringTypeMapping); - } - - if (ToLowerMethodInfo.Equals(method) - || ToUpperMethodInfo.Equals(method)) - { - return _sqlExpressionFactory.Function( - ToLowerMethodInfo.Equals(method) ? "LOWER" : "UPPER", - [instance], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[1], - method.ReturnType, - instance.TypeMapping); - } - - if (SubstringMethodInfoWithOneArg.Equals(method)) - { - return _sqlExpressionFactory.Function( - "SUBSTRING", - [ - instance, - _sqlExpressionFactory.Add( - arguments[0], - _sqlExpressionFactory.Constant(1)), - _sqlExpressionFactory.Function( - "LEN", - [instance], - nullable: true, + return method.Name switch + { + nameof(string.IndexOf) when arguments is [var search] + => TranslateIndexOf(instance, method, search, null), + nameof(string.IndexOf) when arguments is [var search, var startIndex] && startIndex.Type == typeof(int) + => TranslateIndexOf(instance, method, search, startIndex), + + nameof(string.Replace) when arguments is [var oldValue, var newValue] + => TranslateReplace(instance, method, oldValue, newValue), + + nameof(string.ToLower) when arguments is [] + => sqlExpressionFactory.Function( + "LOWER", [instance], nullable: true, argumentsPropagateNullability: Statics.TrueArrays[1], - typeof(int)) - ], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[3], - method.ReturnType, - instance.TypeMapping); - } - - if (SubstringMethodInfoWithTwoArgs.Equals(method)) - { - return _sqlExpressionFactory.Function( - "SUBSTRING", - [ - instance, - _sqlExpressionFactory.Add( - arguments[0], - _sqlExpressionFactory.Constant(1)), - arguments[1] - ], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[3], - method.ReturnType, - instance.TypeMapping); - } - - // There's single-parameter LTRIM/RTRIM for all versions (trims whitespace), but startin with SQL Server 2022 there's also - // an overload that accepts the characters to trim. - if (method == TrimStartMethodInfoWithoutArgs - || (method == TrimStartMethodInfoWithCharArrayArg && arguments[0] is SqlConstantExpression { Value: char[] { Length: 0 } }) - || (((_sqlServerSingletonOptions.EngineType == SqlServerEngineType.SqlServer - && _sqlServerSingletonOptions.SqlServerCompatibilityLevel >= 160) - || (_sqlServerSingletonOptions.EngineType == SqlServerEngineType.AzureSql - && _sqlServerSingletonOptions.AzureSqlCompatibilityLevel >= 160) - || (_sqlServerSingletonOptions.EngineType == SqlServerEngineType.AzureSynapse)) - && (method == TrimStartMethodInfoWithCharArg || method == TrimStartMethodInfoWithCharArrayArg))) - { - return ProcessTrimStartEnd(instance, arguments, "LTRIM"); - } - - if (method == TrimEndMethodInfoWithoutArgs - || (method == TrimEndMethodInfoWithCharArrayArg && arguments[0] is SqlConstantExpression { Value: char[] { Length: 0 } }) - || (((_sqlServerSingletonOptions.EngineType == SqlServerEngineType.SqlServer - && _sqlServerSingletonOptions.SqlServerCompatibilityLevel >= 160) - || (_sqlServerSingletonOptions.EngineType == SqlServerEngineType.AzureSql - && _sqlServerSingletonOptions.AzureSqlCompatibilityLevel >= 160) - || (_sqlServerSingletonOptions.EngineType == SqlServerEngineType.AzureSynapse)) - && (method == TrimEndMethodInfoWithCharArg || method == TrimEndMethodInfoWithCharArrayArg))) - { - return ProcessTrimStartEnd(instance, arguments, "RTRIM"); - } - - if (method == TrimMethodInfoWithoutArgs - || (method == TrimMethodInfoWithCharArrayArg && arguments[0] is SqlConstantExpression { Value: char[] { Length: 0 } })) - { - return _sqlExpressionFactory.Function( - "LTRIM", - [ - _sqlExpressionFactory.Function( - "RTRIM", - [instance], + method.ReturnType, instance.TypeMapping), + nameof(string.ToUpper) when arguments is [] + => sqlExpressionFactory.Function( + "UPPER", [instance], nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[1], + method.ReturnType, instance.TypeMapping), + + nameof(string.Substring) when arguments is [var startIndex] + => TranslateSubstring(instance, method, startIndex, length: null), + nameof(string.Substring) when arguments is [var startIndex, var length] + => TranslateSubstring(instance, method, startIndex, length), + + // There's single-parameter LTRIM/RTRIM for all versions (trims whitespace), but starting with SQL Server 2022 there's + // also an overload that accepts the characters to trim. + nameof(string.TrimStart) => TranslateTrimStartEnd(instance, arguments, "LTRIM"), + nameof(string.TrimEnd) => TranslateTrimStartEnd(instance, arguments, "RTRIM"), + nameof(string.Trim) when arguments is [] or [SqlConstantExpression { Value: char[] { Length: 0 } }] + => sqlExpressionFactory.Function( + "LTRIM", + [ + sqlExpressionFactory.Function( + "RTRIM", [instance], nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[1], + instance.Type, instance.TypeMapping) + ], nullable: true, argumentsPropagateNullability: Statics.TrueArrays[1], instance.Type, - instance.TypeMapping) - ], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[1], - instance.Type, - instance.TypeMapping); - } - } + instance.TypeMapping), - if (IsNullOrEmptyMethodInfo.Equals(method)) - { - var argument = arguments[0]; - - return _sqlExpressionFactory.OrElse( - _sqlExpressionFactory.IsNull(argument), - _sqlExpressionFactory.Like( - argument, - _sqlExpressionFactory.Constant(string.Empty))); - } - - if (IsNullOrWhiteSpaceMethodInfo.Equals(method)) - { - var argument = arguments[0]; - - return _sqlExpressionFactory.OrElse( - _sqlExpressionFactory.IsNull(argument), - _sqlExpressionFactory.Equal( - argument, - _sqlExpressionFactory.Constant(string.Empty, argument.TypeMapping))); - } + _ => null + }; + } - if (FirstOrDefaultMethodInfoWithoutArgs.Equals(method)) - { - var argument = arguments[0]; - return _sqlExpressionFactory.Function( - "SUBSTRING", - [argument, _sqlExpressionFactory.Constant(1), _sqlExpressionFactory.Constant(1)], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[3], - method.ReturnType); + return method.Name switch + { + nameof(string.IsNullOrEmpty) when arguments is [var argument] + => sqlExpressionFactory.OrElse( + sqlExpressionFactory.IsNull(argument), + sqlExpressionFactory.Like( + argument, + sqlExpressionFactory.Constant(string.Empty))), + + nameof(string.IsNullOrWhiteSpace) when arguments is [var argument] + => sqlExpressionFactory.OrElse( + sqlExpressionFactory.IsNull(argument), + sqlExpressionFactory.Equal( + argument, + sqlExpressionFactory.Constant(string.Empty, argument.TypeMapping))), + + _ => null + }; } - if (LastOrDefaultMethodInfoWithoutArgs.Equals(method)) + if (method.DeclaringType == typeof(Enumerable) + && method.IsGenericMethod + && arguments is [var source] + && source.Type == typeof(string)) { - var argument = arguments[0]; - return _sqlExpressionFactory.Function( - "SUBSTRING", - [ - argument, - _sqlExpressionFactory.Function( - "LEN", - [argument], + return method.Name switch + { + nameof(Enumerable.FirstOrDefault) + => sqlExpressionFactory.Function( + "SUBSTRING", + [source, sqlExpressionFactory.Constant(1), sqlExpressionFactory.Constant(1)], nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[1], - typeof(int)), - _sqlExpressionFactory.Constant(1) - ], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[3], - method.ReturnType); + argumentsPropagateNullability: Statics.TrueArrays[3], + method.ReturnType), + + nameof(Enumerable.LastOrDefault) + => sqlExpressionFactory.Function( + "SUBSTRING", + [ + source, + sqlExpressionFactory.Function( + "LEN", [source], nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[1], + typeof(int)), + sqlExpressionFactory.Constant(1) + ], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[3], + method.ReturnType), + + _ => null + }; } - if (PatIndexMethodInfo.Equals(method)) + if (method.DeclaringType == typeof(SqlServerDbFunctionsExtensions) + && method.Name == nameof(SqlServerDbFunctionsExtensions.PatIndex) + && arguments is [_, var pattern, var expression]) { - var pattern = arguments[1]; - var expression = arguments[2]; - - return _sqlExpressionFactory.Function( + return sqlExpressionFactory.Function( "PATINDEX", [pattern, expression], nullable: true, argumentsPropagateNullability: Statics.TrueArrays[2], - method.ReturnType - ); + method.ReturnType); } return null; } + private SqlExpression TranslateReplace( + SqlExpression instance, + MethodInfo method, + SqlExpression oldValue, + SqlExpression newValue) + { + var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, oldValue, newValue); + + instance = sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping); + oldValue = sqlExpressionFactory.ApplyTypeMapping( + oldValue, oldValue.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping); + newValue = sqlExpressionFactory.ApplyTypeMapping( + newValue, newValue.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping); + + return sqlExpressionFactory.Function( + "REPLACE", + [instance, oldValue, newValue], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[3], + method.ReturnType, + stringTypeMapping); + } + + private SqlExpression TranslateSubstring( + SqlExpression instance, + MethodInfo method, + SqlExpression startIndex, + SqlExpression? length) + => sqlExpressionFactory.Function( + "SUBSTRING", + [ + instance, + sqlExpressionFactory.Add(startIndex, sqlExpressionFactory.Constant(1)), + length ?? sqlExpressionFactory.Function( + "LEN", [instance], nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[1], + typeof(int)) + ], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[3], + method.ReturnType, + instance.TypeMapping); + + private SqlExpression? TranslateTrimStartEnd(SqlExpression instance, IReadOnlyList arguments, string functionName) + => arguments switch + { + // No args or empty char[] constant - whitespace trim, always supported + ([]) or ([SqlConstantExpression { Value: char[] { Length: 0 } }]) + => ProcessTrimStartEnd(instance, arguments, functionName), + + // Char or char[] argument - requires SQL Server 2022+ (compatibility level 160) + [_] when (sqlServerSingletonOptions.EngineType == SqlServerEngineType.SqlServer + && sqlServerSingletonOptions.SqlServerCompatibilityLevel >= 160) + || (sqlServerSingletonOptions.EngineType == SqlServerEngineType.AzureSql + && sqlServerSingletonOptions.AzureSqlCompatibilityLevel >= 160) + || sqlServerSingletonOptions.EngineType == SqlServerEngineType.AzureSynapse + => ProcessTrimStartEnd(instance, arguments, functionName), + + _ => null + }; + private SqlExpression TranslateIndexOf( SqlExpression instance, MethodInfo method, @@ -329,10 +221,10 @@ private SqlExpression TranslateIndexOf( SqlExpression? startIndex) { var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, searchExpression)!; - searchExpression = _sqlExpressionFactory.ApplyTypeMapping( + searchExpression = sqlExpressionFactory.ApplyTypeMapping( searchExpression, searchExpression.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping); - instance = _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping); + instance = sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping); var charIndexArguments = new List { searchExpression, instance }; @@ -340,8 +232,8 @@ private SqlExpression TranslateIndexOf( { charIndexArguments.Add( startIndex is SqlConstantExpression { Value : int constantStartIndex } - ? _sqlExpressionFactory.Constant(constantStartIndex + 1, typeof(int)) - : _sqlExpressionFactory.Add(startIndex, _sqlExpressionFactory.Constant(1))); + ? sqlExpressionFactory.Constant(constantStartIndex + 1, typeof(int)) + : sqlExpressionFactory.Add(startIndex, sqlExpressionFactory.Constant(1))); } var argumentsPropagateNullability = Enumerable.Repeat(true, charIndexArguments.Count); @@ -351,18 +243,18 @@ private SqlExpression TranslateIndexOf( if (string.Equals(storeType, "nvarchar(max)", StringComparison.OrdinalIgnoreCase) || string.Equals(storeType, "varchar(max)", StringComparison.OrdinalIgnoreCase)) { - charIndexExpression = _sqlExpressionFactory.Function( + charIndexExpression = sqlExpressionFactory.Function( "CHARINDEX", charIndexArguments, nullable: true, argumentsPropagateNullability, typeof(long)); - charIndexExpression = _sqlExpressionFactory.Convert(charIndexExpression, typeof(int)); + charIndexExpression = sqlExpressionFactory.Convert(charIndexExpression, typeof(int)); } else { - charIndexExpression = _sqlExpressionFactory.Function( + charIndexExpression = sqlExpressionFactory.Function( "CHARINDEX", charIndexArguments, nullable: true, @@ -374,25 +266,25 @@ private SqlExpression TranslateIndexOf( // -1). Handle separately for constant and non-constant patterns. if (searchExpression is SqlConstantExpression { Value: "" }) { - return _sqlExpressionFactory.Case( - [new CaseWhenClause(_sqlExpressionFactory.IsNotNull(instance), _sqlExpressionFactory.Constant(0))], + return sqlExpressionFactory.Case( + [new CaseWhenClause(sqlExpressionFactory.IsNotNull(instance), sqlExpressionFactory.Constant(0))], elseResult: null ); } var offsetExpression = searchExpression is SqlConstantExpression - ? _sqlExpressionFactory.Constant(1) - : _sqlExpressionFactory.Case( + ? sqlExpressionFactory.Constant(1) + : sqlExpressionFactory.Case( [ new CaseWhenClause( - _sqlExpressionFactory.Equal( + sqlExpressionFactory.Equal( searchExpression, - _sqlExpressionFactory.Constant(string.Empty, stringTypeMapping)), - _sqlExpressionFactory.Constant(0)) + sqlExpressionFactory.Constant(string.Empty, stringTypeMapping)), + sqlExpressionFactory.Constant(0)) ], - _sqlExpressionFactory.Constant(1)); + sqlExpressionFactory.Constant(1)); - return _sqlExpressionFactory.Subtract(charIndexExpression, offsetExpression); + return sqlExpressionFactory.Subtract(charIndexExpression, offsetExpression); } private SqlExpression? ProcessTrimStartEnd(SqlExpression instance, IReadOnlyList arguments, string functionName) @@ -402,13 +294,13 @@ [new CaseWhenClause(_sqlExpressionFactory.IsNotNull(instance), _sqlExpressionFac { charactersToTrim = charactersToTrimValue switch { - char singleChar => _sqlExpressionFactory.Constant(singleChar.ToString(), instance.TypeMapping), - char[] charArray => _sqlExpressionFactory.Constant(new string(charArray), instance.TypeMapping), + char singleChar => sqlExpressionFactory.Constant(singleChar.ToString(), instance.TypeMapping), + char[] charArray => sqlExpressionFactory.Constant(new string(charArray), instance.TypeMapping), _ => throw new UnreachableException("Invalid parameter type for string.TrimStart/TrimEnd") }; } - return _sqlExpressionFactory.Function( + return sqlExpressionFactory.Function( functionName, arguments: charactersToTrim is null ? [instance] : [instance, charactersToTrim], nullable: true, diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerTimeOnlyMethodTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerTimeOnlyMethodTranslator.cs index 4189be6a7d9..9e45428e27f 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerTimeOnlyMethodTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerTimeOnlyMethodTranslator.cs @@ -13,34 +13,8 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqlServerTimeOnlyMethodTranslator : IMethodCallTranslator +public class SqlServerTimeOnlyMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo AddHoursMethod = typeof(TimeOnly).GetRuntimeMethod( - nameof(TimeOnly.AddHours), [typeof(double)])!; - - private static readonly MethodInfo AddMinutesMethod = typeof(TimeOnly).GetRuntimeMethod( - nameof(TimeOnly.AddMinutes), [typeof(double)])!; - - private static readonly MethodInfo IsBetweenMethod = typeof(TimeOnly).GetRuntimeMethod( - nameof(TimeOnly.IsBetween), [typeof(TimeOnly), typeof(TimeOnly)])!; - - private static readonly MethodInfo FromDateTime = typeof(TimeOnly).GetRuntimeMethod( - nameof(TimeOnly.FromDateTime), [typeof(DateTime)])!; - - private static readonly MethodInfo FromTimeSpan = typeof(TimeOnly).GetRuntimeMethod( - nameof(TimeOnly.FromTimeSpan), [typeof(TimeSpan)])!; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqlServerTimeOnlyMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -58,33 +32,36 @@ public SqlServerTimeOnlyMethodTranslator(ISqlExpressionFactory sqlExpressionFact return null; } - if ((method == FromDateTime || method == FromTimeSpan) - && instance is null - && arguments.Count == 1) - { - return _sqlExpressionFactory.Convert(arguments[0], typeof(TimeOnly)); - } - if (instance is null) { - return null; + return method.Name switch + { + nameof(TimeOnly.FromDateTime) or nameof(TimeOnly.FromTimeSpan) when arguments is [_] + => sqlExpressionFactory.Convert(arguments[0], typeof(TimeOnly)), + _ => null + }; } - if (method == AddHoursMethod || method == AddMinutesMethod) + var datePart = method.Name switch { - var datePart = method == AddHoursMethod ? "hour" : "minute"; + nameof(TimeOnly.AddHours) => "hour", + nameof(TimeOnly.AddMinutes) => "minute", + _ => null + }; + if (datePart is not null) + { // Some Add methods accept a double, and SQL Server DateAdd does not accept number argument outside of int range if (arguments[0] is SqlConstantExpression { Value: double and (<= int.MinValue or >= int.MaxValue) }) { return null; } - instance = _sqlExpressionFactory.ApplyDefaultTypeMapping(instance); + instance = sqlExpressionFactory.ApplyDefaultTypeMapping(instance); - return _sqlExpressionFactory.Function( + return sqlExpressionFactory.Function( "DATEADD", - [_sqlExpressionFactory.Fragment(datePart), _sqlExpressionFactory.Convert(arguments[0], typeof(int)), instance], + [sqlExpressionFactory.Fragment(datePart), sqlExpressionFactory.Convert(arguments[0], typeof(int)), instance], nullable: true, argumentsPropagateNullability: [false, true, true], instance.Type, @@ -93,19 +70,19 @@ public SqlServerTimeOnlyMethodTranslator(ISqlExpressionFactory sqlExpressionFact // Translate TimeOnly.IsBetween to a >= b AND a < c. // Since a is evaluated multiple times, only translate for simple constructs (i.e. avoid duplicating complex subqueries). - if (method == IsBetweenMethod + if (method.Name == nameof(TimeOnly.IsBetween) && instance is ColumnExpression or SqlConstantExpression or SqlParameterExpression) { var typeMapping = ExpressionExtensions.InferTypeMapping(instance, arguments[0], arguments[1]); - instance = _sqlExpressionFactory.ApplyTypeMapping(instance, typeMapping); + instance = sqlExpressionFactory.ApplyTypeMapping(instance, typeMapping); - return _sqlExpressionFactory.And( - _sqlExpressionFactory.GreaterThanOrEqual( + return sqlExpressionFactory.And( + sqlExpressionFactory.GreaterThanOrEqual( instance, - _sqlExpressionFactory.ApplyTypeMapping(arguments[0], typeMapping)), - _sqlExpressionFactory.LessThan( + sqlExpressionFactory.ApplyTypeMapping(arguments[0], typeMapping)), + sqlExpressionFactory.LessThan( instance, - _sqlExpressionFactory.ApplyTypeMapping(arguments[1], typeMapping))); + sqlExpressionFactory.ApplyTypeMapping(arguments[1], typeMapping))); } return null; diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteByteArrayMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteByteArrayMethodTranslator.cs index e5e7af297c6..5bb58b85d7c 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteByteArrayMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteByteArrayMethodTranslator.cs @@ -12,19 +12,8 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteByteArrayMethodTranslator : IMethodCallTranslator +public class SqliteByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -38,56 +27,56 @@ public SqliteByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFactor IDiagnosticsLogger logger) { if (method.IsGenericMethod - && method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains) - && arguments[0].Type == typeof(byte[])) + && method.DeclaringType == typeof(Enumerable) + && method.Name == nameof(Enumerable.Contains) + && arguments is [var source, var item] + && source.Type == typeof(byte[])) { - var source = arguments[0]; - - var value = arguments[1] is SqlConstantExpression constantValue - ? (SqlExpression)_sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, source.TypeMapping) - : _sqlExpressionFactory.Function( + var value = item is SqlConstantExpression constantValue + ? sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, source.TypeMapping) + : sqlExpressionFactory.Function( "char", - [arguments[1]], + [item], nullable: false, argumentsPropagateNullability: Statics.FalseArrays[1], typeof(string)); - return _sqlExpressionFactory.GreaterThan( - _sqlExpressionFactory.Function( + return sqlExpressionFactory.GreaterThan( + sqlExpressionFactory.Function( "instr", [source, value], nullable: true, argumentsPropagateNullability: Statics.TrueArrays[2], typeof(int)), - _sqlExpressionFactory.Constant(0)); + sqlExpressionFactory.Constant(0)); } - // See issue#16428 - //if (method.IsGenericMethod - // && method.GetGenericMethodDefinition().Equals(EnumerableMethods.FirstWithoutPredicate) - // && arguments[0].Type == typeof(byte[])) - //{ - // return _sqlExpressionFactory.Function( - // "unicode", - // new SqlExpression[] - // { - // _sqlExpressionFactory.Function( - // "substr", - // new SqlExpression[] - // { - // arguments[0], - // _sqlExpressionFactory.Constant(1), - // _sqlExpressionFactory.Constant(1) - // }, - // nullable: true, - // argumentsPropagateNullability: new[] { true, true, true }, - // typeof(byte[])) - // }, - // nullable: true, - // argumentsPropagateNullability: new[] { true }, - // method.ReturnType); - //} - return null; } + + // See issue#16428 + //if (method.IsGenericMethod + // && method.GetGenericMethodDefinition().Equals(EnumerableMethods.FirstWithoutPredicate) + // && arguments[0].Type == typeof(byte[])) + //{ + // return _sqlExpressionFactory.Function( + // "unicode", + // new SqlExpression[] + // { + // _sqlExpressionFactory.Function( + // "substr", + // new SqlExpression[] + // { + // arguments[0], + // _sqlExpressionFactory.Constant(1), + // _sqlExpressionFactory.Constant(1) + // }, + // nullable: true, + // argumentsPropagateNullability: new[] { true, true, true }, + // typeof(byte[])) + // }, + // nullable: true, + // argumentsPropagateNullability: new[] { true }, + // method.ReturnType); + //} } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteCharMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteCharMethodTranslator.cs index 4e75f3de527..4ebe3bdc527 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteCharMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteCharMethodTranslator.cs @@ -12,25 +12,8 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteCharMethodTranslator : IMethodCallTranslator +public class SqliteCharMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly Dictionary SupportedMethods = new() - { - { typeof(char).GetRuntimeMethod(nameof(char.ToLower), [typeof(char)])!, "lower" }, - { typeof(char).GetRuntimeMethod(nameof(char.ToUpper), [typeof(char)])!, "upper" } - }; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteCharMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -43,17 +26,26 @@ public SqliteCharMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (SupportedMethods.TryGetValue(method, out var sqlFunctionName)) + if (method.DeclaringType != typeof(char) || arguments is not [var arg]) + { + return null; + } + + var sqlFunctionName = method.Name switch { - return _sqlExpressionFactory.Function( + nameof(char.ToLower) => "lower", + nameof(char.ToUpper) => "upper", + _ => (string?)null + }; + + return sqlFunctionName is null + ? null + : sqlExpressionFactory.Function( sqlFunctionName, - arguments, + [arg], nullable: true, - argumentsPropagateNullability: arguments.Select(_ => true).ToList(), + argumentsPropagateNullability: Statics.TrueArrays[1], method.ReturnType, - arguments[0].TypeMapping); - } - - return null; + arg.TypeMapping); } } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteDateOnlyMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteDateOnlyMethodTranslator.cs index 3f9aade9abb..69488d5dc7e 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteDateOnlyMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteDateOnlyMethodTranslator.cs @@ -12,19 +12,8 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteDateOnlyMethodTranslator : IMethodCallTranslator +public class SqliteDateOnlyMethodTranslator(SqliteSqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private readonly SqliteSqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteDateOnlyMethodTranslator(SqliteSqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -39,9 +28,9 @@ public SqliteDateOnlyMethodTranslator(SqliteSqlExpressionFactory sqlExpressionFa { if (method.DeclaringType == typeof(DateOnly) && method.Name == nameof(DateOnly.FromDateTime) - && arguments.Count == 1) + && arguments is [var arg]) { - return _sqlExpressionFactory.Date(method.ReturnType, arguments[0]); + return sqlExpressionFactory.Date(method.ReturnType, arg); } return null; diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteDateTimeMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteDateTimeMethodTranslator.cs index 9081423d5bc..4aab2e06d55 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteDateTimeMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteDateTimeMethodTranslator.cs @@ -14,25 +14,6 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// public class SqliteDateTimeMethodTranslator(SqliteSqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo AddMilliseconds - = typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMilliseconds), [typeof(double)])!; - - private static readonly MethodInfo AddTicks - = typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddTicks), [typeof(long)])!; - - private static readonly Dictionary MethodInfoToUnitSuffix = new() - { - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddYears), [typeof(int)])!, " years" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMonths), [typeof(int)])!, " months" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddDays), [typeof(double)])!, " days" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddHours), [typeof(double)])!, " hours" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMinutes), [typeof(double)])!, " minutes" }, - { typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddSeconds), [typeof(double)])!, " seconds" }, - { typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddYears), [typeof(int)])!, " years" }, - { typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddMonths), [typeof(int)])!, " months" }, - { typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddDays), [typeof(int)])!, " days" } - }; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -55,60 +36,61 @@ private static readonly MethodInfo AddTicks MethodInfo method, IReadOnlyList arguments) { - SqlExpression? modifier = null; - if (AddMilliseconds.Equals(method)) + if (instance is null || arguments is not [var arg]) { - modifier = sqlExpressionFactory.Add( - sqlExpressionFactory.Convert( - sqlExpressionFactory.Divide( - arguments[0], - sqlExpressionFactory.Constant(1000.0)), - typeof(string)), - sqlExpressionFactory.Constant(" seconds")); + return null; } - else if (AddTicks.Equals(method)) + + var modifier = method.Name switch { - modifier = sqlExpressionFactory.Add( + nameof(DateTime.AddMilliseconds) => sqlExpressionFactory.Add( sqlExpressionFactory.Convert( - sqlExpressionFactory.Divide( - arguments[0], - sqlExpressionFactory.Constant((double)TimeSpan.TicksPerSecond)), + sqlExpressionFactory.Divide(arg, sqlExpressionFactory.Constant(1000.0)), typeof(string)), - sqlExpressionFactory.Constant(" seconds")); - } - else if (MethodInfoToUnitSuffix.TryGetValue(method, out var unitSuffix)) - { - modifier = sqlExpressionFactory.Add( - sqlExpressionFactory.Convert(arguments[0], typeof(string)), - sqlExpressionFactory.Constant(unitSuffix)); - } + sqlExpressionFactory.Constant(" seconds")), + + nameof(DateTime.AddTicks) => sqlExpressionFactory.Add( + sqlExpressionFactory.Convert( + sqlExpressionFactory.Divide(arg, sqlExpressionFactory.Constant((double)TimeSpan.TicksPerSecond)), + typeof(string)), + sqlExpressionFactory.Constant(" seconds")), + + nameof(DateTime.AddYears) => MakeModifier(arg, " years"), + nameof(DateTime.AddMonths) => MakeModifier(arg, " months"), + nameof(DateTime.AddDays) => MakeModifier(arg, " days"), + nameof(DateTime.AddHours) => MakeModifier(arg, " hours"), + nameof(DateTime.AddMinutes) => MakeModifier(arg, " minutes"), + nameof(DateTime.AddSeconds) => MakeModifier(arg, " seconds"), - if (modifier != null) + _ => (SqlExpression?)null + }; + + if (modifier is null) { - return sqlExpressionFactory.Function( - "rtrim", - [ - sqlExpressionFactory.Function( - "rtrim", - [ - sqlExpressionFactory.Strftime( - method.ReturnType, - "%Y-%m-%d %H:%M:%f", - instance!, - modifiers: [modifier]), - sqlExpressionFactory.Constant("0") - ], - nullable: true, - argumentsPropagateNullability: Statics.TrueFalse, - method.ReturnType), - sqlExpressionFactory.Constant(".") - ], - nullable: true, - argumentsPropagateNullability: Statics.TrueFalse, - method.ReturnType); + return null; } - return null; + return sqlExpressionFactory.Function( + "rtrim", + [ + sqlExpressionFactory.Function( + "rtrim", + [ + sqlExpressionFactory.Strftime( + method.ReturnType, + "%Y-%m-%d %H:%M:%f", + instance, + modifiers: [modifier]), + sqlExpressionFactory.Constant("0") + ], + nullable: true, + argumentsPropagateNullability: Statics.TrueFalse, + method.ReturnType), + sqlExpressionFactory.Constant(".") + ], + nullable: true, + argumentsPropagateNullability: Statics.TrueFalse, + method.ReturnType); } private SqlExpression? TranslateDateOnly( @@ -116,19 +98,29 @@ private static readonly MethodInfo AddTicks MethodInfo method, IReadOnlyList arguments) { - if (instance is not null && MethodInfoToUnitSuffix.TryGetValue(method, out var unitSuffix)) + if (instance is null || arguments is not [var arg]) { - return sqlExpressionFactory.Date( - method.ReturnType, - instance, - modifiers: - [ - sqlExpressionFactory.Add( - sqlExpressionFactory.Convert(arguments[0], typeof(string)), - sqlExpressionFactory.Constant(unitSuffix)) - ]); + return null; } - return null; + var unitSuffix = method.Name switch + { + nameof(DateOnly.AddYears) => " years", + nameof(DateOnly.AddMonths) => " months", + nameof(DateOnly.AddDays) => " days", + _ => (string?)null + }; + + return unitSuffix is not null + ? sqlExpressionFactory.Date( + method.ReturnType, + instance, + modifiers: [MakeModifier(arg, unitSuffix)]) + : null; } + + private SqlExpression MakeModifier(SqlExpression argument, string unitSuffix) + => sqlExpressionFactory.Add( + sqlExpressionFactory.Convert(argument, typeof(string)), + sqlExpressionFactory.Constant(unitSuffix)); } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteGlobMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteGlobMethodTranslator.cs index c69d1d189f4..68ea5b925bb 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteGlobMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteGlobMethodTranslator.cs @@ -12,22 +12,8 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteGlobMethodTranslator : IMethodCallTranslator +public class SqliteGlobMethodTranslator(SqliteSqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo MethodInfo = typeof(SqliteDbFunctionsExtensions) - .GetMethod(nameof(SqliteDbFunctionsExtensions.Glob), [typeof(DbFunctions), typeof(string), typeof(string)])!; - - private readonly SqliteSqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteGlobMethodTranslator(SqliteSqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -40,12 +26,11 @@ public SqliteGlobMethodTranslator(SqliteSqlExpressionFactory sqlExpressionFactor IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (method.Equals(MethodInfo)) + if (method.DeclaringType == typeof(SqliteDbFunctionsExtensions) + && method.Name == nameof(SqliteDbFunctionsExtensions.Glob) + && arguments is [_, var matchExpression, var pattern]) { - var matchExpression = arguments[1]; - var pattern = arguments[2]; - - return _sqlExpressionFactory.Glob(matchExpression, pattern); + return sqlExpressionFactory.Glob(matchExpression, pattern); } return null; diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteHexMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteHexMethodTranslator.cs index bd3a5886e86..ac958b5990e 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteHexMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteHexMethodTranslator.cs @@ -12,28 +12,8 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteHexMethodTranslator : IMethodCallTranslator +public class SqliteHexMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo HexMethodInfo = typeof(SqliteDbFunctionsExtensions) - .GetMethod(nameof(SqliteDbFunctionsExtensions.Hex), [typeof(DbFunctions), typeof(byte[])])!; - - private static readonly MethodInfo UnhexMethodInfo = typeof(SqliteDbFunctionsExtensions) - .GetMethod(nameof(SqliteDbFunctionsExtensions.Unhex), [typeof(DbFunctions), typeof(string)])!; - - private static readonly MethodInfo UnhexWithIgnoreCharsMethodInfo = typeof(SqliteDbFunctionsExtensions) - .GetMethod(nameof(SqliteDbFunctionsExtensions.Unhex), [typeof(DbFunctions), typeof(string), typeof(string)])!; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteHexMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -46,29 +26,40 @@ public SqliteHexMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (method.Equals(HexMethodInfo)) + if (method.DeclaringType != typeof(SqliteDbFunctionsExtensions)) { - return _sqlExpressionFactory.Function( - "hex", - [arguments[1]], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[1], - typeof(string)); + return null; } - if (method.Equals(UnhexMethodInfo) - || method.Equals(UnhexWithIgnoreCharsMethodInfo)) + return method.Name switch { + nameof(SqliteDbFunctionsExtensions.Hex) when arguments is [_, var arg] + => sqlExpressionFactory.Function( + "hex", + [arg], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[1], + typeof(string)), + // unhex returns NULL whenever the decoding fails, hence mark as // nullable and use an all-false argumentsPropagateNullability - return _sqlExpressionFactory.Function( - "unhex", - arguments.Skip(1), - nullable: true, - argumentsPropagateNullability: arguments.Skip(1).Select(_ => false).ToArray(), - typeof(byte[])); - } + nameof(SqliteDbFunctionsExtensions.Unhex) when arguments is [_, var arg] + => sqlExpressionFactory.Function( + "unhex", + [arg], + nullable: true, + argumentsPropagateNullability: Statics.FalseArrays[1], + typeof(byte[])), + + nameof(SqliteDbFunctionsExtensions.Unhex) when arguments is [_, var arg, var ignoreChars] + => sqlExpressionFactory.Function( + "unhex", + [arg, ignoreChars], + nullable: true, + argumentsPropagateNullability: Statics.FalseArrays[2], + typeof(byte[])), - return null; + _ => null + }; } } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteMathTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteMathTranslator.cs index 6aecd3f8802..1c24b42571d 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteMathTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteMathTranslator.cs @@ -13,98 +13,10 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteMathTranslator : IMethodCallTranslator +public class SqliteMathTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly Dictionary SupportedMethods = new() - { - { typeof(Math).GetMethod(nameof(Math.Abs), [typeof(double)])!, "abs" }, - { typeof(Math).GetMethod(nameof(Math.Abs), [typeof(float)])!, "abs" }, - { typeof(Math).GetMethod(nameof(Math.Abs), [typeof(int)])!, "abs" }, - { typeof(Math).GetMethod(nameof(Math.Abs), [typeof(long)])!, "abs" }, - { typeof(Math).GetMethod(nameof(Math.Abs), [typeof(sbyte)])!, "abs" }, - { typeof(Math).GetMethod(nameof(Math.Abs), [typeof(short)])!, "abs" }, - { typeof(Math).GetMethod(nameof(Math.Acos), [typeof(double)])!, "acos" }, - { typeof(Math).GetMethod(nameof(Math.Acosh), [typeof(double)])!, "acosh" }, - { typeof(Math).GetMethod(nameof(Math.Asin), [typeof(double)])!, "asin" }, - { typeof(Math).GetMethod(nameof(Math.Asinh), [typeof(double)])!, "asinh" }, - { typeof(Math).GetMethod(nameof(Math.Atan), [typeof(double)])!, "atan" }, - { typeof(Math).GetMethod(nameof(Math.Atan2), [typeof(double), typeof(double)])!, "atan2" }, - { typeof(Math).GetMethod(nameof(Math.Atanh), [typeof(double)])!, "atanh" }, - { typeof(Math).GetMethod(nameof(Math.Ceiling), [typeof(double)])!, "ceiling" }, - { typeof(Math).GetMethod(nameof(Math.Cos), [typeof(double)])!, "cos" }, - { typeof(Math).GetMethod(nameof(Math.Cosh), [typeof(double)])!, "cosh" }, - { typeof(Math).GetMethod(nameof(Math.Exp), [typeof(double)])!, "exp" }, - { typeof(Math).GetMethod(nameof(Math.Floor), [typeof(double)])!, "floor" }, - { typeof(Math).GetMethod(nameof(Math.Log), [typeof(double)])!, "ln" }, - { typeof(Math).GetMethod(nameof(Math.Log2), [typeof(double)])!, "log2" }, - { typeof(Math).GetMethod(nameof(Math.Log10), [typeof(double)])!, "log10" }, - { typeof(Math).GetMethod(nameof(Math.Pow), [typeof(double), typeof(double)])!, "pow" }, - { typeof(Math).GetMethod(nameof(Math.Round), [typeof(double)])!, "round" }, - { typeof(Math).GetMethod(nameof(Math.Sign), [typeof(double)])!, "sign" }, - { typeof(Math).GetMethod(nameof(Math.Sign), [typeof(float)])!, "sign" }, - { typeof(Math).GetMethod(nameof(Math.Sign), [typeof(long)])!, "sign" }, - { typeof(Math).GetMethod(nameof(Math.Sign), [typeof(sbyte)])!, "sign" }, - { typeof(Math).GetMethod(nameof(Math.Sign), [typeof(short)])!, "sign" }, - { typeof(Math).GetMethod(nameof(Math.Sin), [typeof(double)])!, "sin" }, - { typeof(Math).GetMethod(nameof(Math.Sinh), [typeof(double)])!, "sinh" }, - { typeof(Math).GetMethod(nameof(Math.Sqrt), [typeof(double)])!, "sqrt" }, - { typeof(Math).GetMethod(nameof(Math.Tan), [typeof(double)])!, "tan" }, - { typeof(Math).GetMethod(nameof(Math.Tanh), [typeof(double)])!, "tanh" }, - { typeof(Math).GetMethod(nameof(Math.Truncate), [typeof(double)])!, "trunc" }, - { typeof(double).GetRuntimeMethod(nameof(double.DegreesToRadians), [typeof(double)])!, "radians" }, - { typeof(double).GetRuntimeMethod(nameof(double.RadiansToDegrees), [typeof(double)])!, "degrees" }, - { typeof(MathF).GetMethod(nameof(MathF.Acos), [typeof(float)])!, "acos" }, - { typeof(MathF).GetMethod(nameof(MathF.Acosh), [typeof(float)])!, "acosh" }, - { typeof(MathF).GetMethod(nameof(MathF.Asin), [typeof(float)])!, "asin" }, - { typeof(MathF).GetMethod(nameof(MathF.Asinh), [typeof(float)])!, "asinh" }, - { typeof(MathF).GetMethod(nameof(MathF.Atan), [typeof(float)])!, "atan" }, - { typeof(MathF).GetMethod(nameof(MathF.Atan2), [typeof(float), typeof(float)])!, "atan2" }, - { typeof(MathF).GetMethod(nameof(MathF.Atanh), [typeof(float)])!, "atanh" }, - { typeof(MathF).GetMethod(nameof(MathF.Ceiling), [typeof(float)])!, "ceiling" }, - { typeof(MathF).GetMethod(nameof(MathF.Cos), [typeof(float)])!, "cos" }, - { typeof(MathF).GetMethod(nameof(MathF.Cosh), [typeof(float)])!, "cosh" }, - { typeof(MathF).GetMethod(nameof(MathF.Exp), [typeof(float)])!, "exp" }, - { typeof(MathF).GetMethod(nameof(MathF.Floor), [typeof(float)])!, "floor" }, - { typeof(MathF).GetMethod(nameof(MathF.Log), [typeof(float)])!, "ln" }, - { typeof(MathF).GetMethod(nameof(MathF.Log10), [typeof(float)])!, "log10" }, - { typeof(MathF).GetMethod(nameof(MathF.Log2), [typeof(float)])!, "log2" }, - { typeof(MathF).GetMethod(nameof(MathF.Pow), [typeof(float), typeof(float)])!, "pow" }, - { typeof(MathF).GetMethod(nameof(MathF.Round), [typeof(float)])!, "round" }, - { typeof(MathF).GetMethod(nameof(MathF.Sin), [typeof(float)])!, "sin" }, - { typeof(MathF).GetMethod(nameof(MathF.Sinh), [typeof(float)])!, "sinh" }, - { typeof(MathF).GetMethod(nameof(MathF.Sqrt), [typeof(float)])!, "sqrt" }, - { typeof(MathF).GetMethod(nameof(MathF.Tan), [typeof(float)])!, "tan" }, - { typeof(MathF).GetMethod(nameof(MathF.Tanh), [typeof(float)])!, "tanh" }, - { typeof(MathF).GetMethod(nameof(MathF.Truncate), [typeof(float)])!, "trunc" }, - { typeof(float).GetRuntimeMethod(nameof(float.DegreesToRadians), [typeof(float)])!, "radians" }, - { typeof(float).GetRuntimeMethod(nameof(float.RadiansToDegrees), [typeof(float)])!, "degrees" } - }; - // Note: Math.Max/Min are handled in RelationalSqlTranslatingExpressionVisitor - private static readonly List _roundWithDecimalMethods = - [ - typeof(Math).GetMethod(nameof(Math.Round), [typeof(double), typeof(int)])!, - typeof(MathF).GetMethod(nameof(MathF.Round), [typeof(float), typeof(int)])! - ]; - - private static readonly List _logWithBaseMethods = - [ - typeof(Math).GetMethod(nameof(Math.Log), [typeof(double), typeof(double)])!, - typeof(MathF).GetMethod(nameof(MathF.Log), [typeof(float), typeof(float)])! - ]; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteMathTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -117,50 +29,163 @@ public SqliteMathTranslator(ISqlExpressionFactory sqlExpressionFactory) IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (SupportedMethods.TryGetValue(method, out var sqlFunctionName)) + if (method.DeclaringType != typeof(Math) + && method.DeclaringType != typeof(MathF) + && method.DeclaringType != typeof(double) + && method.DeclaringType != typeof(float)) { - var typeMapping = ExpressionExtensions.InferTypeMapping(arguments.ToArray()); - var newArguments = arguments - .Select(a => _sqlExpressionFactory.ApplyTypeMapping(a, typeMapping)) - .ToList(); - - return _sqlExpressionFactory.Function( - sqlFunctionName, - newArguments, - nullable: true, - argumentsPropagateNullability: newArguments.Select(_ => true).ToList(), - method.ReturnType, - typeMapping); + return null; } - if (_roundWithDecimalMethods.Contains(method)) + return method.Name switch { - return _sqlExpressionFactory.Function( - "round", - arguments, - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[2], - method.ReturnType, - arguments[0].TypeMapping); - } + nameof(Math.Abs) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float) || t == typeof(int) || t == typeof(long) || t == typeof(sbyte) || t == typeof(short)) + => TranslateFunction("abs", arg), - if (_logWithBaseMethods.Contains(method)) - { - var a = arguments[0]; - var newBase = arguments[1]; - var typeMapping = ExpressionExtensions.InferTypeMapping(a, newBase); - - return _sqlExpressionFactory.Function( - "log", - [ - _sqlExpressionFactory.ApplyTypeMapping(newBase, typeMapping), _sqlExpressionFactory.ApplyTypeMapping(a, typeMapping) - ], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[2], - method.ReturnType, - typeMapping); - } + nameof(Math.Acos) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("acos", arg), + nameof(Math.Acosh) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("acosh", arg), + nameof(Math.Asin) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("asin", arg), + nameof(Math.Asinh) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("asinh", arg), + nameof(Math.Atan) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("atan", arg), + nameof(Math.Atan2) when arguments is [var arg1, var arg2] + => TranslateFunction("atan2", arg1, arg2), + nameof(Math.Atanh) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("atanh", arg), + nameof(Math.Ceiling) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("ceiling", arg), + nameof(Math.Cos) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("cos", arg), + nameof(Math.Cosh) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("cosh", arg), + nameof(Math.Exp) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("exp", arg), + nameof(Math.Floor) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("floor", arg), + + nameof(Math.Log) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("ln", arg), + nameof(Math.Log) when arguments is [var a, var newBase] + => TranslateLogWithBase(a, newBase), + + nameof(Math.Log2) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("log2", arg), + nameof(Math.Log10) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("log10", arg), + nameof(Math.Pow) when arguments is [var arg1, var arg2] + => TranslateFunction("pow", arg1, arg2), + + nameof(Math.Round) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("round", arg), + nameof(Math.Round) when arguments is [var arg, var digits] + && digits.Type == typeof(int) + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateRoundWithDigits(arg, digits), + + nameof(Math.Sign) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float) || t == typeof(long) || t == typeof(sbyte) || t == typeof(short)) + => TranslateFunction("sign", arg), + + nameof(Math.Sin) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("sin", arg), + nameof(Math.Sinh) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("sinh", arg), + nameof(Math.Sqrt) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("sqrt", arg), + nameof(Math.Tan) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("tan", arg), + nameof(Math.Tanh) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("tanh", arg), + nameof(Math.Truncate) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("trunc", arg), + nameof(double.DegreesToRadians) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("radians", arg), + nameof(double.RadiansToDegrees) when arguments is [var arg] + && arg.Type is { } t && (t == typeof(double) || t == typeof(float)) + => TranslateFunction("degrees", arg), + + _ => null + }; + } + + private SqlExpression TranslateFunction(string sqlFunctionName, SqlExpression argument) + { + var typeMapping = argument.TypeMapping; + argument = sqlExpressionFactory.ApplyTypeMapping(argument, typeMapping); + + return sqlExpressionFactory.Function( + sqlFunctionName, + [argument], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[1], + argument.Type, + typeMapping); + } + + private SqlExpression TranslateFunction(string sqlFunctionName, SqlExpression arg1, SqlExpression arg2) + { + var typeMapping = ExpressionExtensions.InferTypeMapping(arg1, arg2); + arg1 = sqlExpressionFactory.ApplyTypeMapping(arg1, typeMapping); + arg2 = sqlExpressionFactory.ApplyTypeMapping(arg2, typeMapping); + + return sqlExpressionFactory.Function( + sqlFunctionName, + [arg1, arg2], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[2], + arg1.Type, + typeMapping); + } + + private SqlExpression TranslateRoundWithDigits(SqlExpression arg, SqlExpression digits) + => sqlExpressionFactory.Function( + "round", + [arg, digits], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[2], + arg.Type, + arg.TypeMapping); + + private SqlExpression TranslateLogWithBase(SqlExpression a, SqlExpression newBase) + { + var typeMapping = ExpressionExtensions.InferTypeMapping(a, newBase); - return null; + return sqlExpressionFactory.Function( + "log", + [ + sqlExpressionFactory.ApplyTypeMapping(newBase, typeMapping), + sqlExpressionFactory.ApplyTypeMapping(a, typeMapping) + ], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[2], + a.Type, + typeMapping); } } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteObjectToStringTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteObjectToStringTranslator.cs index 61400abc90f..41ace797aa3 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteObjectToStringTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteObjectToStringTranslator.cs @@ -12,42 +12,8 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteObjectToStringTranslator : IMethodCallTranslator +public class SqliteObjectToStringTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly HashSet TypeMapping = - [ - typeof(bool), - typeof(byte), - typeof(byte[]), - typeof(char), - typeof(DateOnly), - typeof(DateTime), - typeof(DateTimeOffset), - typeof(decimal), - typeof(double), - typeof(float), - typeof(Guid), - typeof(int), - typeof(long), - typeof(sbyte), - typeof(short), - typeof(TimeOnly), - typeof(TimeSpan), - typeof(uint), - typeof(ushort) - ]; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteObjectToStringTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -74,34 +40,54 @@ public SqliteObjectToStringTranslator(ISqlExpressionFactory sqlExpressionFactory { if (instance is not ColumnExpression { IsNullable: false }) { - return _sqlExpressionFactory.Case( + return sqlExpressionFactory.Case( instance, [ new CaseWhenClause( - _sqlExpressionFactory.Constant(false), - _sqlExpressionFactory.Constant(false.ToString())), + sqlExpressionFactory.Constant(false), + sqlExpressionFactory.Constant(false.ToString())), new CaseWhenClause( - _sqlExpressionFactory.Constant(true), - _sqlExpressionFactory.Constant(true.ToString())) + sqlExpressionFactory.Constant(true), + sqlExpressionFactory.Constant(true.ToString())) ], - _sqlExpressionFactory.Constant(string.Empty)); + sqlExpressionFactory.Constant(string.Empty)); } - return _sqlExpressionFactory.Case( + return sqlExpressionFactory.Case( [ new CaseWhenClause( instance, - _sqlExpressionFactory.Constant(true.ToString())) + sqlExpressionFactory.Constant(true.ToString())) ], - _sqlExpressionFactory.Constant(false.ToString())); + sqlExpressionFactory.Constant(false.ToString())); } // Enums are handled by EnumMethodTranslator - return TypeMapping.Contains(instance.Type) - ? _sqlExpressionFactory.Coalesce( - _sqlExpressionFactory.Convert(instance, typeof(string)), - _sqlExpressionFactory.Constant(string.Empty)) + return IsSupportedType(instance.Type) + ? sqlExpressionFactory.Coalesce( + sqlExpressionFactory.Convert(instance, typeof(string)), + sqlExpressionFactory.Constant(string.Empty)) : null; } + + private static bool IsSupportedType(Type type) + => type == typeof(byte) + || type == typeof(byte[]) + || type == typeof(char) + || type == typeof(DateOnly) + || type == typeof(DateTime) + || type == typeof(DateTimeOffset) + || type == typeof(decimal) + || type == typeof(double) + || type == typeof(float) + || type == typeof(Guid) + || type == typeof(int) + || type == typeof(long) + || type == typeof(sbyte) + || type == typeof(short) + || type == typeof(TimeOnly) + || type == typeof(TimeSpan) + || type == typeof(uint) + || type == typeof(ushort); } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteQueryableAggregateMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteQueryableAggregateMethodTranslator.cs index c9f11e150d2..af1e816e87c 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteQueryableAggregateMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteQueryableAggregateMethodTranslator.cs @@ -13,19 +13,8 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteQueryableAggregateMethodTranslator : IAggregateMethodCallTranslator +public class SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IAggregateMethodCallTranslator { - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -53,7 +42,7 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress if (averageArgumentType == typeof(decimal)) { averageSqlExpression = CombineTerms(source, averageSqlExpression); - return _sqlExpressionFactory.Function( + return sqlExpressionFactory.Function( "ef_avg", [averageSqlExpression], nullable: true, @@ -80,7 +69,7 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress if (maxArgumentType == typeof(decimal)) { maxSqlExpression = CombineTerms(source, maxSqlExpression); - return _sqlExpressionFactory.Function( + return sqlExpressionFactory.Function( "ef_max", [maxSqlExpression], nullable: true, @@ -107,7 +96,7 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress if (minArgumentType == typeof(decimal)) { minSqlExpression = CombineTerms(source, minSqlExpression); - return _sqlExpressionFactory.Function( + return sqlExpressionFactory.Function( "ef_min", [minSqlExpression], nullable: true, @@ -126,7 +115,7 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress if (sumArgumentType == typeof(decimal)) { sumSqlExpression = CombineTerms(source, sumSqlExpression); - return _sqlExpressionFactory.Function( + return sqlExpressionFactory.Function( "ef_sum", [sumSqlExpression], nullable: true, @@ -151,8 +140,8 @@ private SqlExpression CombineTerms(EnumerableExpression enumerableExpression, Sq { if (enumerableExpression.Predicate != null) { - sqlExpression = _sqlExpressionFactory.Case( - new List { new(enumerableExpression.Predicate, sqlExpression) }, + sqlExpression = sqlExpressionFactory.Case( + [new(enumerableExpression.Predicate, sqlExpression)], elseResult: null); } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteRandomTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteRandomTranslator.cs index b111e7de682..8d798587bd9 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteRandomTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteRandomTranslator.cs @@ -12,22 +12,8 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteRandomTranslator : IMethodCallTranslator +public class SqliteRandomTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo MethodInfo - = typeof(DbFunctionsExtensions).GetMethod(nameof(DbFunctionsExtensions.Random), [typeof(DbFunctions)])!; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteRandomTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -40,21 +26,22 @@ public SqliteRandomTranslator(ISqlExpressionFactory sqlExpressionFactory) IReadOnlyList arguments, IDiagnosticsLogger logger) // Issue #15586: Query: TypeCompatibility chart for inference. - => MethodInfo.Equals(method) - ? _sqlExpressionFactory.Function( - "abs", - [ - _sqlExpressionFactory.Divide( - _sqlExpressionFactory.Function( - "random", - [], - nullable: false, - argumentsPropagateNullability: [], - method.ReturnType), - _sqlExpressionFactory.Constant(9223372036854780000.0)) - ], - nullable: false, - argumentsPropagateNullability: Statics.TrueArrays[1], - method.ReturnType) - : null; + => method.DeclaringType == typeof(DbFunctionsExtensions) + && method.Name == nameof(DbFunctionsExtensions.Random) + ? sqlExpressionFactory.Function( + "abs", + [ + sqlExpressionFactory.Divide( + sqlExpressionFactory.Function( + "random", + [], + nullable: false, + argumentsPropagateNullability: [], + method.ReturnType), + sqlExpressionFactory.Constant(9223372036854780000.0)) + ], + nullable: false, + argumentsPropagateNullability: Statics.TrueArrays[1], + method.ReturnType) + : null; } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteRegexMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteRegexMethodTranslator.cs index 6df83ebe284..dd2e1eebf75 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteRegexMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteRegexMethodTranslator.cs @@ -13,22 +13,8 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteRegexMethodTranslator : IMethodCallTranslator +public class SqliteRegexMethodTranslator(SqliteSqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo RegexIsMatchMethodInfo - = typeof(Regex).GetRuntimeMethod(nameof(Regex.IsMatch), [typeof(string), typeof(string)])!; - - private readonly SqliteSqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteRegexMethodTranslator(SqliteSqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -41,12 +27,11 @@ public SqliteRegexMethodTranslator(SqliteSqlExpressionFactory sqlExpressionFacto IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (method.Equals(RegexIsMatchMethodInfo)) + if (method.DeclaringType == typeof(Regex) + && method.Name == nameof(Regex.IsMatch) + && arguments is [var input, var pattern]) { - var input = arguments[0]; - var pattern = arguments[1]; - - return _sqlExpressionFactory.Regexp(input, pattern); + return sqlExpressionFactory.Regexp(input, pattern); } return null; diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringAggregateMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringAggregateMethodTranslator.cs index 492539f26ed..1f048b447ea 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringAggregateMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringAggregateMethodTranslator.cs @@ -12,25 +12,8 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteStringAggregateMethodTranslator : IAggregateMethodCallTranslator +public class SqliteStringAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IAggregateMethodCallTranslator { - private static readonly MethodInfo StringConcatMethod - = typeof(string).GetRuntimeMethod(nameof(string.Concat), [typeof(IEnumerable)])!; - - private static readonly MethodInfo StringJoinMethod - = typeof(string).GetRuntimeMethod(nameof(string.Join), [typeof(string), typeof(IEnumerable)])!; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteStringAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -46,11 +29,24 @@ public SqliteStringAggregateMethodTranslator(ISqlExpressionFactory sqlExpression // Docs: https://sqlite.org/lang_aggfunc.html#group_concat if (source.Selector is not SqlExpression sqlExpression - || (method != StringJoinMethod && method != StringConcatMethod)) + || method.DeclaringType != typeof(string)) { return null; } + SqlExpression separator; + switch (method.Name) + { + case nameof(string.Concat) when arguments is []: + separator = sqlExpressionFactory.Constant(string.Empty, typeof(string)); + break; + case nameof(string.Join) when arguments is [var sep]: + separator = sep; + break; + default: + return null; + } + // SQLite does not support input ordering on aggregate methods. Since ordering matters very much for translating, if the user // specified an ordering we refuse to translate (but to error than to ignore in this case). if (source.Orderings.Count > 0) @@ -58,18 +54,18 @@ public SqliteStringAggregateMethodTranslator(ISqlExpressionFactory sqlExpression return null; } - sqlExpression = _sqlExpressionFactory.Coalesce( + sqlExpression = sqlExpressionFactory.Coalesce( sqlExpression, - _sqlExpressionFactory.Constant(string.Empty, typeof(string))); + sqlExpressionFactory.Constant(string.Empty, typeof(string))); if (source.Predicate != null) { if (sqlExpression is SqlFragmentExpression) { - sqlExpression = _sqlExpressionFactory.Constant(1); + sqlExpression = sqlExpressionFactory.Constant(1); } - sqlExpression = _sqlExpressionFactory.Case( + sqlExpression = sqlExpressionFactory.Case( new List { new(source.Predicate, sqlExpression) }, elseResult: null); } @@ -80,19 +76,17 @@ public SqliteStringAggregateMethodTranslator(ISqlExpressionFactory sqlExpression } // group_concat returns null when there are no rows (or non-null values), but string.Join returns an empty string. - return _sqlExpressionFactory.Coalesce( - _sqlExpressionFactory.Function( + return sqlExpressionFactory.Coalesce( + sqlExpressionFactory.Function( "group_concat", [ sqlExpression, - _sqlExpressionFactory.ApplyTypeMapping( - method == StringJoinMethod ? arguments[0] : _sqlExpressionFactory.Constant(string.Empty, typeof(string)), - sqlExpression.TypeMapping) + sqlExpressionFactory.ApplyTypeMapping(separator, sqlExpression.TypeMapping) ], nullable: true, argumentsPropagateNullability: Statics.FalseArrays[2], typeof(string)), - _sqlExpressionFactory.Constant(string.Empty, typeof(string)), + sqlExpressionFactory.Constant(string.Empty, typeof(string)), sqlExpression.TypeMapping); } } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringLengthTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringLengthTranslator.cs index 10e8a4416f7..ada70dff1cf 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringLengthTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringLengthTranslator.cs @@ -12,19 +12,8 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteStringLengthTranslator : IMemberTranslator +public class SqliteStringLengthTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMemberTranslator { - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteStringLengthTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -38,7 +27,7 @@ public SqliteStringLengthTranslator(ISqlExpressionFactory sqlExpressionFactory) IDiagnosticsLogger logger) => member.DeclaringType == typeof(string) && member.Name == nameof(string.Length) - ? _sqlExpressionFactory.Function( + ? sqlExpressionFactory.Function( "length", [instance!], nullable: true, diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringMethodTranslator.cs index c806954ee51..d98656ea86e 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringMethodTranslator.cs @@ -13,95 +13,8 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteStringMethodTranslator : IMethodCallTranslator +public class SqliteStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo IndexOfMethodInfoString - = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(string)])!; - - private static readonly MethodInfo IndexOfMethodInfoChar - = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(char)])!; - - private static readonly MethodInfo IndexOfMethodInfoWithStartingPositionString - = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(string), typeof(int)])!; - - private static readonly MethodInfo IndexOfMethodInfoWithStartingPositionChar - = typeof(string).GetRuntimeMethod(nameof(string.IndexOf), [typeof(char), typeof(int)])!; - - private static readonly MethodInfo ReplaceMethodInfoString - = typeof(string).GetRuntimeMethod(nameof(string.Replace), [typeof(string), typeof(string)])!; - - private static readonly MethodInfo ReplaceMethodInfoChar - = typeof(string).GetRuntimeMethod(nameof(string.Replace), [typeof(char), typeof(char)])!; - - private static readonly MethodInfo ToLowerMethodInfo - = typeof(string).GetRuntimeMethod(nameof(string.ToLower), Type.EmptyTypes)!; - - private static readonly MethodInfo ToUpperMethodInfo - = typeof(string).GetRuntimeMethod(nameof(string.ToUpper), Type.EmptyTypes)!; - - private static readonly MethodInfo SubstringMethodInfoWithOneArg - = typeof(string).GetRuntimeMethod(nameof(string.Substring), [typeof(int)])!; - - private static readonly MethodInfo SubstringMethodInfoWithTwoArgs - = typeof(string).GetRuntimeMethod(nameof(string.Substring), [typeof(int), typeof(int)])!; - - private static readonly MethodInfo IsNullOrWhiteSpaceMethodInfo - = typeof(string).GetRuntimeMethod(nameof(string.IsNullOrWhiteSpace), [typeof(string)])!; - - // Method defined in netcoreapp2.0 only - private static readonly MethodInfo TrimStartMethodInfoWithoutArgs - = typeof(string).GetRuntimeMethod(nameof(string.TrimStart), Type.EmptyTypes)!; - - private static readonly MethodInfo TrimStartMethodInfoWithCharArg - = typeof(string).GetRuntimeMethod(nameof(string.TrimStart), [typeof(char)])!; - - private static readonly MethodInfo TrimEndMethodInfoWithoutArgs - = typeof(string).GetRuntimeMethod(nameof(string.TrimEnd), Type.EmptyTypes)!; - - private static readonly MethodInfo TrimEndMethodInfoWithCharArg - = typeof(string).GetRuntimeMethod(nameof(string.TrimEnd), [typeof(char)])!; - - private static readonly MethodInfo TrimMethodInfoWithoutArgs - = typeof(string).GetRuntimeMethod(nameof(string.Trim), Type.EmptyTypes)!; - - private static readonly MethodInfo TrimMethodInfoWithCharArg - = typeof(string).GetRuntimeMethod(nameof(string.Trim), [typeof(char)])!; - - // Method defined in netstandard2.0 - private static readonly MethodInfo TrimStartMethodInfoWithCharArrayArg - = typeof(string).GetRuntimeMethod(nameof(string.TrimStart), [typeof(char[])])!; - - private static readonly MethodInfo TrimEndMethodInfoWithCharArrayArg - = typeof(string).GetRuntimeMethod(nameof(string.TrimEnd), [typeof(char[])])!; - - private static readonly MethodInfo TrimMethodInfoWithCharArrayArg - = typeof(string).GetRuntimeMethod(nameof(string.Trim), [typeof(char[])])!; - - private static readonly MethodInfo ContainsMethodInfoString - = typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(string)])!; - - private static readonly MethodInfo ContainsMethodInfoChar - = typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(char)])!; - - private static readonly MethodInfo FirstOrDefaultMethodInfoWithoutArgs - = typeof(Enumerable).GetRuntimeMethods().Single(m => m.Name == nameof(Enumerable.FirstOrDefault) - && m.GetParameters().Length == 1).MakeGenericMethod(typeof(char)); - - private static readonly MethodInfo LastOrDefaultMethodInfoWithoutArgs - = typeof(Enumerable).GetRuntimeMethods().Single(m => m.Name == nameof(Enumerable.LastOrDefault) - && m.GetParameters().Length == 1).MakeGenericMethod(typeof(char)); - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -114,201 +27,192 @@ public SqliteStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (instance != null) + if (method.DeclaringType == typeof(string)) { - if (IndexOfMethodInfoString.Equals(method) || IndexOfMethodInfoChar.Equals(method)) - { - var argument = arguments[0]; - var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, argument); - - return _sqlExpressionFactory.Subtract( - _sqlExpressionFactory.Function( - "instr", - [ - _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping), - _sqlExpressionFactory.ApplyTypeMapping( - argument, argument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping) - ], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[2], - method.ReturnType), - _sqlExpressionFactory.Constant(1)); - } - - if (IndexOfMethodInfoWithStartingPositionString.Equals(method) || IndexOfMethodInfoWithStartingPositionChar.Equals(method)) + if (instance is not null) { - var argument = arguments[0]; - var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, argument); - instance = _sqlExpressionFactory.Function( - "substr", - [instance, _sqlExpressionFactory.Add(arguments[1], _sqlExpressionFactory.Constant(1))], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[2], - method.ReturnType, - instance.TypeMapping); - - return _sqlExpressionFactory.Add( - _sqlExpressionFactory.Subtract( - _sqlExpressionFactory.Function( - "instr", - [ - _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping), - _sqlExpressionFactory.ApplyTypeMapping( - argument, argument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping) - ], + return method.Name switch + { + nameof(string.IndexOf) when arguments is [var arg] + => TranslateIndexOf(instance, arg), + nameof(string.IndexOf) when arguments is [var arg, var startIndex] + => TranslateIndexOfWithStartingPosition(instance, arg, startIndex), + nameof(string.Replace) when arguments is [var oldValue, var newValue] + => TranslateReplace(instance, oldValue, newValue), + nameof(string.ToLower) when arguments is [] + => TranslateSimpleFunction("lower", instance), + nameof(string.ToUpper) when arguments is [] + => TranslateSimpleFunction("upper", instance), + nameof(string.Substring) when arguments is [var startIndex] + => sqlExpressionFactory.Function( + "substr", + [instance, sqlExpressionFactory.Add(startIndex, sqlExpressionFactory.Constant(1))], nullable: true, argumentsPropagateNullability: Statics.TrueArrays[2], - method.ReturnType), - _sqlExpressionFactory.Constant(1)), - arguments[1]); - } - - if (ReplaceMethodInfoString.Equals(method) || ReplaceMethodInfoChar.Equals(method)) - { - var firstArgument = arguments[0]; - var secondArgument = arguments[1]; - var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, firstArgument, secondArgument); - - return _sqlExpressionFactory.Function( - "replace", - [ - _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping), - _sqlExpressionFactory.ApplyTypeMapping( - firstArgument, firstArgument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping), - _sqlExpressionFactory.ApplyTypeMapping( - secondArgument, secondArgument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping) - ], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[3], - method.ReturnType, - stringTypeMapping); + method.ReturnType, + instance.TypeMapping), + nameof(string.Substring) when arguments is [var startIndex, var length] + => sqlExpressionFactory.Function( + "substr", + [instance, sqlExpressionFactory.Add(startIndex, sqlExpressionFactory.Constant(1)), length], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[3], + method.ReturnType, + instance.TypeMapping), + nameof(string.TrimStart) when arguments is [] or [_] + => ProcessTrimMethod(instance, arguments, "ltrim"), + nameof(string.TrimEnd) when arguments is [] or [_] + => ProcessTrimMethod(instance, arguments, "rtrim"), + nameof(string.Trim) when arguments is [] or [_] + => ProcessTrimMethod(instance, arguments, "trim"), + nameof(string.Contains) when arguments is [var pattern] + => TranslateContains(instance, pattern), + _ => null + }; } - if (ToLowerMethodInfo.Equals(method) - || ToUpperMethodInfo.Equals(method)) + // Static string methods + return method.Name switch { - return _sqlExpressionFactory.Function( - ToLowerMethodInfo.Equals(method) ? "lower" : "upper", - [instance], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[1], - method.ReturnType, - instance.TypeMapping); - } + nameof(string.IsNullOrWhiteSpace) when arguments is [var arg] + => sqlExpressionFactory.OrElse( + sqlExpressionFactory.IsNull(arg), + sqlExpressionFactory.Equal( + sqlExpressionFactory.Function( + "trim", + [arg], + nullable: true, + argumentsPropagateNullability: [true], + arg.Type, + arg.TypeMapping), + sqlExpressionFactory.Constant(string.Empty))), + _ => null + }; + } - if (SubstringMethodInfoWithOneArg.Equals(method)) - { - return _sqlExpressionFactory.Function( + if (method.DeclaringType == typeof(Enumerable) + && method.Name is nameof(Enumerable.FirstOrDefault) or nameof(Enumerable.LastOrDefault) + && method.IsGenericMethod + && method.GetGenericArguments()[0] == typeof(char) + && arguments is [var source]) + { + return method.Name == nameof(Enumerable.FirstOrDefault) + ? sqlExpressionFactory.Function( "substr", - [instance, _sqlExpressionFactory.Add(arguments[0], _sqlExpressionFactory.Constant(1))], + [source, sqlExpressionFactory.Constant(1), sqlExpressionFactory.Constant(1)], nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[2], - method.ReturnType, - instance.TypeMapping); - } - - if (SubstringMethodInfoWithTwoArgs.Equals(method)) - { - return _sqlExpressionFactory.Function( + argumentsPropagateNullability: Statics.TrueArrays[3], + method.ReturnType) + : sqlExpressionFactory.Function( "substr", - [instance, _sqlExpressionFactory.Add(arguments[0], _sqlExpressionFactory.Constant(1)), arguments[1]], + [ + source, + sqlExpressionFactory.Function( + "length", + [source], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[1], + typeof(int)), + sqlExpressionFactory.Constant(1) + ], nullable: true, argumentsPropagateNullability: Statics.TrueArrays[3], - method.ReturnType, - instance.TypeMapping); - } + method.ReturnType); + } - if (TrimStartMethodInfoWithoutArgs.Equals(method) - || TrimStartMethodInfoWithCharArg.Equals(method) - || TrimStartMethodInfoWithCharArrayArg.Equals(method)) - { - return ProcessTrimMethod(instance, arguments, "ltrim"); - } + return null; + } - if (TrimEndMethodInfoWithoutArgs.Equals(method) - || TrimEndMethodInfoWithCharArg.Equals(method) - || TrimEndMethodInfoWithCharArrayArg.Equals(method)) - { - return ProcessTrimMethod(instance, arguments, "rtrim"); - } + private SqlExpression TranslateIndexOf(SqlExpression instance, SqlExpression argument) + { + var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, argument); - if (TrimMethodInfoWithoutArgs.Equals(method) - || TrimMethodInfoWithCharArg.Equals(method) - || TrimMethodInfoWithCharArrayArg.Equals(method)) - { - return ProcessTrimMethod(instance, arguments, "trim"); - } + return sqlExpressionFactory.Subtract( + sqlExpressionFactory.Function( + "instr", + [ + sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping), + sqlExpressionFactory.ApplyTypeMapping( + argument, argument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping) + ], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[2], + typeof(int)), + sqlExpressionFactory.Constant(1)); + } - if (ContainsMethodInfoString.Equals(method) || ContainsMethodInfoChar.Equals(method)) - { - var pattern = arguments[0]; - var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, pattern); + private SqlExpression TranslateIndexOfWithStartingPosition( + SqlExpression instance, SqlExpression argument, SqlExpression startIndex) + { + var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, argument); + instance = sqlExpressionFactory.Function( + "substr", + [instance, sqlExpressionFactory.Add(startIndex, sqlExpressionFactory.Constant(1))], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[2], + typeof(string), + instance.TypeMapping); - instance = _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping); - pattern = _sqlExpressionFactory.ApplyTypeMapping( - pattern, pattern.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping); + return sqlExpressionFactory.Add( + sqlExpressionFactory.Subtract( + sqlExpressionFactory.Function( + "instr", + [ + sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping), + sqlExpressionFactory.ApplyTypeMapping( + argument, argument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping) + ], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[2], + typeof(int)), + sqlExpressionFactory.Constant(1)), + startIndex); + } - return - _sqlExpressionFactory.GreaterThan( - _sqlExpressionFactory.Function( - "instr", - [instance, pattern], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[2], - typeof(int)), - _sqlExpressionFactory.Constant(0)); - } - } + private SqlExpression TranslateReplace(SqlExpression instance, SqlExpression oldValue, SqlExpression newValue) + { + var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, oldValue, newValue); + + return sqlExpressionFactory.Function( + "replace", + [ + sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping), + sqlExpressionFactory.ApplyTypeMapping( + oldValue, oldValue.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping), + sqlExpressionFactory.ApplyTypeMapping( + newValue, newValue.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping) + ], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[3], + typeof(string), + stringTypeMapping); + } - if (IsNullOrWhiteSpaceMethodInfo.Equals(method)) - { - var argument = arguments[0]; + private SqlExpression TranslateSimpleFunction(string functionName, SqlExpression instance) + => sqlExpressionFactory.Function( + functionName, + [instance], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[1], + instance.Type, + instance.TypeMapping); - return _sqlExpressionFactory.OrElse( - _sqlExpressionFactory.IsNull(argument), - _sqlExpressionFactory.Equal( - _sqlExpressionFactory.Function( - "trim", - [argument], - nullable: true, - argumentsPropagateNullability: [true], - argument.Type, - argument.TypeMapping), - _sqlExpressionFactory.Constant(string.Empty))); - } + private SqlExpression TranslateContains(SqlExpression instance, SqlExpression pattern) + { + var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, pattern); - if (FirstOrDefaultMethodInfoWithoutArgs.Equals(method)) - { - var argument = arguments[0]; - return _sqlExpressionFactory.Function( - "substr", - [argument, _sqlExpressionFactory.Constant(1), _sqlExpressionFactory.Constant(1)], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[3], - method.ReturnType); - } + instance = sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping); + pattern = sqlExpressionFactory.ApplyTypeMapping( + pattern, pattern.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping); - if (LastOrDefaultMethodInfoWithoutArgs.Equals(method)) - { - var argument = arguments[0]; - return _sqlExpressionFactory.Function( - "substr", - [ - argument, - _sqlExpressionFactory.Function( - "length", - [argument], - nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[1], - typeof(int)), - _sqlExpressionFactory.Constant(1) - ], + return sqlExpressionFactory.GreaterThan( + sqlExpressionFactory.Function( + "instr", + [instance, pattern], nullable: true, - argumentsPropagateNullability: Statics.TrueArrays[3], - method.ReturnType); - } - - return null; + argumentsPropagateNullability: Statics.TrueArrays[2], + typeof(int)), + sqlExpressionFactory.Constant(0)); } private SqlExpression? ProcessTrimMethod(SqlExpression instance, IReadOnlyList arguments, string functionName) @@ -340,11 +244,11 @@ public SqliteStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) if (charactersToTrim.Count > 0) { - sqlArguments.Add(_sqlExpressionFactory.Constant(new string(charactersToTrim.ToArray()), typeMapping)); + sqlArguments.Add(sqlExpressionFactory.Constant(new string(charactersToTrim.ToArray()), typeMapping)); } } - return _sqlExpressionFactory.Function( + return sqlExpressionFactory.Function( functionName, sqlArguments, nullable: true, diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteSubstrMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteSubstrMethodTranslator.cs index a9b2883a72b..92464ba5f99 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteSubstrMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteSubstrMethodTranslator.cs @@ -12,26 +12,8 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class SqliteSubstrMethodTranslator : IMethodCallTranslator +public class SqliteSubstrMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IMethodCallTranslator { - private static readonly MethodInfo MethodInfo = typeof(SqliteDbFunctionsExtensions) - .GetMethod(nameof(SqliteDbFunctionsExtensions.Substr), [typeof(DbFunctions), typeof(byte[]), typeof(int)])!; - - private static readonly MethodInfo MethodInfoWithLength = typeof(SqliteDbFunctionsExtensions) - .GetMethod( - nameof(SqliteDbFunctionsExtensions.Substr), [typeof(DbFunctions), typeof(byte[]), typeof(int), typeof(int)])!; - - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public SqliteSubstrMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) - => _sqlExpressionFactory = sqlExpressionFactory; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -44,16 +26,18 @@ public SqliteSubstrMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) IReadOnlyList arguments, IDiagnosticsLogger logger) { - if (method.Equals(MethodInfo) - || method.Equals(MethodInfoWithLength)) + if (method.DeclaringType == typeof(SqliteDbFunctionsExtensions) + && method.Name == nameof(SqliteDbFunctionsExtensions.Substr) + && arguments is [_, var source, ..]) { - return _sqlExpressionFactory.Function( + var funcArgs = arguments.Skip(1).ToArray(); + return sqlExpressionFactory.Function( "substr", - arguments.Skip(1), + funcArgs, nullable: true, - arguments.Skip(1).Select(_ => true).ToArray(), + funcArgs.Select(_ => true).ToArray(), typeof(byte[]), - arguments[1].TypeMapping); + source.TypeMapping); } return null;