diff --git a/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/BuiltinFunctionVisitor.cs b/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/BuiltinFunctionVisitor.cs index 3dbbedce16..7867846946 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/BuiltinFunctionVisitor.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/BuiltinFunctionVisitor.cs @@ -1,121 +1,124 @@ -//------------------------------------------------------------ -// Copyright (c) Microsoft Corporation. All rights reserved. -//------------------------------------------------------------ - -namespace Microsoft.Azure.Cosmos.Linq -{ - using System; - using System.Collections.Generic; - using System.Globalization; - using System.Linq.Expressions; - using Microsoft.Azure.Cosmos; - using Microsoft.Azure.Cosmos.Spatial; - using Microsoft.Azure.Cosmos.SqlObjects; - using Microsoft.Azure.Documents; - - internal abstract class BuiltinFunctionVisitor - { - public SqlScalarExpression Visit(MethodCallExpression methodCallExpression, TranslationContext context) - { - SqlScalarExpression result = this.VisitExplicit(methodCallExpression, context); - if (result != null) - { - return result; - } - - result = this.VisitImplicit(methodCallExpression, context); - if (result != null) - { - return result; - } - - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.MethodNotSupported, methodCallExpression.Method.Name)); - } - - public static SqlScalarExpression VisitBuiltinFunctionCall(MethodCallExpression methodCallExpression, TranslationContext context) - { - Type declaringType; - - // Method could be an extension method - if (methodCallExpression.Method.IsStatic && methodCallExpression.Method.IsExtensionMethod()) - { - if (methodCallExpression.Arguments.Count < 1) - { - // Extension methods should has at least 1 argument, this should never happen - // Throwing ArgumentException instead of assert - throw new ArgumentException(); - } - - declaringType = methodCallExpression.Arguments[0].Type; - - if (methodCallExpression.Method.DeclaringType.GeUnderlyingSystemType() == typeof(CosmosLinqExtensions)) - { - // CosmosLinq Extensions can be RegexMatch, DocumentId or Type check functions (IsString, IsBool, etc.) - if ((methodCallExpression.Method.Name == nameof(CosmosLinqExtensions.RegexMatch)) || - (methodCallExpression.Method.Name == nameof(CosmosLinqExtensions.FullTextContains)) || - (methodCallExpression.Method.Name == nameof(CosmosLinqExtensions.FullTextContainsAll)) || - (methodCallExpression.Method.Name == nameof(CosmosLinqExtensions.FullTextContainsAny))) - { - return StringBuiltinFunctions.Visit(methodCallExpression, context); +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Linq +{ + using System; + using System.Collections.Generic; + using System.Globalization; + using System.Linq.Expressions; + using Microsoft.Azure.Cosmos; + using Microsoft.Azure.Cosmos.Spatial; + using Microsoft.Azure.Cosmos.SqlObjects; + using Microsoft.Azure.Documents; + + internal abstract class BuiltinFunctionVisitor + { + public SqlScalarExpression Visit(MethodCallExpression methodCallExpression, TranslationContext context) + { + SqlScalarExpression result = this.VisitExplicit(methodCallExpression, context); + if (result != null) + { + return result; + } + + result = this.VisitImplicit(methodCallExpression, context); + if (result != null) + { + return result; + } + + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.MethodNotSupported, methodCallExpression.Method.Name)); + } + + public static SqlScalarExpression VisitBuiltinFunctionCall(MethodCallExpression methodCallExpression, TranslationContext context) + { + Type declaringType; + bool isExtensionMethod = methodCallExpression.Method.IsExtensionMethod(); + // Method could be an extension method + // RRF doesn't have "this" qualifier, so it's not considered an extension method by the compiler, and so needed to be checked separately + if (methodCallExpression.Method.IsStatic && + (methodCallExpression.Method.IsExtensionMethod() + || methodCallExpression.Method.Name.Equals(nameof(CosmosLinqExtensions.RRF)))) + { + if (methodCallExpression.Arguments.Count < 1) + { + // Extension methods should has at least 1 argument, this should never happen + // Throwing ArgumentException instead of assert + throw new ArgumentException(); + } + + declaringType = methodCallExpression.Arguments[0].Type; + + if (methodCallExpression.Method.DeclaringType.GeUnderlyingSystemType() == typeof(CosmosLinqExtensions)) + { + // CosmosLinq Extensions can be RegexMatch, DocumentId or Type check functions (IsString, IsBool, etc.) + switch (methodCallExpression.Method.Name) + { + case nameof(CosmosLinqExtensions.RegexMatch): + case nameof(CosmosLinqExtensions.FullTextContains): + case nameof(CosmosLinqExtensions.FullTextContainsAll): + case nameof(CosmosLinqExtensions.FullTextContainsAny): + return StringBuiltinFunctions.Visit(methodCallExpression, context); + case nameof(CosmosLinqExtensions.DocumentId): + case nameof(CosmosLinqExtensions.RRF): + case nameof(CosmosLinqExtensions.FullTextScore): + return OtherBuiltinSystemFunctions.Visit(methodCallExpression, context); + default: + return TypeCheckFunctions.Visit(methodCallExpression, context); } - - if (methodCallExpression.Method.Name == nameof(CosmosLinqExtensions.DocumentId)) - { - return OtherBuiltinSystemFunctions.Visit(methodCallExpression, context); - } - - return TypeCheckFunctions.Visit(methodCallExpression, context); - } - } - else - { - declaringType = methodCallExpression.Method.DeclaringType; - } - - // Check order matters, some extension methods work for both strings and arrays - - // Math functions - if (declaringType == typeof(Math)) - { - return MathBuiltinFunctions.Visit(methodCallExpression, context); - } - - // ToString with String and Guid only becomes passthrough - if (methodCallExpression.Method.Name == "ToString" && - methodCallExpression.Arguments.Count == 0 && - methodCallExpression.Object != null && - ((methodCallExpression.Object.Type == typeof(string)) || - (methodCallExpression.Object.Type == typeof(Guid)))) - { - return ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Object, context); - } - - // String functions or ToString with Objects that are not strings and guids - if ((declaringType == typeof(string)) || - (methodCallExpression.Method.Name == "ToString" && - methodCallExpression.Arguments.Count == 0 && - methodCallExpression.Object != null)) - { - return StringBuiltinFunctions.Visit(methodCallExpression, context); - } - - // Array functions - if (declaringType.IsEnumerable()) - { - return ArrayBuiltinFunctions.Visit(methodCallExpression, context); - } - - // Spatial functions - if (typeof(Geometry).IsAssignableFrom(declaringType)) - { - return SpatialBuiltinFunctions.Visit(methodCallExpression, context); - } - - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.MethodNotSupported, methodCallExpression.Method.Name)); - } - - protected abstract SqlScalarExpression VisitExplicit(MethodCallExpression methodCallExpression, TranslationContext context); - - protected abstract SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context); - } -} + } + } + else + { + declaringType = methodCallExpression.Method.DeclaringType; + } + + // Check order matters, some extension methods work for both strings and arrays + + // Math functions + if (declaringType == typeof(Math)) + { + return MathBuiltinFunctions.Visit(methodCallExpression, context); + } + + // ToString with String and Guid only becomes passthrough + if (methodCallExpression.Method.Name == "ToString" && + methodCallExpression.Arguments.Count == 0 && + methodCallExpression.Object != null && + ((methodCallExpression.Object.Type == typeof(string)) || + (methodCallExpression.Object.Type == typeof(Guid)))) + { + return ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Object, context); + } + + // String functions or ToString with Objects that are not strings and guids + if ((declaringType == typeof(string)) || + (methodCallExpression.Method.Name == "ToString" && + methodCallExpression.Arguments.Count == 0 && + methodCallExpression.Object != null)) + { + return StringBuiltinFunctions.Visit(methodCallExpression, context); + } + + // Array functions + if (declaringType.IsEnumerable()) + { + return ArrayBuiltinFunctions.Visit(methodCallExpression, context); + } + + // Spatial functions + if (typeof(Geometry).IsAssignableFrom(declaringType)) + { + return SpatialBuiltinFunctions.Visit(methodCallExpression, context); + } + + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.MethodNotSupported, methodCallExpression.Method.Name)); + } + + protected abstract SqlScalarExpression VisitExplicit(MethodCallExpression methodCallExpression, TranslationContext context); + + protected abstract SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context); + } +} diff --git a/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/OtherBuiltinSystemFunctions.cs b/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/OtherBuiltinSystemFunctions.cs index ea748b9a74..684a68f770 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/OtherBuiltinSystemFunctions.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/OtherBuiltinSystemFunctions.cs @@ -6,12 +6,88 @@ namespace Microsoft.Azure.Cosmos.Linq { using System; using System.Collections.Generic; + using System.Collections.Immutable; + using System.Collections.ObjectModel; using System.Globalization; using System.Linq.Expressions; using Microsoft.Azure.Cosmos.SqlObjects; internal static class OtherBuiltinSystemFunctions { + private class RRFVisit : SqlBuiltinFunctionVisitor + { + public RRFVisit() + : base("RRF", + true, + new List() + { + new Type[]{typeof(Func[])} + }) + { + } + + protected override SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context) + { + if (methodCallExpression.Arguments.Count == 1 + && methodCallExpression.Arguments[0] is NewArrayExpression argumentsExpressions) + { + // For RRF, We don't need to care about the first argument, it is the object itself and have no relevance to the computation + ReadOnlyCollection functionListExpression = argumentsExpressions.Expressions; + List arguments = new List(); + foreach (Expression argument in functionListExpression) + { + arguments.Add(ExpressionToSql.VisitScalarExpression(argument, context)); + } + + return SqlFunctionCallScalarExpression.CreateBuiltin(SqlFunctionCallScalarExpression.Names.RRF, arguments.ToImmutableArray()); + } + + return null; + } + + protected override SqlScalarExpression VisitExplicit(MethodCallExpression methodCallExpression, TranslationContext context) + { + return null; + } + } + + private class FullTextScoreVisit : SqlBuiltinFunctionVisitor + { + public FullTextScoreVisit() + : base("FullTextScore", + true, + new List() + { + new Type[]{typeof(object), typeof(string[])} + }) + { + } + + protected override SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context) + { + if (methodCallExpression.Arguments.Count == 2 + && methodCallExpression.Arguments[1] is ConstantExpression stringListArgumentExpression + && ExpressionToSql.VisitConstant(stringListArgumentExpression, context) is SqlArrayCreateScalarExpression arrayScalarExpressions) + { + List arguments = new List + { + ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Arguments[0], context) + }; + + arguments.AddRange(arrayScalarExpressions.Items); + + return SqlFunctionCallScalarExpression.CreateBuiltin(SqlFunctionCallScalarExpression.Names.FullTextScore, arguments.ToImmutableArray()); + } + + return null; + } + + protected override SqlScalarExpression VisitExplicit(MethodCallExpression methodCallExpression, TranslationContext context) + { + return null; + } + } + private static Dictionary FunctionsDefinitions { get; set; } static OtherBuiltinSystemFunctions() @@ -24,7 +100,9 @@ static OtherBuiltinSystemFunctions() argumentLists: new List() { new Type[]{typeof(object)}, - }) + }), + [nameof(CosmosLinqExtensions.RRF)] = new RRFVisit(), + [nameof(CosmosLinqExtensions.FullTextScore)] = new FullTextScoreVisit(), }; } diff --git a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs index 6b84c04198..14eab9aa33 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs @@ -1,246 +1,248 @@ -//------------------------------------------------------------ -// Copyright (c) Microsoft Corporation. All rights reserved. -//------------------------------------------------------------ - -namespace Microsoft.Azure.Cosmos.Linq -{ - using System; - using System.Collections.Generic; - using System.Diagnostics; - using System.Linq; - using System.Linq.Expressions; - using System.Reflection; - using System.Threading; - using System.Threading.Tasks; - using Microsoft.Azure.Cosmos.Diagnostics; - using Microsoft.Azure.Cosmos.Tracing; - - /// - /// This class provides extension methods for cosmos LINQ code. - /// - public static class CosmosLinqExtensions +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Linq +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Diagnostics; + using System.Linq; + using System.Linq.Expressions; + using System.Reflection; + using System.Runtime.CompilerServices; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.Cosmos.Diagnostics; + using Microsoft.Azure.Cosmos.Tracing; + + /// + /// This class provides extension methods for cosmos LINQ code. + /// + public static class CosmosLinqExtensions { - /// - /// Returns the integer identifier corresponding to a specific item within a physical partition. - /// This method is to be used in LINQ expressions only and will be evaluated on server. - /// There's no implementation provided in the client library. - /// - /// The root object - /// Returns the integer identifier corresponding to a specific item within a physical partition. - /// - /// - /// root.DocumentId()); - /// ]]> - /// - /// - public static int DocumentId(this object obj) - { - throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); - } - - /// - /// Returns a Boolean value indicating if the type of the specified expression is an array. - /// This method is to be used in LINQ expressions only and will be evaluated on server. - /// There's no implementation provided in the client library. - /// - /// - /// Returns true if the type of the specified expression is an array; otherwise, false. - /// - /// - /// document.Names.IsArray()); - /// ]]> - /// - /// - public static bool IsArray(this object obj) - { - throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); - } - - /// - /// Returns a Boolean value indicating if the type of the specified expression is a boolean. - /// This method is to be used in LINQ expressions only and will be evaluated on server. - /// There's no implementation provided in the client library. - /// - /// - /// Returns true if the type of the specified expression is a boolean; otherwise, false. - /// - /// - /// document.IsRegistered.IsBool()); - /// ]]> - /// - /// - public static bool IsBool(this object obj) - { - throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); - } - - /// - /// Determines if a certain property is defined or not. - /// This method is to be used in LINQ expressions only and will be evaluated on server. - /// There's no implementation provided in the client library. - /// - /// - /// Returns true if this property is defined otherwise returns false. - /// - /// - /// document.Name.IsDefined()); - /// ]]> - /// - /// - public static bool IsDefined(this object obj) - { - throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); - } - - /// - /// Determines if a certain property is null or not. - /// This method is to be used in LINQ expressions only and will be evaluated on server. - /// There's no implementation provided in the client library. - /// - /// - /// Returns true if this property is null otherwise returns false. - /// - /// - /// document.Name.IsNull()); - /// ]]> - /// - /// - public static bool IsNull(this object obj) - { - throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); - } - - /// - /// Returns a Boolean value indicating if the type of the specified expression is a number. - /// This method is to be used in LINQ expressions only and will be evaluated on server. - /// There's no implementation provided in the client library. - /// - /// - /// Returns true if the type of the specified expression is a number; otherwise, false. - /// - /// - /// document.Age.IsNumber()); - /// ]]> - /// - /// - public static bool IsNumber(this object obj) - { - throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); - } - - /// - /// Returns a Boolean value indicating if the type of the specified expression is an object. - /// This method is to be used in LINQ expressions only and will be evaluated on server. - /// There's no implementation provided in the client library. - /// - /// - /// Returns true if the type of the specified expression is an object; otherwise, false. - /// - /// - /// document.Address.IsObject()); - /// ]]> - /// - /// - public static bool IsObject(this object obj) - { - throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); - } - - /// - /// Determines if a certain property is of primitive JSON type. - /// This method is to be used in LINQ expressions only and will be evaluated on server. - /// There's no implementation provided in the client library. - /// - /// - /// Returns true if this property is null otherwise returns false. - /// - /// Primitive JSON types (Double, String, Boolean and Null) - /// - /// - /// - /// document.Name.IsPrimitive()); - /// ]]> - /// - /// - public static bool IsPrimitive(this object obj) - { - throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); - } - - /// - /// Returns a Boolean value indicating if the type of the specified expression is a string. - /// This method is to be used in LINQ expressions only and will be evaluated on server. - /// There's no implementation provided in the client library. - /// - /// - /// Returns true if the type of the specified expression is a string; otherwise, false. - /// - /// - /// document.Name.IsString()); - /// ]]> - /// - /// - public static bool IsString(this object obj) - { - throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); - } - - /// - /// Returns a Boolean value indicating if the specified expression matches the supplied regex pattern. - /// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/regexmatch. - /// This method is to be used in LINQ expressions only and will be evaluated on server. - /// There's no implementation provided in the client library. - /// - /// - /// A string expression with a regular expression defined to use when searching. - /// Returns true if the string matches the regex expressions; otherwise, false. - /// - /// - /// document.Name.RegexMatch()); - /// ]]> - /// - /// - public static bool RegexMatch(this object obj, string regularExpression) - { - throw new NotImplementedException(ClientResources.ExtensionMethodNotImplemented); - } - - /// - /// Returns a Boolean value indicating if the specified expression matches the supplied regex pattern. - /// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/regexmatch. - /// This method is to be used in LINQ expressions only and will be evaluated on server. - /// There's no implementation provided in the client library. - /// - /// - /// A string expression with a regular expression defined to use when searching. - /// An optional string expression with the selected modifiers to use with the regular expression. - /// Returns true if the string matches the regex expressions; otherwise, false. - /// - /// - /// document.Name.RegexMatch(, )); - /// ]]> - /// - /// - public static bool RegexMatch(this object obj, string regularExpression, string searchModifier) - { - throw new NotImplementedException(ClientResources.ExtensionMethodNotImplemented); + /// + /// Returns the integer identifier corresponding to a specific item within a physical partition. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// The root object + /// Returns the integer identifier corresponding to a specific item within a physical partition. + /// + /// + /// root.DocumentId()); + /// ]]> + /// + /// + public static int DocumentId(this object obj) + { + throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); + } + + /// + /// Returns a Boolean value indicating if the type of the specified expression is an array. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// + /// Returns true if the type of the specified expression is an array; otherwise, false. + /// + /// + /// document.Names.IsArray()); + /// ]]> + /// + /// + public static bool IsArray(this object obj) + { + throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); + } + + /// + /// Returns a Boolean value indicating if the type of the specified expression is a boolean. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// + /// Returns true if the type of the specified expression is a boolean; otherwise, false. + /// + /// + /// document.IsRegistered.IsBool()); + /// ]]> + /// + /// + public static bool IsBool(this object obj) + { + throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); + } + + /// + /// Determines if a certain property is defined or not. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// + /// Returns true if this property is defined otherwise returns false. + /// + /// + /// document.Name.IsDefined()); + /// ]]> + /// + /// + public static bool IsDefined(this object obj) + { + throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); + } + + /// + /// Determines if a certain property is null or not. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// + /// Returns true if this property is null otherwise returns false. + /// + /// + /// document.Name.IsNull()); + /// ]]> + /// + /// + public static bool IsNull(this object obj) + { + throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); + } + + /// + /// Returns a Boolean value indicating if the type of the specified expression is a number. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// + /// Returns true if the type of the specified expression is a number; otherwise, false. + /// + /// + /// document.Age.IsNumber()); + /// ]]> + /// + /// + public static bool IsNumber(this object obj) + { + throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); + } + + /// + /// Returns a Boolean value indicating if the type of the specified expression is an object. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// + /// Returns true if the type of the specified expression is an object; otherwise, false. + /// + /// + /// document.Address.IsObject()); + /// ]]> + /// + /// + public static bool IsObject(this object obj) + { + throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); + } + + /// + /// Determines if a certain property is of primitive JSON type. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// + /// Returns true if this property is null otherwise returns false. + /// + /// Primitive JSON types (Double, String, Boolean and Null) + /// + /// + /// + /// document.Name.IsPrimitive()); + /// ]]> + /// + /// + public static bool IsPrimitive(this object obj) + { + throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); + } + + /// + /// Returns a Boolean value indicating if the type of the specified expression is a string. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// + /// Returns true if the type of the specified expression is a string; otherwise, false. + /// + /// + /// document.Name.IsString()); + /// ]]> + /// + /// + public static bool IsString(this object obj) + { + throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); + } + + /// + /// Returns a Boolean value indicating if the specified expression matches the supplied regex pattern. + /// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/regexmatch. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// + /// A string expression with a regular expression defined to use when searching. + /// Returns true if the string matches the regex expressions; otherwise, false. + /// + /// + /// document.Name.RegexMatch()); + /// ]]> + /// + /// + public static bool RegexMatch(this object obj, string regularExpression) + { + throw new NotImplementedException(ClientResources.ExtensionMethodNotImplemented); + } + + /// + /// Returns a Boolean value indicating if the specified expression matches the supplied regex pattern. + /// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/regexmatch. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// + /// A string expression with a regular expression defined to use when searching. + /// An optional string expression with the selected modifiers to use with the regular expression. + /// Returns true if the string matches the regex expressions; otherwise, false. + /// + /// + /// document.Name.RegexMatch(, )); + /// ]]> + /// + /// + public static bool RegexMatch(this object obj, string regularExpression, string searchModifier) + { + throw new NotImplementedException(ClientResources.ExtensionMethodNotImplemented); } /// /// Returns a boolean indicating whether the keyword string expression is contained in a specified property path. - /// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/fulltextcontains. - /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/fulltextcontains. + /// This method is to be used in LINQ expressions only and will be evaluated on server. /// There's no implementation provided in the client library. /// /// @@ -248,8 +250,8 @@ public static bool RegexMatch(this object obj, string regularExpression, string /// Returns true if a given string is contained in the specified property of a document. /// /// - /// document.Name.FullTextContains()); + /// document.Name.FullTextContains()); /// ]]> /// /// @@ -260,8 +262,8 @@ public static bool FullTextContains(this object obj, string search) /// /// Returns a boolean indicating whether all of the provided string expressions are contained in a specified property path. - /// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/fulltextcontainsall. - /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/fulltextcontainsall. + /// This method is to be used in LINQ expressions only and will be evaluated on server. /// There's no implementation provided in the client library. /// /// @@ -269,8 +271,8 @@ public static bool FullTextContains(this object obj, string search) /// Returns true if all of the given strings are contained in the specified property of a document. /// /// - /// document.Name.FullTextContainsAll(, , , ...)); + /// document.Name.FullTextContainsAll(, , , ...)); /// ]]> /// /// @@ -281,8 +283,8 @@ public static bool FullTextContainsAll(this object obj, params string[] searches /// /// Returns a boolean indicating whether any of the provided string expressions are contained in a specified property path. - /// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/fulltextcontainsany. - /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/fulltextcontainsany. + /// This method is to be used in LINQ expressions only and will be evaluated on server. /// There's no implementation provided in the client library. /// /// @@ -290,8 +292,8 @@ public static bool FullTextContainsAll(this object obj, params string[] searches /// Returns true if any of the given strings are contained in the specified property of a document. /// /// - /// document.Name.FullTextContainsAny(, , , ...)); + /// document.Name.FullTextContainsAny(, , , ...)); /// ]]> /// /// @@ -300,650 +302,723 @@ public static bool FullTextContainsAny(this object obj, params string[] searches throw new NotImplementedException(ClientResources.ExtensionMethodNotImplemented); } - /// - /// This method generate query definition from LINQ query. - /// - /// the type of object to query. - /// the IQueryable{T} to be converted. - /// Dictionary containing parameter value and name for parameterized query - /// The queryDefinition which can be used in query execution. - /// - /// This example shows how to generate query definition from LINQ. - /// - /// - /// queryable = container.GetItemsQueryIterator(allowSynchronousQueryExecution = true) - /// .Where(t => b.id.contains("test")); - /// QueryDefinition queryDefinition = queryable.ToQueryDefinition(); - /// ]]> - /// - /// -#if PREVIEW - public -#else - internal -#endif - static QueryDefinition ToQueryDefinition(this IQueryable query, IDictionary namedParameters) - { - if (namedParameters == null) - { - throw new ArgumentException("namedParameters dictionary cannot be empty for this overload, please use ToQueryDefinition(IQueryable query) instead", nameof(namedParameters)); - } - - if (query is CosmosLinqQuery linqQuery) - { - return linqQuery.ToQueryDefinition(namedParameters); - } - - throw new ArgumentException("ToQueryDefinition is only supported on Cosmos LINQ query operations", nameof(query)); - } - - /// - /// This method generate query definition from LINQ query. - /// - /// the type of object to query. - /// the IQueryable{T} to be converted. - /// The queryDefinition which can be used in query execution. - /// - /// This example shows how to generate query definition from LINQ. - /// - /// - /// queryable = container.GetItemsQueryIterator(allowSynchronousQueryExecution = true) - /// .Where(t => b.id.contains("test")); - /// QueryDefinition queryDefinition = queryable.ToQueryDefinition(); - /// ]]> - /// - /// - public static QueryDefinition ToQueryDefinition(this IQueryable query) - { - if (query is CosmosLinqQuery linqQuery) - { - return linqQuery.ToQueryDefinition(); - - } - - throw new ArgumentException("ToQueryDefinition is only supported on Cosmos LINQ query operations", nameof(query)); - } - - /// - /// This extension method gets the FeedIterator from LINQ IQueryable to execute query asynchronously. - /// This will create the fresh new FeedIterator when called. - /// - /// the type of object to query. - /// the IQueryable{T} to be converted. - /// An iterator to go through the items. - /// - /// This example shows how to get FeedIterator from LINQ. - /// - /// - /// linqQueryable = this.Container.GetItemLinqQueryable(); - /// using (FeedIterator setIterator = linqQueryable.Where(item => (item.taskNum < 100)).ToFeedIterator() - /// ]]> - /// - /// - public static FeedIterator ToFeedIterator(this IQueryable query) - { - if (!(query is CosmosLinqQuery linqQuery)) - { - throw new ArgumentOutOfRangeException(nameof(linqQuery), "ToFeedIterator is only supported on Cosmos LINQ query operations"); - } - - return linqQuery.ToFeedIterator(); - } - - /// - /// This extension method gets the FeedIterator from LINQ IQueryable to execute query asynchronously. - /// This will create the fresh new FeedIterator when called. - /// - /// the type of object to query. - /// the IQueryable{T} to be converted. - /// An iterator to go through the items. - /// - /// This example shows how to get FeedIterator from LINQ. - /// - /// - /// linqQueryable = this.Container.GetItemLinqQueryable(); - /// using (FeedIterator setIterator = linqQueryable.Where(item => (item.taskNum < 100)).ToStreamIterator()) - /// ]]> - /// - /// - public static FeedIterator ToStreamIterator(this IQueryable query) - { - if (!(query is CosmosLinqQuery linqQuery)) - { - throw new ArgumentOutOfRangeException(nameof(linqQuery), "ToStreamFeedIterator is only supported on cosmos LINQ query operations"); - } - - return linqQuery.ToStreamIterator(); - } - - /// - /// Returns the maximum value in a generic . - /// - /// The type of the elements of source. - /// A sequence of values to determine the maximum of. - /// The cancellation token. - /// The maximum value in the sequence. - public static Task> MaxAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Max()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, TSource>(Queryable.Max), - source.Expression), - cancellationToken); - } - - /// - /// Returns the minimum value in a generic . - /// - /// The type of the elements of source. - /// A sequence of values to determine the minimum of. - /// The cancellation token. - /// The minimum value in the sequence. - public static Task> MinAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Min()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, TSource>(Queryable.Min), - source.Expression), - cancellationToken); - } - - /// - /// Computes the average of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> AverageAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Average()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, decimal>(Queryable.Average), - source.Expression), - cancellationToken); - } - - /// - /// Computes the average of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> AverageAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Average()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, decimal?>(Queryable.Average), - source.Expression), - cancellationToken); - } - - /// - /// Computes the average of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> AverageAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Average()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, double>(Queryable.Average), - source.Expression), - cancellationToken); - } - - /// - /// Computes the average of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> AverageAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Average()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, double?>(Queryable.Average), - source.Expression), - cancellationToken); - } - - /// - /// Computes the average of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> AverageAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Average()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, float>(Queryable.Average), - source.Expression), - cancellationToken); - } - - /// - /// Computes the average of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> AverageAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Average()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, float?>(Queryable.Average), - source.Expression), - cancellationToken); - } - - /// - /// Computes the average of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> AverageAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Average()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, double>(Queryable.Average), - source.Expression), - cancellationToken); - } - - /// - /// Computes the average of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> AverageAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Average()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, double?>(Queryable.Average), - source.Expression), - cancellationToken); - } - - /// - /// Computes the average of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> AverageAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Average()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, double>(Queryable.Average), - source.Expression), - cancellationToken); - } - - /// - /// Computes the average of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> AverageAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Average()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, double?>(Queryable.Average), - source.Expression), - cancellationToken); - } - - /// - /// Computes the sum of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> SumAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Sum()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, decimal>(Queryable.Sum), - source.Expression), - cancellationToken); - } - - /// - /// Computes the sum of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> SumAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Sum()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, decimal?>(Queryable.Sum), - source.Expression), - cancellationToken); - } - - /// - /// Computes the sum of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> SumAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Sum()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, double>(Queryable.Sum), - source.Expression), - cancellationToken); - } - - /// - /// Computes the sum of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> SumAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Sum()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, double?>(Queryable.Sum), - source.Expression), - cancellationToken); - } - - /// - /// Computes the sum of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> SumAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Sum()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, float>(Queryable.Sum), - source.Expression), - cancellationToken); - } - - /// - /// Computes the sum of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> SumAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Sum()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, float?>(Queryable.Sum), - source.Expression), - cancellationToken); - } - - /// - /// Computes the sum of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> SumAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Sum()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, int>(Queryable.Sum), - source.Expression), - cancellationToken); - } - - /// - /// Computes the sum of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> SumAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Sum()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, int?>(Queryable.Sum), - source.Expression), - cancellationToken); - } - - /// - /// Computes the sum of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> SumAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Sum()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, long>(Queryable.Sum), - source.Expression), - cancellationToken); - } - - /// - /// Computes the sum of a sequence of values. - /// - /// A sequence of values to calculate the average of. - /// The cancellation token. - /// The average value in the sequence. - public static Task> SumAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Sum()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, long?>(Queryable.Sum), - source.Expression), - cancellationToken); - } - - /// - /// Returns the number of elements in a sequence. - /// - /// The type of the elements of source. - /// The sequence that contains the elements to be counted. - /// The cancellation token. - /// The number of elements in the input sequence. - public static Task> CountAsync( - this IQueryable source, - CancellationToken cancellationToken = default) - { - if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) - { - return ResponseHelperAsync(source.Count()); - } - - return cosmosLinqQueryProvider.ExecuteAggregateAsync( - Expression.Call( - GetMethodInfoOf, int>(Queryable.Count), - source.Expression), - cancellationToken); - } - - private static Task> ResponseHelperAsync(T value) - { - return Task.FromResult>( - new ItemResponse( - System.Net.HttpStatusCode.OK, - new Headers(), - value, - new CosmosTraceDiagnostics(NoOpTrace.Singleton), - null)); - } - - private static MethodInfo GetMethodInfoOf(Func func) - { - Debug.Assert(func != null); - return func.GetMethodInfo(); - } - } -} + /// + /// Returns a BM25 score value that can only be used in an ORDER BY RANK function to sort results from highest relevancy to lowest relevancy. + /// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/fulltextscore. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// + /// A nonempty array of string literals. + /// Returns true BM25 score value that can only be used in an ORDER BY RANK clause. + /// + /// + /// document.Name.FullTextScore([], [keyword2])); + /// ]]> + /// + /// + public static Func FullTextScore(this TSource obj, params string[] terms) + { + throw new NotImplementedException(ClientResources.ExtensionMethodNotImplemented); + } + + /// + /// This optional ORDER BY RANK clause sorts scoring functions by their rank. It's used specifically for scoring functions like VectorDistance, FullTextScore, and RRF. + /// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/order-by-rank. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// A sequence of values to order. + /// A scoring function. + /// Returns the sequence with the elements ordered according to the rank of the scoring functions. + /// + /// + /// document.Name.FullTextScore()); + /// ]]> + /// + /// + public static IOrderedQueryable OrderByRank(this IQueryable source, Expression> scoreFunction) + { + if (!(source is CosmosLinqQuery)) + { + throw new ArgumentException("OrderByRank is only supported on Cosmos LINQ query operations"); + } + + return (IOrderedQueryable)source.Provider.CreateQuery( + Expression.Call( + null, + typeof(CosmosLinqExtensions).GetMethod("OrderByRank").MakeGenericMethod(typeof(TSource)), + source.Expression, + Expression.Quote(scoreFunction))); + } + + /// + /// This system function is used to combine two or more scores provided by other scoring functions. + /// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/rrf. + /// This method is to be used in LINQ expressions only and will be evaluated on server. + /// There's no implementation provided in the client library. + /// + /// the scoring functions to combine. + /// Returns the the combined scores of the scoring functions. + /// + /// + /// document.RRF(document.Name.FullTextScore(), document.Address.FullTextScore())); + /// ]]> + /// + /// + public static Func RRF(params Func[] scoringFunctions) + { + // The reason for not defining "this" keyword is because this causes undesirable serialization when call Expression.ToString() on this method + throw new NotImplementedException(ClientResources.ExtensionMethodNotImplemented); + } + + /// + /// This method generate query definition from LINQ query. + /// + /// the type of object to query. + /// the IQueryable{T} to be converted. + /// Dictionary containing parameter value and name for parameterized query + /// The queryDefinition which can be used in query execution. + /// + /// This example shows how to generate query definition from LINQ. + /// + /// + /// queryable = container.GetItemsQueryIterator(allowSynchronousQueryExecution = true) + /// .Where(t => b.id.contains("test")); + /// QueryDefinition queryDefinition = queryable.ToQueryDefinition(); + /// ]]> + /// + /// +#if PREVIEW + public +#else + internal +#endif + static QueryDefinition ToQueryDefinition(this IQueryable query, IDictionary namedParameters) + { + if (namedParameters == null) + { + throw new ArgumentException("namedParameters dictionary cannot be empty for this overload, please use ToQueryDefinition(IQueryable query) instead", nameof(namedParameters)); + } + + if (query is CosmosLinqQuery linqQuery) + { + return linqQuery.ToQueryDefinition(namedParameters); + } + + throw new ArgumentException("ToQueryDefinition is only supported on Cosmos LINQ query operations", nameof(query)); + } + + /// + /// This method generate query definition from LINQ query. + /// + /// the type of object to query. + /// the IQueryable{T} to be converted. + /// The queryDefinition which can be used in query execution. + /// + /// This example shows how to generate query definition from LINQ. + /// + /// + /// queryable = container.GetItemsQueryIterator(allowSynchronousQueryExecution = true) + /// .Where(t => b.id.contains("test")); + /// QueryDefinition queryDefinition = queryable.ToQueryDefinition(); + /// ]]> + /// + /// + public static QueryDefinition ToQueryDefinition(this IQueryable query) + { + if (query is CosmosLinqQuery linqQuery) + { + return linqQuery.ToQueryDefinition(); + + } + + throw new ArgumentException("ToQueryDefinition is only supported on Cosmos LINQ query operations", nameof(query)); + } + + /// + /// This extension method gets the FeedIterator from LINQ IQueryable to execute query asynchronously. + /// This will create the fresh new FeedIterator when called. + /// + /// the type of object to query. + /// the IQueryable{T} to be converted. + /// An iterator to go through the items. + /// + /// This example shows how to get FeedIterator from LINQ. + /// + /// + /// linqQueryable = this.Container.GetItemLinqQueryable(); + /// using (FeedIterator setIterator = linqQueryable.Where(item => (item.taskNum < 100)).ToFeedIterator() + /// ]]> + /// + /// + public static FeedIterator ToFeedIterator(this IQueryable query) + { + if (!(query is CosmosLinqQuery linqQuery)) + { + throw new ArgumentOutOfRangeException(nameof(linqQuery), "ToFeedIterator is only supported on Cosmos LINQ query operations"); + } + + return linqQuery.ToFeedIterator(); + } + + /// + /// This extension method gets the FeedIterator from LINQ IQueryable to execute query asynchronously. + /// This will create the fresh new FeedIterator when called. + /// + /// the type of object to query. + /// the IQueryable{T} to be converted. + /// An iterator to go through the items. + /// + /// This example shows how to get FeedIterator from LINQ. + /// + /// + /// linqQueryable = this.Container.GetItemLinqQueryable(); + /// using (FeedIterator setIterator = linqQueryable.Where(item => (item.taskNum < 100)).ToStreamIterator()) + /// ]]> + /// + /// + public static FeedIterator ToStreamIterator(this IQueryable query) + { + if (!(query is CosmosLinqQuery linqQuery)) + { + throw new ArgumentOutOfRangeException(nameof(linqQuery), "ToStreamFeedIterator is only supported on cosmos LINQ query operations"); + } + + return linqQuery.ToStreamIterator(); + } + + /// + /// Returns the maximum value in a generic . + /// + /// The type of the elements of source. + /// A sequence of values to determine the maximum of. + /// The cancellation token. + /// The maximum value in the sequence. + public static Task> MaxAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Max()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, TSource>(Queryable.Max), + source.Expression), + cancellationToken); + } + + /// + /// Returns the minimum value in a generic . + /// + /// The type of the elements of source. + /// A sequence of values to determine the minimum of. + /// The cancellation token. + /// The minimum value in the sequence. + public static Task> MinAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Min()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, TSource>(Queryable.Min), + source.Expression), + cancellationToken); + } + + /// + /// Computes the average of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> AverageAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Average()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, decimal>(Queryable.Average), + source.Expression), + cancellationToken); + } + + /// + /// Computes the average of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> AverageAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Average()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, decimal?>(Queryable.Average), + source.Expression), + cancellationToken); + } + + /// + /// Computes the average of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> AverageAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Average()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, double>(Queryable.Average), + source.Expression), + cancellationToken); + } + + /// + /// Computes the average of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> AverageAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Average()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, double?>(Queryable.Average), + source.Expression), + cancellationToken); + } + + /// + /// Computes the average of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> AverageAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Average()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, float>(Queryable.Average), + source.Expression), + cancellationToken); + } + + /// + /// Computes the average of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> AverageAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Average()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, float?>(Queryable.Average), + source.Expression), + cancellationToken); + } + + /// + /// Computes the average of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> AverageAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Average()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, double>(Queryable.Average), + source.Expression), + cancellationToken); + } + + /// + /// Computes the average of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> AverageAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Average()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, double?>(Queryable.Average), + source.Expression), + cancellationToken); + } + + /// + /// Computes the average of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> AverageAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Average()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, double>(Queryable.Average), + source.Expression), + cancellationToken); + } + + /// + /// Computes the average of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> AverageAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Average()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, double?>(Queryable.Average), + source.Expression), + cancellationToken); + } + + /// + /// Computes the sum of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> SumAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Sum()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, decimal>(Queryable.Sum), + source.Expression), + cancellationToken); + } + + /// + /// Computes the sum of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> SumAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Sum()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, decimal?>(Queryable.Sum), + source.Expression), + cancellationToken); + } + + /// + /// Computes the sum of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> SumAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Sum()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, double>(Queryable.Sum), + source.Expression), + cancellationToken); + } + + /// + /// Computes the sum of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> SumAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Sum()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, double?>(Queryable.Sum), + source.Expression), + cancellationToken); + } + + /// + /// Computes the sum of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> SumAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Sum()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, float>(Queryable.Sum), + source.Expression), + cancellationToken); + } + + /// + /// Computes the sum of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> SumAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Sum()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, float?>(Queryable.Sum), + source.Expression), + cancellationToken); + } + + /// + /// Computes the sum of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> SumAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Sum()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, int>(Queryable.Sum), + source.Expression), + cancellationToken); + } + + /// + /// Computes the sum of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> SumAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Sum()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, int?>(Queryable.Sum), + source.Expression), + cancellationToken); + } + + /// + /// Computes the sum of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> SumAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Sum()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, long>(Queryable.Sum), + source.Expression), + cancellationToken); + } + + /// + /// Computes the sum of a sequence of values. + /// + /// A sequence of values to calculate the average of. + /// The cancellation token. + /// The average value in the sequence. + public static Task> SumAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Sum()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, long?>(Queryable.Sum), + source.Expression), + cancellationToken); + } + + /// + /// Returns the number of elements in a sequence. + /// + /// The type of the elements of source. + /// The sequence that contains the elements to be counted. + /// The cancellation token. + /// The number of elements in the input sequence. + public static Task> CountAsync( + this IQueryable source, + CancellationToken cancellationToken = default) + { + if (!(source.Provider is CosmosLinqQueryProvider cosmosLinqQueryProvider)) + { + return ResponseHelperAsync(source.Count()); + } + + return cosmosLinqQueryProvider.ExecuteAggregateAsync( + Expression.Call( + GetMethodInfoOf, int>(Queryable.Count), + source.Expression), + cancellationToken); + } + + private static Task> ResponseHelperAsync(T value) + { + return Task.FromResult>( + new ItemResponse( + System.Net.HttpStatusCode.OK, + new Headers(), + value, + new CosmosTraceDiagnostics(NoOpTrace.Singleton), + null)); + } + + private static MethodInfo GetMethodInfoOf(Func func) + { + Debug.Assert(func != null); + return func.GetMethodInfo(); + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs b/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs index bf68408710..1d763a1dba 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs @@ -1,1792 +1,1802 @@ -//------------------------------------------------------------ -// Copyright (c) Microsoft Corporation. All rights reserved. -//------------------------------------------------------------ -namespace Microsoft.Azure.Cosmos.Linq -{ - using System; - using System.Collections; - using System.Collections.Generic; - using System.Collections.Immutable; - using System.Collections.ObjectModel; - using System.Data.Common; - using System.Diagnostics; - using System.Globalization; - using System.Linq; - using System.Linq.Expressions; - using System.Reflection; - using System.Text.RegularExpressions; +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ +namespace Microsoft.Azure.Cosmos.Linq +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Collections.Immutable; + using System.Collections.ObjectModel; + using System.Data.Common; + using System.Diagnostics; + using System.Globalization; + using System.Linq; + using System.Linq.Expressions; + using System.Reflection; + using System.Text.RegularExpressions; using Microsoft.Azure.Cosmos.CosmosElements; using Microsoft.Azure.Cosmos.Query.Core.ClientDistributionPlan.Cql; - using Microsoft.Azure.Cosmos.Serialization.HybridRow; - using Microsoft.Azure.Cosmos.Serializer; - using Microsoft.Azure.Cosmos.Spatial; - using Microsoft.Azure.Cosmos.SqlObjects; - using Microsoft.Azure.Documents; - using static Microsoft.Azure.Cosmos.Linq.FromParameterBindings; - - // ReSharper disable UnusedParameter.Local - - ////////////////////////////////////////////////////////////////////// - // - // data.SelectMany(x => x.fields.SelectMany(y => y.fields.Select(z => f(x,y,z))) - // expression tree: - // SelectMany - lambda - Selectmany - lambda - Select - lambda - f(x,y,z) - // | | | | | | - // data x .- fields y .- fields z - // | | - // x y - // parameter bound_to - // x data - // y x.fields - // z y.fields - // - // data.Where(x => f(x)).Select(y => g(y)) - // expression tree: - // Select - lambda - g(y) - // | | - // | y - // Where - lambda - f(x) - // | | - // data x - // - // parameter bound_to - // x data - // y Where - - /// - /// Core Linq to DocDBSQL translator. - /// - internal static class ExpressionToSql - { - public static class LinqMethods - { - public const string Any = "Any"; - public const string Average = "Average"; - public const string Count = "Count"; - public const string Distinct = "Distinct"; - public const string First = "First"; - public const string FirstOrDefault = "FirstOrDefault"; - public const string Max = "Max"; - public const string Min = "Min"; - public const string GroupBy = "GroupBy"; - public const string OrderBy = "OrderBy"; - public const string OrderByDescending = "OrderByDescending"; - public const string Select = "Select"; - public const string SelectMany = "SelectMany"; - public const string Single = "Single"; - public const string SingleOrDefault = "SingleOrDefault"; - public const string Skip = "Skip"; - public const string Sum = "Sum"; - public const string ThenBy = "ThenBy"; - public const string ThenByDescending = "ThenByDescending"; - public const string Take = "Take"; - public const string Where = "Where"; - } - - private static readonly string SqlRoot = "root"; - private static readonly string DefaultParameterName = "v"; - private static readonly bool usePropertyRef = false; - - /// - /// Toplevel entry point. - /// - /// An Expression representing a Query on a IDocumentQuery object. - /// Optional dictionary for parameter name and value - /// Optional serializer options. - /// Indicates the client operation that needs to be performed on the results of SqlQuery. - /// The corresponding SQL query. - public static SqlQuery TranslateQuery( - Expression inputExpression, - IDictionary parameters, - CosmosLinqSerializerOptionsInternal linqSerializerOptions, - out ScalarOperationKind clientOperation) - { - TranslationContext context = new TranslationContext(linqSerializerOptions, parameters); - ExpressionToSql.Translate(inputExpression, context); // ignore result here - - QueryUnderConstruction query = context.CurrentQuery; - query = query.FlattenAsPossible(); - SqlQuery result = query.GetSqlQuery(); - clientOperation = context.ClientOperation; - - return result; - } - - /// - /// Translate an expression into a query. - /// Query is constructed as a side-effect in context.CurrentQuery. - /// - /// Expression to translate. - /// Context for translation. - private static Collection Translate(Expression inputExpression, TranslationContext context) - { - Debug.Assert(context != null, "Translation Context should not be null"); - - if (inputExpression == null) - { - throw new ArgumentNullException("inputExpression"); - } - - Collection collection; - switch (inputExpression.NodeType) - { - case ExpressionType.Call: - MethodCallExpression methodCallExpression = (MethodCallExpression)inputExpression; - bool shouldConvertToScalarAnyCollection = (context.PeekMethod() == null) && methodCallExpression.Method.Name.Equals(LinqMethods.Any); - collection = ExpressionToSql.VisitMethodCall(methodCallExpression, context); - if (shouldConvertToScalarAnyCollection) collection = ExpressionToSql.ConvertToScalarAnyCollection(context); - - break; - - case ExpressionType.Constant: - collection = ExpressionToSql.TranslateInput((ConstantExpression)inputExpression, context); - break; - - case ExpressionType.MemberAccess: - collection = ExpressionToSql.VisitMemberAccessCollectionExpression(inputExpression, context, ExpressionToSql.GetBindingParameterName(context)); - break; - - case ExpressionType.Parameter: - SqlScalarExpression scalar = ExpressionToSql.VisitNonSubqueryScalarExpression(inputExpression, context); - collection = ExpressionToSql.ConvertToCollection(scalar); - break; - - default: - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, inputExpression.NodeType)); - } - return collection; - } - - private static Collection TranslateInput(ConstantExpression inputExpression, TranslationContext context) - { - if (!typeof(IDocumentQuery).IsAssignableFrom(inputExpression.Type)) - { - throw new DocumentQueryException(ClientResources.InputIsNotIDocumentQuery); - } - - // ExpressionToSql is the query input value: a IDocumentQuery - if (!(inputExpression.Value is IDocumentQuery input)) - { - throw new DocumentQueryException(ClientResources.InputIsNotIDocumentQuery); - } - - context.CurrentQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc()); - Type elemType = TypeSystem.GetElementType(inputExpression.Type); - context.SetInputParameter(elemType, ParameterSubstitution.InputParameterName); // ignore result - - // First outer collection - Collection result = new Collection(ExpressionToSql.SqlRoot); - return result; - } - - /// - /// Get a parameter name to be binded to the collection from the next lambda. - /// It's merely for readability purpose. If that is not possible, use a default - /// parameter name. - /// - /// The translation context - /// A parameter name - private static string GetBindingParameterName(TranslationContext context) - { - MethodCallExpression peekMethod = context.PeekMethod(); - - // The parameter name is the top method's parameter if applicable - string parameterName = null; - if (peekMethod.Arguments.Count > 1) - { - if (peekMethod.Arguments[1] is LambdaExpression lambda && lambda.Parameters.Count > 0) - { - parameterName = lambda.Parameters[0].Name; - } - } - - parameterName ??= ExpressionToSql.DefaultParameterName; - - return parameterName; - } - - #region VISITOR - - /// - /// Visitor which produces a SqlScalarExpression. - /// - /// Expression to visit. - /// Context information. - /// The translation as a ScalarExpression. - internal static SqlScalarExpression VisitNonSubqueryScalarExpression(Expression inputExpression, TranslationContext context) - { - if (inputExpression == null) - { - return null; - } - - switch (inputExpression.NodeType) - { - case ExpressionType.Negate: - case ExpressionType.NegateChecked: - case ExpressionType.Not: - case ExpressionType.Convert: - case ExpressionType.ConvertChecked: - case ExpressionType.ArrayLength: - case ExpressionType.Quote: - case ExpressionType.TypeAs: - return ExpressionToSql.VisitUnary((UnaryExpression)inputExpression, context); - case ExpressionType.Add: - case ExpressionType.AddChecked: - case ExpressionType.Subtract: - case ExpressionType.SubtractChecked: - case ExpressionType.Multiply: - case ExpressionType.MultiplyChecked: - case ExpressionType.Divide: - case ExpressionType.Modulo: - case ExpressionType.And: - case ExpressionType.AndAlso: - case ExpressionType.Or: - case ExpressionType.OrElse: - case ExpressionType.LessThan: - case ExpressionType.LessThanOrEqual: - case ExpressionType.GreaterThan: - case ExpressionType.GreaterThanOrEqual: - case ExpressionType.Equal: - case ExpressionType.NotEqual: - case ExpressionType.Coalesce: - case ExpressionType.ArrayIndex: - case ExpressionType.RightShift: - case ExpressionType.LeftShift: - case ExpressionType.ExclusiveOr: - return ExpressionToSql.VisitBinary((BinaryExpression)inputExpression, context); - case ExpressionType.TypeIs: - return ExpressionToSql.VisitTypeIs((TypeBinaryExpression)inputExpression, context); - case ExpressionType.Conditional: - return ExpressionToSql.VisitConditional((ConditionalExpression)inputExpression, context); - case ExpressionType.Constant: - return ExpressionToSql.VisitConstant((ConstantExpression)inputExpression, context); - case ExpressionType.Parameter: - return ExpressionToSql.VisitParameter((ParameterExpression)inputExpression, context); - case ExpressionType.MemberAccess: - return ExpressionToSql.VisitMemberAccess((MemberExpression)inputExpression, context); - case ExpressionType.New: - return ExpressionToSql.VisitNew((NewExpression)inputExpression, context); - case ExpressionType.NewArrayInit: - case ExpressionType.NewArrayBounds: - return ExpressionToSql.VisitNewArray((NewArrayExpression)inputExpression, context); - case ExpressionType.Invoke: - return ExpressionToSql.VisitInvocation((InvocationExpression)inputExpression, context); - case ExpressionType.MemberInit: - return ExpressionToSql.VisitMemberInit((MemberInitExpression)inputExpression, context); - case ExpressionType.ListInit: - return ExpressionToSql.VisitListInit((ListInitExpression)inputExpression, context); - case ExpressionType.Call: - return ExpressionToSql.VisitMethodCallScalar((MethodCallExpression)inputExpression, context); - default: - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, inputExpression.NodeType)); - } - } - - private static SqlScalarExpression VisitMethodCallScalar(MethodCallExpression methodCallExpression, TranslationContext context) - { - // Check if it is a UDF method call - if (methodCallExpression.Method.Equals(typeof(CosmosLinq).GetMethod("InvokeUserDefinedFunction"))) - { - string udfName = ((ConstantExpression)methodCallExpression.Arguments[0]).Value as string; - if (string.IsNullOrEmpty(udfName)) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.UdfNameIsNullOrEmpty)); - } - - SqlIdentifier methodName = SqlIdentifier.Create(udfName); - List arguments = new List(); - - if (methodCallExpression.Arguments.Count == 2) - { - // We have two cases here, if the udf was expecting only one parameter and this parameter is an array - // then the second argument will be an expression of this array. - // else we will have a NewArrayExpression of the udf arguments - if (methodCallExpression.Arguments[1] is NewArrayExpression newArrayExpression) - { - ReadOnlyCollection argumentsExpressions = newArrayExpression.Expressions; - foreach (Expression argument in argumentsExpressions) - { - arguments.Add(ExpressionToSql.VisitScalarExpression(argument, context)); - } - } - else if (methodCallExpression.Arguments[1].NodeType == ExpressionType.Constant && - methodCallExpression.Arguments[1].Type == typeof(object[])) - { - object[] argumentsExpressions = (object[])((ConstantExpression)methodCallExpression.Arguments[1]).Value; - foreach (object argument in argumentsExpressions) - { - arguments.Add(ExpressionToSql.VisitConstant(Expression.Constant(argument), context)); - } - } - else - { - arguments.Add(ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[1], context)); - } - } - - return SqlFunctionCallScalarExpression.Create(methodName, true, arguments.ToImmutableArray()); - } - else - { - return BuiltinFunctionVisitor.VisitBuiltinFunctionCall(methodCallExpression, context); - } - } - - private static SqlObjectProperty VisitBinding(MemberBinding binding, TranslationContext context) - { - switch (binding.BindingType) - { - case MemberBindingType.Assignment: - return ExpressionToSql.VisitMemberAssignment((MemberAssignment)binding, context); - case MemberBindingType.MemberBinding: - return ExpressionToSql.VisitMemberMemberBinding((MemberMemberBinding)binding, context); - case MemberBindingType.ListBinding: - default: - return ExpressionToSql.VisitMemberListBinding((MemberListBinding)binding, context); - } - } - - private static SqlUnaryScalarOperatorKind GetUnaryOperatorKind(ExpressionType type) - { - switch (type) - { - case ExpressionType.UnaryPlus: - return SqlUnaryScalarOperatorKind.Plus; - case ExpressionType.Negate: - return SqlUnaryScalarOperatorKind.Minus; - case ExpressionType.OnesComplement: - return SqlUnaryScalarOperatorKind.BitwiseNot; - case ExpressionType.Not: - return SqlUnaryScalarOperatorKind.Not; - default: - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.UnaryOperatorNotSupported, type)); - } - } - - private static SqlScalarExpression VisitUnary(UnaryExpression inputExpression, TranslationContext context) - { - SqlScalarExpression operand = ExpressionToSql.VisitScalarExpression(inputExpression.Operand, context); - - // handle NOT IN - if (operand is SqlInScalarExpression sqlInScalarExpression && inputExpression.NodeType == ExpressionType.Not) - { - SqlInScalarExpression inExpression = sqlInScalarExpression; - return SqlInScalarExpression.Create(inExpression.Needle, true, inExpression.Haystack); - } - - if (inputExpression.NodeType == ExpressionType.Quote) - { - return operand; - } - - if (inputExpression.NodeType == ExpressionType.Convert) - { - return operand; - } - - SqlUnaryScalarOperatorKind op = GetUnaryOperatorKind(inputExpression.NodeType); - return SqlUnaryScalarExpression.Create(op, operand); - } - - private static SqlBinaryScalarOperatorKind GetBinaryOperatorKind(ExpressionType expressionType, Type resultType) - { - switch (expressionType) - { - case ExpressionType.Add: - { - if (resultType == typeof(string)) - { - return SqlBinaryScalarOperatorKind.StringConcat; - } - return SqlBinaryScalarOperatorKind.Add; - } - case ExpressionType.AndAlso: - return SqlBinaryScalarOperatorKind.And; - case ExpressionType.And: - return SqlBinaryScalarOperatorKind.BitwiseAnd; - case ExpressionType.Or: - return SqlBinaryScalarOperatorKind.BitwiseOr; - case ExpressionType.ExclusiveOr: - return SqlBinaryScalarOperatorKind.BitwiseXor; - case ExpressionType.Divide: - return SqlBinaryScalarOperatorKind.Divide; - case ExpressionType.Equal: - return SqlBinaryScalarOperatorKind.Equal; - case ExpressionType.GreaterThan: - return SqlBinaryScalarOperatorKind.GreaterThan; - case ExpressionType.GreaterThanOrEqual: - return SqlBinaryScalarOperatorKind.GreaterThanOrEqual; - case ExpressionType.LessThan: - return SqlBinaryScalarOperatorKind.LessThan; - case ExpressionType.LessThanOrEqual: - return SqlBinaryScalarOperatorKind.LessThanOrEqual; - case ExpressionType.Modulo: - return SqlBinaryScalarOperatorKind.Modulo; - case ExpressionType.Multiply: - return SqlBinaryScalarOperatorKind.Multiply; - case ExpressionType.NotEqual: - return SqlBinaryScalarOperatorKind.NotEqual; - case ExpressionType.OrElse: - return SqlBinaryScalarOperatorKind.Or; - case ExpressionType.Subtract: - return SqlBinaryScalarOperatorKind.Subtract; - case ExpressionType.Coalesce: - return SqlBinaryScalarOperatorKind.Coalesce; - default: - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.BinaryOperatorNotSupported, expressionType)); - } - } - - private static SqlScalarExpression VisitBinary(BinaryExpression inputExpression, TranslationContext context) - { - // Speical case for string.CompareTo - // if any of the left or right expression is MethodCallExpression - // the other expression should only be constant (integer) - MethodCallExpression methodCallExpression = null; - ConstantExpression constantExpression = null; - - bool reverseNodeType = false; - if (inputExpression.Left.NodeType == ExpressionType.Call && inputExpression.Right.NodeType == ExpressionType.Constant) - { - methodCallExpression = (MethodCallExpression)inputExpression.Left; - constantExpression = (ConstantExpression)inputExpression.Right; - } - else if (inputExpression.Right.NodeType == ExpressionType.Call && inputExpression.Left.NodeType == ExpressionType.Constant) - { - methodCallExpression = (MethodCallExpression)inputExpression.Right; - constantExpression = (ConstantExpression)inputExpression.Left; - reverseNodeType = true; - } - - if (methodCallExpression != null && constantExpression != null) - { - if (TryMatchStringCompareTo(methodCallExpression, constantExpression, inputExpression.NodeType)) - { - return ExpressionToSql.VisitStringCompareTo(methodCallExpression, inputExpression.NodeType, reverseNodeType, context); - } - - if (TryMatchStringCompare(methodCallExpression, constantExpression, inputExpression.NodeType)) - { - return ExpressionToSql.VisitStringCompare(methodCallExpression, inputExpression.NodeType, reverseNodeType, context); - } - } - - SqlScalarExpression left = ExpressionToSql.VisitScalarExpression(inputExpression.Left, context); - SqlScalarExpression right = ExpressionToSql.VisitScalarExpression(inputExpression.Right, context); - - if (inputExpression.NodeType == ExpressionType.ArrayIndex) - { - SqlMemberIndexerScalarExpression result = SqlMemberIndexerScalarExpression.Create(left, right); - return result; - } - - SqlBinaryScalarOperatorKind op = GetBinaryOperatorKind(inputExpression.NodeType, inputExpression.Type); - - if (left is SqlMemberIndexerScalarExpression && right is SqlLiteralScalarExpression literalScalarExpression) - { - right = ExpressionToSql.ApplyCustomConverters(inputExpression.Left, literalScalarExpression, context); - } - else if (right is SqlMemberIndexerScalarExpression && left is SqlLiteralScalarExpression sqlLiteralScalarExpression) - { - left = ExpressionToSql.ApplyCustomConverters(inputExpression.Right, sqlLiteralScalarExpression, context); - } - - return SqlBinaryScalarExpression.Create(op, left, right); - } - - private static SqlScalarExpression ApplyCustomConverters(Expression left, SqlLiteralScalarExpression right, TranslationContext context) - { - MemberExpression memberExpression; - if (left is UnaryExpression unaryExpression) - { - memberExpression = unaryExpression.Operand as MemberExpression; - } - else - { - memberExpression = left as MemberExpression; - } - - if (memberExpression != null && - right.Literal is not SqlNullLiteral) - { - Type memberType = memberExpression.Type; - if (memberType.IsNullable()) - { - memberType = memberType.NullableUnderlyingType(); - } - - bool requiresCustomSerialization = context.CosmosLinqSerializer.RequiresCustomSerialization(memberExpression, memberType); - if (requiresCustomSerialization) - { - object value = default(object); - // Enum - if (memberType.IsEnum()) - { - try - { - Number64 number64 = ((SqlNumberLiteral)right.Literal).Value; - if (number64.IsDouble) - { - value = Enum.ToObject(memberType, Number64.ToDouble(number64)); - } - else - { - value = Enum.ToObject(memberType, Number64.ToLong(number64)); - } - } - catch - { - value = ((SqlStringLiteral)right.Literal).Value; - } - - } - // DateTime - else if (memberType == typeof(DateTime)) - { - SqlStringLiteral serializedDateTime = (SqlStringLiteral)right.Literal; - value = DateTime.Parse(serializedDateTime.Value, provider: null, DateTimeStyles.RoundtripKind); - } - - if (value != default(object)) - { - string serializedValue = context.CosmosLinqSerializer.Serialize(value, memberExpression, memberType); - return CosmosElement.Parse(serializedValue).Accept(CosmosElementToSqlScalarExpressionVisitor.Singleton); - } - } - } - - return right; - } - - private static bool TryMatchStringCompareTo(MethodCallExpression left, ConstantExpression right, ExpressionType compareOperator) - { - if (left.Method.Equals(typeof(string).GetMethod("CompareTo", new Type[] { typeof(string) })) && left.Arguments.Count == 1) - { - // operator can only be =, >, >=, <, <= - switch (compareOperator) - { - case ExpressionType.Equal: - case ExpressionType.GreaterThan: - case ExpressionType.GreaterThanOrEqual: - case ExpressionType.LessThan: - case ExpressionType.LessThanOrEqual: - break; - default: - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.StringCompareToInvalidOperator)); - } - - // the constant value should be zero, otherwise we can't determine how to translate the expression - // it could be either integer or nullable integer - if (!(right.Type == typeof(int) && (int)right.Value == 0) && - !(right.Type == typeof(int?) && ((int?)right.Value).HasValue && ((int?)right.Value).Value == 0)) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.StringCompareToInvalidConstant)); - } - - return true; - } - - return false; - } - - private static SqlScalarExpression VisitStringCompareTo( - MethodCallExpression left, - ExpressionType compareOperator, - bool reverseNodeType, - TranslationContext context) - { - if (reverseNodeType) - { - compareOperator = ReverseExpressionTypeForStrings(compareOperator, ClientResources.StringCompareToInvalidOperator); - } - - SqlBinaryScalarOperatorKind op = GetBinaryOperatorKind(compareOperator, null); - - SqlScalarExpression leftExpression = ExpressionToSql.VisitNonSubqueryScalarExpression(left.Object, context); - SqlScalarExpression rightExpression = ExpressionToSql.VisitNonSubqueryScalarExpression(left.Arguments[0], context); - - return SqlBinaryScalarExpression.Create(op, leftExpression, rightExpression); - } - - private static ExpressionType ReverseExpressionTypeForStrings(ExpressionType compareOperator, string errorMessage) - { - switch (compareOperator) - { - case ExpressionType.Equal: - // do nothing - break; - case ExpressionType.GreaterThan: - compareOperator = ExpressionType.LessThan; - break; - case ExpressionType.GreaterThanOrEqual: - compareOperator = ExpressionType.LessThanOrEqual; - break; - case ExpressionType.LessThan: - compareOperator = ExpressionType.GreaterThan; - break; - case ExpressionType.LessThanOrEqual: - compareOperator = ExpressionType.GreaterThanOrEqual; - break; - default: - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, errorMessage)); - } - - return compareOperator; - } - - private static bool TryMatchStringCompare(MethodCallExpression left, ConstantExpression right, ExpressionType compareOperator) - { - if (left.Method.Equals(typeof(string).GetMethod("Compare", new Type[] { typeof(string), typeof(string) })) && left.Arguments.Count == 2) - { - // operator can only be =, >, >=, <, <= - switch (compareOperator) - { - case ExpressionType.Equal: - case ExpressionType.GreaterThan: - case ExpressionType.GreaterThanOrEqual: - case ExpressionType.LessThan: - case ExpressionType.LessThanOrEqual: - break; - default: - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.StringCompareInvalidOperator)); - } - - // the constant value should be zero, otherwise we can't determine how to translate the expression - // it could be either integer or nullable integer - if (!(right.Type == typeof(int) && (int)right.Value == 0) && - !(right.Type == typeof(int?) && ((int?)right.Value).HasValue && ((int?)right.Value).Value == 0)) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.StringCompareInvalidConstant)); - } - - return true; - } - - return false; - } - - private static SqlScalarExpression VisitStringCompare( - MethodCallExpression left, - ExpressionType compareOperator, - bool reverseNodeType, - TranslationContext context) - { - if (reverseNodeType) - { - compareOperator = ReverseExpressionTypeForStrings(compareOperator, ClientResources.StringCompareInvalidOperator); - } - - SqlBinaryScalarOperatorKind op = GetBinaryOperatorKind(compareOperator, null); - - SqlScalarExpression leftExpression = ExpressionToSql.VisitNonSubqueryScalarExpression(left.Arguments[0], context); - SqlScalarExpression rightExpression = ExpressionToSql.VisitNonSubqueryScalarExpression(left.Arguments[1], context); - - return SqlBinaryScalarExpression.Create(op, leftExpression, rightExpression); - } - - private static SqlScalarExpression VisitTypeIs(TypeBinaryExpression inputExpression, TranslationContext context) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, inputExpression.NodeType)); - } - - public static SqlScalarExpression VisitConstant(ConstantExpression inputExpression, TranslationContext context) - { - if (inputExpression.Value == null) - { - return SqlLiteralScalarExpression.SqlNullLiteralScalarExpression; - } - - if (inputExpression.Type.IsNullable()) - { - return VisitConstant(Expression.Constant(inputExpression.Value, Nullable.GetUnderlyingType(inputExpression.Type)), context); - } - - if (context.Parameters != null && context.Parameters.TryGetValue(inputExpression.Value, out string paramName)) - { - SqlParameter sqlParameter = SqlParameter.Create(paramName); - return SqlParameterRefScalarExpression.Create(sqlParameter); - } - - Type constantType = inputExpression.Value.GetType(); - if (constantType.IsValueType) - { - if (inputExpression.Value is bool boolValue) - { - SqlBooleanLiteral literal = SqlBooleanLiteral.Create(boolValue); - return SqlLiteralScalarExpression.Create(literal); - } - - if (ExpressionToSql.TryGetSqlNumberLiteral(inputExpression.Value, out SqlNumberLiteral numberLiteral)) - { - return SqlLiteralScalarExpression.Create(numberLiteral); - } - - if (inputExpression.Value is Guid guidValue) - { - SqlStringLiteral literal = SqlStringLiteral.Create(guidValue.ToString()); - return SqlLiteralScalarExpression.Create(literal); - } - } - - if (inputExpression.Value is string stringValue) - { - SqlStringLiteral literal = SqlStringLiteral.Create(stringValue); - return SqlLiteralScalarExpression.Create(literal); - } - - if (typeof(Geometry).IsAssignableFrom(constantType)) - { - return GeometrySqlExpressionFactory.Construct(inputExpression); - } - - if (inputExpression.Value is IEnumerable enumerable) - { - List arrayItems = new List(); - - foreach (object item in enumerable) - { - arrayItems.Add(VisitConstant(Expression.Constant(item), context)); - } - - return SqlArrayCreateScalarExpression.Create(arrayItems.ToImmutableArray()); - } - - string serializedConstant = context.CosmosLinqSerializer.SerializeScalarExpression(inputExpression); - - return CosmosElement.Parse(serializedConstant).Accept(CosmosElementToSqlScalarExpressionVisitor.Singleton); - } - - private static SqlScalarExpression VisitConditional(ConditionalExpression inputExpression, TranslationContext context) - { - SqlScalarExpression conditionExpression = ExpressionToSql.VisitScalarExpression(inputExpression.Test, context); - SqlScalarExpression firstExpression = ExpressionToSql.VisitScalarExpression(inputExpression.IfTrue, context); - SqlScalarExpression secondExpression = ExpressionToSql.VisitScalarExpression(inputExpression.IfFalse, context); - - return SqlConditionalScalarExpression.Create(conditionExpression, firstExpression, secondExpression); - } - - private static SqlScalarExpression VisitParameter(ParameterExpression inputExpression, TranslationContext context) - { - Expression subst = context.LookupSubstitution(inputExpression); - if (subst != null) - { - return ExpressionToSql.VisitNonSubqueryScalarExpression(subst, context); - } - - string name = inputExpression.Name; - SqlIdentifier id = SqlIdentifier.Create(name); - return SqlPropertyRefScalarExpression.Create(null, id); - } - - private static SqlScalarExpression VisitMemberAccess(MemberExpression inputExpression, TranslationContext context) - { - SqlScalarExpression memberExpression = ExpressionToSql.VisitScalarExpression(inputExpression.Expression, context); - string memberName = inputExpression.Member.GetMemberName(context); - - // If the resulting memberName is null, then the indexer should be on the root of the object. - if (memberName == null) - { - return memberExpression; - } - - // if expression is nullable - if (inputExpression.Expression.Type.IsNullable()) - { - MemberNames memberNames = context.MemberNames; - - // ignore .Value - if (memberName == memberNames.Value) - { - return memberExpression; - } - - // convert .HasValue to IS_DEFINED expression - if (memberName == memberNames.HasValue) - { - return SqlFunctionCallScalarExpression.CreateBuiltin("IS_DEFINED", memberExpression); - } - } - - if (usePropertyRef) - { - SqlIdentifier propertyIdentifier = SqlIdentifier.Create(memberName); - SqlPropertyRefScalarExpression propertyRefExpression = SqlPropertyRefScalarExpression.Create(memberExpression, propertyIdentifier); - return propertyRefExpression; - } - else - { - SqlScalarExpression indexExpression = SqlLiteralScalarExpression.Create(SqlStringLiteral.Create(memberName)); - SqlMemberIndexerScalarExpression memberIndexerExpression = SqlMemberIndexerScalarExpression.Create(memberExpression, indexExpression); - return memberIndexerExpression; - } - } - - private static SqlScalarExpression[] VisitExpressionList(ReadOnlyCollection inputExpressionList, TranslationContext context) - { - SqlScalarExpression[] result = new SqlScalarExpression[inputExpressionList.Count]; - for (int i = 0; i < inputExpressionList.Count; i++) - { - SqlScalarExpression p = ExpressionToSql.VisitScalarExpression(inputExpressionList[i], context); - result[i] = p; - } - - return result; - } - - private static SqlObjectProperty VisitMemberAssignment(MemberAssignment inputExpression, TranslationContext context) - { - SqlScalarExpression assign = ExpressionToSql.VisitScalarExpression(inputExpression.Expression, context); - string memberName = inputExpression.Member.GetMemberName(context); - SqlPropertyName propName = SqlPropertyName.Create(memberName); - SqlObjectProperty prop = SqlObjectProperty.Create(propName, assign); - return prop; - } - - private static SqlObjectProperty VisitMemberMemberBinding(MemberMemberBinding inputExpression, TranslationContext context) - { - throw new DocumentQueryException(ClientResources.MemberBindingNotSupported); - } - - private static SqlObjectProperty VisitMemberListBinding(MemberListBinding inputExpression, TranslationContext context) - { - throw new DocumentQueryException(ClientResources.MemberBindingNotSupported); - } - - private static SqlObjectProperty[] VisitBindingList(ReadOnlyCollection inputExpressionList, TranslationContext context) - { - SqlObjectProperty[] list = new SqlObjectProperty[inputExpressionList.Count]; - for (int i = 0; i < inputExpressionList.Count; i++) - { - SqlObjectProperty b = ExpressionToSql.VisitBinding(inputExpressionList[i], context); - list[i] = b; - } - - return list; - } - - private static SqlObjectProperty[] CreateInitializers(ReadOnlyCollection arguments, ReadOnlyCollection members, TranslationContext context) - { - if (arguments.Count != members.Count) - { - throw new InvalidOperationException("Expected same number of arguments as members"); - } - - SqlObjectProperty[] result = new SqlObjectProperty[arguments.Count]; - for (int i = 0; i < arguments.Count; i++) - { - Expression arg = arguments[i]; - MemberInfo member = members[i]; - SqlScalarExpression value = ExpressionToSql.VisitScalarExpression(arg, context); - - string memberName = member.GetMemberName(context); - SqlPropertyName propName = SqlPropertyName.Create(memberName); - SqlObjectProperty prop = SqlObjectProperty.Create(propName, value); - result[i] = prop; - } - - return result; - } - - private static SqlSelectItem[] CreateSelectItems(ReadOnlyCollection arguments, ReadOnlyCollection members, TranslationContext context) - { - if (arguments.Count != members.Count) - { - throw new InvalidOperationException("Expected same number of arguments as members"); - } - - SqlSelectItem[] result = new SqlSelectItem[arguments.Count]; - for (int i = 0; i < arguments.Count; i++) - { - Expression arg = arguments[i]; - MemberInfo member = members[i]; - SqlScalarExpression selectExpression = ExpressionToSql.VisitScalarExpression(arg, context); - - string memberName = member.GetMemberName(context); - SqlIdentifier alias = SqlIdentifier.Create(memberName); - SqlSelectItem prop = SqlSelectItem.Create(selectExpression, alias); - result[i] = prop; - } - - return result; - } - - private static SqlScalarExpression VisitNew(NewExpression inputExpression, TranslationContext context) - { - if (typeof(Geometry).IsAssignableFrom(inputExpression.Type)) - { - return GeometrySqlExpressionFactory.Construct(inputExpression); - } - - if (inputExpression.Arguments.Count > 0) - { - if (inputExpression.Members == null) - { - throw new DocumentQueryException(ClientResources.ConstructorInvocationNotSupported); - } - - SqlObjectProperty[] propertyBindings = ExpressionToSql.CreateInitializers(inputExpression.Arguments, inputExpression.Members, context); - SqlObjectCreateScalarExpression create = SqlObjectCreateScalarExpression.Create(propertyBindings); - return create; - } - else - { - // no need to return anything; the initializer will generate the complete code - return null; - } - } - - private static SqlScalarExpression VisitMemberInit(MemberInitExpression inputExpression, TranslationContext context) - { - ExpressionToSql.VisitNew(inputExpression.NewExpression, context); // Return value is ignored - SqlObjectProperty[] propertyBindings = ExpressionToSql.VisitBindingList(inputExpression.Bindings, context); - SqlObjectCreateScalarExpression create = SqlObjectCreateScalarExpression.Create(propertyBindings); - return create; - } - - private static SqlScalarExpression VisitListInit(ListInitExpression inputExpression, TranslationContext context) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, inputExpression.NodeType)); - } - - private static SqlScalarExpression VisitNewArray(NewArrayExpression inputExpression, TranslationContext context) - { - SqlScalarExpression[] exprs = ExpressionToSql.VisitExpressionList(inputExpression.Expressions, context); - if (inputExpression.NodeType == ExpressionType.NewArrayInit) - { - SqlArrayCreateScalarExpression array = SqlArrayCreateScalarExpression.Create(exprs); - return array; - } - else - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, inputExpression.NodeType)); - } - } - - private static SqlScalarExpression VisitInvocation(InvocationExpression inputExpression, TranslationContext context) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, inputExpression.NodeType)); - } - - #endregion VISITOR - - #region Scalar and CollectionScalar Visitors - - private static Collection ConvertToCollection(SqlScalarExpression scalar) - { - if (usePropertyRef) - { - SqlPropertyRefScalarExpression propertyRefExpression = scalar as SqlPropertyRefScalarExpression; - if (propertyRefExpression == null) - { - throw new DocumentQueryException(ClientResources.PathExpressionsOnly); - } - - SqlInputPathCollection path = ConvertPropertyRefToPath(propertyRefExpression); - Collection result = new Collection(path); - return result; - } - else - { - SqlMemberIndexerScalarExpression memberIndexerExpression = scalar as SqlMemberIndexerScalarExpression; - if (memberIndexerExpression == null) - { - SqlPropertyRefScalarExpression propertyRefExpression = scalar as SqlPropertyRefScalarExpression; - if (propertyRefExpression == null) - { - throw new DocumentQueryException(ClientResources.PathExpressionsOnly); - } - - SqlInputPathCollection path = ConvertPropertyRefToPath(propertyRefExpression); - Collection result = new Collection(path); - return result; - } - else - { - SqlInputPathCollection path = ConvertMemberIndexerToPath(memberIndexerExpression); - Collection result = new Collection(path); - return result; - } - } - } - - /// - /// Convert the context's current query to a scalar Any collection - /// by wrapping it as following: SELECT VALUE COUNT(v0) > 0 FROM (current query) AS v0. - /// This is used in cases where LINQ expression ends with Any() which is a boolean scalar. - /// Normally Any would translate to SELECT VALUE EXISTS() subquery. However that wouldn't work - /// for these cases because it would result in a boolean value for each row instead of - /// one single "aggregated" boolean value. - /// - /// The translation context - /// The scalar Any collection - private static Collection ConvertToScalarAnyCollection(TranslationContext context) - { - SqlQuery query = context.CurrentQuery.FlattenAsPossible().GetSqlQuery(); - SqlCollection subqueryCollection = SqlSubqueryCollection.Create(query); - - ParameterExpression parameterExpression = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); - Binding binding = new Binding(parameterExpression, subqueryCollection, isInCollection: false, isInputParameter: true); - - context.CurrentQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc()); - context.CurrentQuery.AddBinding(binding); - - SqlSelectSpec selectSpec = SqlSelectValueSpec.Create( - SqlBinaryScalarExpression.Create( - SqlBinaryScalarOperatorKind.GreaterThan, - SqlFunctionCallScalarExpression.CreateBuiltin( - SqlFunctionCallScalarExpression.Names.Count, - SqlPropertyRefScalarExpression.Create(null, SqlIdentifier.Create(parameterExpression.Name))), - SqlLiteralScalarExpression.Create(SqlNumberLiteral.Create(0)))); - SqlSelectClause selectClause = SqlSelectClause.Create(selectSpec); - context.CurrentQuery.AddSelectClause(selectClause); - - return new Collection(LinqMethods.Any); - } - - private static SqlScalarExpression VisitNonSubqueryScalarExpression(Expression expression, ReadOnlyCollection parameters, TranslationContext context) - { - foreach (ParameterExpression par in parameters) - { - context.PushParameter(par, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); - } - - SqlScalarExpression scalarExpression = ExpressionToSql.VisitNonSubqueryScalarExpression(expression, context); - - foreach (ParameterExpression par in parameters) - { - context.PopParameter(); - } - - return scalarExpression; - } - - private static SqlScalarExpression VisitNonSubqueryScalarLambda(LambdaExpression lambdaExpression, TranslationContext context) - { - ReadOnlyCollection parameters = lambdaExpression.Parameters; - if (parameters.Count != 1) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, lambdaExpression.Body, 1, parameters.Count)); - } - - return ExpressionToSql.VisitNonSubqueryScalarExpression(lambdaExpression.Body, parameters, context); - } - - private static Collection VisitCollectionExpression(Expression expression, ReadOnlyCollection parameters, TranslationContext context) - { - foreach (ParameterExpression par in parameters) - { - context.PushParameter(par, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); - } - - Collection collection = ExpressionToSql.VisitCollectionExpression(expression, context, parameters.Count > 0 ? parameters.First().Name : ExpressionToSql.DefaultParameterName); - - foreach (ParameterExpression par in parameters) - { - context.PopParameter(); - } - - return collection; - } - - private static Collection VisitCollectionExpression(Expression expression, TranslationContext context, string parameterName) - { - Collection result; - switch (expression.NodeType) - { - case ExpressionType.Call: - result = ExpressionToSql.Translate(expression, context); - break; - - case ExpressionType.MemberAccess: - result = ExpressionToSql.VisitMemberAccessCollectionExpression(expression, context, parameterName); - break; - - default: - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, expression.NodeType)); - } - - return result; - } - - /// - /// Visit a lambda which is supposed to return a collection. - /// - /// LambdaExpression with a result which is a collection. - /// The translation context. - /// The collection computed by the lambda. - private static Collection VisitCollectionLambda(LambdaExpression lambdaExpression, TranslationContext context) - { - ReadOnlyCollection parameters = lambdaExpression.Parameters; - if (parameters.Count != 1) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, lambdaExpression.Body, 1, parameters.Count)); - } - - return ExpressionToSql.VisitCollectionExpression(lambdaExpression.Body, lambdaExpression.Parameters, context); - } - - /// - /// Visit an expression, usually a MemberAccess, then trigger parameter binding for that expression. - /// - /// The input expression - /// The current translation context - /// Parameter name is merely for readability - private static Collection VisitMemberAccessCollectionExpression(Expression inputExpression, TranslationContext context, string parameterName) - { - SqlScalarExpression body = ExpressionToSql.VisitNonSubqueryScalarExpression(inputExpression, context); - Type type = inputExpression.Type; - - Collection collection = ExpressionToSql.ConvertToCollection(body); - context.PushCollection(collection); - ParameterExpression parameter = context.GenerateFreshParameter(type, parameterName); - context.PushParameter(parameter, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); - context.PopParameter(); - context.PopCollection(); - - return new Collection(parameter.Name); - } - - /// - /// Visit a method call, construct the corresponding query in context.CurrentQuery. - /// At ExpressionToSql point only LINQ method calls are allowed. - /// These methods are static extension methods of IQueryable or IEnumerable. - /// - /// Method to translate. - /// Query translation context. - private static Collection VisitMethodCall(MethodCallExpression inputExpression, TranslationContext context) - { - context.PushMethod(inputExpression); - - Type declaringType = inputExpression.Method.DeclaringType; - if ((declaringType != typeof(Queryable) && declaringType != typeof(Enumerable)) - || !inputExpression.Method.IsStatic) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.OnlyLINQMethodsAreSupported, inputExpression.Method.Name)); - } - - Type returnType = inputExpression.Method.ReturnType; - Type returnElementType = TypeSystem.GetElementType(returnType); - - if (inputExpression.Object != null) - { - throw new DocumentQueryException(ClientResources.ExpectedMethodCallsMethods); - } - - Expression inputCollection = inputExpression.Arguments[0]; // all these methods are static extension methods, so argument[0] is the collection - - Type inputElementType = TypeSystem.GetElementType(inputCollection.Type); - Collection collection = ExpressionToSql.Translate(inputCollection, context); - - context.PushCollection(collection); - - Collection result = new Collection(inputExpression.Method.Name); - bool shouldBeOnNewQuery = context.CurrentQuery.ShouldBeOnNewQuery(inputExpression.Method.Name, inputExpression.Arguments.Count); - context.PushSubqueryBinding(shouldBeOnNewQuery); - - if (context.LastExpressionIsGroupBy) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, "Group By cannot be followed by other methods")); - } - - switch (inputExpression.Method.Name) - { - case LinqMethods.Any: - { - result = new Collection(string.Empty); - - if (inputExpression.Arguments.Count == 2) - { - // Any is translated to an SELECT VALUE EXISTS() where Any operation itself is treated as a Where. - SqlWhereClause where = ExpressionToSql.VisitWhere(inputExpression.Arguments, context); - context.CurrentQuery = context.CurrentQuery.AddWhereClause(where, context); - } - break; - } - case LinqMethods.Average: - { - SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Avg); - context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); - break; - } - case LinqMethods.Count: - { - SqlSelectClause select = ExpressionToSql.VisitCount(inputExpression.Arguments, context); - context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); - break; - } - case LinqMethods.Distinct: - { - SqlSelectClause select = ExpressionToSql.VisitDistinct(inputExpression.Arguments, context); - context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); - break; - } - case LinqMethods.FirstOrDefault: - { - if (inputExpression.Arguments.Count == 1) - { - // TOP is not allowed when OFFSET ... LIMIT is present. - if (!context.CurrentQuery.HasOffsetSpec()) - { - SqlNumberLiteral sqlNumberLiteral = SqlNumberLiteral.Create(1); - SqlTopSpec topSpec = SqlTopSpec.Create(sqlNumberLiteral); - context.CurrentQuery = context.CurrentQuery.AddTopSpec(topSpec); - } - - context.SetClientOperation(ScalarOperationKind.FirstOrDefault); - } - else - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, inputExpression.Method.Name, 0, inputExpression.Arguments.Count - 1)); - } - - break; - } - case LinqMethods.Max: - { - SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Max); - context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); - break; - } - case LinqMethods.Min: - { - SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Min); - context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); - break; - } - case LinqMethods.GroupBy: - { - context.CurrentQuery = context.PackageCurrentQueryIfNeccessary(); - result = ExpressionToSql.VisitGroupBy(returnElementType, inputExpression.Arguments, context); - context.LastExpressionIsGroupBy = true; - break; - } - case LinqMethods.OrderBy: - { - SqlOrderByClause orderBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, false, context); - context.CurrentQuery = context.CurrentQuery.AddOrderByClause(orderBy, context); - break; - } - case LinqMethods.OrderByDescending: - { - SqlOrderByClause orderBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, true, context); - context.CurrentQuery = context.CurrentQuery.AddOrderByClause(orderBy, context); - break; - } - case LinqMethods.Select: - { - SqlSelectClause select = ExpressionToSql.VisitSelect(inputExpression.Arguments, context); - context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); - break; - } - case LinqMethods.SelectMany: - { - context.CurrentQuery = context.PackageCurrentQueryIfNeccessary(); - result = ExpressionToSql.VisitSelectMany(inputExpression.Arguments, context); - break; - } - case LinqMethods.Skip: - { - SqlOffsetSpec offsetSpec = ExpressionToSql.VisitSkip(inputExpression.Arguments, context); - context.CurrentQuery = context.CurrentQuery.AddOffsetSpec(offsetSpec, context); - break; - } - case LinqMethods.Sum: - { - SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Sum); - context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); - break; - } - case LinqMethods.Take: - { - if (context.CurrentQuery.HasOffsetSpec()) - { - SqlLimitSpec limitSpec = ExpressionToSql.VisitTakeLimit(inputExpression.Arguments, context); - context.CurrentQuery = context.CurrentQuery.AddLimitSpec(limitSpec, context); - } - else - { - SqlTopSpec topSpec = ExpressionToSql.VisitTakeTop(inputExpression.Arguments, context); - context.CurrentQuery = context.CurrentQuery.AddTopSpec(topSpec); - } - break; - } - case LinqMethods.ThenBy: - { - SqlOrderByClause thenBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, false, context); - context.CurrentQuery = context.CurrentQuery.UpdateOrderByClause(thenBy, context); - break; - } - case LinqMethods.ThenByDescending: - { - SqlOrderByClause thenBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, true, context); - context.CurrentQuery = context.CurrentQuery.UpdateOrderByClause(thenBy, context); - break; - } - case LinqMethods.Where: - { - SqlWhereClause where = ExpressionToSql.VisitWhere(inputExpression.Arguments, context); - context.CurrentQuery = context.CurrentQuery.AddWhereClause(where, context); - break; - } - default: - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.MethodNotSupported, inputExpression.Method.Name)); - } - - context.PopSubqueryBinding(); - context.PopCollection(); - context.PopMethod(); - return result; - } - - /// - /// Visit a method call, construct the corresponding query and return the select clause for the aggregate function. - /// At ExpressionToSql point only LINQ method calls are allowed. - /// These methods are static extension methods of IQueryable or IEnumerable. - /// - /// Method to translate. - /// Query translation context. - private static SqlSelectClause VisitGroupByAggregateMethodCall(MethodCallExpression inputExpression, TranslationContext context) - { - context.PushMethod(inputExpression); - - Type declaringType = inputExpression.Method.DeclaringType; - if ((declaringType != typeof(Queryable) && declaringType != typeof(Enumerable)) - || !inputExpression.Method.IsStatic) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.OnlyLINQMethodsAreSupported, inputExpression.Method.Name)); - } - - if (inputExpression.Object != null) - { - throw new DocumentQueryException(ClientResources.ExpectedMethodCallsMethods); - } - - if (context.LastExpressionIsGroupBy) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, "Group By cannot be followed by other methods")); - } - - SqlSelectClause select; - switch (inputExpression.Method.Name) - { - case LinqMethods.Average: - { - select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Avg); - break; - } - case LinqMethods.Count: - { - select = ExpressionToSql.VisitCount(inputExpression.Arguments, context); - break; - } - case LinqMethods.Max: - { - select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Max); - break; - } - case LinqMethods.Min: - { - select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Min); - break; - } - case LinqMethods.Sum: - { - select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Sum); - break; - } - default: - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.MethodNotSupported, inputExpression.Method.Name)); - } - - context.PopMethod(); - return select; - } - - /// - /// Determine if an expression should be translated to a subquery. - /// This only applies to expression that is inside a lamda. - /// - /// The input expression - /// The expression object kind of the expression - /// True if the method is either Min, Max, or Avg - /// True if subquery is needed, otherwise false - private static bool IsSubqueryScalarExpression(Expression expression, out SubqueryKind? expressionObjKind, out bool isMinMaxAvgMethod) - { - if (!(expression is MethodCallExpression methodCallExpression)) - { - expressionObjKind = null; - isMinMaxAvgMethod = false; - return false; - } - - string methodName = methodCallExpression.Method.Name; - bool isSubqueryExpression; - - isMinMaxAvgMethod = false; - - switch (methodName) - { - case LinqMethods.Min: - case LinqMethods.Max: - case LinqMethods.Average: - isMinMaxAvgMethod = true; - isSubqueryExpression = true; - expressionObjKind = SubqueryKind.SubqueryScalarExpression; - break; - - case LinqMethods.Sum: - isSubqueryExpression = true; - expressionObjKind = SubqueryKind.SubqueryScalarExpression; - break; - - case LinqMethods.Count: - if (methodCallExpression.Arguments.Count > 1) - { - isSubqueryExpression = true; - expressionObjKind = SubqueryKind.SubqueryScalarExpression; - } - else - { - SubqueryKind? objKind; - bool isMinMaxAvg; - isSubqueryExpression = ExpressionToSql.IsSubqueryScalarExpression( - methodCallExpression.Arguments[0] as MethodCallExpression, - out objKind, out isMinMaxAvg); - - if (isSubqueryExpression) - { - isSubqueryExpression = true; - expressionObjKind = SubqueryKind.SubqueryScalarExpression; - } - else - { - isSubqueryExpression = false; - expressionObjKind = null; - } - } - break; - - case LinqMethods.Any: - isSubqueryExpression = true; - expressionObjKind = SubqueryKind.ExistsScalarExpression; - break; - - case LinqMethods.Select: - case LinqMethods.SelectMany: - case LinqMethods.Where: - case LinqMethods.OrderBy: - case LinqMethods.OrderByDescending: - case LinqMethods.ThenBy: - case LinqMethods.ThenByDescending: - case LinqMethods.Skip: - case LinqMethods.Take: - case LinqMethods.Distinct: - case LinqMethods.GroupBy: - isSubqueryExpression = true; - expressionObjKind = SubqueryKind.ArrayScalarExpression; - break; - - default: - isSubqueryExpression = false; - expressionObjKind = null; - break; - } - - return isSubqueryExpression; - } - - /// - /// Visit an lambda expression which is in side a lambda and translate it to a scalar expression or a subquery scalar expression. - /// See the other overload of this method for more details. - /// - /// The input lambda expression - /// The translation context - /// A scalar expression representing the input expression - private static SqlScalarExpression VisitScalarExpression(LambdaExpression lambda, TranslationContext context) - { - return ExpressionToSql.VisitScalarExpression( - lambda.Body, - lambda.Parameters, - context); - } - - /// - /// Visit an lambda expression which is inside a lambda and translate it to a scalar expression or a collection scalar expression. - /// If it is a collection scalar expression, e.g. should be translated to subquery such as SELECT VALUE ARRAY, SELECT VALUE EXISTS, - /// SELECT VALUE [aggregate], the subquery will be aliased to a new binding for the FROM clause. E.g. consider - /// Select(family => family.Children.Select(child => child.Grade)). Since the inner Select corresponds to a subquery, this method would - /// create a new binding of v0 to the subquery SELECT VALUE ARRAY(), and the inner expression will be just SELECT v0. - /// - /// The input expression - /// The translation context - /// A scalar expression representing the input expression - internal static SqlScalarExpression VisitScalarExpression(Expression expression, TranslationContext context) - { - return ExpressionToSql.VisitScalarExpression( - expression, - new ReadOnlyCollection(Array.Empty()), - context); - } - - internal static bool TryGetSqlNumberLiteral(object value, out SqlNumberLiteral sqlNumberLiteral) - { - sqlNumberLiteral = default(SqlNumberLiteral); - if (value is byte byteValue) - { - sqlNumberLiteral = SqlNumberLiteral.Create(byteValue); - } - else if (value is sbyte sbyteValue) - { - sqlNumberLiteral = SqlNumberLiteral.Create(sbyteValue); - } - else if (value is decimal decimalValue) - { - if ((decimalValue >= long.MinValue) && (decimalValue <= long.MaxValue) && (decimalValue % 1 == 0)) - { - sqlNumberLiteral = SqlNumberLiteral.Create(Convert.ToInt64(decimalValue)); - } - else - { - sqlNumberLiteral = SqlNumberLiteral.Create(Convert.ToDouble(decimalValue)); - } - } - else if (value is double doubleValue) - { - sqlNumberLiteral = SqlNumberLiteral.Create(doubleValue); - } - else if (value is float floatVlaue) - { - sqlNumberLiteral = SqlNumberLiteral.Create(floatVlaue); - } - else if (value is int intValue) - { - sqlNumberLiteral = SqlNumberLiteral.Create(intValue); - } - else if (value is uint uintValue) - { - sqlNumberLiteral = SqlNumberLiteral.Create(uintValue); - } - else if (value is long longValue) - { - sqlNumberLiteral = SqlNumberLiteral.Create(longValue); - } - else if (value is ulong ulongValue) - { - if (ulongValue <= long.MaxValue) - { - sqlNumberLiteral = SqlNumberLiteral.Create(Convert.ToInt64(ulongValue)); - } - else - { - sqlNumberLiteral = SqlNumberLiteral.Create(Convert.ToDouble(ulongValue)); - } - } - else if (value is short shortValue) - { - sqlNumberLiteral = SqlNumberLiteral.Create(shortValue); - } - else if (value is ushort ushortValue) - { - sqlNumberLiteral = SqlNumberLiteral.Create(ushortValue); - } - - return sqlNumberLiteral != default(SqlNumberLiteral); - } - - /// - /// Visit an lambda expression which is in side a lambda and translate it to a scalar expression or a collection scalar expression. - /// See the other overload of this method for more details. - /// - private static SqlScalarExpression VisitScalarExpression(Expression expression, - ReadOnlyCollection parameters, - TranslationContext context) - { - SubqueryKind? expressionObjKind; - bool isMinMaxAvgMethod; - bool shouldUseSubquery = ExpressionToSql.IsSubqueryScalarExpression(expression, out expressionObjKind, out isMinMaxAvgMethod); - - SqlScalarExpression sqlScalarExpression; - if (!shouldUseSubquery) - { - sqlScalarExpression = ExpressionToSql.VisitNonSubqueryScalarExpression(expression, parameters, context); - } - else - { - SqlQuery query = ExpressionToSql.CreateSubquery(expression, parameters, context); - - ParameterExpression parameterExpression = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); - SqlCollection subqueryCollection = ExpressionToSql.CreateSubquerySqlCollection( - query, - isMinMaxAvgMethod ? SubqueryKind.ArrayScalarExpression : expressionObjKind.Value); - - Binding newBinding = new Binding(parameterExpression, subqueryCollection, - isInCollection: false, isInputParameter: context.IsInMainBranchSelect()); - - context.CurrentSubqueryBinding.NewBindings.Add(newBinding); - - if (isMinMaxAvgMethod) - { - sqlScalarExpression = SqlMemberIndexerScalarExpression.Create( - SqlPropertyRefScalarExpression.Create(null, SqlIdentifier.Create(parameterExpression.Name)), - SqlLiteralScalarExpression.Create(SqlNumberLiteral.Create(0))); - } - else - { - sqlScalarExpression = SqlPropertyRefScalarExpression.Create(null, SqlIdentifier.Create(parameterExpression.Name)); - } - } - - return sqlScalarExpression; - } - - /// - /// Create a subquery SQL collection object for a SQL query - /// - /// The SQL query object - /// The subquery type - private static SqlCollection CreateSubquerySqlCollection(SqlQuery query, SubqueryKind subqueryType) - { - SqlCollection subqueryCollection; - switch (subqueryType) - { - case SubqueryKind.ArrayScalarExpression: - SqlArrayScalarExpression arrayScalarExpression = SqlArrayScalarExpression.Create(query); - query = SqlQuery.Create( - SqlSelectClause.Create(SqlSelectValueSpec.Create(arrayScalarExpression)), - fromClause: null, whereClause: null, groupByClause: null, orderByClause: null, offsetLimitClause: null); - break; - - case SubqueryKind.ExistsScalarExpression: - SqlExistsScalarExpression existsScalarExpression = SqlExistsScalarExpression.Create(query); - query = SqlQuery.Create( - SqlSelectClause.Create(SqlSelectValueSpec.Create(existsScalarExpression)), - fromClause: null, whereClause: null, groupByClause: null, orderByClause: null, offsetLimitClause: null); - break; - - case SubqueryKind.SubqueryScalarExpression: - // No need to wrap query as in ArrayScalarExpression, or ExistsScalarExpression - break; - - default: - throw new DocumentQueryException($"Unsupported subquery type {subqueryType}"); - } - - subqueryCollection = SqlSubqueryCollection.Create(query); - return subqueryCollection; - } - - /// - /// Create a subquery from a subquery scalar expression. - /// By visiting the collection expression, this builds a new QueryUnderConstruction on top of the current one - /// and then translate it to a SQL query while keeping the current QueryUnderConstruction in tact. - /// - /// The subquery scalar expression - /// The list of parameters of the expression - /// The translation context - /// A query corresponding to the collection expression - /// The QueryUnderConstruction remains unchanged after this. - private static SqlQuery CreateSubquery(Expression expression, ReadOnlyCollection parameters, TranslationContext context) - { - bool shouldBeOnNewQuery = context.CurrentSubqueryBinding.ShouldBeOnNewQuery; - - QueryUnderConstruction queryBeforeVisit = context.CurrentQuery; - QueryUnderConstruction packagedQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc(), context.CurrentQuery); - packagedQuery.FromParameters.SetInputParameter(typeof(object), context.CurrentQuery.GetInputParameterInContext(shouldBeOnNewQuery).Name, context.InScope); - context.CurrentQuery = packagedQuery; - - if (shouldBeOnNewQuery) context.CurrentSubqueryBinding.ShouldBeOnNewQuery = false; - - Collection collection = ExpressionToSql.VisitCollectionExpression(expression, parameters, context); - - QueryUnderConstruction subquery = context.CurrentQuery.GetSubquery(queryBeforeVisit); - context.CurrentSubqueryBinding.ShouldBeOnNewQuery = shouldBeOnNewQuery; - context.CurrentQuery = queryBeforeVisit; - - SqlQuery sqlSubquery = subquery.FlattenAsPossible().GetSqlQuery(); - return sqlSubquery; - } - - #endregion Scalar and CollectionScalar Visitors - - #region LINQ Specific Visitors - - private static SqlWhereClause VisitWhere(ReadOnlyCollection arguments, TranslationContext context) - { - if (arguments.Count != 2) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.Where, 2, arguments.Count)); - } - - LambdaExpression function = Utilities.GetLambda(arguments[1]); - SqlScalarExpression sqlfunc = ExpressionToSql.VisitScalarExpression(function, context); - SqlWhereClause where = SqlWhereClause.Create(sqlfunc); - return where; - } - - private static SqlSelectClause VisitSelect(ReadOnlyCollection arguments, TranslationContext context) - { - if (arguments.Count != 2) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.Select, 2, arguments.Count)); - } - - LambdaExpression lambda = Utilities.GetLambda(arguments[1]); - - SqlScalarExpression sqlfunc = ExpressionToSql.VisitScalarExpression(lambda, context); - SqlSelectSpec sqlSpec = SqlSelectValueSpec.Create(sqlfunc); - SqlSelectClause select = SqlSelectClause.Create(sqlSpec, null); - return select; - } - - private static Collection VisitSelectMany(ReadOnlyCollection arguments, TranslationContext context) - { - if (arguments.Count != 2) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.SelectMany, 2, arguments.Count)); - } - - LambdaExpression lambda = Utilities.GetLambda(arguments[1]); - - // If there is Distinct, Take or OrderBy the lambda then it needs to be in a subquery. - bool requireLocalExecution = false; - - for (MethodCallExpression methodCall = lambda.Body as MethodCallExpression; - methodCall != null; - methodCall = methodCall.Arguments[0] as MethodCallExpression) - { - string methodName = methodCall.Method.Name; - requireLocalExecution |= methodName.Equals(LinqMethods.Distinct) || methodName.Equals(LinqMethods.Take) || methodName.Equals(LinqMethods.OrderBy) || methodName.Equals(LinqMethods.OrderByDescending); - } - - Collection collection; - if (!requireLocalExecution) - { - collection = ExpressionToSql.VisitCollectionLambda(lambda, context); - } - else - { - collection = new Collection(string.Empty); - Binding binding; - SqlQuery query = ExpressionToSql.CreateSubquery(lambda.Body, lambda.Parameters, context); - SqlCollection subqueryCollection = SqlSubqueryCollection.Create(query); - ParameterExpression parameterExpression = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); - binding = new Binding(parameterExpression, subqueryCollection, isInCollection: false, isInputParameter: true); - context.CurrentQuery.FromParameters.Add(binding); - } - - return collection; - } - - private static Collection VisitGroupBy(Type returnElementType, ReadOnlyCollection arguments, TranslationContext context) - { - if (arguments.Count != 3) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.GroupBy, 3, arguments.Count)); - } - - // Key Selector handling - // First argument is input, second is key selector and third is value selector + using Microsoft.Azure.Cosmos.Serialization.HybridRow; + using Microsoft.Azure.Cosmos.Serializer; + using Microsoft.Azure.Cosmos.Spatial; + using Microsoft.Azure.Cosmos.SqlObjects; + using Microsoft.Azure.Documents; + using static Microsoft.Azure.Cosmos.Linq.FromParameterBindings; + + // ReSharper disable UnusedParameter.Local + + ////////////////////////////////////////////////////////////////////// + // + // data.SelectMany(x => x.fields.SelectMany(y => y.fields.Select(z => f(x,y,z))) + // expression tree: + // SelectMany - lambda - Selectmany - lambda - Select - lambda - f(x,y,z) + // | | | | | | + // data x .- fields y .- fields z + // | | + // x y + // parameter bound_to + // x data + // y x.fields + // z y.fields + // + // data.Where(x => f(x)).Select(y => g(y)) + // expression tree: + // Select - lambda - g(y) + // | | + // | y + // Where - lambda - f(x) + // | | + // data x + // + // parameter bound_to + // x data + // y Where + + /// + /// Core Linq to DocDBSQL translator. + /// + internal static class ExpressionToSql + { + public static class LinqMethods + { + public const string Any = "Any"; + public const string Average = "Average"; + public const string Count = "Count"; + public const string Distinct = "Distinct"; + public const string First = "First"; + public const string FirstOrDefault = "FirstOrDefault"; + public const string Max = "Max"; + public const string Min = "Min"; + public const string GroupBy = "GroupBy"; + public const string OrderBy = "OrderBy"; + public const string OrderByDescending = "OrderByDescending"; + public const string Select = "Select"; + public const string SelectMany = "SelectMany"; + public const string Single = "Single"; + public const string SingleOrDefault = "SingleOrDefault"; + public const string Skip = "Skip"; + public const string Sum = "Sum"; + public const string ThenBy = "ThenBy"; + public const string ThenByDescending = "ThenByDescending"; + public const string Take = "Take"; + public const string Where = "Where"; + } + + private static readonly string SqlRoot = "root"; + private static readonly string DefaultParameterName = "v"; + private static readonly bool usePropertyRef = false; + + /// + /// Toplevel entry point. + /// + /// An Expression representing a Query on a IDocumentQuery object. + /// Optional dictionary for parameter name and value + /// Optional serializer options. + /// Indicates the client operation that needs to be performed on the results of SqlQuery. + /// The corresponding SQL query. + public static SqlQuery TranslateQuery( + Expression inputExpression, + IDictionary parameters, + CosmosLinqSerializerOptionsInternal linqSerializerOptions, + out ScalarOperationKind clientOperation) + { + TranslationContext context = new TranslationContext(linqSerializerOptions, parameters); + ExpressionToSql.Translate(inputExpression, context); // ignore result here + + QueryUnderConstruction query = context.CurrentQuery; + query = query.FlattenAsPossible(); + SqlQuery result = query.GetSqlQuery(); + clientOperation = context.ClientOperation; + + return result; + } + + /// + /// Translate an expression into a query. + /// Query is constructed as a side-effect in context.CurrentQuery. + /// + /// Expression to translate. + /// Context for translation. + private static Collection Translate(Expression inputExpression, TranslationContext context) + { + Debug.Assert(context != null, "Translation Context should not be null"); + + if (inputExpression == null) + { + throw new ArgumentNullException("inputExpression"); + } + + Collection collection; + switch (inputExpression.NodeType) + { + case ExpressionType.Call: + MethodCallExpression methodCallExpression = (MethodCallExpression)inputExpression; + bool shouldConvertToScalarAnyCollection = (context.PeekMethod() == null) && methodCallExpression.Method.Name.Equals(LinqMethods.Any); + collection = ExpressionToSql.VisitMethodCall(methodCallExpression, context); + if (shouldConvertToScalarAnyCollection) collection = ExpressionToSql.ConvertToScalarAnyCollection(context); + + break; + + case ExpressionType.Constant: + collection = ExpressionToSql.TranslateInput((ConstantExpression)inputExpression, context); + break; + + case ExpressionType.MemberAccess: + collection = ExpressionToSql.VisitMemberAccessCollectionExpression(inputExpression, context, ExpressionToSql.GetBindingParameterName(context)); + break; + + case ExpressionType.Parameter: + SqlScalarExpression scalar = ExpressionToSql.VisitNonSubqueryScalarExpression(inputExpression, context); + collection = ExpressionToSql.ConvertToCollection(scalar); + break; + + default: + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, inputExpression.NodeType)); + } + return collection; + } + + private static Collection TranslateInput(ConstantExpression inputExpression, TranslationContext context) + { + if (!typeof(IDocumentQuery).IsAssignableFrom(inputExpression.Type)) + { + throw new DocumentQueryException(ClientResources.InputIsNotIDocumentQuery); + } + + // ExpressionToSql is the query input value: a IDocumentQuery + if (!(inputExpression.Value is IDocumentQuery input)) + { + throw new DocumentQueryException(ClientResources.InputIsNotIDocumentQuery); + } + + context.CurrentQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc()); + Type elemType = TypeSystem.GetElementType(inputExpression.Type); + context.SetInputParameter(elemType, ParameterSubstitution.InputParameterName); // ignore result + + // First outer collection + Collection result = new Collection(ExpressionToSql.SqlRoot); + return result; + } + + /// + /// Get a parameter name to be binded to the collection from the next lambda. + /// It's merely for readability purpose. If that is not possible, use a default + /// parameter name. + /// + /// The translation context + /// A parameter name + private static string GetBindingParameterName(TranslationContext context) + { + MethodCallExpression peekMethod = context.PeekMethod(); + + // The parameter name is the top method's parameter if applicable + string parameterName = null; + if (peekMethod.Arguments.Count > 1) + { + if (peekMethod.Arguments[1] is LambdaExpression lambda && lambda.Parameters.Count > 0) + { + parameterName = lambda.Parameters[0].Name; + } + } + + parameterName ??= ExpressionToSql.DefaultParameterName; + + return parameterName; + } + + #region VISITOR + + /// + /// Visitor which produces a SqlScalarExpression. + /// + /// Expression to visit. + /// Context information. + /// The translation as a ScalarExpression. + internal static SqlScalarExpression VisitNonSubqueryScalarExpression(Expression inputExpression, TranslationContext context) + { + if (inputExpression == null) + { + return null; + } + + switch (inputExpression.NodeType) + { + case ExpressionType.Negate: + case ExpressionType.NegateChecked: + case ExpressionType.Not: + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + case ExpressionType.ArrayLength: + case ExpressionType.Quote: + case ExpressionType.TypeAs: + return ExpressionToSql.VisitUnary((UnaryExpression)inputExpression, context); + case ExpressionType.Add: + case ExpressionType.AddChecked: + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + case ExpressionType.Divide: + case ExpressionType.Modulo: + case ExpressionType.And: + case ExpressionType.AndAlso: + case ExpressionType.Or: + case ExpressionType.OrElse: + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + case ExpressionType.GreaterThan: + case ExpressionType.GreaterThanOrEqual: + case ExpressionType.Equal: + case ExpressionType.NotEqual: + case ExpressionType.Coalesce: + case ExpressionType.ArrayIndex: + case ExpressionType.RightShift: + case ExpressionType.LeftShift: + case ExpressionType.ExclusiveOr: + return ExpressionToSql.VisitBinary((BinaryExpression)inputExpression, context); + case ExpressionType.TypeIs: + return ExpressionToSql.VisitTypeIs((TypeBinaryExpression)inputExpression, context); + case ExpressionType.Conditional: + return ExpressionToSql.VisitConditional((ConditionalExpression)inputExpression, context); + case ExpressionType.Constant: + return ExpressionToSql.VisitConstant((ConstantExpression)inputExpression, context); + case ExpressionType.Parameter: + return ExpressionToSql.VisitParameter((ParameterExpression)inputExpression, context); + case ExpressionType.MemberAccess: + return ExpressionToSql.VisitMemberAccess((MemberExpression)inputExpression, context); + case ExpressionType.New: + return ExpressionToSql.VisitNew((NewExpression)inputExpression, context); + case ExpressionType.NewArrayInit: + case ExpressionType.NewArrayBounds: + return ExpressionToSql.VisitNewArray((NewArrayExpression)inputExpression, context); + case ExpressionType.Invoke: + return ExpressionToSql.VisitInvocation((InvocationExpression)inputExpression, context); + case ExpressionType.MemberInit: + return ExpressionToSql.VisitMemberInit((MemberInitExpression)inputExpression, context); + case ExpressionType.ListInit: + return ExpressionToSql.VisitListInit((ListInitExpression)inputExpression, context); + case ExpressionType.Call: + return ExpressionToSql.VisitMethodCallScalar((MethodCallExpression)inputExpression, context); + default: + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, inputExpression.NodeType)); + } + } + + private static SqlScalarExpression VisitMethodCallScalar(MethodCallExpression methodCallExpression, TranslationContext context) + { + // Check if it is a UDF method call + if (methodCallExpression.Method.Equals(typeof(CosmosLinq).GetMethod("InvokeUserDefinedFunction"))) + { + string udfName = ((ConstantExpression)methodCallExpression.Arguments[0]).Value as string; + if (string.IsNullOrEmpty(udfName)) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.UdfNameIsNullOrEmpty)); + } + + SqlIdentifier methodName = SqlIdentifier.Create(udfName); + List arguments = new List(); + + if (methodCallExpression.Arguments.Count == 2) + { + // We have two cases here, if the udf was expecting only one parameter and this parameter is an array + // then the second argument will be an expression of this array. + // else we will have a NewArrayExpression of the udf arguments + if (methodCallExpression.Arguments[1] is NewArrayExpression newArrayExpression) + { + ReadOnlyCollection argumentsExpressions = newArrayExpression.Expressions; + foreach (Expression argument in argumentsExpressions) + { + arguments.Add(ExpressionToSql.VisitScalarExpression(argument, context)); + } + } + else if (methodCallExpression.Arguments[1].NodeType == ExpressionType.Constant && + methodCallExpression.Arguments[1].Type == typeof(object[])) + { + object[] argumentsExpressions = (object[])((ConstantExpression)methodCallExpression.Arguments[1]).Value; + foreach (object argument in argumentsExpressions) + { + arguments.Add(ExpressionToSql.VisitConstant(Expression.Constant(argument), context)); + } + } + else + { + arguments.Add(ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[1], context)); + } + } + + return SqlFunctionCallScalarExpression.Create(methodName, true, arguments.ToImmutableArray()); + } + else + { + return BuiltinFunctionVisitor.VisitBuiltinFunctionCall(methodCallExpression, context); + } + } + + private static SqlObjectProperty VisitBinding(MemberBinding binding, TranslationContext context) + { + switch (binding.BindingType) + { + case MemberBindingType.Assignment: + return ExpressionToSql.VisitMemberAssignment((MemberAssignment)binding, context); + case MemberBindingType.MemberBinding: + return ExpressionToSql.VisitMemberMemberBinding((MemberMemberBinding)binding, context); + case MemberBindingType.ListBinding: + default: + return ExpressionToSql.VisitMemberListBinding((MemberListBinding)binding, context); + } + } + + private static SqlUnaryScalarOperatorKind GetUnaryOperatorKind(ExpressionType type) + { + switch (type) + { + case ExpressionType.UnaryPlus: + return SqlUnaryScalarOperatorKind.Plus; + case ExpressionType.Negate: + return SqlUnaryScalarOperatorKind.Minus; + case ExpressionType.OnesComplement: + return SqlUnaryScalarOperatorKind.BitwiseNot; + case ExpressionType.Not: + return SqlUnaryScalarOperatorKind.Not; + default: + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.UnaryOperatorNotSupported, type)); + } + } + + private static SqlScalarExpression VisitUnary(UnaryExpression inputExpression, TranslationContext context) + { + SqlScalarExpression operand = ExpressionToSql.VisitScalarExpression(inputExpression.Operand, context); + + // handle NOT IN + if (operand is SqlInScalarExpression sqlInScalarExpression && inputExpression.NodeType == ExpressionType.Not) + { + SqlInScalarExpression inExpression = sqlInScalarExpression; + return SqlInScalarExpression.Create(inExpression.Needle, true, inExpression.Haystack); + } + + if (inputExpression.NodeType == ExpressionType.Quote) + { + return operand; + } + + if (inputExpression.NodeType == ExpressionType.Convert) + { + return operand; + } + + SqlUnaryScalarOperatorKind op = GetUnaryOperatorKind(inputExpression.NodeType); + return SqlUnaryScalarExpression.Create(op, operand); + } + + private static SqlBinaryScalarOperatorKind GetBinaryOperatorKind(ExpressionType expressionType, Type resultType) + { + switch (expressionType) + { + case ExpressionType.Add: + { + if (resultType == typeof(string)) + { + return SqlBinaryScalarOperatorKind.StringConcat; + } + return SqlBinaryScalarOperatorKind.Add; + } + case ExpressionType.AndAlso: + return SqlBinaryScalarOperatorKind.And; + case ExpressionType.And: + return SqlBinaryScalarOperatorKind.BitwiseAnd; + case ExpressionType.Or: + return SqlBinaryScalarOperatorKind.BitwiseOr; + case ExpressionType.ExclusiveOr: + return SqlBinaryScalarOperatorKind.BitwiseXor; + case ExpressionType.Divide: + return SqlBinaryScalarOperatorKind.Divide; + case ExpressionType.Equal: + return SqlBinaryScalarOperatorKind.Equal; + case ExpressionType.GreaterThan: + return SqlBinaryScalarOperatorKind.GreaterThan; + case ExpressionType.GreaterThanOrEqual: + return SqlBinaryScalarOperatorKind.GreaterThanOrEqual; + case ExpressionType.LessThan: + return SqlBinaryScalarOperatorKind.LessThan; + case ExpressionType.LessThanOrEqual: + return SqlBinaryScalarOperatorKind.LessThanOrEqual; + case ExpressionType.Modulo: + return SqlBinaryScalarOperatorKind.Modulo; + case ExpressionType.Multiply: + return SqlBinaryScalarOperatorKind.Multiply; + case ExpressionType.NotEqual: + return SqlBinaryScalarOperatorKind.NotEqual; + case ExpressionType.OrElse: + return SqlBinaryScalarOperatorKind.Or; + case ExpressionType.Subtract: + return SqlBinaryScalarOperatorKind.Subtract; + case ExpressionType.Coalesce: + return SqlBinaryScalarOperatorKind.Coalesce; + default: + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.BinaryOperatorNotSupported, expressionType)); + } + } + + private static SqlScalarExpression VisitBinary(BinaryExpression inputExpression, TranslationContext context) + { + // Speical case for string.CompareTo + // if any of the left or right expression is MethodCallExpression + // the other expression should only be constant (integer) + MethodCallExpression methodCallExpression = null; + ConstantExpression constantExpression = null; + + bool reverseNodeType = false; + if (inputExpression.Left.NodeType == ExpressionType.Call && inputExpression.Right.NodeType == ExpressionType.Constant) + { + methodCallExpression = (MethodCallExpression)inputExpression.Left; + constantExpression = (ConstantExpression)inputExpression.Right; + } + else if (inputExpression.Right.NodeType == ExpressionType.Call && inputExpression.Left.NodeType == ExpressionType.Constant) + { + methodCallExpression = (MethodCallExpression)inputExpression.Right; + constantExpression = (ConstantExpression)inputExpression.Left; + reverseNodeType = true; + } + + if (methodCallExpression != null && constantExpression != null) + { + if (TryMatchStringCompareTo(methodCallExpression, constantExpression, inputExpression.NodeType)) + { + return ExpressionToSql.VisitStringCompareTo(methodCallExpression, inputExpression.NodeType, reverseNodeType, context); + } + + if (TryMatchStringCompare(methodCallExpression, constantExpression, inputExpression.NodeType)) + { + return ExpressionToSql.VisitStringCompare(methodCallExpression, inputExpression.NodeType, reverseNodeType, context); + } + } + + SqlScalarExpression left = ExpressionToSql.VisitScalarExpression(inputExpression.Left, context); + SqlScalarExpression right = ExpressionToSql.VisitScalarExpression(inputExpression.Right, context); + + if (inputExpression.NodeType == ExpressionType.ArrayIndex) + { + SqlMemberIndexerScalarExpression result = SqlMemberIndexerScalarExpression.Create(left, right); + return result; + } + + SqlBinaryScalarOperatorKind op = GetBinaryOperatorKind(inputExpression.NodeType, inputExpression.Type); + + if (left is SqlMemberIndexerScalarExpression && right is SqlLiteralScalarExpression literalScalarExpression) + { + right = ExpressionToSql.ApplyCustomConverters(inputExpression.Left, literalScalarExpression, context); + } + else if (right is SqlMemberIndexerScalarExpression && left is SqlLiteralScalarExpression sqlLiteralScalarExpression) + { + left = ExpressionToSql.ApplyCustomConverters(inputExpression.Right, sqlLiteralScalarExpression, context); + } + + return SqlBinaryScalarExpression.Create(op, left, right); + } + + private static SqlScalarExpression ApplyCustomConverters(Expression left, SqlLiteralScalarExpression right, TranslationContext context) + { + MemberExpression memberExpression; + if (left is UnaryExpression unaryExpression) + { + memberExpression = unaryExpression.Operand as MemberExpression; + } + else + { + memberExpression = left as MemberExpression; + } + + if (memberExpression != null && + right.Literal is not SqlNullLiteral) + { + Type memberType = memberExpression.Type; + if (memberType.IsNullable()) + { + memberType = memberType.NullableUnderlyingType(); + } + + bool requiresCustomSerialization = context.CosmosLinqSerializer.RequiresCustomSerialization(memberExpression, memberType); + if (requiresCustomSerialization) + { + object value = default(object); + // Enum + if (memberType.IsEnum()) + { + try + { + Number64 number64 = ((SqlNumberLiteral)right.Literal).Value; + if (number64.IsDouble) + { + value = Enum.ToObject(memberType, Number64.ToDouble(number64)); + } + else + { + value = Enum.ToObject(memberType, Number64.ToLong(number64)); + } + } + catch + { + value = ((SqlStringLiteral)right.Literal).Value; + } + + } + // DateTime + else if (memberType == typeof(DateTime)) + { + SqlStringLiteral serializedDateTime = (SqlStringLiteral)right.Literal; + value = DateTime.Parse(serializedDateTime.Value, provider: null, DateTimeStyles.RoundtripKind); + } + + if (value != default(object)) + { + string serializedValue = context.CosmosLinqSerializer.Serialize(value, memberExpression, memberType); + return CosmosElement.Parse(serializedValue).Accept(CosmosElementToSqlScalarExpressionVisitor.Singleton); + } + } + } + + return right; + } + + private static bool TryMatchStringCompareTo(MethodCallExpression left, ConstantExpression right, ExpressionType compareOperator) + { + if (left.Method.Equals(typeof(string).GetMethod("CompareTo", new Type[] { typeof(string) })) && left.Arguments.Count == 1) + { + // operator can only be =, >, >=, <, <= + switch (compareOperator) + { + case ExpressionType.Equal: + case ExpressionType.GreaterThan: + case ExpressionType.GreaterThanOrEqual: + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + break; + default: + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.StringCompareToInvalidOperator)); + } + + // the constant value should be zero, otherwise we can't determine how to translate the expression + // it could be either integer or nullable integer + if (!(right.Type == typeof(int) && (int)right.Value == 0) && + !(right.Type == typeof(int?) && ((int?)right.Value).HasValue && ((int?)right.Value).Value == 0)) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.StringCompareToInvalidConstant)); + } + + return true; + } + + return false; + } + + private static SqlScalarExpression VisitStringCompareTo( + MethodCallExpression left, + ExpressionType compareOperator, + bool reverseNodeType, + TranslationContext context) + { + if (reverseNodeType) + { + compareOperator = ReverseExpressionTypeForStrings(compareOperator, ClientResources.StringCompareToInvalidOperator); + } + + SqlBinaryScalarOperatorKind op = GetBinaryOperatorKind(compareOperator, null); + + SqlScalarExpression leftExpression = ExpressionToSql.VisitNonSubqueryScalarExpression(left.Object, context); + SqlScalarExpression rightExpression = ExpressionToSql.VisitNonSubqueryScalarExpression(left.Arguments[0], context); + + return SqlBinaryScalarExpression.Create(op, leftExpression, rightExpression); + } + + private static ExpressionType ReverseExpressionTypeForStrings(ExpressionType compareOperator, string errorMessage) + { + switch (compareOperator) + { + case ExpressionType.Equal: + // do nothing + break; + case ExpressionType.GreaterThan: + compareOperator = ExpressionType.LessThan; + break; + case ExpressionType.GreaterThanOrEqual: + compareOperator = ExpressionType.LessThanOrEqual; + break; + case ExpressionType.LessThan: + compareOperator = ExpressionType.GreaterThan; + break; + case ExpressionType.LessThanOrEqual: + compareOperator = ExpressionType.GreaterThanOrEqual; + break; + default: + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, errorMessage)); + } + + return compareOperator; + } + + private static bool TryMatchStringCompare(MethodCallExpression left, ConstantExpression right, ExpressionType compareOperator) + { + if (left.Method.Equals(typeof(string).GetMethod("Compare", new Type[] { typeof(string), typeof(string) })) && left.Arguments.Count == 2) + { + // operator can only be =, >, >=, <, <= + switch (compareOperator) + { + case ExpressionType.Equal: + case ExpressionType.GreaterThan: + case ExpressionType.GreaterThanOrEqual: + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + break; + default: + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.StringCompareInvalidOperator)); + } + + // the constant value should be zero, otherwise we can't determine how to translate the expression + // it could be either integer or nullable integer + if (!(right.Type == typeof(int) && (int)right.Value == 0) && + !(right.Type == typeof(int?) && ((int?)right.Value).HasValue && ((int?)right.Value).Value == 0)) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.StringCompareInvalidConstant)); + } + + return true; + } + + return false; + } + + private static SqlScalarExpression VisitStringCompare( + MethodCallExpression left, + ExpressionType compareOperator, + bool reverseNodeType, + TranslationContext context) + { + if (reverseNodeType) + { + compareOperator = ReverseExpressionTypeForStrings(compareOperator, ClientResources.StringCompareInvalidOperator); + } + + SqlBinaryScalarOperatorKind op = GetBinaryOperatorKind(compareOperator, null); + + SqlScalarExpression leftExpression = ExpressionToSql.VisitNonSubqueryScalarExpression(left.Arguments[0], context); + SqlScalarExpression rightExpression = ExpressionToSql.VisitNonSubqueryScalarExpression(left.Arguments[1], context); + + return SqlBinaryScalarExpression.Create(op, leftExpression, rightExpression); + } + + private static SqlScalarExpression VisitTypeIs(TypeBinaryExpression inputExpression, TranslationContext context) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, inputExpression.NodeType)); + } + + public static SqlScalarExpression VisitConstant(ConstantExpression inputExpression, TranslationContext context) + { + if (inputExpression.Value == null) + { + return SqlLiteralScalarExpression.SqlNullLiteralScalarExpression; + } + + if (inputExpression.Type.IsNullable()) + { + return VisitConstant(Expression.Constant(inputExpression.Value, Nullable.GetUnderlyingType(inputExpression.Type)), context); + } + + if (context.Parameters != null && context.Parameters.TryGetValue(inputExpression.Value, out string paramName)) + { + SqlParameter sqlParameter = SqlParameter.Create(paramName); + return SqlParameterRefScalarExpression.Create(sqlParameter); + } + + Type constantType = inputExpression.Value.GetType(); + if (constantType.IsValueType) + { + if (inputExpression.Value is bool boolValue) + { + SqlBooleanLiteral literal = SqlBooleanLiteral.Create(boolValue); + return SqlLiteralScalarExpression.Create(literal); + } + + if (ExpressionToSql.TryGetSqlNumberLiteral(inputExpression.Value, out SqlNumberLiteral numberLiteral)) + { + return SqlLiteralScalarExpression.Create(numberLiteral); + } + + if (inputExpression.Value is Guid guidValue) + { + SqlStringLiteral literal = SqlStringLiteral.Create(guidValue.ToString()); + return SqlLiteralScalarExpression.Create(literal); + } + } + + if (inputExpression.Value is string stringValue) + { + SqlStringLiteral literal = SqlStringLiteral.Create(stringValue); + return SqlLiteralScalarExpression.Create(literal); + } + + if (typeof(Geometry).IsAssignableFrom(constantType)) + { + return GeometrySqlExpressionFactory.Construct(inputExpression); + } + + if (inputExpression.Value is IEnumerable enumerable) + { + List arrayItems = new List(); + + foreach (object item in enumerable) + { + arrayItems.Add(VisitConstant(Expression.Constant(item), context)); + } + + return SqlArrayCreateScalarExpression.Create(arrayItems.ToImmutableArray()); + } + + string serializedConstant = context.CosmosLinqSerializer.SerializeScalarExpression(inputExpression); + + return CosmosElement.Parse(serializedConstant).Accept(CosmosElementToSqlScalarExpressionVisitor.Singleton); + } + + private static SqlScalarExpression VisitConditional(ConditionalExpression inputExpression, TranslationContext context) + { + SqlScalarExpression conditionExpression = ExpressionToSql.VisitScalarExpression(inputExpression.Test, context); + SqlScalarExpression firstExpression = ExpressionToSql.VisitScalarExpression(inputExpression.IfTrue, context); + SqlScalarExpression secondExpression = ExpressionToSql.VisitScalarExpression(inputExpression.IfFalse, context); + + return SqlConditionalScalarExpression.Create(conditionExpression, firstExpression, secondExpression); + } + + private static SqlScalarExpression VisitParameter(ParameterExpression inputExpression, TranslationContext context) + { + Expression subst = context.LookupSubstitution(inputExpression); + if (subst != null) + { + return ExpressionToSql.VisitNonSubqueryScalarExpression(subst, context); + } + + string name = inputExpression.Name; + SqlIdentifier id = SqlIdentifier.Create(name); + return SqlPropertyRefScalarExpression.Create(null, id); + } + + private static SqlScalarExpression VisitMemberAccess(MemberExpression inputExpression, TranslationContext context) + { + SqlScalarExpression memberExpression = ExpressionToSql.VisitScalarExpression(inputExpression.Expression, context); + string memberName = inputExpression.Member.GetMemberName(context); + + // If the resulting memberName is null, then the indexer should be on the root of the object. + if (memberName == null) + { + return memberExpression; + } + + // if expression is nullable + if (inputExpression.Expression.Type.IsNullable()) + { + MemberNames memberNames = context.MemberNames; + + // ignore .Value + if (memberName == memberNames.Value) + { + return memberExpression; + } + + // convert .HasValue to IS_DEFINED expression + if (memberName == memberNames.HasValue) + { + return SqlFunctionCallScalarExpression.CreateBuiltin("IS_DEFINED", memberExpression); + } + } + + if (usePropertyRef) + { + SqlIdentifier propertyIdentifier = SqlIdentifier.Create(memberName); + SqlPropertyRefScalarExpression propertyRefExpression = SqlPropertyRefScalarExpression.Create(memberExpression, propertyIdentifier); + return propertyRefExpression; + } + else + { + SqlScalarExpression indexExpression = SqlLiteralScalarExpression.Create(SqlStringLiteral.Create(memberName)); + SqlMemberIndexerScalarExpression memberIndexerExpression = SqlMemberIndexerScalarExpression.Create(memberExpression, indexExpression); + return memberIndexerExpression; + } + } + + private static SqlScalarExpression[] VisitExpressionList(ReadOnlyCollection inputExpressionList, TranslationContext context) + { + SqlScalarExpression[] result = new SqlScalarExpression[inputExpressionList.Count]; + for (int i = 0; i < inputExpressionList.Count; i++) + { + SqlScalarExpression p = ExpressionToSql.VisitScalarExpression(inputExpressionList[i], context); + result[i] = p; + } + + return result; + } + + private static SqlObjectProperty VisitMemberAssignment(MemberAssignment inputExpression, TranslationContext context) + { + SqlScalarExpression assign = ExpressionToSql.VisitScalarExpression(inputExpression.Expression, context); + string memberName = inputExpression.Member.GetMemberName(context); + SqlPropertyName propName = SqlPropertyName.Create(memberName); + SqlObjectProperty prop = SqlObjectProperty.Create(propName, assign); + return prop; + } + + private static SqlObjectProperty VisitMemberMemberBinding(MemberMemberBinding inputExpression, TranslationContext context) + { + throw new DocumentQueryException(ClientResources.MemberBindingNotSupported); + } + + private static SqlObjectProperty VisitMemberListBinding(MemberListBinding inputExpression, TranslationContext context) + { + throw new DocumentQueryException(ClientResources.MemberBindingNotSupported); + } + + private static SqlObjectProperty[] VisitBindingList(ReadOnlyCollection inputExpressionList, TranslationContext context) + { + SqlObjectProperty[] list = new SqlObjectProperty[inputExpressionList.Count]; + for (int i = 0; i < inputExpressionList.Count; i++) + { + SqlObjectProperty b = ExpressionToSql.VisitBinding(inputExpressionList[i], context); + list[i] = b; + } + + return list; + } + + private static SqlObjectProperty[] CreateInitializers(ReadOnlyCollection arguments, ReadOnlyCollection members, TranslationContext context) + { + if (arguments.Count != members.Count) + { + throw new InvalidOperationException("Expected same number of arguments as members"); + } + + SqlObjectProperty[] result = new SqlObjectProperty[arguments.Count]; + for (int i = 0; i < arguments.Count; i++) + { + Expression arg = arguments[i]; + MemberInfo member = members[i]; + SqlScalarExpression value = ExpressionToSql.VisitScalarExpression(arg, context); + + string memberName = member.GetMemberName(context); + SqlPropertyName propName = SqlPropertyName.Create(memberName); + SqlObjectProperty prop = SqlObjectProperty.Create(propName, value); + result[i] = prop; + } + + return result; + } + + private static SqlSelectItem[] CreateSelectItems(ReadOnlyCollection arguments, ReadOnlyCollection members, TranslationContext context) + { + if (arguments.Count != members.Count) + { + throw new InvalidOperationException("Expected same number of arguments as members"); + } + + SqlSelectItem[] result = new SqlSelectItem[arguments.Count]; + for (int i = 0; i < arguments.Count; i++) + { + Expression arg = arguments[i]; + MemberInfo member = members[i]; + SqlScalarExpression selectExpression = ExpressionToSql.VisitScalarExpression(arg, context); + + string memberName = member.GetMemberName(context); + SqlIdentifier alias = SqlIdentifier.Create(memberName); + SqlSelectItem prop = SqlSelectItem.Create(selectExpression, alias); + result[i] = prop; + } + + return result; + } + + private static SqlScalarExpression VisitNew(NewExpression inputExpression, TranslationContext context) + { + if (typeof(Geometry).IsAssignableFrom(inputExpression.Type)) + { + return GeometrySqlExpressionFactory.Construct(inputExpression); + } + + if (inputExpression.Arguments.Count > 0) + { + if (inputExpression.Members == null) + { + throw new DocumentQueryException(ClientResources.ConstructorInvocationNotSupported); + } + + SqlObjectProperty[] propertyBindings = ExpressionToSql.CreateInitializers(inputExpression.Arguments, inputExpression.Members, context); + SqlObjectCreateScalarExpression create = SqlObjectCreateScalarExpression.Create(propertyBindings); + return create; + } + else + { + // no need to return anything; the initializer will generate the complete code + return null; + } + } + + private static SqlScalarExpression VisitMemberInit(MemberInitExpression inputExpression, TranslationContext context) + { + ExpressionToSql.VisitNew(inputExpression.NewExpression, context); // Return value is ignored + SqlObjectProperty[] propertyBindings = ExpressionToSql.VisitBindingList(inputExpression.Bindings, context); + SqlObjectCreateScalarExpression create = SqlObjectCreateScalarExpression.Create(propertyBindings); + return create; + } + + private static SqlScalarExpression VisitListInit(ListInitExpression inputExpression, TranslationContext context) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, inputExpression.NodeType)); + } + + private static SqlScalarExpression VisitNewArray(NewArrayExpression inputExpression, TranslationContext context) + { + SqlScalarExpression[] exprs = ExpressionToSql.VisitExpressionList(inputExpression.Expressions, context); + if (inputExpression.NodeType == ExpressionType.NewArrayInit) + { + SqlArrayCreateScalarExpression array = SqlArrayCreateScalarExpression.Create(exprs); + return array; + } + else + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, inputExpression.NodeType)); + } + } + + private static SqlScalarExpression VisitInvocation(InvocationExpression inputExpression, TranslationContext context) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, inputExpression.NodeType)); + } + + #endregion VISITOR + + #region Scalar and CollectionScalar Visitors + + private static Collection ConvertToCollection(SqlScalarExpression scalar) + { + if (usePropertyRef) + { + SqlPropertyRefScalarExpression propertyRefExpression = scalar as SqlPropertyRefScalarExpression; + if (propertyRefExpression == null) + { + throw new DocumentQueryException(ClientResources.PathExpressionsOnly); + } + + SqlInputPathCollection path = ConvertPropertyRefToPath(propertyRefExpression); + Collection result = new Collection(path); + return result; + } + else + { + SqlMemberIndexerScalarExpression memberIndexerExpression = scalar as SqlMemberIndexerScalarExpression; + if (memberIndexerExpression == null) + { + SqlPropertyRefScalarExpression propertyRefExpression = scalar as SqlPropertyRefScalarExpression; + if (propertyRefExpression == null) + { + throw new DocumentQueryException(ClientResources.PathExpressionsOnly); + } + + SqlInputPathCollection path = ConvertPropertyRefToPath(propertyRefExpression); + Collection result = new Collection(path); + return result; + } + else + { + SqlInputPathCollection path = ConvertMemberIndexerToPath(memberIndexerExpression); + Collection result = new Collection(path); + return result; + } + } + } + + /// + /// Convert the context's current query to a scalar Any collection + /// by wrapping it as following: SELECT VALUE COUNT(v0) > 0 FROM (current query) AS v0. + /// This is used in cases where LINQ expression ends with Any() which is a boolean scalar. + /// Normally Any would translate to SELECT VALUE EXISTS() subquery. However that wouldn't work + /// for these cases because it would result in a boolean value for each row instead of + /// one single "aggregated" boolean value. + /// + /// The translation context + /// The scalar Any collection + private static Collection ConvertToScalarAnyCollection(TranslationContext context) + { + SqlQuery query = context.CurrentQuery.FlattenAsPossible().GetSqlQuery(); + SqlCollection subqueryCollection = SqlSubqueryCollection.Create(query); + + ParameterExpression parameterExpression = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); + Binding binding = new Binding(parameterExpression, subqueryCollection, isInCollection: false, isInputParameter: true); + + context.CurrentQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc()); + context.CurrentQuery.AddBinding(binding); + + SqlSelectSpec selectSpec = SqlSelectValueSpec.Create( + SqlBinaryScalarExpression.Create( + SqlBinaryScalarOperatorKind.GreaterThan, + SqlFunctionCallScalarExpression.CreateBuiltin( + SqlFunctionCallScalarExpression.Names.Count, + SqlPropertyRefScalarExpression.Create(null, SqlIdentifier.Create(parameterExpression.Name))), + SqlLiteralScalarExpression.Create(SqlNumberLiteral.Create(0)))); + SqlSelectClause selectClause = SqlSelectClause.Create(selectSpec); + context.CurrentQuery.AddSelectClause(selectClause); + + return new Collection(LinqMethods.Any); + } + + private static SqlScalarExpression VisitNonSubqueryScalarExpression(Expression expression, ReadOnlyCollection parameters, TranslationContext context) + { + foreach (ParameterExpression par in parameters) + { + context.PushParameter(par, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); + } + + SqlScalarExpression scalarExpression = ExpressionToSql.VisitNonSubqueryScalarExpression(expression, context); + + foreach (ParameterExpression par in parameters) + { + context.PopParameter(); + } + + return scalarExpression; + } + + private static SqlScalarExpression VisitNonSubqueryScalarLambda(LambdaExpression lambdaExpression, TranslationContext context) + { + ReadOnlyCollection parameters = lambdaExpression.Parameters; + if (parameters.Count != 1) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, lambdaExpression.Body, 1, parameters.Count)); + } + + return ExpressionToSql.VisitNonSubqueryScalarExpression(lambdaExpression.Body, parameters, context); + } + + private static Collection VisitCollectionExpression(Expression expression, ReadOnlyCollection parameters, TranslationContext context) + { + foreach (ParameterExpression par in parameters) + { + context.PushParameter(par, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); + } + + Collection collection = ExpressionToSql.VisitCollectionExpression(expression, context, parameters.Count > 0 ? parameters.First().Name : ExpressionToSql.DefaultParameterName); + + foreach (ParameterExpression par in parameters) + { + context.PopParameter(); + } + + return collection; + } + + private static Collection VisitCollectionExpression(Expression expression, TranslationContext context, string parameterName) + { + Collection result; + switch (expression.NodeType) + { + case ExpressionType.Call: + result = ExpressionToSql.Translate(expression, context); + break; + + case ExpressionType.MemberAccess: + result = ExpressionToSql.VisitMemberAccessCollectionExpression(expression, context, parameterName); + break; + + default: + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, expression.NodeType)); + } + + return result; + } + + /// + /// Visit a lambda which is supposed to return a collection. + /// + /// LambdaExpression with a result which is a collection. + /// The translation context. + /// The collection computed by the lambda. + private static Collection VisitCollectionLambda(LambdaExpression lambdaExpression, TranslationContext context) + { + ReadOnlyCollection parameters = lambdaExpression.Parameters; + if (parameters.Count != 1) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, lambdaExpression.Body, 1, parameters.Count)); + } + + return ExpressionToSql.VisitCollectionExpression(lambdaExpression.Body, lambdaExpression.Parameters, context); + } + + /// + /// Visit an expression, usually a MemberAccess, then trigger parameter binding for that expression. + /// + /// The input expression + /// The current translation context + /// Parameter name is merely for readability + private static Collection VisitMemberAccessCollectionExpression(Expression inputExpression, TranslationContext context, string parameterName) + { + SqlScalarExpression body = ExpressionToSql.VisitNonSubqueryScalarExpression(inputExpression, context); + Type type = inputExpression.Type; + + Collection collection = ExpressionToSql.ConvertToCollection(body); + context.PushCollection(collection); + ParameterExpression parameter = context.GenerateFreshParameter(type, parameterName); + context.PushParameter(parameter, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); + context.PopParameter(); + context.PopCollection(); + + return new Collection(parameter.Name); + } + + /// + /// Visit a method call, construct the corresponding query in context.CurrentQuery. + /// At ExpressionToSql point only LINQ method calls are allowed. + /// These methods are static extension methods of IQueryable or IEnumerable. + /// + /// Method to translate. + /// Query translation context. + private static Collection VisitMethodCall(MethodCallExpression inputExpression, TranslationContext context) + { + context.PushMethod(inputExpression); + + Type declaringType = inputExpression.Method.DeclaringType; + + if ((declaringType != typeof(Queryable) + && declaringType != typeof(Enumerable) /*LINQ Methods*/ + && declaringType != typeof(CosmosLinqExtensions) /*OrderByRank*/) + || !inputExpression.Method.IsStatic /*Other extansion method*/) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.OnlyLINQMethodsAreSupported, inputExpression.Method.Name)); + } + + Type returnType = inputExpression.Method.ReturnType; + Type returnElementType = TypeSystem.GetElementType(returnType); + + if (inputExpression.Object != null) + { + throw new DocumentQueryException(ClientResources.ExpectedMethodCallsMethods); + } + + Expression inputCollection = inputExpression.Arguments[0]; // all these methods are static extension methods, so argument[0] is the collection + + Type inputElementType = TypeSystem.GetElementType(inputCollection.Type); + Collection collection = ExpressionToSql.Translate(inputCollection, context); + + context.PushCollection(collection); + + Collection result = new Collection(inputExpression.Method.Name); + bool shouldBeOnNewQuery = context.CurrentQuery.ShouldBeOnNewQuery(inputExpression.Method.Name, inputExpression.Arguments.Count); + context.PushSubqueryBinding(shouldBeOnNewQuery); + + if (context.LastExpressionIsGroupBy) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, "Group By cannot be followed by other methods")); + } + + switch (inputExpression.Method.Name) + { + case LinqMethods.Any: + { + result = new Collection(string.Empty); + + if (inputExpression.Arguments.Count == 2) + { + // Any is translated to an SELECT VALUE EXISTS() where Any operation itself is treated as a Where. + SqlWhereClause where = ExpressionToSql.VisitWhere(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddWhereClause(where, context); + } + break; + } + case LinqMethods.Average: + { + SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Avg); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + break; + } + case LinqMethods.Count: + { + SqlSelectClause select = ExpressionToSql.VisitCount(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + break; + } + case LinqMethods.Distinct: + { + SqlSelectClause select = ExpressionToSql.VisitDistinct(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + break; + } + case LinqMethods.FirstOrDefault: + { + if (inputExpression.Arguments.Count == 1) + { + // TOP is not allowed when OFFSET ... LIMIT is present. + if (!context.CurrentQuery.HasOffsetSpec()) + { + SqlNumberLiteral sqlNumberLiteral = SqlNumberLiteral.Create(1); + SqlTopSpec topSpec = SqlTopSpec.Create(sqlNumberLiteral); + context.CurrentQuery = context.CurrentQuery.AddTopSpec(topSpec); + } + + context.SetClientOperation(ScalarOperationKind.FirstOrDefault); + } + else + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, inputExpression.Method.Name, 0, inputExpression.Arguments.Count - 1)); + } + + break; + } + case LinqMethods.Max: + { + SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Max); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + break; + } + case LinqMethods.Min: + { + SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Min); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + break; + } + case LinqMethods.GroupBy: + { + context.CurrentQuery = context.PackageCurrentQueryIfNeccessary(); + result = ExpressionToSql.VisitGroupBy(returnElementType, inputExpression.Arguments, context); + context.LastExpressionIsGroupBy = true; + break; + } + case LinqMethods.OrderBy: + { + SqlOrderByClause orderBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, false, context); + context.CurrentQuery = context.CurrentQuery.AddOrderByClause(orderBy, context); + break; + } + case LinqMethods.OrderByDescending: + { + SqlOrderByClause orderBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, true, context); + context.CurrentQuery = context.CurrentQuery.AddOrderByClause(orderBy, context); + break; + } + case nameof(CosmosLinqExtensions.OrderByRank): + { + SqlOrderByClause orderBy = ExpressionToSql.VisitOrderByRank(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddOrderByClause(orderBy, context); + break; + } + case LinqMethods.Select: + { + SqlSelectClause select = ExpressionToSql.VisitSelect(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + break; + } + case LinqMethods.SelectMany: + { + context.CurrentQuery = context.PackageCurrentQueryIfNeccessary(); + result = ExpressionToSql.VisitSelectMany(inputExpression.Arguments, context); + break; + } + case LinqMethods.Skip: + { + SqlOffsetSpec offsetSpec = ExpressionToSql.VisitSkip(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddOffsetSpec(offsetSpec, context); + break; + } + case LinqMethods.Sum: + { + SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Sum); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + break; + } + case LinqMethods.Take: + { + if (context.CurrentQuery.HasOffsetSpec()) + { + SqlLimitSpec limitSpec = ExpressionToSql.VisitTakeLimit(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddLimitSpec(limitSpec, context); + } + else + { + SqlTopSpec topSpec = ExpressionToSql.VisitTakeTop(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddTopSpec(topSpec); + } + break; + } + case LinqMethods.ThenBy: + { + SqlOrderByClause thenBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, false, context); + context.CurrentQuery = context.CurrentQuery.UpdateOrderByClause(thenBy, context); + break; + } + case LinqMethods.ThenByDescending: + { + SqlOrderByClause thenBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, true, context); + context.CurrentQuery = context.CurrentQuery.UpdateOrderByClause(thenBy, context); + break; + } + case LinqMethods.Where: + { + SqlWhereClause where = ExpressionToSql.VisitWhere(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddWhereClause(where, context); + break; + } + default: + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.MethodNotSupported, inputExpression.Method.Name)); + } + + context.PopSubqueryBinding(); + context.PopCollection(); + context.PopMethod(); + return result; + } + + /// + /// Visit a method call, construct the corresponding query and return the select clause for the aggregate function. + /// At ExpressionToSql point only LINQ method calls are allowed. + /// These methods are static extension methods of IQueryable or IEnumerable. + /// + /// Method to translate. + /// Query translation context. + private static SqlSelectClause VisitGroupByAggregateMethodCall(MethodCallExpression inputExpression, TranslationContext context) + { + context.PushMethod(inputExpression); + + Type declaringType = inputExpression.Method.DeclaringType; + if ((declaringType != typeof(Queryable) && declaringType != typeof(Enumerable)) + || !inputExpression.Method.IsStatic) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.OnlyLINQMethodsAreSupported, inputExpression.Method.Name)); + } + + if (inputExpression.Object != null) + { + throw new DocumentQueryException(ClientResources.ExpectedMethodCallsMethods); + } + + if (context.LastExpressionIsGroupBy) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, "Group By cannot be followed by other methods")); + } + + SqlSelectClause select; + switch (inputExpression.Method.Name) + { + case LinqMethods.Average: + { + select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Avg); + break; + } + case LinqMethods.Count: + { + select = ExpressionToSql.VisitCount(inputExpression.Arguments, context); + break; + } + case LinqMethods.Max: + { + select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Max); + break; + } + case LinqMethods.Min: + { + select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Min); + break; + } + case LinqMethods.Sum: + { + select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Sum); + break; + } + default: + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.MethodNotSupported, inputExpression.Method.Name)); + } + + context.PopMethod(); + return select; + } + + /// + /// Determine if an expression should be translated to a subquery. + /// This only applies to expression that is inside a lamda. + /// + /// The input expression + /// The expression object kind of the expression + /// True if the method is either Min, Max, or Avg + /// True if subquery is needed, otherwise false + private static bool IsSubqueryScalarExpression(Expression expression, out SubqueryKind? expressionObjKind, out bool isMinMaxAvgMethod) + { + if (!(expression is MethodCallExpression methodCallExpression)) + { + expressionObjKind = null; + isMinMaxAvgMethod = false; + return false; + } + + string methodName = methodCallExpression.Method.Name; + bool isSubqueryExpression; + + isMinMaxAvgMethod = false; + + switch (methodName) + { + case LinqMethods.Min: + case LinqMethods.Max: + case LinqMethods.Average: + isMinMaxAvgMethod = true; + isSubqueryExpression = true; + expressionObjKind = SubqueryKind.SubqueryScalarExpression; + break; + + case LinqMethods.Sum: + isSubqueryExpression = true; + expressionObjKind = SubqueryKind.SubqueryScalarExpression; + break; + + case LinqMethods.Count: + if (methodCallExpression.Arguments.Count > 1) + { + isSubqueryExpression = true; + expressionObjKind = SubqueryKind.SubqueryScalarExpression; + } + else + { + SubqueryKind? objKind; + bool isMinMaxAvg; + isSubqueryExpression = ExpressionToSql.IsSubqueryScalarExpression( + methodCallExpression.Arguments[0] as MethodCallExpression, + out objKind, out isMinMaxAvg); + + if (isSubqueryExpression) + { + isSubqueryExpression = true; + expressionObjKind = SubqueryKind.SubqueryScalarExpression; + } + else + { + isSubqueryExpression = false; + expressionObjKind = null; + } + } + break; + + case LinqMethods.Any: + isSubqueryExpression = true; + expressionObjKind = SubqueryKind.ExistsScalarExpression; + break; + + case LinqMethods.Select: + case LinqMethods.SelectMany: + case LinqMethods.Where: + case LinqMethods.OrderBy: + case LinqMethods.OrderByDescending: + case nameof(CosmosLinqExtensions.OrderByRank): + case LinqMethods.ThenBy: + case LinqMethods.ThenByDescending: + case LinqMethods.Skip: + case LinqMethods.Take: + case LinqMethods.Distinct: + case LinqMethods.GroupBy: + isSubqueryExpression = true; + expressionObjKind = SubqueryKind.ArrayScalarExpression; + break; + + default: + isSubqueryExpression = false; + expressionObjKind = null; + break; + } + + return isSubqueryExpression; + } + + /// + /// Visit an lambda expression which is in side a lambda and translate it to a scalar expression or a subquery scalar expression. + /// See the other overload of this method for more details. + /// + /// The input lambda expression + /// The translation context + /// A scalar expression representing the input expression + private static SqlScalarExpression VisitScalarExpression(LambdaExpression lambda, TranslationContext context) + { + return ExpressionToSql.VisitScalarExpression( + lambda.Body, + lambda.Parameters, + context); + } + + /// + /// Visit an lambda expression which is inside a lambda and translate it to a scalar expression or a collection scalar expression. + /// If it is a collection scalar expression, e.g. should be translated to subquery such as SELECT VALUE ARRAY, SELECT VALUE EXISTS, + /// SELECT VALUE [aggregate], the subquery will be aliased to a new binding for the FROM clause. E.g. consider + /// Select(family => family.Children.Select(child => child.Grade)). Since the inner Select corresponds to a subquery, this method would + /// create a new binding of v0 to the subquery SELECT VALUE ARRAY(), and the inner expression will be just SELECT v0. + /// + /// The input expression + /// The translation context + /// A scalar expression representing the input expression + internal static SqlScalarExpression VisitScalarExpression(Expression expression, TranslationContext context) + { + return ExpressionToSql.VisitScalarExpression( + expression, + new ReadOnlyCollection(Array.Empty()), + context); + } + + internal static bool TryGetSqlNumberLiteral(object value, out SqlNumberLiteral sqlNumberLiteral) + { + sqlNumberLiteral = default(SqlNumberLiteral); + if (value is byte byteValue) + { + sqlNumberLiteral = SqlNumberLiteral.Create(byteValue); + } + else if (value is sbyte sbyteValue) + { + sqlNumberLiteral = SqlNumberLiteral.Create(sbyteValue); + } + else if (value is decimal decimalValue) + { + if ((decimalValue >= long.MinValue) && (decimalValue <= long.MaxValue) && (decimalValue % 1 == 0)) + { + sqlNumberLiteral = SqlNumberLiteral.Create(Convert.ToInt64(decimalValue)); + } + else + { + sqlNumberLiteral = SqlNumberLiteral.Create(Convert.ToDouble(decimalValue)); + } + } + else if (value is double doubleValue) + { + sqlNumberLiteral = SqlNumberLiteral.Create(doubleValue); + } + else if (value is float floatVlaue) + { + sqlNumberLiteral = SqlNumberLiteral.Create(floatVlaue); + } + else if (value is int intValue) + { + sqlNumberLiteral = SqlNumberLiteral.Create(intValue); + } + else if (value is uint uintValue) + { + sqlNumberLiteral = SqlNumberLiteral.Create(uintValue); + } + else if (value is long longValue) + { + sqlNumberLiteral = SqlNumberLiteral.Create(longValue); + } + else if (value is ulong ulongValue) + { + if (ulongValue <= long.MaxValue) + { + sqlNumberLiteral = SqlNumberLiteral.Create(Convert.ToInt64(ulongValue)); + } + else + { + sqlNumberLiteral = SqlNumberLiteral.Create(Convert.ToDouble(ulongValue)); + } + } + else if (value is short shortValue) + { + sqlNumberLiteral = SqlNumberLiteral.Create(shortValue); + } + else if (value is ushort ushortValue) + { + sqlNumberLiteral = SqlNumberLiteral.Create(ushortValue); + } + + return sqlNumberLiteral != default(SqlNumberLiteral); + } + + /// + /// Visit an lambda expression which is in side a lambda and translate it to a scalar expression or a collection scalar expression. + /// See the other overload of this method for more details. + /// + private static SqlScalarExpression VisitScalarExpression(Expression expression, + ReadOnlyCollection parameters, + TranslationContext context) + { + SubqueryKind? expressionObjKind; + bool isMinMaxAvgMethod; + bool shouldUseSubquery = ExpressionToSql.IsSubqueryScalarExpression(expression, out expressionObjKind, out isMinMaxAvgMethod); + + SqlScalarExpression sqlScalarExpression; + if (!shouldUseSubquery) + { + sqlScalarExpression = ExpressionToSql.VisitNonSubqueryScalarExpression(expression, parameters, context); + } + else + { + SqlQuery query = ExpressionToSql.CreateSubquery(expression, parameters, context); + + ParameterExpression parameterExpression = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); + SqlCollection subqueryCollection = ExpressionToSql.CreateSubquerySqlCollection( + query, + isMinMaxAvgMethod ? SubqueryKind.ArrayScalarExpression : expressionObjKind.Value); + + Binding newBinding = new Binding(parameterExpression, subqueryCollection, + isInCollection: false, isInputParameter: context.IsInMainBranchSelect()); + + context.CurrentSubqueryBinding.NewBindings.Add(newBinding); + + if (isMinMaxAvgMethod) + { + sqlScalarExpression = SqlMemberIndexerScalarExpression.Create( + SqlPropertyRefScalarExpression.Create(null, SqlIdentifier.Create(parameterExpression.Name)), + SqlLiteralScalarExpression.Create(SqlNumberLiteral.Create(0))); + } + else + { + sqlScalarExpression = SqlPropertyRefScalarExpression.Create(null, SqlIdentifier.Create(parameterExpression.Name)); + } + } + + return sqlScalarExpression; + } + + /// + /// Create a subquery SQL collection object for a SQL query + /// + /// The SQL query object + /// The subquery type + private static SqlCollection CreateSubquerySqlCollection(SqlQuery query, SubqueryKind subqueryType) + { + SqlCollection subqueryCollection; + switch (subqueryType) + { + case SubqueryKind.ArrayScalarExpression: + SqlArrayScalarExpression arrayScalarExpression = SqlArrayScalarExpression.Create(query); + query = SqlQuery.Create( + SqlSelectClause.Create(SqlSelectValueSpec.Create(arrayScalarExpression)), + fromClause: null, whereClause: null, groupByClause: null, orderByClause: null, offsetLimitClause: null); + break; + + case SubqueryKind.ExistsScalarExpression: + SqlExistsScalarExpression existsScalarExpression = SqlExistsScalarExpression.Create(query); + query = SqlQuery.Create( + SqlSelectClause.Create(SqlSelectValueSpec.Create(existsScalarExpression)), + fromClause: null, whereClause: null, groupByClause: null, orderByClause: null, offsetLimitClause: null); + break; + + case SubqueryKind.SubqueryScalarExpression: + // No need to wrap query as in ArrayScalarExpression, or ExistsScalarExpression + break; + + default: + throw new DocumentQueryException($"Unsupported subquery type {subqueryType}"); + } + + subqueryCollection = SqlSubqueryCollection.Create(query); + return subqueryCollection; + } + + /// + /// Create a subquery from a subquery scalar expression. + /// By visiting the collection expression, this builds a new QueryUnderConstruction on top of the current one + /// and then translate it to a SQL query while keeping the current QueryUnderConstruction in tact. + /// + /// The subquery scalar expression + /// The list of parameters of the expression + /// The translation context + /// A query corresponding to the collection expression + /// The QueryUnderConstruction remains unchanged after this. + private static SqlQuery CreateSubquery(Expression expression, ReadOnlyCollection parameters, TranslationContext context) + { + bool shouldBeOnNewQuery = context.CurrentSubqueryBinding.ShouldBeOnNewQuery; + + QueryUnderConstruction queryBeforeVisit = context.CurrentQuery; + QueryUnderConstruction packagedQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc(), context.CurrentQuery); + packagedQuery.FromParameters.SetInputParameter(typeof(object), context.CurrentQuery.GetInputParameterInContext(shouldBeOnNewQuery).Name, context.InScope); + context.CurrentQuery = packagedQuery; + + if (shouldBeOnNewQuery) context.CurrentSubqueryBinding.ShouldBeOnNewQuery = false; + + Collection collection = ExpressionToSql.VisitCollectionExpression(expression, parameters, context); + + QueryUnderConstruction subquery = context.CurrentQuery.GetSubquery(queryBeforeVisit); + context.CurrentSubqueryBinding.ShouldBeOnNewQuery = shouldBeOnNewQuery; + context.CurrentQuery = queryBeforeVisit; + + SqlQuery sqlSubquery = subquery.FlattenAsPossible().GetSqlQuery(); + return sqlSubquery; + } + + #endregion Scalar and CollectionScalar Visitors + + #region LINQ Specific Visitors + + private static SqlWhereClause VisitWhere(ReadOnlyCollection arguments, TranslationContext context) + { + if (arguments.Count != 2) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.Where, 2, arguments.Count)); + } + + LambdaExpression function = Utilities.GetLambda(arguments[1]); + SqlScalarExpression sqlfunc = ExpressionToSql.VisitScalarExpression(function, context); + SqlWhereClause where = SqlWhereClause.Create(sqlfunc); + return where; + } + + private static SqlSelectClause VisitSelect(ReadOnlyCollection arguments, TranslationContext context) + { + if (arguments.Count != 2) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.Select, 2, arguments.Count)); + } + + LambdaExpression lambda = Utilities.GetLambda(arguments[1]); + + SqlScalarExpression sqlfunc = ExpressionToSql.VisitScalarExpression(lambda, context); + SqlSelectSpec sqlSpec = SqlSelectValueSpec.Create(sqlfunc); + SqlSelectClause select = SqlSelectClause.Create(sqlSpec, null); + return select; + } + + private static Collection VisitSelectMany(ReadOnlyCollection arguments, TranslationContext context) + { + if (arguments.Count != 2) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.SelectMany, 2, arguments.Count)); + } + + LambdaExpression lambda = Utilities.GetLambda(arguments[1]); + + // If there is Distinct, Take or OrderBy the lambda then it needs to be in a subquery. + bool requireLocalExecution = false; + + for (MethodCallExpression methodCall = lambda.Body as MethodCallExpression; + methodCall != null; + methodCall = methodCall.Arguments[0] as MethodCallExpression) + { + string methodName = methodCall.Method.Name; + requireLocalExecution |= methodName.Equals(LinqMethods.Distinct) || methodName.Equals(LinqMethods.Take) || methodName.Equals(LinqMethods.OrderBy) || methodName.Equals(LinqMethods.OrderByDescending); + } + + Collection collection; + if (!requireLocalExecution) + { + collection = ExpressionToSql.VisitCollectionLambda(lambda, context); + } + else + { + collection = new Collection(string.Empty); + Binding binding; + SqlQuery query = ExpressionToSql.CreateSubquery(lambda.Body, lambda.Parameters, context); + SqlCollection subqueryCollection = SqlSubqueryCollection.Create(query); + ParameterExpression parameterExpression = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); + binding = new Binding(parameterExpression, subqueryCollection, isInCollection: false, isInputParameter: true); + context.CurrentQuery.FromParameters.Add(binding); + } + + return collection; + } + + private static Collection VisitGroupBy(Type returnElementType, ReadOnlyCollection arguments, TranslationContext context) + { + if (arguments.Count != 3) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.GroupBy, 3, arguments.Count)); + } + + // Key Selector handling + // First argument is input, second is key selector and third is value selector LambdaExpression keySelectorLambda = Utilities.GetLambda(arguments[1]); Collection collection = new Collection("Group By"); @@ -1794,7 +1804,7 @@ private static Collection VisitGroupBy(Type returnElementType, ReadOnlyCollectio SqlGroupByClause groupby; ParameterExpression parameterExpression; - switch (keySelectorLambda.Body.NodeType) + switch (keySelectorLambda.Body.NodeType) { case ExpressionType.Parameter: case ExpressionType.Call: @@ -1862,10 +1872,10 @@ private static Collection VisitGroupBy(Type returnElementType, ReadOnlyCollectio // Value Selector Handingling // Translate the body of the value selector lambda - Expression valueSelectorExpression = Utilities.GetLambda(arguments[2]).Body; - - // The value selector function needs to be either a MethodCall or an AnonymousType - switch (valueSelectorExpression.NodeType) + Expression valueSelectorExpression = Utilities.GetLambda(arguments[2]).Body; + + // The value selector function needs to be either a MethodCall or an AnonymousType + switch (valueSelectorExpression.NodeType) { case ExpressionType.MemberAccess: { @@ -1888,7 +1898,7 @@ private static Collection VisitGroupBy(Type returnElementType, ReadOnlyCollectio } } break; - } + } case ExpressionType.Constant: { ConstantExpression constantExpression = (ConstantExpression)valueSelectorExpression; @@ -1916,7 +1926,7 @@ private static Collection VisitGroupBy(Type returnElementType, ReadOnlyCollectio SqlSelectClause select = ExpressionToSql.VisitGroupByAggregateMethodCall(methodCallExpression, context); context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); break; - } + } case ExpressionType.New: { // Add select item clause at the end of this method @@ -1991,19 +2001,19 @@ private static Collection VisitGroupBy(Type returnElementType, ReadOnlyCollectio throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, arg.NodeType)); } } - - SqlSelectListSpec sqlSpec = SqlSelectListSpec.Create(selectItems); - SqlSelectClause select = SqlSelectClause.Create(sqlSpec, null); + + SqlSelectListSpec sqlSpec = SqlSelectListSpec.Create(selectItems); + SqlSelectClause select = SqlSelectClause.Create(sqlSpec, null); context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); - + break; - } - default: - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, valueSelectorExpression.NodeType)); + } + default: + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, valueSelectorExpression.NodeType)); } // Pop the correct number of items off the parameter stack - switch (keySelectorLambda.Body.NodeType) + switch (keySelectorLambda.Body.NodeType) { case ExpressionType.Parameter: case ExpressionType.Call: @@ -2028,352 +2038,366 @@ private static Collection VisitGroupBy(Type returnElementType, ReadOnlyCollectio break; } - return collection; - } - - private static SqlOrderByClause VisitOrderBy(ReadOnlyCollection arguments, bool isDescending, TranslationContext context) - { - if (arguments.Count != 2) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.OrderBy, 2, arguments.Count)); - } - - LambdaExpression lambda = Utilities.GetLambda(arguments[1]); - SqlScalarExpression sqlfunc = ExpressionToSql.VisitScalarExpression(lambda, context); - SqlOrderByItem orderByItem = SqlOrderByItem.Create(sqlfunc, isDescending); - SqlOrderByClause orderby = SqlOrderByClause.Create(new SqlOrderByItem[] { orderByItem }); - return orderby; - } - - private static bool TryGetTopSkipTakeLiteral( - SqlScalarExpression scalarExpression, - TranslationContext context, - out SqlNumberLiteral literal) - { - literal = default(SqlNumberLiteral); - - if (scalarExpression is SqlLiteralScalarExpression literalScalarExpression) - { - if (literalScalarExpression.Literal is SqlNumberLiteral numberLiteral) - { - // After a member access in SelectMany's lambda, if there is only Top/Skip/Take then - // it is necessary to trigger the binding because Skip is just a spec with no binding on its own. - // This can be done by pushing and popping a temporary parameter. E.g. In SelectMany(f => f.Children.Skip(1)), - // it's necessary to consider Skip as Skip(x => x, 1) to bind x to f.Children. Similarly for Top and Limit. - ParameterExpression parameter = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); - context.PushParameter(parameter, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); - context.PopParameter(); - - literal = numberLiteral; - } - } - - return (literal != default(SqlNumberLiteral)) && (literal.Value >= 0); - } - - private static bool TryGetTopSkipTakeParameter( - SqlScalarExpression scalarExpression, - TranslationContext context, - out SqlParameter sqlParameter) - { - sqlParameter = default(SqlParameter); - SqlParameterRefScalarExpression parameterRefScalarExpression = scalarExpression as SqlParameterRefScalarExpression; - if (parameterRefScalarExpression != null) - { - sqlParameter = parameterRefScalarExpression.Parameter; - } - - return (sqlParameter != default(SqlParameter)) && !string.IsNullOrEmpty(sqlParameter.Name); - } - - private static SqlOffsetSpec VisitSkip(ReadOnlyCollection arguments, TranslationContext context) - { - if (arguments.Count != 2) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.Skip, 2, arguments.Count)); - } - - Expression expression = arguments[1]; - if (expression == null) - { - throw new ArgumentNullException(nameof(expression)); - } - - SqlScalarExpression scalarExpression = ExpressionToSql.VisitScalarExpression(expression, context); - SqlNumberLiteral offsetNumberLiteral; - SqlParameter sqlParameter; - SqlOffsetSpec offsetSpec; - - // skipExpression must be number literal - if (TryGetTopSkipTakeLiteral(scalarExpression, context, out offsetNumberLiteral)) - { - offsetSpec = SqlOffsetSpec.Create(offsetNumberLiteral); - } - else if (TryGetTopSkipTakeParameter(scalarExpression, context, out sqlParameter)) - { - offsetSpec = SqlOffsetSpec.Create(sqlParameter); - } - else - { - // .Skip() has only one overload that takes int - // so we really should always get a number (integer) literal here - // the below throw serves as assert - throw new ArgumentException(ClientResources.InvalidSkipValue); - } - - return offsetSpec; - } - - private static SqlLimitSpec VisitTakeLimit(ReadOnlyCollection arguments, TranslationContext context) - { - if (arguments.Count != 2) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.Take, 2, arguments.Count)); - } - - Expression expression = arguments[1]; - if (expression == null) - { - throw new ArgumentNullException(nameof(expression)); - } - - SqlScalarExpression scalarExpression = ExpressionToSql.VisitScalarExpression(expression, context); - SqlNumberLiteral takeNumberLiteral; - SqlParameter sqlParameter; - SqlLimitSpec limitSpec; - - // takeExpression must be number literal - if (TryGetTopSkipTakeLiteral(scalarExpression, context, out takeNumberLiteral)) - { - limitSpec = SqlLimitSpec.Create(takeNumberLiteral); - } - else if (TryGetTopSkipTakeParameter(scalarExpression, context, out sqlParameter)) - { - limitSpec = SqlLimitSpec.Create(sqlParameter); - } - else - { - // .Take() has only one overload that takes int - // so we really should always get a number (integer) literal here - // the below throw serves as assert - throw new ArgumentException(ClientResources.InvalidTakeValue); - } - - return limitSpec; - } - - private static SqlTopSpec VisitTakeTop(ReadOnlyCollection arguments, TranslationContext context) - { - if (arguments.Count != 2) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.Take, 2, arguments.Count)); - } - - Expression expression = arguments[1]; - if (expression == null) - { - throw new ArgumentNullException(nameof(expression)); - } - - SqlScalarExpression scalarExpression = ExpressionToSql.VisitScalarExpression(expression, context); - SqlNumberLiteral takeNumberLiteral; - SqlParameter sqlParameter; - SqlTopSpec topSpec; - - // takeExpression must be number literal - if (TryGetTopSkipTakeLiteral(scalarExpression, context, out takeNumberLiteral)) - { - topSpec = SqlTopSpec.Create(takeNumberLiteral); - } - else if (TryGetTopSkipTakeParameter(scalarExpression, context, out sqlParameter)) - { - topSpec = SqlTopSpec.Create(sqlParameter); - } - else - { - // .Take() has only one overload that takes int - // so we really should always get a number (integer) literal here - // the below throw serves as assert - throw new ArgumentException(ClientResources.InvalidTakeValue); - } - - return topSpec; - } - - private static SqlSelectClause VisitAggregateFunction( - ReadOnlyCollection arguments, - TranslationContext context, - string aggregateFunctionName) - { - SqlScalarExpression aggregateExpression; - if (arguments.Count == 1) - { - // Need to trigger parameter binding for cases where an aggregate function immediately follows a member access. - ParameterExpression parameter = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); - context.PushParameter(parameter, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); - - // If there is a groupby, since there is no argument to the aggregate, we consider it to be invoked on the source collection, and not the group by keys - aggregateExpression = ExpressionToSql.VisitParameter(parameter, context); - context.PopParameter(); - } - else if (arguments.Count == 2) - { - LambdaExpression lambda = Utilities.GetLambda(arguments[1]); - + return collection; + } + + private static SqlOrderByClause VisitOrderBy(ReadOnlyCollection arguments, bool isDescending, TranslationContext context) + { + if (arguments.Count != 2) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.OrderBy, 2, arguments.Count)); + } + + LambdaExpression lambda = Utilities.GetLambda(arguments[1]); + SqlScalarExpression sqlfunc = ExpressionToSql.VisitScalarExpression(lambda, context); + SqlOrderByItem orderByItem = SqlOrderByItem.Create(sqlfunc, isDescending); + SqlOrderByClause orderby = SqlOrderByClause.Create(new SqlOrderByItem[] { orderByItem }); + return orderby; + } + + private static SqlOrderByClause VisitOrderByRank(ReadOnlyCollection arguments, TranslationContext context) + { + if (arguments.Count != 2) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.OrderBy, 2, arguments.Count)); + } + LambdaExpression lambda = Utilities.GetLambda(arguments[1]); + SqlScalarExpression sqlfunc = ExpressionToSql.VisitScalarExpression(lambda, context); + SqlOrderByItem scoreFuncOrderByItem = SqlOrderByItem.Create(sqlfunc, isDescending: null); + SqlOrderByClause orderby = SqlOrderByClause.Create(rank: true, new SqlOrderByItem[] { scoreFuncOrderByItem }); + + return orderby; + } + + private static bool TryGetTopSkipTakeLiteral( + SqlScalarExpression scalarExpression, + TranslationContext context, + out SqlNumberLiteral literal) + { + literal = default(SqlNumberLiteral); + + if (scalarExpression is SqlLiteralScalarExpression literalScalarExpression) + { + if (literalScalarExpression.Literal is SqlNumberLiteral numberLiteral) + { + // After a member access in SelectMany's lambda, if there is only Top/Skip/Take then + // it is necessary to trigger the binding because Skip is just a spec with no binding on its own. + // This can be done by pushing and popping a temporary parameter. E.g. In SelectMany(f => f.Children.Skip(1)), + // it's necessary to consider Skip as Skip(x => x, 1) to bind x to f.Children. Similarly for Top and Limit. + ParameterExpression parameter = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); + context.PushParameter(parameter, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); + context.PopParameter(); + + literal = numberLiteral; + } + } + + return (literal != default(SqlNumberLiteral)) && (literal.Value >= 0); + } + + private static bool TryGetTopSkipTakeParameter( + SqlScalarExpression scalarExpression, + TranslationContext context, + out SqlParameter sqlParameter) + { + sqlParameter = default(SqlParameter); + SqlParameterRefScalarExpression parameterRefScalarExpression = scalarExpression as SqlParameterRefScalarExpression; + if (parameterRefScalarExpression != null) + { + sqlParameter = parameterRefScalarExpression.Parameter; + } + + return (sqlParameter != default(SqlParameter)) && !string.IsNullOrEmpty(sqlParameter.Name); + } + + private static SqlOffsetSpec VisitSkip(ReadOnlyCollection arguments, TranslationContext context) + { + if (arguments.Count != 2) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.Skip, 2, arguments.Count)); + } + + Expression expression = arguments[1]; + if (expression == null) + { + throw new ArgumentNullException(nameof(expression)); + } + + SqlScalarExpression scalarExpression = ExpressionToSql.VisitScalarExpression(expression, context); + SqlNumberLiteral offsetNumberLiteral; + SqlParameter sqlParameter; + SqlOffsetSpec offsetSpec; + + // skipExpression must be number literal + if (TryGetTopSkipTakeLiteral(scalarExpression, context, out offsetNumberLiteral)) + { + offsetSpec = SqlOffsetSpec.Create(offsetNumberLiteral); + } + else if (TryGetTopSkipTakeParameter(scalarExpression, context, out sqlParameter)) + { + offsetSpec = SqlOffsetSpec.Create(sqlParameter); + } + else + { + // .Skip() has only one overload that takes int + // so we really should always get a number (integer) literal here + // the below throw serves as assert + throw new ArgumentException(ClientResources.InvalidSkipValue); + } + + return offsetSpec; + } + + private static SqlLimitSpec VisitTakeLimit(ReadOnlyCollection arguments, TranslationContext context) + { + if (arguments.Count != 2) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.Take, 2, arguments.Count)); + } + + Expression expression = arguments[1]; + if (expression == null) + { + throw new ArgumentNullException(nameof(expression)); + } + + SqlScalarExpression scalarExpression = ExpressionToSql.VisitScalarExpression(expression, context); + SqlNumberLiteral takeNumberLiteral; + SqlParameter sqlParameter; + SqlLimitSpec limitSpec; + + // takeExpression must be number literal + if (TryGetTopSkipTakeLiteral(scalarExpression, context, out takeNumberLiteral)) + { + limitSpec = SqlLimitSpec.Create(takeNumberLiteral); + } + else if (TryGetTopSkipTakeParameter(scalarExpression, context, out sqlParameter)) + { + limitSpec = SqlLimitSpec.Create(sqlParameter); + } + else + { + // .Take() has only one overload that takes int + // so we really should always get a number (integer) literal here + // the below throw serves as assert + throw new ArgumentException(ClientResources.InvalidTakeValue); + } + + return limitSpec; + } + + private static SqlTopSpec VisitTakeTop(ReadOnlyCollection arguments, TranslationContext context) + { + if (arguments.Count != 2) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.Take, 2, arguments.Count)); + } + + Expression expression = arguments[1]; + if (expression == null) + { + throw new ArgumentNullException(nameof(expression)); + } + + SqlScalarExpression scalarExpression = ExpressionToSql.VisitScalarExpression(expression, context); + SqlNumberLiteral takeNumberLiteral; + SqlParameter sqlParameter; + SqlTopSpec topSpec; + + // takeExpression must be number literal + if (TryGetTopSkipTakeLiteral(scalarExpression, context, out takeNumberLiteral)) + { + topSpec = SqlTopSpec.Create(takeNumberLiteral); + } + else if (TryGetTopSkipTakeParameter(scalarExpression, context, out sqlParameter)) + { + topSpec = SqlTopSpec.Create(sqlParameter); + } + else + { + // .Take() has only one overload that takes int + // so we really should always get a number (integer) literal here + // the below throw serves as assert + throw new ArgumentException(ClientResources.InvalidTakeValue); + } + + return topSpec; + } + + private static SqlSelectClause VisitAggregateFunction( + ReadOnlyCollection arguments, + TranslationContext context, + string aggregateFunctionName) + { + SqlScalarExpression aggregateExpression; + if (arguments.Count == 1) + { + // Need to trigger parameter binding for cases where an aggregate function immediately follows a member access. + ParameterExpression parameter = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); + context.PushParameter(parameter, context.CurrentSubqueryBinding.ShouldBeOnNewQuery); + + // If there is a groupby, since there is no argument to the aggregate, we consider it to be invoked on the source collection, and not the group by keys + aggregateExpression = ExpressionToSql.VisitParameter(parameter, context); + context.PopParameter(); + } + else if (arguments.Count == 2) + { + LambdaExpression lambda = Utilities.GetLambda(arguments[1]); + aggregateExpression = context.CurrentQuery.GroupByParameter != null ? ExpressionToSql.VisitNonSubqueryScalarLambda(lambda, context) - : ExpressionToSql.VisitScalarExpression(lambda, context); - } - else - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, aggregateFunctionName, 2, arguments.Count)); - } - - SqlFunctionCallScalarExpression aggregateFunctionCall; - aggregateFunctionCall = SqlFunctionCallScalarExpression.CreateBuiltin(aggregateFunctionName, aggregateExpression); - - SqlSelectSpec selectSpec = SqlSelectValueSpec.Create(aggregateFunctionCall); - SqlSelectClause selectClause = SqlSelectClause.Create(selectSpec, null); - return selectClause; - } - - private static SqlSelectClause VisitDistinct( - ReadOnlyCollection arguments, - TranslationContext context) - { - string functionName = LinqMethods.Distinct; - if (arguments.Count != 1) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, functionName, 1, arguments.Count)); - } - - // We consider Distinct as Distinct(v0 => v0) - // It's necessary to visit this identity method to replace the parameters names - ParameterExpression parameter = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); - LambdaExpression identityLambda = Expression.Lambda(parameter, parameter); - SqlScalarExpression sqlfunc = ExpressionToSql.VisitNonSubqueryScalarLambda(identityLambda, context); - SqlSelectSpec sqlSpec = SqlSelectValueSpec.Create(sqlfunc); - SqlSelectClause select = SqlSelectClause.Create(sqlSpec, topSpec: null, hasDistinct: true); - return select; - } - - private static SqlSelectClause VisitCount( - ReadOnlyCollection arguments, - TranslationContext context) - { - SqlScalarExpression countExpression; - countExpression = SqlLiteralScalarExpression.Create(SqlNumberLiteral.Create((Number64)1)); - - if (arguments.Count == 2) - { - SqlWhereClause whereClause = ExpressionToSql.VisitWhere(arguments, context); - context.CurrentQuery = context.CurrentQuery.AddWhereClause(whereClause, context); - } - else if (arguments.Count != 1) - { - throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.Count, 2, arguments.Count)); - } - - SqlSelectSpec selectSpec = SqlSelectValueSpec.Create(SqlFunctionCallScalarExpression.CreateBuiltin(SqlFunctionCallScalarExpression.Names.Count, countExpression)); - SqlSelectClause selectClause = SqlSelectClause.Create(selectSpec, null); - return selectClause; - } - - /// - /// Property references that refer to array-valued properties are converted to collection references. - /// - /// Property reference object. - /// An inputPathCollection which contains the same property path as the propRef. - private static SqlInputPathCollection ConvertPropertyRefToPath(SqlPropertyRefScalarExpression propRef) - { - List identifiers = new List(); - while (true) - { - identifiers.Add(propRef.Identifier); - SqlScalarExpression parent = propRef.Member; - if (parent == null) - { - break; - } - - if (parent is SqlPropertyRefScalarExpression) - { - propRef = parent as SqlPropertyRefScalarExpression; - } - else - { - throw new DocumentQueryException(ClientResources.NotSupported); - } - } - - if (identifiers.Count == 0) - { - throw new DocumentQueryException(ClientResources.NotSupported); - } - - SqlPathExpression path = null; - for (int i = identifiers.Count - 2; i >= 0; i--) - { - SqlIdentifier identifer = identifiers[i]; - path = SqlIdentifierPathExpression.Create(path, identifer); - } - - SqlIdentifier last = identifiers[identifiers.Count - 1]; - SqlInputPathCollection result = SqlInputPathCollection.Create(last, path); - return result; - } - - private static SqlInputPathCollection ConvertMemberIndexerToPath(SqlMemberIndexerScalarExpression memberIndexerExpression) - { - // root.Children.Age ==> root["Children"]["Age"] - List literals = new List(); - while (true) - { - literals.Add((SqlStringLiteral)((SqlLiteralScalarExpression)memberIndexerExpression.Indexer).Literal); - SqlScalarExpression parent = memberIndexerExpression.Member; - if (parent == null) - { - break; + : ExpressionToSql.VisitScalarExpression(lambda, context); + } + else + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, aggregateFunctionName, 2, arguments.Count)); + } + + SqlFunctionCallScalarExpression aggregateFunctionCall; + aggregateFunctionCall = SqlFunctionCallScalarExpression.CreateBuiltin(aggregateFunctionName, aggregateExpression); + + SqlSelectSpec selectSpec = SqlSelectValueSpec.Create(aggregateFunctionCall); + SqlSelectClause selectClause = SqlSelectClause.Create(selectSpec, null); + return selectClause; + } + + private static SqlSelectClause VisitDistinct( + ReadOnlyCollection arguments, + TranslationContext context) + { + string functionName = LinqMethods.Distinct; + if (arguments.Count != 1) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, functionName, 1, arguments.Count)); + } + + // We consider Distinct as Distinct(v0 => v0) + // It's necessary to visit this identity method to replace the parameters names + ParameterExpression parameter = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName); + LambdaExpression identityLambda = Expression.Lambda(parameter, parameter); + SqlScalarExpression sqlfunc = ExpressionToSql.VisitNonSubqueryScalarLambda(identityLambda, context); + SqlSelectSpec sqlSpec = SqlSelectValueSpec.Create(sqlfunc); + SqlSelectClause select = SqlSelectClause.Create(sqlSpec, topSpec: null, hasDistinct: true); + return select; + } + + private static SqlSelectClause VisitCount( + ReadOnlyCollection arguments, + TranslationContext context) + { + SqlScalarExpression countExpression; + countExpression = SqlLiteralScalarExpression.Create(SqlNumberLiteral.Create((Number64)1)); + + if (arguments.Count == 2) + { + SqlWhereClause whereClause = ExpressionToSql.VisitWhere(arguments, context); + context.CurrentQuery = context.CurrentQuery.AddWhereClause(whereClause, context); + } + else if (arguments.Count != 1) + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.Count, 2, arguments.Count)); + } + + SqlSelectSpec selectSpec = SqlSelectValueSpec.Create(SqlFunctionCallScalarExpression.CreateBuiltin(SqlFunctionCallScalarExpression.Names.Count, countExpression)); + SqlSelectClause selectClause = SqlSelectClause.Create(selectSpec, null); + return selectClause; + } + + /// + /// Property references that refer to array-valued properties are converted to collection references. + /// + /// Property reference object. + /// An inputPathCollection which contains the same property path as the propRef. + private static SqlInputPathCollection ConvertPropertyRefToPath(SqlPropertyRefScalarExpression propRef) + { + List identifiers = new List(); + while (true) + { + identifiers.Add(propRef.Identifier); + SqlScalarExpression parent = propRef.Member; + if (parent == null) + { + break; + } + + if (parent is SqlPropertyRefScalarExpression) + { + propRef = parent as SqlPropertyRefScalarExpression; } - - if (parent is SqlPropertyRefScalarExpression sqlPropertyRefScalarExpression) - { - literals.Add(SqlStringLiteral.Create(sqlPropertyRefScalarExpression.Identifier.Value)); - break; - } - - if (parent is SqlMemberIndexerScalarExpression) - { - memberIndexerExpression = parent as SqlMemberIndexerScalarExpression; - } - else - { - throw new DocumentQueryException(ClientResources.NotSupported); - } - } - - if (literals.Count == 0) - { - throw new ArgumentException("memberIndexerExpression"); - } - - SqlPathExpression path = null; - for (int i = literals.Count - 2; i >= 0; i--) - { - path = SqlStringPathExpression.Create(path, literals[i]); - } - - SqlInputPathCollection result = SqlInputPathCollection.Create(SqlIdentifier.Create(literals[literals.Count - 1].Value), path); - return result; - } - - #endregion LINQ Specific Visitors - - private enum SubqueryKind - { - ArrayScalarExpression, - ExistsScalarExpression, - SubqueryScalarExpression, - } - } -} + else + { + throw new DocumentQueryException(ClientResources.NotSupported); + } + } + + if (identifiers.Count == 0) + { + throw new DocumentQueryException(ClientResources.NotSupported); + } + + SqlPathExpression path = null; + for (int i = identifiers.Count - 2; i >= 0; i--) + { + SqlIdentifier identifer = identifiers[i]; + path = SqlIdentifierPathExpression.Create(path, identifer); + } + + SqlIdentifier last = identifiers[identifiers.Count - 1]; + SqlInputPathCollection result = SqlInputPathCollection.Create(last, path); + return result; + } + + private static SqlInputPathCollection ConvertMemberIndexerToPath(SqlMemberIndexerScalarExpression memberIndexerExpression) + { + // root.Children.Age ==> root["Children"]["Age"] + List literals = new List(); + while (true) + { + literals.Add((SqlStringLiteral)((SqlLiteralScalarExpression)memberIndexerExpression.Indexer).Literal); + SqlScalarExpression parent = memberIndexerExpression.Member; + if (parent == null) + { + break; + } + + if (parent is SqlPropertyRefScalarExpression sqlPropertyRefScalarExpression) + { + literals.Add(SqlStringLiteral.Create(sqlPropertyRefScalarExpression.Identifier.Value)); + break; + } + + if (parent is SqlMemberIndexerScalarExpression) + { + memberIndexerExpression = parent as SqlMemberIndexerScalarExpression; + } + else + { + throw new DocumentQueryException(ClientResources.NotSupported); + } + } + + if (literals.Count == 0) + { + throw new ArgumentException("memberIndexerExpression"); + } + + SqlPathExpression path = null; + for (int i = literals.Count - 2; i >= 0; i--) + { + path = SqlStringPathExpression.Create(path, literals[i]); + } + + SqlInputPathCollection result = SqlInputPathCollection.Create(SqlIdentifier.Create(literals[literals.Count - 1].Value), path); + return result; + } + + #endregion LINQ Specific Visitors + + private enum SubqueryKind + { + ArrayScalarExpression, + ExistsScalarExpression, + SubqueryScalarExpression, + } + } +} diff --git a/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs b/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs index a5cfb4784a..fbb8d27872 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs @@ -458,7 +458,7 @@ private SqlOrderByClause Substitute(SqlSelectSpec spec, SqlIdentifier inputParam SqlScalarExpression substituted = SqlExpressionManipulation.Substitute(replaced, inputParam, orderByClause.OrderByItems[i].Expression); substitutedItems[i] = SqlOrderByItem.Create(substituted, orderByClause.OrderByItems[i].IsDescending); } - SqlOrderByClause result = SqlOrderByClause.Create(substitutedItems); + SqlOrderByClause result = SqlOrderByClause.Create(orderByClause.Rank, substitutedItems); return result; } @@ -550,17 +550,21 @@ public bool ShouldBeOnNewQuery(string methodName, int argumentCount) case LinqMethods.Where: // Where expression parameter needs to be substituted if necessary so // It is not needed in Select distinct because the Select distinct would have the necessary parameter name adjustment. - case LinqMethods.Any: + case LinqMethods.Any: + case nameof(CosmosLinqExtensions.OrderByRank): case LinqMethods.OrderBy: case LinqMethods.OrderByDescending: case LinqMethods.ThenBy: case LinqMethods.ThenByDescending: case LinqMethods.Distinct: - // New query is needed when there is already a Take or a non-distinct Select + // New query is needed when there is already a Take or a non-distinct Select + // Or when an Order By Rank is added to a query with an Order By clause (and vice versa) shouldPackage = (this.topSpec != null) || (this.offsetSpec != null) || (this.selectClause != null && !this.selectClause.HasDistinct) || - (this.groupByClause != null); + (this.groupByClause != null) || + (this.orderByClause != null && (methodName == nameof(CosmosLinqExtensions.OrderByRank))) || + (this.orderByClause != null && (this.orderByClause.Rank == true) && (methodName == LinqMethods.OrderBy)); break; case LinqMethods.GroupBy: @@ -648,7 +652,7 @@ public QueryUnderConstruction UpdateOrderByClause(SqlOrderByClause thenBy, Trans { List items = new List(context.CurrentQuery.orderByClause.OrderByItems); items.AddRange(thenBy.OrderByItems); - context.CurrentQuery.orderByClause = SqlOrderByClause.Create(items.ToImmutableArray()); + context.CurrentQuery.orderByClause = SqlOrderByClause.Create(context.CurrentQuery.orderByClause.Rank, items.ToImmutableArray()); foreach (Binding binding in context.CurrentSubqueryBinding.TakeBindings()) context.CurrentQuery.AddBinding(binding); diff --git a/Microsoft.Azure.Cosmos/src/SqlObjects/SqlFunctionCallScalarExpression.cs b/Microsoft.Azure.Cosmos/src/SqlObjects/SqlFunctionCallScalarExpression.cs index a4a3af39da..bae3995c8f 100644 --- a/Microsoft.Azure.Cosmos/src/SqlObjects/SqlFunctionCallScalarExpression.cs +++ b/Microsoft.Azure.Cosmos/src/SqlObjects/SqlFunctionCallScalarExpression.cs @@ -85,6 +85,7 @@ sealed class SqlFunctionCallScalarExpression : SqlScalarExpression { Names.Endswith, Identifiers.Endswith }, { Names.Exp, Identifiers.Exp }, { Names.Floor, Identifiers.Floor }, + { Names.FullTextScore, Identifiers.FullTextScore }, { Names.FullTextContains, Identifiers.FullTextContains }, { Names.FullTextContainsAll, Identifiers.FullTextContainsAll }, { Names.FullTextContainsAny, Identifiers.FullTextContainsAny }, @@ -122,6 +123,7 @@ sealed class SqlFunctionCallScalarExpression : SqlScalarExpression { Names.Right, Identifiers.Right }, { Names.Round, Identifiers.Round }, { Names.Rtrim, Identifiers.Rtrim }, + { Names.RRF, Identifiers.RRF }, { Names.Sign, Identifiers.Sign }, { Names.Sin, Identifiers.Sin }, { Names.Sqrt, Identifiers.Sqrt }, @@ -300,6 +302,7 @@ public static class Names public const string FullTextContains = "FullTextContains"; public const string FullTextContainsAll = "FullTextContainsAll"; public const string FullTextContainsAny = "FullTextContainsAny"; + public const string FullTextScore = "FullTextScore"; public const string GetCurrentDateTime = "GetCurrentDateTime"; public const string GetCurrentTicks = "GetCurrentTicks"; public const string GetCurrentTimestamp = "GetCurrentTimestamp"; @@ -345,6 +348,7 @@ public static class Names public const string Reverse = "REVERSE"; public const string Right = "RIGHT"; public const string Round = "ROUND"; + public const string RRF = "RRF"; public const string Rtrim = "RTRIM"; public const string Sign = "SIGN"; public const string Sin = "SIN"; @@ -446,6 +450,7 @@ public static class Identifiers public static readonly SqlIdentifier FullTextContains = SqlIdentifier.Create(Names.FullTextContains); public static readonly SqlIdentifier FullTextContainsAll = SqlIdentifier.Create(Names.FullTextContainsAll); public static readonly SqlIdentifier FullTextContainsAny = SqlIdentifier.Create(Names.FullTextContainsAny); + public static readonly SqlIdentifier FullTextScore = SqlIdentifier.Create(Names.FullTextScore); public static readonly SqlIdentifier GetCurrentDateTime = SqlIdentifier.Create(Names.GetCurrentDateTime); public static readonly SqlIdentifier GetCurrentTicks = SqlIdentifier.Create(Names.GetCurrentTicks); public static readonly SqlIdentifier GetCurrentTimestamp = SqlIdentifier.Create(Names.GetCurrentTimestamp); @@ -491,6 +496,7 @@ public static class Identifiers public static readonly SqlIdentifier Reverse = SqlIdentifier.Create(Names.Reverse); public static readonly SqlIdentifier Right = SqlIdentifier.Create(Names.Right); public static readonly SqlIdentifier Round = SqlIdentifier.Create(Names.Round); + public static readonly SqlIdentifier RRF = SqlIdentifier.Create(Names.RRF); public static readonly SqlIdentifier Rtrim = SqlIdentifier.Create(Names.Rtrim); public static readonly SqlIdentifier Sign = SqlIdentifier.Create(Names.Sign); public static readonly SqlIdentifier Sin = SqlIdentifier.Create(Names.Sin); diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestFullTextScoreOrderByRankFunction.xml b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestFullTextScoreOrderByRankFunction.xml new file mode 100644 index 0000000000..ff3cc7f370 --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestFullTextScoreOrderByRankFunction.xml @@ -0,0 +1,118 @@ + + + + + doc.StringField.FullTextScore(new [] {"test1"})).Select(doc => doc.Pk)]]> + + + + + + + + + + doc.StringField.FullTextScore(new [] {"test1", "test2", "test3"})).Select(doc => doc.Pk)]]> + + + + + + + + + + doc.StringField.FullTextScore(new [] {"test1"})).Select(doc => doc.Pk)]]> + + + + + + + + + + doc.StringField.FullTextScore(new [] {"test1", "test2", "test3"})).Select(doc => doc.Pk)]]> + + + + + + + + + + (doc.StringField.FullTextScore(new [] {"test1"}) != null))]]> + + + + + + + + + + (doc.StringField.FullTextScore(new [] {"test1", "test2", "test3"}) != null))]]> + + + + + + + + + + (doc.StringField.FullTextScore(new [] {"test1"}) != null))]]> + + + + + + + + + + (doc.StringField.FullTextScore(new [] {"test1", "test2", "test3"}) != null))]]> + + + + + + + \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestOrderByRankFunctionComposeWithOtherFunctions.xml b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestOrderByRankFunctionComposeWithOtherFunctions.xml new file mode 100644 index 0000000000..b0126fc23c --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestOrderByRankFunctionComposeWithOtherFunctions.xml @@ -0,0 +1,233 @@ + + + + + doc.Pk).OrderByRank(doc => doc.FullTextScore(new [] {"test1"}))]]> + + + + + + + + + + doc.StringField.FullTextScore(new [] {"test1"})).Select(doc => doc.Pk)]]> + + + + + + + + + + new AnonymousType(stringField = doc.StringField, PartitionKey = doc.Pk)).OrderByRank(doc => doc.stringField.FullTextScore(new [] {"test1"})).Select(doc => doc.PartitionKey)]]> + + + + + + + + + + doc.ArrayField).OrderByRank(doc => doc.ToString().FullTextScore(new [] {"test1"}))]]> + + + + + + + + + + doc.StringField.FullTextScore(new [] {"test1"})).SelectMany(doc => doc.ArrayField)]]> + + + + + + + + + + doc.StringField.FullTextScore(new [] {"test1"}))]]> + + + + + + + + + + doc.StringField.FullTextScore(new [] {"test1"})).Skip(1).Select(doc => doc.Pk)]]> + + + + + + + + + + doc.StringField.FullTextScore(new [] {"test1"}))]]> + + + + + + + + + + doc.StringField.FullTextScore(new [] {"test1"})).Take(1).Select(doc => doc.Pk)]]> + + + + + + + + + + doc.StringField, (key, values) => values.Count()).OrderByRank(doc => doc.ToString().FullTextScore(new [] {"test1"}))]]> + + + + + + + + + + doc.StringField.FullTextScore(new [] {"test1"})).GroupBy(doc => doc.Pk, (key, values) => values.Count())]]> + + + + + + + + + + doc.NumericField).OrderByRank(doc => doc.StringField.FullTextScore(new [] {"test1"})).Select(doc => doc.Pk)]]> + + + + + + + + + + doc.StringField.FullTextScore(new [] {"test1"})).OrderBy(doc => doc.NumericField).Select(doc => doc.Pk)]]> + + + + + + + + + + (doc.NumericField > 0)).OrderByRank(doc => doc.StringField.FullTextScore(new [] {"test1"})).Select(doc => doc.Pk)]]> + + + 0) +ORDER BY RANK FullTextScore(root["StringField"], "test1")]]> + + + + + + + doc.StringField.FullTextScore(new [] {"test1"})).Where(doc => (doc.NumericField > 0)).Select(doc => doc.Pk)]]> + + + 0) +ORDER BY RANK FullTextScore(root["StringField"], "test1")]]> + + + + \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestRRFOrderByRankFunction.xml b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestRRFOrderByRankFunction.xml new file mode 100644 index 0000000000..ccd04e4bbb --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestRRFOrderByRankFunction.xml @@ -0,0 +1,73 @@ + + + + + RRF(new [] {doc.StringField.FullTextScore(new [] {"test1"}), doc.StringField2.FullTextScore(new [] {"test1", "test2", "test3"})})).Select(doc => doc.Pk)]]> + + + + + + + + + + RRF(new [] {doc.StringField.FullTextScore(new [] {"test1"}), doc.StringField.FullTextScore(new [] {"test1", "text2"}), doc.StringField2.FullTextScore(new [] {"test1", "test2", "test3"})})).Select(doc => doc.Pk)]]> + + + + + + + + + + RRF(new [] {doc.StringField.FullTextScore(new [] {"test1"})})).Select(doc => doc.Pk)]]> + + + + + + + + + + (RRF(new [] {doc.StringField.FullTextScore(new [] {"test1"})}) != null))]]> + + + + + + + + + + (RRF(new [] {doc.StringField.FullTextScore(new [] {"test1"}), doc.StringField2.FullTextScore(new [] {"test1", "test2", "test3"})}) != null))]]> + + + + + + + \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestsCommon.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestsCommon.cs index 0821dc3014..37d2b62e6a 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestsCommon.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestsCommon.cs @@ -678,6 +678,8 @@ public class LinqTestInput : BaselineTestInput // Ignore Ordering for AnonymousType object internal bool ignoreOrder; + + internal bool serializeOutput; internal LinqTestInput( string description, @@ -685,14 +687,16 @@ internal LinqTestInput( bool skipVerification = false, bool ignoreOrder = false, string expressionStr = null, - string inputData = null) + string inputData = null, + bool serializeOutput = false) : base(description) { this.Expression = expr ?? throw new ArgumentNullException($"{nameof(expr)} must not be null."); this.skipVerification = skipVerification; this.ignoreOrder = ignoreOrder; this.expressionStr = expressionStr; - this.inputData = inputData; + this.inputData = inputData; + this.serializeOutput = serializeOutput; } public static string FilterInputExpression(string input) diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTranslationBaselineTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTranslationBaselineTests.cs index 19b131cd73..eb4f4c178f 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTranslationBaselineTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTranslationBaselineTests.cs @@ -18,7 +18,9 @@ namespace Microsoft.Azure.Cosmos.Services.Management.Tests.LinqProviderTests using System.Linq.Dynamic; using System.Text; using Microsoft.Azure.Cosmos.SDK.EmulatorTests; - using System.Threading.Tasks; + using System.Threading.Tasks; + + using static Microsoft.Azure.Cosmos.Linq.CosmosLinqExtensions; [Microsoft.Azure.Cosmos.SDK.EmulatorTests.TestClass] public class LinqTranslationBaselineTests : BaselineTests @@ -344,8 +346,8 @@ public void TestRegexMatchFunction() new LinqTestInput("RegexMatch with 2nd argument invalid string options", b => getQuery(b).Where(doc => doc.StringField.RegexMatch("abcd", "this should error out on the back end"))), }; this.ExecuteTestSuite(inputs); - } - + } + [TestMethod] public void TestFullTextContainsFunction() { @@ -389,6 +391,221 @@ public void TestFullTextContainsFunction() this.ExecuteTestSuite(inputs); } + [TestMethod] + public void TestFullTextScoreOrderByRankFunction() + { + const int Records = 2; + const int MaxStringLength = 100; + static DataObject createDataObj(Random random) + { + DataObject obj = new DataObject + { + StringField = LinqTestsCommon.RandomString(random, random.Next(MaxStringLength)), + Id = Guid.NewGuid().ToString(), + Pk = "Test" + }; + return obj; + } + Func> getQuery = LinqTestsCommon.GenerateTestCosmosData(createDataObj, Records, testContainer); + + List inputs = new List + { + new LinqTestInput("FullTextScore with 1 element array", b => getQuery(b) + .OrderByRank(doc => doc.StringField.FullTextScore("test1")) + .Select(doc => doc.Pk)), + new LinqTestInput("FullTextScore with 3 element array", b => getQuery(b) + .OrderByRank(doc => doc.StringField.FullTextScore("test1", "test2", "test3")) + .Select(doc => doc.Pk)), + new LinqTestInput("FullTextScore with 1 element array", b => getQuery(b) + .OrderByRank(doc => doc.StringField.FullTextScore(new string[] { "test1" })) + .Select(doc => doc.Pk)), + new LinqTestInput("FullTextScore with 3 element array", b => getQuery(b) + .OrderByRank(doc => doc.StringField.FullTextScore(new string[] { "test1", "test2", "test3" })) + .Select(doc => doc.Pk)), + + // Negative case: FullTextScore in non order by clause + new LinqTestInput("FullTextScore in WHERE clause", b => getQuery(b) + .Where(doc => doc.StringField.FullTextScore(new string[] { "test1" }) != null)), + new LinqTestInput("FullTextScore in WHERE clause 2", b => getQuery(b) + .Where(doc => doc.StringField.FullTextScore(new string[] { "test1", "test2", "test3" }) != null)), + + new LinqTestInput("FullTextScore in WHERE clause", b => getQuery(b) + .Where(doc => doc.StringField.FullTextScore("test1") != null)), + new LinqTestInput("FullTextScore in WHERE clause 2", b => getQuery(b) + .Where(doc => doc.StringField.FullTextScore("test1", "test2", "test3") != null)), + }; + + foreach (LinqTestInput input in inputs) + { + // OrderBy are not supported client side. + // Therefore this method is verified with baseline only. + input.skipVerification = true; + input.serializeOutput = true; + } + + this.ExecuteTestSuite(inputs); + } + + [TestMethod] + public void TestRRFOrderByRankFunction() + { + const int Records = 2; + const int MaxStringLength = 100; + static DataObject createDataObj(Random random) + { + DataObject obj = new DataObject + { + StringField = LinqTestsCommon.RandomString(random, random.Next(MaxStringLength)), + Id = Guid.NewGuid().ToString(), + Pk = "Test" + }; + return obj; + } + Func> getQuery = LinqTestsCommon.GenerateTestCosmosData(createDataObj, Records, testContainer); + + List inputs = new List + { + new LinqTestInput("RRF with 2 functions", b => getQuery(b) + .OrderByRank(doc => RRF(doc.StringField.FullTextScore(new string[] { "test1" }), doc.StringField2.FullTextScore(new string[] { "test1", "test2", "test3" }))) + .Select(doc => doc.Pk)), + + new LinqTestInput("RRF with 3 functions", b => getQuery(b) + .OrderByRank(doc => RRF(doc.StringField.FullTextScore(new string[] { "test1" }), + doc.StringField.FullTextScore(new string[] { "test1", "text2" }), + doc.StringField2.FullTextScore(new string[] { "test1", "test2", "test3" }))) + .Select(doc => doc.Pk)), + + // Negative case: FullTextScore in non order by clause + new LinqTestInput("RRF with 1 function", b => getQuery(b) + .OrderByRank(doc => CosmosLinqExtensions.RRF(doc.StringField.FullTextScore(new string[] { "test1" }))) + .Select(doc => doc.Pk)), + + new LinqTestInput("RRF in WHERE clause", b => getQuery(b) + .Where(doc => RRF(doc.StringField.FullTextScore(new string[] { "test1" })) != null)), + new LinqTestInput("RRF in WHERE clause 2", b => getQuery(b) + .Where(doc => RRF(doc.StringField.FullTextScore(new string[] { "test1" }), + doc.StringField2.FullTextScore(new string[] { "test1", "test2", "test3" })) != null)), + }; + + foreach (LinqTestInput input in inputs) + { + // OrderBy are not supported client side. + // Therefore this method is verified with baseline only. + input.skipVerification = true; + input.serializeOutput = true; + } + + this.ExecuteTestSuite(inputs); + } + + [TestMethod] + public void TestOrderByRankFunctionComposeWithOtherFunctions() + { + const int Records = 2; + const int MaxStringLength = 100; + static DataObject createDataObj(Random random) + { + DataObject obj = new DataObject + { + StringField = LinqTestsCommon.RandomString(random, random.Next(MaxStringLength)), + NumericField = 1, + Id = Guid.NewGuid().ToString(), + Pk = "Test" + }; + return obj; + } + Func> getQuery = LinqTestsCommon.GenerateTestCosmosData(createDataObj, Records, testContainer); + + List inputs = new List + { + // Select + new LinqTestInput("Select + Order By Rank", b => getQuery(b) + .Select(doc => doc.Pk) + .OrderByRank(doc => doc.FullTextScore(new string[] { "test1" }))), + + new LinqTestInput("Order By Rank + Select", b => getQuery(b) + .OrderByRank(doc => doc.StringField.FullTextScore(new string[] { "test1" })) + .Select(doc => doc.Pk)), + + new LinqTestInput("Select + Order By Rank + Select", b => getQuery(b) + .Select(doc => new { stringField = doc.StringField, PartitionKey = doc.Pk }) + .OrderByRank(doc => doc.stringField.FullTextScore(new string[] { "test1" })) + .Select(doc => doc.PartitionKey)), + + // Join + new LinqTestInput("SelectMany + Order By Rank", b => getQuery(b) + .SelectMany(doc => doc.ArrayField) + .OrderByRank(doc => doc.ToString().FullTextScore(new string[] { "test1" }))), + + new LinqTestInput("Order By Rank + SelectMany", b => getQuery(b) + .OrderByRank(doc => doc.StringField.FullTextScore(new string[] { "test1" })) + .SelectMany(doc => doc.ArrayField)), + + // Skip + new LinqTestInput("Skip + Order By Rank", b => getQuery(b) + .Skip(1) + .OrderByRank(doc => doc.StringField.FullTextScore(new string[] { "test1" }))), + + new LinqTestInput("Order By Rank + Skip", b => getQuery(b) + .OrderByRank(doc => doc.StringField.FullTextScore(new string[] { "test1" })) + .Skip(1) + .Select(doc => doc.Pk)), + + // Take + new LinqTestInput("Take + Order By Rank", b => getQuery(b) + .Take(1) + .OrderByRank(doc => doc.StringField.FullTextScore(new string[] { "test1" }))), + + new LinqTestInput("Order By Rank + Take", b => getQuery(b) + .OrderByRank(doc => doc.StringField.FullTextScore(new string[] { "test1" })) + .Take(1) + .Select(doc => doc.Pk)), + + // GroupBy + new LinqTestInput("GroupBy + Order By Rank", b => getQuery(b) + .GroupBy(doc => doc.StringField, (key, values) => values.Count()) + .OrderByRank(doc => doc.ToString().FullTextScore(new string[] { "test1" }))), + + new LinqTestInput("Order By Rank + GroupBy", b => getQuery(b) + .OrderByRank(doc => doc.StringField.FullTextScore(new string[] { "test1" })) + .GroupBy(doc => doc.Pk, (key, values) => values.Count())), + + // Order By + // $issue-todo-leminh-20250424: There's an issue with OrderBy follows by OrderBy - the most recent order by is the only one that is applied to the query. + // This will be addressed in a separate PR + new LinqTestInput("Order By + Order By Rank", b => getQuery(b) + .OrderBy(doc => doc.NumericField) + .OrderByRank(doc => doc.StringField.FullTextScore(new string[] { "test1" })) + .Select(doc => doc.Pk)), + + new LinqTestInput("Order By Rank + Order By", b => getQuery(b) + .OrderByRank(doc => doc.StringField.FullTextScore(new string[] { "test1" })) + .OrderBy(doc => doc.NumericField) + .Select(doc => doc.Pk)), + + // Where + new LinqTestInput("Where + Order By Rank", b => getQuery(b) + .Where(doc => doc.NumericField > 0) + .OrderByRank(doc => doc.StringField.FullTextScore(new string[] { "test1" })) + .Select(doc => doc.Pk)), + + new LinqTestInput("Order By Rank + Where", b => getQuery(b) + .OrderByRank(doc => doc.StringField.FullTextScore(new string[] { "test1" })) + .Where(doc => doc.NumericField > 0) + .Select(doc => doc.Pk)), + }; + + foreach (LinqTestInput input in inputs) + { + // OrderBy are not supported client side. + // Therefore this method is verified with baseline only. + input.skipVerification = true; + input.serializeOutput = true; + } + + this.ExecuteTestSuite(inputs); + } + [TestMethod] public void TestMemberInitializer() { @@ -1421,7 +1638,7 @@ public void TestSelectManyWithFilters() public override LinqTestOutput ExecuteTest(LinqTestInput input) { - return LinqTestsCommon.ExecuteTest(input); + return LinqTestsCommon.ExecuteTest(input, input.serializeOutput); } } } \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj index a334d6bdf2..44ee444d5b 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj @@ -241,6 +241,15 @@ PreserveNewest + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + PreserveNewest diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetSDKAPI.json b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetSDKAPI.json index e8c9c34eec..08870f891f 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetSDKAPI.json +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetSDKAPI.json @@ -6372,6 +6372,25 @@ ], "MethodInfo": "Microsoft.Azure.Cosmos.QueryDefinition ToQueryDefinition[T](System.Linq.IQueryable`1[T]);IsAbstract:False;IsStatic:True;IsVirtual:False;IsGenericMethod:True;IsConstructor:False;IsFinal:False;" }, + "System.Func`2[TSource,System.Object] FullTextScore[TSource](TSource, System.String[])[System.Runtime.CompilerServices.ExtensionAttribute()]": { + "Type": "Method", + "Attributes": [ + "ExtensionAttribute" + ], + "MethodInfo": "System.Func`2[TSource,System.Object] FullTextScore[TSource](TSource, System.String[]);IsAbstract:False;IsStatic:True;IsVirtual:False;IsGenericMethod:True;IsConstructor:False;IsFinal:False;" + }, + "System.Func`2[TSource,System.Object] RRF[TSource](System.Func`2[TSource,System.Object][])": { + "Type": "Method", + "Attributes": [], + "MethodInfo": "System.Func`2[TSource,System.Object] RRF[TSource](System.Func`2[TSource,System.Object][]);IsAbstract:False;IsStatic:True;IsVirtual:False;IsGenericMethod:True;IsConstructor:False;IsFinal:False;" + }, + "System.Linq.IOrderedQueryable`1[TSource] OrderByRank[TSource](System.Linq.IQueryable`1[TSource], System.Linq.Expressions.Expression`1[System.Func`2[TSource,System.Object]])[System.Runtime.CompilerServices.ExtensionAttribute()]": { + "Type": "Method", + "Attributes": [ + "ExtensionAttribute" + ], + "MethodInfo": "System.Linq.IOrderedQueryable`1[TSource] OrderByRank[TSource](System.Linq.IQueryable`1[TSource], System.Linq.Expressions.Expression`1[System.Func`2[TSource,System.Object]]);IsAbstract:False;IsStatic:True;IsVirtual:False;IsGenericMethod:True;IsConstructor:False;IsFinal:False;" + }, "System.Threading.Tasks.Task`1[Microsoft.Azure.Cosmos.Response`1[System.Decimal]] AverageAsync(System.Linq.IQueryable`1[System.Decimal], System.Threading.CancellationToken)[System.Runtime.CompilerServices.ExtensionAttribute()]": { "Type": "Method", "Attributes": [