diff --git a/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/BuiltinFunctionVisitor.cs b/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/BuiltinFunctionVisitor.cs index 3008b3eac9..d68b1bcce0 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/BuiltinFunctionVisitor.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/BuiltinFunctionVisitor.cs @@ -60,6 +60,8 @@ public static SqlScalarExpression VisitBuiltinFunctionCall(MethodCallExpression case nameof(CosmosLinqExtensions.FullTextContainsAll): case nameof(CosmosLinqExtensions.FullTextContainsAny): return StringBuiltinFunctions.Visit(methodCallExpression, context); + case nameof(CosmosLinqExtensions.ArrayContainsAll): + case nameof(CosmosLinqExtensions.ArrayContainsAny): case nameof(CosmosLinqExtensions.DocumentId): case nameof(CosmosLinqExtensions.RRF): case nameof(CosmosLinqExtensions.FullTextScore): diff --git a/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/OtherBuiltinSystemFunctions.cs b/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/OtherBuiltinSystemFunctions.cs index 1987679df2..21c097be54 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/OtherBuiltinSystemFunctions.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/OtherBuiltinSystemFunctions.cs @@ -17,9 +17,9 @@ namespace Microsoft.Azure.Cosmos.Linq internal static class OtherBuiltinSystemFunctions { - private class RRFVisit : SqlBuiltinFunctionVisitor + private class RrfVisitor : SqlBuiltinFunctionVisitor { - public RRFVisit() + public RrfVisitor() : base("RRF", true, new List() @@ -87,9 +87,9 @@ protected override SqlScalarExpression VisitExplicit(MethodCallExpression method } } - private class FullTextScoreVisit : SqlBuiltinFunctionVisitor + private class FullTextScoreVisitor : SqlBuiltinFunctionVisitor { - public FullTextScoreVisit() + public FullTextScoreVisitor() : base("FullTextScore", true, new List() @@ -124,9 +124,9 @@ protected override SqlScalarExpression VisitExplicit(MethodCallExpression method } } - private class VectorDistanceVisit : SqlBuiltinFunctionVisitor + private class VectorDistanceVisitor : SqlBuiltinFunctionVisitor { - public VectorDistanceVisit() + public VectorDistanceVisitor() : base("VectorDistance", true, new List() @@ -173,12 +173,74 @@ protected override SqlScalarExpression VisitExplicit(MethodCallExpression method } } + private class ArrayContainsAllAnyVisitor : SqlBuiltinFunctionVisitor + { + public ArrayContainsAllAnyVisitor(string sqlName) + : base(sqlName, + true, + new List() + { + new Type[]{typeof(object), typeof(object[])} + }) + { + } + + protected override SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context) + { + if (methodCallExpression.Arguments.Count != 2) + { + return null; + } + + List arguments = new List + { + // First argument: the array to search in + ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Arguments[0], context) + }; + + // Unwrap the second argument based on its type + Expression secondArgument = methodCallExpression.Arguments[1]; + + switch (secondArgument) + { + case NewArrayExpression arrayExpression: + // Unwrap inline array initialization (e.g., new[] { 1, 2, 3 }) + foreach (Expression element in arrayExpression.Expressions) + { + arguments.Add(ExpressionToSql.VisitNonSubqueryScalarExpression(element, context)); + } + break; + + case ConstantExpression constantExpression when constantExpression.Value is Array constantArray: + // Unwrap constant array + foreach (object element in constantArray) + { + Expression constantElementExpression = Expression.Constant(element, element?.GetType() ?? typeof(object)); + arguments.Add(ExpressionToSql.VisitNonSubqueryScalarExpression(constantElementExpression, context)); + } + break; + + default: + return null; + } + + return SqlFunctionCallScalarExpression.CreateBuiltin(this.SqlName, arguments.ToArray()); + } + + protected override SqlScalarExpression VisitExplicit(MethodCallExpression methodCallExpression, TranslationContext context) + { + return null; + } + } + private static Dictionary FunctionsDefinitions { get; set; } static OtherBuiltinSystemFunctions() { FunctionsDefinitions = new Dictionary { + [nameof(CosmosLinqExtensions.ArrayContainsAll)] = new ArrayContainsAllAnyVisitor(sqlName: "ARRAY_CONTAINS_ALL"), + [nameof(CosmosLinqExtensions.ArrayContainsAny)] = new ArrayContainsAllAnyVisitor(sqlName: "ARRAY_CONTAINS_ANY"), [nameof(CosmosLinqExtensions.DocumentId)] = new SqlBuiltinFunctionVisitor( sqlName: "DOCUMENTID", isStatic: true, @@ -186,9 +248,9 @@ static OtherBuiltinSystemFunctions() { new Type[]{typeof(object)}, }), - [nameof(CosmosLinqExtensions.RRF)] = new RRFVisit(), - [nameof(CosmosLinqExtensions.FullTextScore)] = new FullTextScoreVisit(), - [nameof(CosmosLinqExtensions.VectorDistance)] = new VectorDistanceVisit(), + [nameof(CosmosLinqExtensions.FullTextScore)] = new FullTextScoreVisitor(), + [nameof(CosmosLinqExtensions.RRF)] = new RrfVisitor(), + [nameof(CosmosLinqExtensions.VectorDistance)] = new VectorDistanceVisitor(), }; } diff --git a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs index 12732395c0..36e078fc72 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs @@ -48,6 +48,42 @@ public sealed class VectorDistanceOptions public int? SearchListSizeMultiplier { get; set; } } + /// + /// Returns whether all values are present in the array. + /// + /// The array object + /// values to search + /// Returns true if all values are present; otherwise, false. + /// + /// + /// root.DocumentId()); + /// ]]> + /// + /// + public static bool ArrayContainsAll(this IEnumerable obj, params object[] values) + { + throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); + } + + /// + /// Returns whether any values are present in the array. + /// + /// The array object + /// values to search + /// Returns true if any values are present; otherwise, false. + /// + /// + /// root.DocumentId()); + /// ]]> + /// + /// + public static bool ArrayContainsAny(this IEnumerable obj, params object[] values) + { + throw new NotImplementedException(ClientResources.TypeCheckExtensionFunctionsNotImplemented); + } + /// /// Returns the integer identifier corresponding to a specific item within a physical partition. /// This method is to be used in LINQ expressions only and will be evaluated on server. diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestArrayContainsAll.xml b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestArrayContainsAll.xml new file mode 100644 index 0000000000..e81ebc419a --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestArrayContainsAll.xml @@ -0,0 +1,895 @@ + + + + + doc.ArrayField.ArrayContainsAll(new [] {})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAll(new [] {Convert(1, Object)})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAll(new [] {Convert(1, Object), Convert(2, Object), Convert(3, Object)})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAll(new [] {Convert(1, Object), Convert(2, Object), Convert(3, Object), Convert(2147483647, Object), Convert(-2147483648, Object)})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAll(new [] {Convert(1, Object), "hello", null, Convert(55, Object), Convert(10.123456789, Object), new [] {"world", new [] {"nested array", new AnonymousType(A = 2147483647, B = -2147483648)}}})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAll(new [] {doc.ArrayField})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + new [] {Convert(1, Object), "hello", null, Convert(55, Object), Convert(10.123456789, Object), new [] {"world", new [] {"nested array", new AnonymousType(A = 2147483647, B = -2147483648)}}}.ArrayContainsAll(new [] {doc.ArrayField})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.EnumerableField.ArrayContainsAll(new [] {Convert(5, Object), Convert(6, Object), Convert(7, Object)})).Select(doc => doc.EnumerableField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAll(new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object)})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAll(new [] {Convert(1, Object), doc.StringField, null, doc.VectorFloatField, Convert(10.123456789, Object), new [] {"world", new [] {doc.ToString(), new AnonymousType(A = doc.StringField2.StartsWith("abc"), B = -2147483648)}}})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object)}.ArrayContainsAll(new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object)})).Select(doc => new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object)})]]> + + + + + + + + + + new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object), doc.ArrayField}.ArrayContainsAll(new [] {doc.ArrayField})).Select(doc => new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object), doc.ArrayField})]]> + + + + + + + + + + doc.StringField.ArrayContainsAll(new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object)})).Select(doc => doc.StringField)]]> + + + + + + + + + + doc.StringField.ArrayContainsAll(new [] {doc.VectorFloatField})).Select(doc => doc.StringField)]]> + + + + + + + + + + (doc.ArrayField.ArrayContainsAll(new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object)}) OrElse (doc.StringField.StartsWith("abc") AndAlso doc.EnumerableField.ArrayContainsAll(new [] {doc.ArrayField})))).Select(doc => new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object), doc.StringField, doc.EnumerableField})]]> + + + + + + + + + + new [] {doc.ArrayField, Convert(doc.ArrayField.ArrayContainsAll(new [] {Convert(1, Object), doc.StringField, null, doc.VectorFloatField, Convert(10.123456789, Object), new [] {"world", new [] {doc.ToString(), new AnonymousType(A = doc.StringField2.StartsWith("abc"), B = -2147483648)}}}), Object)})]]> + + + + + + + + + + doc.ArrayField).Select(item => new [] {Convert(item, Object), Convert(new [] {item}.ArrayContainsAll(new [] {item.ToString()}), Object)})]]> + + + + + + + \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestArrayContainsAny.xml b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestArrayContainsAny.xml new file mode 100644 index 0000000000..3c6459486e --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestArrayContainsAny.xml @@ -0,0 +1,1391 @@ + + + + + doc.ArrayField.ArrayContainsAny(new [] {})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAny(new [] {Convert(1, Object)})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAny(new [] {Convert(1, Object), Convert(2, Object), Convert(3, Object)})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAny(new [] {Convert(1, Object), Convert(2, Object), Convert(3, Object), Convert(2147483647, Object), Convert(-2147483648, Object)})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAny(new [] {Convert(1, Object), "hello", null, Convert(55, Object), Convert(10.123456789, Object), new [] {"world", new [] {"nested array", new AnonymousType(A = 2147483647, B = -2147483648)}}})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAny(new [] {doc.ArrayField})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + new [] {Convert(1, Object), "hello", null, Convert(55, Object), Convert(10.123456789, Object), new [] {"world", new [] {"nested array", new AnonymousType(A = 2147483647, B = -2147483648)}}}.ArrayContainsAny(new [] {doc.ArrayField})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.EnumerableField.ArrayContainsAny(new [] {Convert(5, Object), Convert(6, Object), Convert(7, Object), Convert(100, Object)})).Select(doc => doc.EnumerableField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAny(new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object)})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + doc.ArrayField.ArrayContainsAny(new [] {Convert(1, Object), doc.StringField, null, doc.VectorFloatField, Convert(10.123456789, Object), new [] {"world", new [] {doc.ToString(), new AnonymousType(A = doc.StringField2.StartsWith("abc"), B = -2147483648)}}})).Select(doc => doc.ArrayField)]]> + + + + + + + + + + new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object)}.ArrayContainsAny(new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object)})).Select(doc => new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object)})]]> + + + + + + + + + + new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object), doc.ArrayField}.ArrayContainsAny(new [] {doc.ArrayField})).Select(doc => new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object), doc.ArrayField})]]> + + + + + + + + + + doc.StringField.ArrayContainsAny(new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object)})).Select(doc => doc.StringField)]]> + + + + + + + + + + doc.StringField.ArrayContainsAny(new [] {doc.VectorFloatField})).Select(doc => doc.StringField)]]> + + + + + + + + + + (doc.ArrayField.ArrayContainsAny(new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object)}) OrElse (doc.StringField.StartsWith("abc") AndAlso doc.EnumerableField.ArrayContainsAny(new [] {doc.ArrayField})))).Select(doc => new [] {doc.Id, Convert(doc.IntField, Object), Convert(doc.GuidField, Object), doc.StringField, doc.EnumerableField})]]> + + + + + + + + + + new [] {doc.ArrayField, Convert(doc.ArrayField.ArrayContainsAny(new [] {Convert(1, Object), doc.StringField, null, doc.VectorFloatField, Convert(10.123456789, Object), new [] {"world", new [] {doc.ToString(), new AnonymousType(A = doc.StringField2.StartsWith("abc"), B = -2147483648)}}}), Object)})]]> + + + + + + + + + + doc.ArrayField).Select(item => new [] {Convert(item, Object), Convert(new [] {item}.ArrayContainsAny(new [] {item.ToString()}), Object)})]]> + + + + + + + \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestArrayFunctions.xml b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestArrayFunctions.xml index ecf71c89a6..9abb29b5bb 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestArrayFunctions.xml +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqTranslationBaselineTests.TestArrayFunctions.xml @@ -212,4 +212,14 @@ SELECT VALUE ARRAY_LENGTH(root["EnumerableField"]) FROM root]]> + + + + Not(doc.EnumerableField.Contains(1, EqualityComparer`1.Default)))]]> + + + + + + \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTranslationBaselineTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTranslationBaselineTests.cs index 0cfd83e2af..b1faf73c28 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTranslationBaselineTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTranslationBaselineTests.cs @@ -105,6 +105,13 @@ public static bool ObjectEquals(object objA, object objB) return objA.Equals(objB); } + private static Guid GuidFromInt(int value) + { + byte[] bytes = new byte[16]; + BitConverter.GetBytes(value).CopyTo(bytes, 0); + return new Guid(bytes); + } + internal class DataObject : LinqTestObject { public double NumericField; @@ -1379,6 +1386,298 @@ public void TestStringFunctions() this.ExecuteTestSuite(inputs); } + [TestMethod] + public void TestArrayContainsAll() + { + const int Records = 10; + const int MaxAbsValue = 10; + const int MaxArraySize = 5; + + int index = 0; + Func createDataObj = (random) => + { + DataObject obj = new DataObject(); + obj.ArrayField = new int[index + 1]; + for (int i = 0; i < obj.ArrayField.Length; ++i) + { + obj.ArrayField[i] = i; + } + obj.EnumerableField = new List(); + for (int i = 0; i < index + MaxArraySize; ++i) + { + obj.EnumerableField.Add(i - MaxArraySize); + } + obj.NumericField = (index * MaxAbsValue * 2) - MaxAbsValue; + obj.Id = GuidFromInt(index).ToString(); + obj.Pk = "Test"; + index++; + return obj; + }; + Func> getQuery = LinqTestsCommon.GenerateTestCosmosData(createDataObj, Records, testContainer); + + List inputs = new List + { + new LinqTestInput("no parameters", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAll()).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("params with 1 int", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAll(1)).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("params with 3 ints", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAll(1, 2, 3)).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("params with 5 ints", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAll(1, 2, 3, int.MaxValue, int.MinValue )).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("params with mixed types", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAll( + 1, + "hello", + null, + 55f, + 10.123456789d, + new object[] { + "world", + new object[] { "nested array", new { A = int.MaxValue, B = int.MinValue } } })).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("same field", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAll(doc.ArrayField)).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("params with mixed types", + b => getQuery(b).Where(doc => new object[] { + 1, + "hello", + null, + 55f, + 10.123456789d, + new object[] { + "world", + new object[] { "nested array", new { A = int.MaxValue, B = int.MinValue } } } + }.ArrayContainsAll(doc.ArrayField)).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("use enumerable field", + b => getQuery(b).Where(doc => doc.EnumerableField.ArrayContainsAll(5, 6, 7)).Select(doc => doc.EnumerableField), + skipVerification: true, + serializeOutput: true), + + new LinqTestInput("use document fields", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAll(new object[] { doc.Id, doc.IntField, doc.GuidField })).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("params with mixed types using document fields", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAll( + 1, + doc.StringField, + null, + doc.VectorFloatField, + 10.123456789d, + new object[] { + "world", + new object[] { doc.ToString(), new { A = doc.StringField2.StartsWith("abc"), B = int.MinValue } } })) + .Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("use document fields", + b => getQuery(b).Where(doc => new object[] { doc.Id, doc.IntField, doc.GuidField }.ArrayContainsAll(doc.Id, doc.IntField, doc.GuidField)).Select(doc => new object[] { doc.Id, doc.IntField, doc.GuidField }), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("use document fields", + b => getQuery(b).Where(doc => new object[] { doc.Id, doc.IntField, doc.GuidField, doc.ArrayField }.ArrayContainsAll(new object[] { doc.ArrayField })).Select(doc => new object[] { doc.Id, doc.IntField, doc.GuidField, doc.ArrayField }), + skipVerification: true, + serializeOutput: true), + + new LinqTestInput("use non-array field", + b => getQuery(b).Where(doc => doc.StringField.ArrayContainsAll(new object[] { doc.Id, doc.IntField, doc.GuidField })).Select(doc => doc.StringField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("use non-array field", + b => getQuery(b).Where(doc => doc.StringField.ArrayContainsAll(doc.VectorFloatField)).Select(doc => doc.StringField), + skipVerification: true, + serializeOutput: true), + + new LinqTestInput("composite condition", + b => getQuery(b).Where(doc => + doc.ArrayField.ArrayContainsAll(new object[] { doc.Id, doc.IntField, doc.GuidField }) || + doc.StringField.StartsWith("abc") && + doc.EnumerableField.ArrayContainsAll(doc.ArrayField)).Select(doc => new object[] { doc.Id, doc.IntField, doc.GuidField, doc.StringField, doc.EnumerableField }), + skipVerification: true, + serializeOutput: true), + + new LinqTestInput("projection", + b => getQuery(b).Select(doc => new object[] { doc.ArrayField, doc.ArrayField.ArrayContainsAll( + 1, + doc.StringField, + null, + doc.VectorFloatField, + 10.123456789d, + new object[] { + "world", + new object[] { doc.ToString(), new { A = doc.StringField2.StartsWith("abc"), B = int.MinValue } } }) }), + skipVerification: true, + serializeOutput: true), + + new LinqTestInput("join", + b => getQuery(b).SelectMany(doc => doc.ArrayField).Select(item => new object[] { item, new[] { item }.ArrayContainsAll(item.ToString()) }), + skipVerification: true, + serializeOutput: true), + }; + + this.ExecuteTestSuite(inputs); + } + + [TestMethod] + public void TestArrayContainsAny() + { + const int Records = 10; + const int MaxAbsValue = 10; + const int MaxArraySize = 5; + + int index = 0; + Func createDataObj = (random) => + { + DataObject obj = new DataObject(); + obj.ArrayField = new int[index + 1]; + for (int i = 0; i < obj.ArrayField.Length; ++i) + { + obj.ArrayField[i] = i; + } + obj.EnumerableField = new List(); + for (int i = 0; i < index + MaxArraySize; ++i) + { + obj.EnumerableField.Add(i - MaxArraySize); + } + obj.NumericField = (index * MaxAbsValue * 2) - MaxAbsValue; + obj.Id = GuidFromInt(index).ToString(); + obj.Pk = "Test"; + index++; + return obj; + }; + Func> getQuery = LinqTestsCommon.GenerateTestCosmosData(createDataObj, Records, testContainer); + + List inputs = new List + { + new LinqTestInput("no parameters", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAny()).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("params with 1 int", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAny(1)).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("params with 3 ints", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAny(1, 2, 3)).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("params with 5 ints", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAny(1, 2, 3, int.MaxValue, int.MinValue )).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("params with mixed types", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAny( + 1, + "hello", + null, + 55f, + 10.123456789d, + new object[] { + "world", + new object[] { "nested array", new { A = int.MaxValue, B = int.MinValue } } })).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("same field", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAny(doc.ArrayField)).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("params with mixed types", + b => getQuery(b).Where(doc => new object[] { + 1, + "hello", + null, + 55f, + 10.123456789d, + new object[] { + "world", + new object[] { "nested array", new { A = int.MaxValue, B = int.MinValue } } } + }.ArrayContainsAny(doc.ArrayField)).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("use enumerable field", + b => getQuery(b).Where(doc => doc.EnumerableField.ArrayContainsAny(5, 6, 7, 100)).Select(doc => doc.EnumerableField), + skipVerification: true, + serializeOutput: true), + + new LinqTestInput("use document fields", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAny(new object[] { doc.Id, doc.IntField, doc.GuidField })).Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("params with mixed types using document fields", + b => getQuery(b).Where(doc => doc.ArrayField.ArrayContainsAny( + 1, + doc.StringField, + null, + doc.VectorFloatField, + 10.123456789d, + new object[] { + "world", + new object[] { doc.ToString(), new { A = doc.StringField2.StartsWith("abc"), B = int.MinValue } } })) + .Select(doc => doc.ArrayField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("use document fields", + b => getQuery(b).Where(doc => new object[] { doc.Id, doc.IntField, doc.GuidField }.ArrayContainsAny(doc.Id, doc.IntField, doc.GuidField)).Select(doc => new object[] { doc.Id, doc.IntField, doc.GuidField }), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("use document fields", + b => getQuery(b).Where(doc => new object[] { doc.Id, doc.IntField, doc.GuidField, doc.ArrayField }.ArrayContainsAny(new object[] { doc.ArrayField })).Select(doc => new object[] { doc.Id, doc.IntField, doc.GuidField, doc.ArrayField }), + skipVerification: true, + serializeOutput: true), + + new LinqTestInput("use non-array field", + b => getQuery(b).Where(doc => doc.StringField.ArrayContainsAny(new object[] { doc.Id, doc.IntField, doc.GuidField })).Select(doc => doc.StringField), + skipVerification: true, + serializeOutput: true), + new LinqTestInput("use non-array field", + b => getQuery(b).Where(doc => doc.StringField.ArrayContainsAny(doc.VectorFloatField)).Select(doc => doc.StringField), + skipVerification: true, + serializeOutput: true), + + new LinqTestInput("composite condition", + b => getQuery(b).Where(doc => + doc.ArrayField.ArrayContainsAny(new object[] { doc.Id, doc.IntField, doc.GuidField }) || + doc.StringField.StartsWith("abc") && + doc.EnumerableField.ArrayContainsAny(doc.ArrayField)).Select(doc => new object[] { doc.Id, doc.IntField, doc.GuidField, doc.StringField, doc.EnumerableField }), + skipVerification: true, + serializeOutput: true), + + new LinqTestInput("projection", + b => getQuery(b).Select(doc => new object[] { doc.ArrayField, doc.ArrayField.ArrayContainsAny( + 1, + doc.StringField, + null, + doc.VectorFloatField, + 10.123456789d, + new object[] { + "world", + new object[] { doc.ToString(), new { A = doc.StringField2.StartsWith("abc"), B = int.MinValue } } }) }), + skipVerification: true, + serializeOutput: true), + + new LinqTestInput("join", + b => getQuery(b).SelectMany(doc => doc.ArrayField).Select(item => new object[] { item, new[] { item }.ArrayContainsAny(item.ToString()) }), + skipVerification: true, + serializeOutput: true), + }; + + this.ExecuteTestSuite(inputs); + } + [TestMethod] public void TestArrayFunctions() { @@ -1435,7 +1734,11 @@ public void TestArrayFunctions() new LinqTestInput("Empty list not contains", b => getQuery(b).Select(doc => !emptyList.Contains((int)doc.NumericField))), // Count new LinqTestInput("Count ArrayField", b => getQuery(b).Select(doc => doc.ArrayField.Count())), - new LinqTestInput("Count EnumerableField", b => getQuery(b).Select(doc => doc.EnumerableField.Count())) + new LinqTestInput("Count EnumerableField", b => getQuery(b).Select(doc => doc.EnumerableField.Count())), + + // Unsupported: + // Contains + new LinqTestInput("Contains with EqualityComparer", b => getQuery(b).Select(doc => !doc.EnumerableField.Contains(1, EqualityComparer.Default))), }; this.ExecuteTestSuite(inputs); } diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj index fdde935e9f..35f81927d9 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj @@ -39,6 +39,8 @@ + + @@ -113,6 +115,12 @@ PreserveNewest + + PreserveNewest + + + PreserveNewest + PreserveNewest diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetSDKAPI.net6.json b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetSDKAPI.net6.json index a5478efaf7..8d61d3bd19 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetSDKAPI.net6.json +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetSDKAPI.net6.json @@ -6294,6 +6294,20 @@ ], "MethodInfo": "Microsoft.Azure.Cosmos.QueryDefinition Microsoft.Azure.Cosmos.Linq.CosmosLinqExtensions.ToQueryDefinition[T](System.Linq.IQueryable`1[T]);IsAbstract:False;IsStatic:True;IsVirtual:False;IsGenericMethod:True;IsConstructor:False;IsFinal:False;" }, + "System.Boolean Microsoft.Azure.Cosmos.Linq.CosmosLinqExtensions.ArrayContainsAll[T](System.Collections.Generic.IEnumerable`1[T], System.Object[])[System.Runtime.CompilerServices.ExtensionAttribute()]": { + "Type": "Method", + "Attributes": [ + "ExtensionAttribute" + ], + "MethodInfo": "System.Boolean Microsoft.Azure.Cosmos.Linq.CosmosLinqExtensions.ArrayContainsAll[T](System.Collections.Generic.IEnumerable`1[T], System.Object[]);IsAbstract:False;IsStatic:True;IsVirtual:False;IsGenericMethod:True;IsConstructor:False;IsFinal:False;" + }, + "System.Boolean Microsoft.Azure.Cosmos.Linq.CosmosLinqExtensions.ArrayContainsAny[T](System.Collections.Generic.IEnumerable`1[T], System.Object[])[System.Runtime.CompilerServices.ExtensionAttribute()]": { + "Type": "Method", + "Attributes": [ + "ExtensionAttribute" + ], + "MethodInfo": "System.Boolean Microsoft.Azure.Cosmos.Linq.CosmosLinqExtensions.ArrayContainsAny[T](System.Collections.Generic.IEnumerable`1[T], System.Object[]);IsAbstract:False;IsStatic:True;IsVirtual:False;IsGenericMethod:True;IsConstructor:False;IsFinal:False;" + }, "System.Boolean Microsoft.Azure.Cosmos.Linq.CosmosLinqExtensions.FullTextContains(System.Object, System.String)[System.Runtime.CompilerServices.ExtensionAttribute()]": { "Type": "Method", "Attributes": [