diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/TypedSet.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/TypedSet.java index bee405f94ff7..8287c6bf3135 100644 --- a/core/trino-main/src/main/java/io/trino/operator/aggregation/TypedSet.java +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/TypedSet.java @@ -79,13 +79,38 @@ public static TypedSet createEqualityTypedSet( BlockPositionHashCode elementHashCodeOperator, int expectedSize, String functionName) + { + return createEqualityTypedSet( + elementType, + elementEqualOperator, + elementHashCodeOperator, + elementType.createBlockBuilder(null, expectedSize), + expectedSize, + functionName); + } + + /** + * Create a {@code TypedSet} that compares its elements using SQL equality + * comparison. + * + *

The elements of the set will be written in the given {@code BlockBuilder}. + * If the {@code BlockBuilder} is modified by the caller, the set will stop + * functioning correctly. + */ + public static TypedSet createEqualityTypedSet( + Type elementType, + BlockPositionEqual elementEqualOperator, + BlockPositionHashCode elementHashCodeOperator, + BlockBuilder elementBlock, + int expectedSize, + String functionName) { return new TypedSet( elementType, elementEqualOperator, null, elementHashCodeOperator, - elementType.createBlockBuilder(null, expectedSize), + elementBlock, expectedSize, functionName, false); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayDistinctFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayDistinctFunction.java index db26ebb03c41..0a30fd62cf50 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayDistinctFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayDistinctFunction.java @@ -75,24 +75,27 @@ public Block distinct( return array.getSingleValueBlock(0); } - TypedSet typedSet = createDistinctTypedSet(type, elementIsDistinctFrom, elementHashCode, array.getPositionCount(), "array_distinct"); - int distinctCount = 0; - if (pageBuilder.isFull()) { pageBuilder.reset(); } - BlockBuilder distinctElementBlockBuilder = pageBuilder.getBlockBuilder(0); + BlockBuilder distinctElementsBlockBuilder = pageBuilder.getBlockBuilder(0); + TypedSet distinctElements = createDistinctTypedSet( + type, + elementIsDistinctFrom, + elementHashCode, + distinctElementsBlockBuilder, + array.getPositionCount(), + "array_distinct"); + for (int i = 0; i < array.getPositionCount(); i++) { - if (typedSet.add(array, i)) { - distinctCount++; - type.appendTo(array, i, distinctElementBlockBuilder); - } + distinctElements.add(array, i); } - pageBuilder.declarePositions(distinctCount); - - return distinctElementBlockBuilder.getRegion(distinctElementBlockBuilder.getPositionCount() - distinctCount, distinctCount); + pageBuilder.declarePositions(distinctElements.size()); + return distinctElementsBlockBuilder.getRegion( + distinctElementsBlockBuilder.getPositionCount() - distinctElements.size(), + distinctElements.size()); } @SqlType("array(bigint)") @@ -130,6 +133,8 @@ public Block bigintDistinct(@SqlType("array(bigint)") Block array) pageBuilder.declarePositions(distinctCount); - return distinctElementBlockBuilder.getRegion(distinctElementBlockBuilder.getPositionCount() - distinctCount, distinctCount); + return distinctElementBlockBuilder.getRegion( + distinctElementBlockBuilder.getPositionCount() - distinctCount, + distinctCount); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayExceptFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayExceptFunction.java index c3686e9270a9..cb3cfb0a1c46 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayExceptFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayExceptFunction.java @@ -57,10 +57,11 @@ public static Block except( int leftPositionCount = leftArray.getPositionCount(); int rightPositionCount = rightArray.getPositionCount(); - if (leftPositionCount == 0) { + if (leftPositionCount == 0 || rightPositionCount == 0) { return leftArray; } - TypedSet typedSet = createEqualityTypedSet(type, elementEqual, elementHashCode, leftPositionCount + rightPositionCount, "array_except"); + + TypedSet typedSet = createEqualityTypedSet(type, elementEqual, elementHashCode, leftPositionCount, "array_except"); BlockBuilder distinctElementBlockBuilder = type.createBlockBuilder(null, leftPositionCount); for (int i = 0; i < rightPositionCount; i++) { typedSet.add(rightArray, i); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayUnionFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayUnionFunction.java index 61e7b74e2931..8596c53fe694 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayUnionFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/ArrayUnionFunction.java @@ -61,21 +61,24 @@ public static Block union( { int leftArrayCount = leftArray.getPositionCount(); int rightArrayCount = rightArray.getPositionCount(); - TypedSet typedSet = createEqualityTypedSet(type, elementEqual, elementHashCode, leftArrayCount + rightArrayCount, "array_union"); BlockBuilder distinctElementBlockBuilder = type.createBlockBuilder(null, leftArrayCount + rightArrayCount); - appendTypedArray(leftArray, type, typedSet, distinctElementBlockBuilder); - appendTypedArray(rightArray, type, typedSet, distinctElementBlockBuilder); + TypedSet typedSet = createEqualityTypedSet( + type, + elementEqual, + elementHashCode, + distinctElementBlockBuilder, + leftArrayCount + rightArrayCount, + "array_union"); - return distinctElementBlockBuilder.build(); - } + for (int i = 0; i < leftArray.getPositionCount(); i++) { + typedSet.add(leftArray, i); + } - private static void appendTypedArray(Block array, Type type, TypedSet typedSet, BlockBuilder blockBuilder) - { - for (int i = 0; i < array.getPositionCount(); i++) { - if (typedSet.add(array, i)) { - type.appendTo(array, i, blockBuilder); - } + for (int i = 0; i < rightArray.getPositionCount(); i++) { + typedSet.add(rightArray, i); } + + return distinctElementBlockBuilder.build(); } @SqlType("array(bigint)") diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapFromEntriesFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapFromEntriesFunction.java index 95f80e4927b1..aff0304d42ee 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapFromEntriesFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapFromEntriesFunction.java @@ -70,44 +70,44 @@ public Block mapFromEntries( convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) BlockPositionHashCode keyHashCode, @TypeParameter("map(K,V)") MapType mapType, ConnectorSession session, - @SqlType("array(row(K,V))") Block block) + @SqlType("array(row(K,V))") Block mapEntries) { Type keyType = mapType.getKeyType(); Type valueType = mapType.getValueType(); - RowType rowType = RowType.anonymous(ImmutableList.of(keyType, valueType)); + RowType mapEntryType = RowType.anonymous(ImmutableList.of(keyType, valueType)); if (pageBuilder.isFull()) { pageBuilder.reset(); } - int entryCount = block.getPositionCount(); + int entryCount = mapEntries.getPositionCount(); BlockBuilder mapBlockBuilder = pageBuilder.getBlockBuilder(0); BlockBuilder resultBuilder = mapBlockBuilder.beginBlockEntry(); TypedSet uniqueKeys = createEqualityTypedSet(keyType, keyEqual, keyHashCode, entryCount, "map_from_entries"); for (int i = 0; i < entryCount; i++) { - if (block.isNull(i)) { + if (mapEntries.isNull(i)) { mapBlockBuilder.closeEntry(); pageBuilder.declarePosition(); throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "map entry cannot be null"); } - Block rowBlock = rowType.getObject(block, i); + Block mapEntryBlock = mapEntryType.getObject(mapEntries, i); - if (rowBlock.isNull(0)) { + if (mapEntryBlock.isNull(0)) { mapBlockBuilder.closeEntry(); pageBuilder.declarePosition(); throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "map key cannot be null"); } - if (!uniqueKeys.add(rowBlock, 0)) { + if (!uniqueKeys.add(mapEntryBlock, 0)) { mapBlockBuilder.closeEntry(); pageBuilder.declarePosition(); - throw new TrinoException(INVALID_FUNCTION_ARGUMENT, format("Duplicate keys (%s) are not allowed", keyType.getObjectValue(session, rowBlock, 0))); + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, format("Duplicate keys (%s) are not allowed", keyType.getObjectValue(session, mapEntryBlock, 0))); } - keyType.appendTo(rowBlock, 0, resultBuilder); - valueType.appendTo(rowBlock, 1, resultBuilder); + keyType.appendTo(mapEntryBlock, 0, resultBuilder); + valueType.appendTo(mapEntryBlock, 1, resultBuilder); } mapBlockBuilder.closeEntry(); diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java index 3e90a60390bf..ffe419bdf4cd 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MapToMapCast.java @@ -235,15 +235,17 @@ public static Block checkBlockIsNotNull(Block value) public static Block mapCast( MethodHandle keyProcessFunction, MethodHandle valueProcessFunction, - Type toMapType, + Type targetType, BlockPositionEqual keyEqual, BlockPositionHashCode keyHashCode, ConnectorSession session, Block fromMap) { - checkState(toMapType.getTypeParameters().size() == 2, "Expect two type parameters for toMapType"); - Type toKeyType = toMapType.getTypeParameters().get(0); - TypedSet typedSet = createEqualityTypedSet(toKeyType, keyEqual, keyHashCode, fromMap.getPositionCount() / 2, "map-to-map cast"); + checkState(targetType.getTypeParameters().size() == 2, "Expect two type parameters for targetType"); + Type toKeyType = targetType.getTypeParameters().get(0); + TypedSet resultKeys = createEqualityTypedSet(toKeyType, keyEqual, keyHashCode, fromMap.getPositionCount() / 2, "map-to-map cast"); + + // Cast the keys into a new block BlockBuilder keyBlockBuilder = toKeyType.createBlockBuilder(null, fromMap.getPositionCount() / 2); for (int i = 0; i < fromMap.getPositionCount(); i += 2) { try { @@ -255,10 +257,10 @@ public static Block mapCast( } Block keyBlock = keyBlockBuilder.build(); - BlockBuilder mapBlockBuilder = toMapType.createBlockBuilder(null, 1); + BlockBuilder mapBlockBuilder = targetType.createBlockBuilder(null, 1); BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry(); for (int i = 0; i < fromMap.getPositionCount(); i += 2) { - if (typedSet.add(keyBlock, i / 2)) { + if (resultKeys.add(keyBlock, i / 2)) { toKeyType.appendTo(keyBlock, i / 2, blockBuilder); if (fromMap.isNull(i + 1)) { blockBuilder.appendNull(); @@ -279,6 +281,6 @@ public static Block mapCast( } mapBlockBuilder.closeEntry(); - return (Block) toMapType.getObject(mapBlockBuilder, mapBlockBuilder.getPositionCount() - 1); + return (Block) targetType.getObject(mapBlockBuilder, mapBlockBuilder.getPositionCount() - 1); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/scalar/MultimapFromEntriesFunction.java b/core/trino-main/src/main/java/io/trino/operator/scalar/MultimapFromEntriesFunction.java index 9c0f979705c1..a6f9255431fa 100644 --- a/core/trino-main/src/main/java/io/trino/operator/scalar/MultimapFromEntriesFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/scalar/MultimapFromEntriesFunction.java @@ -76,51 +76,51 @@ public Block multimapFromEntries( operator = HASH_CODE, argumentTypes = "K", convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL)) BlockPositionHashCode keyHashCode, - @SqlType("array(row(K,V))") Block block) + @SqlType("array(row(K,V))") Block mapEntries) { Type keyType = mapType.getKeyType(); Type valueType = ((ArrayType) mapType.getValueType()).getElementType(); - RowType rowType = RowType.anonymous(ImmutableList.of(keyType, valueType)); + RowType mapEntryType = RowType.anonymous(ImmutableList.of(keyType, valueType)); if (pageBuilder.isFull()) { pageBuilder.reset(); } - int entryCount = block.getPositionCount(); + int entryCount = mapEntries.getPositionCount(); if (entryCount > entryIndicesList.length) { initializeEntryIndicesList(entryCount); } TypedSet keySet = createEqualityTypedSet(keyType, keyEqual, keyHashCode, entryCount, NAME); for (int i = 0; i < entryCount; i++) { - if (block.isNull(i)) { + if (mapEntries.isNull(i)) { clearEntryIndices(keySet.size()); throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "map entry cannot be null"); } - Block rowBlock = rowType.getObject(block, i); + Block mapEntryBlock = mapEntryType.getObject(mapEntries, i); - if (rowBlock.isNull(0)) { + if (mapEntryBlock.isNull(0)) { clearEntryIndices(keySet.size()); throw new TrinoException(INVALID_FUNCTION_ARGUMENT, "map key cannot be null"); } - if (keySet.add(rowBlock, 0)) { + if (keySet.add(mapEntryBlock, 0)) { entryIndicesList[keySet.size() - 1].add(i); } else { - entryIndicesList[keySet.positionOf(rowBlock, 0)].add(i); + entryIndicesList[keySet.positionOf(mapEntryBlock, 0)].add(i); } } BlockBuilder multimapBlockBuilder = pageBuilder.getBlockBuilder(0); - BlockBuilder singleMapWriter = multimapBlockBuilder.beginBlockEntry(); + BlockBuilder mapWriter = multimapBlockBuilder.beginBlockEntry(); for (int i = 0; i < keySet.size(); i++) { - keyType.appendTo(rowType.getObject(block, entryIndicesList[i].getInt(0)), 0, singleMapWriter); - BlockBuilder singleArrayWriter = singleMapWriter.beginBlockEntry(); + keyType.appendTo(mapEntryType.getObject(mapEntries, entryIndicesList[i].getInt(0)), 0, mapWriter); + BlockBuilder valuesArray = mapWriter.beginBlockEntry(); for (int entryIndex : entryIndicesList[i]) { - valueType.appendTo(rowType.getObject(block, entryIndex), 1, singleArrayWriter); + valueType.appendTo(mapEntryType.getObject(mapEntries, entryIndex), 1, valuesArray); } - singleMapWriter.closeEntry(); + mapWriter.closeEntry(); } multimapBlockBuilder.closeEntry();