diff --git a/presto-docs/src/main/sphinx/functions/array.rst b/presto-docs/src/main/sphinx/functions/array.rst index a384d4fea3c0..bbc6e9cb5ec0 100644 --- a/presto-docs/src/main/sphinx/functions/array.rst +++ b/presto-docs/src/main/sphinx/functions/array.rst @@ -21,6 +21,20 @@ The ``||`` operator is used to concatenate an array with an array or an element Array Functions --------------- +.. function:: all_match(array(T), function(T,boolean)) -> boolean + + Returns whether all elements of an array match the given predicate. Returns ``true`` if all the elements + match the predicate (a special case is when the array is empty); ``false`` if one or more elements don't + match; ``NULL`` if the predicate function returns ``NULL`` for one or more elements and ``true`` for all + other elements. + +.. function:: any_match(array(T), function(T,boolean)) -> boolean + + Returns whether any elements of an array match the given predicate. Returns ``true`` if one or more + elements match the predicate; ``false`` if none of the elements matches (a special case is when the + array is empty); ``NULL`` if the predicate function returns ``NULL`` for one or more elements and ``false`` + for all other elements. + .. function:: array_distinct(x) -> array Remove duplicate values from the array ``x``. @@ -153,6 +167,12 @@ Array Functions SELECT ngrams(ARRAY['foo', 'bar', 'baz', 'foo'], 5); -- [['foo', 'bar', 'baz', 'foo']] SELECT ngrams(ARRAY[1, 2, 3, 4], 2); -- [[1, 2], [2, 3], [3, 4]] +.. function:: none_match(array(T), function(T,boolean)) -> boolean + + Returns whether no elements of an array match the given predicate. Returns ``true`` if none of the elements + matches the predicate (a special case is when the array is empty); ``false`` if one or more elements match; + ``NULL`` if the predicate function returns ``NULL`` for one or more elements and ``false`` for all other elements. + .. function:: reduce(array(T), initialState S, inputFunction(S,T,S), outputFunction(S,R)) -> R Returns a single value reduced from ``array``. ``inputFunction`` will diff --git a/presto-main/src/main/java/io/prestosql/metadata/FunctionRegistry.java b/presto-main/src/main/java/io/prestosql/metadata/FunctionRegistry.java index cdb60181d75e..5393f14ec534 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/FunctionRegistry.java +++ b/presto-main/src/main/java/io/prestosql/metadata/FunctionRegistry.java @@ -67,6 +67,8 @@ import io.prestosql.operator.aggregation.arrayagg.ArrayAggregationFunction; import io.prestosql.operator.aggregation.histogram.Histogram; import io.prestosql.operator.aggregation.multimapagg.MultimapAggregationFunction; +import io.prestosql.operator.scalar.ArrayAllMatchFunction; +import io.prestosql.operator.scalar.ArrayAnyMatchFunction; import io.prestosql.operator.scalar.ArrayCardinalityFunction; import io.prestosql.operator.scalar.ArrayCombinationsFunction; import io.prestosql.operator.scalar.ArrayContains; @@ -87,6 +89,7 @@ import io.prestosql.operator.scalar.ArrayMaxFunction; import io.prestosql.operator.scalar.ArrayMinFunction; import io.prestosql.operator.scalar.ArrayNgramsFunction; +import io.prestosql.operator.scalar.ArrayNoneMatchFunction; import io.prestosql.operator.scalar.ArrayNotEqualOperator; import io.prestosql.operator.scalar.ArrayPositionFunction; import io.prestosql.operator.scalar.ArrayRemoveFunction; @@ -563,6 +566,9 @@ public FunctionRegistry(Metadata metadata, FeaturesConfig featuresConfig) .scalar(ArrayIndeterminateOperator.class) .scalar(ArrayCombinationsFunction.class) .scalar(ArrayNgramsFunction.class) + .scalar(ArrayAllMatchFunction.class) + .scalar(ArrayAnyMatchFunction.class) + .scalar(ArrayNoneMatchFunction.class) .scalar(MapDistinctFromOperator.class) .scalar(MapEqualOperator.class) .scalar(MapEntriesFunction.class) diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayAllMatchFunction.java b/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayAllMatchFunction.java new file mode 100644 index 000000000000..06018ea13ba3 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayAllMatchFunction.java @@ -0,0 +1,183 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.scalar; + +import io.airlift.slice.Slice; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.function.Description; +import io.prestosql.spi.function.ScalarFunction; +import io.prestosql.spi.function.SqlNullable; +import io.prestosql.spi.function.SqlType; +import io.prestosql.spi.function.TypeParameter; +import io.prestosql.spi.function.TypeParameterSpecialization; +import io.prestosql.spi.type.StandardTypes; +import io.prestosql.spi.type.Type; + +import static java.lang.Boolean.FALSE; + +@Description("Returns true if all elements of the array match the given predicate") +@ScalarFunction(value = "all_match") +public final class ArrayAllMatchFunction +{ + private ArrayAllMatchFunction() {} + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = Block.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean allMatchBlock( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") BlockToBooleanFunction function) + { + boolean hasNullResult = false; + for (int i = 0; i < arrayBlock.getPositionCount(); i++) { + Block element = null; + if (!arrayBlock.isNull(i)) { + element = (Block) elementType.getObject(arrayBlock, i); + } + Boolean match = function.apply(element); + if (FALSE.equals(match)) { + return false; + } + if (match == null) { + hasNullResult = true; + } + } + if (hasNullResult) { + return null; + } + return true; + } + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = Slice.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean allMatchSlice( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") SliceToBooleanFunction function) + { + boolean hasNullResult = false; + int positionCount = arrayBlock.getPositionCount(); + for (int i = 0; i < positionCount; i++) { + Slice element = null; + if (!arrayBlock.isNull(i)) { + element = elementType.getSlice(arrayBlock, i); + } + Boolean match = function.apply(element); + if (FALSE.equals(match)) { + return false; + } + if (match == null) { + hasNullResult = true; + } + } + if (hasNullResult) { + return null; + } + return true; + } + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = long.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean allMatchLong( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") LongToBooleanFunction function) + { + boolean hasNullResult = false; + int positionCount = arrayBlock.getPositionCount(); + for (int i = 0; i < positionCount; i++) { + Long element = null; + if (!arrayBlock.isNull(i)) { + element = elementType.getLong(arrayBlock, i); + } + Boolean match = function.apply(element); + if (FALSE.equals(match)) { + return false; + } + if (match == null) { + hasNullResult = true; + } + } + if (hasNullResult) { + return null; + } + return true; + } + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = double.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean allMatchDouble( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") DoubleToBooleanFunction function) + { + boolean hasNullResult = false; + int positionCount = arrayBlock.getPositionCount(); + for (int i = 0; i < positionCount; i++) { + Double element = null; + if (!arrayBlock.isNull(i)) { + element = elementType.getDouble(arrayBlock, i); + } + Boolean match = function.apply(element); + if (FALSE.equals(match)) { + return false; + } + if (match == null) { + hasNullResult = true; + } + } + if (hasNullResult) { + return null; + } + return true; + } + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = boolean.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean allMatchBoolean( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") BooleanToBooleanFunction function) + { + boolean hasNullResult = false; + int positionCount = arrayBlock.getPositionCount(); + for (int i = 0; i < positionCount; i++) { + Boolean element = null; + if (!arrayBlock.isNull(i)) { + element = elementType.getBoolean(arrayBlock, i); + } + Boolean match = function.apply(element); + if (FALSE.equals(match)) { + return false; + } + if (match == null) { + hasNullResult = true; + } + } + if (hasNullResult) { + return null; + } + return true; + } +} diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayAnyMatchFunction.java b/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayAnyMatchFunction.java new file mode 100644 index 000000000000..1126e0776ace --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayAnyMatchFunction.java @@ -0,0 +1,183 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.scalar; + +import io.airlift.slice.Slice; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.function.Description; +import io.prestosql.spi.function.ScalarFunction; +import io.prestosql.spi.function.SqlNullable; +import io.prestosql.spi.function.SqlType; +import io.prestosql.spi.function.TypeParameter; +import io.prestosql.spi.function.TypeParameterSpecialization; +import io.prestosql.spi.type.StandardTypes; +import io.prestosql.spi.type.Type; + +import static java.lang.Boolean.TRUE; + +@Description("Returns true if the array contains one or more elements that match the given predicate") +@ScalarFunction(value = "any_match") +public final class ArrayAnyMatchFunction +{ + private ArrayAnyMatchFunction() {} + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = Block.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean anyMatchBlock( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") BlockToBooleanFunction function) + { + boolean hasNullResult = false; + for (int i = 0; i < arrayBlock.getPositionCount(); i++) { + Block element = null; + if (!arrayBlock.isNull(i)) { + element = (Block) elementType.getObject(arrayBlock, i); + } + Boolean match = function.apply(element); + if (TRUE.equals(match)) { + return true; + } + if (match == null) { + hasNullResult = true; + } + } + if (hasNullResult) { + return null; + } + return false; + } + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = Slice.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean anyMatchSlice( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") SliceToBooleanFunction function) + { + boolean hasNullResult = false; + int positionCount = arrayBlock.getPositionCount(); + for (int i = 0; i < positionCount; i++) { + Slice element = null; + if (!arrayBlock.isNull(i)) { + element = elementType.getSlice(arrayBlock, i); + } + Boolean match = function.apply(element); + if (TRUE.equals(match)) { + return true; + } + if (match == null) { + hasNullResult = true; + } + } + if (hasNullResult) { + return null; + } + return false; + } + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = long.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean anyMatchLong( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") LongToBooleanFunction function) + { + boolean hasNullResult = false; + int positionCount = arrayBlock.getPositionCount(); + for (int i = 0; i < positionCount; i++) { + Long element = null; + if (!arrayBlock.isNull(i)) { + element = elementType.getLong(arrayBlock, i); + } + Boolean match = function.apply(element); + if (TRUE.equals(match)) { + return true; + } + if (match == null) { + hasNullResult = true; + } + } + if (hasNullResult) { + return null; + } + return false; + } + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = double.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean anyMatchDouble( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") DoubleToBooleanFunction function) + { + boolean hasNullResult = false; + int positionCount = arrayBlock.getPositionCount(); + for (int i = 0; i < positionCount; i++) { + Double element = null; + if (!arrayBlock.isNull(i)) { + element = elementType.getDouble(arrayBlock, i); + } + Boolean match = function.apply(element); + if (TRUE.equals(match)) { + return true; + } + if (match == null) { + hasNullResult = true; + } + } + if (hasNullResult) { + return null; + } + return false; + } + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = boolean.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean anyMatchBoolean( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") BooleanToBooleanFunction function) + { + boolean hasNullResult = false; + int positionCount = arrayBlock.getPositionCount(); + for (int i = 0; i < positionCount; i++) { + Boolean element = null; + if (!arrayBlock.isNull(i)) { + element = elementType.getBoolean(arrayBlock, i); + } + Boolean match = function.apply(element); + if (TRUE.equals(match)) { + return true; + } + if (match == null) { + hasNullResult = true; + } + } + if (hasNullResult) { + return null; + } + return false; + } +} diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayFilterFunction.java b/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayFilterFunction.java index a524f098908c..535b6c6b0231 100644 --- a/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayFilterFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayFilterFunction.java @@ -22,7 +22,6 @@ import io.prestosql.spi.function.TypeParameter; import io.prestosql.spi.function.TypeParameterSpecialization; import io.prestosql.spi.type.Type; -import io.prestosql.sql.gen.lambda.LambdaFunctionInterface; import static java.lang.Boolean.TRUE; @@ -38,7 +37,7 @@ private ArrayFilterFunction() {} public static Block filterLong( @TypeParameter("T") Type elementType, @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") FilterLongLambda function) + @SqlType("function(T, boolean)") LongToBooleanFunction function) { int positionCount = arrayBlock.getPositionCount(); BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount); @@ -62,7 +61,7 @@ public static Block filterLong( public static Block filterDouble( @TypeParameter("T") Type elementType, @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") FilterDoubleLambda function) + @SqlType("function(T, boolean)") DoubleToBooleanFunction function) { int positionCount = arrayBlock.getPositionCount(); BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount); @@ -86,7 +85,7 @@ public static Block filterDouble( public static Block filterBoolean( @TypeParameter("T") Type elementType, @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") FilterBooleanLambda function) + @SqlType("function(T, boolean)") BooleanToBooleanFunction function) { int positionCount = arrayBlock.getPositionCount(); BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount); @@ -110,7 +109,7 @@ public static Block filterBoolean( public static Block filterSlice( @TypeParameter("T") Type elementType, @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") FilterSliceLambda function) + @SqlType("function(T, boolean)") SliceToBooleanFunction function) { int positionCount = arrayBlock.getPositionCount(); BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount); @@ -134,7 +133,7 @@ public static Block filterSlice( public static Block filterBlock( @TypeParameter("T") Type elementType, @SqlType("array(T)") Block arrayBlock, - @SqlType("function(T, boolean)") FilterBlockLambda function) + @SqlType("function(T, boolean)") BlockToBooleanFunction function) { int positionCount = arrayBlock.getPositionCount(); BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount); @@ -151,46 +150,4 @@ public static Block filterBlock( } return resultBuilder.build(); } - - @FunctionalInterface - public interface FilterLongLambda - extends LambdaFunctionInterface - { - Boolean apply(Long x); - } - - @FunctionalInterface - public interface FilterDoubleLambda - extends LambdaFunctionInterface - { - Boolean apply(Double x); - } - - @FunctionalInterface - public interface FilterBooleanLambda - extends LambdaFunctionInterface - { - Boolean apply(Boolean x); - } - - @FunctionalInterface - public interface FilterSliceLambda - extends LambdaFunctionInterface - { - Boolean apply(Slice x); - } - - @FunctionalInterface - public interface FilterBlockLambda - extends LambdaFunctionInterface - { - Boolean apply(Block x); - } - - @FunctionalInterface - public interface FilterVoidLambda - extends LambdaFunctionInterface - { - Boolean apply(Void x); - } } diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayNoneMatchFunction.java b/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayNoneMatchFunction.java new file mode 100644 index 000000000000..052e3006575a --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/ArrayNoneMatchFunction.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.scalar; + +import io.airlift.slice.Slice; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.function.Description; +import io.prestosql.spi.function.ScalarFunction; +import io.prestosql.spi.function.SqlNullable; +import io.prestosql.spi.function.SqlType; +import io.prestosql.spi.function.TypeParameter; +import io.prestosql.spi.function.TypeParameterSpecialization; +import io.prestosql.spi.type.StandardTypes; +import io.prestosql.spi.type.Type; + +@Description("Returns true if all elements of the array don't match the given predicate") +@ScalarFunction(value = "none_match") +public final class ArrayNoneMatchFunction +{ + private ArrayNoneMatchFunction() {} + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = Block.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean noneMatchBlock( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") BlockToBooleanFunction function) + { + Boolean anyMatchResult = ArrayAnyMatchFunction.anyMatchBlock(elementType, arrayBlock, function); + if (anyMatchResult == null) { + return null; + } + return !anyMatchResult; + } + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = Slice.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean noneMatchSlice( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") SliceToBooleanFunction function) + { + Boolean anyMatchResult = ArrayAnyMatchFunction.anyMatchSlice(elementType, arrayBlock, function); + if (anyMatchResult == null) { + return null; + } + return !anyMatchResult; + } + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = long.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean noneMatchLong( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") LongToBooleanFunction function) + { + Boolean anyMatchResult = ArrayAnyMatchFunction.anyMatchLong(elementType, arrayBlock, function); + if (anyMatchResult == null) { + return null; + } + return !anyMatchResult; + } + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = double.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean noneMatchDouble( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") DoubleToBooleanFunction function) + { + Boolean anyMatchResult = ArrayAnyMatchFunction.anyMatchDouble(elementType, arrayBlock, function); + if (anyMatchResult == null) { + return null; + } + return !anyMatchResult; + } + + @TypeParameter("T") + @TypeParameterSpecialization(name = "T", nativeContainerType = boolean.class) + @SqlType(StandardTypes.BOOLEAN) + @SqlNullable + public static Boolean noneMatchBoolean( + @TypeParameter("T") Type elementType, + @SqlType("array(T)") Block arrayBlock, + @SqlType("function(T, boolean)") BooleanToBooleanFunction function) + { + Boolean anyMatchResult = ArrayAnyMatchFunction.anyMatchBoolean(elementType, arrayBlock, function); + if (anyMatchResult == null) { + return null; + } + return !anyMatchResult; + } +} diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/BlockToBooleanFunction.java b/presto-main/src/main/java/io/prestosql/operator/scalar/BlockToBooleanFunction.java new file mode 100644 index 000000000000..677de8f6b463 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/BlockToBooleanFunction.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.scalar; + +import io.prestosql.spi.block.Block; +import io.prestosql.sql.gen.lambda.LambdaFunctionInterface; + +@FunctionalInterface +public interface BlockToBooleanFunction + extends LambdaFunctionInterface +{ + Boolean apply(Block x); +} diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/BooleanToBooleanFunction.java b/presto-main/src/main/java/io/prestosql/operator/scalar/BooleanToBooleanFunction.java new file mode 100644 index 000000000000..980013e98a1c --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/BooleanToBooleanFunction.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.scalar; + +import io.prestosql.sql.gen.lambda.LambdaFunctionInterface; + +@FunctionalInterface +public interface BooleanToBooleanFunction + extends LambdaFunctionInterface +{ + Boolean apply(Boolean x); +} diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/DoubleToBooleanFunction.java b/presto-main/src/main/java/io/prestosql/operator/scalar/DoubleToBooleanFunction.java new file mode 100644 index 000000000000..32f1f046d512 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/DoubleToBooleanFunction.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.scalar; + +import io.prestosql.sql.gen.lambda.LambdaFunctionInterface; + +@FunctionalInterface +public interface DoubleToBooleanFunction + extends LambdaFunctionInterface +{ + Boolean apply(Double x); +} diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/LongToBooleanFunction.java b/presto-main/src/main/java/io/prestosql/operator/scalar/LongToBooleanFunction.java new file mode 100644 index 000000000000..ce9fad2b4f3c --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/LongToBooleanFunction.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.scalar; + +import io.prestosql.sql.gen.lambda.LambdaFunctionInterface; + +@FunctionalInterface +public interface LongToBooleanFunction + extends LambdaFunctionInterface +{ + Boolean apply(Long x); +} diff --git a/presto-main/src/main/java/io/prestosql/operator/scalar/SliceToBooleanFunction.java b/presto-main/src/main/java/io/prestosql/operator/scalar/SliceToBooleanFunction.java new file mode 100644 index 000000000000..f76a210958eb --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/operator/scalar/SliceToBooleanFunction.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.scalar; + +import io.airlift.slice.Slice; +import io.prestosql.sql.gen.lambda.LambdaFunctionInterface; + +@FunctionalInterface +public interface SliceToBooleanFunction + extends LambdaFunctionInterface +{ + Boolean apply(Slice x); +} diff --git a/presto-main/src/test/java/io/prestosql/operator/scalar/TestArrayMatchFunctions.java b/presto-main/src/test/java/io/prestosql/operator/scalar/TestArrayMatchFunctions.java new file mode 100644 index 000000000000..d21b515dec69 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/operator/scalar/TestArrayMatchFunctions.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.operator.scalar; + +import io.prestosql.spi.type.BooleanType; +import org.testng.annotations.Test; + +public class TestArrayMatchFunctions + extends AbstractTestFunctions +{ + @Test + public void testAllMatch() + { + assertFunction("all_match(ARRAY [5, 7, 9], x -> x % 2 = 1)", BooleanType.BOOLEAN, true); + assertFunction("all_match(ARRAY [true, false, true], x -> x)", BooleanType.BOOLEAN, false); + assertFunction("all_match(ARRAY ['abc', 'ade', 'afg'], x -> substr(x, 1, 1) = 'a')", BooleanType.BOOLEAN, true); + assertFunction("all_match(ARRAY [], x -> true)", BooleanType.BOOLEAN, true); + assertFunction("all_match(ARRAY [true, true, NULL], x -> x)", BooleanType.BOOLEAN, null); + assertFunction("all_match(ARRAY [true, false, NULL], x -> x)", BooleanType.BOOLEAN, false); + assertFunction("all_match(ARRAY [NULL, NULL, NULL], x -> x > 1)", BooleanType.BOOLEAN, null); + assertFunction("all_match(ARRAY [NULL, NULL, NULL], x -> x IS NULL)", BooleanType.BOOLEAN, true); + assertFunction("all_match(ARRAY [MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])], x -> cardinality(x) > 1)", BooleanType.BOOLEAN, true); + } + + @Test + public void testAnyMatch() + { + assertFunction("any_match(ARRAY [5, 8, 10], x -> x % 2 = 1)", BooleanType.BOOLEAN, true); + assertFunction("any_match(ARRAY [false, false, false], x -> x)", BooleanType.BOOLEAN, false); + assertFunction("any_match(ARRAY ['abc', 'def', 'ghi'], x -> substr(x, 1, 1) = 'a')", BooleanType.BOOLEAN, true); + assertFunction("any_match(ARRAY [], x -> true)", BooleanType.BOOLEAN, false); + assertFunction("any_match(ARRAY [false, false, NULL], x -> x)", BooleanType.BOOLEAN, null); + assertFunction("any_match(ARRAY [true, false, NULL], x -> x)", BooleanType.BOOLEAN, true); + assertFunction("any_match(ARRAY [NULL, NULL, NULL], x -> x > 1)", BooleanType.BOOLEAN, null); + assertFunction("any_match(ARRAY [true, false, NULL], x -> x IS NULL)", BooleanType.BOOLEAN, true); + assertFunction("any_match(ARRAY [MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])], x -> cardinality(x) > 4)", BooleanType.BOOLEAN, false); + } + + @Test + public void testNoneMatch() + { + assertFunction("none_match(ARRAY [5, 8, 10], x -> x % 2 = 1)", BooleanType.BOOLEAN, false); + assertFunction("none_match(ARRAY [false, false, false], x -> x)", BooleanType.BOOLEAN, true); + assertFunction("none_match(ARRAY ['abc', 'def', 'ghi'], x -> substr(x, 1, 1) = 'a')", BooleanType.BOOLEAN, false); + assertFunction("none_match(ARRAY [], x -> true)", BooleanType.BOOLEAN, true); + assertFunction("none_match(ARRAY [false, false, NULL], x -> x)", BooleanType.BOOLEAN, null); + assertFunction("none_match(ARRAY [true, false, NULL], x -> x)", BooleanType.BOOLEAN, false); + assertFunction("none_match(ARRAY [NULL, NULL, NULL], x -> x > 1)", BooleanType.BOOLEAN, null); + assertFunction("none_match(ARRAY [true, false, NULL], x -> x IS NULL)", BooleanType.BOOLEAN, false); + assertFunction("none_match(ARRAY [MAP(ARRAY[1,2], ARRAY[3,4]), MAP(ARRAY[1,2,3], ARRAY[3,4,5])], x -> cardinality(x) > 4)", BooleanType.BOOLEAN, true); + } +}