From 1ebdab05fbaeaeaa783e6421bf3eb97475989315 Mon Sep 17 00:00:00 2001 From: Feilong Liu Date: Thu, 8 Sep 2022 13:21:41 -0700 Subject: [PATCH] Add array find UDF Add an UDF find for array, which returns the first array element which matches the predicate. Returns null if no match found. --- .../src/main/sphinx/functions/array.rst | 15 ++ ...uiltInTypeAndFunctionNamespaceManager.java | 4 + .../scalar/ArrayFindFirstFirstFunction.java | 86 +++++++ .../ArrayFindFirstWithOffsetFunction.java | 232 ++++++++++++++++++ .../scalar/TestArrayFindFirstFunction.java | 103 ++++++++ 5 files changed, 440 insertions(+) create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFindFirstFirstFunction.java create mode 100644 presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFindFirstWithOffsetFunction.java create mode 100644 presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFindFirstFunction.java diff --git a/presto-docs/src/main/sphinx/functions/array.rst b/presto-docs/src/main/sphinx/functions/array.rst index 2adf8dbb50765..3b91ef71d5572 100644 --- a/presto-docs/src/main/sphinx/functions/array.rst +++ b/presto-docs/src/main/sphinx/functions/array.rst @@ -208,6 +208,21 @@ Array Functions Flattens an ``array(array(T))`` to an ``array(T)`` by concatenating the contained arrays. +.. function:: find_first(array(E), function(T,boolean)) -> E + + Returns the first element of ``array`` which returns true for ``function(T,boolean)``. Returns ``NULL`` if no such element exists. + +.. function:: find_first(array(E), index, function(T,boolean)) -> E + + Returns the first element of ``array`` which returns true for ``function(T,boolean)``. Returns ``NULL`` if no such element exists. + If ``index`` > 0, the search for element starts at position ``index`` until the end of array. + If ``index`` < 0, the search for element starts at position ``abs(index)`` counting from last, until the start of array. :: + + SELECT find_first(ARRAY[3, 4, 5, 6], 2, x -> x > 0); -- 4 + SELECT find_first(ARRAY[3, 4, 5, 6], -2, x -> x > 0); -- 5 + SELECT find_first(ARRAY[3, 4, 5, 6], 2, x -> x < 4); -- NULL + SELECT find_first(ARRAY[3, 4, 5, 6], -2, x -> x > 5); -- NULL + .. function:: ngrams(array(T), n) -> array(array(T)) Returns ``n``-grams for the ``array``:: diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 8ba8851c0ba43..e7dd971da47de 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -101,6 +101,8 @@ import com.facebook.presto.operator.scalar.ArrayEqualOperator; import com.facebook.presto.operator.scalar.ArrayExceptFunction; import com.facebook.presto.operator.scalar.ArrayFilterFunction; +import com.facebook.presto.operator.scalar.ArrayFindFirstFirstFunction; +import com.facebook.presto.operator.scalar.ArrayFindFirstWithOffsetFunction; import com.facebook.presto.operator.scalar.ArrayFunctions; import com.facebook.presto.operator.scalar.ArrayGreaterThanOperator; import com.facebook.presto.operator.scalar.ArrayGreaterThanOrEqualOperator; @@ -808,6 +810,8 @@ private List getBuildInFunctions(FeaturesConfig featuresC .scalar(ArrayNgramsFunction.class) .scalar(ArrayAllMatchFunction.class) .scalar(ArrayAnyMatchFunction.class) + .scalar(ArrayFindFirstFirstFunction.class) + .scalar(ArrayFindFirstWithOffsetFunction.class) .scalar(ArrayNoneMatchFunction.class) .scalar(ArrayNormalizeFunction.class) .scalar(MapDistinctFromOperator.class) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFindFirstFirstFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFindFirstFirstFunction.java new file mode 100644 index 0000000000000..5d83c09c5a898 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFindFirstFirstFunction.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlNullable; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import io.airlift.slice.Slice; + +@Description("Return the first element which matches the given predicate, null if no match") +@ScalarFunction(value = "find_first", deterministic = true) +public final class ArrayFindFirstFirstFunction + extends ArrayFindFirstWithOffsetFunction +{ + private ArrayFindFirstFirstFunction() {} + + @TypeParameter("T") + @SqlType("T") + @SqlNullable + public static Block findBlock( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") BlockToBooleanFunction function) + { + return findBlockUtil(elementType, arrayBlock, 1, function); + } + + @TypeParameter("T") + @SqlType("T") + @SqlNullable + public static Slice findSlice( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") SliceToBooleanFunction function) + { + return findSliceUtil(elementType, arrayBlock, 1, function); + } + + @TypeParameter("T") + @SqlType("T") + @SqlNullable + public static Long findLong( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") LongToBooleanFunction function) + { + return findLongUtil(elementType, arrayBlock, 1, function); + } + + @TypeParameter("T") + @SqlType("T") + @SqlNullable + public static Double findDouble( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") DoubleToBooleanFunction function) + { + return findDoubleUtil(elementType, arrayBlock, 1, function); + } + + @TypeParameter("T") + @SqlType("T") + @SqlNullable + public static Boolean findBoolean( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") BooleanToBooleanFunction function) + { + return findBooleanUtil(elementType, arrayBlock, 1, function); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFindFirstWithOffsetFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFindFirstWithOffsetFunction.java new file mode 100644 index 0000000000000..fa4a1864ce497 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayFindFirstWithOffsetFunction.java @@ -0,0 +1,232 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlNullable; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; +import io.airlift.slice.Slice; + +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static java.lang.Boolean.TRUE; +import static java.lang.Math.toIntExact; + +@Description("Return the first element which matches the given predicate, null if no match") +@ScalarFunction(value = "find_first", deterministic = true) +public class ArrayFindFirstWithOffsetFunction +{ + protected ArrayFindFirstWithOffsetFunction() {} + + @TypeParameter("T") + @SqlType("T") + @SqlNullable + public static Block findBlockWithOffset( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType(StandardTypes.BIGINT) long offset, + @SqlType("function(T, boolean)") BlockToBooleanFunction function) + { + return findBlockUtil(elementType, arrayBlock, offset, function); + } + + @TypeParameter("T") + @SqlType("T") + @SqlNullable + public static Slice findSliceWithOffset( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType(StandardTypes.BIGINT) long offset, + @SqlType("function(T, boolean)") SliceToBooleanFunction function) + { + return findSliceUtil(elementType, arrayBlock, offset, function); + } + + @TypeParameter("T") + @SqlType("T") + @SqlNullable + public static Long findLongWithOffset( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType(StandardTypes.BIGINT) long offset, + @SqlType("function(T, boolean)") LongToBooleanFunction function) + { + return findLongUtil(elementType, arrayBlock, offset, function); + } + + @TypeParameter("T") + @SqlType("T") + @SqlNullable + public static Double findDoubleWithOffset( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType(StandardTypes.BIGINT) long offset, + @SqlType("function(T, boolean)") DoubleToBooleanFunction function) + { + return findDoubleUtil(elementType, arrayBlock, offset, function); + } + + @TypeParameter("T") + @SqlType("T") + @SqlNullable + public static Boolean findBooleanWithOffset( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType(StandardTypes.BIGINT) long offset, + @SqlType("function(T, boolean)") BooleanToBooleanFunction function) + { + return findBooleanUtil(elementType, arrayBlock, offset, function); + } + + public static Block findBlockUtil( + Type elementType, + Block arrayBlock, + long offset, + BlockToBooleanFunction function) + { + int startPosition = checkedIndexToBlockPosition(arrayBlock, offset); + if (startPosition < 0) { + return null; + } + int increment = offset > 0 ? 1 : -1; + for (int i = startPosition; i < arrayBlock.getPositionCount() && i >= 0; i += increment) { + Block element = null; + if (!arrayBlock.isNull(i)) { + element = (Block) elementType.getObject(arrayBlock, i); + } + Boolean match = function.apply(element); + if (TRUE.equals(match)) { + return element; + } + } + return null; + } + + public static Slice findSliceUtil( + Type elementType, + Block arrayBlock, + long offset, + SliceToBooleanFunction function) + { + int startPosition = checkedIndexToBlockPosition(arrayBlock, offset); + if (startPosition < 0) { + return null; + } + int increment = offset > 0 ? 1 : -1; + for (int i = startPosition; i < arrayBlock.getPositionCount() && i >= 0; i += increment) { + Slice element = null; + if (!arrayBlock.isNull(i)) { + element = elementType.getSlice(arrayBlock, i); + } + Boolean match = function.apply(element); + if (TRUE.equals(match)) { + return element; + } + } + return null; + } + + public static Long findLongUtil( + Type elementType, + Block arrayBlock, + long offset, + LongToBooleanFunction function) + { + int startPosition = checkedIndexToBlockPosition(arrayBlock, offset); + if (startPosition < 0) { + return null; + } + int increment = offset > 0 ? 1 : -1; + for (int i = startPosition; i < arrayBlock.getPositionCount() && i >= 0; i += increment) { + Long element = null; + if (!arrayBlock.isNull(i)) { + element = elementType.getLong(arrayBlock, i); + } + Boolean match = function.apply(element); + if (TRUE.equals(match)) { + return element; + } + } + return null; + } + + public static Double findDoubleUtil( + Type elementType, + Block arrayBlock, + long offset, + DoubleToBooleanFunction function) + { + int startPosition = checkedIndexToBlockPosition(arrayBlock, offset); + if (startPosition < 0) { + return null; + } + int increment = offset > 0 ? 1 : -1; + for (int i = startPosition; i < arrayBlock.getPositionCount() && i >= 0; i += increment) { + Double element = null; + if (!arrayBlock.isNull(i)) { + element = elementType.getDouble(arrayBlock, i); + } + Boolean match = function.apply(element); + if (TRUE.equals(match)) { + return element; + } + } + return null; + } + + public static Boolean findBooleanUtil( + Type elementType, + Block arrayBlock, + long offset, + BooleanToBooleanFunction function) + { + int startPosition = checkedIndexToBlockPosition(arrayBlock, offset); + if (startPosition < 0) { + return null; + } + int increment = offset > 0 ? 1 : -1; + for (int i = startPosition; i < arrayBlock.getPositionCount() && i >= 0; i += increment) { + Boolean element = null; + if (!arrayBlock.isNull(i)) { + element = elementType.getBoolean(arrayBlock, i); + } + Boolean match = function.apply(element); + if (TRUE.equals(match)) { + return element; + } + } + return null; + } + + /** + * @return PrestoException if the index is 0, -1 if the index is out of range (to tell the calling function to return null), and the element position otherwise. + */ + private static int checkedIndexToBlockPosition(Block block, long index) + { + int arrayLength = block.getPositionCount(); + if (index == 0) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "SQL array indices start at 1"); + } + if (Math.abs(index) > arrayLength) { + return -1; // -1 indicates that the element is out of range and "ELEMENT_AT" should return null + } + index = index > 0 ? index - 1 : arrayLength + index; + return toIntExact(index); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFindFirstFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFindFirstFunction.java new file mode 100644 index 0000000000000..30fce9a9a764d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayFindFirstFunction.java @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.type.ArrayType; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.UnknownType.UNKNOWN; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static java.util.Arrays.asList; + +public class TestArrayFindFirstFunction + extends AbstractTestFunctions +{ + @Test + public void testBasic() + { + assertFunction("find_first(ARRAY [5, 6], x -> x = 5)", INTEGER, 5); + assertFunction("find_first(ARRAY [BIGINT '5', BIGINT '6'], x -> x = 5)", BIGINT, 5L); + assertFunction("find_first(ARRAY [5, 6], x -> x > 5)", INTEGER, 6); + assertFunction("find_first(ARRAY [null, false, true, false, true, false], x -> nullif(x, false))", BOOLEAN, true); + assertFunction("find_first(ARRAY [null, true, false, null, true, false, null], x -> not x)", BOOLEAN, false); + assertFunction("find_first(ARRAY [4.8E0, 6.2E0], x -> x > 5)", DOUBLE, 6.2); + assertFunction("find_first(ARRAY ['abc', 'def', 'ayz'], x -> substr(x, 1, 1) = 'a')", createVarcharType(3), "abc"); + assertFunction( + "find_first(ARRAY [ARRAY ['abc', null, '123'], ARRAY ['def', 'x', '456']], x -> x[2] IS NULL)", + new ArrayType(createVarcharType(3)), + asList("abc", null, "123")); + } + + @Test + public void testPositiveOffset() + { + assertFunction("find_first(ARRAY [5, 6], 2, x -> x = 5)", INTEGER, null); + assertFunction("find_first(ARRAY [5, 6], 4, x -> x > 0)", INTEGER, null); + assertFunction("find_first(ARRAY [5, 6, 7, 8], 3, x -> x > 5)", INTEGER, 7); + assertFunction("find_first(ARRAY [3, 4, 5, 6], 2, x -> x > 0)", INTEGER, 4); + assertFunction("find_first(ARRAY [3, 4, 5, 6], 2, x -> x < 4)", INTEGER, null); + assertFunction("find_first(ARRAY [null, false, true, null, true, false], 4, x -> nullif(x, false))", BOOLEAN, true); + assertFunction("find_first(ARRAY [4.8E0, 6.2E0, 7.8E0], 3, x -> x > 5)", DOUBLE, 7.8); + assertFunction("find_first(ARRAY ['abc', 'def', 'ayz'], 2, x -> substr(x, 1, 1) = 'a')", createVarcharType(3), "ayz"); + assertFunction( + "find_first(ARRAY [ARRAY ['abc', null, '123'], ARRAY ['def', null, '456']], 2, x -> x[2] IS NULL)", + new ArrayType(createVarcharType(3)), + asList("def", null, "456")); + } + + @Test + public void testNegativeOffset() + { + assertFunction("find_first(ARRAY [5, 6], -2, x -> x > 5)", INTEGER, null); + assertFunction("find_first(ARRAY [5, 6], -4, x -> x > 0)", INTEGER, null); + assertFunction("find_first(ARRAY [5, 6, 7, 8], -2, x -> x > 5)", INTEGER, 7); + assertFunction("find_first(ARRAY [9, 6, 3, 8], -2, x -> x > 5)", INTEGER, 6); + assertFunction("find_first(ARRAY [3, 4, 5, 6], -2, x -> x > 0)", INTEGER, 5); + assertFunction("find_first(ARRAY [3, 4, 5, 6], -2, x -> x > 5)", INTEGER, null); + assertFunction("find_first(ARRAY [null, false, true, null, true, false], -3, x -> nullif(x, false))", BOOLEAN, true); + assertFunction("find_first(ARRAY [4.8E0, 6.2E0, 7.8E0], -2, x -> x > 5)", DOUBLE, 6.2); + assertFunction("find_first(ARRAY ['abc', 'def', 'ayz'], -2, x -> substr(x, 1, 1) = 'a')", createVarcharType(3), "abc"); + assertFunction( + "find_first(ARRAY [ARRAY ['abc', null, '123'], ARRAY ['def', null, '456']], -2, x -> x[2] IS NULL)", + new ArrayType(createVarcharType(3)), + asList("abc", null, "123")); + } + + @Test + public void testEmpty() + { + assertFunction("find_first(ARRAY [], x -> true)", UNKNOWN, null); + assertFunction("find_first(ARRAY [], x -> false)", UNKNOWN, null); + assertFunction("find_first(CAST (ARRAY [] AS ARRAY(INTEGER)), x -> true)", INTEGER, null); + } + + @Test + public void testNullArray() + { + assertFunction("find_first(ARRAY [NULL], x -> x IS NULL)", UNKNOWN, null); + assertFunction("find_first(ARRAY [NULL], x -> x IS NOT NULL)", UNKNOWN, null); + assertFunction("find_first(ARRAY [CAST (NULL AS INTEGER)], x -> x IS NULL)", INTEGER, null); + } + + @Test + public void testNull() + { + assertFunction("find_first(NULL, x -> true)", UNKNOWN, null); + assertFunction("find_first(NULL, x -> false)", UNKNOWN, null); + } +}