Skip to content
Merged
16 changes: 16 additions & 0 deletions core/trino-main/src/main/java/io/trino/FeaturesConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ public class FeaturesConfig

private boolean faultTolerantExecutionExchangeEncryptionEnabled = true;

private boolean flatGroupByHash = true;

public enum DataIntegrityVerification
{
NONE,
Expand Down Expand Up @@ -512,4 +514,18 @@ public void applyFaultTolerantExecutionDefaults()
{
exchangeCompressionEnabled = true;
}

@Deprecated
public boolean isFlatGroupByHash()
{
return flatGroupByHash;
}

@Deprecated
@Config("legacy.flat-group-by-hash")
public FeaturesConfig setFlatGroupByHash(boolean flatGroupByHash)
{
this.flatGroupByHash = flatGroupByHash;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ public final class SystemSessionProperties
public static final String FORCE_SPILLING_JOIN = "force_spilling_join";
public static final String FAULT_TOLERANT_EXECUTION_FORCE_PREFERRED_WRITE_PARTITIONING_ENABLED = "fault_tolerant_execution_force_preferred_write_partitioning_enabled";
public static final String PAGE_PARTITIONING_BUFFER_POOL_SIZE = "page_partitioning_buffer_pool_size";
public static final String FLAT_GROUP_BY_HASH = "flat_group_by_hash";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -1070,6 +1071,11 @@ public SystemSessionProperties(
integerProperty(PAGE_PARTITIONING_BUFFER_POOL_SIZE,
"Maximum number of free buffers in the per task partitioned page buffer pool. Setting this to zero effectively disables the pool",
taskManagerConfig.getPagePartitioningBufferPoolSize(),
true),
booleanProperty(
FLAT_GROUP_BY_HASH,
"Enable new flat group by hash",
featuresConfig.isFlatGroupByHash(),
true));
}

Expand Down Expand Up @@ -1918,4 +1924,9 @@ public static int getPagePartitioningBufferPoolSize(Session session)
{
return session.getSystemProperty(PAGE_PARTITIONING_BUFFER_POOL_SIZE, Integer.class);
}

public static boolean isFlatGroupByHash(Session session)
{
return session.getSystemProperty(FLAT_GROUP_BY_HASH, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

Expand Down Expand Up @@ -239,6 +240,12 @@ private static void verifyMethodHandleSignature(BoundSignature boundSignature, S
verifyFunctionSignature(parameterType.equals(Block.class) && methodType.parameterType(parameterIndex + 1).equals(int.class),
"Expected %s argument types to be Block and int".formatted(argumentConvention));
break;
case FLAT:
verifyFunctionSignature(parameterType.equals(byte[].class) &&
methodType.parameterType(parameterIndex + 1).equals(int.class) &&
methodType.parameterType(parameterIndex + 2).equals(byte[].class),
"Expected FLAT argument types to be byte[], int, byte[]");
break;
case IN_OUT:
verifyFunctionSignature(parameterType.equals(InOut.class), "Expected IN_OUT argument type to be InOut");
break;
Expand Down Expand Up @@ -270,6 +277,14 @@ private static void verifyMethodHandleSignature(BoundSignature boundSignature, S
verifyFunctionSignature(methodType.returnType().equals(void.class),
"Expected return type to be void, but is %s", methodType.returnType());
break;
case FLAT_RETURN:
List<Class<?>> parameters = methodType.parameterList();
parameters = parameters.subList(parameters.size() - 4, parameters.size());
verifyFunctionSignature(parameters.equals(List.of(byte[].class, int.class, byte[].class, int.class)),
"Expected last argument types to be (byte[], int, byte[], int), but is %s", methodType);
verifyFunctionSignature(methodType.returnType().equals(void.class),
"Expected return type to be void, but is %s", methodType.returnType());
break;
default:
throw new UnsupportedOperationException("Unknown return convention: " + convention.getReturnConvention());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ private static Class<?> getNullAwareContainerType(Class<?> clazz, InvocationRetu
return switch (returnConvention) {
case NULLABLE_RETURN -> Primitives.wrap(clazz);
case DEFAULT_ON_NULL, FAIL_ON_NULL -> clazz;
case BLOCK_BUILDER -> void.class;
case BLOCK_BUILDER, FLAT_RETURN -> void.class;
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package io.trino.operator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.TrinoException;
Expand All @@ -24,10 +23,8 @@
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.AbstractLongType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;

import java.util.Arrays;
import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
Expand All @@ -50,10 +47,7 @@ public class BigintGroupByHash
private static final int BATCH_SIZE = 1024;

private static final float FILL_RATIO = 0.75f;
private static final List<Type> TYPES = ImmutableList.of(BIGINT);
private static final List<Type> TYPES_WITH_RAW_HASH = ImmutableList.of(BIGINT, BIGINT);

private final int hashChannel;
private final boolean outputRawHash;

private int hashCapacity;
Expand All @@ -78,12 +72,10 @@ public class BigintGroupByHash
private long preallocatedMemoryInBytes;
private long currentPageSizeInBytes;

public BigintGroupByHash(int hashChannel, boolean outputRawHash, int expectedSize, UpdateMemory updateMemory)
public BigintGroupByHash(boolean outputRawHash, int expectedSize, UpdateMemory updateMemory)
{
checkArgument(hashChannel >= 0, "hashChannel must be at least zero");
checkArgument(expectedSize > 0, "expectedSize must be greater than zero");

this.hashChannel = hashChannel;
this.outputRawHash = outputRawHash;

hashCapacity = arraySize(expectedSize, FILL_RATIO);
Expand Down Expand Up @@ -111,12 +103,6 @@ public long getEstimatedSize()
preallocatedMemoryInBytes;
}

@Override
public List<Type> getTypes()
{
return outputRawHash ? TYPES_WITH_RAW_HASH : TYPES;
}

@Override
public int getGroupCount()
{
Expand Down Expand Up @@ -150,7 +136,7 @@ public void appendValuesTo(int groupId, PageBuilder pageBuilder)
public Work<?> addPage(Page page)
{
currentPageSizeInBytes = page.getRetainedSizeInBytes();
Block block = page.getBlock(hashChannel);
Block block = page.getBlock(0);
if (block instanceof RunLengthEncodedBlock rleBlock) {
return new AddRunLengthEncodedPageWork(rleBlock);
}
Expand All @@ -165,7 +151,7 @@ public Work<?> addPage(Page page)
public Work<int[]> getGroupIds(Page page)
{
currentPageSizeInBytes = page.getRetainedSizeInBytes();
Block block = page.getBlock(hashChannel);
Block block = page.getBlock(0);
if (block instanceof RunLengthEncodedBlock rleBlock) {
return new GetRunLengthEncodedGroupIdsWork(rleBlock);
}
Expand All @@ -177,9 +163,9 @@ public Work<int[]> getGroupIds(Page page)
}

@Override
public boolean contains(int position, Page page, int[] hashChannels)
public boolean contains(int position, Page page)
{
Block block = page.getBlock(hashChannel);
Block block = page.getBlock(0);
if (block.isNull(position)) {
return nullGroupId >= 0;
}
Expand Down
54 changes: 16 additions & 38 deletions core/trino-main/src/main/java/io/trino/operator/ChannelSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
import io.trino.memory.context.LocalMemoryContext;
import io.trino.spi.Page;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.gen.JoinCompiler;
import io.trino.type.BlockTypeOperators;

import java.util.List;
import java.util.Optional;

import static io.trino.operator.GroupByHash.createGroupByHash;
import static io.trino.type.UnknownType.UNKNOWN;
Expand All @@ -32,18 +29,11 @@ public class ChannelSet
{
private final GroupByHash hash;
private final boolean containsNull;
private final int[] hashChannels;

public ChannelSet(GroupByHash hash, boolean containsNull, int[] hashChannels)
private ChannelSet(GroupByHash hash, boolean containsNull)
{
this.hash = hash;
this.containsNull = containsNull;
this.hashChannels = hashChannels;
}

public Type getType()
{
return hash.getTypes().get(0);
}

public long getEstimatedSizeInBytes()
Expand All @@ -68,53 +58,41 @@ public boolean containsNull()

public boolean contains(int position, Page page)
{
return hash.contains(position, page, hashChannels);
return hash.contains(position, page);
}

public boolean contains(int position, Page page, long rawHash)
{
return hash.contains(position, page, hashChannels, rawHash);
return hash.contains(position, page, rawHash);
}

public static class ChannelSetBuilder
{
private static final int[] HASH_CHANNELS = {0};

private final GroupByHash hash;
private final Page nullBlockPage;
private final Type type;
private final OperatorContext operatorContext;
private final LocalMemoryContext localMemoryContext;
private final GroupByHash hash;

public ChannelSetBuilder(Type type, Optional<Integer> hashChannel, int expectedPositions, OperatorContext operatorContext, JoinCompiler joinCompiler, BlockTypeOperators blockTypeOperators)
public ChannelSetBuilder(Type type, boolean hasPrecomputedHash, int expectedPositions, OperatorContext operatorContext, JoinCompiler joinCompiler, TypeOperators typeOperators)
{
List<Type> types = ImmutableList.of(type);
this.type = requireNonNull(type, "type is null");
this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
this.localMemoryContext = operatorContext.localUserMemoryContext();
this.hash = createGroupByHash(
operatorContext.getSession(),
types,
HASH_CHANNELS,
hashChannel,
ImmutableList.of(type),
hasPrecomputedHash,
expectedPositions,
joinCompiler,
blockTypeOperators,
typeOperators,
this::updateMemoryReservation);
this.nullBlockPage = new Page(type.createBlockBuilder(null, 1, UNKNOWN.getFixedSize()).appendNull().build());
this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
this.localMemoryContext = operatorContext.localUserMemoryContext();
}

public ChannelSet build()
{
return new ChannelSet(hash, hash.contains(0, nullBlockPage, HASH_CHANNELS), HASH_CHANNELS);
}

public long getEstimatedSize()
{
return hash.getEstimatedSize();
}

public int size()
{
return hash.getGroupCount();
Page nullBlockPage = new Page(type.createBlockBuilder(null, 1, UNKNOWN.getFixedSize()).appendNull().build());
boolean containsNull = hash.contains(0, nullBlockPage);
return new ChannelSet(hash, containsNull);
}

public Work<?> addPage(Page page)
Expand Down
Loading