Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public static SqlScalarExpression VisitBuiltinFunctionCall(MethodCallExpression
case nameof(CosmosLinqExtensions.FullTextContainsAll):
case nameof(CosmosLinqExtensions.FullTextContainsAny):
return StringBuiltinFunctions.Visit(methodCallExpression, context);
case nameof(CosmosLinqExtensions.ArrayContainsAll):
case nameof(CosmosLinqExtensions.ArrayContainsAny):
case nameof(CosmosLinqExtensions.DocumentId):
case nameof(CosmosLinqExtensions.RRF):
case nameof(CosmosLinqExtensions.FullTextScore):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ namespace Microsoft.Azure.Cosmos.Linq

internal static class OtherBuiltinSystemFunctions
{
private class RRFVisit : SqlBuiltinFunctionVisitor
private class RrfVisitor : SqlBuiltinFunctionVisitor
{
public RRFVisit()
public RrfVisitor()
: base("RRF",
true,
new List<Type[]>()
Expand Down Expand Up @@ -87,9 +87,9 @@ protected override SqlScalarExpression VisitExplicit(MethodCallExpression method
}
}

private class FullTextScoreVisit : SqlBuiltinFunctionVisitor
private class FullTextScoreVisitor : SqlBuiltinFunctionVisitor
{
public FullTextScoreVisit()
public FullTextScoreVisitor()
: base("FullTextScore",
true,
new List<Type[]>()
Expand Down Expand Up @@ -124,9 +124,9 @@ protected override SqlScalarExpression VisitExplicit(MethodCallExpression method
}
}

private class VectorDistanceVisit : SqlBuiltinFunctionVisitor
private class VectorDistanceVisitor : SqlBuiltinFunctionVisitor
{
public VectorDistanceVisit()
public VectorDistanceVisitor()
: base("VectorDistance",
true,
new List<Type[]>()
Expand Down Expand Up @@ -173,22 +173,84 @@ protected override SqlScalarExpression VisitExplicit(MethodCallExpression method
}
}

private class ArrayContainsAllAnyVisitor : SqlBuiltinFunctionVisitor
{
public ArrayContainsAllAnyVisitor(string sqlName)
: base(sqlName,
true,
new List<Type[]>()
{
new Type[]{typeof(object), typeof(object[])}
})
{
}

protected override SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context)
{
if (methodCallExpression.Arguments.Count != 2)
{
return null;
Comment thread
adityasa marked this conversation as resolved.
}

List<SqlScalarExpression> arguments = new List<SqlScalarExpression>
{
// First argument: the array to search in
ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Arguments[0], context)
};

// Unwrap the second argument based on its type
Expression secondArgument = methodCallExpression.Arguments[1];

switch (secondArgument)
{
case NewArrayExpression arrayExpression:
// Unwrap inline array initialization (e.g., new[] { 1, 2, 3 })
foreach (Expression element in arrayExpression.Expressions)
{
arguments.Add(ExpressionToSql.VisitNonSubqueryScalarExpression(element, context));
}
break;

case ConstantExpression constantExpression when constantExpression.Value is Array constantArray:
// Unwrap constant array
foreach (object element in constantArray)
{
Expression constantElementExpression = Expression.Constant(element, element?.GetType() ?? typeof(object));
arguments.Add(ExpressionToSql.VisitNonSubqueryScalarExpression(constantElementExpression, context));
}
break;

default:
return null;
}

return SqlFunctionCallScalarExpression.CreateBuiltin(this.SqlName, arguments.ToArray());
}

protected override SqlScalarExpression VisitExplicit(MethodCallExpression methodCallExpression, TranslationContext context)
{
return null;
}
}

private static Dictionary<string, BuiltinFunctionVisitor> FunctionsDefinitions { get; set; }

static OtherBuiltinSystemFunctions()
{
FunctionsDefinitions = new Dictionary<string, BuiltinFunctionVisitor>
{
[nameof(CosmosLinqExtensions.ArrayContainsAll)] = new ArrayContainsAllAnyVisitor(sqlName: "ARRAY_CONTAINS_ALL"),
[nameof(CosmosLinqExtensions.ArrayContainsAny)] = new ArrayContainsAllAnyVisitor(sqlName: "ARRAY_CONTAINS_ANY"),
[nameof(CosmosLinqExtensions.DocumentId)] = new SqlBuiltinFunctionVisitor(
sqlName: "DOCUMENTID",
isStatic: true,
argumentLists: new List<Type[]>()
{
new Type[]{typeof(object)},
}),
[nameof(CosmosLinqExtensions.RRF)] = new RRFVisit(),
[nameof(CosmosLinqExtensions.FullTextScore)] = new FullTextScoreVisit(),
[nameof(CosmosLinqExtensions.VectorDistance)] = new VectorDistanceVisit(),
[nameof(CosmosLinqExtensions.FullTextScore)] = new FullTextScoreVisitor(),
[nameof(CosmosLinqExtensions.RRF)] = new RrfVisitor(),
[nameof(CosmosLinqExtensions.VectorDistance)] = new VectorDistanceVisitor(),
};
}

Expand Down
36 changes: 36 additions & 0 deletions Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,42 @@ public sealed class VectorDistanceOptions
public int? SearchListSizeMultiplier { get; set; }
}

/// <summary>
/// Returns whether all values are present in the array.
/// </summary>
/// <param name="obj">The array object</param>
/// <param name="values">values to search</param>
/// <returns>Returns true if all values are present; otherwise, false.</returns>
/// <example>
/// <code>
/// <![CDATA[
/// var documentIdQuery = documents.Where(root => root.DocumentId());
/// ]]>
/// </code>
/// </example>
public static bool ArrayContainsAll<T>(this IEnumerable<T> obj, params object[] values)
{
throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented);
}

/// <summary>
/// Returns whether any values are present in the array.
/// </summary>
/// <param name="obj">The array object</param>
/// <param name="values">values to search</param>
/// <returns>Returns true if any values are present; otherwise, false.</returns>
/// <example>
/// <code>
/// <![CDATA[
/// var documentIdQuery = documents.Where(root => root.DocumentId());
/// ]]>
/// </code>
/// </example>
public static bool ArrayContainsAny<T>(this IEnumerable<T> obj, params object[] values)
{
throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented);
}

/// <summary>
/// Returns the integer identifier corresponding to a specific item within a physical partition.
/// This method is to be used in LINQ expressions only and will be evaluated on server.
Expand Down
Loading
Loading