Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

public class BucketPartitionFunction
Expand Down Expand Up @@ -46,4 +47,19 @@ public int getPartition(Page functionArguments, int position)
int bucket = bucketFunction.getBucket(functionArguments, position);
return bucketToPartition[bucket];
}

@Override
public void getPartitions(Page functionArguments, int positionOffset, int length, int[] partitions)
{
checkArgument(positionOffset >= 0, "Invalid positionOffset: %s", positionOffset);
checkArgument(length >= 0, "Invalid length: %s", length);
checkArgument(positionOffset + length <= functionArguments.getPositionCount(), "End position exceeds page position count: %s > %s", positionOffset + length, functionArguments.getPositionCount());
checkArgument(length <= partitions.length, "Length exceeds partitions length: %s > %s", length, partitions.length);

bucketFunction.getBuckets(functionArguments, positionOffset, length, partitions);
for (int i = 0; i < length; i++) {
int bucket = partitions[i];
partitions[i] = bucketToPartition[bucket];
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class FlatGroupByHash
private static final double SMALL_DICTIONARIES_MAX_CARDINALITY_RATIO = 0.25;

private final FlatHash flatHash;
private final InterpretedHashGenerator hashGenerator;
private final int groupByChannelCount;

private final boolean processDictionary;
Expand All @@ -69,6 +70,7 @@ public FlatGroupByHash(
UpdateMemory checkMemoryReservation)
{
this.flatHash = new FlatHash(hashStrategyCompiler.getFlatHashStrategy(hashTypes), cacheHashValue, expectedSize, checkMemoryReservation);
this.hashGenerator = hashStrategyCompiler.getInterpretedHashGenerator(hashTypes);
this.groupByChannelCount = hashTypes.size();

checkArgument(expectedSize > 0, "expectedSize must be greater than zero");
Expand All @@ -83,6 +85,7 @@ public FlatGroupByHash(
public FlatGroupByHash(FlatGroupByHash other)
{
this.flatHash = other.flatHash.copy();
this.hashGenerator = other.hashGenerator;
groupByChannelCount = other.groupByChannelCount;
processDictionary = other.processDictionary;
dictionaryLookBack = other.dictionaryLookBack == null ? null : other.dictionaryLookBack.copy();
Expand Down Expand Up @@ -358,7 +361,7 @@ public boolean process()
return false;
}

flatHash.computeHashes(blocks, hashes, lastPosition, batchSize);
hashGenerator.hashBlocksBatched(blocks, hashes, lastPosition, batchSize);
for (int i = 0; i < batchSize; i++) {
flatHash.putIfAbsent(blocks, lastPosition + i, hashes[i]);
}
Expand Down Expand Up @@ -524,7 +527,7 @@ public boolean process()
return false;
}

flatHash.computeHashes(blocks, hashes, lastPosition, batchSize);
hashGenerator.hashBlocksBatched(blocks, hashes, lastPosition, batchSize);
for (int i = 0, position = lastPosition; i < batchSize; i++, position++) {
groupIds[position] = flatHash.putIfAbsent(blocks, position, hashes[i]);
}
Expand Down
5 changes: 0 additions & 5 deletions core/trino-main/src/main/java/io/trino/operator/FlatHash.java
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,6 @@ public void appendTo(int groupId, BlockBuilder[] blockBuilders)
}
}

public void computeHashes(Block[] blocks, long[] hashes, int offset, int length)
{
flatHashStrategy.hashBlocksBatched(blocks, hashes, offset, length);
}

public int putIfAbsent(Block[] blocks, int position)
{
long hash = flatHashStrategy.hash(blocks, position);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,4 @@ boolean valueIdentical(
long hash(Block[] blocks, int position);

long hash(byte[] fixedChunk, int fixedOffset, byte[] variableChunk, int variableOffset);

void hashBlocksBatched(Block[] blocks, long[] hashes, int offset, int length);
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.ForLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.trino.annotation.UsedByGeneratedCode;
import io.trino.cache.CacheStatsMBean;
import io.trino.operator.scalar.CombineHashFunction;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.gen.CallSiteBinder;
Expand All @@ -45,11 +43,7 @@

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.Iterables.getOnlyElement;
Expand All @@ -61,14 +55,12 @@
import static io.airlift.bytecode.Parameter.arg;
import static io.airlift.bytecode.ParameterizedType.type;
import static io.airlift.bytecode.expression.BytecodeExpressions.add;
import static io.airlift.bytecode.expression.BytecodeExpressions.and;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantBoolean;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantLong;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantNull;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue;
import static io.airlift.bytecode.expression.BytecodeExpressions.equal;
import static io.airlift.bytecode.expression.BytecodeExpressions.inlineIf;
import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic;
import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic;
Expand All @@ -78,6 +70,7 @@
import static io.airlift.bytecode.expression.BytecodeExpressions.notEqual;
import static io.trino.cache.SafeCaches.buildNonEvictableCache;
import static io.trino.operator.HashGenerator.INITIAL_HASH_VALUE;
import static io.trino.operator.InterpretedHashGenerator.createPagePrefixHashGenerator;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER;
Expand All @@ -90,17 +83,20 @@
import static io.trino.sql.gen.SqlTypeBytecodeExpression.constantType;
import static io.trino.util.CompilerUtils.defineClass;
import static io.trino.util.CompilerUtils.makeClassName;
import static java.util.Objects.requireNonNull;

public final class FlatHashStrategyCompiler
{
@VisibleForTesting
static final int COLUMNS_PER_CHUNK = 500;

private final LoadingCache<List<Type>, FlatHashStrategy> flatHashStrategies;
private final NullSafeHashCompiler nullSafeHashCompiler;

@Inject
public FlatHashStrategyCompiler(TypeOperators typeOperators)
public FlatHashStrategyCompiler(TypeOperators typeOperators, NullSafeHashCompiler nullSafeHashCompiler)
{
this.nullSafeHashCompiler = requireNonNull(nullSafeHashCompiler, "nullSafeHashCompiler is null");
this.flatHashStrategies = buildNonEvictableCache(
CacheBuilder.newBuilder()
.recordStats()
Expand All @@ -113,6 +109,11 @@ public FlatHashStrategy getFlatHashStrategy(List<Type> types)
return flatHashStrategies.getUnchecked(ImmutableList.copyOf(types));
}

public InterpretedHashGenerator getInterpretedHashGenerator(List<Type> types)
{
return createPagePrefixHashGenerator(types, nullSafeHashCompiler);
}

@Managed
@Nested
public CacheStatsMBean getFlatHashStrategiesStats()
Expand Down Expand Up @@ -181,7 +182,6 @@ public static FlatHashStrategy compileFlatHashStrategy(List<Type> types, TypeOpe
generateIdenticalMethod(definition, chunkClasses);
generateHashBlock(definition, chunkClasses);
generateHashFlat(definition, chunkClasses, singleChunkClass);
generateHashBlocksBatched(definition, chunkClasses);

try {
DynamicClassLoader classLoader = new DynamicClassLoader(FlatHashStrategyCompiler.class.getClassLoader(), callSiteBinder.getBindings());
Expand Down Expand Up @@ -219,7 +219,6 @@ private static ChunkClass compileFlatHashStrategyChunk(CallSiteBinder callSiteBi
else {
hashFlatChunk = generateHashFlatMultiChunk(definition, keyFields, callSiteBinder);
}
MethodDefinition hashBlocksBatchedChunk = generateHashBlocksBatchedChunk(definition, keyFields, callSiteBinder);

return new ChunkClass(
definition,
Expand All @@ -228,8 +227,7 @@ private static ChunkClass compileFlatHashStrategyChunk(CallSiteBinder callSiteBi
writeFlatChunk,
identicalChunkMethod,
hashBlockChunk,
hashFlatChunk,
hashBlocksBatchedChunk);
hashFlatChunk);
}

private static void generateGetTotalVariableWidth(ClassDefinition definition, List<ChunkClass> chunkClasses)
Expand Down Expand Up @@ -686,165 +684,6 @@ private static MethodDefinition generateHashBlockChunk(ClassDefinition definitio
return methodDefinition;
}

private static void generateHashBlocksBatched(ClassDefinition definition, List<ChunkClass> chunkClasses)
{
Parameter blocks = arg("blocks", type(Block[].class));
Parameter hashes = arg("hashes", type(long[].class));
Parameter offset = arg("offset", type(int.class));
Parameter length = arg("length", type(int.class));

MethodDefinition methodDefinition = definition.declareMethod(
a(PUBLIC),
"hashBlocksBatched",
type(void.class),
blocks,
hashes,
offset,
length);

BytecodeBlock body = methodDefinition.getBody();
body.append(invokeStatic(Objects.class, "checkFromIndexSize", int.class, constantInt(0), length, hashes.length()).pop());

BytecodeBlock nonEmptyLength = new BytecodeBlock();
for (ChunkClass chunkClass : chunkClasses) {
nonEmptyLength.append(invokeStatic(chunkClass.hashBlocksBatchedChunk(), blocks, hashes, offset, length));
}

body.append(new IfStatement("if (length != 0)")
.condition(equal(length, constantInt(0)))
.ifFalse(nonEmptyLength))
.ret();
}

private static MethodDefinition generateHashBlocksBatchedChunk(ClassDefinition definition, List<KeyField> keyFields, CallSiteBinder callSiteBinder)
{
Parameter blocks = arg("blocks", type(Block[].class));
Parameter hashes = arg("hashes", type(long[].class));
Parameter offset = arg("offset", type(int.class));
Parameter length = arg("length", type(int.class));

MethodDefinition methodDefinition = definition.declareMethod(
a(PUBLIC, STATIC),
"hashBlocksBatched",
type(void.class),
blocks,
hashes,
offset,
length);

BytecodeBlock body = methodDefinition.getBody();
body.append(invokeStatic(Objects.class, "checkFromIndexSize", int.class, constantInt(0), length, hashes.length()).pop());

BytecodeBlock nonEmptyLength = new BytecodeBlock();

Map<Type, MethodDefinition> typeMethods = new HashMap<>();
for (KeyField keyField : keyFields) {
MethodDefinition method;
// The first hash method implementation does not combine hashes, so it can't be reused
if (keyField.index() == 0) {
method = generateHashBlockVectorized(definition, keyField, callSiteBinder);
}
else {
// Columns of the same type can reuse the same static method implementation
method = typeMethods.get(keyField.type());
if (method == null) {
method = generateHashBlockVectorized(definition, keyField, callSiteBinder);
typeMethods.put(keyField.type(), method);
}
}
nonEmptyLength.append(invokeStatic(method, blocks.getElement(keyField.index()), hashes, offset, length));
}

body.append(new IfStatement("if (length != 0)")
.condition(equal(length, constantInt(0)))
.ifFalse(nonEmptyLength))
.ret();

return methodDefinition;
}

private static MethodDefinition generateHashBlockVectorized(ClassDefinition definition, KeyField field, CallSiteBinder callSiteBinder)
{
Parameter block = arg("block", type(Block.class));
Parameter hashes = arg("hashes", type(long[].class));
Parameter offset = arg("offset", type(int.class));
Parameter length = arg("length", type(int.class));

MethodDefinition methodDefinition = definition.declareMethod(
a(PUBLIC, STATIC),
"hashBlockVectorized_" + field.index(),
type(void.class),
block,
hashes,
offset,
length);

Scope scope = methodDefinition.getScope();
BytecodeBlock body = methodDefinition.getBody();

Variable index = scope.declareVariable(int.class, "index");
Variable position = scope.declareVariable(int.class, "position");
Variable mayHaveNull = scope.declareVariable(boolean.class, "mayHaveNull");
Variable hash = scope.declareVariable(long.class, "hash");

body.append(position.set(invokeStatic(Objects.class, "checkFromToIndex", int.class, offset, add(offset, length), block.invoke("getPositionCount", int.class))));
body.append(invokeStatic(Objects.class, "checkFromIndexSize", int.class, constantInt(0), length, hashes.length()).pop());

BytecodeExpression computeHashNonNull = invokeDynamic(
BOOTSTRAP_METHOD,
ImmutableList.of(callSiteBinder.bind(field.hashBlockMethod()).getBindingId()),
"hash",
long.class,
block,
position);

BytecodeBlock rleHandling = new BytecodeBlock()
.append(new IfStatement("hash = block.isNull(position) ? NULL_HASH_CODE : hash(block, position)")
.condition(block.invoke("isNull", boolean.class, position))
.ifTrue(hash.set(constantLong(NULL_HASH_CODE)))
.ifFalse(hash.set(computeHashNonNull)));
if (field.index() == 0) {
// Arrays.fill(hashes, 0, length, hash)
rleHandling.append(invokeStatic(Arrays.class, "fill", void.class, hashes, constantInt(0), length, hash));
}
else {
// CombineHashFunction.combineAllHashesWithConstant(hashes, 0, length, hash)
rleHandling.append(invokeStatic(CombineHashFunction.class, "combineAllHashesWithConstant", void.class, hashes, constantInt(0), length, hash));
}

BytecodeExpression setHashExpression;
if (field.index() == 0) {
// hashes[index] = hash;
setHashExpression = hashes.setElement(index, hash);
}
else {
// hashes[index] = CombineHashFunction.getHash(hashes[index], hash);
setHashExpression = hashes.setElement(index, invokeStatic(CombineHashFunction.class, "getHash", long.class, hashes.getElement(index), hash));
}

BytecodeBlock computeHashLoop = new BytecodeBlock()
.append(mayHaveNull.set(block.invoke("mayHaveNull", boolean.class)))
.append(new ForLoop("for (int index = 0; index < length; index++)")
.initialize(index.set(constantInt(0)))
.condition(lessThan(index, length))
.update(index.increment())
.body(new BytecodeBlock()
.append(new IfStatement("if (mayHaveNull && block.isNull(position))")
.condition(and(mayHaveNull, block.invoke("isNull", boolean.class, position)))
.ifTrue(hash.set(constantLong(NULL_HASH_CODE)))
.ifFalse(hash.set(computeHashNonNull)))
.append(setHashExpression)
.append(position.increment())));

body.append(new IfStatement("if (block instanceof RunLengthEncodedBlock)")
.condition(block.instanceOf(RunLengthEncodedBlock.class))
.ifTrue(rleHandling)
.ifFalse(computeHashLoop))
.ret();

return methodDefinition;
}

private static void generateHashFlat(ClassDefinition definition, List<ChunkClass> chunkClasses, boolean singleChunkClass)
{
Parameter fixedChunk = arg("fixedChunk", type(byte[].class));
Expand Down Expand Up @@ -1031,6 +870,5 @@ private record ChunkClass(
MethodDefinition writeFlatChunk,
MethodDefinition identicalMethodChunk,
MethodDefinition hashBlockChunk,
MethodDefinition hashFlatChunk,
MethodDefinition hashBlocksBatchedChunk) {}
MethodDefinition hashFlatChunk) {}
}
16 changes: 16 additions & 0 deletions core/trino-main/src/main/java/io/trino/operator/HashGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,25 @@ public interface HashGenerator

long hashPosition(int position, Page page);

void hash(Page page, int positionOffset, int length, long[] hashes);

default int getPartition(int partitionCount, int position, Page page)
{
long rawHash = hashPosition(position, page);
return processRawHash(rawHash, partitionCount);
}

default void getPartitions(int partitionCount, int positionOffset, Page page, int length, int[] partitions)
{
long[] hashes = new long[length];
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reusing this array is likely to be beneficial for most operator use cases, since allocating a new instance on each invocation is non-trivial allocation pressure. Having this default implementation seems like a performance hazard

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The allocation here is very short-lived and the JVM is pretty good at optimizing for that. Trying to reuse array adds some complexity as it has to be passed down from the calling operator, where it potentially needs to be tracked as a retained memory allocation. Since we didn't observe a problem with this in production for a while, I'm inclined to keep it simple for now and explore reuse as a follow-up.

hash(page, positionOffset, length, hashes);
for (int i = 0; i < length; i++) {
partitions[i] = processRawHash(hashes[i], partitionCount);
}
}

private static int processRawHash(long rawHash, int partitionCount)
{
// This function reduces the 64 bit rawHash to [0, partitionCount) uniformly. It first reduces the rawHash to 32 bit
// integer x then normalize it to x / 2^32 * partitionCount to reduce the range of x from [0, 2^32) to [0, partitionCount)
return (int) ((Integer.toUnsignedLong(Long.hashCode(rawHash)) * partitionCount) >>> 32);
Expand Down
Loading