diff --git a/core/trino-main/src/main/java/io/trino/operator/BucketPartitionFunction.java b/core/trino-main/src/main/java/io/trino/operator/BucketPartitionFunction.java index 37676c4f9ebc..457169d6bb51 100644 --- a/core/trino-main/src/main/java/io/trino/operator/BucketPartitionFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/BucketPartitionFunction.java @@ -46,4 +46,12 @@ public int getPartition(Page functionArguments, int position) int bucket = bucketFunction.getBucket(functionArguments, position); return bucketToPartition[bucket]; } + + @Override + public void getPartitions(Page page, int[] partitions, long[] rawHashes, int offset, int length) + { + for (int i = 0; i < length; i++) { + partitions[i] = getPartition(page, offset + i); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java index bfe32991c863..795865bc8762 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java @@ -28,15 +28,15 @@ import io.airlift.bytecode.Parameter; import io.airlift.bytecode.Scope; import io.airlift.bytecode.Variable; -import io.airlift.bytecode.control.ForLoop; import io.airlift.bytecode.control.IfStatement; import io.airlift.bytecode.expression.BytecodeExpression; import io.trino.annotation.UsedByGeneratedCode; import io.trino.cache.CacheStatsMBean; +import io.trino.operator.HashStrategyCompilerUtils.ChunkClass; +import io.trino.operator.HashStrategyCompilerUtils.HashGeneratorKeyField; import io.trino.operator.scalar.CombineHashFunction; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.sql.gen.CallSiteBinder; @@ -45,10 +45,8 @@ import java.lang.invoke.MethodHandle; import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; +import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Objects; import static com.google.common.base.Preconditions.checkArgument; @@ -61,7 +59,6 @@ import static io.airlift.bytecode.Parameter.arg; import static io.airlift.bytecode.ParameterizedType.type; import static io.airlift.bytecode.expression.BytecodeExpressions.add; -import static io.airlift.bytecode.expression.BytecodeExpressions.and; import static io.airlift.bytecode.expression.BytecodeExpressions.constantBoolean; import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; @@ -78,6 +75,9 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.notEqual; import static io.trino.cache.SafeCaches.buildNonEvictableCache; import static io.trino.operator.HashGenerator.INITIAL_HASH_VALUE; +import static io.trino.operator.HashStrategyCompilerUtils.generateHashBlock; +import static io.trino.operator.HashStrategyCompilerUtils.generateHashBlockChunk; +import static io.trino.operator.HashStrategyCompilerUtils.generateHashBlocksBatchedChunk; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; @@ -124,6 +124,7 @@ public CacheStatsMBean getFlatHashStrategiesStats() public static FlatHashStrategy compileFlatHashStrategy(List types, TypeOperators typeOperators) { List keyFields = new ArrayList<>(); + List hashGeneratorKeyFields = new ArrayList<>(); int fixedOffset = 0; for (int i = 0; i < types.size(); i++) { Type type = types.get(i); @@ -136,7 +137,13 @@ public static FlatHashStrategy compileFlatHashStrategy(List types, TypeOpe typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)), typeOperators.getIdenticalOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)), typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, FLAT)), - typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)))); + typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)), + null)); + hashGeneratorKeyFields.add(new HashGeneratorKeyField( + i, + type, + typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)), + i)); fixedOffset += 1 + type.getFlatFixedSize(); } @@ -145,8 +152,13 @@ public static FlatHashStrategy compileFlatHashStrategy(List types, TypeOpe int chunkNumber = 0; // generate a separate class for each chunk of 500 types to avoid hitting the JVM method size and constant pool limits boolean singleChunkClass = keyFields.size() <= COLUMNS_PER_CHUNK; + Iterator hashGeneratorKeyFieldsIterator = hashGeneratorKeyFields.iterator(); for (List chunk : Lists.partition(keyFields, COLUMNS_PER_CHUNK)) { - chunkClasses.add(compileFlatHashStrategyChunk(callSiteBinder, chunk, chunkNumber, singleChunkClass)); + List hashGeneratorKeyFieldChunk = new ArrayList<>(); + for (int i = 0; i < chunk.size(); i++) { + hashGeneratorKeyFieldChunk.add(hashGeneratorKeyFieldsIterator.next()); + } + chunkClasses.add(compileFlatHashStrategyChunk(callSiteBinder, chunk, hashGeneratorKeyFieldChunk, chunkNumber, singleChunkClass)); chunkNumber++; } @@ -197,7 +209,7 @@ public static FlatHashStrategy compileFlatHashStrategy(List types, TypeOpe } } - private static ChunkClass compileFlatHashStrategyChunk(CallSiteBinder callSiteBinder, List keyFields, int chunkNumber, boolean singleChunkClass) + private static ChunkClass compileFlatHashStrategyChunk(CallSiteBinder callSiteBinder, List keyFields, List hashGeneratorKeyFields, int chunkNumber, boolean singleChunkClass) { ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), @@ -211,7 +223,7 @@ private static ChunkClass compileFlatHashStrategyChunk(CallSiteBinder callSiteBi MethodDefinition readFlatChunk = generateReadFlatChunk(definition, keyFields, callSiteBinder); MethodDefinition writeFlatChunk = generateWriteFlatChunk(definition, keyFields, callSiteBinder); MethodDefinition identicalChunkMethod = generateIdenticalChunkMethod(definition, keyFields, callSiteBinder); - MethodDefinition hashBlockChunk = generateHashBlockChunk(definition, keyFields, callSiteBinder); + MethodDefinition hashBlockChunk = generateHashBlockChunk(definition, hashGeneratorKeyFields, callSiteBinder); MethodDefinition hashFlatChunk; if (singleChunkClass) { hashFlatChunk = generateHashFlatSingleChunk(definition, keyFields, callSiteBinder); @@ -219,7 +231,7 @@ private static ChunkClass compileFlatHashStrategyChunk(CallSiteBinder callSiteBi else { hashFlatChunk = generateHashFlatMultiChunk(definition, keyFields, callSiteBinder); } - MethodDefinition hashBlocksBatchedChunk = generateHashBlocksBatchedChunk(definition, keyFields, callSiteBinder); + MethodDefinition hashBlocksBatchedChunk = generateHashBlocksBatchedChunk(definition, hashGeneratorKeyFields, callSiteBinder); return new ChunkClass( definition, @@ -629,63 +641,6 @@ private static MethodDefinition generateVariableWidthIdenticalMethod(ClassDefini return methodDefinition; } - private static void generateHashBlock(ClassDefinition definition, List chunkClasses) - { - Parameter blocks = arg("blocks", type(Block[].class)); - Parameter position = arg("position", type(int.class)); - MethodDefinition methodDefinition = definition.declareMethod( - a(PUBLIC), - "hash", - type(long.class), - blocks, - position); - BytecodeBlock body = methodDefinition.getBody(); - - Scope scope = methodDefinition.getScope(); - Variable result = scope.declareVariable("result", body, constantLong(INITIAL_HASH_VALUE)); - for (ChunkClass chunkClass : chunkClasses) { - body.append(result.set(invokeStatic(chunkClass.hashBlockChunk(), blocks, position, result))); - } - body.append(result.ret()); - } - - private static MethodDefinition generateHashBlockChunk(ClassDefinition definition, List keyFields, CallSiteBinder callSiteBinder) - { - Parameter blocks = arg("blocks", type(Block[].class)); - Parameter position = arg("position", type(int.class)); - Parameter seed = arg("seed", type(long.class)); - MethodDefinition methodDefinition = definition.declareMethod( - a(PUBLIC, STATIC), - "hashBlocks", - type(long.class), - blocks, - position, - seed); - BytecodeBlock body = methodDefinition.getBody(); - - Scope scope = methodDefinition.getScope(); - Variable result = scope.declareVariable("result", body, seed); - Variable hash = scope.declareVariable(long.class, "hash"); - Variable block = scope.declareVariable(Block.class, "block"); - - for (KeyField keyField : keyFields) { - body.append(block.set(blocks.getElement(keyField.index()))); - body.append(new IfStatement() - .condition(block.invoke("isNull", boolean.class, position)) - .ifTrue(hash.set(constantLong(NULL_HASH_CODE))) - .ifFalse(hash.set(invokeDynamic( - BOOTSTRAP_METHOD, - ImmutableList.of(callSiteBinder.bind(keyField.hashBlockMethod()).getBindingId()), - "hash", - long.class, - block, - position)))); - body.append(result.set(invokeStatic(CombineHashFunction.class, "getHash", long.class, result, hash))); - } - body.append(result.ret()); - return methodDefinition; - } - private static void generateHashBlocksBatched(ClassDefinition definition, List chunkClasses) { Parameter blocks = arg("blocks", type(Block[].class)); @@ -716,135 +671,6 @@ private static void generateHashBlocksBatched(ClassDefinition definition, List keyFields, CallSiteBinder callSiteBinder) - { - Parameter blocks = arg("blocks", type(Block[].class)); - Parameter hashes = arg("hashes", type(long[].class)); - Parameter offset = arg("offset", type(int.class)); - Parameter length = arg("length", type(int.class)); - - MethodDefinition methodDefinition = definition.declareMethod( - a(PUBLIC, STATIC), - "hashBlocksBatched", - type(void.class), - blocks, - hashes, - offset, - length); - - BytecodeBlock body = methodDefinition.getBody(); - body.append(invokeStatic(Objects.class, "checkFromIndexSize", int.class, constantInt(0), length, hashes.length()).pop()); - - BytecodeBlock nonEmptyLength = new BytecodeBlock(); - - Map typeMethods = new HashMap<>(); - for (KeyField keyField : keyFields) { - MethodDefinition method; - // The first hash method implementation does not combine hashes, so it can't be reused - if (keyField.index() == 0) { - method = generateHashBlockVectorized(definition, keyField, callSiteBinder); - } - else { - // Columns of the same type can reuse the same static method implementation - method = typeMethods.get(keyField.type()); - if (method == null) { - method = generateHashBlockVectorized(definition, keyField, callSiteBinder); - typeMethods.put(keyField.type(), method); - } - } - nonEmptyLength.append(invokeStatic(method, blocks.getElement(keyField.index()), hashes, offset, length)); - } - - body.append(new IfStatement("if (length != 0)") - .condition(equal(length, constantInt(0))) - .ifFalse(nonEmptyLength)) - .ret(); - - return methodDefinition; - } - - private static MethodDefinition generateHashBlockVectorized(ClassDefinition definition, KeyField field, CallSiteBinder callSiteBinder) - { - Parameter block = arg("block", type(Block.class)); - Parameter hashes = arg("hashes", type(long[].class)); - Parameter offset = arg("offset", type(int.class)); - Parameter length = arg("length", type(int.class)); - - MethodDefinition methodDefinition = definition.declareMethod( - a(PUBLIC, STATIC), - "hashBlockVectorized_" + field.index(), - type(void.class), - block, - hashes, - offset, - length); - - Scope scope = methodDefinition.getScope(); - BytecodeBlock body = methodDefinition.getBody(); - - Variable index = scope.declareVariable(int.class, "index"); - Variable position = scope.declareVariable(int.class, "position"); - Variable mayHaveNull = scope.declareVariable(boolean.class, "mayHaveNull"); - Variable hash = scope.declareVariable(long.class, "hash"); - - body.append(position.set(invokeStatic(Objects.class, "checkFromToIndex", int.class, offset, add(offset, length), block.invoke("getPositionCount", int.class)))); - body.append(invokeStatic(Objects.class, "checkFromIndexSize", int.class, constantInt(0), length, hashes.length()).pop()); - - BytecodeExpression computeHashNonNull = invokeDynamic( - BOOTSTRAP_METHOD, - ImmutableList.of(callSiteBinder.bind(field.hashBlockMethod()).getBindingId()), - "hash", - long.class, - block, - position); - - BytecodeBlock rleHandling = new BytecodeBlock() - .append(new IfStatement("hash = block.isNull(position) ? NULL_HASH_CODE : hash(block, position)") - .condition(block.invoke("isNull", boolean.class, position)) - .ifTrue(hash.set(constantLong(NULL_HASH_CODE))) - .ifFalse(hash.set(computeHashNonNull))); - if (field.index() == 0) { - // Arrays.fill(hashes, 0, length, hash) - rleHandling.append(invokeStatic(Arrays.class, "fill", void.class, hashes, constantInt(0), length, hash)); - } - else { - // CombineHashFunction.combineAllHashesWithConstant(hashes, 0, length, hash) - rleHandling.append(invokeStatic(CombineHashFunction.class, "combineAllHashesWithConstant", void.class, hashes, constantInt(0), length, hash)); - } - - BytecodeExpression setHashExpression; - if (field.index() == 0) { - // hashes[index] = hash; - setHashExpression = hashes.setElement(index, hash); - } - else { - // hashes[index] = CombineHashFunction.getHash(hashes[index], hash); - setHashExpression = hashes.setElement(index, invokeStatic(CombineHashFunction.class, "getHash", long.class, hashes.getElement(index), hash)); - } - - BytecodeBlock computeHashLoop = new BytecodeBlock() - .append(mayHaveNull.set(block.invoke("mayHaveNull", boolean.class))) - .append(new ForLoop("for (int index = 0; index < length; index++)") - .initialize(index.set(constantInt(0))) - .condition(lessThan(index, length)) - .update(index.increment()) - .body(new BytecodeBlock() - .append(new IfStatement("if (mayHaveNull && block.isNull(position))") - .condition(and(mayHaveNull, block.invoke("isNull", boolean.class, position))) - .ifTrue(hash.set(constantLong(NULL_HASH_CODE))) - .ifFalse(hash.set(computeHashNonNull))) - .append(setHashExpression) - .append(position.increment()))); - - body.append(new IfStatement("if (block instanceof RunLengthEncodedBlock)") - .condition(block.instanceOf(RunLengthEncodedBlock.class)) - .ifTrue(rleHandling) - .ifFalse(computeHashLoop)) - .ret(); - - return methodDefinition; - } - private static void generateHashFlat(ClassDefinition definition, List chunkClasses, boolean singleChunkClass) { Parameter fixedChunk = arg("fixedChunk", type(byte[].class)); @@ -1013,7 +839,7 @@ public static int checkVariableWidthOffsetArgument(int variableWidthOffset) return variableWidthOffset; } - private record KeyField( + public record KeyField( int index, Type type, int fieldIsNullOffset, @@ -1022,15 +848,6 @@ private record KeyField( MethodHandle writeFlatMethod, MethodHandle identicalFlatBlockMethod, MethodHandle hashFlatMethod, - MethodHandle hashBlockMethod) {} - - private record ChunkClass( - ClassDefinition definition, - MethodDefinition getTotalVariableWidth, - MethodDefinition readFlatChunk, - MethodDefinition writeFlatChunk, - MethodDefinition identicalMethodChunk, - MethodDefinition hashBlockChunk, - MethodDefinition hashFlatChunk, - MethodDefinition hashBlocksBatchedChunk) {} + MethodHandle hashBlockMethod, + Integer hashChannelIndex) {} } diff --git a/core/trino-main/src/main/java/io/trino/operator/HashGenerator.java b/core/trino-main/src/main/java/io/trino/operator/HashGenerator.java index 7bc300457dfc..564cccf288a0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/HashGenerator.java +++ b/core/trino-main/src/main/java/io/trino/operator/HashGenerator.java @@ -28,4 +28,6 @@ default int getPartition(int partitionCount, int position, Page page) // integer x then normalize it to x / 2^32 * partitionCount to reduce the range of x from [0, 2^32) to [0, partitionCount) return (int) ((Integer.toUnsignedLong(Long.hashCode(rawHash)) * partitionCount) >>> 32); } + + void hashBlocksBatched(Page page, long[] hashes, int offset, int length); } diff --git a/core/trino-main/src/main/java/io/trino/operator/HashStrategyCompilerUtils.java b/core/trino-main/src/main/java/io/trino/operator/HashStrategyCompilerUtils.java new file mode 100644 index 000000000000..9d4b62536ba9 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/HashStrategyCompilerUtils.java @@ -0,0 +1,261 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.collect.ImmutableList; +import io.airlift.bytecode.BytecodeBlock; +import io.airlift.bytecode.ClassDefinition; +import io.airlift.bytecode.MethodDefinition; +import io.airlift.bytecode.Parameter; +import io.airlift.bytecode.Scope; +import io.airlift.bytecode.Variable; +import io.airlift.bytecode.control.ForLoop; +import io.airlift.bytecode.control.IfStatement; +import io.airlift.bytecode.expression.BytecodeExpression; +import io.trino.operator.scalar.CombineHashFunction; +import io.trino.spi.block.Block; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.type.Type; +import io.trino.sql.gen.CallSiteBinder; + +import java.lang.invoke.MethodHandle; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static io.airlift.bytecode.Access.PUBLIC; +import static io.airlift.bytecode.Access.STATIC; +import static io.airlift.bytecode.Access.a; +import static io.airlift.bytecode.Parameter.arg; +import static io.airlift.bytecode.ParameterizedType.type; +import static io.airlift.bytecode.expression.BytecodeExpressions.add; +import static io.airlift.bytecode.expression.BytecodeExpressions.and; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantLong; +import static io.airlift.bytecode.expression.BytecodeExpressions.equal; +import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic; +import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; +import static io.airlift.bytecode.expression.BytecodeExpressions.lessThan; +import static io.trino.operator.HashGenerator.INITIAL_HASH_VALUE; +import static io.trino.spi.type.TypeUtils.NULL_HASH_CODE; +import static io.trino.sql.gen.Bootstrap.BOOTSTRAP_METHOD; + +public final class HashStrategyCompilerUtils +{ + private HashStrategyCompilerUtils() {} + + public static void generateHashBlock(ClassDefinition definition, List chunkClasses) + { + Parameter blocks = arg("blocks", type(Block[].class)); + Parameter position = arg("position", type(int.class)); + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC), + "hash", + type(long.class), + blocks, + position); + BytecodeBlock body = methodDefinition.getBody(); + + Scope scope = methodDefinition.getScope(); + Variable result = scope.declareVariable("result", body, constantLong(INITIAL_HASH_VALUE)); + for (ChunkClass chunkClass : chunkClasses) { + body.append(result.set(invokeStatic(chunkClass.hashBlockChunk(), blocks, position, result))); + } + body.append(result.ret()); + } + + public static MethodDefinition generateHashBlockChunk(ClassDefinition definition, List keyFields, CallSiteBinder callSiteBinder) + { + Parameter blocks = arg("blocks", type(Block[].class)); + Parameter position = arg("position", type(int.class)); + Parameter seed = arg("seed", type(long.class)); + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC, STATIC), + "hashBlocks", + type(long.class), + blocks, + position, + seed); + BytecodeBlock body = methodDefinition.getBody(); + + Scope scope = methodDefinition.getScope(); + Variable result = scope.declareVariable("result", body, seed); + Variable hash = scope.declareVariable(long.class, "hash"); + Variable block = scope.declareVariable(Block.class, "block"); + + for (HashGeneratorKeyField keyField : keyFields) { + body.append(block.set(blocks.getElement(keyField.index()))); + body.append(new IfStatement() + .condition(block.invoke("isNull", boolean.class, position)) + .ifTrue(hash.set(constantLong(NULL_HASH_CODE))) + .ifFalse(hash.set(invokeDynamic( + BOOTSTRAP_METHOD, + ImmutableList.of(callSiteBinder.bind(keyField.hashBlockMethod()).getBindingId()), + "hash", + long.class, + block, + position)))); + body.append(result.set(invokeStatic(CombineHashFunction.class, "getHash", long.class, result, hash))); + } + body.append(result.ret()); + return methodDefinition; + } + + public static MethodDefinition generateHashBlocksBatchedChunk(ClassDefinition definition, List keyFields, CallSiteBinder callSiteBinder) + { + Parameter blocks = arg("blocks", type(Block[].class)); + Parameter hashes = arg("hashes", type(long[].class)); + Parameter offset = arg("offset", type(int.class)); + Parameter length = arg("length", type(int.class)); + + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC, STATIC), + "hashBlocksBatched", + type(void.class), + blocks, + hashes, + offset, + length); + + BytecodeBlock body = methodDefinition.getBody(); + body.append(invokeStatic(Objects.class, "checkFromIndexSize", int.class, constantInt(0), length, hashes.length()).pop()); + + BytecodeBlock nonEmptyLength = new BytecodeBlock(); + + Map typeMethods = new HashMap<>(); + for (HashGeneratorKeyField keyField : keyFields) { + MethodDefinition method; + // The first hash method implementation does not combine hashes, so it can't be reused + if (keyField.index() == 0) { + method = generateHashBlockVectorized(definition, keyField, callSiteBinder); + } + else { + // Columns of the same type can reuse the same static method implementation + method = typeMethods.get(keyField.type()); + if (method == null) { + method = generateHashBlockVectorized(definition, keyField, callSiteBinder); + typeMethods.put(keyField.type(), method); + } + } + nonEmptyLength.append(invokeStatic(method, blocks.getElement(keyField.hashChannelIndex()), hashes, offset, length)); + } + + body.append(new IfStatement("if (length != 0)") + .condition(equal(length, constantInt(0))) + .ifFalse(nonEmptyLength)) + .ret(); + + return methodDefinition; + } + + private static MethodDefinition generateHashBlockVectorized(ClassDefinition definition, HashGeneratorKeyField field, CallSiteBinder callSiteBinder) + { + Parameter block = arg("block", type(Block.class)); + Parameter hashes = arg("hashes", type(long[].class)); + Parameter offset = arg("offset", type(int.class)); + Parameter length = arg("length", type(int.class)); + + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC, STATIC), + "hashBlockVectorized_" + field.index(), + type(void.class), + block, + hashes, + offset, + length); + + Scope scope = methodDefinition.getScope(); + BytecodeBlock body = methodDefinition.getBody(); + + Variable index = scope.declareVariable(int.class, "index"); + Variable position = scope.declareVariable(int.class, "position"); + Variable mayHaveNull = scope.declareVariable(boolean.class, "mayHaveNull"); + Variable hash = scope.declareVariable(long.class, "hash"); + + body.append(position.set(invokeStatic(Objects.class, "checkFromToIndex", int.class, offset, add(offset, length), block.invoke("getPositionCount", int.class)))); + body.append(invokeStatic(Objects.class, "checkFromIndexSize", int.class, constantInt(0), length, hashes.length()).pop()); + + BytecodeExpression computeHashNonNull = invokeDynamic( + BOOTSTRAP_METHOD, + ImmutableList.of(callSiteBinder.bind(field.hashBlockMethod()).getBindingId()), + "hash", + long.class, + block, + position); + + BytecodeBlock rleHandling = new BytecodeBlock() + .append(new IfStatement("hash = block.isNull(position) ? NULL_HASH_CODE : hash(block, position)") + .condition(block.invoke("isNull", boolean.class, position)) + .ifTrue(hash.set(constantLong(NULL_HASH_CODE))) + .ifFalse(hash.set(computeHashNonNull))); + if (field.index() == 0) { + // Arrays.fill(hashes, 0, length, hash) + rleHandling.append(invokeStatic(Arrays.class, "fill", void.class, hashes, constantInt(0), length, hash)); + } + else { + // CombineHashFunction.combineAllHashesWithConstant(hashes, 0, length, hash) + rleHandling.append(invokeStatic(CombineHashFunction.class, "combineAllHashesWithConstant", void.class, hashes, constantInt(0), length, hash)); + } + + BytecodeExpression setHashExpression; + if (field.index() == 0) { + // hashes[index] = hash; + setHashExpression = hashes.setElement(index, hash); + } + else { + // hashes[index] = CombineHashFunction.getHash(hashes[index], hash); + setHashExpression = hashes.setElement(index, invokeStatic(CombineHashFunction.class, "getHash", long.class, hashes.getElement(index), hash)); + } + + BytecodeBlock computeHashLoop = new BytecodeBlock() + .append(mayHaveNull.set(block.invoke("mayHaveNull", boolean.class))) + .append(new ForLoop("for (int index = 0; index < length; index++)") + .initialize(index.set(constantInt(0))) + .condition(lessThan(index, length)) + .update(index.increment()) + .body(new BytecodeBlock() + .append(new IfStatement("if (mayHaveNull && block.isNull(position))") + .condition(and(mayHaveNull, block.invoke("isNull", boolean.class, position))) + .ifTrue(hash.set(constantLong(NULL_HASH_CODE))) + .ifFalse(hash.set(computeHashNonNull))) + .append(setHashExpression) + .append(position.increment()))); + + body.append(new IfStatement("if (block instanceof RunLengthEncodedBlock)") + .condition(block.instanceOf(RunLengthEncodedBlock.class)) + .ifTrue(rleHandling) + .ifFalse(computeHashLoop)) + .ret(); + + return methodDefinition; + } + + public record HashGeneratorKeyField( + int index, + Type type, + MethodHandle hashBlockMethod, + int hashChannelIndex) {} + + public record ChunkClass( + ClassDefinition definition, + MethodDefinition getTotalVariableWidth, + MethodDefinition readFlatChunk, + MethodDefinition writeFlatChunk, + MethodDefinition identicalMethodChunk, + MethodDefinition hashBlockChunk, + MethodDefinition hashFlatChunk, + MethodDefinition hashBlocksBatchedChunk) {} +} diff --git a/core/trino-main/src/main/java/io/trino/operator/InterpretedHashGenerator.java b/core/trino-main/src/main/java/io/trino/operator/InterpretedHashGenerator.java index 72614aaa0c47..1de4f97e5292 100644 --- a/core/trino-main/src/main/java/io/trino/operator/InterpretedHashGenerator.java +++ b/core/trino-main/src/main/java/io/trino/operator/InterpretedHashGenerator.java @@ -84,6 +84,12 @@ public long hashPosition(int position, Page page) return result; } + @Override + public void hashBlocksBatched(Page page, long[] hashes, int offset, int length) + { + throw new UnsupportedOperationException("This method is not supported"); + } + private long nullSafeHash(int operatorIndex, Block block, int position) { try { @@ -104,7 +110,7 @@ public String toString() .toString(); } - private static boolean isPositionalChannels(int[] hashChannels) + public static boolean isPositionalChannels(int[] hashChannels) { for (int i = 0; i < hashChannels.length; i++) { if (hashChannels[i] != i) { diff --git a/core/trino-main/src/main/java/io/trino/operator/PartitionFunction.java b/core/trino-main/src/main/java/io/trino/operator/PartitionFunction.java index ef1627843f86..bc17db0183f9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PartitionFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/PartitionFunction.java @@ -23,4 +23,6 @@ public interface PartitionFunction * @param page the arguments to bucketing function in order (no extra columns) */ int getPartition(Page page, int position); + + void getPartitions(Page page, int[] partitions, long[] rawHashes, int offset, int length); } diff --git a/core/trino-main/src/main/java/io/trino/operator/PartitionHashGeneratorCompiler.java b/core/trino-main/src/main/java/io/trino/operator/PartitionHashGeneratorCompiler.java new file mode 100644 index 000000000000..db7c9ec9e813 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/PartitionHashGeneratorCompiler.java @@ -0,0 +1,274 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.inject.Inject; +import io.airlift.bytecode.BytecodeBlock; +import io.airlift.bytecode.ClassDefinition; +import io.airlift.bytecode.DynamicClassLoader; +import io.airlift.bytecode.FieldDefinition; +import io.airlift.bytecode.MethodDefinition; +import io.airlift.bytecode.Parameter; +import io.airlift.bytecode.Scope; +import io.airlift.bytecode.Variable; +import io.airlift.bytecode.control.IfStatement; +import io.trino.cache.CacheStatsMBean; +import io.trino.operator.HashStrategyCompilerUtils.ChunkClass; +import io.trino.operator.HashStrategyCompilerUtils.HashGeneratorKeyField; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeOperators; +import io.trino.sql.gen.CallSiteBinder; +import jakarta.annotation.Nullable; +import org.assertj.core.util.VisibleForTesting; +import org.weakref.jmx.Managed; +import org.weakref.jmx.Nested; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static io.airlift.bytecode.Access.FINAL; +import static io.airlift.bytecode.Access.PRIVATE; +import static io.airlift.bytecode.Access.PUBLIC; +import static io.airlift.bytecode.Access.a; +import static io.airlift.bytecode.Parameter.arg; +import static io.airlift.bytecode.ParameterizedType.type; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantLong; +import static io.airlift.bytecode.expression.BytecodeExpressions.equal; +import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; +import static io.trino.cache.SafeCaches.buildNonEvictableCache; +import static io.trino.operator.HashGenerator.INITIAL_HASH_VALUE; +import static io.trino.operator.HashStrategyCompilerUtils.generateHashBlock; +import static io.trino.operator.HashStrategyCompilerUtils.generateHashBlockChunk; +import static io.trino.operator.HashStrategyCompilerUtils.generateHashBlocksBatchedChunk; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.InvocationConvention.simpleConvention; +import static io.trino.sql.gen.BytecodeUtils.loadConstant; +import static io.trino.util.CompilerUtils.defineClass; +import static io.trino.util.CompilerUtils.makeClassName; + +public final class PartitionHashGeneratorCompiler +{ + @VisibleForTesting + static final int COLUMNS_PER_CHUNK = 500; + + private final LoadingCache partitionHashGeneratorLoadingCache; + + @Inject + public PartitionHashGeneratorCompiler(TypeOperators typeOperators) + { + this.partitionHashGeneratorLoadingCache = buildNonEvictableCache( + CacheBuilder.newBuilder() + .recordStats() + .maximumSize(1000), + CacheLoader.from(key -> compilePartitionHashGenerator(key.getTypes(), key.getHashChannels(), typeOperators))); + } + + public HashGenerator getPartitionHashGenerator(List types, @Nullable int[] hashChannels) + { + return partitionHashGeneratorLoadingCache.getUnchecked(new CacheKey(types, hashChannels)); + } + + @Managed + @Nested + public CacheStatsMBean getPartitionHashGeneratorStats() + { + return new CacheStatsMBean(partitionHashGeneratorLoadingCache); + } + + @VisibleForTesting + public static HashGenerator compilePartitionHashGenerator(List types, @Nullable int[] hashChannels, TypeOperators typeOperators) + { + List keyFields = new ArrayList<>(); + for (int i = 0; i < types.size(); i++) { + Type type = types.get(i); + int hashChannelIndex = hashChannels == null ? i : hashChannels[i]; + keyFields.add(new HashGeneratorKeyField( + i, + type, + typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)), + hashChannelIndex)); + } + + CallSiteBinder callSiteBinder = new CallSiteBinder(); + List chunkClasses = new ArrayList<>(); + int chunkNumber = 0; + for (List chunk : Lists.partition(keyFields, COLUMNS_PER_CHUNK)) { + chunkClasses.add(compilePartitionHashGeneratorChunk(callSiteBinder, chunk, chunkNumber)); + chunkNumber++; + } + + ClassDefinition definition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName("HashGenerator"), + type(Object.class), + type(HashGenerator.class)); + + FieldDefinition typesField = definition.declareField(a(PRIVATE, FINAL), "types", type(List.class, Type.class)); + MethodDefinition constructor = definition.declareConstructor(a(PUBLIC)); + constructor + .getBody() + .append(constructor.getThis()) + .invokeConstructor(Object.class) + .append(constructor.getThis().setField(typesField, loadConstant(callSiteBinder, ImmutableList.copyOf(types), List.class))) + .ret(); + + generateHashBlock(definition, chunkClasses); + generateHashBlocksBatched(definition, chunkClasses); + generateHashPosition(definition, chunkClasses); + + try { + DynamicClassLoader classLoader = new DynamicClassLoader(PartitionHashGeneratorCompiler.class.getClassLoader(), callSiteBinder.getBindings()); + for (ChunkClass chunkClass : chunkClasses) { + defineClass(chunkClass.definition(), Object.class, classLoader); + } + return defineClass(definition, HashGenerator.class, classLoader) + .getConstructor() + .newInstance(); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } + + private static ChunkClass compilePartitionHashGeneratorChunk(CallSiteBinder callSiteBinder, List keyFields, int chunkNumber) + { + ClassDefinition definition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName("PartitionHashGeneratorChunk$" + chunkNumber), + type(Object.class), + type(HashGenerator.class)); + + definition.declareDefaultConstructor(a(PRIVATE)); + + MethodDefinition hashBlockChunk = generateHashBlockChunk(definition, keyFields, callSiteBinder); + MethodDefinition hashBlocksBatchedChunk = generateHashBlocksBatchedChunk(definition, keyFields, callSiteBinder); + + return new ChunkClass( + definition, + null, + null, + null, + null, + hashBlockChunk, + null, + hashBlocksBatchedChunk); + } + + private static void generateHashBlocksBatched(ClassDefinition definition, List chunkClasses) + { + Parameter page = arg("page", type(Page.class)); + Parameter hashes = arg("hashes", type(long[].class)); + Parameter offset = arg("offset", type(int.class)); + Parameter length = arg("length", type(int.class)); + + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC), + "hashBlocksBatched", + type(void.class), + page, + hashes, + offset, + length); + + Scope scope = methodDefinition.getScope(); + BytecodeBlock body = methodDefinition.getBody(); + Variable blocks = scope.declareVariable(Block[].class, "blocks"); + body.append(invokeStatic(Objects.class, "checkFromIndexSize", int.class, constantInt(0), length, hashes.length()).pop()); + body.append(blocks.set(page.invoke("getBlocks", Block[].class))); + + BytecodeBlock nonEmptyLength = new BytecodeBlock(); + for (ChunkClass chunkClass : chunkClasses) { + nonEmptyLength.append(invokeStatic(chunkClass.hashBlocksBatchedChunk(), blocks, hashes, offset, length)); + } + + body.append(new IfStatement("if (length != 0)") + .condition(equal(length, constantInt(0))) + .ifFalse(nonEmptyLength)) + .ret(); + } + + private static void generateHashPosition(ClassDefinition definition, List chunkClasses) + { + Parameter position = arg("position", type(int.class)); + Parameter page = arg("page", type(Page.class)); + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC), + "hashPosition", + type(long.class), + position, + page); + + Scope scope = methodDefinition.getScope(); + BytecodeBlock body = methodDefinition.getBody(); + Variable blocks = scope.declareVariable(Block[].class, "blocks"); + Variable result = scope.declareVariable("result", body, constantLong(INITIAL_HASH_VALUE)); + body.append(blocks.set(page.invoke("getBlocks", Block[].class))); + for (ChunkClass chunkClass : chunkClasses) { + body.append(result.set(invokeStatic(chunkClass.hashBlockChunk(), blocks, position, result))); + } + body.append(result.ret()); + } + + private record CacheKey( + List types, + int[] hashChannels) + { + private CacheKey(List types, @Nullable int[] hashChannels) + { + this.types = ImmutableList.copyOf(types); + this.hashChannels = hashChannels; + } + + private List getTypes() + { + return types; + } + + private int[] getHashChannels() + { + return hashChannels; + } + + @Override + public int hashCode() + { + return Objects.hash(types, Arrays.hashCode(hashChannels)); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + CacheKey other = (CacheKey) obj; + return Objects.equals(this.types, other.types) && + Arrays.equals(this.hashChannels, other.hashChannels); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/PrecomputedHashGenerator.java b/core/trino-main/src/main/java/io/trino/operator/PrecomputedHashGenerator.java index 032239596599..74284dec5831 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PrecomputedHashGenerator.java +++ b/core/trino-main/src/main/java/io/trino/operator/PrecomputedHashGenerator.java @@ -36,6 +36,15 @@ public long hashPosition(int position, Page page) return BIGINT.getLong(hashBlock, position); } + @Override + public void hashBlocksBatched(Page page, long[] hashes, int offset, int length) + { + Block hashBlock = page.getBlock(hashChannel); + for (int i = 0; i < length; i++) { + hashes[i] = BIGINT.getLong(hashBlock, i + offset); + } + } + @Override public String toString() { diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java index 3adbae96fa5c..56d5c1dceb10 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java @@ -23,10 +23,10 @@ import io.trino.Session; import io.trino.operator.HashGenerator; import io.trino.operator.PartitionFunction; +import io.trino.operator.PartitionHashGeneratorCompiler; import io.trino.operator.output.SkewedPartitionRebalancer; import io.trino.spi.Page; import io.trino.spi.type.Type; -import io.trino.spi.type.TypeOperators; import io.trino.sql.planner.MergePartitioningHandle; import io.trino.sql.planner.PartitionFunctionProvider; import io.trino.sql.planner.PartitioningHandle; @@ -49,7 +49,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.SystemSessionProperties.getQueryMaxMemoryPerNode; import static io.trino.SystemSessionProperties.getSkewedPartitionMinDataProcessedRebalanceThreshold; -import static io.trino.operator.InterpretedHashGenerator.createChannelsHashGenerator; +import static io.trino.operator.InterpretedHashGenerator.isPositionalChannels; import static io.trino.operator.exchange.LocalExchangeSink.finishedLocalExchangeSink; import static io.trino.sql.planner.PartitioningHandle.isScaledWriterHashDistribution; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; @@ -94,11 +94,12 @@ public LocalExchange( List partitionChannels, List partitionChannelTypes, DataSize maxBufferedBytes, - TypeOperators typeOperators, DataSize writerScalingMinDataProcessed, - Supplier totalMemoryUsed) + Supplier totalMemoryUsed, + PartitionHashGeneratorCompiler partitionHashGeneratorCompiler) { int bufferCount = computeBufferCount(partitioning, defaultConcurrency, partitionChannels); + requireNonNull(partitionHashGeneratorCompiler, "partitionHashGeneratorCompiler is null"); if (partitioning.equals(SINGLE_DISTRIBUTION) || partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION)) { LocalExchangeMemoryManager memoryManager = new LocalExchangeMemoryManager(maxBufferedBytes.toBytes()); @@ -153,12 +154,12 @@ else if (isScaledWriterHashDistribution(partitioning)) { PartitionFunction partitionFunction = createPartitionFunction( partitionFunctionProvider, session, - typeOperators, bucketCount, partitioning, partitionCount, partitionChannels, - partitionChannelTypes); + partitionChannelTypes, + partitionHashGeneratorCompiler); return new ScaleWriterPartitioningExchanger( asPageConsumers(sources), memoryManager, @@ -181,12 +182,12 @@ else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalog PartitionFunction partitionFunction = createPartitionFunction( partitionFunctionProvider, session, - typeOperators, bucketCount, partitioning, bufferCount, partitionChannels, - partitionChannelTypes); + partitionChannelTypes, + partitionHashGeneratorCompiler); return new PartitioningExchanger( asPageConsumers(sources), memoryManager, @@ -236,17 +237,24 @@ private static Function createPartitionPagePreparer(PartitioningHand private static PartitionFunction createPartitionFunction( PartitionFunctionProvider partitionFunctionProvider, Session session, - TypeOperators typeOperators, OptionalInt optionalBucketCount, PartitioningHandle partitioningHandle, int partitionCount, List partitionChannels, - List partitionChannelTypes) + List partitionChannelTypes, + PartitionHashGeneratorCompiler partitionHashGeneratorCompiler) { checkArgument(Integer.bitCount(partitionCount) == 1, "partitionCount must be a power of 2"); if (partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle) { - HashGenerator hashGenerator = createChannelsHashGenerator(partitionChannelTypes, Ints.toArray(partitionChannels), typeOperators); + HashGenerator hashGenerator; + int[] partitionChannelsArray = partitionChannels == null ? null : Ints.toArray(partitionChannels); + if (partitionChannelsArray != null && !isPositionalChannels(partitionChannelsArray)) { + hashGenerator = partitionHashGeneratorCompiler.getPartitionHashGenerator(partitionChannelTypes, partitionChannelsArray); + } + else { + hashGenerator = partitionHashGeneratorCompiler.getPartitionHashGenerator(partitionChannelTypes, null); + } return new LocalPartitionGenerator(hashGenerator, partitionCount); } @@ -263,7 +271,7 @@ private static PartitionFunction createPartitionFunction( bucketToPartition[bucket] = hashedBucket & (partitionCount - 1); } - return partitionFunctionProvider.getPartitionFunction(session, partitioningHandle, partitionChannelTypes, bucketToPartition); + return partitionFunctionProvider.getPartitionFunction(session, partitioningHandle, partitionChannelTypes, bucketToPartition, partitionHashGeneratorCompiler); } private void checkAllSourcesFinished() diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalPartitionGenerator.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalPartitionGenerator.java index 2c065f182238..d9851f19bd39 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalPartitionGenerator.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalPartitionGenerator.java @@ -48,6 +48,16 @@ public int getPartition(Page page, int position) return processRawHash(rawHash) & hashMask; } + @Override + public void getPartitions(Page page, int[] partitions, long[] rawHashes, int offset, int length) + { + hashGenerator.hashBlocksBatched(page, rawHashes, offset, length); + for (int i = 0; i < length; i++) { + long rawHash = rawHashes[i]; + partitions[i] = processRawHash(rawHash) & hashMask; + } + } + public long getRawHash(Page page, int position) { return hashGenerator.hashPosition(position, page); diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java b/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java index 1407b5f9908f..8ac156ae8346 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java @@ -24,17 +24,22 @@ import java.util.function.Consumer; import java.util.function.Function; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static java.lang.Math.min; import static java.util.Objects.requireNonNull; @NotThreadSafe class PartitioningExchanger implements LocalExchanger { + private static final int BATCH_SIZE = 1024; private final List> buffers; private final LocalExchangeMemoryManager memoryManager; private final Function partitionedPagePreparer; private final PartitionFunction partitionFunction; private final IntArrayList[] partitionAssignments; + private long[] hashesBufferArray; public PartitioningExchanger( List> partitions, @@ -53,15 +58,34 @@ public PartitioningExchanger( } } + public long[] getHashesBufferArray() + { + if (hashesBufferArray == null) { + hashesBufferArray = new long[BATCH_SIZE]; + } + return hashesBufferArray; + } + @Override public void accept(Page page) { Page partitionPage = partitionedPagePreparer.apply(page); - // assign each row to a partition. The assignments lists are all expected to cleared by the previous iterations - for (int position = 0; position < partitionPage.getPositionCount(); position++) { - int partition = partitionFunction.getPartition(partitionPage, position); - partitionAssignments[partition].add(position); + int positionCount = partitionPage.getPositionCount(); + int lastPosition = 0; + checkState(lastPosition <= positionCount, "position count out of bound"); + int remainingPositions = positionCount - lastPosition; + long[] hashes = getHashesBufferArray(); + int[] partitions = new int[BATCH_SIZE]; + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, hashes.length); + partitionFunction.getPartitions(partitionPage, partitions, hashes, lastPosition, batchSize); + for (int i = 0; i < batchSize; i++) { + partitionAssignments[partitions[i]].add(lastPosition + i); + } + lastPosition += batchSize; + remainingPositions -= batchSize; } + verify(lastPosition == positionCount); // build a page for each partition for (int partition = 0; partition < partitionAssignments.length; partition++) { @@ -96,4 +120,10 @@ public ListenableFuture waitForWriting() { return memoryManager.getNotFullFuture(); } + + @Override + public void finish() + { + hashesBufferArray = null; + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/ScaleWriterPartitioningExchanger.java b/core/trino-main/src/main/java/io/trino/operator/exchange/ScaleWriterPartitioningExchanger.java index cdee92ecaea4..cc29f5f42a1c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/ScaleWriterPartitioningExchanger.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/ScaleWriterPartitioningExchanger.java @@ -25,6 +25,9 @@ import java.util.function.Function; import java.util.function.Supplier; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static java.lang.Math.min; import static java.util.Arrays.fill; import static java.util.Objects.requireNonNull; @@ -32,6 +35,7 @@ public class ScaleWriterPartitioningExchanger implements LocalExchanger { private static final double SCALE_WRITER_MEMORY_PERCENTAGE = 0.7; + private static final int BATCH_SIZE = 1024; private final List> buffers; private final LocalExchangeMemoryManager memoryManager; private final long maxBufferedBytes; @@ -45,6 +49,7 @@ public class ScaleWriterPartitioningExchanger private final int[] partitionWriterIndexes; private final Supplier totalMemoryUsed; private final long maxMemoryPerNode; + private long[] hashesBufferArray; public ScaleWriterPartitioningExchanger( List> buffers, @@ -81,6 +86,14 @@ public ScaleWriterPartitioningExchanger( fill(partitionWriterIds, -1); } + public long[] getHashesBufferArray() + { + if (hashesBufferArray == null) { + hashesBufferArray = new long[BATCH_SIZE]; + } + return hashesBufferArray; + } + @Override public void accept(Page page) { @@ -99,24 +112,30 @@ public void accept(Page page) } Page partitionPage = partitionedPagePreparer.apply(page); - - // Assign each row to a writer by looking at partitions scaling state using partitionRebalancer - for (int position = 0; position < partitionPage.getPositionCount(); position++) { - // Get row partition id (or bucket id) which limits to the partitionCount. If there are more physical partitions than - // this artificial partition limit, then it is possible that multiple physical partitions will get assigned the same - // bucket id. Thus, multiple partitions will be scaled together since we track partition physicalWrittenBytes - // using the artificial limit (partitionCount). - int partitionId = partitionFunction.getPartition(partitionPage, position); - partitionRowCounts[partitionId] += 1; - - // Get writer id for this partition by looking at the scaling state - int writerId = partitionWriterIds[partitionId]; - if (writerId == -1) { - writerId = getNextWriterId(partitionId); - partitionWriterIds[partitionId] = writerId; + int positionCount = partitionPage.getPositionCount(); + int lastPosition = 0; + checkState(lastPosition <= positionCount, "position count out of bound"); + int remainingPositions = positionCount - lastPosition; + long[] hashes = getHashesBufferArray(); + int[] partitions = new int[BATCH_SIZE]; + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, hashes.length); + partitionFunction.getPartitions(partitionPage, partitions, hashes, lastPosition, batchSize); + for (int i = 0; i < batchSize; i++) { + int partitionId = partitions[i]; + partitionRowCounts[partitionId] += 1; + // Get writer id for this partition by looking at the scaling state + int writerId = partitionWriterIds[partitionId]; + if (writerId == -1) { + writerId = getNextWriterId(partitionId); + partitionWriterIds[partitionId] = writerId; + } + writerAssignments[writerId].add(lastPosition + i); } - writerAssignments[writerId].add(position); + lastPosition += batchSize; + remainingPositions -= batchSize; } + verify(lastPosition == positionCount); // build a page for each writer for (int bucket = 0; bucket < writerAssignments.length; bucket++) { @@ -160,6 +179,12 @@ public ListenableFuture waitForWriting() return memoryManager.getNotFullFuture(); } + @Override + public void finish() + { + hashesBufferArray = null; + } + private int getNextWriterId(int partitionId) { return partitionRebalancer.getTaskId(partitionId, partitionWriterIndexes[partitionId]++); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java index bf896571e460..b4ffa87859ec 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java @@ -25,7 +25,6 @@ import io.trino.operator.PartitionFunction; import io.trino.spi.Page; import io.trino.spi.block.Block; -import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.metrics.Metrics; import io.trino.spi.predicate.NullableValue; @@ -41,7 +40,6 @@ import java.util.Objects; import java.util.Optional; import java.util.OptionalInt; -import java.util.function.IntUnaryOperator; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -57,6 +55,7 @@ public class PagePartitioner implements Closeable { private static final int COLUMNAR_STRATEGY_COEFFICIENT = 2; + private static final int BATCH_SIZE = 1024; private final OutputBuffer outputBuffer; private final PartitionFunction partitionFunction; private final int[] partitionChannels; @@ -66,13 +65,13 @@ public class PagePartitioner private final PageSerializer serializer; private final PositionsAppenderPageBuilder[] positionsAppenders; private final boolean replicatesAnyRow; - private final boolean partitionProcessRleAndDictionaryBlocks; private final int nullChannel; // when >= 0, send the position to every partition if this channel is null private boolean hasAnyRowBeenReplicated; // outputSizeInBytes that has already been reported to the operator stats during release and should be subtracted // from future stats reporting to avoid double counting private long outputSizeReportedBeforeRelease; + private long[] hashesBufferArray; public PagePartitioner( PartitionFunction partitionFunction, @@ -86,8 +85,7 @@ public PagePartitioner( DataSize maxMemory, PositionsAppenderFactory positionsAppenderFactory, Optional exchangeEncryptionKey, - AggregatedMemoryContext aggregatedMemoryContext, - boolean partitionProcessRleAndDictionaryBlocks) + AggregatedMemoryContext aggregatedMemoryContext) { this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null"); this.partitionChannels = Ints.toArray(requireNonNull(partitionChannels, "partitionChannels is null")); @@ -105,7 +103,6 @@ public PagePartitioner( this.nullChannel = nullChannel.orElse(-1); this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); this.serializer = serdeFactory.createSerializer(exchangeEncryptionKey.map(Ciphers::deserializeAesEncryptionKey)); - this.partitionProcessRleAndDictionaryBlocks = partitionProcessRleAndDictionaryBlocks; // Ensure partition channels align with constant arguments provided for (int i = 0; i < this.partitionChannels.length; i++) { @@ -132,6 +129,14 @@ public PartitionFunction getPartitionFunction() return partitionFunction; } + public long[] getHashesBufferArray() + { + if (hashesBufferArray == null) { + hashesBufferArray = new long[BATCH_SIZE]; + } + return hashesBufferArray; + } + public void partitionPage(Page page, OperatorContext operatorContext) { if (page.getPositionCount() == 0) { @@ -221,6 +226,7 @@ public Metrics prepareForRelease(OperatorContext operatorContext) } } updateMemoryUsage(); + hashesBufferArray = null; // Adjust flushed and buffered values against the previously eagerly reported sizes outputSizeInBytes = adjustFlushedOutputSizeWithEagerlyReportedBytes(outputSizeInBytes); bufferedSizeInBytes = adjustEagerlyReportedBytesWithBufferedBytesOnRelease(bufferedSizeInBytes); @@ -249,27 +255,38 @@ public void partitionPageByRow(Page page) } Page partitionFunctionArgs = getPartitionFunctionArguments(page); - // Skip null block checks if mayHaveNull reports that no positions will be null - if (nullChannel >= 0 && page.getBlock(nullChannel).mayHaveNull()) { - Block nullsBlock = page.getBlock(nullChannel); - for (; position < page.getPositionCount(); position++) { - if (nullsBlock.isNull(position)) { - for (PositionsAppenderPageBuilder pageBuilder : positionsAppenders) { - pageBuilder.appendToOutputPartition(page, position); + int positionCount = partitionFunctionArgs.getPositionCount(); + int lastPosition = position; + checkState(lastPosition <= positionCount, "position count out of bound"); + int remainingPositions = positionCount - lastPosition; + long[] hashes = getHashesBufferArray(); + int[] partitions = new int[BATCH_SIZE]; + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, hashes.length); + partitionFunction.getPartitions(partitionFunctionArgs, partitions, hashes, lastPosition, batchSize); + // Skip null block checks if mayHaveNull reports that no positions will be null + if (nullChannel >= 0 && page.getBlock(nullChannel).mayHaveNull()) { + Block nullsBlock = page.getBlock(nullChannel); + for (int i = 0; i < batchSize; i++) { + if (nullsBlock.isNull(lastPosition + i)) { + for (PositionsAppenderPageBuilder pageBuilder : positionsAppenders) { + pageBuilder.appendToOutputPartition(page, lastPosition + i); + } + } + else { + positionsAppenders[partitions[i]].appendToOutputPartition(page, lastPosition + i); } - } - else { - int partition = partitionFunction.getPartition(partitionFunctionArgs, position); - positionsAppenders[partition].appendToOutputPartition(page, position); } } - } - else { - for (; position < page.getPositionCount(); position++) { - int partition = partitionFunction.getPartition(partitionFunctionArgs, position); - positionsAppenders[partition].appendToOutputPartition(page, position); + else { + for (int i = 0; i < batchSize; i++) { + positionsAppenders[partitions[i]].appendToOutputPartition(page, lastPosition + i); + } } + lastPosition += batchSize; + remainingPositions -= batchSize; } + verify(lastPosition == positionCount); } public void partitionPageByColumn(Page page) @@ -304,17 +321,7 @@ private IntArrayList[] partitionPositions(Page page) Page partitionFunctionArgs = getPartitionFunctionArguments(page); - if (partitionProcessRleAndDictionaryBlocks && partitionFunctionArgs.getChannelCount() > 0 && onlyRleBlocks(partitionFunctionArgs)) { - // we need at least one Rle block since with no blocks partition function - // can return a different value per invocation (e.g. RoundRobinBucketFunction) - partitionBySingleRleValue(page, position, partitionFunctionArgs, partitionPositions); - } - else if (partitionProcessRleAndDictionaryBlocks && partitionFunctionArgs.getChannelCount() == 1 && isDictionaryProcessingFaster(partitionFunctionArgs.getBlock(0))) { - partitionBySingleDictionary(page, position, partitionFunctionArgs, partitionPositions); - } - else { - partitionGeneric(page, position, aPosition -> partitionFunction.getPartition(partitionFunctionArgs, aPosition), partitionPositions); - } + partitionGeneric(page, position, partitionPositions, partitionFunctionArgs); return partitionPositions; } @@ -340,129 +347,46 @@ private static int initialPartitionSize(int averagePositionsPerPartition) return (int) (averagePositionsPerPartition * 1.1) + 32; } - private static boolean onlyRleBlocks(Page page) + private void partitionGeneric(Page page, int position, IntArrayList[] partitionPositions, Page partitionFunctionArgs) { - for (int i = 0; i < page.getChannelCount(); i++) { - if (!(page.getBlock(i) instanceof RunLengthEncodedBlock)) { - return false; + boolean mayHaveNullBlock = nullChannel != -1 && page.getBlock(nullChannel).mayHaveNull(); + Block nullsBlock = null; + if (mayHaveNullBlock) { + nullsBlock = page.getBlock(nullChannel); + int[] nullPositions = new int[page.getPositionCount()]; + int nullCount = 0; + for (int i = position; i < page.getPositionCount(); i++) { + nullPositions[nullCount] = i; + int isNull = nullsBlock.isNull(i) ? 1 : 0; + nullCount += isNull; } - } - return true; - } - - private void partitionBySingleRleValue(Page page, int position, Page partitionFunctionArgs, IntArrayList[] partitionPositions) - { - // copy all positions because all hash function args are the same for every position - if (nullChannel != -1 && page.getBlock(nullChannel).isNull(0)) { - verify(page.getBlock(nullChannel) instanceof RunLengthEncodedBlock, "null channel is not RunLengthEncodedBlock, found instead: %s", page.getBlock(nullChannel)); - // all positions are null - int[] allPositions = integersInRange(position, page.getPositionCount()); - for (IntList partitionPosition : partitionPositions) { - partitionPosition.addElements(position, allPositions); + for (IntArrayList positions : partitionPositions) { + positions.addElements(position, nullPositions, 0, nullCount); } } - else { - // extract rle page to prevent JIT profile pollution - Page rlePage = extractRlePage(partitionFunctionArgs); - - int partition = partitionFunction.getPartition(rlePage, 0); - IntArrayList positions = partitionPositions[partition]; - for (int i = position; i < page.getPositionCount(); i++) { - positions.add(i); + int positionCount = partitionFunctionArgs.getPositionCount(); + int lastPosition = position; + checkState(lastPosition <= positionCount, "position count out of bound"); + int remainingPositions = positionCount - lastPosition; + long[] hashes = getHashesBufferArray(); + int[] partitions = new int[BATCH_SIZE]; + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, hashes.length); + partitionFunction.getPartitions(partitionFunctionArgs, partitions, hashes, lastPosition, batchSize); + for (int i = 0; i < batchSize; i++) { + if (mayHaveNullBlock) { + if (nullsBlock != null && !nullsBlock.isNull(lastPosition + i)) { + partitionPositions[partitions[i]].add(lastPosition + i); + } + } + else { + partitionPositions[partitions[i]].add(lastPosition + i); + } } + lastPosition += batchSize; + remainingPositions -= batchSize; } - } - - private static Page extractRlePage(Page page) - { - Block[] valueBlocks = new Block[page.getChannelCount()]; - for (int channel = 0; channel < valueBlocks.length; ++channel) { - valueBlocks[channel] = ((RunLengthEncodedBlock) page.getBlock(channel)).getValue(); - } - return new Page(valueBlocks); - } - - private static int[] integersInRange(int start, int endExclusive) - { - int[] array = new int[endExclusive - start]; - int current = start; - for (int i = 0; i < array.length; i++) { - array[i] = current++; - } - return array; - } - - private static boolean isDictionaryProcessingFaster(Block block) - { - if (!(block instanceof DictionaryBlock dictionaryBlock)) { - return false; - } - // if dictionary block positionCount is greater than number of elements in the dictionary - // it will be faster to compute hash for the dictionary values only once and re-use it - // instead of recalculating it. - return dictionaryBlock.getPositionCount() > dictionaryBlock.getDictionary().getPositionCount(); - } - - private void partitionBySingleDictionary(Page page, int position, Page partitionFunctionArgs, IntArrayList[] partitionPositions) - { - DictionaryBlock dictionaryBlock = (DictionaryBlock) partitionFunctionArgs.getBlock(0); - Block dictionary = dictionaryBlock.getDictionary(); - int[] dictionaryPartitions = new int[dictionary.getPositionCount()]; - Page dictionaryPage = new Page(dictionary); - for (int i = 0; i < dictionary.getPositionCount(); i++) { - dictionaryPartitions[i] = partitionFunction.getPartition(dictionaryPage, i); - } - - partitionGeneric(page, position, aPosition -> dictionaryPartitions[dictionaryBlock.getId(aPosition)], partitionPositions); - } - - private void partitionGeneric(Page page, int position, IntUnaryOperator partitionFunction, IntArrayList[] partitionPositions) - { - // Skip null block checks if mayHaveNull reports that no positions will be null - if (nullChannel != -1 && page.getBlock(nullChannel).mayHaveNull()) { - partitionNullablePositions(page, position, partitionPositions, partitionFunction); - } - else { - partitionNotNullPositions(page, position, partitionPositions, partitionFunction); - } - } - - private void partitionNullablePositions(Page page, int position, IntArrayList[] partitionPositions, IntUnaryOperator partitionFunction) - { - Block nullsBlock = page.getBlock(nullChannel); - int[] nullPositions = new int[page.getPositionCount()]; - int[] nonNullPositions = new int[page.getPositionCount()]; - int nullCount = 0; - int nonNullCount = 0; - for (int i = position; i < page.getPositionCount(); i++) { - nullPositions[nullCount] = i; - nonNullPositions[nonNullCount] = i; - int isNull = nullsBlock.isNull(i) ? 1 : 0; - nullCount += isNull; - nonNullCount += isNull ^ 1; - } - for (IntArrayList positions : partitionPositions) { - positions.addElements(position, nullPositions, 0, nullCount); - } - for (int i = 0; i < nonNullCount; i++) { - int nonNullPosition = nonNullPositions[i]; - int partition = partitionFunction.applyAsInt(nonNullPosition); - partitionPositions[partition].add(nonNullPosition); - } - } - - private static void partitionNotNullPositions(Page page, int startingPosition, IntArrayList[] partitionPositions, IntUnaryOperator partitionFunction) - { - int positionCount = page.getPositionCount(); - int[] partitionPerPosition = new int[positionCount]; - for (int position = startingPosition; position < positionCount; position++) { - int partition = partitionFunction.applyAsInt(position); - partitionPerPosition[position] = partition; - } - - for (int position = startingPosition; position < positionCount; position++) { - partitionPositions[partitionPerPosition[position]].add(position); - } + verify(lastPosition == positionCount); } private Page getPartitionFunctionArguments(Page page) diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PartitionedOutputOperator.java b/core/trino-main/src/main/java/io/trino/operator/output/PartitionedOutputOperator.java index 43cc7e9404be..a4e945f52418 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PartitionedOutputOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PartitionedOutputOperator.java @@ -182,13 +182,9 @@ public PartitionedOutputOperatorFactory( this.pagePartitionerPool = new PagePartitionerPool( pagePartitionerPoolSize, () -> { - boolean partitionProcessRleAndDictionaryBlocks = true; PartitionFunction function = partitionFunction; if (skewedPartitionRebalancer.isPresent()) { function = new SkewedPartitionFunction(partitionFunction, skewedPartitionRebalancer.get()); - // Partition flattened Rle and Dictionary blocks since if they are scaled then we want to - // round-robin the entire block to increase the writing parallelism across tasks/workers. - partitionProcessRleAndDictionaryBlocks = false; } return new PagePartitioner( function, @@ -202,8 +198,7 @@ public PartitionedOutputOperatorFactory( maxMemory, positionsAppenderFactory, exchangeEncryptionKey, - memoryContext, - partitionProcessRleAndDictionaryBlocks); + memoryContext); }); } diff --git a/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionFunction.java b/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionFunction.java index 638fc54f3b16..40b7fca9e508 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionFunction.java @@ -48,6 +48,20 @@ public int getPartition(Page page, int position) return skewedPartitionRebalancer.getTaskId(partition, partitionRowCount[partition]++); } + @Override + public void getPartitions(Page page, int[] partitions, long[] rawHashes, int offset, int length) + { + partitionFunction.getPartitions(page, partitions, rawHashes, offset, length); + for (int i = 0; i < partitions.length; i++) { + partitions[i] = skewedPartitionRebalancer.getTaskId(partitions[i], partitionRowCount[partitions[i]]++); + } + } + + public PartitionFunction getPartitionFunction() + { + return partitionFunction; + } + public void flushPartitionRowCountToRebalancer() { for (int partition = 0; partition < partitionFunction.partitionCount(); partition++) { diff --git a/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionRebalancer.java b/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionRebalancer.java index 65b0329df811..d251a7d3780c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionRebalancer.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionRebalancer.java @@ -22,6 +22,7 @@ import io.trino.SystemSessionProperties; import io.trino.execution.resourcegroups.IndexedPriorityQueue; import io.trino.operator.PartitionFunction; +import io.trino.operator.PartitionHashGeneratorCompiler; import io.trino.spi.type.Type; import io.trino.sql.planner.NodePartitionMap.BucketToPartition; import io.trino.sql.planner.NodePartitioningManager; @@ -141,7 +142,8 @@ public static PartitionFunction createPartitionFunction( PartitionFunctionProvider partitionFunctionProvider, PartitioningHandle partitioningHandle, int bucketCount, - List partitionChannelTypes) + List partitionChannelTypes, + PartitionHashGeneratorCompiler partitionHashGeneratorCompiler) { // In case of SystemPartitioningHandle we can use arbitrary bucket count so that skewness mitigation // is more granular. @@ -161,7 +163,7 @@ public static PartitionFunction createPartitionFunction( // compared to only a single hive bucket reaching the min limit. int[] bucketToPartition = IntStream.range(0, bucketCount).toArray(); - return partitionFunctionProvider.getPartitionFunction(session, partitioningHandle, partitionChannelTypes, bucketToPartition); + return partitionFunctionProvider.getPartitionFunction(session, partitioningHandle, partitionChannelTypes, bucketToPartition, partitionHashGeneratorCompiler); } public static int getMaxWritersBasedOnMemory(Session session) diff --git a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java index 49ba585e942b..5706ed016d1a 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java @@ -88,6 +88,7 @@ import io.trino.operator.GroupByHashPageIndexerFactory; import io.trino.operator.PagesIndex; import io.trino.operator.PagesIndexPageSorter; +import io.trino.operator.PartitionHashGeneratorCompiler; import io.trino.operator.RetryPolicy; import io.trino.operator.index.IndexJoinLookupStats; import io.trino.operator.index.IndexManager; @@ -316,6 +317,8 @@ protected void setup(Binder binder) newExporter(binder).export(FlatHashStrategyCompiler.class).withGeneratedName(); binder.bind(OrderingCompiler.class).in(Scopes.SINGLETON); newExporter(binder).export(OrderingCompiler.class).withGeneratedName(); + binder.bind(PartitionHashGeneratorCompiler.class).in(Scopes.SINGLETON); + newExporter(binder).export(PartitionHashGeneratorCompiler.class).withGeneratedName(); binder.bind(PagesIndex.Factory.class).to(PagesIndex.DefaultFactory.class); binder.bind(PagesInputStreamFactory.class); jaxrsBinder(binder).bind(IoExceptionSuppressingWriterInterceptor.class); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/HashBucketFunction.java b/core/trino-main/src/main/java/io/trino/sql/planner/HashBucketFunction.java deleted file mode 100644 index 2368873ad64d..000000000000 --- a/core/trino-main/src/main/java/io/trino/sql/planner/HashBucketFunction.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.sql.planner; - -import io.trino.operator.HashGenerator; -import io.trino.spi.Page; -import io.trino.spi.connector.BucketFunction; - -import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.base.Preconditions.checkArgument; - -public class HashBucketFunction - implements BucketFunction -{ - private final HashGenerator generator; - private final int bucketCount; - - public HashBucketFunction(HashGenerator generator, int bucketCount) - { - checkArgument(bucketCount > 0, "partitionCount must be at least 1"); - this.generator = generator; - this.bucketCount = bucketCount; - } - - @Override - public int getBucket(Page page, int position) - { - return generator.getPartition(bucketCount, position, page); - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("generator", generator) - .add("bucketCount", bucketCount) - .toString(); - } -} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/HashPartitionFunction.java b/core/trino-main/src/main/java/io/trino/sql/planner/HashPartitionFunction.java new file mode 100644 index 000000000000..284306134a73 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/HashPartitionFunction.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner; + +import io.trino.operator.HashGenerator; +import io.trino.operator.PartitionFunction; +import io.trino.spi.Page; + +import java.util.stream.IntStream; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +public class HashPartitionFunction + implements PartitionFunction +{ + private final HashGenerator generator; + private final int bucketCount; + private final int[] bucketToPartition; + private final int partitionCount; + + public HashPartitionFunction(HashGenerator generator, int bucketCount, int[] bucketToPartition) + { + checkArgument(bucketCount > 0, "partitionCount must be at least 1"); + this.generator = generator; + this.bucketCount = bucketCount; + this.bucketToPartition = bucketToPartition.clone(); + partitionCount = IntStream.of(bucketToPartition).max().getAsInt() + 1; + } + + @Override + public int partitionCount() + { + return partitionCount; + } + + @Override + public int getPartition(Page functionArguments, int position) + { + int bucket = generator.getPartition(bucketCount, position, functionArguments); + return bucketToPartition[bucket]; + } + + @Override + public void getPartitions(Page page, int[] partitions, long[] rawHashes, int offset, int length) + { + generator.hashBlocksBatched(page, rawHashes, offset, length); + for (int i = 0; i < length; i++) { + long rawHash = rawHashes[i]; + // This function reduces the 64 bit rawHash to [0, partitionCount) uniformly. It first reduces the rawHash to 32 bit + // integer x then normalize it to x / 2^32 * partitionCount to reduce the range of x from [0, 2^32) to [0, partitionCount) + int bucket = (int) ((Integer.toUnsignedLong(Long.hashCode(rawHash)) * bucketToPartition.length) >>> 32); + partitions[i] = bucketToPartition[bucket]; + } + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("generator", generator) + .add("bucketCount", bucketCount) + .toString(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index bc5284ada326..71f1f96106a2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -78,6 +78,7 @@ import io.trino.operator.PagesIndex; import io.trino.operator.PagesSpatialIndexFactory; import io.trino.operator.PartitionFunction; +import io.trino.operator.PartitionHashGeneratorCompiler; import io.trino.operator.RefreshMaterializedViewOperator.RefreshMaterializedViewOperatorFactory; import io.trino.operator.RowNumberOperator; import io.trino.operator.ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory; @@ -458,6 +459,7 @@ public class LocalExecutionPlanner private final PositionsAppenderFactory positionsAppenderFactory; private final NodeVersion version; private final boolean specializeAggregationLoops; + private final PartitionHashGeneratorCompiler partitionHashGeneratorCompiler; private final NonEvictableCache accumulatorFactoryCache = buildNonEvictableCache(CacheBuilder.newBuilder() .maximumSize(1000) @@ -495,7 +497,8 @@ public LocalExecutionPlanner( TableExecuteContextManager tableExecuteContextManager, ExchangeManagerRegistry exchangeManagerRegistry, NodeVersion version, - CompilerConfig compilerConfig) + CompilerConfig compilerConfig, + PartitionHashGeneratorCompiler partitionHashGeneratorCompiler) { this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); this.metadata = plannerContext.getMetadata(); @@ -545,6 +548,7 @@ public LocalExecutionPlanner( this.positionsAppenderFactory = new PositionsAppenderFactory(blockTypeOperators); this.version = requireNonNull(version, "version is null"); this.specializeAggregationLoops = compilerConfig.isSpecializeAggregationLoops(); + this.partitionHashGeneratorCompiler = requireNonNull(partitionHashGeneratorCompiler, "partitionHashGeneratorCompiler is null"); } public LocalExecutionPlan plan( @@ -595,7 +599,7 @@ public LocalExecutionPlan plan( Optional skewedPartitionRebalancer = Optional.empty(); int taskCount = getTaskCount(partitioningScheme); if (outputSkewedBucketCount.isPresent()) { - partitionFunction = createPartitionFunction(taskContext.getSession(), partitionFunctionProvider, partitioningScheme.getPartitioning().getHandle(), outputSkewedBucketCount.getAsInt(), partitionChannelTypes); + partitionFunction = createPartitionFunction(taskContext.getSession(), partitionFunctionProvider, partitioningScheme.getPartitioning().getHandle(), outputSkewedBucketCount.getAsInt(), partitionChannelTypes, partitionHashGeneratorCompiler); int partitionedWriterCount = getPartitionedWriterCountBasedOnMemory(taskContext.getSession()); // Keep the task bucket count to 50% of total local writers int taskBucketCount = (int) ceil(0.5 * partitionedWriterCount); @@ -612,7 +616,8 @@ public LocalExecutionPlan plan( partitioningScheme.getPartitioning().getHandle(), partitionChannelTypes, partitioningScheme.getBucketToPartition() - .orElseThrow(() -> new IllegalArgumentException("Bucket to partition must be set before a partition function can be created"))); + .orElseThrow(() -> new IllegalArgumentException("Bucket to partition must be set before a partition function can be created")), + partitionHashGeneratorCompiler); } OptionalInt nullChannel = OptionalInt.empty(); Set partitioningColumns = partitioningScheme.getPartitioning().getColumns(); @@ -3705,9 +3710,9 @@ private PhysicalOperation createLocalMerge(ExchangeNode node, LocalExecutionPlan ImmutableList.of(), ImmutableList.of(), maxLocalExchangeBufferSize, - typeOperators, getWriterScalingMinDataProcessed(session), - () -> context.getTaskContext().getQueryMemoryReservation().toBytes()); + () -> context.getTaskContext().getQueryMemoryReservation().toBytes(), + partitionHashGeneratorCompiler); List expectedLayout = getOnlyElement(node.getInputs()); Function pagePreprocessor = enforceLoadedLayoutProcessor(expectedLayout, source.getLayout()); @@ -3780,9 +3785,9 @@ else if (context.getDriverInstanceCount().isPresent()) { partitionChannels, partitionChannelTypes, maxLocalExchangeBufferSize, - typeOperators, getWriterScalingMinDataProcessed(session), - () -> context.getTaskContext().getQueryMemoryReservation().toBytes()); + () -> context.getTaskContext().getQueryMemoryReservation().toBytes(), + partitionHashGeneratorCompiler); for (int i = 0; i < node.getSources().size(); i++) { DriverFactoryParameters driverFactoryParameters = driverFactoryParametersList.get(i); PhysicalOperation source = driverFactoryParameters.getSource(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java b/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java index 386078120729..e6ef40e42c33 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java @@ -243,5 +243,21 @@ public int getPartition(Page page, int position) default -> throw new VerifyException("Invalid merge operation number: " + operation); }; } + + @Override + public void getPartitions(Page page, int[] partitions, long[] rawHashes, int offset, int length) + { + Block operationBlock = page.getBlock(0); + for (int i = 0; i < length; i++) { + byte operation = TINYINT.getByte(operationBlock, offset + i); + partitions[i] = switch (operation) { + case INSERT_OPERATION_NUMBER, UPDATE_INSERT_OPERATION_NUMBER -> + insertFunction.getPartition(page.getColumns(insertColumns), offset + i); + case UPDATE_OPERATION_NUMBER, DELETE_OPERATION_NUMBER, UPDATE_DELETE_OPERATION_NUMBER -> + updateFunction.getPartition(page.getColumns(updateColumns), offset + i); + default -> throw new VerifyException("Invalid merge operation number: " + operation); + }; + } + } } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PartitionFunctionProvider.java b/core/trino-main/src/main/java/io/trino/sql/planner/PartitionFunctionProvider.java index e50ab5c1b935..2887eec96fb0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PartitionFunctionProvider.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PartitionFunctionProvider.java @@ -19,6 +19,7 @@ import io.trino.connector.CatalogServiceProvider; import io.trino.operator.BucketPartitionFunction; import io.trino.operator.PartitionFunction; +import io.trino.operator.PartitionHashGeneratorCompiler; import io.trino.spi.connector.BucketFunction; import io.trino.spi.connector.ConnectorNodePartitioningProvider; import io.trino.spi.type.Type; @@ -41,15 +42,15 @@ public PartitionFunctionProvider(TypeOperators typeOperators, CatalogServiceProv this.partitioningProvider = requireNonNull(partitioningProvider, "partitioningProvider is null"); } - public PartitionFunction getPartitionFunction(Session session, PartitioningHandle partitioningHandle, List partitionChannelTypes, int[] bucketToPartition) + public PartitionFunction getPartitionFunction(Session session, PartitioningHandle partitioningHandle, List partitionChannelTypes, int[] bucketToPartition, PartitionHashGeneratorCompiler partitionHashGeneratorCompiler) { if (partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle handle) { - return handle.getPartitionFunction(partitionChannelTypes, bucketToPartition, typeOperators); + return handle.getPartitionFunction(partitionChannelTypes, bucketToPartition, typeOperators, partitionHashGeneratorCompiler); } if (partitioningHandle.getConnectorHandle() instanceof MergePartitioningHandle handle) { return handle.getPartitionFunction( - (scheme, types) -> getPartitionFunction(session, scheme.getPartitioning().getHandle(), types, bucketToPartition), + (scheme, types) -> getPartitionFunction(session, scheme.getPartitioning().getHandle(), types, bucketToPartition, partitionHashGeneratorCompiler), partitionChannelTypes, bucketToPartition); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SystemPartitioningHandle.java b/core/trino-main/src/main/java/io/trino/sql/planner/SystemPartitioningHandle.java index b60183f9b701..a5c079f80e99 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/SystemPartitioningHandle.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SystemPartitioningHandle.java @@ -17,6 +17,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.trino.operator.BucketPartitionFunction; import io.trino.operator.PartitionFunction; +import io.trino.operator.PartitionHashGeneratorCompiler; import io.trino.spi.Page; import io.trino.spi.connector.BucketFunction; import io.trino.spi.connector.ConnectorPartitioningHandle; @@ -29,7 +30,6 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; -import static io.trino.operator.InterpretedHashGenerator.createPagePrefixHashGenerator; import static java.util.Objects.requireNonNull; public final class SystemPartitioningHandle @@ -135,13 +135,18 @@ public String toString() return partitioning.toString(); } - public PartitionFunction getPartitionFunction(List partitionChannelTypes, int[] bucketToPartition, TypeOperators typeOperators) + public PartitionFunction getPartitionFunction(List partitionChannelTypes, int[] bucketToPartition, TypeOperators typeOperators, PartitionHashGeneratorCompiler partitionHashGeneratorCompiler) { requireNonNull(partitionChannelTypes, "partitionChannelTypes is null"); requireNonNull(bucketToPartition, "bucketToPartition is null"); - BucketFunction bucketFunction = function.createBucketFunction(partitionChannelTypes, bucketToPartition.length, typeOperators); - return new BucketPartitionFunction(bucketFunction, bucketToPartition); + if (function == SystemPartitionFunction.HASH) { + return new HashPartitionFunction(partitionHashGeneratorCompiler.getPartitionHashGenerator(partitionChannelTypes, null), bucketToPartition.length, bucketToPartition); + } + else { + BucketFunction bucketFunction = function.createBucketFunction(partitionChannelTypes, bucketToPartition.length, typeOperators); + return new BucketPartitionFunction(bucketFunction, bucketToPartition); + } } public enum SystemPartitionFunction @@ -158,7 +163,7 @@ public BucketFunction createBucketFunction(List partitionChannelTypes, int @Override public BucketFunction createBucketFunction(List partitionChannelTypes, int bucketCount, TypeOperators typeOperators) { - return new HashBucketFunction(createPagePrefixHashGenerator(partitionChannelTypes, typeOperators), bucketCount); + throw new UnsupportedOperationException(); } }, ROUND_ROBIN { diff --git a/core/trino-main/src/main/java/io/trino/testing/PlanTester.java b/core/trino-main/src/main/java/io/trino/testing/PlanTester.java index 82cca09a7894..f05c2d536403 100644 --- a/core/trino-main/src/main/java/io/trino/testing/PlanTester.java +++ b/core/trino-main/src/main/java/io/trino/testing/PlanTester.java @@ -119,6 +119,7 @@ import io.trino.operator.OutputFactory; import io.trino.operator.PagesIndex; import io.trino.operator.PagesIndexPageSorter; +import io.trino.operator.PartitionHashGeneratorCompiler; import io.trino.operator.TaskContext; import io.trino.operator.index.IndexJoinLookupStats; import io.trino.operator.index.IndexManager; @@ -320,6 +321,7 @@ public class PlanTester private final TaskManagerConfig taskManagerConfig; private final OptimizerConfig optimizerConfig; private final StatementAnalyzerFactory statementAnalyzerFactory; + private final PartitionHashGeneratorCompiler partitionHashGeneratorCompiler; private boolean printPlan; public static PlanTester create(Session defaultSession) @@ -389,6 +391,7 @@ private PlanTester(Session defaultSession, int nodeCountForStats) typeRegistry.addType(new JsonPath2016Type(new TypeDeserializer(typeManager), blockEncodingSerde)); this.joinCompiler = new JoinCompiler(typeOperators); this.hashStrategyCompiler = new FlatHashStrategyCompiler(typeOperators); + this.partitionHashGeneratorCompiler = new PartitionHashGeneratorCompiler(typeOperators); PageIndexerFactory pageIndexerFactory = new GroupByHashPageIndexerFactory(hashStrategyCompiler); EventListenerManager eventListenerManager = new EventListenerManager(new EventListenerConfig(), secretsResolver, noop(), tracer, CURRENT_NODE.getNodeVersion()); this.accessControl = new TestingAccessControlManager(transactionManager, eventListenerManager, secretsResolver); @@ -793,7 +796,8 @@ private List createDrivers(Session session, Plan plan, OutputFactory out tableExecuteContextManager, exchangeManagerRegistry, CURRENT_NODE.getNodeVersion(), - new CompilerConfig()); + new CompilerConfig(), + partitionHashGeneratorCompiler); // plan query LocalExecutionPlan localExecutionPlan = executionPlanner.plan( diff --git a/core/trino-main/src/main/java/io/trino/type/BlockTypeOperators.java b/core/trino-main/src/main/java/io/trino/type/BlockTypeOperators.java index 75c8685a0cb7..5983229dfa06 100644 --- a/core/trino-main/src/main/java/io/trino/type/BlockTypeOperators.java +++ b/core/trino-main/src/main/java/io/trino/type/BlockTypeOperators.java @@ -72,6 +72,11 @@ public BlockTypeOperators(TypeOperators typeOperators) .expireAfterWrite(2, TimeUnit.HOURS)); } + public TypeOperators getTypeOperators() + { + return typeOperators; + } + public BlockPositionEqual getEqualOperator(Type type) { return getBlockOperator(type, BlockPositionEqual.class, () -> typeOperators.getEqualOperator(type, BLOCK_EQUAL_CONVENTION)); diff --git a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java index 8028b022ddf6..c0be45b5e5c5 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java @@ -29,6 +29,7 @@ import io.trino.metadata.Split; import io.trino.operator.FlatHashStrategyCompiler; import io.trino.operator.PagesIndex; +import io.trino.operator.PartitionHashGeneratorCompiler; import io.trino.operator.index.IndexJoinLookupStats; import io.trino.operator.index.IndexManager; import io.trino.server.protocol.spooling.QueryDataEncoders; @@ -183,7 +184,8 @@ public static LocalExecutionPlanner createTestingPlanner() new TableExecuteContextManager(), new ExchangeManagerRegistry(noop(), noopTracer(), new SecretsResolver(ImmutableMap.of()), new ExchangeManagerConfig()), new NodeVersion("test"), - new CompilerConfig()); + new CompilerConfig(), + new PartitionHashGeneratorCompiler(PLANNER_CONTEXT.getTypeOperators())); } public static TaskInfo updateTask(SqlTask sqlTask, List splitAssignments, OutputBuffers outputBuffers) diff --git a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java index ff93415bd28a..1a480d3e134e 100644 --- a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java +++ b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java @@ -21,6 +21,7 @@ import io.trino.block.BlockAssertions; import io.trino.connector.CatalogHandle; import io.trino.operator.PageAssertions; +import io.trino.operator.PartitionHashGeneratorCompiler; import io.trino.operator.exchange.LocalExchange.LocalExchangeSinkFactory; import io.trino.spi.Page; import io.trino.spi.block.Block; @@ -55,7 +56,6 @@ import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.trino.SystemSessionProperties.QUERY_MAX_MEMORY_PER_NODE; import static io.trino.SystemSessionProperties.SKEWED_PARTITION_MIN_DATA_PROCESSED_REBALANCE_THRESHOLD; -import static io.trino.operator.InterpretedHashGenerator.createChannelsHashGenerator; import static io.trino.spi.connector.ConnectorBucketNodeMap.createBucketNodeMap; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; @@ -85,6 +85,7 @@ public class TestLocalExchange private static final DataSize WRITER_SCALING_MIN_DATA_PROCESSED = DataSize.of(32, MEGABYTE); private static final Supplier TOTAL_MEMORY_USED = () -> 0L; private static final OptionalInt BUCKET_COUNT = OptionalInt.of(8); + private static final PartitionHashGeneratorCompiler partitionHashGeneratorCompiler = new PartitionHashGeneratorCompiler(TYPE_OPERATORS); private final ConcurrentMap partitionManagers = new ConcurrentHashMap<>(); private PartitionFunctionProvider functionProvider; @@ -114,9 +115,9 @@ public void testGatherSingleWriter() ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(retainedSizeOfPages(99)), - TYPE_OPERATORS, WRITER_SCALING_MIN_DATA_PROCESSED, - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(1); @@ -188,9 +189,9 @@ public void testRandom() ImmutableList.of(), ImmutableList.of(), LOCAL_EXCHANGE_MAX_BUFFERED_BYTES, - TYPE_OPERATORS, WRITER_SCALING_MIN_DATA_PROCESSED, - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(2); @@ -238,9 +239,9 @@ public void testScaleWriter() ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(retainedSizeOfPages(4)), - TYPE_OPERATORS, DataSize.ofBytes(sizeOfPages(2)), - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(3); @@ -298,9 +299,9 @@ public void testNoWriterScalingWhenOnlyBufferSizeLimitIsExceeded() ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(retainedSizeOfPages(4)), - TYPE_OPERATORS, DataSize.ofBytes(sizeOfPages(10)), - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(3); @@ -349,9 +350,9 @@ private void testScalingWithTwoDifferentPartitions(PartitioningHandle partitioni ImmutableList.of(0), TYPES, DataSize.ofBytes(retainedSizeOfPages(2)), - TYPE_OPERATORS, DataSize.of(10, KILOBYTE), - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(4); @@ -458,9 +459,9 @@ public void testScaledWriterRoundRobinExchangerWhenTotalMemoryUsedIsGreaterThanL ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(retainedSizeOfPages(4)), - TYPE_OPERATORS, DataSize.ofBytes(sizeOfPages(2)), - totalMemoryUsed::get); + totalMemoryUsed::get, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(3); @@ -502,9 +503,9 @@ public void testNoWriterScalingWhenOnlyWriterScalingMinDataProcessedLimitIsExcee ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(retainedSizeOfPages(20)), - TYPE_OPERATORS, DataSize.ofBytes(sizeOfPages(2)), - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(3); @@ -555,9 +556,9 @@ private void testScalingForSkewedWriters(PartitioningHandle partitioningHandle) ImmutableList.of(0), TYPES, DataSize.ofBytes(retainedSizeOfPages(2)), - TYPE_OPERATORS, DataSize.of(10, KILOBYTE), - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(4); @@ -651,9 +652,9 @@ private void testNoScalingWhenDataWrittenIsLessThanMinFileSize(PartitioningHandl ImmutableList.of(0), TYPES, DataSize.ofBytes(retainedSizeOfPages(2)), - TYPE_OPERATORS, DataSize.of(50, MEGABYTE), - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(4); @@ -721,9 +722,9 @@ private void testNoScalingWhenBufferUtilizationIsLessThanLimit(PartitioningHandl ImmutableList.of(0), TYPES, DataSize.of(50, MEGABYTE), - TYPE_OPERATORS, DataSize.of(10, KILOBYTE), - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(4); @@ -793,9 +794,9 @@ private void testNoScalingWhenTotalMemoryUsedIsGreaterThanLimit(PartitioningHand ImmutableList.of(0), TYPES, DataSize.ofBytes(retainedSizeOfPages(2)), - TYPE_OPERATORS, DataSize.of(10, KILOBYTE), - totalMemoryUsed::get); + totalMemoryUsed::get, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(4); @@ -880,9 +881,9 @@ private void testDoNotUpdateScalingStateWhenMemoryIsAboveLimit(PartitioningHandl ImmutableList.of(0), TYPES, DataSize.ofBytes(retainedSizeOfPages(2)), - TYPE_OPERATORS, DataSize.of(10, KILOBYTE), - totalMemoryUsed::get); + totalMemoryUsed::get, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(4); @@ -973,9 +974,9 @@ public void testNoScalingWhenNoWriterSkewness() ImmutableList.of(0), TYPES, DataSize.ofBytes(retainedSizeOfPages(2)), - TYPE_OPERATORS, DataSize.of(50, KILOBYTE), - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(2); @@ -1021,9 +1022,9 @@ public void testPassthrough() ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(retainedSizeOfPages(1)), - TYPE_OPERATORS, WRITER_SCALING_MIN_DATA_PROCESSED, - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(2); @@ -1089,9 +1090,9 @@ public void testPartition() ImmutableList.of(0), TYPES, LOCAL_EXCHANGE_MAX_BUFFERED_BYTES, - TYPE_OPERATORS, WRITER_SCALING_MIN_DATA_PROCESSED, - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(2); @@ -1186,9 +1187,9 @@ public BucketFunction getBucketFunction(ConnectorTransactionHandle transactionHa ImmutableList.of(1), ImmutableList.of(BIGINT), LOCAL_EXCHANGE_MAX_BUFFERED_BYTES, - TYPE_OPERATORS, WRITER_SCALING_MIN_DATA_PROCESSED, - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(2); @@ -1238,9 +1239,9 @@ public void writeUnblockWhenAllReadersFinish() ImmutableList.of(), ImmutableList.of(), LOCAL_EXCHANGE_MAX_BUFFERED_BYTES, - TYPE_OPERATORS, WRITER_SCALING_MIN_DATA_PROCESSED, - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(2); @@ -1286,9 +1287,9 @@ public void writeUnblockWhenAllReadersFinishAndPagesConsumed() ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(2), - TYPE_OPERATORS, WRITER_SCALING_MIN_DATA_PROCESSED, - TOTAL_MEMORY_USED); + TOTAL_MEMORY_USED, + partitionHashGeneratorCompiler); run(localExchange, exchange -> { assertThat(exchange.getBufferCount()).isEqualTo(2); @@ -1438,8 +1439,8 @@ private static void assertPartitionedRemovePage(LocalExchangeSource source, int assertThat(source.waitForReading().isDone()).isTrue(); Page page = source.removePage(); assertThat(page).isNotNull(); - - LocalPartitionGenerator partitionGenerator = new LocalPartitionGenerator(createChannelsHashGenerator(TYPES, new int[] {0}, TYPE_OPERATORS), partitionCount); + PartitionHashGeneratorCompiler partitionHashGeneratorCompiler = new PartitionHashGeneratorCompiler(TYPE_OPERATORS); + LocalPartitionGenerator partitionGenerator = new LocalPartitionGenerator(partitionHashGeneratorCompiler.getPartitionHashGenerator(TYPES, new int[] {0}), partitionCount); for (int position = 0; position < page.getPositionCount(); position++) { assertThat(partitionGenerator.getPartition(page, position)).isEqualTo(partition); } diff --git a/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java index 87c47c021714..67fe9ced5589 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java @@ -23,6 +23,7 @@ import io.trino.operator.DriverContext; import io.trino.operator.OperatorFactory; import io.trino.operator.PagesIndex; +import io.trino.operator.PartitionHashGeneratorCompiler; import io.trino.operator.PipelineContext; import io.trino.operator.SpillContext; import io.trino.operator.TaskContext; @@ -70,6 +71,7 @@ public final class JoinTestUtils { private static final int PARTITION_COUNT = 4; private static final TypeOperators TYPE_OPERATORS = new TypeOperators(); + private static final PartitionHashGeneratorCompiler partitionHashGeneratorCompiler = new PartitionHashGeneratorCompiler(TYPE_OPERATORS); private JoinTestUtils() {} @@ -146,9 +148,9 @@ public static BuildSideSetup setupBuildSide( hashChannels, hashChannelTypes, DataSize.of(32, DataSize.Unit.MEGABYTE), - TYPE_OPERATORS, DataSize.of(32, DataSize.Unit.MEGABYTE), - () -> 0L); + () -> 0L, + partitionHashGeneratorCompiler); // collect input data into the partitioned exchange DriverContext collectDriverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); diff --git a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java index b1af1927e08f..d0e0fd6ab18c 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java @@ -20,6 +20,7 @@ import io.trino.operator.DriverContext; import io.trino.operator.OperatorFactory; import io.trino.operator.PagesIndex; +import io.trino.operator.PartitionHashGeneratorCompiler; import io.trino.operator.PipelineContext; import io.trino.operator.TaskContext; import io.trino.operator.ValuesOperator; @@ -62,6 +63,7 @@ public final class JoinTestUtils { private static final int PARTITION_COUNT = 4; private static final TypeOperators TYPE_OPERATORS = new TypeOperators(); + private static final PartitionHashGeneratorCompiler partitionHashGeneratorCompiler = new PartitionHashGeneratorCompiler(TYPE_OPERATORS); private JoinTestUtils() {} @@ -145,9 +147,9 @@ public static BuildSideSetup setupBuildSide( hashChannels, hashChannelTypes, DataSize.of(32, DataSize.Unit.MEGABYTE), - TYPE_OPERATORS, DataSize.of(32, DataSize.Unit.MEGABYTE), - () -> 0L); + () -> 0L, + partitionHashGeneratorCompiler); // collect input data into the partitioned exchange DriverContext collectDriverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); diff --git a/core/trino-main/src/test/java/io/trino/operator/output/BenchmarkPartitionedOutputOperator.java b/core/trino-main/src/test/java/io/trino/operator/output/BenchmarkPartitionedOutputOperator.java index 66e6ab4371db..6f8e923d03ab 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/BenchmarkPartitionedOutputOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/BenchmarkPartitionedOutputOperator.java @@ -27,11 +27,10 @@ import io.trino.jmh.Benchmarks; import io.trino.memory.context.LocalMemoryContext; import io.trino.memory.context.SimpleLocalMemoryContext; -import io.trino.operator.BucketPartitionFunction; import io.trino.operator.DriverContext; import io.trino.operator.PageTestUtils; import io.trino.operator.PartitionFunction; -import io.trino.operator.PrecomputedHashGenerator; +import io.trino.operator.PartitionHashGeneratorCompiler; import io.trino.operator.output.PartitionedOutputOperator.PartitionedOutputFactory; import io.trino.spi.Page; import io.trino.spi.QueryId; @@ -48,7 +47,7 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.spi.type.VarcharType; -import io.trino.sql.planner.HashBucketFunction; +import io.trino.sql.planner.HashPartitionFunction; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.TestingTaskContext; import io.trino.type.BlockTypeOperators; @@ -95,7 +94,6 @@ import static io.trino.execution.buffer.TestingPagesSerdes.createTestingPagesSerdeFactory; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.operator.output.BenchmarkPartitionedOutputOperator.BenchmarkData.TestType; -import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.Decimals.MAX_SHORT_PRECISION; import static java.util.Collections.nCopies; @@ -113,7 +111,9 @@ @BenchmarkMode(Mode.AverageTime) public class BenchmarkPartitionedOutputOperator { - private static final PositionsAppenderFactory POSITIONS_APPENDER_FACTORY = new PositionsAppenderFactory(new BlockTypeOperators()); + private static final BlockTypeOperators BLOCK_TYPE_OPERATORS = new BlockTypeOperators(); + private static final PositionsAppenderFactory POSITIONS_APPENDER_FACTORY = new PositionsAppenderFactory(BLOCK_TYPE_OPERATORS); + private static final PartitionHashGeneratorCompiler partitionHashGeneratorCompiler = new PartitionHashGeneratorCompiler(BLOCK_TYPE_OPERATORS.getTypeOperators()); @Benchmark public void addPage(BenchmarkData data) @@ -153,6 +153,9 @@ public static class BenchmarkData @Param("8192") private int positionCount = DEFAULT_POSITION_COUNT; + @Param({"5", "50", "100"}) + private int columns = 5; + @Param({ // Flat BIGINT data channel, flat BIGINT partition channel. "BIGINT", @@ -382,6 +385,11 @@ public void setPositionCount(int positionCount) this.positionCount = positionCount; } + public void setColumns(int columns) + { + this.columns = columns; + } + public void setType(TestType type) { this.type = requireNonNull(type, "type is null"); @@ -405,11 +413,17 @@ private void setupData(Blackhole blackhole) // and in case of unit test it will be null this.blackhole = blackhole; types = type.getTypes(channelCount); + List types2 = new ArrayList(); + for (int i = 0; i < columns; i++) { + types2.add(type.type); + } + types = ImmutableList.builder() + .addAll(types2) + .build(); dataPage = type.getPageGenerator().createPage(types, positionCount, nullRate); pageCount = type.getPageCount(); types = ImmutableList.builder() .addAll(types) - .add(BIGINT) // dataPage has pre-computed hash block at the last channel .build(); } @@ -439,8 +453,9 @@ private PartitionedOutputBuffer createPartitionedOutputBuffer() private PartitionedOutputOperator createPartitionedOutputOperator() { - PartitionFunction partitionFunction = new BucketPartitionFunction( - new HashBucketFunction(new PrecomputedHashGenerator(0), partitionCount), + PartitionFunction partitionFunction = new HashPartitionFunction( + partitionHashGeneratorCompiler.getPartitionHashGenerator(types, null), + partitionCount, IntStream.range(0, partitionCount).toArray()); PagesSerdeFactory serdeFactory = createTestingPagesSerdeFactory(compressionCodec); @@ -448,7 +463,7 @@ private PartitionedOutputOperator createPartitionedOutputOperator() PartitionedOutputFactory operatorFactory = new PartitionedOutputFactory( partitionFunction, - ImmutableList.of(types.size() - 1), // hash block is at the last channel + IntStream.rangeClosed(0, types.size() - 1).boxed().toList(), ImmutableList.of(Optional.empty()), false, OptionalInt.empty(), @@ -562,6 +577,7 @@ private static void pollute() data.setPositionCount(256); data.setupData(null); data.setPageCount(50); + data.setColumns(50); benchmark.addPage(data); }); } diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java b/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java index a7b8c81589ce..d36cf5da2d42 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestPagePartitioner.java @@ -36,6 +36,7 @@ import io.trino.operator.OperatorFactory; import io.trino.operator.OutputFactory; import io.trino.operator.PartitionFunction; +import io.trino.operator.PartitionHashGeneratorCompiler; import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; @@ -45,6 +46,7 @@ import io.trino.spi.type.Decimals; import io.trino.spi.type.TimestampType; import io.trino.spi.type.Type; +import io.trino.sql.planner.HashPartitionFunction; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.TestingTaskContext; import io.trino.type.BlockTypeOperators; @@ -726,6 +728,68 @@ private void testMemoryReleasedOnFailure(PartitioningMode partitioningMode) assertThat(memoryContext.getBytes()).isEqualTo(0); } + @Test + public void testOutputForVectorizedHashGeneration() + { + testVectorizedHashGeneration(PartitioningMode.ROW_WISE); + testVectorizedHashGeneration(PartitioningMode.COLUMNAR); + } + + private void testVectorizedHashGeneration(PartitioningMode partitioningMode) + { + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + BlockTypeOperators blockTypeOperators = new BlockTypeOperators(); + PartitionHashGeneratorCompiler partitionHashGeneratorCompiler = new PartitionHashGeneratorCompiler(blockTypeOperators.getTypeOperators()); + List types = new ArrayList<>(); + types.add(BIGINT); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, types.get(0)) + .withPartitionFunction(new HashPartitionFunction( + partitionHashGeneratorCompiler.getPartitionHashGenerator(types, null), + PARTITION_COUNT, + IntStream.range(0, PARTITION_COUNT).toArray())) + .withPartitionChannels(ImmutableList.of(0)) + .build(); + Page page = new Page(createLongSequenceBlock(0, POSITIONS_PER_PAGE)); + + processPages(pagePartitioner, partitioningMode, page); + + List partition0 = readLongs(outputBuffer.getEnqueuedDeserialized(0), 0); + assertThat(partition0).containsExactly(0L, 2L, 6L); + List partition1 = readLongs(outputBuffer.getEnqueuedDeserialized(1), 0); + assertThat(partition1).containsExactly(1L, 3L, 4L, 5L, 7L); + } + + @Test + public void testOutputForVectorizedHashGenerationForDictionaryBlock() + { + testVectorizedHashGenerationForDictionaryBlock(PartitioningMode.ROW_WISE); + testVectorizedHashGenerationForDictionaryBlock(PartitioningMode.COLUMNAR); + } + + private void testVectorizedHashGenerationForDictionaryBlock(PartitioningMode partitioningMode) + { + TestOutputBuffer outputBuffer = new TestOutputBuffer(); + BlockTypeOperators blockTypeOperators = new BlockTypeOperators(); + PartitionHashGeneratorCompiler partitionHashGeneratorCompiler = new PartitionHashGeneratorCompiler(blockTypeOperators.getTypeOperators()); + List types = new ArrayList<>(); + types.add(BIGINT); + PagePartitioner pagePartitioner = pagePartitioner(outputBuffer, types.get(0)) + .withPartitionFunction(new HashPartitionFunction( + partitionHashGeneratorCompiler.getPartitionHashGenerator(types, null), + PARTITION_COUNT, + IntStream.range(0, PARTITION_COUNT).toArray())) + .withPartitionChannels(ImmutableList.of(0)) + .build(); + Page page = new Page(createLongDictionaryBlock(0, 10)); + + processPages(pagePartitioner, partitioningMode, page); + + List partition0 = readLongs(outputBuffer.getEnqueuedDeserialized(0), 0); + assertThat(partition0).containsExactlyElementsOf(nCopies(5, 0L)); + List partition1 = readLongs(outputBuffer.getEnqueuedDeserialized(1), 0); + assertThat(partition1).containsExactlyElementsOf(nCopies(5, 1L)); + } + private void testOutputEqualsInput(Type type, PartitioningMode mode1, PartitioningMode mode2) { TestOutputBuffer outputBuffer = new TestOutputBuffer(); @@ -955,8 +1019,7 @@ public PagePartitioner build() PARTITION_MAX_MEMORY, POSITIONS_APPENDER_FACTORY, Optional.empty(), - memoryContext, - true); + memoryContext); } } @@ -1118,6 +1181,18 @@ public int getPartition(Page page, int position) return toIntExact(Math.abs(value) % partitionCount); } + + @Override + public void getPartitions(Page page, int[] partitions, long[] rawHashes, int offset, int length) + { + for (int i = 0; i < length; i++) { + long value = 0; + for (int hashChannel : hashChannels) { + value += BIGINT.getLong(page.getBlock(hashChannel), offset + i); + } + partitions[i] = toIntExact(Math.abs(value) % partitionCount); + } + } } private static final class SinglePartitionFailIfCalled @@ -1134,5 +1209,11 @@ public int getPartition(Page page, int position) { throw new UnsupportedOperationException("getPartition should not be called on single partitioned outputs"); } + + @Override + public void getPartitions(Page page, int[] partitions, long[] rawHashes, int offset, int length) + { + throw new UnsupportedOperationException("getPartitions should not be called on single partitioned outputs"); + } } } diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestSkewedPartitionRebalancer.java b/core/trino-main/src/test/java/io/trino/operator/output/TestSkewedPartitionRebalancer.java index 548f25b6da20..28cd18be6b4b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestSkewedPartitionRebalancer.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestSkewedPartitionRebalancer.java @@ -301,5 +301,13 @@ public int getPartition(Page page, int position) { return position % partitionCount; } + + @Override + public void getPartitions(Page page, int[] partitions, long[] rawHashes, int offset, int length) + { + for (int i = 0; i < length; i++) { + partitions[i] = (offset + i) % partitionCount; + } + } } } diff --git a/core/trino-main/src/test/java/io/trino/spiller/TestGenericPartitioningSpiller.java b/core/trino-main/src/test/java/io/trino/spiller/TestGenericPartitioningSpiller.java index 78866d16b688..391de442c0ad 100644 --- a/core/trino-main/src/test/java/io/trino/spiller/TestGenericPartitioningSpiller.java +++ b/core/trino-main/src/test/java/io/trino/spiller/TestGenericPartitioningSpiller.java @@ -265,6 +265,24 @@ public int getPartition(Page page, int position) } return 0; } + + @Override + public void getPartitions(Page page, int[] partitions, long[] rawHashes, int offset, int length) + { + for (int i = 0; i < length; i++) { + long value = BIGINT.getLong(page.getBlock(valueChannel), offset + i); + if (value >= FOURTH_PARTITION_START) { + partitions[i] = 3; + } + if (value >= THIRD_PARTITION_START) { + partitions[i] = 2; + } + if (value >= SECOND_PARTITION_START) { + partitions[i] = 1; + } + partitions[i] = 0; + } + } } private static class ModuloPartitionFunction @@ -292,5 +310,14 @@ public int getPartition(Page page, int position) long value = BIGINT.getLong(page.getBlock(valueChannel), position); return toIntExact(Math.abs(value) % partitionCount); } + + @Override + public void getPartitions(Page page, int[] partitions, long[] rawHashes, int offset, int length) + { + for (int i = 0; i < length; i++) { + long value = BIGINT.getLong(page.getBlock(valueChannel), offset + i); + partitions[i] = toIntExact(Math.abs(value) % partitionCount); + } + } } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/Page.java b/core/trino-spi/src/main/java/io/trino/spi/Page.java index 39fec58c8dde..4dd02421eca5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/Page.java +++ b/core/trino-spi/src/main/java/io/trino/spi/Page.java @@ -292,6 +292,11 @@ private long updateRetainedSize() return retainedSizeInBytes; } + public Block[] getBlocks() + { + return blocks; + } + private static class DictionaryBlockIndexes { private final List blocks = new ArrayList<>();