diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java index 9fe86559908a..7e3892f88a3a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/multimapagg/MultimapAggregationFunction.java @@ -30,8 +30,8 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.type.BlockTypeOperators; -import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; +import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; @@ -40,7 +40,7 @@ import static io.trino.metadata.FunctionKind.AGGREGATE; import static io.trino.metadata.Signature.comparableTypeParameter; import static io.trino.metadata.Signature.typeVariable; -import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet; +import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet; import static io.trino.spi.type.TypeSignature.arrayType; import static io.trino.spi.type.TypeSignature.mapType; import static io.trino.spi.type.TypeSignature.rowType; @@ -57,7 +57,7 @@ public class MultimapAggregationFunction MultimapAggregationFunction.class, "output", Type.class, - BlockPositionEqual.class, + BlockPositionIsDistinctFrom.class, BlockPositionHashCode.class, Type.class, MultimapAggregationState.class, @@ -103,7 +103,7 @@ public MultimapAggregationFunction(BlockTypeOperators blockTypeOperators) public AggregationMetadata specialize(BoundSignature boundSignature) { Type keyType = boundSignature.getArgumentType(0); - BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(keyType); + BlockPositionIsDistinctFrom keyDistinctOperator = blockTypeOperators.getDistinctFromOperator(keyType); BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(keyType); Type valueType = boundSignature.getArgumentType(1); @@ -114,7 +114,7 @@ public AggregationMetadata specialize(BoundSignature boundSignature) INPUT_FUNCTION, Optional.empty(), Optional.of(COMBINE_FUNCTION), - MethodHandles.insertArguments(OUTPUT_FUNCTION, 0, keyType, keyEqual, keyHashCode, valueType), + MethodHandles.insertArguments(OUTPUT_FUNCTION, 0, keyType, keyDistinctOperator, keyHashCode, valueType), ImmutableList.of(new AccumulatorStateDescriptor<>( MultimapAggregationState.class, stateSerializer, @@ -131,7 +131,7 @@ public static void combine(MultimapAggregationState state, MultimapAggregationSt state.merge(otherState); } - public static void output(Type keyType, BlockPositionEqual keyEqual, BlockPositionHashCode keyHashCode, Type valueType, MultimapAggregationState state, BlockBuilder out) + public static void output(Type keyType, BlockPositionIsDistinctFrom keyDistinctOperator, BlockPositionHashCode keyHashCode, Type valueType, MultimapAggregationState state, BlockBuilder out) { if (state.isEmpty()) { out.appendNull(); @@ -141,7 +141,7 @@ public static void output(Type keyType, BlockPositionEqual keyEqual, BlockPositi ObjectBigArray valueArrayBlockBuilders = new ObjectBigArray<>(); valueArrayBlockBuilders.ensureCapacity(state.getEntryCount()); BlockBuilder distinctKeyBlockBuilder = keyType.createBlockBuilder(null, state.getEntryCount(), expectedValueSize(keyType, 100)); - TypedSet keySet = createEqualityTypedSet(keyType, keyEqual, keyHashCode, state.getEntryCount(), NAME); + TypedSet keySet = createDistinctTypedSet(keyType, keyDistinctOperator, keyHashCode, state.getEntryCount(), NAME); state.forEach((key, value, keyValueIndex) -> { // Merge values of the same key into an array diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayExceptFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayExceptFunction.java index cb3cfb0a1c46..54d7e5afa385 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayExceptFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayExceptFunction.java @@ -23,15 +23,14 @@ import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; -import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; +import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; -import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet; +import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; -import static io.trino.spi.function.OperatorType.EQUAL; import static io.trino.spi.function.OperatorType.HASH_CODE; +import static io.trino.spi.function.OperatorType.IS_DISTINCT_FROM; @ScalarFunction("array_except") @Description("Returns an array of elements that are in the first array but not the second, without duplicates.") @@ -44,9 +43,9 @@ private ArrayExceptFunction() {} public static Block except( @TypeParameter("E") Type type, @OperatorDependency( - operator = EQUAL, + operator = IS_DISTINCT_FROM, argumentTypes = {"E", "E"}, - convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) BlockPositionEqual elementEqual, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = FAIL_ON_NULL)) BlockPositionIsDistinctFrom isDistinctOperator, @OperatorDependency( operator = HASH_CODE, argumentTypes = "E", @@ -61,7 +60,7 @@ public static Block except( return leftArray; } - TypedSet typedSet = createEqualityTypedSet(type, elementEqual, elementHashCode, leftPositionCount, "array_except"); + TypedSet typedSet = createDistinctTypedSet(type, isDistinctOperator, elementHashCode, leftPositionCount, "array_except"); BlockBuilder distinctElementBlockBuilder = type.createBlockBuilder(null, leftPositionCount); for (int i = 0; i < rightPositionCount; i++) { typedSet.add(rightArray, i); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayUnionFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayUnionFunction.java index 8596c53fe694..ac21673bce25 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayUnionFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayUnionFunction.java @@ -23,19 +23,18 @@ import io.trino.spi.function.SqlType; import io.trino.spi.function.TypeParameter; import io.trino.spi.type.Type; -import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; +import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; import it.unimi.dsi.fastutil.longs.LongSet; import java.util.concurrent.atomic.AtomicBoolean; -import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet; +import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; -import static io.trino.spi.function.OperatorType.EQUAL; import static io.trino.spi.function.OperatorType.HASH_CODE; +import static io.trino.spi.function.OperatorType.IS_DISTINCT_FROM; import static io.trino.spi.type.BigintType.BIGINT; @ScalarFunction("array_union") @@ -49,9 +48,9 @@ private ArrayUnionFunction() {} public static Block union( @TypeParameter("E") Type type, @OperatorDependency( - operator = EQUAL, + operator = IS_DISTINCT_FROM, argumentTypes = {"E", "E"}, - convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) BlockPositionEqual elementEqual, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = FAIL_ON_NULL)) BlockPositionIsDistinctFrom isDistinctOperator, @OperatorDependency( operator = HASH_CODE, argumentTypes = "E", @@ -62,9 +61,9 @@ public static Block union( int leftArrayCount = leftArray.getPositionCount(); int rightArrayCount = rightArray.getPositionCount(); BlockBuilder distinctElementBlockBuilder = type.createBlockBuilder(null, leftArrayCount + rightArrayCount); - TypedSet typedSet = createEqualityTypedSet( + TypedSet typedSet = createDistinctTypedSet( type, - elementEqual, + isDistinctOperator, elementHashCode, distinctElementBlockBuilder, leftArrayCount + rightArrayCount, diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapConcatFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapConcatFunction.java index 50ef32d39cde..6e59f32ea925 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapConcatFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapConcatFunction.java @@ -30,8 +30,8 @@ import io.trino.spi.type.TypeSignature; import io.trino.sql.gen.VarArgsToArrayAdapterGenerator.MethodHandleAndConstructor; import io.trino.type.BlockTypeOperators; -import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; +import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; @@ -39,7 +39,7 @@ import static io.trino.metadata.FunctionKind.SCALAR; import static io.trino.metadata.Signature.typeVariable; -import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet; +import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; @@ -61,7 +61,7 @@ public final class MapConcatFunction MapConcatFunction.class, "mapConcat", MapType.class, - BlockPositionEqual.class, + BlockPositionIsDistinctFrom.class, BlockPositionHashCode.class, Object.class, Block[].class); @@ -95,14 +95,14 @@ protected ScalarFunctionImplementation specialize(BoundSignature boundSignature) MapType mapType = (MapType) boundSignature.getReturnType(); Type keyType = mapType.getKeyType(); - BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(keyType); + BlockPositionIsDistinctFrom keysDistinctOperator = blockTypeOperators.getDistinctFromOperator(keyType); BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(keyType); MethodHandleAndConstructor methodHandleAndConstructor = generateVarArgsToArrayAdapter( Block.class, Block.class, boundSignature.getArity(), - MethodHandles.insertArguments(METHOD_HANDLE, 0, mapType, keyEqual, keyHashCode), + MethodHandles.insertArguments(METHOD_HANDLE, 0, mapType, keysDistinctOperator, keyHashCode), USER_STATE_FACTORY.bindTo(mapType)); return new ChoicesScalarFunctionImplementation( @@ -120,7 +120,7 @@ public static Object createMapState(MapType mapType) } @UsedByGeneratedCode - public static Block mapConcat(MapType mapType, BlockPositionEqual keyEqual, BlockPositionHashCode keyHashCode, Object state, Block[] maps) + public static Block mapConcat(MapType mapType, BlockPositionIsDistinctFrom keysDistinctOperator, BlockPositionHashCode keyHashCode, Object state, Block[] maps) { int entries = 0; int lastMapIndex = maps.length - 1; @@ -144,7 +144,7 @@ public static Block mapConcat(MapType mapType, BlockPositionEqual keyEqual, Bloc // TODO: we should move TypedSet into user state as well Type keyType = mapType.getKeyType(); Type valueType = mapType.getValueType(); - TypedSet typedSet = createEqualityTypedSet(keyType, keyEqual, keyHashCode, entries / 2, FUNCTION_NAME); + TypedSet typedSet = createDistinctTypedSet(keyType, keysDistinctOperator, keyHashCode, entries / 2, FUNCTION_NAME); BlockBuilder mapBlockBuilder = pageBuilder.getBlockBuilder(0); BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry(); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapFromEntriesFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapFromEntriesFunction.java index aff0304d42ee..f4131c3d5ac6 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapFromEntriesFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapFromEntriesFunction.java @@ -30,16 +30,15 @@ import io.trino.spi.type.MapType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; +import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; -import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet; +import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; -import static io.trino.spi.function.OperatorType.EQUAL; import static io.trino.spi.function.OperatorType.HASH_CODE; +import static io.trino.spi.function.OperatorType.IS_DISTINCT_FROM; import static java.lang.String.format; @ScalarFunction("map_from_entries") @@ -61,9 +60,9 @@ public MapFromEntriesFunction(@TypeParameter("map(K,V)") Type mapType) @SqlNullable public Block mapFromEntries( @OperatorDependency( - operator = EQUAL, + operator = IS_DISTINCT_FROM, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) BlockPositionEqual keyEqual, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = FAIL_ON_NULL)) BlockPositionIsDistinctFrom keysDistinctOperator, @OperatorDependency( operator = HASH_CODE, argumentTypes = "K", @@ -84,7 +83,7 @@ public Block mapFromEntries( BlockBuilder mapBlockBuilder = pageBuilder.getBlockBuilder(0); BlockBuilder resultBuilder = mapBlockBuilder.beginBlockEntry(); - TypedSet uniqueKeys = createEqualityTypedSet(keyType, keyEqual, keyHashCode, entryCount, "map_from_entries"); + TypedSet uniqueKeys = createDistinctTypedSet(keyType, keysDistinctOperator, keyHashCode, entryCount, "map_from_entries"); for (int i = 0; i < entryCount; i++) { if (mapEntries.isNull(i)) { diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java index 0b40fc2a9f17..fe700eb40621 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java @@ -31,8 +31,8 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.type.BlockTypeOperators; -import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; +import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; @@ -41,7 +41,7 @@ import static com.google.common.base.Preconditions.checkState; import static io.trino.metadata.Signature.castableToTypeParameter; import static io.trino.metadata.Signature.typeVariable; -import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet; +import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet; import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; import static io.trino.spi.block.MethodHandleUtil.compose; import static io.trino.spi.block.MethodHandleUtil.nativeValueWriter; @@ -66,7 +66,7 @@ public final class MapToMapCast MethodHandle.class, MethodHandle.class, Type.class, - BlockPositionEqual.class, + BlockPositionIsDistinctFrom.class, BlockPositionHashCode.class, ConnectorSession.class, Block.class); @@ -116,7 +116,7 @@ public ScalarFunctionImplementation specialize(BoundSignature boundSignature, Fu MethodHandle keyProcessor = buildProcessor(functionDependencies, fromKeyType, toKeyType, true); MethodHandle valueProcessor = buildProcessor(functionDependencies, fromValueType, toValueType, false); - BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(toKeyType); + BlockPositionIsDistinctFrom keyEqual = blockTypeOperators.getDistinctFromOperator(toKeyType); BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(toKeyType); MethodHandle target = MethodHandles.insertArguments(METHOD_HANDLE, 0, keyProcessor, valueProcessor, toMapType, keyEqual, keyHashCode); return new ChoicesScalarFunctionImplementation(boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL), target); @@ -238,14 +238,14 @@ public static Block mapCast( MethodHandle keyProcessFunction, MethodHandle valueProcessFunction, Type targetType, - BlockPositionEqual keyEqual, + BlockPositionIsDistinctFrom keyDistinctOperator, BlockPositionHashCode keyHashCode, ConnectorSession session, Block fromMap) { checkState(targetType.getTypeParameters().size() == 2, "Expect two type parameters for targetType"); Type toKeyType = targetType.getTypeParameters().get(0); - TypedSet resultKeys = createEqualityTypedSet(toKeyType, keyEqual, keyHashCode, fromMap.getPositionCount() / 2, "map-to-map cast"); + TypedSet resultKeys = createDistinctTypedSet(toKeyType, keyDistinctOperator, keyHashCode, fromMap.getPositionCount() / 2, "map-to-map cast"); // Cast the keys into a new block BlockBuilder keyBlockBuilder = toKeyType.createBlockBuilder(null, fromMap.getPositionCount() / 2); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MultimapFromEntriesFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MultimapFromEntriesFunction.java index a6f9255431fa..cda225e87a3a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MultimapFromEntriesFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MultimapFromEntriesFunction.java @@ -30,19 +30,18 @@ import io.trino.spi.type.MapType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; -import io.trino.type.BlockTypeOperators.BlockPositionEqual; import io.trino.type.BlockTypeOperators.BlockPositionHashCode; +import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom; import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntList; import static com.google.common.base.Verify.verify; -import static io.trino.operator.aggregation.TypedSet.createEqualityTypedSet; +import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet; import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; -import static io.trino.spi.function.OperatorType.EQUAL; import static io.trino.spi.function.OperatorType.HASH_CODE; +import static io.trino.spi.function.OperatorType.IS_DISTINCT_FROM; @ScalarFunction("multimap_from_entries") @Description("Construct a multimap from an array of entries") @@ -69,9 +68,9 @@ public MultimapFromEntriesFunction(@TypeParameter("map(K,array(V))") Type mapTyp public Block multimapFromEntries( @TypeParameter("map(K,array(V))") MapType mapType, @OperatorDependency( - operator = EQUAL, + operator = IS_DISTINCT_FROM, argumentTypes = {"K", "K"}, - convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN)) BlockPositionEqual keyEqual, + convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = FAIL_ON_NULL)) BlockPositionIsDistinctFrom keysDistinctOperator, @OperatorDependency( operator = HASH_CODE, argumentTypes = "K", @@ -90,7 +89,7 @@ public Block multimapFromEntries( if (entryCount > entryIndicesList.length) { initializeEntryIndicesList(entryCount); } - TypedSet keySet = createEqualityTypedSet(keyType, keyEqual, keyHashCode, entryCount, NAME); + TypedSet keySet = createDistinctTypedSet(keyType, keysDistinctOperator, keyHashCode, entryCount, NAME); for (int i = 0; i < entryCount; i++) { if (mapEntries.isNull(i)) { diff --git a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java index c3d5174c62e8..0a6cc2d2d623 100644 --- a/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java +++ b/core/trino-main/src/test/java/io/trino/operator/aggregation/TestMultimapAggAggregation.java @@ -88,6 +88,13 @@ public void testNullMap() testMultimapAgg(DOUBLE, ImmutableList.of(), VARCHAR, ImmutableList.of()); } + @Test + public void testKeysUseIsDistinctSemantics() + { + testMultimapAgg(DOUBLE, ImmutableList.of(Double.NaN, Double.NaN), BIGINT, ImmutableList.of(1L, 1L)); + testMultimapAgg(DOUBLE, ImmutableList.of(Double.NaN, Double.NaN, Double.NaN), BIGINT, ImmutableList.of(2L, 1L, 2L)); + } + @Test public void testDoubleMapMultimap() { @@ -177,6 +184,11 @@ private static TestingAggregationFunction getAggregationFunction(Type keyType, T return FUNCTION_RESOLUTION.getAggregateFunction(QualifiedName.of(MultimapAggregationFunction.NAME), fromTypes(keyType, valueType)); } + /** + * Given a list of keys and a list of corresponding values, manually + * aggregate them into a map of list and check that Trino's aggregation has + * the same results. + */ private static void testMultimapAgg(Type keyType, List expectedKeys, Type valueType, List expectedValues) { checkState(expectedKeys.size() == expectedValues.size(), "expectedKeys and expectedValues should have equal size"); diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayExceptFunction.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayExceptFunction.java index 9011d8a0a01a..011500c6e539 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayExceptFunction.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestArrayExceptFunction.java @@ -66,4 +66,11 @@ public void testDuplicates() assertFunction("array_except(ARRAY[VARCHAR 'x', 'x', 'y', 'z'], ARRAY['x', 'y', 'x'])", new ArrayType(VARCHAR), ImmutableList.of("z")); assertFunction("array_except(ARRAY[true, false, null, true, false, null], ARRAY[true, true, true])", new ArrayType(BOOLEAN), asList(false, null)); } + + @Test + public void testNonDistinctNonEqualValues() + { + assertFunction("array_except(ARRAY[NaN()], ARRAY[NaN()])", new ArrayType(DOUBLE), ImmutableList.of()); + assertFunction("array_except(ARRAY[1, NaN(), 3], ARRAY[NaN(), 3])", new ArrayType(DOUBLE), ImmutableList.of(1.0)); + } } diff --git a/core/trino-main/src/test/java/io/trino/type/TestArrayOperators.java b/core/trino-main/src/test/java/io/trino/type/TestArrayOperators.java index a89db7c2ee2e..6fdb3de2c6d2 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestArrayOperators.java +++ b/core/trino-main/src/test/java/io/trino/type/TestArrayOperators.java @@ -1422,6 +1422,8 @@ public void testArrayUnion() assertFunction("ARRAY_UNION(ARRAY [8.3E0, 1.6E0, 4.1E0, 5.2E0], ARRAY [4.0E0, 5.2E0, 8.3E0, 9.7E0, 3.5E0])", new ArrayType(DOUBLE), ImmutableList.of(8.3, 1.6, 4.1, 5.2, 4.0, 9.7, 3.5)); assertFunction("ARRAY_UNION(ARRAY [5.1E0, 7, 3.0E0, 4.8E0, 10], ARRAY [6.5E0, 10.0E0, 1.9E0, 5.1E0, 3.9E0, 4.8E0])", new ArrayType(DOUBLE), ImmutableList.of(5.1, 7.0, 3.0, 4.8, 10.0, 6.5, 1.9, 3.9)); assertFunction("ARRAY_UNION(ARRAY [ARRAY [4, 5], ARRAY [6, 7]], ARRAY [ARRAY [4, 5], ARRAY [6, 8]])", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(4, 5), ImmutableList.of(6, 7), ImmutableList.of(6, 8))); + assertFunction("ARRAY_UNION(ARRAY [NaN()], ARRAY [NaN()])", new ArrayType(DOUBLE), ImmutableList.of(NaN)); // results unique based on IS DISTINCT semantics + assertFunction("ARRAY_UNION(ARRAY [1, NaN(), 3], ARRAY [1, NaN()])", new ArrayType(DOUBLE), ImmutableList.of(1.0, NaN, 3.0)); } @Test diff --git a/core/trino-main/src/test/java/io/trino/type/TestMapOperators.java b/core/trino-main/src/test/java/io/trino/type/TestMapOperators.java index 674d0ef08f1a..bc678f32d215 100644 --- a/core/trino-main/src/test/java/io/trino/type/TestMapOperators.java +++ b/core/trino-main/src/test/java/io/trino/type/TestMapOperators.java @@ -878,6 +878,14 @@ public void testMapConcat() decimal("2.2", createDecimalType(2, 1)), decimal("5.1", createDecimalType(2, 1)), decimal("2.2", createDecimalType(2, 1)))); + + // Compare keys with IS DISTINCT semantics + assertFunction("MAP_CONCAT(" + + "MAP(ARRAY[NaN()], ARRAY[1])," + + "MAP(ARRAY[NaN()], ARRAY[2])," + + "MAP(ARRAY[NaN()], ARRAY[3]))", + mapType(DOUBLE, INTEGER), + ImmutableMap.of(Double.NaN, 3)); } @Test @@ -909,6 +917,7 @@ public void testMapToMapCast() assertInvalidCast("CAST(MAP(ARRAY[1, 2], ARRAY[6, 9]) AS MAP)", "duplicate keys"); assertInvalidCast("CAST(MAP(ARRAY[json 'null'], ARRAY[1]) AS MAP)", "map key is null"); + assertInvalidCast("CAST(MAP(ARRAY['NaN', ' NaN '], ARRAY[1, 2]) AS MAP)", "duplicate keys"); } @Test @@ -950,6 +959,8 @@ public void testMapFromEntries() assertInvalidFunction("map_from_entries(ARRAY[(1.0, 1), (1.0, 2)])", "Duplicate keys (1.0) are not allowed"); assertInvalidFunction("map_from_entries(ARRAY[(ARRAY[1, 2], 1), (ARRAY[1, 2], 2)])", "Duplicate keys ([1, 2]) are not allowed"); assertInvalidFunction("map_from_entries(ARRAY[(MAP(ARRAY[1], ARRAY[2]), 1), (MAP(ARRAY[1], ARRAY[2]), 2)])", "Duplicate keys ({1=2}) are not allowed"); + assertInvalidFunction("map_from_entries(ARRAY[(NaN(), 1), (NaN(), 2)])", "Duplicate keys (NaN) are not allowed"); + assertInvalidFunction("map_from_entries(ARRAY[(null, 1), (null, 2)])", "map key cannot be null"); assertInvalidFunction("map_from_entries(ARRAY[null])", "map entry cannot be null"); assertInvalidFunction("map_from_entries(ARRAY[(1, 2), null])", "map entry cannot be null"); @@ -985,6 +996,11 @@ public void testMultimapFromEntries() "y", ImmutableList.of(2.0, 2.5), "z", singletonList(null))); + assertFunction( + "multimap_from_entries(ARRAY[(NaN(), 1), (NaN(), 2)])", + mapType(DOUBLE, new ArrayType(INTEGER)), + ImmutableMap.of(Double.NaN, ImmutableList.of(1, 2))); + // invalid invocation assertInvalidFunction("multimap_from_entries(ARRAY[(null, 1), (null, 2)])", "map key cannot be null"); assertInvalidFunction("multimap_from_entries(ARRAY[null])", "map entry cannot be null");