diff --git a/presto-docs/src/main/sphinx/functions/aggregate.rst b/presto-docs/src/main/sphinx/functions/aggregate.rst index 1ae441291dc55..ff407141bd273 100644 --- a/presto-docs/src/main/sphinx/functions/aggregate.rst +++ b/presto-docs/src/main/sphinx/functions/aggregate.rst @@ -184,6 +184,13 @@ General Aggregate Functions Returns an array created from the distinct input ``x`` elements. If the input includes ``NULL``, ``NULL`` will be included in the returned array. + If the input includes arrays with ``NULL`` elements or rows with ``NULL`` fields, they will + be included in the returned array. This function uses ``IS DISTINCT FROM`` to determine + distinctness. :: + + SELECT set_agg(x) FROM (VALUES(1), (2), (null), (2), (null)) t(x) -- ARRAY[1, 2, null] + SELECT set_agg(x) FROM (VALUES(ROW(ROW(1, null))), ROW((ROW(2, 'a'))), ROW((ROW(1, null))), (null)) t(x) -- ARRAY[ROW(1, null), ROW(2, 'a'), null] + .. function:: set_union(array(T)) -> array(T) @@ -191,6 +198,9 @@ General Aggregate Functions When all inputs are ``NULL``, this function returns an empty array. If ``NULL`` is an element of one of the input arrays, ``NULL`` will be included in the returned array. + If the input includes arrays with ``NULL`` elements or rows with ``NULL`` fields, they will + be included in the returned array. This function uses ``IS DISTINCT FROM`` to determine + distinctness. Example:: diff --git a/presto-docs/src/main/sphinx/functions/array.rst b/presto-docs/src/main/sphinx/functions/array.rst index e7cbe2390d3d6..b8a4261f1c39f 100644 --- a/presto-docs/src/main/sphinx/functions/array.rst +++ b/presto-docs/src/main/sphinx/functions/array.rst @@ -50,14 +50,25 @@ Array Functions .. function:: array_distinct(x) -> array Remove duplicate values from the array ``x``. + This function uses ``IS DISTINCT FROM`` to determine the distinct elements. :: + + SELECT array_distinct(ARRAY [1, 2, null, null, 2]) -- ARRAY[1, 2, null] + SELECT array_distinct(ARRAY [ROW(1, null), ROW (1, null)] -- ARRAY[ROW(1, null) .. function:: array_duplicates(array(T)) -> array(bigint/varchar) Returns a set of elements that occur more than once in ``array``. + Throws an exception if any of the elements are rows or arrays that contain nulls. :: + + SELECT array_duplicates(ARRAY[1, 2, null, 1, null, 3]) -- ARRAY[1, null] + SELECT array_duplicates(ARRAY[ROW(1, null), ROW(1, null)]) -- "map key cannot be null or contain nulls" .. function:: array_except(x, y) -> array Returns an array of elements in ``x`` but not in ``y``, without duplicates. + This function uses ``IS NOT DISTINCT FROM`` to determine which elements are the same. :: + + SELECT array_except(ARRAY[1, 3, 3, 2, null], ARRAY[1,2, 2, 4]) -- ARRAY[3, null] .. function:: array_frequency(array(E)) -> map(E, int) @@ -67,14 +78,24 @@ Array Functions .. function:: array_has_duplicates(array(T)) -> boolean Returns a boolean: whether ``array`` has any elements that occur more than once. + Throws an exception if any of the elements are rows or arrays that contain nulls. :: + + SELECT array_has_duplicates(ARRAY[1, 2, null, 1, null, 3]) -- true + SELECT array_has_duplicates(ARRAY[ROW(1, null), ROW(1, null)]) -- "map key cannot be null or contain nulls" .. function:: array_intersect(x, y) -> array Returns an array of the elements in the intersection of ``x`` and ``y``, without duplicates. + This function uses ``IS NOT DISTINCT FROM`` to determine which elements are the same. :: + + SELECT array_intersect(ARRAY[1, 2, 3, 2, null], ARRAY[1,2, 2, 4, null]) -- ARRAY[1, 2, null] .. function:: array_intersect(array(array(E))) -> array(E) Returns an array of the elements in the intersection of all arrays in the given array, without duplicates. + This function uses ``IS NOT DISTINCT FROM`` to determine which elements are the same. :: + + SELECT array_intersect(ARRAY[ARRAY[1, 2, 3, 2, null], ARRAY[1,2,2, 4, null], ARRAY [1, 2, 3, 4 null]]) -- ARRAY[1, 2, null] .. function:: array_join(x, delimiter, null_replacement) -> varchar @@ -82,16 +103,22 @@ Array Functions .. function:: array_least_frequent(array(T)) -> array(T) - Returns the least frequent element of an array. If there are multiple elements with same frequency, the function returns the smallest element. :: + Returns the least frequent non-null element of an array. If there are multiple elements with the same frequency, the function returns the smallest element. + If the array has more than one element and any elements are ``ROWS`` with null fields or ``ARRAYS`` with null elements, an exception is returned. :: SELECT array_least_frequent(ARRAY[1, 0 , 5]) -- ARRAY[0] + select array_least_frequent(ARRAY[1, null, 1]) -- ARRAY[1] + select array_least_frequent(ARRAY[ROW(1,null), ROW(1, null)]) -- "map key cannot be null or contain nulls" .. function:: array_least_frequent(array(T), n) -> array(T) - Returns ``n`` least frequent elements of an array. The elements are ordered in increasing order of their frequencies. - If two elements have same frequency, smaller elements will appear first. :: + Returns ``n`` least frequent non-null elements of an array. The elements are ordered in increasing order of their frequencies. + If two elements have the same frequency, smaller elements will appear first. + If the array has more than one element and any elements are ``ROWS`` with null fields or ``ARRAYS`` with null elements, an exception is returned. :: SELECT array_least_frequent(ARRAY[3, 2, 2, 6, 6, 1, 1], 3) -- ARRAY[3, 1, 2] + select array_least_frequent(ARRAY[1, null, 1], 2) -- ARRAY[1] + select array_least_frequent(ARRAY[ROW(1,null), ROW(1, null)], 2) -- "map key cannot be null or contain nulls" .. function:: array_max(x) -> x @@ -139,7 +166,7 @@ Array Functions .. function:: array_sort(x) -> array Sorts and returns the array ``x``. The elements of ``x`` must be orderable. - Null elements will be placed at the end of the returned array. + Null elements are placed at the end of the returned array. .. function:: array_sort(array(T), function(T,T,int)) -> array(T) @@ -174,7 +201,7 @@ Array Functions .. function:: array_sort_desc(x) -> array Returns the ``array`` sorted in the descending order. Elements of the ``array`` must be orderable. - Null elements will be placed at the end of the returned array.:: + Null elements are placed at the end of the returned array. :: SELECT array_sort_desc(ARRAY [100, 1, 10, 50]); -- [100, 50, 10, 1] SELECT array_sort_desc(ARRAY [null, 100, null, 1, 10, 50]); -- [100, 50, 10, 1, null, null] @@ -201,10 +228,19 @@ Array Functions Tests if arrays ``x`` and ``y`` have any non-null elements in common. Returns null if there are no non-null elements in common but either array contains null. + Throws a ``NOT_SUPPORTED`` exception on elements of ``ROW`` or ``ARRAY`` type that contain null values. :: + + SELECT arrays_overlap(ARRAY [1, 2, null], ARRAY [2, 3, null]) -- true + SELECT arrays_overlap(ARRAY [1, 2], ARRAY [3, 4]) -- false + SELECT arrays_overlap(ARRAY [1, null], ARRAY[2]) -- null + SELECT arrays_overlap(ARRAY[ROW(1, null)], ARRAY[1, 2]) -- "ROW comparison not supported for fields with null elements" .. function:: array_union(x, y) -> array Returns an array of the elements in the union of ``x`` and ``y``, without duplicates. + This function uses ``IS NOT DISTINCT FROM`` to determine which elements are the same. :: + + SELECT array_union(ARRAY[1, 2, 3, 2, null], ARRAY[1,2, 2, 4, null]) -- ARRAY[1, 2, 3, 4 null] .. function:: cardinality(x) -> bigint diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 9a6c01095e1fc..03dc76e6279c9 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -90,8 +90,6 @@ import com.facebook.presto.operator.aggregation.VarianceAggregation; import com.facebook.presto.operator.aggregation.approxmostfrequent.ApproximateMostFrequent; import com.facebook.presto.operator.aggregation.arrayagg.ArrayAggregationFunction; -import com.facebook.presto.operator.aggregation.arrayagg.SetAggregationFunction; -import com.facebook.presto.operator.aggregation.arrayagg.SetUnionFunction; import com.facebook.presto.operator.aggregation.differentialentropy.DifferentialEntropyAggregation; import com.facebook.presto.operator.aggregation.histogram.Histogram; import com.facebook.presto.operator.aggregation.multimapagg.AlternativeMultimapAggregationFunction; @@ -372,6 +370,8 @@ import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG; import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG_WITH_WEIGHT; import static com.facebook.presto.operator.aggregation.TDigestAggregationFunction.TDIGEST_AGG_WITH_WEIGHT_AND_COMPRESSION; +import static com.facebook.presto.operator.aggregation.arrayagg.SetAggregationFunction.SET_AGG; +import static com.facebook.presto.operator.aggregation.arrayagg.SetUnionFunction.SET_UNION; import static com.facebook.presto.operator.aggregation.minmaxby.AlternativeMaxByAggregationFunction.ALTERNATIVE_MAX_BY; import static com.facebook.presto.operator.aggregation.minmaxby.AlternativeMinByAggregationFunction.ALTERNATIVE_MIN_BY; import static com.facebook.presto.operator.aggregation.minmaxby.MaxByAggregationFunction.MAX_BY; @@ -932,8 +932,7 @@ private List getBuiltInFunctions(FeaturesConfig featuresC .function(ARRAY_FLATTEN_FUNCTION) .function(ARRAY_CONCAT_FUNCTION) .functions(ARRAY_CONSTRUCTOR, ARRAY_SUBSCRIPT, ARRAY_TO_JSON, JSON_TO_ARRAY, JSON_STRING_TO_ARRAY) - .aggregate(SetAggregationFunction.class) - .aggregate(SetUnionFunction.class) + .functions(SET_AGG, SET_UNION) .function(new ArrayAggregationFunction(featuresConfig.isLegacyArrayAgg(), featuresConfig.getArrayAggGroupImplementation())) .functions(new MapSubscriptOperator(featuresConfig.isLegacyMapSubscript())) .functions(MAP_CONSTRUCTOR, MAP_TO_JSON, JSON_TO_MAP, JSON_STRING_TO_MAP) diff --git a/presto-main/src/main/java/com/facebook/presto/operator/DynamicFilterSourceOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/DynamicFilterSourceOperator.java index 258b203bfe4e5..1f0fd5e9dd5db 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/DynamicFilterSourceOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/DynamicFilterSourceOperator.java @@ -203,6 +203,7 @@ private DynamicFilterSourceOperator( this.blockBuilders[channelIndex] = type.createBlockBuilder(null, EXPECTED_BLOCK_BUILDER_SIZE); this.valueSets[channelIndex] = new TypedSet( type, + Optional.empty(), blockBuilders[channelIndex], EXPECTED_BLOCK_BUILDER_SIZE, String.format("DynamicFilterSourceOperator_%s_%d", planNodeId, channelIndex), diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/OptimizedTypedSet.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/OptimizedTypedSet.java index 3d609503752bc..907d49d9a603b 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/OptimizedTypedSet.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/OptimizedTypedSet.java @@ -21,14 +21,19 @@ import com.facebook.presto.operator.project.SelectedPositions; import org.openjdk.jol.info.ClassLayout; +import java.lang.invoke.MethodHandle; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Optional; import static com.facebook.presto.common.array.Arrays.ensureCapacity; +import static com.facebook.presto.common.type.TypeUtils.readNativeValue; import static com.facebook.presto.operator.project.SelectedPositions.positionsList; import static com.facebook.presto.type.TypeUtils.hashPosition; import static com.facebook.presto.type.TypeUtils.positionEqualsPosition; +import static com.facebook.presto.util.Failures.internalError; +import static com.google.common.base.Defaults.defaultValue; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.sizeOf; import static it.unimi.dsi.fastutil.HashCommon.arraySize; @@ -47,6 +52,7 @@ public class OptimizedTypedSet private static final SelectedPositions EMPTY_SELECTED_POSITIONS = positionsList(new int[0], 0, 0); private final Type elementType; + private final Optional elementIsDistinctFrom; private final int hashCapacity; private final int hashMask; @@ -56,17 +62,18 @@ public class OptimizedTypedSet private long[] blockPositionByHash; // Each 64-bit long is 32-bit index for blocks + 32-bit position within block private int currentBlockIndex = -1; // The index into the blocks array and positionsForBlocks list - public OptimizedTypedSet(Type elementType, int maxPositionCount) + public OptimizedTypedSet(Type elementType, MethodHandle elementIsDistinctFrom, int maxPositionCount) { - this(elementType, INITIAL_BLOCK_COUNT, maxPositionCount); + this(elementType, Optional.of(elementIsDistinctFrom), INITIAL_BLOCK_COUNT, maxPositionCount); } - public OptimizedTypedSet(Type elementType, int expectedBlockCount, int maxPositionCount) + public OptimizedTypedSet(Type elementType, Optional elementIsDistinctFrom, int expectedBlockCount, int maxPositionCount) { checkArgument(expectedBlockCount >= 0, "expectedBlockCount must not be negative"); checkArgument(maxPositionCount >= 0, "maxPositionCount must not be negative"); this.elementType = requireNonNull(elementType, "elementType must not be null"); + this.elementIsDistinctFrom = requireNonNull(elementIsDistinctFrom, "elementIsDistinctFrom is null"); this.hashCapacity = arraySize(maxPositionCount, FILL_RATIO); this.hashMask = hashCapacity - 1; @@ -293,7 +300,7 @@ private int getInsertPosition(long[] hashtable, int hashPosition, Block block, i // Already has this element int blockIndex = (int) ((blockPosition & 0xffff_ffff_0000_0000L) >> 32); int positionWithinBlock = (int) (blockPosition & 0xffff_ffff); - if (positionEqualsPosition(elementType, blocks[blockIndex], positionWithinBlock, block, position)) { + if (isContainedAt(blocks[blockIndex], positionWithinBlock, block, position)) { return INVALID_POSITION; } @@ -301,6 +308,23 @@ private int getInsertPosition(long[] hashtable, int hashPosition, Block block, i } } + private boolean isContainedAt(Block firstBlock, int positionWithinFirstBlock, Block secondBlock, int positionWithinSecondBlock) + { + if (elementIsDistinctFrom.isPresent()) { + boolean firstValueNull = firstBlock.isNull(positionWithinFirstBlock); + Object firstValue = firstValueNull ? defaultValue(elementType.getJavaType()) : readNativeValue(elementType, firstBlock, positionWithinFirstBlock); + boolean secondValueNull = secondBlock.isNull(positionWithinSecondBlock); + Object secondValue = secondValueNull ? defaultValue(elementType.getJavaType()) : readNativeValue(elementType, secondBlock, positionWithinSecondBlock); + try { + return !(boolean) elementIsDistinctFrom.get().invoke(firstValue, firstValueNull, secondValue, secondValueNull); + } + catch (Throwable t) { + throw internalError(t); + } + } + return positionEqualsPosition(elementType, firstBlock, positionWithinFirstBlock, secondBlock, positionWithinSecondBlock); + } + /** * Add an element to the hash table if it's not already existed. * @@ -322,7 +346,7 @@ private boolean addElement(long[] hashtable, int hashPosition, Block block, int // Already has this element int blockIndex = (int) ((blockPosition & 0xffff_ffff_0000_0000L) >> 32); int positionWithinBlock = (int) (blockPosition & 0xffff_ffff); - if (positionEqualsPosition(elementType, blocks[blockIndex], positionWithinBlock, block, position)) { + if (isContainedAt(blocks[blockIndex], positionWithinBlock, block, position)) { return false; } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/SetOfValues.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/SetOfValues.java index 18c0bb45d303b..6164d244215d5 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/SetOfValues.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/SetOfValues.java @@ -19,12 +19,15 @@ import com.facebook.presto.spi.PrestoException; import org.openjdk.jol.info.ClassLayout; +import java.lang.invoke.MethodHandle; import java.util.Arrays; +import static com.facebook.presto.common.type.TypeUtils.readNativeValue; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; import static com.facebook.presto.type.TypeUtils.expectedValueSize; import static com.facebook.presto.type.TypeUtils.hashPosition; -import static com.facebook.presto.type.TypeUtils.positionEqualsPosition; +import static com.facebook.presto.util.Failures.internalError; +import static com.google.common.base.Defaults.defaultValue; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.sizeOf; import static it.unimi.dsi.fastutil.HashCommon.arraySize; @@ -40,15 +43,17 @@ public final class SetOfValues private final BlockBuilder valueBlockBuilder; private final Type valueType; + MethodHandle elementIsDistinctFrom; private int[] valuePositionByHash; private int hashCapacity; private int maxFill; private int hashMask; - public SetOfValues(Type valueType) + public SetOfValues(Type valueType, MethodHandle elementIsDistinctFrom) { this.valueType = requireNonNull(valueType, "valueType is null"); + this.elementIsDistinctFrom = requireNonNull(elementIsDistinctFrom, "elementIsDistinctFrom is null"); valueBlockBuilder = this.valueType.createBlockBuilder(null, EXPECTED_ENTRIES, expectedValueSize(valueType, EXPECTED_ENTRY_SIZE)); hashCapacity = arraySize(EXPECTED_ENTRIES, FILL_RATIO); this.maxFill = calculateMaxFill(hashCapacity); @@ -57,9 +62,9 @@ public SetOfValues(Type valueType) Arrays.fill(valuePositionByHash, EMPTY_SLOT); } - public SetOfValues(Block serialized, Type elementType) + public SetOfValues(Block serialized, Type elementType, MethodHandle elementIsDistinctFrom) { - this(elementType); + this(elementType, elementIsDistinctFrom); deserialize(requireNonNull(serialized, "serialized is null")); } @@ -111,13 +116,27 @@ private int getHashPositionOfValue(Block value, int position) if (valuePositionByHash[hashPosition] == EMPTY_SLOT) { return hashPosition; } - else if (positionEqualsPosition(valueType, valueBlockBuilder, valuePositionByHash[hashPosition], value, position)) { + else if (isContainedAt(valueBlockBuilder, valuePositionByHash[hashPosition], value, position)) { return hashPosition; } hashPosition = getMaskedHash(hashPosition + 1); } } + private boolean isContainedAt(Block firstBlock, int positionWithinFirstBlock, Block secondBlock, int positionWithinSecondBlock) + { + boolean firstValueNull = firstBlock.isNull(positionWithinFirstBlock); + Object firstValue = firstValueNull ? defaultValue(valueType.getJavaType()) : readNativeValue(valueType, firstBlock, positionWithinFirstBlock); + boolean secondValueNull = secondBlock.isNull(positionWithinSecondBlock); + Object secondValue = secondValueNull ? defaultValue(valueType.getJavaType()) : readNativeValue(valueType, secondBlock, positionWithinSecondBlock); + try { + return !(boolean) elementIsDistinctFrom.invoke(firstValue, firstValueNull, secondValue, secondValueNull); + } + catch (Throwable t) { + throw internalError(t); + } + } + private void rehash() { long newCapacityLong = hashCapacity * 2L; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedSet.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedSet.java index 71c47386aa03f..0ddd89e74779d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedSet.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/TypedSet.java @@ -22,12 +22,16 @@ import it.unimi.dsi.fastutil.ints.IntArrayList; import org.openjdk.jol.info.ClassLayout; +import java.lang.invoke.MethodHandle; import java.util.Optional; +import static com.facebook.presto.common.type.TypeUtils.readNativeValue; import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_FUNCTION_MEMORY_LIMIT; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; import static com.facebook.presto.type.TypeUtils.hashPosition; import static com.facebook.presto.type.TypeUtils.positionEqualsPosition; +import static com.facebook.presto.util.Failures.internalError; +import static com.google.common.base.Defaults.defaultValue; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static it.unimi.dsi.fastutil.HashCommon.arraySize; @@ -44,6 +48,7 @@ public class TypedSet private static final float FILL_RATIO = 0.75f; private final Type elementType; + private final Optional elementIsDistinctFrom; private final IntArrayList blockPositionByHash; private final BlockBuilder elementBlock; private final String functionName; @@ -63,18 +68,19 @@ public class TypedSet public TypedSet(Type elementType, int expectedSize, String functionName) { - this(elementType, elementType.createBlockBuilder(null, expectedSize), expectedSize, functionName); + this(elementType, Optional.empty(), elementType.createBlockBuilder(null, expectedSize), expectedSize, functionName, Optional.of(MAX_FUNCTION_MEMORY)); } - public TypedSet(Type elementType, BlockBuilder blockBuilder, int expectedSize, String functionName) + public TypedSet(Type elementType, MethodHandle elementIsDistinctFrom, int expectedSize, String functionName) { - this(elementType, blockBuilder, expectedSize, functionName, Optional.of(MAX_FUNCTION_MEMORY)); + this(elementType, Optional.of(elementIsDistinctFrom), elementType.createBlockBuilder(null, expectedSize), expectedSize, functionName, Optional.of(MAX_FUNCTION_MEMORY)); } - public TypedSet(Type elementType, BlockBuilder blockBuilder, int expectedSize, String functionName, Optional maxBlockMemory) + public TypedSet(Type elementType, Optional elementIsDistinctFrom, BlockBuilder blockBuilder, int expectedSize, String functionName, Optional maxBlockMemory) { checkArgument(expectedSize >= 0, "expectedSize must not be negative"); this.elementType = requireNonNull(elementType, "elementType must not be null"); + this.elementIsDistinctFrom = requireNonNull(elementIsDistinctFrom, "elementIsDistinctFrom is null"); this.elementBlock = requireNonNull(blockBuilder, "blockBuilder must not be null"); this.functionName = functionName; this.maxBlockMemoryInBytes = requireNonNull(maxBlockMemory, "maxBlockMemory must not be null").map(DataSize::toBytes).orElse(Long.MAX_VALUE); @@ -202,7 +208,7 @@ private int getHashPositionOfElement(Block block, int position) return hashPosition; } // Already has this element - else if (positionEqualsPosition(elementType, elementBlock, blockPosition, block, position)) { + else if (isContainedAt(block, position, blockPosition)) { return hashPosition; } @@ -210,6 +216,23 @@ else if (positionEqualsPosition(elementType, elementBlock, blockPosition, block, } } + private boolean isContainedAt(Block block, int blockPosition, int elementBlockPosition) + { + if (elementIsDistinctFrom.isPresent()) { + boolean firstValueNull = elementBlock.isNull(elementBlockPosition); + Object firstValue = firstValueNull ? defaultValue(elementType.getJavaType()) : readNativeValue(elementType, elementBlock, elementBlockPosition); + boolean secondValueNull = block.isNull(blockPosition); + Object secondValue = secondValueNull ? defaultValue(elementType.getJavaType()) : readNativeValue(elementType, block, blockPosition); + try { + return !(boolean) elementIsDistinctFrom.get().invoke(firstValue, firstValueNull, secondValue, secondValueNull); + } + catch (Throwable t) { + throw internalError(t); + } + } + return positionEqualsPosition(elementType, elementBlock, elementBlockPosition, block, blockPosition); + } + private void addNewElement(int hashPosition, Block block, int position) { elementType.appendTo(block, position, elementBlock); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/arrayagg/SetAggregationFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/arrayagg/SetAggregationFunction.java index 3c8321abb5ac5..3158ab14d83dd 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/arrayagg/SetAggregationFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/arrayagg/SetAggregationFunction.java @@ -13,40 +13,130 @@ */ package com.facebook.presto.operator.aggregation.arrayagg; +import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; -import com.facebook.presto.operator.aggregation.NullablePosition; +import com.facebook.presto.common.type.TypeSignatureParameter; +import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.SqlAggregationFunction; +import com.facebook.presto.operator.aggregation.AccumulatorCompiler; +import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation; +import com.facebook.presto.operator.aggregation.MapAggregationFunction; import com.facebook.presto.operator.aggregation.SetOfValues; import com.facebook.presto.operator.aggregation.state.SetAggregationState; -import com.facebook.presto.spi.function.AggregationFunction; -import com.facebook.presto.spi.function.AggregationState; -import com.facebook.presto.spi.function.BlockIndex; -import com.facebook.presto.spi.function.BlockPosition; -import com.facebook.presto.spi.function.CombineFunction; -import com.facebook.presto.spi.function.InputFunction; -import com.facebook.presto.spi.function.OutputFunction; -import com.facebook.presto.spi.function.SqlType; -import com.facebook.presto.spi.function.TypeParameter; - -@AggregationFunction(value = "set_agg", isCalledOnNullInput = true) +import com.facebook.presto.operator.aggregation.state.SetAggregationStateFactory; +import com.facebook.presto.spi.function.aggregation.Accumulator; +import com.facebook.presto.spi.function.aggregation.AggregationMetadata; +import com.facebook.presto.spi.function.aggregation.GroupedAccumulator; +import com.google.common.collect.ImmutableList; + +import java.lang.invoke.MethodHandle; +import java.util.List; + +import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; +import static com.facebook.presto.spi.function.Signature.typeVariable; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.util.Reflection.methodHandle; +import static com.google.common.collect.ImmutableList.toImmutableList; + public class SetAggregationFunction + extends SqlAggregationFunction { - private SetAggregationFunction() + public static final SetAggregationFunction SET_AGG = new SetAggregationFunction(); + public static final String NAME = "set_agg"; + private static final MethodHandle INPUT_FUNCTION = methodHandle(SetAggregationFunction.class, "input", Type.class, MethodHandle.class, SetAggregationState.class, Block.class, int.class); + private static final MethodHandle COMBINE_FUNCTION = methodHandle(SetAggregationFunction.class, "combine", SetAggregationState.class, SetAggregationState.class); + private static final MethodHandle OUTPUT_FUNCTION = methodHandle(SetAggregationFunction.class, "output", SetAggregationState.class, BlockBuilder.class); + + public SetAggregationFunction() + { + super(NAME, + ImmutableList.of(typeVariable("T")), + ImmutableList.of(), + parseTypeSignature("array(T)"), + ImmutableList.of(parseTypeSignature("T"))); + } + + @Override + public String getDescription() + { + return "Aggregates distinct values into a single array"; + } + + @Override + public boolean isCalledOnNullInput() + { + return true; + } + + @Override + public BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) + { + Type inputType = boundVariables.getTypeVariable("T"); + ArrayType outputType = (ArrayType) functionAndTypeManager.getParameterizedType(StandardTypes.ARRAY, ImmutableList.of(TypeSignatureParameter.of(inputType.getTypeSignature()))); + MethodHandle distinctFromMethodHandle = functionAndTypeManager.getJavaScalarFunctionImplementation(functionAndTypeManager.resolveOperator(IS_DISTINCT_FROM, fromTypes(inputType, inputType))).getMethodHandle(); + return generateAggregation(inputType, outputType, distinctFromMethodHandle); + } + + private static BuiltInAggregationFunctionImplementation generateAggregation(Type inputType, ArrayType outputType, MethodHandle distinctFromMethodHandle) + { + DynamicClassLoader classLoader = new DynamicClassLoader(MapAggregationFunction.class.getClassLoader()); + List inputTypes = ImmutableList.of(inputType); + SetAggregationStateSerializer stateSerializer = new SetAggregationStateSerializer(inputType, distinctFromMethodHandle); + Type intermediateType = stateSerializer.getSerializedType(); + + AggregationMetadata metadata = new AggregationMetadata( + generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), + createInputParameterMetadata(inputType), + INPUT_FUNCTION.bindTo(inputType).bindTo(distinctFromMethodHandle), + COMBINE_FUNCTION, + OUTPUT_FUNCTION, + ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor( + SetAggregationState.class, + stateSerializer, + new SetAggregationStateFactory(inputType))), + outputType); + + Class accumulatorClass = AccumulatorCompiler.generateAccumulatorClass( + Accumulator.class, + metadata, + classLoader); + Class groupedAccumulatorClass = AccumulatorCompiler.generateAccumulatorClass( + GroupedAccumulator.class, + metadata, + classLoader); + return new + + BuiltInAggregationFunctionImplementation(NAME, inputTypes, ImmutableList.of(intermediateType), outputType, + true, true, metadata, accumulatorClass, groupedAccumulatorClass); + } + + private static List createInputParameterMetadata(Type valueType) { + return ImmutableList.of(new AggregationMetadata.ParameterMetadata(STATE), + new AggregationMetadata.ParameterMetadata(NULLABLE_BLOCK_INPUT_CHANNEL, valueType), + new AggregationMetadata.ParameterMetadata(BLOCK_INDEX)); } - @InputFunction - @TypeParameter("T") public static void input( - @TypeParameter("T") Type type, - @AggregationState SetAggregationState state, - @BlockPosition @SqlType("T") @NullablePosition Block block, - @BlockIndex int position) + Type type, + MethodHandle distinctFromMethodHandle, + SetAggregationState state, + Block block, + int position) { SetOfValues set = state.get(); if (set == null) { - set = new SetOfValues(type); + set = new SetOfValues(type, distinctFromMethodHandle); state.set(set); } @@ -55,10 +145,9 @@ public static void input( state.addMemoryUsage(set.estimatedInMemorySize() - startSize); } - @CombineFunction public static void combine( - @AggregationState SetAggregationState state, - @AggregationState SetAggregationState otherState) + SetAggregationState state, + SetAggregationState otherState) { if (state.get() != null && otherState.get() != null) { SetOfValues otherSet = otherState.get(); @@ -76,9 +165,8 @@ else if (state.get() == null) { } } - @OutputFunction("array(T)") public static void output( - @AggregationState SetAggregationState state, + SetAggregationState state, BlockBuilder out) { SetOfValues set = state.get(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/arrayagg/SetAggregationStateSerializer.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/arrayagg/SetAggregationStateSerializer.java index e812d88eff0ab..b6182b2ee2d26 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/arrayagg/SetAggregationStateSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/arrayagg/SetAggregationStateSerializer.java @@ -20,16 +20,21 @@ import com.facebook.presto.operator.aggregation.SetOfValues; import com.facebook.presto.operator.aggregation.state.SetAggregationState; import com.facebook.presto.spi.function.AccumulatorStateSerializer; -import com.facebook.presto.spi.function.TypeParameter; + +import java.lang.invoke.MethodHandle; + +import static java.util.Objects.requireNonNull; public class SetAggregationStateSerializer implements AccumulatorStateSerializer { private final ArrayType arrayType; + private final MethodHandle elementIsDistinctFrom; - public SetAggregationStateSerializer(@TypeParameter("T") Type elementType) + public SetAggregationStateSerializer(Type elementType, MethodHandle elementIsDistinctFrom) { - this.arrayType = new ArrayType(elementType); + this.arrayType = new ArrayType(requireNonNull(elementType, "elementType is null")); + this.elementIsDistinctFrom = requireNonNull(elementIsDistinctFrom, "element is distinct from is null"); } @Override @@ -52,6 +57,6 @@ public void serialize(SetAggregationState state, BlockBuilder out) @Override public void deserialize(Block block, int index, SetAggregationState state) { - state.set(new SetOfValues(arrayType.getObject(block, index), state.getElementType())); + state.set(new SetOfValues(arrayType.getObject(block, index), state.getElementType(), elementIsDistinctFrom)); } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/arrayagg/SetUnionFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/arrayagg/SetUnionFunction.java index 40add64c90c7e..1d4547dc937f0 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/arrayagg/SetUnionFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/arrayagg/SetUnionFunction.java @@ -13,44 +13,132 @@ */ package com.facebook.presto.operator.aggregation.arrayagg; +import com.facebook.presto.bytecode.DynamicClassLoader; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; -import com.facebook.presto.operator.aggregation.NullablePosition; +import com.facebook.presto.common.type.TypeSignatureParameter; +import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.SqlAggregationFunction; +import com.facebook.presto.operator.aggregation.AccumulatorCompiler; +import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation; +import com.facebook.presto.operator.aggregation.MapAggregationFunction; import com.facebook.presto.operator.aggregation.SetOfValues; import com.facebook.presto.operator.aggregation.state.SetAggregationState; -import com.facebook.presto.spi.function.AggregationFunction; -import com.facebook.presto.spi.function.AggregationState; -import com.facebook.presto.spi.function.BlockIndex; -import com.facebook.presto.spi.function.BlockPosition; -import com.facebook.presto.spi.function.CombineFunction; -import com.facebook.presto.spi.function.Description; -import com.facebook.presto.spi.function.InputFunction; -import com.facebook.presto.spi.function.OutputFunction; -import com.facebook.presto.spi.function.SqlType; -import com.facebook.presto.spi.function.TypeParameter; - -@AggregationFunction(value = "set_union", isCalledOnNullInput = true) -@Description("Given a column of array type, return an array of all the unique values contained in each of the arrays in the column") +import com.facebook.presto.operator.aggregation.state.SetAggregationStateFactory; +import com.facebook.presto.spi.function.aggregation.Accumulator; +import com.facebook.presto.spi.function.aggregation.AggregationMetadata; +import com.facebook.presto.spi.function.aggregation.GroupedAccumulator; +import com.google.common.collect.ImmutableList; + +import java.lang.invoke.MethodHandle; +import java.util.List; + +import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName; +import static com.facebook.presto.spi.function.Signature.typeVariable; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL; +import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.util.Reflection.methodHandle; +import static com.google.common.collect.ImmutableList.toImmutableList; + public class SetUnionFunction + extends SqlAggregationFunction { - private SetUnionFunction() + public static final SetUnionFunction SET_UNION = new SetUnionFunction(); + public static final String NAME = "set_union"; + private static final MethodHandle INPUT_FUNCTION = methodHandle(SetUnionFunction.class, "input", Type.class, ArrayType.class, MethodHandle.class, SetAggregationState.class, Block.class, int.class); + private static final MethodHandle COMBINE_FUNCTION = methodHandle(SetUnionFunction.class, "combine", SetAggregationState.class, SetAggregationState.class); + private static final MethodHandle OUTPUT_FUNCTION = methodHandle(SetUnionFunction.class, "output", SetAggregationState.class, BlockBuilder.class); + + public SetUnionFunction() + { + super(NAME, + ImmutableList.of(typeVariable("T")), + ImmutableList.of(), + parseTypeSignature("array(T)"), + ImmutableList.of(parseTypeSignature("array(T)"))); + } + + @Override + public String getDescription() + { + return "Given a column of array type, return an array of all the unique values contained in each of the arrays in the column"; + } + + @Override + public boolean isCalledOnNullInput() + { + return true; + } + + @Override + public BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) + { + Type elementType = boundVariables.getTypeVariable("T"); + ArrayType inputType = (ArrayType) functionAndTypeManager.getParameterizedType(StandardTypes.ARRAY, ImmutableList.of(TypeSignatureParameter.of(elementType.getTypeSignature()))); + ArrayType outputType = (ArrayType) functionAndTypeManager.getParameterizedType(StandardTypes.ARRAY, ImmutableList.of(TypeSignatureParameter.of(elementType.getTypeSignature()))); + MethodHandle distinctFromMethodHandle = functionAndTypeManager.getJavaScalarFunctionImplementation(functionAndTypeManager.resolveOperator(IS_DISTINCT_FROM, fromTypes(elementType, elementType))).getMethodHandle(); + return generateAggregation(elementType, inputType, outputType, distinctFromMethodHandle); + } + + private static BuiltInAggregationFunctionImplementation generateAggregation(Type elementType, ArrayType inputType, ArrayType outputType, MethodHandle distinctFromMethodHandle) + { + DynamicClassLoader classLoader = new DynamicClassLoader(MapAggregationFunction.class.getClassLoader()); + List inputTypes = ImmutableList.of(inputType); + SetAggregationStateSerializer stateSerializer = new SetAggregationStateSerializer(elementType, distinctFromMethodHandle); + Type intermediateType = stateSerializer.getSerializedType(); + + AggregationMetadata metadata = new AggregationMetadata( + generateAggregationName(NAME, outputType.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())), + createInputParameterMetadata(inputType), + INPUT_FUNCTION.bindTo(elementType).bindTo(inputType).bindTo(distinctFromMethodHandle), + COMBINE_FUNCTION, + OUTPUT_FUNCTION, + ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor( + SetAggregationState.class, + stateSerializer, + new SetAggregationStateFactory(elementType))), + outputType); + + Class accumulatorClass = AccumulatorCompiler.generateAccumulatorClass( + Accumulator.class, + metadata, + classLoader); + Class groupedAccumulatorClass = AccumulatorCompiler.generateAccumulatorClass( + GroupedAccumulator.class, + metadata, + classLoader); + return new + + BuiltInAggregationFunctionImplementation(NAME, inputTypes, ImmutableList.of(intermediateType), outputType, + true, true, metadata, accumulatorClass, groupedAccumulatorClass); + } + + private static List createInputParameterMetadata(Type valueType) { + return ImmutableList.of(new AggregationMetadata.ParameterMetadata(STATE), + new AggregationMetadata.ParameterMetadata(NULLABLE_BLOCK_INPUT_CHANNEL, valueType), + new AggregationMetadata.ParameterMetadata(BLOCK_INDEX)); } - @InputFunction - @TypeParameter("T") public static void input( - @TypeParameter("T") Type elementType, - @TypeParameter("array(T)") ArrayType arrayType, - @AggregationState SetAggregationState state, - @BlockPosition @SqlType("array(T)") @NullablePosition Block inputBlock, - @BlockIndex int position) + Type elementType, + ArrayType arrayType, + MethodHandle elementIsDistinctFrom, + SetAggregationState state, + Block inputBlock, + int position) { SetOfValues set = state.get(); if (set == null) { - set = new SetOfValues(elementType); + set = new SetOfValues(elementType, elementIsDistinctFrom); state.set(set); } @@ -62,10 +150,9 @@ public static void input( state.addMemoryUsage(set.estimatedInMemorySize() - startSize); } - @CombineFunction public static void combine( - @AggregationState SetAggregationState state, - @AggregationState SetAggregationState otherState) + SetAggregationState state, + SetAggregationState otherState) { if (state.get() != null && otherState.get() != null) { SetOfValues otherSet = otherState.get(); @@ -83,9 +170,8 @@ else if (state.get() == null) { } } - @OutputFunction("array(T)") public static void output( - @AggregationState SetAggregationState state, + SetAggregationState state, BlockBuilder out) { SetOfValues set = state.get(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayDistinctFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayDistinctFunction.java index 89f2d7645584c..aff06e0a85df3 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayDistinctFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayDistinctFunction.java @@ -18,14 +18,20 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.operator.aggregation.TypedSet; import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; -import com.facebook.presto.type.TypeUtils; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; import it.unimi.dsi.fastutil.longs.LongSet; +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.TypeUtils.readNativeValue; +import static com.facebook.presto.util.Failures.internalError; +import static com.google.common.base.Defaults.defaultValue; @ScalarFunction("array_distinct") @Description("Remove duplicate values from the given array") @@ -35,7 +41,10 @@ private ArrayDistinctFunction() {} @TypeParameter("E") @SqlType("array(E)") - public static Block distinct(@TypeParameter("E") Type type, @SqlType("array(E)") Block array) + public static Block distinct( + @TypeParameter("E") Type type, + @OperatorDependency(operator = IS_DISTINCT_FROM, argumentTypes = {"E", "E"}) MethodHandle elementIsDistinctFrom, + @SqlType("array(E)") Block array) { int arrayLength = array.getPositionCount(); if (arrayLength < 2) { @@ -43,15 +52,24 @@ public static Block distinct(@TypeParameter("E") Type type, @SqlType("array(E)") } if (arrayLength == 2) { - if (TypeUtils.positionEqualsPosition(type, array, 0, array, 1)) { - return array.getSingleValueBlock(0); - } - else { + boolean firstValueNull = array.isNull(0); + Object firstValue = firstValueNull ? defaultValue(type.getJavaType()) : readNativeValue(type, array, 0); + boolean secondValueNull = array.isNull(1); + Object secondValue = secondValueNull ? defaultValue(type.getJavaType()) : readNativeValue(type, array, 1); + boolean distinct; + try { + distinct = (boolean) elementIsDistinctFrom.invoke(firstValue, firstValueNull, secondValue, secondValueNull); + } + catch (Throwable t) { + throw internalError(t); + } + if (distinct) { return array; } + return array.getSingleValueBlock(0); } - TypedSet typedSet = new TypedSet(type, arrayLength, "array_distinct"); + TypedSet typedSet = new TypedSet(type, elementIsDistinctFrom, array.getPositionCount(), "array_distinct"); BlockBuilder distinctElementBlockBuilder; if (array.mayHaveNull()) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayExceptFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayExceptFunction.java index ae2ed727da1aa..1d891781cc9dd 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayExceptFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayExceptFunction.java @@ -17,10 +17,14 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.operator.aggregation.OptimizedTypedSet; import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM; import static java.lang.Math.max; @ScalarFunction("array_except") @@ -33,6 +37,7 @@ private ArrayExceptFunction() {} @SqlType("array(E)") public static Block except( @TypeParameter("E") Type type, + @OperatorDependency(operator = IS_DISTINCT_FROM, argumentTypes = {"E", "E"}) MethodHandle elementIsDistinctFrom, @SqlType("array(E)") Block leftArray, @SqlType("array(E)") Block rightArray) { @@ -43,7 +48,7 @@ public static Block except( return leftArray; } - OptimizedTypedSet typedSet = new OptimizedTypedSet(type, max(leftPositionCount, rightPositionCount)); + OptimizedTypedSet typedSet = new OptimizedTypedSet(type, elementIsDistinctFrom, max(leftPositionCount, rightPositionCount)); typedSet.union(rightArray); typedSet.except(leftArray); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java index ceb42383e07cd..b9c7b1268f306 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java @@ -17,12 +17,17 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.operator.aggregation.OptimizedTypedSet; import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlInvokedScalarFunction; import com.facebook.presto.spi.function.SqlParameter; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM; + public final class ArrayIntersectFunction { private ArrayIntersectFunction() {} @@ -33,6 +38,7 @@ private ArrayIntersectFunction() {} @SqlType("array(E)") public static Block intersect( @TypeParameter("E") Type type, + @OperatorDependency(operator = IS_DISTINCT_FROM, argumentTypes = {"E", "E"}) MethodHandle elementIsDistinctFrom, @SqlType("array(E)") Block leftArray, @SqlType("array(E)") Block rightArray) { @@ -48,7 +54,7 @@ public static Block intersect( return rightArray; } - OptimizedTypedSet typedSet = new OptimizedTypedSet(type, rightPositionCount); + OptimizedTypedSet typedSet = new OptimizedTypedSet(type, elementIsDistinctFrom, rightPositionCount); typedSet.union(rightArray); typedSet.intersect(leftArray); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayUnionFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayUnionFunction.java index 0a370dcca5b16..f7c16d451deb4 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayUnionFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayUnionFunction.java @@ -17,10 +17,15 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.operator.aggregation.OptimizedTypedSet; import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.function.OperatorType.IS_DISTINCT_FROM; + @ScalarFunction("array_union") @Description("Union elements of the two given arrays") public final class ArrayUnionFunction @@ -31,12 +36,13 @@ private ArrayUnionFunction() {} @SqlType("array(E)") public static Block union( @TypeParameter("E") Type type, + @OperatorDependency(operator = IS_DISTINCT_FROM, argumentTypes = {"E", "E"}) MethodHandle elementIsDistinctFrom, @SqlType("array(E)") Block leftArray, @SqlType("array(E)") Block rightArray) { int leftArrayCount = leftArray.getPositionCount(); int rightArrayCount = rightArray.getPositionCount(); - OptimizedTypedSet typedSet = new OptimizedTypedSet(type, leftArrayCount + rightArrayCount); + OptimizedTypedSet typedSet = new OptimizedTypedSet(type, elementIsDistinctFrom, leftArrayCount + rightArrayCount); typedSet.union(leftArray); typedSet.union(rightArray); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java index e0af8cb522616..1e56f3bd5fa63 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MapConcatFunction.java @@ -151,7 +151,7 @@ public static Block mapConcat(MapType mapType, Block[] maps) Type valueType = mapType.getValueType(); // We need to divide the entries by 2 because the maps array is SingleMapBlocks and it had the positionCount twice as large as a normal Block - OptimizedTypedSet typedSet = new OptimizedTypedSet(keyType, maps.length, entries / 2); + OptimizedTypedSet typedSet = new OptimizedTypedSet(keyType, Optional.empty(), maps.length, entries / 2); for (int i = lastMapIndex; i >= firstMapIndex; i--) { SingleMapBlock singleMapBlock = (SingleMapBlock) maps[i]; Block keyBlock = singleMapBlock.getKeyBlock(); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestOptimizedTypedSet.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestOptimizedTypedSet.java index 997fbfc5c015e..363b772713708 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestOptimizedTypedSet.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestOptimizedTypedSet.java @@ -14,19 +14,24 @@ package com.facebook.presto.operator.aggregation; import com.facebook.presto.common.block.Block; +import com.facebook.presto.type.BigintOperators; import org.testng.annotations.Test; +import java.lang.invoke.MethodHandle; +import java.util.Optional; + import static com.facebook.presto.block.BlockAssertions.assertBlockEquals; import static com.facebook.presto.block.BlockAssertions.createEmptyBlock; import static com.facebook.presto.block.BlockAssertions.createLongRepeatBlock; import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock; +import static com.facebook.presto.common.block.MethodHandleUtil.methodHandle; import static com.facebook.presto.common.type.BigintType.BIGINT; import static org.testng.Assert.fail; public class TestOptimizedTypedSet { - private static final String FUNCTION_NAME = "optimized_typed_set_test"; private static final int POSITIONS_PER_PAGE = 100; + private static final MethodHandle BIGINT_DISTINCT_METHOD_HANDLE = methodHandle(BigintOperators.BigintDistinctFromOperator.class, "isDistinctFrom", long.class, boolean.class, long.class, boolean.class); @Test public void testConstructor() @@ -34,7 +39,7 @@ public void testConstructor() for (int i = -2; i <= -1; i++) { try { //noinspection ResultOfObjectAllocationIgnored - new OptimizedTypedSet(BIGINT, 2, i); + new OptimizedTypedSet(BIGINT, Optional.of(BIGINT_DISTINCT_METHOD_HANDLE), 2, i); fail("Should throw exception if expectedSize < 0"); } catch (IllegalArgumentException e) { @@ -44,7 +49,7 @@ public void testConstructor() try { //noinspection ResultOfObjectAllocationIgnored - new OptimizedTypedSet(null, -1, 1); + new OptimizedTypedSet(null, Optional.of(BIGINT_DISTINCT_METHOD_HANDLE), -1, 1); fail("Should throw exception if expectedBlockCount is negative"); } catch (NullPointerException | IllegalArgumentException e) { @@ -53,7 +58,7 @@ public void testConstructor() try { //noinspection ResultOfObjectAllocationIgnored - new OptimizedTypedSet(null, 2, 1); + new OptimizedTypedSet(null, Optional.of(BIGINT_DISTINCT_METHOD_HANDLE), 2, 1); fail("Should throw exception if type is null"); } catch (NullPointerException | IllegalArgumentException e) { @@ -64,7 +69,7 @@ public void testConstructor() @Test public void testUnionWithDistinctValues() { - OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE + 1); + OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE + 1); Block block = createLongSequenceBlock(0, POSITIONS_PER_PAGE / 2); testUnion(typedSet, block, block); @@ -80,7 +85,7 @@ public void testUnionWithDistinctValues() @Test public void testUnionWithRepeatingValues() { - OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE); + OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE); Block block = createLongRepeatBlock(0, POSITIONS_PER_PAGE); Block expectedBlock = createLongRepeatBlock(0, 1); @@ -95,14 +100,14 @@ public void testUnionWithRepeatingValues() @Test public void testIntersectWithEmptySet() { - OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE); + OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE); testIntersect(typedSet, createLongSequenceBlock(0, POSITIONS_PER_PAGE - 1).appendNull(), createEmptyBlock(BIGINT)); } @Test public void testIntersectWithDistinctValues() { - OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE); + OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE); Block block = createLongSequenceBlock(0, POSITIONS_PER_PAGE - 1).appendNull(); typedSet.union(block); @@ -119,7 +124,7 @@ public void testIntersectWithDistinctValues() @Test public void testIntersectWithNonDistinctValues() { - OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE); + OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE); Block block = createLongSequenceBlock(0, POSITIONS_PER_PAGE - 1).appendNull(); typedSet.union(block); @@ -137,7 +142,7 @@ public void testIntersectWithNonDistinctValues() @Test public void testExceptWithDistinctValues() { - OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE); + OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE); Block block = createLongSequenceBlock(0, POSITIONS_PER_PAGE - 1).appendNull(); typedSet.union(block); @@ -149,7 +154,7 @@ public void testExceptWithDistinctValues() @Test public void testExceptWithRepeatingValues() { - OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE); + OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE); Block block = createLongRepeatBlock(0, POSITIONS_PER_PAGE - 1).appendNull(); testExcept(typedSet, block, createLongSequenceBlock(0, 1).appendNull()); @@ -158,7 +163,7 @@ public void testExceptWithRepeatingValues() @Test public void testMultipleOperations() { - OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE + 1); + OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE + 1); Block block = createLongSequenceBlock(0, POSITIONS_PER_PAGE / 2).appendNull(); @@ -176,7 +181,7 @@ public void testMultipleOperations() @Test public void testNulls() { - OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, POSITIONS_PER_PAGE + 1); + OptimizedTypedSet typedSet = new OptimizedTypedSet(BIGINT, BIGINT_DISTINCT_METHOD_HANDLE, POSITIONS_PER_PAGE + 1); // Empty block Block emptyBlock = createLongSequenceBlock(0, 0); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestTypedSet.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestTypedSet.java index 9c1a6a00a83ec..efbcf7f0ad4f4 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestTypedSet.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestTypedSet.java @@ -22,6 +22,7 @@ import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.block.BlockAssertions.createEmptyLongsBlock; @@ -29,6 +30,7 @@ import static com.facebook.presto.block.BlockAssertions.createLongsBlock; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.operator.aggregation.TypedSet.MAX_FUNCTION_MEMORY; import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_FUNCTION_MEMORY_LIMIT; import static io.airlift.slice.Slices.utf8Slice; import static java.util.Collections.nCopies; @@ -129,7 +131,7 @@ public void testGetElementPositionWithProvidedEmptyBlockBuilder() int initialTypedSetEntryCount = 10; BlockBuilder emptyBlockBuilder = BIGINT.createFixedSizeBlockBuilder(elementCount); - TypedSet typedSet = new TypedSet(BIGINT, emptyBlockBuilder, initialTypedSetEntryCount, FUNCTION_NAME); + TypedSet typedSet = new TypedSet(BIGINT, Optional.empty(), emptyBlockBuilder, initialTypedSetEntryCount, FUNCTION_NAME, Optional.of(MAX_FUNCTION_MEMORY)); BlockBuilder externalBlockBuilder = BIGINT.createFixedSizeBlockBuilder(elementCount); for (int i = 0; i < elementCount; i++) { if (i % 10 == 0) { @@ -167,7 +169,7 @@ public void testGetElementPositionWithProvidedNonEmptyBlockBuilder() // The secondBlockBuilder should already have elementCount rows. BlockBuilder secondBlockBuilder = pageBuilder.getBlockBuilder(0); - TypedSet typedSet = new TypedSet(BIGINT, secondBlockBuilder, initialTypedSetEntryCount, FUNCTION_NAME); + TypedSet typedSet = new TypedSet(BIGINT, Optional.empty(), secondBlockBuilder, initialTypedSetEntryCount, FUNCTION_NAME, Optional.of(MAX_FUNCTION_MEMORY)); BlockBuilder externalBlockBuilder = BIGINT.createFixedSizeBlockBuilder(elementCount); for (int i = 0; i < elementCount; i++) { if (i % 10 == 0) { @@ -195,7 +197,7 @@ public void testGetElementPositionRandom() testGetElementPositionRandomFor(set); BlockBuilder emptyBlockBuilder = VARCHAR.createBlockBuilder(null, 3); - TypedSet setWithPassedInBuilder = new TypedSet(VARCHAR, emptyBlockBuilder, 1, FUNCTION_NAME); + TypedSet setWithPassedInBuilder = new TypedSet(VARCHAR, Optional.empty(), emptyBlockBuilder, 1, FUNCTION_NAME, Optional.of(MAX_FUNCTION_MEMORY)); testGetElementPositionRandomFor(setWithPassedInBuilder); } @@ -249,7 +251,7 @@ public void testMemoryExceeded() { try { TypedSet typedSet = new TypedSet(BIGINT, 10, FUNCTION_NAME); - for (int i = 0; i <= TypedSet.MAX_FUNCTION_MEMORY.toBytes() + 1; i++) { + for (int i = 0; i <= MAX_FUNCTION_MEMORY.toBytes() + 1; i++) { Block block = createLongsBlock(nCopies(1, (long) i)); typedSet.add(block, 0); } @@ -294,7 +296,7 @@ private static void testBigint(Block longBlock, int expectedSetSize) testBigintFor(typedSet, longBlock); BlockBuilder emptyBlockBuilder = BIGINT.createBlockBuilder(null, expectedSetSize); - TypedSet typedSetWithPassedInBuilder = new TypedSet(BIGINT, emptyBlockBuilder, expectedSetSize, FUNCTION_NAME); + TypedSet typedSetWithPassedInBuilder = new TypedSet(BIGINT, Optional.empty(), emptyBlockBuilder, expectedSetSize, FUNCTION_NAME, Optional.of(MAX_FUNCTION_MEMORY)); testBigintFor(typedSetWithPassedInBuilder, longBlock); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java index c459226f6f0de..96c2db4dabc30 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.RowType; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; @@ -23,6 +24,7 @@ import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.UnknownType.UNKNOWN; import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; @@ -66,4 +68,31 @@ public void testDuplicates() assertFunction("array_except(ARRAY[CAST('x' as VARCHAR), 'x', 'y', 'z'], ARRAY['x', 'y', 'x'])", new ArrayType(VARCHAR), ImmutableList.of("z")); assertFunction("array_except(ARRAY[true, false, null, true, false, null], ARRAY[true, true, true])", new ArrayType(BOOLEAN), asList(false, null)); } + + @Test + public void testIndeterminateRows() + { + // test unsupported + assertFunction( + "array_except(ARRAY[(123, 'abc'), (123, NULL)], ARRAY[(123, 'abc'), (123, NULL)])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of()); + assertFunction( + "array_except(ARRAY[(NULL, 'abc'), (123, null), (123, 'abc')], ARRAY[(456, 'def'),(NULL, 'abc')])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of(asList(123, null), asList(123, "abc"))); + } + + @Test + public void testIndeterminateArrays() + { + assertFunction( + "array_except(ARRAY[ARRAY[123, 456], ARRAY[123, NULL]], ARRAY[ARRAY[123, 456], ARRAY[123, NULL]])", + new ArrayType(new ArrayType(INTEGER)), + ImmutableList.of()); + assertFunction( + "array_except(ARRAY[ARRAY[NULL, 456], ARRAY[123, null], ARRAY[123, 456]], ARRAY[ARRAY[456, 456],ARRAY[NULL, 456]])", + new ArrayType(new ArrayType(INTEGER)), + ImmutableList.of(asList(123, null), asList(123, 456))); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayIntersectFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayIntersectFunction.java index 60f09dd353ed8..528eb9521144f 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayIntersectFunction.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayIntersectFunction.java @@ -11,6 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.facebook.presto.operator.scalar; import com.facebook.presto.common.type.ArrayType; @@ -20,42 +21,185 @@ import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DecimalType.createDecimalType; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.UnknownType.UNKNOWN; import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; public class TestArrayIntersectFunction extends AbstractTestFunctions { @Test - public void testBasic() + public void testVarchar() + { + assertFunction("ARRAY_INTERSECT(ARRAY[CAST('x' as VARCHAR), 'y', 'z'], ARRAY['x', 'y'])", new ArrayType(VARCHAR), ImmutableList.of("x", "y")); + assertFunction("ARRAY_INTERSECT(ARRAY ['abc'], ARRAY ['abc', 'bcd'])", new ArrayType(createVarcharType(3)), ImmutableList.of("abc")); + assertFunction("ARRAY_INTERSECT(ARRAY ['abc', 'abc'], ARRAY ['abc', 'abc'])", new ArrayType(createVarcharType(3)), ImmutableList.of("abc")); + assertFunction("ARRAY_INTERSECT(ARRAY ['foo', 'bar', 'baz'], ARRAY ['foo', 'test', 'bar'])", new ArrayType(createVarcharType(4)), ImmutableList.of("foo", "bar")); + } + + @Test + public void testBigint() + { + assertFunction("ARRAY_INTERSECT(ARRAY[CAST(1 as BIGINT), 5, 3], ARRAY[5])", new ArrayType(BIGINT), ImmutableList.of(5L)); + assertFunction("ARRAY_INTERSECT(ARRAY [CAST(5 AS BIGINT), CAST(5 AS BIGINT)], ARRAY [CAST(1 AS BIGINT), CAST(5 AS BIGINT)])", new ArrayType(BIGINT), ImmutableList.of(5L)); + } + + @Test + public void testInteger() + { + assertFunction("ARRAY_INTERSECT(ARRAY[1, 5, 3], ARRAY[3])", new ArrayType(INTEGER), ImmutableList.of(3)); + assertFunction("ARRAY_INTERSECT(ARRAY [5], ARRAY [5])", new ArrayType(INTEGER), ImmutableList.of(5)); + assertFunction("ARRAY_INTERSECT(ARRAY [1, 2, 5, 5, 6], ARRAY [5, 5, 6, 6, 7, 8])", new ArrayType(INTEGER), ImmutableList.of(5, 6)); + assertFunction("ARRAY_INTERSECT(ARRAY [IF (RAND() < 1.0E0, 7, 1) , 2], ARRAY [7])", new ArrayType(INTEGER), ImmutableList.of(7)); + } + + @Test + public void testDouble() + { + assertFunction("ARRAY_INTERSECT(ARRAY [1, 5], ARRAY [1.0E0])", new ArrayType(DOUBLE), ImmutableList.of(1.0)); + assertFunction("ARRAY_INTERSECT(ARRAY [1.0E0, 5.0E0], ARRAY [5.0E0, 5.0E0, 6.0E0])", new ArrayType(DOUBLE), ImmutableList.of(5.0)); + assertFunction("ARRAY_INTERSECT(ARRAY [8.3E0, 1.6E0, 4.1E0, 5.2E0], ARRAY [4.0E0, 5.2E0, 8.3E0, 9.7E0, 3.5E0])", new ArrayType(DOUBLE), ImmutableList.of(5.2, 8.3)); + assertFunction("ARRAY_INTERSECT(ARRAY [5.1E0, 7, 3.0E0, 4.8E0, 10], ARRAY [6.5E0, 10.0E0, 1.9E0, 5.1E0, 3.9E0, 4.8E0])", new ArrayType(DOUBLE), ImmutableList.of(10.0, 5.1, 4.8)); + assertFunction("ARRAY_INTERSECT(ARRAY[1.1E0, 5.4E0, 3.9E0], ARRAY[5, 5.4E0])", new ArrayType(DOUBLE), ImmutableList.of(5.4)); + } + + @Test + public void testDecimal() + { + assertFunction( + "ARRAY_INTERSECT(ARRAY [2.3, 2.3, 2.2], ARRAY[2.2, 2.3])", + new ArrayType(createDecimalType(2, 1)), + ImmutableList.of(decimal("2.3"), decimal("2.2"))); + assertFunction("ARRAY_INTERSECT(ARRAY [2.330, 1.900, 2.330], ARRAY [2.3300, 1.9000])", new ArrayType(createDecimalType(5, 4)), + ImmutableList.of(decimal("2.3300"), decimal("1.9000"))); + assertFunction("ARRAY_INTERSECT(ARRAY [2, 3], ARRAY[2.0, 3.0])", new ArrayType(createDecimalType(11, 1)), + ImmutableList.of(decimal("00000000002.0"), decimal("00000000003.0"))); + } + + @Test + public void testBoolean() { - assertFunction("array_intersect(ARRAY[1, 5, 3], ARRAY[3])", new ArrayType(INTEGER), ImmutableList.of(3)); - assertFunction("array_intersect(ARRAY[CAST(1 as BIGINT), 5, 3], ARRAY[5])", new ArrayType(BIGINT), ImmutableList.of(5L)); - assertFunction("array_intersect(ARRAY[CAST('x' as VARCHAR), 'y', 'z'], ARRAY['x', 'y'])", new ArrayType(VARCHAR), ImmutableList.of("x", "y")); - assertFunction("array_intersect(ARRAY[true, false, null], ARRAY[true, null])", new ArrayType(BOOLEAN), asList(true, null)); - assertFunction("array_intersect(ARRAY[1.1E0, 5.4E0, 3.9E0], ARRAY[5, 5.4E0])", new ArrayType(DOUBLE), ImmutableList.of(5.4)); + assertFunction("ARRAY_INTERSECT(ARRAY [true], ARRAY [true])", new ArrayType(BOOLEAN), ImmutableList.of(true)); + assertFunction("ARRAY_INTERSECT(ARRAY [true, false], ARRAY [true])", new ArrayType(BOOLEAN), ImmutableList.of(true)); + assertFunction("ARRAY_INTERSECT(ARRAY [true, true], ARRAY [true, true])", new ArrayType(BOOLEAN), ImmutableList.of(true)); + assertFunction("ARRAY_INTERSECT(ARRAY[true, false, null], ARRAY[true, null])", new ArrayType(BOOLEAN), asList(true, null)); } @Test - public void testEmpty() + public void testRow() { - assertFunction("array_intersect(ARRAY[], ARRAY[])", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("array_intersect(ARRAY[], ARRAY[1, 3])", new ArrayType(INTEGER), ImmutableList.of()); - assertFunction("array_intersect(ARRAY[CAST('abc' as VARCHAR)], ARRAY[])", new ArrayType(VARCHAR), ImmutableList.of()); + assertFunction( + "ARRAY_INTERSECT(ARRAY[(123, 456), (123, 789)], ARRAY[(123, 456), (123, 456), (123, 789)])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, INTEGER))), + ImmutableList.of(asList(123, 456), asList(123, 789))); + assertFunction( + "ARRAY_INTERSECT(ARRAY[ARRAY[123, 456], ARRAY[123, 789]], ARRAY[ARRAY[123, 456], ARRAY[123, 456], ARRAY[123, 789]])", + new ArrayType(new ArrayType((INTEGER))), + ImmutableList.of(asList(123, 456), asList(123, 789))); + assertFunction( + "ARRAY_INTERSECT(ARRAY[(123, 'abc'), (123, 'cde')], ARRAY[(123, 'abc'), (123, 'cde')])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of(asList(123, "abc"), asList(123, "cde"))); + assertFunction( + "ARRAY_INTERSECT(ARRAY[(123, 'abc'), (123, 'cde'), NULL], ARRAY[(123, 'abc'), (123, 'cde')])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of(asList(123, "abc"), asList(123, "cde"))); + assertFunction( + "ARRAY_INTERSECT(ARRAY[(123, 'abc'), (123, 'cde'), NULL, NULL], ARRAY[(123, 'abc'), (123, 'cde'), NULL])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + asList(asList(123, "abc"), asList(123, "cde"), null)); + assertFunction( + "ARRAY_INTERSECT(ARRAY[(123, 'abc'), (123, 'abc')], ARRAY[(123, 'abc'), (123, NULL)])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of(asList(123, "abc"))); + assertFunction( + "ARRAY_INTERSECT(ARRAY[(123, 'abc')], ARRAY[(123, NULL)])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of()); + } + + @Test + public void testIndeterminateRows() + { + assertFunction( + "ARRAY_INTERSECT(ARRAY[(123, 'abc'), (123, NULL)], ARRAY[(123, 'abc'), (123, NULL)])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of(asList(123, "abc"), asList(123, null))); + assertFunction( + "ARRAY_INTERSECT(ARRAY[(NULL, 'abc'), (123, 'abc')], ARRAY[(123, 'abc'),(NULL, 'abc')])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of(asList(null, "abc"), asList(123, "abc"))); + } + + @Test + public void testIndeterminateArrays() + { + assertFunction( + "ARRAY_INTERSECT(ARRAY[ARRAY[123, 456], ARRAY[123, NULL]], ARRAY[ARRAY[123, 456], ARRAY[123, NULL]])", + new ArrayType(new ArrayType(INTEGER)), + ImmutableList.of(asList(123, 456), asList(123, null))); + assertFunction( + "ARRAY_INTERSECT(ARRAY[ARRAY[NULL, 456], ARRAY[123, 456]], ARRAY[ARRAY[123, 456],ARRAY[NULL, 456]])", + new ArrayType(new ArrayType(INTEGER)), + ImmutableList.of(asList(null, 456), asList(123, 456))); + } + + @Test + public void testUnboundedRetainedSize() + { + assertCachedInstanceHasBoundedRetainedSize("ARRAY_INTERSECT(ARRAY ['foo', 'bar', 'baz'], ARRAY ['foo', 'test', 'bar'])"); + } + + @Test + public void testEmptyArrays() + { + assertFunction("ARRAY_INTERSECT(ARRAY[], ARRAY[])", new ArrayType(UNKNOWN), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY[], ARRAY[1, 3])", new ArrayType(INTEGER), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY [], ARRAY [5])", new ArrayType(INTEGER), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY [5, 6], ARRAY [])", new ArrayType(INTEGER), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY ['abc'], ARRAY [])", new ArrayType(createVarcharType(3)), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY [], ARRAY ['abc', 'bcd'])", new ArrayType(createVarcharType(3)), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY [], ARRAY [NULL])", new ArrayType(UNKNOWN), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY [], ARRAY [false])", new ArrayType(BOOLEAN), ImmutableList.of()); + } + + @Test + public void testEmptyResults() + { + assertFunction("ARRAY_INTERSECT(ARRAY [1], ARRAY [5])", new ArrayType(INTEGER), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY [CAST(1 AS BIGINT)], ARRAY [CAST(5 AS BIGINT)])", new ArrayType(BIGINT), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY [true, true], ARRAY [false])", new ArrayType(BOOLEAN), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY [5], ARRAY [1.0E0])", new ArrayType(DOUBLE), ImmutableList.of()); } @Test public void testNull() { - assertFunction("array_intersect(ARRAY[NULL], NULL)", new ArrayType(UNKNOWN), null); - assertFunction("array_intersect(NULL, NULL)", new ArrayType(UNKNOWN), null); - assertFunction("array_intersect(NULL, ARRAY[NULL])", new ArrayType(UNKNOWN), null); - assertFunction("array_intersect(ARRAY[NULL], ARRAY[NULL])", new ArrayType(UNKNOWN), asList(false ? 1 : null)); - assertFunction("array_intersect(ARRAY[], ARRAY[NULL])", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("array_intersect(ARRAY[NULL], ARRAY[])", new ArrayType(UNKNOWN), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(NULL, NULL)", new ArrayType(UNKNOWN), null); + assertFunction("ARRAY_INTERSECT(ARRAY[NULL], NULL)", new ArrayType(UNKNOWN), null); + assertFunction("ARRAY_INTERSECT(NULL, ARRAY[NULL])", new ArrayType(UNKNOWN), null); + assertFunction("ARRAY_INTERSECT(ARRAY[NULL], ARRAY[NULL])", new ArrayType(UNKNOWN), asList(false ? 1 : null)); + assertFunction("ARRAY_INTERSECT(ARRAY[], ARRAY[NULL])", new ArrayType(UNKNOWN), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY[NULL], ARRAY[])", new ArrayType(UNKNOWN), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY [NULL], ARRAY [NULL, NULL])", new ArrayType(UNKNOWN), asList((Object) null)); + + assertFunction("ARRAY_INTERSECT(ARRAY [0, 0, 1, NULL], ARRAY [0, 0, 1, NULL])", new ArrayType(INTEGER), asList(0, 1, null)); + assertFunction("ARRAY_INTERSECT(ARRAY [0, 0], ARRAY [0, 0, NULL])", new ArrayType(INTEGER), ImmutableList.of(0)); + assertFunction("ARRAY_INTERSECT(ARRAY [CAST(0 AS BIGINT), CAST(0 AS BIGINT)], ARRAY [CAST(0 AS BIGINT), NULL])", new ArrayType(BIGINT), ImmutableList.of(0L)); + assertFunction("ARRAY_INTERSECT(ARRAY [0.0E0], ARRAY [NULL])", new ArrayType(DOUBLE), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY [0.0E0, NULL], ARRAY [0.0E0, NULL])", new ArrayType(DOUBLE), asList(0.0, null)); + assertFunction("ARRAY_INTERSECT(ARRAY [true, true, false, false, NULL], ARRAY [true, false, false, NULL])", new ArrayType(BOOLEAN), asList(true, false, null)); + assertFunction("ARRAY_INTERSECT(ARRAY [false, false], ARRAY [false, false, NULL])", new ArrayType(BOOLEAN), ImmutableList.of(false)); + assertFunction("ARRAY_INTERSECT(ARRAY ['abc'], ARRAY [NULL])", new ArrayType(createVarcharType(3)), ImmutableList.of()); + assertFunction("ARRAY_INTERSECT(ARRAY [''], ARRAY ['', NULL])", new ArrayType(createVarcharType(0)), ImmutableList.of("")); + assertFunction("ARRAY_INTERSECT(ARRAY ['', NULL], ARRAY ['', NULL])", new ArrayType(createVarcharType(0)), asList("", null)); + assertFunction("ARRAY_INTERSECT(ARRAY [NULL], ARRAY ['abc', NULL])", new ArrayType(createVarcharType(3)), singletonList(null)); + assertFunction("ARRAY_INTERSECT(ARRAY ['abc', NULL, 'xyz', NULL], ARRAY [NULL, 'abc', NULL, NULL])", new ArrayType(createVarcharType(3)), asList("abc", null)); } @Test @@ -68,7 +212,7 @@ public void testDuplicates() } @Test - public void testSQLFunctions() + public void testSqlFunctions() { assertFunction("array_intersect(ARRAY[ARRAY[1, 3, 5], ARRAY[2, 3, 5], ARRAY[3, 3, 3, 6]])", new ArrayType(INTEGER), ImmutableList.of(3)); assertFunction("array_intersect(ARRAY[ARRAY[], ARRAY[1, 2, 3]])", new ArrayType(INTEGER), ImmutableList.of()); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayUnionFunction.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayUnionFunction.java new file mode 100644 index 0000000000000..095a624b32d84 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestArrayUnionFunction.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.RowType; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.UnknownType.UNKNOWN; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static java.util.Arrays.asList; + +public class TestArrayUnionFunction + extends AbstractTestFunctions +{ + @Test + public void testBigint() + { + assertFunction("ARRAY_UNION(ARRAY [cast(10 as bigint), NULL, cast(12 as bigint), NULL], ARRAY [NULL, cast(10 as bigint), NULL, NULL])", new ArrayType(BIGINT), asList(10L, null, 12L)); + } + + public void testInteger() + { + assertFunction("ARRAY_UNION(ARRAY [12], ARRAY [10])", new ArrayType(INTEGER), ImmutableList.of(12, 10)); + assertFunction("ARRAY_UNION(ARRAY [1, 5], ARRAY [1])", new ArrayType(INTEGER), ImmutableList.of(1, 5)); + assertFunction("ARRAY_UNION(ARRAY [1, 1, 2, 4], ARRAY [1, 1, 4, 4])", new ArrayType(INTEGER), ImmutableList.of(1, 2, 4)); + assertFunction("ARRAY_UNION(ARRAY [2, 8], ARRAY [8, 3])", new ArrayType(INTEGER), ImmutableList.of(2, 8, 3)); + assertFunction("ARRAY_UNION(ARRAY [IF (RAND() < 1.0E0, 7, 1) , 2], ARRAY [7])", new ArrayType(INTEGER), ImmutableList.of(7, 2)); + } + + @Test + public void testVarchar() + { + assertFunction("ARRAY_UNION(ARRAY ['foo', 'bar', 'baz'], ARRAY ['foo', 'test', 'bar'])", new ArrayType(createVarcharType(4)), ImmutableList.of("foo", "bar", "baz", "test")); + } + + @Test + public void testDouble() + { + assertFunction("ARRAY_UNION(ARRAY [1, 5], ARRAY [1.0E0])", new ArrayType(DOUBLE), ImmutableList.of(1.0, 5.0)); + assertFunction("ARRAY_UNION(ARRAY [8.3E0, 1.6E0, 4.1E0, 5.2E0], ARRAY [4.0E0, 5.2E0, 8.3E0, 9.7E0, 3.5E0])", new ArrayType(DOUBLE), ImmutableList.of(8.3, 1.6, 4.1, 5.2, 4.0, 9.7, 3.5)); + assertFunction("ARRAY_UNION(ARRAY [5.1E0, 7, 3.0E0, 4.8E0, 10], ARRAY [6.5E0, 10.0E0, 1.9E0, 5.1E0, 3.9E0, 4.8E0])", new ArrayType(DOUBLE), ImmutableList.of(5.1, 7.0, 3.0, 4.8, 10.0, 6.5, 1.9, 3.9)); + } + + @Test + public void testArrayOfArrays() + { + assertFunction("ARRAY_UNION(ARRAY [ARRAY [4, 5], ARRAY [6, 7]], ARRAY [ARRAY [4, 5], ARRAY [6, 8]])", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(4, 5), ImmutableList.of(6, 7), ImmutableList.of(6, 8))); + } + + @Test + public void testNull() + { + assertFunction("ARRAY_UNION(ARRAY [NULL], ARRAY [NULL, NULL])", new ArrayType(UNKNOWN), asList((Object) null)); + assertFunction("ARRAY_UNION(ARRAY ['abc', NULL, 'xyz', NULL], ARRAY [NULL, 'abc', NULL, NULL])", new ArrayType(createVarcharType(3)), asList("abc", null, "xyz")); + } + + @Test + public void testIndeterminateRows() + { + // test unsupported + assertFunction( + "array_union(ARRAY[(123, 'abc'), (123, NULL)], ARRAY[(123, 'abc'), (123, NULL)])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of(asList(123, "abc"), asList(123, null))); + assertFunction( + "array_union(ARRAY[(NULL, 'abc'), (123, null), (123, 'abc')], ARRAY[(456, 'def'),(NULL, 'abc')])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of(asList(null, "abc"), asList(123, null), asList(123, "abc"), asList(456, "def"))); + } + + @Test + public void testIndeterminateArrays() + { + assertFunction( + "array_union(ARRAY[ARRAY[123, 456], ARRAY[123, NULL]], ARRAY[ARRAY[123, 456], ARRAY[123, NULL]])", + new ArrayType(new ArrayType(INTEGER)), + ImmutableList.of(asList(123, 456), asList(123, null))); + assertFunction( + "array_union(ARRAY[ARRAY[NULL, 456], ARRAY[123, 456]], ARRAY[ARRAY[123, 456],ARRAY[NULL, 456]])", + new ArrayType(new ArrayType(INTEGER)), + ImmutableList.of(asList(null, 456), asList(123, 456))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java b/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java index e7618eb0c48d0..3f7b795895f35 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestArrayOperators.java @@ -102,7 +102,7 @@ public class TestArrayOperators { private static FunctionAssertions fieldNameInJsonCastEnabled; - public TestArrayOperators(){} + public TestArrayOperators() {} @BeforeClass public void setUp() @@ -1173,6 +1173,19 @@ public void testDistinct() assertCachedInstanceHasBoundedRetainedSize("ARRAY_DISTINCT(ARRAY['cat', 'dog', 'dog', 'coffee', 'apple'])"); } + @Test + public void testDistinctWithIndeterminateRows() + { + assertFunction( + "ARRAY_DISTINCT(ARRAY[(123, 'abc'), (123, NULL)])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of(asList(123, "abc"), asList(123, null))); + assertFunction( + "ARRAY_DISTINCT(ARRAY[(NULL, NULL), (42, 'def'), (NULL, 'abc'), (123, NULL), (42, 'def'), (NULL, NULL), (NULL, 'abc'), (123, NULL)])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of(asList(null, null), asList(42, "def"), asList(null, "abc"), asList(123, null))); + } + @Test public void testSlice() { @@ -1273,122 +1286,6 @@ public void testArraysOverlap() assertFunction("ARRAYS_OVERLAP(ARRAY [2.4, 9.0, 10.9999999, 9.1, 4.1, 8.1], ARRAY [2.1, 10.999])", BooleanType.BOOLEAN, false); } - @Test - public void testArrayIntersect() - { - // test basic - assertFunction("ARRAY_INTERSECT(ARRAY [5], ARRAY [5])", new ArrayType(INTEGER), ImmutableList.of(5)); - assertFunction("ARRAY_INTERSECT(ARRAY [1, 2, 5, 5, 6], ARRAY [5, 5, 6, 6, 7, 8])", new ArrayType(INTEGER), ImmutableList.of(5, 6)); - assertFunction("ARRAY_INTERSECT(ARRAY [IF (RAND() < 1.0E0, 7, 1) , 2], ARRAY [7])", new ArrayType(INTEGER), ImmutableList.of(7)); - assertFunction("ARRAY_INTERSECT(ARRAY [CAST(5 AS BIGINT), CAST(5 AS BIGINT)], ARRAY [CAST(1 AS BIGINT), CAST(5 AS BIGINT)])", new ArrayType(BIGINT), ImmutableList.of(5L)); - assertFunction("ARRAY_INTERSECT(ARRAY [1, 5], ARRAY [1.0E0])", new ArrayType(DOUBLE), ImmutableList.of(1.0)); - assertFunction("ARRAY_INTERSECT(ARRAY [1.0E0, 5.0E0], ARRAY [5.0E0, 5.0E0, 6.0E0])", new ArrayType(DOUBLE), ImmutableList.of(5.0)); - assertFunction("ARRAY_INTERSECT(ARRAY [8.3E0, 1.6E0, 4.1E0, 5.2E0], ARRAY [4.0E0, 5.2E0, 8.3E0, 9.7E0, 3.5E0])", new ArrayType(DOUBLE), ImmutableList.of(5.2, 8.3)); - assertFunction("ARRAY_INTERSECT(ARRAY [5.1E0, 7, 3.0E0, 4.8E0, 10], ARRAY [6.5E0, 10.0E0, 1.9E0, 5.1E0, 3.9E0, 4.8E0])", new ArrayType(DOUBLE), ImmutableList.of(10.0, 5.1, 4.8)); - assertFunction( - "ARRAY_INTERSECT(ARRAY [2.3, 2.3, 2.2], ARRAY[2.2, 2.3])", - new ArrayType(createDecimalType(2, 1)), - ImmutableList.of(decimal("2.3"), decimal("2.2"))); - assertFunction("ARRAY_INTERSECT(ARRAY [2.330, 1.900, 2.330], ARRAY [2.3300, 1.9000])", new ArrayType(createDecimalType(5, 4)), - ImmutableList.of(decimal("2.3300"), decimal("1.9000"))); - assertFunction("ARRAY_INTERSECT(ARRAY [2, 3], ARRAY[2.0, 3.0])", new ArrayType(createDecimalType(11, 1)), - ImmutableList.of(decimal("00000000002.0"), decimal("00000000003.0"))); - assertFunction("ARRAY_INTERSECT(ARRAY [true], ARRAY [true])", new ArrayType(BOOLEAN), ImmutableList.of(true)); - assertFunction("ARRAY_INTERSECT(ARRAY [true, false], ARRAY [true])", new ArrayType(BOOLEAN), ImmutableList.of(true)); - assertFunction("ARRAY_INTERSECT(ARRAY [true, true], ARRAY [true, true])", new ArrayType(BOOLEAN), ImmutableList.of(true)); - assertFunction("ARRAY_INTERSECT(ARRAY ['abc'], ARRAY ['abc', 'bcd'])", new ArrayType(createVarcharType(3)), ImmutableList.of("abc")); - assertFunction("ARRAY_INTERSECT(ARRAY ['abc', 'abc'], ARRAY ['abc', 'abc'])", new ArrayType(createVarcharType(3)), ImmutableList.of("abc")); - assertFunction("ARRAY_INTERSECT(ARRAY ['foo', 'bar', 'baz'], ARRAY ['foo', 'test', 'bar'])", new ArrayType(createVarcharType(4)), ImmutableList.of("foo", "bar")); - - // test empty results - assertFunction("ARRAY_INTERSECT(ARRAY [], ARRAY [5])", new ArrayType(INTEGER), ImmutableList.of()); - assertFunction("ARRAY_INTERSECT(ARRAY [5, 6], ARRAY [])", new ArrayType(INTEGER), ImmutableList.of()); - assertFunction("ARRAY_INTERSECT(ARRAY [1], ARRAY [5])", new ArrayType(INTEGER), ImmutableList.of()); - assertFunction("ARRAY_INTERSECT(ARRAY [CAST(1 AS BIGINT)], ARRAY [CAST(5 AS BIGINT)])", new ArrayType(BIGINT), ImmutableList.of()); - assertFunction("ARRAY_INTERSECT(ARRAY [true, true], ARRAY [false])", new ArrayType(BOOLEAN), ImmutableList.of()); - assertFunction("ARRAY_INTERSECT(ARRAY [], ARRAY [false])", new ArrayType(BOOLEAN), ImmutableList.of()); - assertFunction("ARRAY_INTERSECT(ARRAY [5], ARRAY [1.0E0])", new ArrayType(DOUBLE), ImmutableList.of()); - assertFunction("ARRAY_INTERSECT(ARRAY ['abc'], ARRAY [])", new ArrayType(createVarcharType(3)), ImmutableList.of()); - assertFunction("ARRAY_INTERSECT(ARRAY [], ARRAY ['abc', 'bcd'])", new ArrayType(createVarcharType(3)), ImmutableList.of()); - assertFunction("ARRAY_INTERSECT(ARRAY [], ARRAY [])", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("ARRAY_INTERSECT(ARRAY [], ARRAY [NULL])", new ArrayType(UNKNOWN), ImmutableList.of()); - - // test nulls - assertFunction("ARRAY_INTERSECT(ARRAY [NULL], ARRAY [NULL, NULL])", new ArrayType(UNKNOWN), asList((Object) null)); - assertFunction("ARRAY_INTERSECT(ARRAY [0, 0, 1, NULL], ARRAY [0, 0, 1, NULL])", new ArrayType(INTEGER), asList(0, 1, null)); - assertFunction("ARRAY_INTERSECT(ARRAY [0, 0], ARRAY [0, 0, NULL])", new ArrayType(INTEGER), ImmutableList.of(0)); - assertFunction("ARRAY_INTERSECT(ARRAY [CAST(0 AS BIGINT), CAST(0 AS BIGINT)], ARRAY [CAST(0 AS BIGINT), NULL])", new ArrayType(BIGINT), ImmutableList.of(0L)); - assertFunction("ARRAY_INTERSECT(ARRAY [0.0E0], ARRAY [NULL])", new ArrayType(DOUBLE), ImmutableList.of()); - assertFunction("ARRAY_INTERSECT(ARRAY [0.0E0, NULL], ARRAY [0.0E0, NULL])", new ArrayType(DOUBLE), asList(0.0, null)); - assertFunction("ARRAY_INTERSECT(ARRAY [true, true, false, false, NULL], ARRAY [true, false, false, NULL])", new ArrayType(BOOLEAN), asList(true, false, null)); - assertFunction("ARRAY_INTERSECT(ARRAY [false, false], ARRAY [false, false, NULL])", new ArrayType(BOOLEAN), ImmutableList.of(false)); - assertFunction("ARRAY_INTERSECT(ARRAY ['abc'], ARRAY [NULL])", new ArrayType(createVarcharType(3)), ImmutableList.of()); - assertFunction("ARRAY_INTERSECT(ARRAY [''], ARRAY ['', NULL])", new ArrayType(createVarcharType(0)), ImmutableList.of("")); - assertFunction("ARRAY_INTERSECT(ARRAY ['', NULL], ARRAY ['', NULL])", new ArrayType(createVarcharType(0)), asList("", null)); - assertFunction("ARRAY_INTERSECT(ARRAY [NULL], ARRAY ['abc', NULL])", new ArrayType(createVarcharType(3)), singletonList(null)); - assertFunction("ARRAY_INTERSECT(ARRAY ['abc', NULL, 'xyz', NULL], ARRAY [NULL, 'abc', NULL, NULL])", new ArrayType(createVarcharType(3)), asList("abc", null)); - assertFunction("ARRAY_INTERSECT(ARRAY [], ARRAY [NULL])", new ArrayType(UNKNOWN), ImmutableList.of()); - assertFunction("ARRAY_INTERSECT(ARRAY [NULL], ARRAY [NULL])", new ArrayType(UNKNOWN), singletonList(null)); - - // test composite types - assertFunction( - "ARRAY_INTERSECT(ARRAY[(123, 456), (123, 789)], ARRAY[(123, 456), (123, 456), (123, 789)])", - new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, INTEGER))), - ImmutableList.of(asList(123, 456), asList(123, 789))); - assertFunction( - "ARRAY_INTERSECT(ARRAY[ARRAY[123, 456], ARRAY[123, 789]], ARRAY[ARRAY[123, 456], ARRAY[123, 456], ARRAY[123, 789]])", - new ArrayType(new ArrayType((INTEGER))), - ImmutableList.of(asList(123, 456), asList(123, 789))); - assertFunction( - "ARRAY_INTERSECT(ARRAY[(123, 'abc'), (123, 'cde')], ARRAY[(123, 'abc'), (123, 'cde')])", - new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), - ImmutableList.of(asList(123, "abc"), asList(123, "cde"))); - assertFunction( - "ARRAY_INTERSECT(ARRAY[(123, 'abc'), (123, 'cde'), NULL], ARRAY[(123, 'abc'), (123, 'cde')])", - new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), - ImmutableList.of(asList(123, "abc"), asList(123, "cde"))); - assertFunction( - "ARRAY_INTERSECT(ARRAY[(123, 'abc'), (123, 'cde'), NULL, NULL], ARRAY[(123, 'abc'), (123, 'cde'), NULL])", - new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), - asList(asList(123, "abc"), asList(123, "cde"), null)); - assertFunction( - "ARRAY_INTERSECT(ARRAY[(123, 'abc'), (123, 'abc')], ARRAY[(123, 'abc'), (123, NULL)])", - new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), - ImmutableList.of(asList(123, "abc"))); - assertFunction( - "ARRAY_INTERSECT(ARRAY[(123, 'abc')], ARRAY[(123, NULL)])", - new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), - ImmutableList.of()); - - // test unsupported - assertNotSupported( - "ARRAY_INTERSECT(ARRAY[(123, 'abc'), (123, NULL)], ARRAY[(123, 'abc'), (123, NULL)])", - "ROW comparison not supported for fields with null elements"); - assertNotSupported( - "ARRAY_INTERSECT(ARRAY[(NULL, 'abc'), (123, 'abc')], ARRAY[(123, 'abc'),(NULL, 'abc')])", - "ROW comparison not supported for fields with null elements"); - - assertCachedInstanceHasBoundedRetainedSize("ARRAY_INTERSECT(ARRAY ['foo', 'bar', 'baz'], ARRAY ['foo', 'test', 'bar'])"); - } - - @Test - public void testArrayUnion() - { - assertFunction("ARRAY_UNION(ARRAY [cast(10 as bigint), NULL, cast(12 as bigint), NULL], ARRAY [NULL, cast(10 as bigint), NULL, NULL])", new ArrayType(BIGINT), asList(10L, null, 12L)); - assertFunction("ARRAY_UNION(ARRAY [12], ARRAY [10])", new ArrayType(INTEGER), ImmutableList.of(12, 10)); - assertFunction("ARRAY_UNION(ARRAY ['foo', 'bar', 'baz'], ARRAY ['foo', 'test', 'bar'])", new ArrayType(createVarcharType(4)), ImmutableList.of("foo", "bar", "baz", "test")); - assertFunction("ARRAY_UNION(ARRAY [NULL], ARRAY [NULL, NULL])", new ArrayType(UNKNOWN), asList((Object) null)); - assertFunction("ARRAY_UNION(ARRAY ['abc', NULL, 'xyz', NULL], ARRAY [NULL, 'abc', NULL, NULL])", new ArrayType(createVarcharType(3)), asList("abc", null, "xyz")); - assertFunction("ARRAY_UNION(ARRAY [1, 5], ARRAY [1])", new ArrayType(INTEGER), ImmutableList.of(1, 5)); - assertFunction("ARRAY_UNION(ARRAY [1, 1, 2, 4], ARRAY [1, 1, 4, 4])", new ArrayType(INTEGER), ImmutableList.of(1, 2, 4)); - assertFunction("ARRAY_UNION(ARRAY [2, 8], ARRAY [8, 3])", new ArrayType(INTEGER), ImmutableList.of(2, 8, 3)); - assertFunction("ARRAY_UNION(ARRAY [IF (RAND() < 1.0E0, 7, 1) , 2], ARRAY [7])", new ArrayType(INTEGER), ImmutableList.of(7, 2)); - assertFunction("ARRAY_UNION(ARRAY [1, 5], ARRAY [1.0E0])", new ArrayType(DOUBLE), ImmutableList.of(1.0, 5.0)); - assertFunction("ARRAY_UNION(ARRAY [8.3E0, 1.6E0, 4.1E0, 5.2E0], ARRAY [4.0E0, 5.2E0, 8.3E0, 9.7E0, 3.5E0])", new ArrayType(DOUBLE), ImmutableList.of(8.3, 1.6, 4.1, 5.2, 4.0, 9.7, 3.5)); - assertFunction("ARRAY_UNION(ARRAY [5.1E0, 7, 3.0E0, 4.8E0, 10], ARRAY [6.5E0, 10.0E0, 1.9E0, 5.1E0, 3.9E0, 4.8E0])", new ArrayType(DOUBLE), ImmutableList.of(5.1, 7.0, 3.0, 4.8, 10.0, 6.5, 1.9, 3.9)); - assertFunction("ARRAY_UNION(ARRAY [ARRAY [4, 5], ARRAY [6, 7]], ARRAY [ARRAY [4, 5], ARRAY [6, 8]])", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(4, 5), ImmutableList.of(6, 7), ImmutableList.of(6, 8))); - } - @Test public void testComparison() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/aggregation/AggregationMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/aggregation/AggregationMetadata.java index cd00af12e25d4..de42f02aefce1 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/aggregation/AggregationMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/aggregation/AggregationMetadata.java @@ -152,7 +152,7 @@ private static void verifyInputFunctionSignature(MethodHandle method, List