From 8311dd32b84a65fdec1e31d62d8b84e60a2dea43 Mon Sep 17 00:00:00 2001 From: Avinash Jain Date: Sat, 13 Jul 2024 12:27:34 +0530 Subject: [PATCH] Add array_contains_array function Scalar function that takes two arrays an input and checks it all elements of right array are present in left array --- .../src/main/sphinx/functions/array.rst | 6 +- ...uiltInTypeAndFunctionNamespaceManager.java | 2 + .../scalar/ArrayContainsAllFunction.java | 52 +++++++ .../presto/type/TestArrayContainsAll.java | 129 ++++++++++++++++++ 4 files changed, 188 insertions(+), 1 deletion(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayContainsAllFunction.java create mode 100644 presto-main/src/test/java/com/facebook/presto/type/TestArrayContainsAll.java diff --git a/presto-docs/src/main/sphinx/functions/array.rst b/presto-docs/src/main/sphinx/functions/array.rst index 05dee6a99447b..04ccc703dd5c4 100644 --- a/presto-docs/src/main/sphinx/functions/array.rst +++ b/presto-docs/src/main/sphinx/functions/array.rst @@ -40,6 +40,10 @@ Array Functions Returns the average of all non-null elements of the ``array``. If there is no non-null elements, returns ``null``. +.. function:: array_contains_all(x, y) -> boolean + + Returns true if the array ``x`` contains all the elements of array ``y`` including ``nulls``. + .. function:: array_cum_sum(array(T)) -> array(T) Returns the array whose elements are the cumulative sum of the input array, i.e. result[i] = input[1]+input[2]+...+input[i]. @@ -78,7 +82,7 @@ Array Functions .. function:: array_has_duplicates(array(T)) -> boolean Returns a boolean: whether ``array`` has any elements that occur more than once. - Throws an exception if any of the elements are rows or arrays that contain nulls. + Throws an exception if any of the elements are rows or arrays that contain nulls. SELECT array_has_duplicates(ARRAY[1, 2, null, 1, null, 3]) -- true SELECT array_has_duplicates(ARRAY[ROW(1, null), ROW(1, null)]) -- "map key cannot be null or contain nulls" diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 03dc76e6279c9..dba8fb90f9844 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -113,6 +113,7 @@ import com.facebook.presto.operator.scalar.ArrayCardinalityFunction; import com.facebook.presto.operator.scalar.ArrayCombinationsFunction; import com.facebook.presto.operator.scalar.ArrayContains; +import com.facebook.presto.operator.scalar.ArrayContainsAllFunction; import com.facebook.presto.operator.scalar.ArrayCumSum; import com.facebook.presto.operator.scalar.ArrayDistinctFromOperator; import com.facebook.presto.operator.scalar.ArrayDistinctFunction; @@ -854,6 +855,7 @@ private List getBuiltInFunctions(FeaturesConfig featuresC .scalars(DataSizeFunctions.class) .scalar(ArrayCardinalityFunction.class) .scalar(ArrayContains.class) + .scalar(ArrayContainsAllFunction.class) .scalar(ArrayFilterFunction.class) .scalar(ArrayPositionFunction.class) .scalar(ArrayPositionWithIndexFunction.class) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayContainsAllFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayContainsAllFunction.java new file mode 100644 index 0000000000000..01b738fbf02e5 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayContainsAllFunction.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.operator.aggregation.TypedSet; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlNullable; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; + +@ScalarFunction("array_contains_all") +@Description("Returns true if all elements of the second array are present in the first array") +public final class ArrayContainsAllFunction +{ + private ArrayContainsAllFunction() {} + + @TypeParameter("T") + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean containsAll( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block firstArray, + @SqlType("array(T)") Block secondArray) + { + TypedSet firstSet = new TypedSet(elementType, firstArray.getPositionCount(), "arrayContainsAll"); + for (int i = 0; i < firstArray.getPositionCount(); i++) { + firstSet.add(firstArray, i); + } + + for (int i = 0; i < secondArray.getPositionCount(); i++) { + if (!firstSet.contains(secondArray, i)) { + return false; + } + } + return true; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestArrayContainsAll.java b/presto-main/src/test/java/com/facebook/presto/type/TestArrayContainsAll.java new file mode 100644 index 0000000000000..041fbcc072b4c --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/type/TestArrayContainsAll.java @@ -0,0 +1,129 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.type; + +import com.facebook.presto.Session; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.operator.scalar.AbstractTestFunctions; +import com.facebook.presto.operator.scalar.FunctionAssertions; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.SemanticErrorCode; +import com.facebook.presto.sql.analyzer.SemanticException; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static com.facebook.presto.SystemSessionProperties.FIELD_NAMES_IN_JSON_CAST_ENABLED; +import static com.facebook.presto.common.type.UnknownType.UNKNOWN; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.fail; + +public class TestArrayContainsAll + extends AbstractTestFunctions +{ + private static FunctionAssertions fieldNameInJsonCastEnabled; + + public TestArrayContainsAll() {} + + @BeforeClass + public void setUp() + { + registerScalar(getClass()); + fieldNameInJsonCastEnabled = new FunctionAssertions( + Session.builder(session) + .setSystemProperty(FIELD_NAMES_IN_JSON_CAST_ENABLED, "true") + .build(), + new FeaturesConfig()); + } + + @AfterClass(alwaysRun = true) + public final void tearDown() + { + fieldNameInJsonCastEnabled.close(); + fieldNameInJsonCastEnabled = null; + } + + @Test + public void testOverlappingArrays() + { + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [1, 2, 3], ARRAY [2])", BooleanType.BOOLEAN, true); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [1, 2, 3], ARRAY [2, 3])", BooleanType.BOOLEAN, true); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [1, 2, 3], ARRAY [1, 2, 3])", BooleanType.BOOLEAN, true); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [1, 2, 3], ARRAY [1, 1, 2, 3])", BooleanType.BOOLEAN, true); + } + + @Test + public void testDisjointArrays() + { + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [1, 2, 3], ARRAY [4])", BooleanType.BOOLEAN, false); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [1, 2, 3], ARRAY [4, 5])", BooleanType.BOOLEAN, false); + } + + @Test + public void testEmptyArrays() + { + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [], ARRAY [])", BooleanType.BOOLEAN, true); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [1, 2, 3], ARRAY [])", BooleanType.BOOLEAN, true); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [], ARRAY [1, 2, 3])", BooleanType.BOOLEAN, false); + } + + @Test + public void testDifferentDataTypes() + { + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [1.0, 2.0, 3.0], ARRAY [2.0])", BooleanType.BOOLEAN, true); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [1.0, 2.0, 3.0], ARRAY [4.0])", BooleanType.BOOLEAN, false); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY ['a', 'b', 'c'], ARRAY ['b'])", BooleanType.BOOLEAN, true); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY ['a', 'b', 'c'], ARRAY ['d'])", BooleanType.BOOLEAN, false); + } + + @Test + public void testNulls() + { + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [NULL, 2, 3], ARRAY [2])", BooleanType.BOOLEAN, true); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [1, NULL, 3], ARRAY [NULL])", BooleanType.BOOLEAN, true); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [1, 2, 3], ARRAY [NULL])", BooleanType.BOOLEAN, false); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [NULL, NULL], ARRAY [NULL])", BooleanType.BOOLEAN, true); + assertFunction("ARRAY_CONTAINS_ALL(NULL, ARRAY [1])", BooleanType.BOOLEAN, null); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY [1], NULL)", BooleanType.BOOLEAN, null); + assertFunction("ARRAY_CONTAINS_ALL(NULL, NULL)", BooleanType.BOOLEAN, null); + } + + @Test + public void testArrayOfArrays() + { + assertFunction("ARRAY_CONTAINS_ALL(ARRAY[ARRAY[1, 2], ARRAY[3, 4]], ARRAY[ARRAY[1, 2]])", BooleanType.BOOLEAN, true); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY[ARRAY[1, 2], ARRAY[3, 4]], ARRAY[ARRAY[3, 4], ARRAY[5, 6]])", BooleanType.BOOLEAN, false); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY[ARRAY[1, 2], ARRAY[3, 4]], ARRAY[ARRAY[5, 6]])", BooleanType.BOOLEAN, false); + } + + @Test + public void testArrayOfRows() + { + assertFunction("ARRAY_CONTAINS_ALL(ARRAY[ROW(1, 'a'), ROW(2, 'b')], ARRAY[ROW(1, 'a')])", BooleanType.BOOLEAN, true); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY[ROW(1, 'a'), ROW(2, 'b')], ARRAY[ROW(2, 'b'), ROW(3, 'c')])", BooleanType.BOOLEAN, false); + assertFunction("ARRAY_CONTAINS_ALL(ARRAY[ROW(1, 'a'), ROW(2, 'b')], ARRAY[ROW(3, 'c')])", BooleanType.BOOLEAN, false); + } + + @Override + public void assertInvalidFunction(String projection, SemanticErrorCode errorCode) + { + try { + assertFunction(projection, UNKNOWN, null); + fail("Expected error " + errorCode + " from " + projection); + } + catch (SemanticException e) { + assertEquals(e.getCode(), errorCode); + } + } +}