Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import io.trino.array.ObjectBigArray;
import io.trino.operator.aggregation.NullablePosition;
import io.trino.operator.aggregation.TypedSet;
import io.trino.operator.scalar.BlockSet;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.MapBlockBuilder;
Expand All @@ -37,7 +37,6 @@
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;
import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom;

import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
Expand Down Expand Up @@ -96,7 +95,7 @@ public static void output(
ObjectBigArray<BlockBuilder> valueArrayBlockBuilders = new ObjectBigArray<>();
valueArrayBlockBuilders.ensureCapacity(state.getEntryCount());
BlockBuilder distinctKeyBlockBuilder = keyType.createBlockBuilder(null, state.getEntryCount(), expectedValueSize(keyType, 100));
TypedSet keySet = createDistinctTypedSet(keyType, keyDistinctFrom, keyHashCode, state.getEntryCount(), "multimap_agg");
BlockSet keySet = new BlockSet(keyType, keyDistinctFrom, keyHashCode, state.getEntryCount());

state.forEach((key, value, keyValueIndex) -> {
// Merge values of the same key into an array
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,22 @@
*/
package io.trino.operator.scalar;

import com.google.common.collect.ImmutableList;
import io.trino.operator.aggregation.TypedSet;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BufferedArrayValueBuilder;
import io.trino.spi.function.Convention;
import io.trino.spi.function.Description;
import io.trino.spi.function.OperatorDependency;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.Type;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;
import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;

import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet;
import static io.trino.operator.scalar.BlockSet.MAX_FUNCTION_MEMORY;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.OperatorType.HASH_CODE;
Expand All @@ -42,12 +40,12 @@
public final class ArrayDistinctFunction
{
public static final String NAME = "array_distinct";
private final PageBuilder pageBuilder;
private final BufferedArrayValueBuilder arrayValueBuilder;

@TypeParameter("E")
public ArrayDistinctFunction(@TypeParameter("E") Type elementType)
{
pageBuilder = new PageBuilder(ImmutableList.of(elementType));
arrayValueBuilder = BufferedArrayValueBuilder.createBuffered(new ArrayType(elementType));
}

@TypeParameter("E")
Expand Down Expand Up @@ -76,27 +74,19 @@ public Block distinct(
return array.getSingleValueBlock(0);
}

if (pageBuilder.isFull()) {
pageBuilder.reset();
}

BlockBuilder distinctElementsBlockBuilder = pageBuilder.getBlockBuilder(0);
TypedSet distinctElements = createDistinctTypedSet(
BlockSet distinctElements = new BlockSet(
type,
elementIsDistinctFrom,
elementHashCode,
distinctElementsBlockBuilder,
array.getPositionCount(),
"array_distinct");
array.getPositionCount());

for (int i = 0; i < array.getPositionCount(); i++) {
distinctElements.add(array, i);
}

pageBuilder.declarePositions(distinctElements.size());
return distinctElementsBlockBuilder.getRegion(
distinctElementsBlockBuilder.getPositionCount() - distinctElements.size(),
distinctElements.size());
return arrayValueBuilder.build(
distinctElements.size(),
blockBuilder -> distinctElements.getAllWithSizeLimit(blockBuilder, "array_distinct", MAX_FUNCTION_MEMORY));
}

@SqlType("array(bigint)")
Expand All @@ -106,36 +96,23 @@ public Block bigintDistinct(@SqlType("array(bigint)") Block array)
return array;
}

boolean containsNull = false;
LongSet set = new LongOpenHashSet(array.getPositionCount());
int distinctCount = 0;

if (pageBuilder.isFull()) {
pageBuilder.reset();
}

BlockBuilder distinctElementBlockBuilder = pageBuilder.getBlockBuilder(0);
for (int i = 0; i < array.getPositionCount(); i++) {
if (array.isNull(i)) {
if (!containsNull) {
containsNull = true;
distinctElementBlockBuilder.appendNull();
distinctCount++;
return arrayValueBuilder.build(array.getPositionCount(), distinctElementBlockBuilder -> {
boolean containsNull = false;
LongSet set = new LongOpenHashSet(array.getPositionCount());

for (int i = 0; i < array.getPositionCount(); i++) {
if (array.isNull(i)) {
if (!containsNull) {
containsNull = true;
distinctElementBlockBuilder.appendNull();
}
continue;
}
long value = BIGINT.getLong(array, i);
if (set.add(value)) {
BIGINT.writeLong(distinctElementBlockBuilder, value);
}
continue;
}
long value = BIGINT.getLong(array, i);
if (!set.contains(value)) {
set.add(value);
distinctCount++;
BIGINT.appendTo(array, i, distinctElementBlockBuilder);
}
}

pageBuilder.declarePositions(distinctCount);

return distinctElementBlockBuilder.getRegion(
distinctElementBlockBuilder.getPositionCount() - distinctCount,
distinctCount);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.trino.operator.scalar;

import io.trino.operator.aggregation.TypedSet;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.Convention;
Expand All @@ -26,7 +25,6 @@
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;
import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom;

import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.OperatorType.HASH_CODE;
Expand Down Expand Up @@ -60,13 +58,13 @@ public static Block except(
return leftArray;
}

TypedSet typedSet = createDistinctTypedSet(type, isDistinctOperator, elementHashCode, leftPositionCount, "array_except");
BlockBuilder distinctElementBlockBuilder = type.createBlockBuilder(null, leftPositionCount);
BlockSet set = new BlockSet(type, isDistinctOperator, elementHashCode, rightPositionCount + leftPositionCount);
for (int i = 0; i < rightPositionCount; i++) {
typedSet.add(rightArray, i);
set.add(rightArray, i);
}
BlockBuilder distinctElementBlockBuilder = type.createBlockBuilder(null, leftPositionCount);
for (int i = 0; i < leftPositionCount; i++) {
if (typedSet.add(leftArray, i)) {
if (set.add(leftArray, i)) {
type.appendTo(leftArray, i, distinctElementBlockBuilder);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@
*/
package io.trino.operator.scalar;

import com.google.common.collect.ImmutableList;
import io.trino.operator.aggregation.TypedSet;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BufferedArrayValueBuilder;
import io.trino.spi.function.Convention;
import io.trino.spi.function.Description;
import io.trino.spi.function.OperatorDependency;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.Type;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;
import io.trino.type.BlockTypeOperators.BlockPositionIsDistinctFrom;

import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet;
import static io.trino.operator.scalar.BlockSet.MAX_FUNCTION_MEMORY;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.OperatorType.HASH_CODE;
Expand All @@ -38,12 +36,12 @@
@Description("Intersects elements of the two given arrays")
public final class ArrayIntersectFunction
{
private final PageBuilder pageBuilder;
private final BufferedArrayValueBuilder arrayValueBuilder;

@TypeParameter("E")
public ArrayIntersectFunction(@TypeParameter("E") Type elementType)
{
pageBuilder = new PageBuilder(ImmutableList.of(elementType));
arrayValueBuilder = BufferedArrayValueBuilder.createBuffered(new ArrayType(elementType));
}

@TypeParameter("E")
Expand Down Expand Up @@ -74,27 +72,21 @@ public Block intersect(
return rightArray;
}

if (pageBuilder.isFull()) {
pageBuilder.reset();
}

TypedSet rightTypedSet = createDistinctTypedSet(type, elementIsDistinctFrom, elementHashCode, rightPositionCount, "array_intersect");
BlockSet rightSet = new BlockSet(type, elementIsDistinctFrom, elementHashCode, rightPositionCount);
for (int i = 0; i < rightPositionCount; i++) {
rightTypedSet.add(rightArray, i);
rightSet.add(rightArray, i);
}

BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0);

// The intersected set can have at most rightPositionCount elements
TypedSet intersectTypedSet = createDistinctTypedSet(type, elementIsDistinctFrom, elementHashCode, blockBuilder, rightPositionCount, "array_intersect");
BlockSet intersectSet = new BlockSet(type, elementIsDistinctFrom, elementHashCode, rightSet.size());
for (int i = 0; i < leftPositionCount; i++) {
if (rightTypedSet.contains(leftArray, i)) {
intersectTypedSet.add(leftArray, i);
if (rightSet.contains(leftArray, i)) {
intersectSet.add(leftArray, i);
}
}

pageBuilder.declarePositions(intersectTypedSet.size());

return blockBuilder.getRegion(blockBuilder.getPositionCount() - intersectTypedSet.size(), intersectTypedSet.size());
return arrayValueBuilder.build(
intersectSet.size(),
blockBuilder -> intersectSet.getAllWithSizeLimit(blockBuilder, "array_intersect", MAX_FUNCTION_MEMORY));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.trino.operator.scalar;

import io.trino.operator.aggregation.TypedSet;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.Convention;
Expand All @@ -30,7 +29,7 @@

import java.util.concurrent.atomic.AtomicBoolean;

import static io.trino.operator.aggregation.TypedSet.createDistinctTypedSet;
import static io.trino.operator.scalar.BlockSet.MAX_FUNCTION_MEMORY;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.OperatorType.HASH_CODE;
Expand Down Expand Up @@ -58,26 +57,23 @@ public static Block union(
@SqlType("array(E)") Block leftArray,
@SqlType("array(E)") Block rightArray)
{
int leftArrayCount = leftArray.getPositionCount();
int rightArrayCount = rightArray.getPositionCount();
BlockBuilder distinctElementBlockBuilder = type.createBlockBuilder(null, leftArrayCount + rightArrayCount);
TypedSet typedSet = createDistinctTypedSet(
BlockSet set = new BlockSet(
type,
isDistinctOperator,
elementHashCode,
distinctElementBlockBuilder,
leftArrayCount + rightArrayCount,
"array_union");
leftArray.getPositionCount() + rightArray.getPositionCount());

for (int i = 0; i < leftArray.getPositionCount(); i++) {
typedSet.add(leftArray, i);
set.add(leftArray, i);
}

for (int i = 0; i < rightArray.getPositionCount(); i++) {
typedSet.add(rightArray, i);
set.add(rightArray, i);
}

return distinctElementBlockBuilder.build();
BlockBuilder blockBuilder = type.createBlockBuilder(null, set.size());
set.getAllWithSizeLimit(blockBuilder, "array_union", MAX_FUNCTION_MEMORY);
return blockBuilder.build();
}

@SqlType("array(bigint)")
Expand Down
Loading