diff --git a/core/trino-main/src/main/java/io/trino/operator/OperatorFactories.java b/core/trino-main/src/main/java/io/trino/operator/OperatorFactories.java index 1e5baa0ba3ce..c9c8503ed1d0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/OperatorFactories.java +++ b/core/trino-main/src/main/java/io/trino/operator/OperatorFactories.java @@ -13,11 +13,8 @@ */ package io.trino.operator; -import io.airlift.units.DataSize; -import io.trino.execution.buffer.OutputBuffer; import io.trino.operator.join.JoinBridgeManager; import io.trino.operator.join.LookupSourceFactory; -import io.trino.spi.predicate.NullableValue; import io.trino.spi.type.Type; import io.trino.spiller.PartitioningSpillerFactory; import io.trino.sql.planner.plan.PlanNodeId; @@ -84,14 +81,4 @@ OperatorFactory fullOuterJoin( OptionalInt totalOperatorsCount, PartitioningSpillerFactory partitioningSpillerFactory, BlockTypeOperators blockTypeOperators); - - OutputFactory partitionedOutput( - TaskContext taskContext, - PartitionFunction partitionFunction, - List partitionChannels, - List> partitionConstants, - boolean replicateNullsAndAny, - OptionalInt nullChannel, - OutputBuffer outputBuffer, - DataSize maxPagePartitioningBufferSize); } diff --git a/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java b/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java index 1c0410a89f43..37269f681eb7 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java +++ b/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java @@ -13,14 +13,11 @@ */ package io.trino.operator; -import io.airlift.units.DataSize; -import io.trino.execution.buffer.OutputBuffer; import io.trino.operator.join.JoinBridgeManager; import io.trino.operator.join.JoinProbe.JoinProbeFactory; import io.trino.operator.join.LookupJoinOperatorFactory; import io.trino.operator.join.LookupJoinOperatorFactory.JoinType; import io.trino.operator.join.LookupSourceFactory; -import io.trino.spi.predicate.NullableValue; import io.trino.spi.type.Type; import io.trino.spiller.PartitioningSpillerFactory; import io.trino.sql.planner.plan.PlanNodeId; @@ -36,7 +33,6 @@ import static io.trino.operator.join.LookupJoinOperatorFactory.JoinType.INNER; import static io.trino.operator.join.LookupJoinOperatorFactory.JoinType.LOOKUP_OUTER; import static io.trino.operator.join.LookupJoinOperatorFactory.JoinType.PROBE_OUTER; -import static io.trino.operator.output.PartitionedOutputOperator.PartitionedOutputFactory; public class TrinoOperatorFactories implements OperatorFactories @@ -165,27 +161,6 @@ public OperatorFactory fullOuterJoin( blockTypeOperators); } - @Override - public OutputFactory partitionedOutput( - TaskContext taskContext, - PartitionFunction partitionFunction, - List partitionChannels, - List> partitionConstants, - boolean replicateNullsAndAny, - OptionalInt nullChannel, - OutputBuffer outputBuffer, - DataSize maxPagePartitioningBufferSize) - { - return new PartitionedOutputFactory( - partitionFunction, - partitionChannels, - partitionConstants, - replicateNullsAndAny, - nullChannel, - outputBuffer, - maxPagePartitioningBufferSize); - } - private static List rangeList(int endExclusive) { return IntStream.range(0, endExclusive) diff --git a/core/trino-main/src/main/java/io/trino/operator/output/DefaultPagePartitioner.java b/core/trino-main/src/main/java/io/trino/operator/output/DefaultPagePartitioner.java deleted file mode 100644 index 465c15e932dc..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/output/DefaultPagePartitioner.java +++ /dev/null @@ -1,271 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.output; - -import com.google.common.collect.ImmutableList; -import com.google.common.primitives.Ints; -import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.slice.Slice; -import io.airlift.units.DataSize; -import io.trino.execution.buffer.OutputBuffer; -import io.trino.execution.buffer.PagesSerde; -import io.trino.execution.buffer.PagesSerdeFactory; -import io.trino.operator.OperatorContext; -import io.trino.operator.PartitionFunction; -import io.trino.operator.output.PartitionedOutputOperator.PartitionedOutputInfo; -import io.trino.spi.Page; -import io.trino.spi.PageBuilder; -import io.trino.spi.block.Block; -import io.trino.spi.block.RunLengthEncodedBlock; -import io.trino.spi.predicate.NullableValue; -import io.trino.spi.type.Type; - -import javax.annotation.Nullable; - -import java.util.Arrays; -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.OptionalInt; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Supplier; - -import static com.google.common.base.Preconditions.checkArgument; -import static io.trino.execution.buffer.PageSplitterUtil.splitPage; -import static io.trino.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; -import static java.lang.Math.max; -import static java.lang.Math.min; -import static java.lang.Math.toIntExact; -import static java.util.Objects.requireNonNull; - -public class DefaultPagePartitioner - implements PagePartitioner -{ - private final OutputBuffer outputBuffer; - private final Type[] sourceTypes; - private final PartitionFunction partitionFunction; - private final int[] partitionChannels; - @Nullable - private final Block[] partitionConstantBlocks; // when null, no constants are present. Only non-null elements are constants - private final PagesSerde serde; - private final PageBuilder[] pageBuilders; - private final boolean replicatesAnyRow; - private final int nullChannel; // when >= 0, send the position to every partition if this channel is null - private final AtomicLong rowsAdded = new AtomicLong(); - private final AtomicLong pagesAdded = new AtomicLong(); - private boolean hasAnyRowBeenReplicated; - private final OperatorContext operatorContext; - - public DefaultPagePartitioner( - PartitionFunction partitionFunction, - List partitionChannels, - List> partitionConstants, - boolean replicatesAnyRow, - OptionalInt nullChannel, - OutputBuffer outputBuffer, - PagesSerdeFactory serdeFactory, - List sourceTypes, - DataSize maxMemory, - OperatorContext operatorContext) - { - this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null"); - this.partitionChannels = Ints.toArray(requireNonNull(partitionChannels, "partitionChannels is null")); - Block[] partitionConstantBlocks = requireNonNull(partitionConstants, "partitionConstants is null").stream() - .map(constant -> constant.map(NullableValue::asBlock).orElse(null)) - .toArray(Block[]::new); - if (Arrays.stream(partitionConstantBlocks).anyMatch(Objects::nonNull)) { - this.partitionConstantBlocks = partitionConstantBlocks; - } - else { - this.partitionConstantBlocks = null; - } - this.replicatesAnyRow = replicatesAnyRow; - this.nullChannel = requireNonNull(nullChannel, "nullChannel is null").orElse(-1); - this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); - this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null").toArray(new Type[0]); - this.serde = requireNonNull(serdeFactory, "serdeFactory is null").createPagesSerde(); - this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); - - // Ensure partition channels align with constant arguments provided - for (int i = 0; i < this.partitionChannels.length; i++) { - if (this.partitionChannels[i] < 0) { - checkArgument(this.partitionConstantBlocks != null && this.partitionConstantBlocks[i] != null, - "Expected constant for partitioning channel %s, but none was found", i); - } - } - - int partitionCount = partitionFunction.getPartitionCount(); - int pageSize = toIntExact(min(DEFAULT_MAX_PAGE_SIZE_IN_BYTES, maxMemory.toBytes() / partitionCount)); - pageSize = max(1, pageSize); - - this.pageBuilders = new PageBuilder[partitionCount]; - for (int i = 0; i < partitionCount; i++) { - pageBuilders[i] = PageBuilder.withMaxPageSize(pageSize, sourceTypes); - } - } - - @Override - public ListenableFuture isFull() - { - return outputBuffer.isFull(); - } - - @Override - public long getSizeInBytes() - { - // We use a foreach loop instead of streams - // as it has much better performance. - long sizeInBytes = 0; - for (PageBuilder pageBuilder : pageBuilders) { - sizeInBytes += pageBuilder.getSizeInBytes(); - } - return sizeInBytes; - } - - /** - * This method can be expensive for complex types. - */ - @Override - public long getRetainedSizeInBytes() - { - long sizeInBytes = 0; - for (PageBuilder pageBuilder : pageBuilders) { - sizeInBytes += pageBuilder.getRetainedSizeInBytes(); - } - return sizeInBytes; - } - - @Override - public Supplier getOperatorInfoSupplier() - { - return createPartitionedOutputOperatorInfoSupplier(rowsAdded, pagesAdded, outputBuffer); - } - - private static Supplier createPartitionedOutputOperatorInfoSupplier(AtomicLong rowsAdded, AtomicLong pagesAdded, OutputBuffer outputBuffer) - { - // Must be a separate static method to avoid embedding references to "this" in the supplier - requireNonNull(rowsAdded, "rowsAdded is null"); - requireNonNull(pagesAdded, "pagesAdded is null"); - requireNonNull(outputBuffer, "outputBuffer is null"); - return () -> new PartitionedOutputInfo(rowsAdded.get(), pagesAdded.get(), outputBuffer.getPeakMemoryUsage()); - } - - @Override - public void partitionPage(Page page) - { - requireNonNull(page, "page is null"); - if (page.getPositionCount() == 0) { - return; - } - - int position; - // Handle "any row" replication outside of the inner loop processing - if (replicatesAnyRow && !hasAnyRowBeenReplicated) { - for (PageBuilder pageBuilder : pageBuilders) { - appendRow(pageBuilder, page, 0); - } - hasAnyRowBeenReplicated = true; - position = 1; - } - else { - position = 0; - } - - Page partitionFunctionArgs = getPartitionFunctionArguments(page); - // Skip null block checks if mayHaveNull reports that no positions will be null - if (nullChannel >= 0 && page.getBlock(nullChannel).mayHaveNull()) { - Block nullsBlock = page.getBlock(nullChannel); - for (; position < page.getPositionCount(); position++) { - if (nullsBlock.isNull(position)) { - for (PageBuilder pageBuilder : pageBuilders) { - appendRow(pageBuilder, page, position); - } - } - else { - int partition = partitionFunction.getPartition(partitionFunctionArgs, position); - appendRow(pageBuilders[partition], page, position); - } - } - } - else { - for (; position < page.getPositionCount(); position++) { - int partition = partitionFunction.getPartition(partitionFunctionArgs, position); - appendRow(pageBuilders[partition], page, position); - } - } - - flush(false); - } - - private Page getPartitionFunctionArguments(Page page) - { - // Fast path for no constants - if (partitionConstantBlocks == null) { - return page.getColumns(partitionChannels); - } - - Block[] blocks = new Block[partitionChannels.length]; - for (int i = 0; i < blocks.length; i++) { - int channel = partitionChannels[i]; - if (channel < 0) { - blocks[i] = new RunLengthEncodedBlock(partitionConstantBlocks[i], page.getPositionCount()); - } - else { - blocks[i] = page.getBlock(channel); - } - } - return new Page(page.getPositionCount(), blocks); - } - - private void appendRow(PageBuilder pageBuilder, Page page, int position) - { - pageBuilder.declarePosition(); - - for (int channel = 0; channel < sourceTypes.length; channel++) { - Type type = sourceTypes[channel]; - type.appendTo(page.getBlock(channel), position, pageBuilder.getBlockBuilder(channel)); - } - } - - @Override - public void flush(boolean force) - { - try (PagesSerde.PagesSerdeContext context = serde.newContext()) { - // add all full pages to output buffer - for (int partition = 0; partition < pageBuilders.length; partition++) { - PageBuilder partitionPageBuilder = pageBuilders[partition]; - if (!partitionPageBuilder.isEmpty() && (force || partitionPageBuilder.isFull())) { - Page pagePartition = partitionPageBuilder.build(); - partitionPageBuilder.reset(); - - operatorContext.recordOutput(pagePartition.getSizeInBytes(), pagePartition.getPositionCount()); - - outputBuffer.enqueue(partition, splitAndSerializePage(context, pagePartition)); - pagesAdded.incrementAndGet(); - rowsAdded.addAndGet(pagePartition.getPositionCount()); - } - } - } - } - - private List splitAndSerializePage(PagesSerde.PagesSerdeContext context, Page page) - { - List split = splitPage(page, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); - ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(split.size()); - for (Page p : split) { - builder.add(serde.serialize(context, p)); - } - return builder.build(); - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java index 9311d40dce22..a62072256d36 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitioner.java @@ -13,22 +13,496 @@ */ package io.trino.operator.output; +import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.slice.Slice; +import io.airlift.units.DataSize; +import io.trino.execution.buffer.OutputBuffer; +import io.trino.execution.buffer.PagesSerde; +import io.trino.execution.buffer.PagesSerdeFactory; +import io.trino.operator.OperatorContext; +import io.trino.operator.PartitionFunction; import io.trino.spi.Page; +import io.trino.spi.PageBuilder; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.predicate.NullableValue; +import io.trino.spi.type.Type; +import it.unimi.dsi.fastutil.ints.IntArrayList; +import it.unimi.dsi.fastutil.ints.IntList; +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.IntUnaryOperator; import java.util.function.Supplier; -public interface PagePartitioner +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static io.trino.execution.buffer.PageSplitterUtil.splitPage; +import static io.trino.operator.output.PartitionedOutputOperator.PartitionedOutputInfo; +import static io.trino.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + +public class PagePartitioner { - void partitionPage(Page page); + private static final int COLUMNAR_STRATEGY_COEFFICIENT = 4; + private final OutputBuffer outputBuffer; + private final Type[] sourceTypes; + private final PartitionFunction partitionFunction; + private final int[] partitionChannels; + @Nullable + private final Block[] partitionConstantBlocks; // when null, no constants are present. Only non-null elements are constants + private final PagesSerde serde; + private final PageBuilder[] pageBuilders; + private final boolean replicatesAnyRow; + private final int nullChannel; // when >= 0, send the position to every partition if this channel is null + private final AtomicLong rowsAdded = new AtomicLong(); + private final AtomicLong pagesAdded = new AtomicLong(); + private final OperatorContext operatorContext; + private final PositionsAppenderFactory positionsAppenderFactory; + private final PositionsAppender[] positionsAppenders; + + private boolean hasAnyRowBeenReplicated; + + public PagePartitioner( + PartitionFunction partitionFunction, + List partitionChannels, + List> partitionConstants, + boolean replicatesAnyRow, + OptionalInt nullChannel, + OutputBuffer outputBuffer, + PagesSerdeFactory serdeFactory, + List sourceTypes, + DataSize maxMemory, + OperatorContext operatorContext, + PositionsAppenderFactory positionsAppenderFactory) + { + this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null"); + this.partitionChannels = Ints.toArray(requireNonNull(partitionChannels, "partitionChannels is null")); + this.positionsAppenderFactory = requireNonNull(positionsAppenderFactory, "positionsAppenderFactory is null"); + Block[] partitionConstantBlocks = requireNonNull(partitionConstants, "partitionConstants is null").stream() + .map(constant -> constant.map(NullableValue::asBlock).orElse(null)) + .toArray(Block[]::new); + if (Arrays.stream(partitionConstantBlocks).anyMatch(Objects::nonNull)) { + this.partitionConstantBlocks = partitionConstantBlocks; + } + else { + this.partitionConstantBlocks = null; + } + this.replicatesAnyRow = replicatesAnyRow; + this.nullChannel = requireNonNull(nullChannel, "nullChannel is null").orElse(-1); + this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); + this.sourceTypes = requireNonNull(sourceTypes, "sourceTypes is null").toArray(new Type[0]); + this.serde = requireNonNull(serdeFactory, "serdeFactory is null").createPagesSerde(); + this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); + + // Ensure partition channels align with constant arguments provided + for (int i = 0; i < this.partitionChannels.length; i++) { + if (this.partitionChannels[i] < 0) { + checkArgument(this.partitionConstantBlocks != null && this.partitionConstantBlocks[i] != null, + "Expected constant for partitioning channel %s, but none was found", i); + } + } + + int partitionCount = partitionFunction.getPartitionCount(); + int pageSize = toIntExact(min(DEFAULT_MAX_PAGE_SIZE_IN_BYTES, maxMemory.toBytes() / partitionCount)); + pageSize = max(1, pageSize); + + this.pageBuilders = new PageBuilder[partitionCount]; + for (int i = 0; i < partitionCount; i++) { + pageBuilders[i] = PageBuilder.withMaxPageSize(pageSize, sourceTypes); + } + positionsAppenders = new PositionsAppender[sourceTypes.size()]; + } + + public ListenableFuture isFull() + { + return outputBuffer.isFull(); + } + + public long getSizeInBytes() + { + // We use a foreach loop instead of streams + // as it has much better performance. + long sizeInBytes = 0; + for (PageBuilder pageBuilder : pageBuilders) { + sizeInBytes += pageBuilder.getSizeInBytes(); + } + return sizeInBytes; + } + + /** + * This method can be expensive for complex types. + */ + public long getRetainedSizeInBytes() + { + long sizeInBytes = 0; + for (PageBuilder pageBuilder : pageBuilders) { + sizeInBytes += pageBuilder.getRetainedSizeInBytes(); + } + return sizeInBytes; + } + + public Supplier getOperatorInfoSupplier() + { + return createPartitionedOutputOperatorInfoSupplier(rowsAdded, pagesAdded, outputBuffer); + } + + private static Supplier createPartitionedOutputOperatorInfoSupplier(AtomicLong rowsAdded, AtomicLong pagesAdded, OutputBuffer outputBuffer) + { + // Must be a separate static method to avoid embedding references to "this" in the supplier + requireNonNull(rowsAdded, "rowsAdded is null"); + requireNonNull(pagesAdded, "pagesAdded is null"); + requireNonNull(outputBuffer, "outputBuffer is null"); + return () -> new PartitionedOutputInfo(rowsAdded.get(), pagesAdded.get(), outputBuffer.getPeakMemoryUsage()); + } + + public void partitionPage(Page page) + { + if (page.getPositionCount() == 0) { + return; + } + + if (page.getPositionCount() < partitionFunction.getPartitionCount() * COLUMNAR_STRATEGY_COEFFICIENT) { + // Partition will have on average less than COLUMNAR_STRATEGY_COEFFICIENT rows. + // Doing it column-wise would degrade performance, so we fall back to row-wise approach. + // Performance degradation is the worst in case of skewed hash distribution when only small subset + // of partitions is selected. + partitionPageByRow(page); + } + else { + partitionPageByColumn(page); + } + } + + public void partitionPageByRow(Page page) + { + requireNonNull(page, "page is null"); + if (page.getPositionCount() == 0) { + return; + } + + int position; + // Handle "any row" replication outside of the inner loop processing + if (replicatesAnyRow && !hasAnyRowBeenReplicated) { + for (PageBuilder pageBuilder : pageBuilders) { + appendRow(pageBuilder, page, 0); + } + hasAnyRowBeenReplicated = true; + position = 1; + } + else { + position = 0; + } + + Page partitionFunctionArgs = getPartitionFunctionArguments(page); + // Skip null block checks if mayHaveNull reports that no positions will be null + if (nullChannel >= 0 && page.getBlock(nullChannel).mayHaveNull()) { + Block nullsBlock = page.getBlock(nullChannel); + for (; position < page.getPositionCount(); position++) { + if (nullsBlock.isNull(position)) { + for (PageBuilder pageBuilder : pageBuilders) { + appendRow(pageBuilder, page, position); + } + } + else { + int partition = partitionFunction.getPartition(partitionFunctionArgs, position); + appendRow(pageBuilders[partition], page, position); + } + } + } + else { + for (; position < page.getPositionCount(); position++) { + int partition = partitionFunction.getPartition(partitionFunctionArgs, position); + appendRow(pageBuilders[partition], page, position); + } + } + + flush(false); + } + + private void appendRow(PageBuilder pageBuilder, Page page, int position) + { + pageBuilder.declarePosition(); + + for (int channel = 0; channel < sourceTypes.length; channel++) { + Type type = sourceTypes[channel]; + type.appendTo(page.getBlock(channel), position, pageBuilder.getBlockBuilder(channel)); + } + } + + public void partitionPageByColumn(Page page) + { + IntArrayList[] partitionedPositions = partitionPositions(page); + + PositionsAppender[] positionsAppenders = getAppenders(page); + + for (int i = 0; i < partitionFunction.getPartitionCount(); i++) { + IntArrayList partitionPositions = partitionedPositions[i]; + if (!partitionPositions.isEmpty()) { + appendToOutputPartition(pageBuilders[i], page, partitionPositions, positionsAppenders); + partitionPositions.clear(); + } + } + + flush(false); + } + + private PositionsAppender[] getAppenders(Page page) + { + for (int i = 0; i < positionsAppenders.length; i++) { + positionsAppenders[i] = positionsAppenderFactory.create(sourceTypes[i], page.getBlock(i).getClass()); + } + return positionsAppenders; + } + + private IntArrayList[] partitionPositions(Page page) + { + verify(page.getPositionCount() > 0, "position count is 0"); + IntArrayList[] partitionPositions = initPositions(page); + int position; + // Handle "any row" replication outside the inner loop processing + if (replicatesAnyRow && !hasAnyRowBeenReplicated) { + for (IntList partitionPosition : partitionPositions) { + partitionPosition.add(0); + } + hasAnyRowBeenReplicated = true; + position = 1; + } + else { + position = 0; + } + + Page partitionFunctionArgs = getPartitionFunctionArguments(page); + + if (partitionFunctionArgs.getChannelCount() > 0 && onlyRleBlocks(partitionFunctionArgs)) { + // we need at least one Rle block since with no blocks partition function + // can return a different value per invocation (e.g. RoundRobinBucketFunction) + partitionBySingleRleValue(page, position, partitionFunctionArgs, partitionPositions); + } + else if (partitionFunctionArgs.getChannelCount() == 1 && isDictionaryProcessingFaster(partitionFunctionArgs.getBlock(0))) { + partitionBySingleDictionary(page, position, partitionFunctionArgs, partitionPositions); + } + else { + partitionGeneric(page, position, aPosition -> partitionFunction.getPartition(partitionFunctionArgs, aPosition), partitionPositions); + } + return partitionPositions; + } + + private void appendToOutputPartition(PageBuilder outputPartition, Page page, IntArrayList positions, PositionsAppender[] positionsAppenders) + { + outputPartition.declarePositions(positions.size()); + + for (int channel = 0; channel < positionsAppenders.length; channel++) { + Block partitionBlock = page.getBlock(channel); + BlockBuilder target = outputPartition.getBlockBuilder(channel); + positionsAppenders[channel].appendTo(positions, partitionBlock, target); + } + } + + private IntArrayList[] initPositions(Page page) + { + // We allocate new arrays for every page (instead of caching them) because we don't + // want memory to explode in case there are input pages with many positions, where each page + // is assigned to a single partition entirely. + // For example this can happen for partition columns if they are represented by RLE blocks. + IntArrayList[] partitionPositions = new IntArrayList[partitionFunction.getPartitionCount()]; + for (int i = 0; i < partitionPositions.length; i++) { + partitionPositions[i] = new IntArrayList(initialPartitionSize(page.getPositionCount() / partitionFunction.getPartitionCount())); + } + return partitionPositions; + } + + private static int initialPartitionSize(int averagePositionsPerPartition) + { + // 1.1 coefficient compensates for the not perfect hash distribution. + // 32 compensates for the case when averagePositionsPerPartition is small, + // and we would see more variance in the hash distribution. + return (int) (averagePositionsPerPartition * 1.1) + 32; + } + + private boolean onlyRleBlocks(Page page) + { + for (int i = 0; i < page.getChannelCount(); i++) { + if (!(page.getBlock(i) instanceof RunLengthEncodedBlock)) { + return false; + } + } + return true; + } + + private void partitionBySingleRleValue(Page page, int position, Page partitionFunctionArgs, IntArrayList[] partitionPositions) + { + // copy all positions because all hash function args are the same for every position + if (nullChannel != -1 && page.getBlock(nullChannel).isNull(0)) { + verify(page.getBlock(nullChannel) instanceof RunLengthEncodedBlock, "null channel is not RunLengthEncodedBlock", page.getBlock(nullChannel)); + // all positions are null + int[] allPositions = integersInRange(position, page.getPositionCount()); + for (IntList partitionPosition : partitionPositions) { + partitionPosition.addElements(position, allPositions); + } + } + else { + // extract rle page to prevent JIT profile pollution + Page rlePage = extractRlePage(partitionFunctionArgs); + + int partition = partitionFunction.getPartition(rlePage, 0); + IntArrayList positions = partitionPositions[partition]; + for (int i = position; i < page.getPositionCount(); i++) { + positions.add(i); + } + } + } + + private Page extractRlePage(Page page) + { + Block[] valueBlocks = new Block[page.getChannelCount()]; + for (int channel = 0; channel < valueBlocks.length; ++channel) { + valueBlocks[channel] = ((RunLengthEncodedBlock) page.getBlock(channel)).getValue(); + } + return new Page(valueBlocks); + } + + private int[] integersInRange(int start, int endExclusive) + { + int[] array = new int[endExclusive - start]; + int current = start; + for (int i = 0; i < array.length; i++) { + array[i] = current++; + } + return array; + } + + private boolean isDictionaryProcessingFaster(Block block) + { + if (!(block instanceof DictionaryBlock)) { + return false; + } + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + // if dictionary block positionCount is greater than number of elements in the dictionary + // it will be faster to compute hash for the dictionary values only once and re-use it + // instead of recalculating it. + return dictionaryBlock.getPositionCount() > dictionaryBlock.getDictionary().getPositionCount(); + } + + private void partitionBySingleDictionary(Page page, int position, Page partitionFunctionArgs, IntArrayList[] partitionPositions) + { + DictionaryBlock dictionaryBlock = (DictionaryBlock) partitionFunctionArgs.getBlock(0); + Block dictionary = dictionaryBlock.getDictionary(); + int[] dictionaryPartitions = new int[dictionary.getPositionCount()]; + Page dictionaryPage = new Page(dictionary); + for (int i = 0; i < dictionary.getPositionCount(); i++) { + dictionaryPartitions[i] = partitionFunction.getPartition(dictionaryPage, i); + } + + partitionGeneric(page, position, aPosition -> dictionaryPartitions[dictionaryBlock.getId(aPosition)], partitionPositions); + } + + private void partitionGeneric(Page page, int position, IntUnaryOperator partitionFunction, IntArrayList[] partitionPositions) + { + // Skip null block checks if mayHaveNull reports that no positions will be null + if (nullChannel != -1 && page.getBlock(nullChannel).mayHaveNull()) { + partitionNullablePositions(page, position, partitionPositions, partitionFunction); + } + else { + partitionNotNullPositions(page, position, partitionPositions, partitionFunction); + } + } + + private IntArrayList[] partitionNullablePositions(Page page, int position, IntArrayList[] partitionPositions, IntUnaryOperator partitionFunction) + { + Block nullsBlock = page.getBlock(nullChannel); + int[] nullPositions = new int[page.getPositionCount()]; + int[] nonNullPositions = new int[page.getPositionCount()]; + int nullCount = 0; + int nonNullCount = 0; + for (int i = position; i < page.getPositionCount(); i++) { + nullPositions[nullCount] = i; + nonNullPositions[nonNullCount] = i; + int isNull = nullsBlock.isNull(i) ? 1 : 0; + nullCount += isNull; + nonNullCount += isNull ^ 1; + } + for (IntArrayList positions : partitionPositions) { + positions.addElements(position, nullPositions, 0, nullCount); + } + for (int i = 0; i < nonNullCount; i++) { + int nonNullPosition = nonNullPositions[i]; + int partition = partitionFunction.applyAsInt(nonNullPosition); + partitionPositions[partition].add(nonNullPosition); + } + return partitionPositions; + } + + private IntArrayList[] partitionNotNullPositions(Page page, int startingPosition, IntArrayList[] partitionPositions, IntUnaryOperator partitionFunction) + { + for (int position = startingPosition; position < page.getPositionCount(); position++) { + int partition = partitionFunction.applyAsInt(position); + partitionPositions[partition].add(position); + } + + return partitionPositions; + } + + private Page getPartitionFunctionArguments(Page page) + { + // Fast path for no constants + if (partitionConstantBlocks == null) { + return page.getColumns(partitionChannels); + } - void flush(boolean force); + Block[] blocks = new Block[partitionChannels.length]; + for (int i = 0; i < blocks.length; i++) { + int channel = partitionChannels[i]; + if (channel < 0) { + blocks[i] = new RunLengthEncodedBlock(partitionConstantBlocks[i], page.getPositionCount()); + } + else { + blocks[i] = page.getBlock(channel); + } + } + return new Page(page.getPositionCount(), blocks); + } - ListenableFuture isFull(); + public void flush(boolean force) + { + try (PagesSerde.PagesSerdeContext context = serde.newContext()) { + // add all full pages to output buffer + for (int partition = 0; partition < pageBuilders.length; partition++) { + PageBuilder partitionPageBuilder = pageBuilders[partition]; + if (!partitionPageBuilder.isEmpty() && (force || partitionPageBuilder.isFull())) { + Page pagePartition = partitionPageBuilder.build(); + partitionPageBuilder.reset(); - long getSizeInBytes(); + operatorContext.recordOutput(pagePartition.getSizeInBytes(), pagePartition.getPositionCount()); - long getRetainedSizeInBytes(); + outputBuffer.enqueue(partition, splitAndSerializePage(context, pagePartition)); + pagesAdded.incrementAndGet(); + rowsAdded.addAndGet(pagePartition.getPositionCount()); + } + } + } + } - Supplier getOperatorInfoSupplier(); + private List splitAndSerializePage(PagesSerde.PagesSerdeContext context, Page page) + { + List split = splitPage(page, DEFAULT_MAX_PAGE_SIZE_IN_BYTES); + ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(split.size()); + for (Page chunk : split) { + builder.add(serde.serialize(context, chunk)); + } + return builder.build(); + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerFactory.java b/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerFactory.java deleted file mode 100644 index ed4036d92578..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/output/PagePartitionerFactory.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator.output; - -import io.airlift.units.DataSize; -import io.trino.execution.buffer.OutputBuffer; -import io.trino.execution.buffer.PagesSerdeFactory; -import io.trino.operator.OperatorContext; -import io.trino.operator.PartitionFunction; -import io.trino.spi.predicate.NullableValue; -import io.trino.spi.type.Type; - -import java.util.List; -import java.util.Optional; -import java.util.OptionalInt; - -public interface PagePartitionerFactory -{ - PagePartitioner create( - PartitionFunction partitionFunction, - List partitionChannels, - List> partitionConstants, - boolean replicatesAnyRow, - OptionalInt nullChannel, - OutputBuffer outputBuffer, - PagesSerdeFactory serdeFactory, - List sourceTypes, - DataSize maxMemory, - OperatorContext operatorContext); -} diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PartitionedOutputOperator.java b/core/trino-main/src/main/java/io/trino/operator/output/PartitionedOutputOperator.java index 43f5d4ddc649..355340b8c71c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/PartitionedOutputOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/PartitionedOutputOperator.java @@ -54,26 +54,7 @@ public static class PartitionedOutputFactory private final boolean replicatesAnyRow; private final OptionalInt nullChannel; private final DataSize maxMemory; - private final PagePartitionerFactory pagePartitionerFactory; - - public PartitionedOutputFactory( - PartitionFunction partitionFunction, - List partitionChannels, - List> partitionConstants, - boolean replicatesAnyRow, - OptionalInt nullChannel, - OutputBuffer outputBuffer, - DataSize maxMemory) - { - this(partitionFunction, - partitionChannels, - partitionConstants, - replicatesAnyRow, - nullChannel, - outputBuffer, - maxMemory, - DefaultPagePartitioner::new); - } + private final PositionsAppenderFactory positionsAppenderFactory; public PartitionedOutputFactory( PartitionFunction partitionFunction, @@ -83,7 +64,7 @@ public PartitionedOutputFactory( OptionalInt nullChannel, OutputBuffer outputBuffer, DataSize maxMemory, - PagePartitionerFactory pagePartitionerFactory) + PositionsAppenderFactory positionsAppenderFactory) { this.partitionFunction = requireNonNull(partitionFunction, "partitionFunction is null"); this.partitionChannels = requireNonNull(partitionChannels, "partitionChannels is null"); @@ -92,7 +73,7 @@ public PartitionedOutputFactory( this.nullChannel = requireNonNull(nullChannel, "nullChannel is null"); this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); this.maxMemory = requireNonNull(maxMemory, "maxMemory is null"); - this.pagePartitionerFactory = requireNonNull(pagePartitionerFactory, "pagePartitionerFactory is null"); + this.positionsAppenderFactory = requireNonNull(positionsAppenderFactory, "positionsAppenderFactory is null"); } @Override @@ -116,7 +97,7 @@ public OperatorFactory createOutputOperator( outputBuffer, serdeFactory, maxMemory, - pagePartitionerFactory); + positionsAppenderFactory); } } @@ -135,7 +116,7 @@ public static class PartitionedOutputOperatorFactory private final OutputBuffer outputBuffer; private final PagesSerdeFactory serdeFactory; private final DataSize maxMemory; - private final PagePartitionerFactory pagePartitionerFactory; + private final PositionsAppenderFactory positionsAppenderFactory; public PartitionedOutputOperatorFactory( int operatorId, @@ -150,7 +131,7 @@ public PartitionedOutputOperatorFactory( OutputBuffer outputBuffer, PagesSerdeFactory serdeFactory, DataSize maxMemory, - PagePartitionerFactory pagePartitionerFactory) + PositionsAppenderFactory positionsAppenderFactory) { this.operatorId = operatorId; this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); @@ -164,7 +145,7 @@ public PartitionedOutputOperatorFactory( this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); this.serdeFactory = requireNonNull(serdeFactory, "serdeFactory is null"); this.maxMemory = requireNonNull(maxMemory, "maxMemory is null"); - this.pagePartitionerFactory = requireNonNull(pagePartitionerFactory, "pagePartitionerFactory is null"); + this.positionsAppenderFactory = requireNonNull(positionsAppenderFactory, "positionsAppenderFactory is null"); } @Override @@ -183,7 +164,7 @@ public Operator createOperator(DriverContext driverContext) outputBuffer, serdeFactory, maxMemory, - pagePartitionerFactory); + positionsAppenderFactory); } @Override @@ -207,7 +188,7 @@ public OperatorFactory duplicate() outputBuffer, serdeFactory, maxMemory, - pagePartitionerFactory); + positionsAppenderFactory); } } @@ -231,11 +212,11 @@ public PartitionedOutputOperator( OutputBuffer outputBuffer, PagesSerdeFactory serdeFactory, DataSize maxMemory, - PagePartitionerFactory pagePartitionerFactory) + PositionsAppenderFactory positionsAppenderFactory) { this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.pagePreprocessor = requireNonNull(pagePreprocessor, "pagePreprocessor is null"); - this.partitionFunction = pagePartitionerFactory.create( + this.partitionFunction = new PagePartitioner( partitionFunction, partitionChannels, partitionConstants, @@ -245,7 +226,8 @@ public PartitionedOutputOperator( serdeFactory, sourceTypes, maxMemory, - operatorContext); + operatorContext, + positionsAppenderFactory); operatorContext.setInfoSupplier(this.partitionFunction.getOperatorInfoSupplier()); this.memoryContext = operatorContext.newLocalUserMemoryContext(PartitionedOutputOperator.class.getSimpleName()); diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppender.java b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppender.java new file mode 100644 index 000000000000..0b5aa8414f8b --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppender.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.trino.operator.output; + +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.Type; +import it.unimi.dsi.fastutil.ints.IntArrayList; + +import static java.util.Objects.requireNonNull; + +public interface PositionsAppender +{ + void appendTo(IntArrayList positions, Block source, BlockBuilder target); + + class TypedPositionsAppender + implements PositionsAppender + { + private final Type type; + + public TypedPositionsAppender(Type type) + { + this.type = requireNonNull(type, "type is null"); + } + + @Override + public void appendTo(IntArrayList positions, Block source, BlockBuilder target) + { + int[] positionArray = positions.elements(); + for (int i = 0; i < positions.size(); i++) { + type.appendTo(source, positionArray[i], target); + } + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderFactory.java b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderFactory.java new file mode 100644 index 000000000000..41021d98b70e --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/output/PositionsAppenderFactory.java @@ -0,0 +1,353 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.output; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import io.airlift.bytecode.DynamicClassLoader; +import io.trino.collect.cache.NonEvictableLoadingCache; +import io.trino.operator.output.PositionsAppender.TypedPositionsAppender; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.Int128ArrayBlock; +import io.trino.spi.block.Int96ArrayBlock; +import io.trino.spi.type.FixedWidthType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VariableWidthType; +import io.trino.sql.gen.IsolatedClass; +import it.unimi.dsi.fastutil.ints.IntArrayList; + +import java.util.Objects; +import java.util.Optional; + +import static io.airlift.slice.SizeOf.SIZE_OF_LONG; +import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; +import static java.util.Objects.requireNonNull; + +/** + * Isolates the {@code PositionsAppender} class per type and block tuples. + * Type specific {@code PositionsAppender} implementations manually inline {@code Type#appendTo} method inside the loop + * to avoid virtual(mega-morphic) calls and force jit to inline the {@code Block} and {@code BlockBuilder} methods. + * Ideally, {@code TypedPositionsAppender} could work instead of type specific {@code PositionsAppender}s, + * but in practice jit falls back to virtual calls in some cases (e.g. {@link Block#isNull}). + */ +public class PositionsAppenderFactory +{ + private final NonEvictableLoadingCache cache; + + public PositionsAppenderFactory() + { + this.cache = buildNonEvictableCache( + CacheBuilder.newBuilder().maximumSize(1000), + CacheLoader.from(key -> createAppender(key.type))); + } + + public PositionsAppender create(Type type, Class blockClass) + { + return cache.getUnchecked(new CacheKey(type, blockClass)); + } + + private PositionsAppender createAppender(Type type) + { + return Optional.ofNullable(findDedicatedAppenderClassFor(type)) + .map(this::isolateAppender) + .orElseGet(() -> isolateTypeAppender(type)); + } + + private Class findDedicatedAppenderClassFor(Type type) + { + if (type instanceof FixedWidthType) { + switch (((FixedWidthType) type).getFixedSize()) { + case Byte.BYTES: + return BytePositionsAppender.class; + case Short.BYTES: + return SmallintPositionsAppender.class; + case Integer.BYTES: + return IntPositionsAppender.class; + case Long.BYTES: + return LongPositionsAppender.class; + case Int96ArrayBlock.INT96_BYTES: + return Int96PositionsAppender.class; + case Int128ArrayBlock.INT128_BYTES: + return Int128PositionsAppender.class; + default: + // size not supported directly, fallback to the generic appender + } + } + else if (type instanceof VariableWidthType) { + return SlicePositionsAppender.class; + } + + return null; + } + + private PositionsAppender isolateTypeAppender(Type type) + { + Class isolatedAppenderClass = isolateAppenderClass(TypedPositionsAppender.class); + try { + return isolatedAppenderClass.getConstructor(Type.class).newInstance(type); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } + + private PositionsAppender isolateAppender(Class appenderClass) + { + Class isolatedAppenderClass = isolateAppenderClass(appenderClass); + try { + return isolatedAppenderClass.getConstructor().newInstance(); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } + + private Class isolateAppenderClass(Class appenderClass) + { + DynamicClassLoader dynamicClassLoader = new DynamicClassLoader(PositionsAppender.class.getClassLoader()); + + Class isolatedBatchPositionsTransferClass = IsolatedClass.isolateClass( + dynamicClassLoader, + PositionsAppender.class, + appenderClass); + return isolatedBatchPositionsTransferClass; + } + + public static class LongPositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeLong(block.getLong(position, 0)).closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + blockBuilder.writeLong(block.getLong(positionArray[i], 0)).closeEntry(); + } + } + } + } + + public static class IntPositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeInt(block.getInt(position, 0)).closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + blockBuilder.writeInt(block.getInt(positionArray[i], 0)).closeEntry(); + } + } + } + } + + public static class BytePositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeByte(block.getByte(position, 0)).closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + blockBuilder.writeByte(block.getByte(positionArray[i], 0)).closeEntry(); + } + } + } + } + + public static class SlicePositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + block.writeBytesTo(position, 0, block.getSliceLength(position), blockBuilder); + blockBuilder.closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + block.writeBytesTo(position, 0, block.getSliceLength(position), blockBuilder); + blockBuilder.closeEntry(); + } + } + } + } + + public static class SmallintPositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeShort(block.getShort(position, 0)).closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + blockBuilder.writeShort(block.getShort(positionArray[i], 0)).closeEntry(); + } + } + } + } + + public static class Int96PositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeLong(block.getLong(position, 0)); + blockBuilder.writeInt(block.getInt(position, SIZE_OF_LONG)); + blockBuilder.closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + blockBuilder.writeLong(block.getLong(position, 0)); + blockBuilder.writeInt(block.getInt(position, SIZE_OF_LONG)); + blockBuilder.closeEntry(); + } + } + } + } + + public static class Int128PositionsAppender + implements PositionsAppender + { + @Override + public void appendTo(IntArrayList positions, Block block, BlockBuilder blockBuilder) + { + int[] positionArray = positions.elements(); + if (block.mayHaveNull()) { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + if (block.isNull(position)) { + blockBuilder.appendNull(); + } + else { + blockBuilder.writeLong(block.getLong(position, 0)); + blockBuilder.writeLong(block.getLong(position, SIZE_OF_LONG)); + blockBuilder.closeEntry(); + } + } + } + else { + for (int i = 0; i < positions.size(); i++) { + int position = positionArray[i]; + blockBuilder.writeLong(block.getLong(position, 0)); + blockBuilder.writeLong(block.getLong(position, SIZE_OF_LONG)); + blockBuilder.closeEntry(); + } + } + } + } + + private static class CacheKey + { + private final Type type; + private final Class blockClass; + + private CacheKey(Type type, Class blockClass) + { + this.type = requireNonNull(type, "type is null"); + this.blockClass = requireNonNull(blockClass, "blockClass is null"); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + CacheKey cacheKey = (CacheKey) o; + return type.equals(cacheKey.type) && blockClass.equals(cacheKey.blockClass); + } + + @Override + public int hashCode() + { + return Objects.hash(type, blockClass); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 9ffb9754f566..859a1ff2dcfd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -122,6 +122,8 @@ import io.trino.operator.join.NestedLoopJoinBridge; import io.trino.operator.join.NestedLoopJoinPagesSupplier; import io.trino.operator.join.PartitionedLookupSourceFactory; +import io.trino.operator.output.PartitionedOutputOperator.PartitionedOutputFactory; +import io.trino.operator.output.PositionsAppenderFactory; import io.trino.operator.output.TaskOutputOperator.TaskOutputFactory; import io.trino.operator.project.CursorProcessor; import io.trino.operator.project.PageProcessor; @@ -385,6 +387,7 @@ public class LocalExecutionPlanner private final BlockTypeOperators blockTypeOperators; private final TableExecuteContextManager tableExecuteContextManager; private final ExchangeManagerRegistry exchangeManagerRegistry; + private final PositionsAppenderFactory positionsAppenderFactory = new PositionsAppenderFactory(); @Inject public LocalExecutionPlanner( @@ -515,15 +518,15 @@ public LocalExecutionPlan plan( outputLayout, types, partitionedSourceOrder, - operatorFactories.partitionedOutput( - taskContext, + new PartitionedOutputFactory( partitionFunction, partitionChannels, partitionConstants, partitioningScheme.isReplicateNullsAndAny(), nullChannel, outputBuffer, - maxPagePartitioningBufferSize)); + maxPagePartitioningBufferSize, + positionsAppenderFactory)); } public LocalExecutionPlan plan( diff --git a/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java b/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java index ea83827a88ab..fc6b4423efa5 100644 --- a/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java +++ b/core/trino-main/src/test/java/io/trino/block/BlockAssertions.java @@ -26,8 +26,11 @@ import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Int128; +import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.MapType; import io.trino.spi.type.RowType; +import io.trino.spi.type.SqlTimestamp; +import io.trino.spi.type.TimestampType; import io.trino.spi.type.Type; import java.math.BigDecimal; @@ -68,6 +71,7 @@ import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.testing.TestingConnectorSession.SESSION; import static io.trino.type.ColorType.COLOR; +import static io.trino.type.IpAddressType.IPADDRESS; import static java.lang.Float.floatToRawIntBits; import static java.lang.Math.multiplyExact; import static java.lang.String.format; @@ -174,9 +178,19 @@ public static Block createRandomBlockForType(Type type, int positionCount, float if (type == UUID) { return createRandomUUIDsBlock(positionCount, nullRate); } + if (type == IPADDRESS) { + return createRandomIpAddressesBlock(positionCount, nullRate); + } if (type == VARBINARY) { return createRandomVarbinariesBlock(positionCount, nullRate); } + if (type instanceof TimestampType) { + TimestampType timestampType = (TimestampType) type; + if (timestampType.isShort()) { + return createRandomShortTimestampBlock(timestampType, positionCount, nullRate); + } + return createRandomLongTimestampBlock(timestampType, positionCount, nullRate); + } return createRandomBlockForNestedType(type, positionCount, nullRate); } @@ -248,6 +262,28 @@ public static Block createRandomLongDecimalsBlock(int positionCount, float nullR () -> String.valueOf(RANDOM.nextLong()))); } + public static Block createRandomShortTimestampBlock(TimestampType type, int positionCount, float nullRate) + { + return createLongsBlock( + generateListWithNulls( + positionCount, + nullRate, + () -> SqlTimestamp.fromMillis(type.getPrecision(), RANDOM.nextLong()).getEpochMicros())); + } + + public static Block createRandomLongTimestampBlock(TimestampType type, int positionCount, float nullRate) + { + return createLongTimestampBlock( + type, + generateListWithNulls( + positionCount, + nullRate, + () -> { + SqlTimestamp sqlTimestamp = SqlTimestamp.fromMillis(type.getPrecision(), RANDOM.nextLong()); + return new LongTimestamp(sqlTimestamp.getEpochMicros(), sqlTimestamp.getPicosOfMicros()); + })); + } + public static Block createRandomLongsBlock(int positionCount, int numberOfUniqueValues) { checkArgument(positionCount >= numberOfUniqueValues, "numberOfUniqueValues must be between 1 and positionCount: %s but was %s", positionCount, numberOfUniqueValues); @@ -288,6 +324,11 @@ private static Block createRandomUUIDsBlock(int positionCount, float nullRate) return createSlicesBlock(UUID, generateListWithNulls(positionCount, nullRate, () -> Slices.wrappedLongArray(RANDOM.nextLong(), RANDOM.nextLong()))); } + private static Block createRandomIpAddressesBlock(int positionCount, float nullRate) + { + return createSlicesBlock(IPADDRESS, generateListWithNulls(positionCount, nullRate, () -> Slices.wrappedLongArray(RANDOM.nextLong(), RANDOM.nextLong()))); + } + private static Block createRandomTinyintsBlock(int positionCount, float nullRate) { return createTypedLongsBlock(TINYINT, generateListWithNulls(positionCount, nullRate, () -> (long) (byte) RANDOM.nextLong())); @@ -473,6 +514,22 @@ public static Block createLongDecimalsBlock(Iterable values) return builder.build(); } + public static Block createLongTimestampBlock(TimestampType type, Iterable values) + { + BlockBuilder builder = type.createBlockBuilder(null, 100); + + for (LongTimestamp value : values) { + if (value == null) { + builder.appendNull(); + } + else { + type.writeObject(builder, value); + } + } + + return builder.build(); + } + public static Block createCharsBlock(CharType charType, List values) { return createBlock(charType, charType::writeString, values); diff --git a/core/trino-main/src/test/java/io/trino/operator/output/BenchmarkPartitionedOutputOperator.java b/core/trino-main/src/test/java/io/trino/operator/output/BenchmarkPartitionedOutputOperator.java index d8385509eae2..74c815face77 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/BenchmarkPartitionedOutputOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/BenchmarkPartitionedOutputOperator.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.airlift.units.DataSize; -import io.trino.Session; import io.trino.execution.StageId; import io.trino.execution.TaskId; import io.trino.execution.buffer.OutputBufferStateMachine; @@ -28,12 +27,9 @@ import io.trino.memory.context.SimpleLocalMemoryContext; import io.trino.operator.BucketPartitionFunction; import io.trino.operator.DriverContext; -import io.trino.operator.OperatorFactories; -import io.trino.operator.OutputFactory; import io.trino.operator.PartitionFunction; import io.trino.operator.PrecomputedHashGenerator; -import io.trino.operator.TaskContext; -import io.trino.operator.TrinoOperatorFactories; +import io.trino.operator.output.PartitionedOutputOperator.PartitionedOutputFactory; import io.trino.spi.Page; import io.trino.spi.QueryId; import io.trino.spi.block.Block; @@ -82,6 +78,7 @@ import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.units.DataSize.Unit.BYTE; import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.block.BlockAssertions.createLongDictionaryBlock; import static io.trino.block.BlockAssertions.createLongsBlock; import static io.trino.block.BlockAssertions.createRLEBlock; @@ -97,7 +94,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.Decimals.MAX_SHORT_PRECISION; -import static io.trino.testing.TestingSession.testSessionBuilder; import static java.util.Collections.nCopies; import static java.util.Collections.unmodifiableList; import static java.util.Objects.requireNonNull; @@ -113,7 +109,7 @@ @BenchmarkMode(Mode.AverageTime) public class BenchmarkPartitionedOutputOperator { - private static final OperatorFactories OPERATOR_FACTORIES = new TrinoOperatorFactories(); + private static final PositionsAppenderFactory POSITIONS_APPENDER_FACTORY = new PositionsAppenderFactory(); @Benchmark public void addPage(BenchmarkData data) @@ -142,9 +138,6 @@ public static class BenchmarkData private static final ExecutorService EXECUTOR = newCachedThreadPool(daemonThreadsNamed("BenchmarkPartitionedOutputOperator-executor-%s")); private static final ScheduledExecutorService SCHEDULER = newScheduledThreadPool(1, daemonThreadsNamed("BenchmarkPartitionedOutputOperator-scheduledExecutor-%s")); - private final OperatorFactories operatorFactories; - private final Session session; - @Param({"2", "16", "256"}) private int partitionCount = 256; @@ -167,6 +160,7 @@ public static class BenchmarkData "BIGINT_DICTIONARY_PARTITION_CHANNEL_50_PERCENT", "BIGINT_DICTIONARY_PARTITION_CHANNEL_80_PERCENT", "BIGINT_DICTIONARY_PARTITION_CHANNEL_100_PERCENT", + "BIGINT_DICTIONARY_PARTITION_CHANNEL_100_PERCENT_MINUS_1", "RLE_PARTITION_BIGINT", "RLE_PARTITION_NULL_BIGINT", "LONG_DECIMAL", @@ -279,6 +273,17 @@ public Page createPage(List types, int positionCount, float nullRate) createLongDictionaryBlock(0, positionCount, positionCount)); } }, + BIGINT_DICTIONARY_PARTITION_CHANNEL_100_PERCENT_MINUS_1(BigintType.BIGINT, 3000) { + @Override + public Page createPage(List types, int positionCount, float nullRate) + { + return page( + positionCount, + types.size(), + () -> createRandomBlockForType(BigintType.BIGINT, positionCount, nullRate), + createLongDictionaryBlock(0, positionCount, positionCount - 1)); + } + }, RLE_PARTITION_BIGINT(BigintType.BIGINT, 5000) { @Override public Page createPage(List types, int positionCount, float nullRate) @@ -350,17 +355,6 @@ public List getTypes(int channelCount) } } - public BenchmarkData() - { - this(OPERATOR_FACTORIES, testSessionBuilder().build()); - } - - protected BenchmarkData(OperatorFactories operatorFactories, Session session) - { - this.operatorFactories = requireNonNull(operatorFactories, "operatorFactories is null"); - this.session = requireNonNull(session, "session is null"); - } - public int getPageCount() { return pageCount; @@ -383,6 +377,12 @@ public Page getDataPage() @Setup public void setup(Blackhole blackhole) + { + setupData(blackhole); + pollute(); + } + + private void setupData(Blackhole blackhole) { // We don't check blackhole is not null, because blackhole has to be injected by jmh (should not be created manually) // and in case of unit test it will be null @@ -429,34 +429,29 @@ private PartitionedOutputOperator createPartitionedOutputOperator() PagesSerdeFactory serdeFactory = new PagesSerdeFactory(new TestingBlockEncodingSerde(), enableCompression); PartitionedOutputBuffer buffer = createPartitionedOutputBuffer(); - TaskContext taskContext = createTaskContext(); - OutputFactory operatorFactory = operatorFactories.partitionedOutput( - taskContext, + PartitionedOutputFactory operatorFactory = new PartitionedOutputFactory( partitionFunction, ImmutableList.of(types.size() - 1), // hash block is at the last channel ImmutableList.of(Optional.empty()), false, - nullChannel, + OptionalInt.empty(), buffer, - MAX_PARTITION_BUFFER_SIZE); + MAX_PARTITION_BUFFER_SIZE, + POSITIONS_APPENDER_FACTORY); return (PartitionedOutputOperator) operatorFactory .createOutputOperator(0, new PlanNodeId("plan-node-0"), types, Function.identity(), serdeFactory) - .createOperator(createDriverContext(taskContext)); + .createOperator(createDriverContext()); } - private DriverContext createDriverContext(TaskContext taskContext) + private DriverContext createDriverContext() { - return taskContext + return TestingTaskContext.builder(EXECUTOR, SCHEDULER, TEST_SESSION) + .build() .addPipelineContext(0, true, true, false) .addDriverContext(); } - private TaskContext createTaskContext() - { - return TestingTaskContext.builder(EXECUTOR, SCHEDULER, session).build(); - } - private TestingPartitionedOutputBuffer createPartitionedBuffer(OutputBuffers buffers, DataSize dataSize) { return new TestingPartitionedOutputBuffer( @@ -519,7 +514,8 @@ private static MapType createMapType(Type keyType, Type valueType) new TypeOperators()); } - static { + private static void pollute() + { try { List types = List.of( TestType.BIGINT, @@ -536,7 +532,7 @@ private static MapType createMapType(Type keyType, Type valueType) types.forEach(type -> { BenchmarkData data = new BenchmarkData(); data.setType(type); - data.setup(null); + data.setupData(null); data.setPageCount(1); benchmark.addPage(data); }); diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestPartitionedOutputOperator.java b/core/trino-main/src/test/java/io/trino/operator/output/TestPartitionedOutputOperator.java index 0e66e10aebf7..b5d591b38f32 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestPartitionedOutputOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestPartitionedOutputOperator.java @@ -19,7 +19,6 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.Slice; import io.airlift.units.DataSize; -import io.trino.Session; import io.trino.execution.StateMachine; import io.trino.execution.buffer.BufferResult; import io.trino.execution.buffer.BufferState; @@ -31,11 +30,8 @@ import io.trino.operator.BucketPartitionFunction; import io.trino.operator.DriverContext; import io.trino.operator.OperatorContext; -import io.trino.operator.OperatorFactories; import io.trino.operator.OutputFactory; import io.trino.operator.PartitionFunction; -import io.trino.operator.TaskContext; -import io.trino.operator.TrinoOperatorFactories; import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; @@ -44,6 +40,7 @@ import io.trino.spi.predicate.NullableValue; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Decimals; +import io.trino.spi.type.TimestampType; import io.trino.spi.type.Type; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.TestingTaskContext; @@ -67,6 +64,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.block.BlockAssertions.createLongDictionaryBlock; import static io.trino.block.BlockAssertions.createLongSequenceBlock; import static io.trino.block.BlockAssertions.createLongsBlock; @@ -85,7 +83,7 @@ import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; import static io.trino.sql.planner.SystemPartitioningHandle.SystemPartitionFunction.ROUND_ROBIN; -import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.type.IpAddressType.IPADDRESS; import static java.lang.Math.toIntExact; import static java.util.Collections.nCopies; import static java.util.Collections.unmodifiableList; @@ -98,8 +96,6 @@ @Test(singleThreaded = true) public class TestPartitionedOutputOperator { - private static final OperatorFactories TRINO_OPERATOR_FACTORIES = new TrinoOperatorFactories(); - private static final Session TEST_SESSION = testSessionBuilder().build(); private static final DataSize MAX_MEMORY = DataSize.of(50, MEGABYTE); private static final DataSize PARTITION_MAX_MEMORY = DataSize.of(5, MEGABYTE); @@ -109,24 +105,10 @@ public class TestPartitionedOutputOperator private static final PagesSerdeFactory PAGES_SERDE_FACTORY = new PagesSerdeFactory(new TestingBlockEncodingSerde(), false); private static final PagesSerde PAGES_SERDE = PAGES_SERDE_FACTORY.createPagesSerde(); - private final Session testSession; - private final OperatorFactories operatorFactories; - private ExecutorService executor; private ScheduledExecutorService scheduledExecutor; private TestOutputBuffer outputBuffer; - public TestPartitionedOutputOperator() - { - this(TEST_SESSION, TRINO_OPERATOR_FACTORIES); - } - - protected TestPartitionedOutputOperator(Session testSession, OperatorFactories operatorFactories) - { - this.testSession = testSession; - this.operatorFactories = operatorFactories; - } - @BeforeClass public void setUpClass() { @@ -392,7 +374,10 @@ public static Object[][] types() {VARBINARY}, {createDecimalType(1)}, {createDecimalType(Decimals.MAX_SHORT_PRECISION + 1)}, - {new ArrayType(BIGINT)} + {new ArrayType(BIGINT)}, + {TimestampType.createTimestampType(9)}, + {TimestampType.createTimestampType(3)}, + {IPADDRESS} }; } @@ -444,13 +429,12 @@ private PartitionedOutputOperatorBuilder partitionedOutputOperator(List ty private PartitionedOutputOperatorBuilder partitionedOutputOperator() { - return new PartitionedOutputOperatorBuilder(operatorFactories, testSession, executor, scheduledExecutor, outputBuffer); + return new PartitionedOutputOperatorBuilder(executor, scheduledExecutor, outputBuffer); } static class PartitionedOutputOperatorBuilder { - private final OperatorFactories operatorFactories; - private final Session testSession; + public static final PositionsAppenderFactory POSITIONS_APPENDER_FACTORY = new PositionsAppenderFactory(); private final ExecutorService executor; private final ScheduledExecutorService scheduledExecutor; private final OutputBuffer outputBuffer; @@ -462,10 +446,8 @@ static class PartitionedOutputOperatorBuilder private OptionalInt nullChannel = OptionalInt.empty(); private List types; - PartitionedOutputOperatorBuilder(OperatorFactories operatorFactories, Session testSession, ExecutorService executor, ScheduledExecutorService scheduledExecutor, OutputBuffer outputBuffer) + PartitionedOutputOperatorBuilder(ExecutorService executor, ScheduledExecutorService scheduledExecutor, OutputBuffer outputBuffer) { - this.operatorFactories = requireNonNull(operatorFactories, "operatorFactories is null"); - this.testSession = requireNonNull(testSession, "testSession is null"); this.executor = requireNonNull(executor, "executor is null"); this.scheduledExecutor = requireNonNull(scheduledExecutor, "scheduledExecutor is null"); this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null"); @@ -529,10 +511,9 @@ public PartitionedOutputOperatorBuilder withTypes(List types) public PartitionedOutputOperator build() { - TaskContext taskContext = TestingTaskContext.builder(executor, scheduledExecutor, testSession) + DriverContext driverContext = TestingTaskContext.builder(executor, scheduledExecutor, TEST_SESSION) .setMemoryPoolSize(MAX_MEMORY) - .build(); - DriverContext driverContext = taskContext + .build() .addPipelineContext(0, true, true, false) .addDriverContext(); @@ -541,15 +522,15 @@ public PartitionedOutputOperator build() buffers = buffers.withBuffer(new OutputBuffers.OutputBufferId(partition), partition); } - OutputFactory operatorFactory = operatorFactories.partitionedOutput( - taskContext, + OutputFactory operatorFactory = new PartitionedOutputOperator.PartitionedOutputFactory( partitionFunction, partitionChannels, partitionConstants, shouldReplicate, nullChannel, outputBuffer, - PARTITION_MAX_MEMORY); + PARTITION_MAX_MEMORY, + POSITIONS_APPENDER_FACTORY); return (PartitionedOutputOperator) operatorFactory .createOutputOperator(0, new PlanNodeId("plan-node-0"), types, Function.identity(), PAGES_SERDE_FACTORY)