Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions core/trino-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.squareup</groupId>
<artifactId>javapoet</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
Expand Down
12 changes: 12 additions & 0 deletions core/trino-main/src/main/java/io/trino/FeaturesConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public class FeaturesConfig
private boolean legacyCatalogRoles;
private boolean incrementalHashArrayLoadFactorEnabled = true;
private boolean allowSetViewAuthorization;
private boolean useEnhancedGroupBy = true;

private boolean hideInaccessibleColumns;

Expand Down Expand Up @@ -216,6 +217,17 @@ public FeaturesConfig setWriterMinSize(DataSize writerMinSize)
return this;
}

public boolean isUseEnhancedGroupBy()
{
return useEnhancedGroupBy;
}

@Config("optimizer.use-enhanced-group-by")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"enhanced" is not descriptive enough.

public void setUseEnhancedGroupBy(boolean useEnhancedGroupBy)
{
this.useEnhancedGroupBy = useEnhancedGroupBy;
}

@Min(2)
public int getRe2JDfaStatesLimit()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ public final class SystemSessionProperties
public static final String ADAPTIVE_PARTIAL_AGGREGATION_ENABLED = "adaptive_partial_aggregation_enabled";
public static final String ADAPTIVE_PARTIAL_AGGREGATION_MIN_ROWS = "adaptive_partial_aggregation_min_rows";
public static final String ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold";
public static final String TASK_MAX_PARTIAL_AGGREGATION_MEMORY = "task_max_partial_aggregation_memory";
public static final String USE_ENHANCED_GROUP_BY = "use_enhanced_group_by";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -791,6 +793,16 @@ public SystemSessionProperties(
ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD,
"Ratio between aggregation output and input rows above which partial aggregation might be adaptively turned off",
optimizerConfig.getAdaptivePartialAggregationUniqueRowsRatioThreshold(),
false),
dataSizeProperty(
TASK_MAX_PARTIAL_AGGREGATION_MEMORY,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done in a separate PR

"Maximum size of partial aggregation results for distributed aggregations.",
taskManagerConfig.getMaxPartialAggregationMemoryUsage(),
false),
booleanProperty(
USE_ENHANCED_GROUP_BY,
"Enable optimization for aggregations",
featuresConfig.isUseEnhancedGroupBy(),
false));
}

Expand Down Expand Up @@ -1428,4 +1440,14 @@ public static double getAdaptivePartialAggregationUniqueRowsRatioThreshold(Sessi
{
return session.getSystemProperty(ADAPTIVE_PARTIAL_AGGREGATION_UNIQUE_ROWS_RATIO_THRESHOLD, Double.class);
}

public static DataSize getMaxPartialAggregationMemoryUsage(Session session)
{
return session.getSystemProperty(TASK_MAX_PARTIAL_AGGREGATION_MEMORY, DataSize.class);
}

public static boolean isUseEnhancedGroupByEnabled(Session session)
{
return session.getSystemProperty(USE_ENHANCED_GROUP_BY, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ public Work<GroupByIdBlock> getGroupIds(Page page)
}

@Override
public boolean contains(int position, Page page, int[] hashChannels)
public boolean contains(int position, Page page)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can extract the first commit to a separate PR and merge it.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'm gonna prepare the PR

{
Block block = page.getBlock(hashChannel);
if (block.isNull(position)) {
Expand Down
38 changes: 24 additions & 14 deletions core/trino-main/src/main/java/io/trino/operator/ChannelSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,26 @@
import com.google.common.collect.ImmutableList;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.gen.JoinCompiler;
import io.trino.type.BlockTypeOperators;

import java.util.List;
import java.util.Optional;

import static io.trino.operator.GroupByHash.createGroupByHash;
import static io.trino.type.TypeUtils.NULL_HASH_CODE;
import static io.trino.type.UnknownType.UNKNOWN;
import static java.util.Objects.requireNonNull;

public class ChannelSet
{
private final GroupByHash hash;
private final boolean containsNull;
private final int[] hashChannels;

public ChannelSet(GroupByHash hash, boolean containsNull, int[] hashChannels)
public ChannelSet(GroupByHash hash, boolean containsNull)
{
this.hash = hash;
this.containsNull = containsNull;
this.hashChannels = hashChannels;
}

public Type getType()
Expand Down Expand Up @@ -68,12 +66,12 @@ public boolean containsNull()

public boolean contains(int position, Page page)
{
return hash.contains(position, page, hashChannels);
return hash.contains(position, page);
}

public boolean contains(int position, Page page, long rawHash)
{
return hash.contains(position, page, hashChannels, rawHash);
return hash.contains(position, page, rawHash);
}

public static class ChannelSetBuilder
Expand All @@ -85,26 +83,38 @@ public static class ChannelSetBuilder
private final OperatorContext operatorContext;
private final LocalMemoryContext localMemoryContext;

public ChannelSetBuilder(Type type, Optional<Integer> hashChannel, int expectedPositions, OperatorContext operatorContext, JoinCompiler joinCompiler, BlockTypeOperators blockTypeOperators)
public ChannelSetBuilder(Type type, boolean hashPresent, int expectedPositions, OperatorContext operatorContext, GroupByHashFactory groupByHashFactory)
{
// Set builder has a single channel which goes in channel 0, if hash is present, add a hashBlock to channel 1
Optional<Integer> hashChannel = hashPresent ? Optional.of(1) : Optional.empty();
List<Type> types = ImmutableList.of(type);
this.hash = createGroupByHash(
this.hash = groupByHashFactory.createGroupByHash(
operatorContext.getSession(),
types,
HASH_CHANNELS,
hashChannel,
expectedPositions,
joinCompiler,
blockTypeOperators,
this::updateMemoryReservation);
this.nullBlockPage = new Page(type.createBlockBuilder(null, 1, UNKNOWN.getFixedSize()).appendNull().build());
this.nullBlockPage = createNullPage(type, hashPresent);
this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
this.localMemoryContext = operatorContext.localUserMemoryContext();
}

private static Page createNullPage(Type type, boolean hashPresent)
{
Block nullBlock = type.createBlockBuilder(null, 1, UNKNOWN.getFixedSize()).appendNull().build();
if (hashPresent) {
Block nullHashCode = BigintType.BIGINT.createBlockBuilder(null, 1).writeLong(NULL_HASH_CODE).build();
return new Page(nullBlock, nullHashCode);
}
else {
return new Page(nullBlock);
}
}

public ChannelSet build()
{
return new ChannelSet(hash, hash.contains(0, nullBlockPage, HASH_CHANNELS), HASH_CHANNELS);
return new ChannelSet(hash, hash.contains(0, nullBlockPage));
}

public long getEstimatedSize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import io.trino.memory.context.LocalMemoryContext;
import io.trino.spi.Page;
import io.trino.spi.type.Type;
import io.trino.sql.gen.JoinCompiler;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.type.BlockTypeOperators;

import java.util.Arrays;
import java.util.List;
Expand All @@ -32,7 +30,6 @@
import static com.google.common.base.Verify.verify;
import static com.google.common.base.Verify.verifyNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.operator.GroupByHash.createGroupByHash;
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
Expand All @@ -50,8 +47,7 @@ public static class DistinctLimitOperatorFactory
private final long limit;
private final Optional<Integer> hashChannel;
private boolean closed;
private final JoinCompiler joinCompiler;
private final BlockTypeOperators blockTypeOperators;
private final GroupByHashFactory groupByHashFactory;

public DistinctLimitOperatorFactory(
int operatorId,
Expand All @@ -60,8 +56,7 @@ public DistinctLimitOperatorFactory(
List<Integer> distinctChannels,
long limit,
Optional<Integer> hashChannel,
JoinCompiler joinCompiler,
BlockTypeOperators blockTypeOperators)
GroupByHashFactory groupByHashFactory)
{
this.operatorId = operatorId;
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
Expand All @@ -71,8 +66,7 @@ public DistinctLimitOperatorFactory(
checkArgument(limit >= 0, "limit must be at least zero");
this.limit = limit;
this.hashChannel = requireNonNull(hashChannel, "hashChannel is null");
this.joinCompiler = requireNonNull(joinCompiler, "joinCompiler is null");
this.blockTypeOperators = requireNonNull(blockTypeOperators, "blockTypeOperators is null");
this.groupByHashFactory = requireNonNull(groupByHashFactory, "groupByHashFactory is null");
}

@Override
Expand All @@ -83,7 +77,7 @@ public Operator createOperator(DriverContext driverContext)
List<Type> distinctTypes = distinctChannels.stream()
.map(sourceTypes::get)
.collect(toImmutableList());
return new DistinctLimitOperator(operatorContext, distinctChannels, distinctTypes, limit, hashChannel, joinCompiler, blockTypeOperators);
return new DistinctLimitOperator(operatorContext, distinctChannels, distinctTypes, limit, hashChannel, groupByHashFactory);
}

@Override
Expand All @@ -95,7 +89,7 @@ public void noMoreOperators()
@Override
public OperatorFactory duplicate()
{
return new DistinctLimitOperatorFactory(operatorId, planNodeId, sourceTypes, distinctChannels, limit, hashChannel, joinCompiler, blockTypeOperators);
return new DistinctLimitOperatorFactory(operatorId, planNodeId, sourceTypes, distinctChannels, limit, hashChannel, groupByHashFactory);
}
}

Expand All @@ -115,7 +109,7 @@ public OperatorFactory duplicate()
private GroupByIdBlock groupByIds;
private Work<GroupByIdBlock> unfinishedWork;

public DistinctLimitOperator(OperatorContext operatorContext, List<Integer> distinctChannels, List<Type> distinctTypes, long limit, Optional<Integer> hashChannel, JoinCompiler joinCompiler, BlockTypeOperators blockTypeOperators)
public DistinctLimitOperator(OperatorContext operatorContext, List<Integer> distinctChannels, List<Type> distinctTypes, long limit, Optional<Integer> hashChannel, GroupByHashFactory groupByHashFactory)
{
this.operatorContext = requireNonNull(operatorContext, "operatorContext is null");
this.localUserMemoryContext = operatorContext.localUserMemoryContext();
Expand All @@ -131,14 +125,12 @@ public DistinctLimitOperator(OperatorContext operatorContext, List<Integer> dist
outputChannels = distinctChannelInts.clone(); // defensive copy since this is passed into createGroupByHash
}

this.groupByHash = createGroupByHash(
this.groupByHash = groupByHashFactory.createGroupByHash(
operatorContext.getSession(),
distinctTypes,
distinctChannelInts,
hashChannel,
toIntExact(Math.min(limit, 10_000)),
joinCompiler,
blockTypeOperators,
this::updateMemoryReservation);
remainingLimit = limit;
}
Expand Down
72 changes: 34 additions & 38 deletions core/trino-main/src/main/java/io/trino/operator/GroupByHash.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,50 +14,19 @@
package io.trino.operator;

import com.google.common.annotations.VisibleForTesting;
import io.trino.Session;
import io.trino.operator.aggregation.GroupedAggregator;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.Type;
import io.trino.sql.gen.JoinCompiler;
import io.trino.type.BlockTypeOperators;
import it.unimi.dsi.fastutil.ints.IntIterator;

import java.util.List;
import java.util.Optional;

import static io.trino.SystemSessionProperties.isDictionaryAggregationEnabled;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.operator.WorkProcessor.ProcessState;

public interface GroupByHash
{
static GroupByHash createGroupByHash(
Session session,
List<? extends Type> hashTypes,
int[] hashChannels,
Optional<Integer> inputHashChannel,
int expectedSize,
JoinCompiler joinCompiler,
BlockTypeOperators blockTypeOperators,
UpdateMemory updateMemory)
{
return createGroupByHash(hashTypes, hashChannels, inputHashChannel, expectedSize, isDictionaryAggregationEnabled(session), joinCompiler, blockTypeOperators, updateMemory);
}

static GroupByHash createGroupByHash(
List<? extends Type> hashTypes,
int[] hashChannels,
Optional<Integer> inputHashChannel,
int expectedSize,
boolean processDictionary,
JoinCompiler joinCompiler,
BlockTypeOperators blockTypeOperators,
UpdateMemory updateMemory)
{
if (hashTypes.size() == 1 && hashTypes.get(0).equals(BIGINT) && hashChannels.length == 1) {
return new BigintGroupByHash(hashChannels[0], inputHashChannel.isPresent(), expectedSize, updateMemory);
}
return new MultiChannelGroupByHash(hashTypes, hashChannels, inputHashChannel, expectedSize, processDictionary, joinCompiler, blockTypeOperators, updateMemory);
}

long getEstimatedSize();

long getHashCollisions();
Expand All @@ -74,15 +43,42 @@ static GroupByHash createGroupByHash(

Work<GroupByIdBlock> getGroupIds(Page page);

boolean contains(int position, Page page, int[] hashChannels);
boolean contains(int position, Page page);

default boolean contains(int position, Page page, int[] hashChannels, long rawHash)
default boolean contains(int position, Page page, long rawHash)
{
return contains(position, page, hashChannels);
return contains(position, page);
}

long getRawHash(int groupyId);

@VisibleForTesting
int getCapacity();

default WorkProcessor<Page> buildResult(IntIterator groupIds, PageBuilder pageBuilder, List<GroupedAggregator> groupedAggregators)
{
return WorkProcessor.create(() -> {
if (!groupIds.hasNext()) {
return ProcessState.finished();
}

pageBuilder.reset();

List<Type> types = getTypes();
while (!pageBuilder.isFull() && groupIds.hasNext()) {
int groupId = groupIds.nextInt();

appendValuesTo(groupId, pageBuilder);

pageBuilder.declarePosition();
for (int i = 0; i < groupedAggregators.size(); i++) {
GroupedAggregator groupedAggregator = groupedAggregators.get(i);
BlockBuilder output = pageBuilder.getBlockBuilder(types.size() + i);
groupedAggregator.evaluate(groupId, output);
}
}

return ProcessState.ofResult(pageBuilder.build());
});
}
}
Loading