diff --git a/core/trino-main/src/main/java/io/trino/connector/informationschema/InformationSchemaPageSource.java b/core/trino-main/src/main/java/io/trino/connector/informationschema/InformationSchemaPageSource.java index a28449bfe5ad..e405110d18be 100644 --- a/core/trino-main/src/main/java/io/trino/connector/informationschema/InformationSchemaPageSource.java +++ b/core/trino-main/src/main/java/io/trino/connector/informationschema/InformationSchemaPageSource.java @@ -27,6 +27,7 @@ import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.RelationType; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.connector.SourcePage; import io.trino.spi.security.AccessDeniedException; import io.trino.spi.security.GrantInfo; import io.trino.spi.security.RoleGrant; @@ -168,7 +169,7 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { if (isFinished()) { return null; @@ -187,7 +188,7 @@ public Page getNextPage() memoryUsageBytes -= page.getRetainedSizeInBytes(); Page outputPage = projection.apply(page); completedBytes += outputPage.getSizeInBytes(); - return outputPage; + return SourcePage.create(outputPage); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/FilterAndProjectOperator.java b/core/trino-main/src/main/java/io/trino/operator/FilterAndProjectOperator.java index 975e1380673f..bbc9e88b0d49 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FilterAndProjectOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/FilterAndProjectOperator.java @@ -23,6 +23,7 @@ import io.trino.operator.project.PageProcessorMetrics; import io.trino.spi.Page; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.metrics.Metrics; import io.trino.spi.type.Type; import io.trino.sql.planner.plan.PlanNodeId; @@ -62,7 +63,7 @@ private FilterAndProjectOperator( yieldSignal, outputMemoryContext, metrics, - page)) + SourcePage.create(page))) .transformProcessor(processor -> mergePages(types, minOutputPageSize.toBytes(), minOutputPageRowCount, processor, localAggregatedMemoryContext)) .blocking(() -> memoryTrackingContext.localUserMemoryContext().setBytes(localAggregatedMemoryContext.getBytes())); } diff --git a/core/trino-main/src/main/java/io/trino/operator/PageSourceOperator.java b/core/trino-main/src/main/java/io/trino/operator/PageSourceOperator.java index bfb99131586c..a3747027f615 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PageSourceOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/PageSourceOperator.java @@ -17,6 +17,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.trino.spi.Page; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import java.io.IOException; import java.io.UncheckedIOException; @@ -90,11 +91,13 @@ public void addInput(Page page) @Override public Page getOutput() { - Page page = pageSource.getNextPage(); - if (page == null) { + SourcePage sourcePage = pageSource.getNextSourcePage(); + if (sourcePage == null) { return null; } + Page page = sourcePage.getPage(); + // update operator stats long endCompletedBytes = pageSource.getCompletedBytes(); long endReadTimeNanos = pageSource.getReadTimeNanos(); @@ -103,9 +106,6 @@ public Page getOutput() completedBytes = endCompletedBytes; readTimeNanos = endReadTimeNanos; - // assure the page is in memory before handing to another operator - page = page.getLoadedPage(); - return page; } diff --git a/core/trino-main/src/main/java/io/trino/operator/PageUtils.java b/core/trino-main/src/main/java/io/trino/operator/PageUtils.java deleted file mode 100644 index 82c07319173d..000000000000 --- a/core/trino-main/src/main/java/io/trino/operator/PageUtils.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator; - -import io.trino.spi.Page; -import io.trino.spi.block.Block; - -import java.util.function.LongConsumer; - -import static io.trino.spi.block.LazyBlock.listenForLoads; - -public final class PageUtils -{ - private PageUtils() {} - - public static void recordMaterializedBytes(Page page, LongConsumer sizeInBytesConsumer) - { - // account processed bytes from lazy blocks only when they are loaded - long loadedBlocksSizeInBytes = 0; - - for (int i = 0; i < page.getChannelCount(); ++i) { - Block block = page.getBlock(i); - long initialSize = block.getSizeInBytes(); - loadedBlocksSizeInBytes += initialSize; - listenForLoads(block, loadedBlock -> sizeInBytesConsumer.accept(loadedBlock.getSizeInBytes())); - } - - if (loadedBlocksSizeInBytes > 0) { - sizeInBytesConsumer.accept(loadedBlocksSizeInBytes); - } - } -} diff --git a/core/trino-main/src/main/java/io/trino/operator/ScanFilterAndProjectOperator.java b/core/trino-main/src/main/java/io/trino/operator/ScanFilterAndProjectOperator.java index ba154d044e0b..0e498c926885 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ScanFilterAndProjectOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/ScanFilterAndProjectOperator.java @@ -39,6 +39,7 @@ import io.trino.spi.connector.EmptyPageSource; import io.trino.spi.connector.RecordCursor; import io.trino.spi.connector.RecordPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.metrics.Metrics; import io.trino.spi.type.Type; import io.trino.split.EmptySplit; @@ -51,14 +52,15 @@ import java.io.UncheckedIOException; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.LongConsumer; import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.toListenableFuture; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; -import static io.trino.operator.PageUtils.recordMaterializedBytes; import static io.trino.operator.WorkProcessor.TransformationState.finished; import static io.trino.operator.WorkProcessor.TransformationState.ofResult; import static io.trino.operator.project.MergePages.mergePages; @@ -283,12 +285,16 @@ WorkProcessor processPageSource() return WorkProcessor .create(new ConnectorPageSourceToPages(pageSourceMemoryContext)) .yielding(yieldSignal::isSet) - .flatMap(page -> pageProcessor.createWorkProcessor( - connectorSession, - yieldSignal, - outputMemoryContext, - pageProcessorMetrics, - page)) + .flatMap(page -> { + WorkProcessor workProcessor = pageProcessor.createWorkProcessor( + connectorSession, + yieldSignal, + outputMemoryContext, + pageProcessorMetrics, + page); + // Note this is monitoring the original source page not the result page + return workProcessor.withProcessStateMonitor(new ProcessedBytesMonitor(page, bytes -> processedBytes += bytes)); + }) .transformProcessor(processor -> mergePages(types, minOutputPageSize.toBytes(), minOutputPageRowCount, processor, localAggregatedMemoryContext)) .blocking(() -> memoryContext.setBytes(localAggregatedMemoryContext.getBytes())); } @@ -356,8 +362,37 @@ public ProcessState process() } } + static class ProcessedBytesMonitor + implements Consumer> + { + private final SourcePage page; + private final LongConsumer processedBytesConsumer; + private long localProcessedBytes; + + public ProcessedBytesMonitor(SourcePage page, LongConsumer processedBytesConsumer) + { + this.page = requireNonNull(page, "page is null"); + this.processedBytesConsumer = requireNonNull(processedBytesConsumer, "processedBytesConsumer is null"); + localProcessedBytes = page.getSizeInBytes(); + processedBytesConsumer.accept(localProcessedBytes); + } + + @Override + public void accept(ProcessState state) + { + update(); + } + + void update() + { + long newProcessedBytes = page.getSizeInBytes(); + processedBytesConsumer.accept(newProcessedBytes - localProcessedBytes); + localProcessedBytes = newProcessedBytes; + } + } + private class ConnectorPageSourceToPages - implements WorkProcessor.Process + implements WorkProcessor.Process { final LocalMemoryContext pageSourceMemoryContext; @@ -367,7 +402,7 @@ private class ConnectorPageSourceToPages } @Override - public ProcessState process() + public ProcessState process() { if (pageSource.isFinished()) { return ProcessState.finished(); @@ -378,7 +413,7 @@ public ProcessState process() return ProcessState.blocked(asVoid(toListenableFuture(isBlocked))); } - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); pageSourceMemoryContext.setBytes(pageSource.getMemoryUsage()); if (page == null) { @@ -388,8 +423,6 @@ public ProcessState process() return ProcessState.yielded(); } - recordMaterializedBytes(page, sizeInBytes -> processedBytes += sizeInBytes); - // update operator stats processedPositions += page.getPositionCount(); physicalBytes = pageSource.getCompletedBytes(); diff --git a/core/trino-main/src/main/java/io/trino/operator/TableScanOperator.java b/core/trino-main/src/main/java/io/trino/operator/TableScanOperator.java index 277a71af1dba..cfe03a05fe02 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TableScanOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/TableScanOperator.java @@ -26,6 +26,7 @@ import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.DynamicFilter; import io.trino.spi.connector.EmptyPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.split.EmptySplit; import io.trino.split.PageSourceProvider; import io.trino.split.PageSourceProviderFactory; @@ -265,8 +266,10 @@ public Page getOutput() source = pageSourceProvider.createPageSource(operatorContext.getSession(), split, table, columns, dynamicFilter); } - Page page = source.getNextPage(); - if (page != null) { + SourcePage sourcePage = source.getNextSourcePage(); + Page page = null; + if (sourcePage != null) { + page = sourcePage.getPage(); // assure the page is in memory before handing to another operator page = page.getLoadedPage(); diff --git a/core/trino-main/src/main/java/io/trino/operator/index/TuplePageFilter.java b/core/trino-main/src/main/java/io/trino/operator/index/TuplePageFilter.java index c4a77023ad1e..19f711f4d708 100644 --- a/core/trino-main/src/main/java/io/trino/operator/index/TuplePageFilter.java +++ b/core/trino-main/src/main/java/io/trino/operator/index/TuplePageFilter.java @@ -20,6 +20,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.type.BlockTypeOperators.BlockPositionEqual; import java.util.List; @@ -66,7 +67,7 @@ public InputChannels getInputChannels() } @Override - public SelectedPositions filter(ConnectorSession session, Page page) + public SelectedPositions filter(ConnectorSession session, SourcePage page) { if (selectedPositions.length < page.getPositionCount()) { selectedPositions = new boolean[page.getPositionCount()]; @@ -79,7 +80,7 @@ public SelectedPositions filter(ConnectorSession session, Page page) return PageFilter.positionsArrayToSelectedPositions(selectedPositions, page.getPositionCount()); } - private boolean matches(Page page, int position) + private boolean matches(SourcePage page, int position) { for (int channel = 0; channel < inputChannels.size(); channel++) { BlockPositionEqual equalOperator = equalOperators.get(channel); diff --git a/core/trino-main/src/main/java/io/trino/operator/project/ConstantPageProjection.java b/core/trino-main/src/main/java/io/trino/operator/project/ConstantPageProjection.java index 4365f594f76a..d3001d07ece4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/ConstantPageProjection.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/ConstantPageProjection.java @@ -17,10 +17,10 @@ import io.trino.operator.CompletedWork; import io.trino.operator.DriverYieldSignal; import io.trino.operator.Work; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import static io.trino.spi.type.TypeUtils.writeNativeValue; @@ -58,7 +58,7 @@ public InputChannels getInputChannels() } @Override - public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) + public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, SourcePage page, SelectedPositions selectedPositions) { return new CompletedWork<>(RunLengthEncodedBlock.create(value, selectedPositions.size())); } diff --git a/core/trino-main/src/main/java/io/trino/operator/project/DictionaryAwarePageFilter.java b/core/trino-main/src/main/java/io/trino/operator/project/DictionaryAwarePageFilter.java index 584083397388..0d540b07d40f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/DictionaryAwarePageFilter.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/DictionaryAwarePageFilter.java @@ -13,11 +13,11 @@ */ package io.trino.operator.project; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import java.util.Optional; @@ -54,7 +54,7 @@ public InputChannels getInputChannels() } @Override - public SelectedPositions filter(ConnectorSession session, Page page) + public SelectedPositions filter(ConnectorSession session, SourcePage page) { Block block = page.getBlock(0).getLoadedBlock(); @@ -79,7 +79,7 @@ public SelectedPositions filter(ConnectorSession session, Page page) } } - return filter.filter(session, new Page(block)); + return filter.filter(session, SourcePage.create(block)); } private Optional processDictionary(ConnectorSession session, Block dictionary, int blockPositionsCount) @@ -99,7 +99,7 @@ private Optional processDictionary(ConnectorSession session, Block di if (shouldProcessDictionary) { try { - SelectedPositions selectedDictionaryPositions = filter.filter(session, new Page(dictionary)); + SelectedPositions selectedDictionaryPositions = filter.filter(session, SourcePage.create(dictionary)); lastOutputDictionary = Optional.of(toPositionsMask(selectedDictionaryPositions, dictionary.getPositionCount())); } catch (Exception _) { diff --git a/core/trino-main/src/main/java/io/trino/operator/project/DictionaryAwarePageProjection.java b/core/trino-main/src/main/java/io/trino/operator/project/DictionaryAwarePageProjection.java index 7e6922bcec4a..4e51a8fb4aab 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/DictionaryAwarePageProjection.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/DictionaryAwarePageProjection.java @@ -16,13 +16,13 @@ import io.trino.operator.CompletedWork; import io.trino.operator.DriverYieldSignal; import io.trino.operator.Work; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.DictionaryId; import io.trino.spi.block.LazyBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; @@ -75,9 +75,9 @@ public InputChannels getInputChannels() } @Override - public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) + public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, SourcePage page, SelectedPositions selectedPositions) { - return new DictionaryAwarePageProjectionWork(session, yieldSignal, page, selectedPositions); + return new DictionaryAwarePageProjectionWork(session, yieldSignal, page.getBlock(0), selectedPositions); } private class DictionaryAwarePageProjectionWork @@ -95,10 +95,10 @@ private class DictionaryAwarePageProjectionWork // always prepare to fall back to a general block in case the dictionary does not apply or fails private Work fallbackProcessingProjectionWork; - public DictionaryAwarePageProjectionWork(@Nullable ConnectorSession session, DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) + public DictionaryAwarePageProjectionWork(@Nullable ConnectorSession session, DriverYieldSignal yieldSignal, Block block, SelectedPositions selectedPositions) { this.session = session; - this.block = page.getBlock(0); + this.block = block; this.selectedPositions = requireNonNull(selectedPositions, "selectedPositions is null"); this.produceLazyBlock = DictionaryAwarePageProjection.this.produceLazyBlock && !block.isLoaded(); @@ -178,7 +178,7 @@ private boolean processInternal() // there is no dictionary handling or dictionary handling failed; fall back to general projection verify(dictionaryProcessingProjectionWork == null); verify(fallbackProcessingProjectionWork == null); - fallbackProcessingProjectionWork = projection.project(session, yieldSignal, new Page(block), selectedPositions); + fallbackProcessingProjectionWork = projection.project(session, yieldSignal, SourcePage.create(block), selectedPositions); if (fallbackProcessingProjectionWork.process()) { result = fallbackProcessingProjectionWork.getResult(); return true; @@ -241,7 +241,7 @@ private Work createDictionaryBlockProjection(Optional dictionary, lastOutputDictionary = Optional.empty(); if (shouldProcessDictionary) { - return projection.project(session, yieldSignal, new Page(lastInputDictionary), SelectedPositions.positionsRange(0, lastInputDictionary.getPositionCount())); + return projection.project(session, yieldSignal, SourcePage.create(lastInputDictionary), SelectedPositions.positionsRange(0, lastInputDictionary.getPositionCount())); } return null; } diff --git a/core/trino-main/src/main/java/io/trino/operator/project/GeneratedPageProjection.java b/core/trino-main/src/main/java/io/trino/operator/project/GeneratedPageProjection.java index 85bedc543813..6e1b325409b4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/GeneratedPageProjection.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/GeneratedPageProjection.java @@ -15,10 +15,10 @@ import io.trino.operator.DriverYieldSignal; import io.trino.operator.Work; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import io.trino.sql.relational.RowExpression; @@ -65,7 +65,7 @@ public InputChannels getInputChannels() } @Override - public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) + public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, SourcePage page, SelectedPositions selectedPositions) { blockBuilder = blockBuilder.newBlockBuilderLike(selectedPositions.size(), null); try { diff --git a/core/trino-main/src/main/java/io/trino/operator/project/InputChannels.java b/core/trino-main/src/main/java/io/trino/operator/project/InputChannels.java index 4c4ffcd979e3..a53f582cb6a1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/InputChannels.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/InputChannels.java @@ -13,36 +13,46 @@ */ package io.trino.operator.project; -import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; +import jakarta.annotation.Nullable; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Set; +import java.util.function.ObjLongConsumer; import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; public class InputChannels { private final int[] inputChannels; - private final int[] eagerlyLoadedChannels; + @Nullable + private final boolean[] eagerlyLoad; public InputChannels(int... inputChannels) { this.inputChannels = inputChannels.clone(); - this.eagerlyLoadedChannels = new int[0]; + this.eagerlyLoad = null; } public InputChannels(List inputChannels) { - this(inputChannels, ImmutableList.of()); + this.inputChannels = inputChannels.stream().mapToInt(Integer::intValue).toArray(); + this.eagerlyLoad = null; } - public InputChannels(List inputChannels, List eagerlyLoadedChannels) + public InputChannels(List inputChannels, Set eagerlyLoadedChannels) { this.inputChannels = inputChannels.stream().mapToInt(Integer::intValue).toArray(); - this.eagerlyLoadedChannels = eagerlyLoadedChannels.stream().mapToInt(Integer::intValue).toArray(); + this.eagerlyLoad = new boolean[this.inputChannels.length]; + for (int i = 0; i < this.inputChannels.length; i++) { + eagerlyLoad[i] = eagerlyLoadedChannels.contains(this.inputChannels[i]); + } } public int size() @@ -55,9 +65,9 @@ public List getInputChannels() return Collections.unmodifiableList(Ints.asList(inputChannels)); } - public Page getInputChannels(Page page) + public SourcePage getInputChannels(SourcePage page) { - return page.getLoadedPage(inputChannels, eagerlyLoadedChannels); + return new InputChannelsSourcePage(page, inputChannels, eagerlyLoad); } @Override @@ -67,4 +77,106 @@ public String toString() .addValue(Arrays.toString(inputChannels)) .toString(); } + + private static final class InputChannelsSourcePage + implements SourcePage + { + private final SourcePage sourcePage; + private final int[] channels; + private final Block[] blocks; + + private InputChannelsSourcePage(SourcePage sourcePage, int[] channels, @Nullable boolean[] eagerlyLoad) + { + requireNonNull(sourcePage, "sourcePage is null"); + requireNonNull(channels, "channels is null"); + + this.sourcePage = sourcePage; + this.channels = channels; + this.blocks = new Block[channels.length]; + + if (eagerlyLoad != null) { + for (int channel = 0; channel < eagerlyLoad.length; channel++) { + if (eagerlyLoad[channel]) { + this.blocks[channel] = sourcePage.getBlock(channels[channel]).getLoadedBlock(); + } + } + } + } + + @Override + public int getPositionCount() + { + return sourcePage.getPositionCount(); + } + + @Override + public long getSizeInBytes() + { + return sourcePage.getSizeInBytes(); + } + + @Override + public long getRetainedSizeInBytes() + { + return sourcePage.getRetainedSizeInBytes(); + } + + @Override + public void retainedBytesForEachPart(ObjLongConsumer consumer) + { + for (Block block : blocks) { + if (block != null) { + block.retainedBytesForEachPart(consumer); + } + } + } + + @Override + public int getChannelCount() + { + return blocks.length; + } + + @Override + public Block getBlock(int channel) + { + Block block = blocks[channel]; + if (block == null) { + block = sourcePage.getBlock(channels[channel]); + blocks[channel] = block; + } + return block; + } + + @Override + public Page getPage() + { + for (int i = 0; i < blocks.length; i++) { + getBlock(i); + } + return new Page(getPositionCount(), blocks); + } + + @Override + public Page getColumns(int[] channels) + { + Block[] blocks = new Block[channels.length]; + for (int i = 0; i < channels.length; i++) { + blocks[i] = getBlock(channels[i]); + } + return new Page(getPositionCount(), blocks); + } + + @Override + public void selectPositions(int[] positions, int offset, int size) + { + sourcePage.selectPositions(positions, offset, size); + for (int i = 0; i < blocks.length; i++) { + Block block = blocks[i]; + if (block != null) { + blocks[i] = block.getPositions(positions, offset, size); + } + } + } + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/project/InputPageProjection.java b/core/trino-main/src/main/java/io/trino/operator/project/InputPageProjection.java index 2670100008ae..3c9a68142303 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/InputPageProjection.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/InputPageProjection.java @@ -16,9 +16,9 @@ import io.trino.operator.CompletedWork; import io.trino.operator.DriverYieldSignal; import io.trino.operator.Work; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import static java.util.Objects.requireNonNull; @@ -54,12 +54,11 @@ public InputChannels getInputChannels() } @Override - public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) + public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, SourcePage page, SelectedPositions selectedPositions) { Block block = page.getBlock(0); requireNonNull(selectedPositions, "selectedPositions is null"); - // TODO: make it lazy when MergePages have better merging heuristics for small lazy pages if (selectedPositions.isList()) { block = block.copyPositions(selectedPositions.getPositions(), selectedPositions.getOffset(), selectedPositions.size()); } diff --git a/core/trino-main/src/main/java/io/trino/operator/project/PageFieldsToInputParametersRewriter.java b/core/trino-main/src/main/java/io/trino/operator/project/PageFieldsToInputParametersRewriter.java index d5410fbe9309..418416b145b3 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/PageFieldsToInputParametersRewriter.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/PageFieldsToInputParametersRewriter.java @@ -14,6 +14,7 @@ package io.trino.operator.project; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import io.trino.sql.relational.CallExpression; import io.trino.sql.relational.ConstantExpression; import io.trino.sql.relational.InputReferenceExpression; @@ -63,9 +64,9 @@ public List getInputChannels() return ImmutableList.copyOf(inputChannels); } - public List getEagerlyLoadedChannels() + public Set getEagerlyLoadedChannels() { - return ImmutableList.copyOf(eagerlyLoadedChannels); + return ImmutableSet.copyOf(eagerlyLoadedChannels); } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/project/PageFilter.java b/core/trino-main/src/main/java/io/trino/operator/project/PageFilter.java index 7fe96d09efb8..d5eb17b47f72 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/PageFilter.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/PageFilter.java @@ -13,8 +13,8 @@ */ package io.trino.operator.project; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; public interface PageFilter { @@ -22,7 +22,7 @@ public interface PageFilter InputChannels getInputChannels(); - SelectedPositions filter(ConnectorSession session, Page page); + SelectedPositions filter(ConnectorSession session, SourcePage page); static SelectedPositions positionsArrayToSelectedPositions(boolean[] selectedPositions, int size) { diff --git a/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java b/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java index 13f265cac3af..06317e5c27e2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java @@ -26,6 +26,7 @@ import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.DictionaryId; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.sql.gen.ExpressionProfiler; import io.trino.sql.gen.columnar.FilterEvaluator; @@ -92,7 +93,7 @@ public PageProcessor(Optional filterEvaluator, List> process(ConnectorSession session, DriverYieldSignal yieldSignal, LocalMemoryContext memoryContext, Page page) + public Iterator> process(ConnectorSession session, DriverYieldSignal yieldSignal, LocalMemoryContext memoryContext, SourcePage page) { WorkProcessor processor = createWorkProcessor(session, yieldSignal, memoryContext, new PageProcessorMetrics(), page); return processor.yieldingIterator(); @@ -103,7 +104,7 @@ public WorkProcessor createWorkProcessor( DriverYieldSignal yieldSignal, LocalMemoryContext memoryContext, PageProcessorMetrics metrics, - Page page) + SourcePage page) { // limit the scope of the dictionary ids to just one page dictionarySourceIdFunction.reset(); @@ -146,7 +147,7 @@ private class ProjectSelectedPositions private final LocalMemoryContext memoryContext; private final PageProcessorMetrics metrics; - private Page page; + private SourcePage page; private final Block[] previouslyComputedResults; private SelectedPositions selectedPositions; private long retainedSizeInBytes; @@ -161,7 +162,7 @@ private ProjectSelectedPositions( DriverYieldSignal yieldSignal, LocalMemoryContext memoryContext, PageProcessorMetrics metrics, - Page page, + SourcePage page, SelectedPositions selectedPositions) { checkArgument(!selectedPositions.isEmpty(), "selectedPositions is empty"); @@ -260,19 +261,12 @@ private void updateBatchSize(int positionCount, long pageSize) private void updateRetainedSize() { // increment the size only when it is the first reference - retainedSizeInBytes = Page.getInstanceSizeInBytes(page.getChannelCount()); ReferenceCountMap referenceCountMap = new ReferenceCountMap(); - for (int channel = 0; channel < page.getChannelCount(); channel++) { - Block block = page.getBlock(channel); - // TODO: block might be partially loaded - if (block.isLoaded()) { - block.retainedBytesForEachPart((object, size) -> { - if (referenceCountMap.incrementAndGet(object) == 1) { - retainedSizeInBytes += size; - } - }); + page.retainedBytesForEachPart((object, size) -> { + if (referenceCountMap.incrementAndGet(object) == 1) { + retainedSizeInBytes += size; } - } + }); for (Block previouslyComputedResult : previouslyComputedResults) { if (previouslyComputedResult != null) { previouslyComputedResult.retainedBytesForEachPart((object, size) -> { diff --git a/core/trino-main/src/main/java/io/trino/operator/project/PageProjection.java b/core/trino-main/src/main/java/io/trino/operator/project/PageProjection.java index 2d96da62e4f0..a802ba39ad69 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/PageProjection.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/PageProjection.java @@ -15,9 +15,9 @@ import io.trino.operator.DriverYieldSignal; import io.trino.operator.Work; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; public interface PageProjection @@ -28,5 +28,5 @@ public interface PageProjection InputChannels getInputChannels(); - Work project(ConnectorSession session, DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions); + Work project(ConnectorSession session, DriverYieldSignal yieldSignal, SourcePage page, SelectedPositions selectedPositions); } diff --git a/core/trino-main/src/main/java/io/trino/operator/window/pattern/ArgumentComputation.java b/core/trino-main/src/main/java/io/trino/operator/window/pattern/ArgumentComputation.java index 38b1a4b8f363..7b4a9fb371bd 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/pattern/ArgumentComputation.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/pattern/ArgumentComputation.java @@ -19,6 +19,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import java.util.List; @@ -55,7 +56,7 @@ public List getInputChannels() public Block compute(Block[] blocks) { // wrap block array into a single-row page - Page page = new Page(1, blocks); + SourcePage page = SourcePage.create(new Page(1, blocks)); // evaluate expression Work work = projection.project(session, new DriverYieldSignal(), projection.getInputChannels().getInputChannels(page), positionsRange(0, 1)); diff --git a/core/trino-main/src/main/java/io/trino/operator/window/pattern/MeasureComputation.java b/core/trino-main/src/main/java/io/trino/operator/window/pattern/MeasureComputation.java index 69a45649c9e2..1f5f6ba51d86 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/pattern/MeasureComputation.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/pattern/MeasureComputation.java @@ -20,6 +20,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import java.util.List; @@ -104,7 +105,7 @@ public Block computeEmpty(long matchNumber) } // wrap block array into a single-row page - Page page = new Page(1, blocks); + SourcePage page = SourcePage.create(new Page(1, blocks)); // evaluate expression Work work = projection.project(session, new DriverYieldSignal(), projection.getInputChannels().getInputChannels(page), positionsRange(0, 1)); @@ -174,7 +175,7 @@ public static Block compute( } // wrap block array into a single-row page - Page page = new Page(1, blocks); + SourcePage page = SourcePage.create(new Page(1, blocks)); // evaluate expression Work work = projection.project(session, new DriverYieldSignal(), projection.getInputChannels().getInputChannels(page), positionsRange(0, 1)); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/PageFunctionCompiler.java b/core/trino-main/src/main/java/io/trino/sql/gen/PageFunctionCompiler.java index a2918be1dab7..d2f6fcc994cb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/PageFunctionCompiler.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/PageFunctionCompiler.java @@ -44,11 +44,11 @@ import io.trino.operator.project.PageFilter; import io.trino.operator.project.PageProjection; import io.trino.operator.project.SelectedPositions; -import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.sql.gen.LambdaBytecodeGenerator.CompiledLambda; import io.trino.sql.planner.CompilerConfig; import io.trino.sql.relational.ConstantExpression; @@ -213,7 +213,7 @@ private Supplier compileProjectionInternal(RowExpression project throw new TrinoException(COMPILER_ERROR, e); } - MethodHandle pageProjectionConstructor = constructorMethodHandle(pageProjectionWorkClass, BlockBuilder.class, ConnectorSession.class, Page.class, SelectedPositions.class); + MethodHandle pageProjectionConstructor = constructorMethodHandle(pageProjectionWorkClass, BlockBuilder.class, ConnectorSession.class, SourcePage.class, SelectedPositions.class); return () -> new GeneratedPageProjection( result.getRewrittenExpression(), isExpressionDeterministic, @@ -256,7 +256,7 @@ private ClassDefinition definePageProjectWorkClass(RowExpression projection, Cal // constructor Parameter blockBuilder = arg("blockBuilder", BlockBuilder.class); Parameter session = arg("session", ConnectorSession.class); - Parameter page = arg("page", Page.class); + Parameter page = arg("page", SourcePage.class); Parameter selectedPositions = arg("selectedPositions", SelectedPositions.class); MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC), blockBuilder, session, page, selectedPositions); @@ -475,7 +475,7 @@ private ClassDefinition defineFilterClass(RowExpression filter, InputChannels in private static MethodDefinition generatePageFilterMethod(ClassDefinition classDefinition, FieldDefinition selectedPositionsField) { Parameter session = arg("session", ConnectorSession.class); - Parameter page = arg("page", Page.class); + Parameter page = arg("page", SourcePage.class); MethodDefinition method = classDefinition.declareMethod( a(PUBLIC), @@ -523,7 +523,7 @@ private MethodDefinition generateFilterMethod( RowExpression filter) { Parameter session = arg("session", ConnectorSession.class); - Parameter page = arg("page", Page.class); + Parameter page = arg("page", SourcePage.class); Parameter position = arg("position", int.class); MethodDefinition method = classDefinition.declareMethod( diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/AndFilterEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/AndFilterEvaluator.java index b61e50622e74..e280e142493f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/AndFilterEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/AndFilterEvaluator.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableList; import io.trino.operator.project.SelectedPositions; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.sql.relational.RowExpression; import io.trino.sql.relational.SpecialForm; @@ -57,7 +57,7 @@ private AndFilterEvaluator(List subFilterEvaluators) } @Override - public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, Page page) + public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, SourcePage page) { long filterTimeNanos = 0; for (FilterEvaluator evaluator : subFilterEvaluators) { diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/BetweenInlineColumnarFilterGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/BetweenInlineColumnarFilterGenerator.java index 600e7b281674..ae9651b540e6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/BetweenInlineColumnarFilterGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/BetweenInlineColumnarFilterGenerator.java @@ -24,8 +24,8 @@ import io.airlift.bytecode.control.IfStatement; import io.trino.metadata.FunctionManager; import io.trino.metadata.ResolvedFunction; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.sql.gen.CallSiteBinder; import io.trino.sql.relational.CallExpression; import io.trino.sql.relational.InputReferenceExpression; @@ -104,7 +104,7 @@ private void generateFilterRangeMethod(CallSiteBinder binder, ClassDefinition cl Parameter outputPositions = arg("outputPositions", int[].class); Parameter offset = arg("offset", int.class); Parameter size = arg("size", int.class); - Parameter page = arg("page", Page.class); + Parameter page = arg("page", SourcePage.class); MethodDefinition method = classDefinition.declareMethod( a(PUBLIC), @@ -170,7 +170,7 @@ private void generateFilterListMethod(CallSiteBinder binder, ClassDefinition cla Parameter activePositions = arg("activePositions", int[].class); Parameter offset = arg("offset", int.class); Parameter size = arg("size", int.class); - Parameter page = arg("page", Page.class); + Parameter page = arg("page", SourcePage.class); MethodDefinition method = classDefinition.declareMethod( a(PUBLIC), diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/CallColumnarFilterGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/CallColumnarFilterGenerator.java index b32fdb8cb740..1119dc2a47a5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/CallColumnarFilterGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/CallColumnarFilterGenerator.java @@ -27,8 +27,8 @@ import io.airlift.bytecode.expression.BytecodeExpression; import io.trino.metadata.FunctionManager; import io.trino.metadata.ResolvedFunction; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.function.FunctionNullability; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.ScalarFunctionImplementation; @@ -129,7 +129,7 @@ private void generateFilterRangeMethod(ClassDefinition classDefinition, CallSite Parameter outputPositions = arg("outputPositions", int[].class); Parameter offset = arg("offset", int.class); Parameter size = arg("size", int.class); - Parameter page = arg("page", Page.class); + Parameter page = arg("page", SourcePage.class); MethodDefinition method = classDefinition.declareMethod( a(PUBLIC), @@ -197,7 +197,7 @@ private void generateFilterListMethod(ClassDefinition classDefinition, CallSiteB Parameter activePositions = arg("activePositions", int[].class); Parameter offset = arg("offset", int.class); Parameter size = arg("size", int.class); - Parameter page = arg("page", Page.class); + Parameter page = arg("page", SourcePage.class); MethodDefinition method = classDefinition.declareMethod( a(PUBLIC), diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/ColumnarFilter.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/ColumnarFilter.java index 7b9a8aa27b8b..b6062caa7895 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/ColumnarFilter.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/ColumnarFilter.java @@ -14,8 +14,8 @@ package io.trino.sql.gen.columnar; import io.trino.operator.project.InputChannels; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; /** * Implementations of this interface evaluate a filter on the input Page. @@ -40,7 +40,7 @@ public interface ColumnarFilter * @param loadedPage input Page after using {@link ColumnarFilter#getInputChannels} to load only the required channels * @return count of positions active after evaluating this filter on the input loadedPage */ - int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, Page loadedPage); + int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, SourcePage loadedPage); /** * @param outputPositions list of positions active after evaluating this filter on the input loadedPage @@ -50,7 +50,7 @@ public interface ColumnarFilter * @param loadedPage input Page after using {@link ColumnarFilter#getInputChannels} to load only the required channels * @return count of positions active after evaluating this filter on the input loadedPage */ - int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, Page loadedPage); + int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, SourcePage loadedPage); /** * @return InputChannels of input Page that this filter operates on diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/ColumnarFilterEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/ColumnarFilterEvaluator.java index 08f39e74cf23..f2444c69925d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/ColumnarFilterEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/ColumnarFilterEvaluator.java @@ -14,8 +14,8 @@ package io.trino.sql.gen.columnar; import io.trino.operator.project.SelectedPositions; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import static io.trino.operator.project.SelectedPositions.positionsList; import static io.trino.operator.project.SelectedPositions.positionsRange; @@ -33,13 +33,13 @@ public ColumnarFilterEvaluator(ColumnarFilter filter) } @Override - public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, Page page) + public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, SourcePage page) { if (activePositions.isEmpty()) { return new SelectionResult(activePositions, 0); } // Should load only the blocks necessary for evaluating the kernel and unwrap lazy blocks - Page loadedPage = filter.getInputChannels().getInputChannels(page); + SourcePage loadedPage = filter.getInputChannels().getInputChannels(page); if (outputPositions.length < activePositions.size()) { outputPositions = new int[activePositions.size()]; } diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DictionaryAwareColumnarFilter.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DictionaryAwareColumnarFilter.java index 920d055cfaf0..eef33abb9b9a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DictionaryAwareColumnarFilter.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DictionaryAwareColumnarFilter.java @@ -14,11 +14,11 @@ package io.trino.sql.gen.columnar; import io.trino.operator.project.InputChannels; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import static com.google.common.base.Verify.verify; import static java.lang.System.arraycopy; @@ -38,7 +38,7 @@ public DictionaryAwareColumnarFilter(ColumnarFilter columnarFilter) } @Override - public int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, Page loadedPage) + public int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, SourcePage loadedPage) { Block block = loadedPage.getBlock(0); if (block instanceof RunLengthEncodedBlock runLengthEncodedBlock) { @@ -60,7 +60,7 @@ else if (block instanceof DictionaryBlock dictionaryBlock) { } @Override - public int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, Page loadedPage) + public int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, SourcePage loadedPage) { Block block = loadedPage.getBlock(0); if (block instanceof RunLengthEncodedBlock runLengthEncodedBlock) { @@ -142,7 +142,7 @@ private boolean[] selectedDictionaryMask(ConnectorSession session, Block diction int positionCount = dictionary.getPositionCount(); int[] selectedPositions = new int[positionCount]; - int selectedPositionsCount = columnarFilter.filterPositionsRange(session, selectedPositions, 0, positionCount, new Page(positionCount, dictionary)); + int selectedPositionsCount = columnarFilter.filterPositionsRange(session, selectedPositions, 0, positionCount, SourcePage.create(dictionary)); boolean[] positionsMask = new boolean[positionCount]; for (int index = 0; index < selectedPositionsCount; index++) { diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DynamicPageFilter.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DynamicPageFilter.java index 7e3814426ba2..b9abeed4ed32 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DynamicPageFilter.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/DynamicPageFilter.java @@ -18,10 +18,10 @@ import io.trino.Session; import io.trino.metadata.Metadata; import io.trino.operator.project.SelectedPositions; -import io.trino.spi.Page; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.TypeManager; import io.trino.sql.PlannerContext; @@ -150,7 +150,7 @@ private DynamicFilterEvaluator(List subFilterEvaluators, double } @Override - public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, Page page) + public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, SourcePage page) { long filterTimeNanos = 0; for (int filterIndex = 0; filterIndex < subFilterEvaluators.size(); filterIndex++) { diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/FilterEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/FilterEvaluator.java index cb0ef354e653..983c7787e18c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/FilterEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/FilterEvaluator.java @@ -16,8 +16,8 @@ import com.google.common.collect.ImmutableList; import io.trino.metadata.ResolvedFunction; import io.trino.operator.project.SelectedPositions; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.type.Type; import io.trino.sql.relational.CallExpression; @@ -60,7 +60,7 @@ public sealed interface FilterEvaluator SelectNoneEvaluator, DynamicFilterEvaluator { - SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, Page page); + SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, SourcePage page); record SelectionResult(SelectedPositions selectedPositions, long filterTimeNanos) {} diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/InColumnarFilterGenerator.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/InColumnarFilterGenerator.java index d9776c0d10a8..014beb8e9d60 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/InColumnarFilterGenerator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/InColumnarFilterGenerator.java @@ -30,8 +30,8 @@ import io.airlift.slice.Slice; import io.trino.metadata.FunctionManager; import io.trino.metadata.ResolvedFunction; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import io.trino.sql.gen.Binding; import io.trino.sql.gen.CallSiteBinder; @@ -157,7 +157,7 @@ private void generateFilterRangeMethod(CallSiteBinder binder, ClassDefinition cl Parameter outputPositions = arg("outputPositions", int[].class); Parameter offset = arg("offset", int.class); Parameter size = arg("size", int.class); - Parameter page = arg("page", Page.class); + Parameter page = arg("page", SourcePage.class); MethodDefinition method = classDefinition.declareMethod( a(PUBLIC), @@ -205,7 +205,7 @@ private void generateFilterListMethod(CallSiteBinder binder, ClassDefinition cla Parameter activePositions = arg("activePositions", int[].class); Parameter offset = arg("offset", int.class); Parameter size = arg("size", int.class); - Parameter page = arg("page", Page.class); + Parameter page = arg("page", SourcePage.class); MethodDefinition method = classDefinition.declareMethod( a(PUBLIC), diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/IsNotNullColumnarFilter.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/IsNotNullColumnarFilter.java index fc824a82b339..9fb03793de97 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/IsNotNullColumnarFilter.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/IsNotNullColumnarFilter.java @@ -14,15 +14,15 @@ package io.trino.sql.gen.columnar; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import io.trino.operator.project.InputChannels; -import io.trino.spi.Page; import io.trino.spi.block.ByteArrayBlock; import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.sql.relational.InputReferenceExpression; import io.trino.sql.relational.SpecialForm; -import java.util.List; import java.util.Optional; import java.util.function.Supplier; @@ -47,8 +47,7 @@ public static Supplier createIsNotNullColumnarFilter(SpecialForm private IsNotNullColumnarFilter(InputReferenceExpression inputReference) { - List channels = ImmutableList.of(inputReference.field()); - this.inputChannels = new InputChannels(channels, channels); + this.inputChannels = new InputChannels(ImmutableList.of(inputReference.field()), ImmutableSet.of(inputReference.field())); } @Override @@ -58,7 +57,7 @@ public InputChannels getInputChannels() } @Override - public int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, Page page) + public int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, SourcePage page) { ValueBlock block = (ValueBlock) page.getBlock(0); int nonNullPositionsCount = 0; @@ -83,7 +82,7 @@ public int filterPositionsRange(ConnectorSession session, int[] outputPositions, } @Override - public int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, Page page) + public int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, SourcePage page) { ValueBlock block = (ValueBlock) page.getBlock(0); if (block.mayHaveNull()) { diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/IsNullColumnarFilter.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/IsNullColumnarFilter.java index 941b4edf10d5..0b56fb670651 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/IsNullColumnarFilter.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/IsNullColumnarFilter.java @@ -14,15 +14,15 @@ package io.trino.sql.gen.columnar; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import io.trino.operator.project.InputChannels; -import io.trino.spi.Page; import io.trino.spi.block.ByteArrayBlock; import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.sql.relational.InputReferenceExpression; import io.trino.sql.relational.SpecialForm; -import java.util.List; import java.util.Optional; import java.util.function.Supplier; @@ -46,8 +46,7 @@ public static Supplier createIsNullColumnarFilter(SpecialForm sp private IsNullColumnarFilter(InputReferenceExpression inputReference) { - List channels = ImmutableList.of(inputReference.field()); - this.inputChannels = new InputChannels(channels, channels); + this.inputChannels = new InputChannels(ImmutableList.of(inputReference.field()), ImmutableSet.of(inputReference.field())); } @Override @@ -57,7 +56,7 @@ public InputChannels getInputChannels() } @Override - public int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, Page page) + public int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, SourcePage page) { ValueBlock block = (ValueBlock) page.getBlock(0); if (!block.mayHaveNull()) { @@ -80,7 +79,7 @@ public int filterPositionsRange(ConnectorSession session, int[] outputPositions, } @Override - public int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, Page page) + public int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, SourcePage page) { ValueBlock block = (ValueBlock) page.getBlock(0); if (!block.mayHaveNull()) { diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/OrFilterEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/OrFilterEvaluator.java index 4a9e2569bfce..bbdd80b078e2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/OrFilterEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/OrFilterEvaluator.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableList; import io.trino.operator.project.SelectedPositions; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.sql.relational.RowExpression; import io.trino.sql.relational.SpecialForm; @@ -57,7 +57,7 @@ private OrFilterEvaluator(List subFilterEvaluators) } @Override - public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, Page page) + public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, SourcePage page) { long filterTimeNanos = 0; SelectionResult result = subFilterEvaluators.getFirst().evaluate(session, activePositions, page); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/PageFilterEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/PageFilterEvaluator.java index 3331712fa810..415055a82050 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/PageFilterEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/PageFilterEvaluator.java @@ -16,8 +16,8 @@ import io.trino.operator.project.DictionaryAwarePageFilter; import io.trino.operator.project.PageFilter; import io.trino.operator.project.SelectedPositions; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; public final class PageFilterEvaluator implements FilterEvaluator @@ -35,9 +35,9 @@ public PageFilterEvaluator(PageFilter filter) } @Override - public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, Page page) + public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, SourcePage page) { - Page inputPage = filter.getInputChannels().getInputChannels(page); + SourcePage inputPage = filter.getInputChannels().getInputChannels(page); long start = System.nanoTime(); SelectedPositions selectedPositions = filter.filter(session, inputPage); return new SelectionResult(selectedPositions, System.nanoTime() - start); diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/SelectAllEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/SelectAllEvaluator.java index ccede31cf162..e7187f7f023a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/SelectAllEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/SelectAllEvaluator.java @@ -14,14 +14,14 @@ package io.trino.sql.gen.columnar; import io.trino.operator.project.SelectedPositions; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; public final class SelectAllEvaluator implements FilterEvaluator { @Override - public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, Page page) + public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, SourcePage page) { return new SelectionResult(activePositions, 0); } diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/SelectNoneEvaluator.java b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/SelectNoneEvaluator.java index cbe0005b8ca3..ceb2bd4aaeff 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/columnar/SelectNoneEvaluator.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/columnar/SelectNoneEvaluator.java @@ -14,8 +14,8 @@ package io.trino.sql.gen.columnar; import io.trino.operator.project.SelectedPositions; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import static io.trino.operator.project.SelectedPositions.positionsRange; @@ -23,7 +23,7 @@ public final class SelectNoneEvaluator implements FilterEvaluator { @Override - public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, Page page) + public SelectionResult evaluate(ConnectorSession session, SelectedPositions activePositions, SourcePage page) { return new SelectionResult(positionsRange(0, 0), 0); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java index c85c794f0108..9b0436bdf5f2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -32,11 +32,11 @@ import io.trino.metadata.ResolvedFunction; import io.trino.metadata.Split; import io.trino.metadata.TableHandle; -import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; import io.trino.split.PageSourceManager; @@ -468,7 +468,7 @@ private static KdbTree loadKdbTree(String tableName, Session session, Metadata m try (ConnectorPageSource pageSource = statefulPageSourceProvider.createPageSource(session, split, tableHandle, ImmutableList.of(kdbTreeColumn), DynamicFilter.EMPTY)) { do { getFutureValue(pageSource.isBlocked()); - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); if (page != null && page.getPositionCount() > 0) { checkSpatialPartitioningTable(kdbTree.isEmpty(), "Expected exactly one row for table %s, but found more", name); checkSpatialPartitioningTable(page.getPositionCount() == 1, "Expected exactly one row for table %s, but found %s rows", name, page.getPositionCount()); diff --git a/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java b/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java index 6c66cc7d4806..dd3db1d5510e 100644 --- a/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java +++ b/core/trino-main/src/main/java/io/trino/testing/MaterializedResult.java @@ -24,6 +24,7 @@ import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.SqlDate; import io.trino.spi.type.SqlDecimal; import io.trino.spi.type.SqlTime; @@ -336,11 +337,11 @@ public static MaterializedResult materializeSourceDataStream(ConnectorSession se { MaterializedResult.Builder builder = resultBuilder(session, types); while (!pageSource.isFinished()) { - Page outputPage = pageSource.getNextPage(); + SourcePage outputPage = pageSource.getNextSourcePage(); if (outputPage == null) { continue; } - builder.page(outputPage); + builder.page(outputPage.getPage()); } return builder.build(); } diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnectorPageSource.java b/core/trino-main/src/test/java/io/trino/connector/MockConnectorPageSource.java index b78322440d3a..f16ae6d43e08 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnectorPageSource.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnectorPageSource.java @@ -15,6 +15,7 @@ import io.trino.spi.Page; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.metrics.Metrics; import java.io.IOException; @@ -60,11 +61,18 @@ public boolean isFinished() } @Override + @SuppressWarnings("removal") public Page getNextPage() { return delegate.getNextPage(); } + @Override + public SourcePage getNextSourcePage() + { + return delegate.getNextSourcePage(); + } + @Override public long getMemoryUsage() { diff --git a/core/trino-main/src/test/java/io/trino/operator/TestColumnarPageProcessor.java b/core/trino-main/src/test/java/io/trino/operator/TestColumnarPageProcessor.java index 2e4da92eaa9f..80f091d90645 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestColumnarPageProcessor.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestColumnarPageProcessor.java @@ -17,6 +17,7 @@ import io.trino.metadata.FunctionManager; import io.trino.operator.project.PageProcessor; import io.trino.spi.Page; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import io.trino.sql.gen.CursorProcessorCompiler; import io.trino.sql.gen.ExpressionCompiler; @@ -55,7 +56,7 @@ public void testProcess() SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - page)) + SourcePage.create(page))) .orElseThrow(() -> new AssertionError("page is not present")); assertPageEquals(types, outputPage, page); } @@ -70,7 +71,7 @@ public void testProcessWithDictionary() SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - page)) + SourcePage.create(page))) .orElseThrow(() -> new AssertionError("page is not present")); assertPageEquals(types, outputPage, page); } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestPageUtils.java b/core/trino-main/src/test/java/io/trino/operator/TestPageUtils.java deleted file mode 100644 index 821c2f5369d2..000000000000 --- a/core/trino-main/src/test/java/io/trino/operator/TestPageUtils.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.operator; - -import io.trino.spi.Page; -import io.trino.spi.block.ArrayBlock; -import io.trino.spi.block.Block; -import io.trino.spi.block.LazyBlock; -import org.junit.jupiter.api.Test; - -import java.util.Optional; -import java.util.concurrent.atomic.AtomicLong; - -import static io.trino.block.BlockAssertions.createIntsBlock; -import static io.trino.operator.PageUtils.recordMaterializedBytes; -import static org.assertj.core.api.Assertions.assertThat; - -public class TestPageUtils -{ - @Test - public void testRecordMaterializedBytes() - { - Block first = createIntsBlock(1, 2, 3); - LazyBlock second = lazyWrapper(first); - LazyBlock third = lazyWrapper(first); - Page page = new Page(3, first, second, third); - - second.getLoadedBlock(); - - AtomicLong sizeInBytes = new AtomicLong(); - recordMaterializedBytes(page, sizeInBytes::getAndAdd); - - assertThat(sizeInBytes.get()).isEqualTo(first.getSizeInBytes() * 2); - - page.getBlock(2).getLoadedBlock(); - assertThat(sizeInBytes.get()).isEqualTo(first.getSizeInBytes() * 3); - } - - @Test - public void testNestedBlocks() - { - Block elements = lazyWrapper(createIntsBlock(1, 2, 3)); - Block arrayBlock = ArrayBlock.fromElementBlock(2, Optional.empty(), new int[] {0, 1, 3}, elements); - long initialArraySize = arrayBlock.getSizeInBytes(); - Page page = new Page(2, arrayBlock); - - AtomicLong sizeInBytes = new AtomicLong(); - recordMaterializedBytes(page, sizeInBytes::getAndAdd); - - assertThat(arrayBlock.getSizeInBytes()).isEqualTo(initialArraySize); - assertThat(sizeInBytes.get()).isEqualTo(arrayBlock.getSizeInBytes()); - - // dictionary block caches size in bytes - arrayBlock.getLoadedBlock(); - assertThat(sizeInBytes.get()).isEqualTo(arrayBlock.getSizeInBytes()); - assertThat(sizeInBytes.get()).isEqualTo(initialArraySize + elements.getSizeInBytes()); - } - - private static LazyBlock lazyWrapper(Block block) - { - return new LazyBlock(block.getPositionCount(), block::getLoadedBlock); - } -} diff --git a/core/trino-main/src/test/java/io/trino/operator/TestScanFilterAndProjectOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestScanFilterAndProjectOperator.java index 2fe77398980c..625834c5cf18 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestScanFilterAndProjectOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestScanFilterAndProjectOperator.java @@ -29,12 +29,14 @@ import io.trino.operator.project.TestPageProcessor.LazyPagePageProjection; import io.trino.operator.project.TestPageProcessor.SelectAllFilter; import io.trino.spi.Page; +import io.trino.spi.block.ArrayBlock; import io.trino.spi.block.Block; import io.trino.spi.block.LazyBlock; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.DynamicFilter; import io.trino.spi.connector.FixedPageSource; import io.trino.spi.connector.RecordPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.sql.gen.CursorProcessorCompiler; import io.trino.sql.gen.ExpressionCompiler; import io.trino.sql.gen.PageFunctionCompiler; @@ -56,6 +58,7 @@ import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import static io.airlift.concurrent.Threads.daemonThreadsNamed; @@ -63,6 +66,7 @@ import static io.airlift.units.DataSize.Unit.KILOBYTE; import static io.trino.RowPagesBuilder.rowPagesBuilder; import static io.trino.SessionTestUtils.TEST_SESSION; +import static io.trino.block.BlockAssertions.createIntsBlock; import static io.trino.block.BlockAssertions.toValues; import static io.trino.operator.OperatorAssertion.toMaterializedResult; import static io.trino.operator.PageAssertions.assertPageEquals; @@ -206,9 +210,7 @@ public void testPageSourceLazyLoad() { Block inputBlock = BlockAssertions.createLongSequenceBlock(0, 100); // If column 1 is loaded, test will fail - Page input = new Page(100, inputBlock, new LazyBlock(100, () -> { - throw new AssertionError("Lazy block should not be loaded"); - })); + TestingSourcePage input = new TestingSourcePage(100, inputBlock, null); DriverContext driverContext = newDriverContext(); List projections = ImmutableList.of(field(0, VARCHAR)); @@ -405,6 +407,52 @@ public void testRecordCursorYield() assertThat(toValues(BIGINT, output.getBlock(0))).isEqualTo(toValues(BIGINT, input.getBlock(0))); } + @Test + public void testRecordMaterializedBytes() + { + Block first = createIntsBlock(1, 2, 3); + LazyBlock second = lazyWrapper(first); + LazyBlock third = lazyWrapper(first); + SourcePage page = new TestingSourcePage(3, first, second, third); + + second.getLoadedBlock(); + + AtomicLong sizeInBytes = new AtomicLong(); + ScanFilterAndProjectOperator.ProcessedBytesMonitor monitor = new ScanFilterAndProjectOperator.ProcessedBytesMonitor(page, sizeInBytes::getAndAdd); + + assertThat(sizeInBytes.get()).isEqualTo(first.getSizeInBytes() * 2); + + page.getBlock(2).getLoadedBlock(); + monitor.update(); + assertThat(sizeInBytes.get()).isEqualTo(first.getSizeInBytes() * 3); + } + + @Test + public void testNestedBlocks() + { + Block elements = lazyWrapper(createIntsBlock(1, 2, 3)); + Block arrayBlock = ArrayBlock.fromElementBlock(2, Optional.empty(), new int[] {0, 1, 3}, elements); + long initialArraySize = arrayBlock.getSizeInBytes(); + SourcePage page = new TestingSourcePage(2, arrayBlock); + + AtomicLong sizeInBytes = new AtomicLong(); + ScanFilterAndProjectOperator.ProcessedBytesMonitor monitor = new ScanFilterAndProjectOperator.ProcessedBytesMonitor(page, sizeInBytes::getAndAdd); + + assertThat(arrayBlock.getSizeInBytes()).isEqualTo(initialArraySize); + assertThat(sizeInBytes.get()).isEqualTo(arrayBlock.getSizeInBytes()); + + // dictionary block caches size in bytes + arrayBlock.getLoadedBlock(); + monitor.update(); + assertThat(sizeInBytes.get()).isEqualTo(arrayBlock.getSizeInBytes()); + assertThat(sizeInBytes.get()).isEqualTo(initialArraySize + elements.getSizeInBytes()); + } + + private static LazyBlock lazyWrapper(Block block) + { + return new LazyBlock(block.getPositionCount(), block::getLoadedBlock); + } + private static List toPages(Operator operator) { ImmutableList.Builder outputPages = ImmutableList.builder(); @@ -439,9 +487,9 @@ private DriverContext newDriverContext() public static class SinglePagePageSource implements ConnectorPageSource { - private Page page; + private SourcePage page; - public SinglePagePageSource(Page page) + public SinglePagePageSource(SourcePage page) { this.page = page; } @@ -477,9 +525,12 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { - Page page = this.page; + SourcePage page = this.page; + if (page == null) { + return null; + } this.page = null; return page; } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestingSourcePage.java b/core/trino-main/src/test/java/io/trino/operator/TestingSourcePage.java new file mode 100644 index 000000000000..6dd8c2b15c60 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/TestingSourcePage.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; + +import java.util.Arrays; +import java.util.function.ObjLongConsumer; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class TestingSourcePage + implements SourcePage +{ + private final int positionCount; + private final Block[] blocks; + private final boolean[] loaded; + + public TestingSourcePage(int positionCount, Block... blocks) + { + this.positionCount = positionCount; + this.blocks = requireNonNull(blocks, "blocks is null"); + this.loaded = new boolean[blocks.length]; + } + + @Override + public int getPositionCount() + { + return positionCount; + } + + @Override + public long getSizeInBytes() + { + long sizeInBytes = 0; + for (Block block : blocks) { + if (block != null) { + sizeInBytes += block.getSizeInBytes(); + } + } + return sizeInBytes; + } + + @Override + public long getRetainedSizeInBytes() + { + long retainedSizeInBytes = 0; + for (Block block : blocks) { + if (block != null) { + retainedSizeInBytes += block.getRetainedSizeInBytes(); + } + } + return retainedSizeInBytes; + } + + @Override + public void retainedBytesForEachPart(ObjLongConsumer consumer) + { + for (Block block : blocks) { + if (block != null) { + block.retainedBytesForEachPart(consumer); + } + } + } + + @Override + public int getChannelCount() + { + return blocks.length; + } + + public boolean wasLoaded(int channel) + { + return loaded[channel]; + } + + @Override + public Block getBlock(int channel) + { + Block block = blocks[channel]; + checkArgument(block != null, "Block %s should not be accessed", channel); + loaded[channel] = true; + return block; + } + + @Override + public Page getPage() + { + for (Block block : blocks) { + checkArgument(block != null, "Page cannot be created because block is null"); + } + Arrays.fill(loaded, true); + Block[] blocks = this.blocks.clone(); + return new Page(positionCount, blocks); + } + + @Override + public void selectPositions(int[] positions, int offset, int size) + { + for (int i = 0; i < blocks.length; i++) { + Block block = blocks[i]; + if (block != null) { + blocks[i] = block.getPositions(positions, offset, size); + } + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/index/TestTupleFilterProcessor.java b/core/trino-main/src/test/java/io/trino/operator/index/TestTupleFilterProcessor.java index d20889b1bf02..1ae4d955e2b4 100644 --- a/core/trino-main/src/test/java/io/trino/operator/index/TestTupleFilterProcessor.java +++ b/core/trino-main/src/test/java/io/trino/operator/index/TestTupleFilterProcessor.java @@ -19,6 +19,7 @@ import io.trino.operator.DriverYieldSignal; import io.trino.operator.project.PageProcessor; import io.trino.spi.Page; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.sql.gen.PageFunctionCompiler; @@ -75,7 +76,7 @@ public void testFilter() SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - inputPage)) + SourcePage.create(inputPage))) .orElseThrow(() -> new AssertionError("page is not present")); Page expectedPage = Iterables.getOnlyElement(rowPagesBuilder(outputTypes) diff --git a/core/trino-main/src/test/java/io/trino/operator/project/TestDictionaryAwarePageFilter.java b/core/trino-main/src/test/java/io/trino/operator/project/TestDictionaryAwarePageFilter.java index c8b2e1e5d024..2e567cff170f 100644 --- a/core/trino-main/src/test/java/io/trino/operator/project/TestDictionaryAwarePageFilter.java +++ b/core/trino-main/src/test/java/io/trino/operator/project/TestDictionaryAwarePageFilter.java @@ -14,13 +14,13 @@ package io.trino.operator.project; import com.google.common.collect.ImmutableList; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.LazyBlock; import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntArraySet; import it.unimi.dsi.fastutil.ints.IntSet; @@ -190,7 +190,7 @@ private static DictionaryAwarePageFilter createDictionaryAwarePageFilter(boolean private static void testFilter(DictionaryAwarePageFilter filter, Block block, boolean filterRange) { - IntSet actualSelectedPositions = toSet(filter.filter(null, new Page(block))); + IntSet actualSelectedPositions = toSet(filter.filter(null, SourcePage.create(block))); block = block.getLoadedBlock(); @@ -272,7 +272,7 @@ public InputChannels getInputChannels() } @Override - public SelectedPositions filter(ConnectorSession session, Page page) + public SelectedPositions filter(ConnectorSession session, SourcePage page) { assertThat(page.getChannelCount()).isEqualTo(1); Block block = page.getBlock(0); diff --git a/core/trino-main/src/test/java/io/trino/operator/project/TestDictionaryAwarePageProjection.java b/core/trino-main/src/test/java/io/trino/operator/project/TestDictionaryAwarePageProjection.java index aa0ea1a9c747..0659bea5c9b2 100644 --- a/core/trino-main/src/test/java/io/trino/operator/project/TestDictionaryAwarePageProjection.java +++ b/core/trino-main/src/test/java/io/trino/operator/project/TestDictionaryAwarePageProjection.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import io.trino.operator.DriverYieldSignal; import io.trino.operator.Work; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.DictionaryBlock; @@ -25,6 +24,7 @@ import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.block.ValueBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; @@ -198,13 +198,13 @@ public void testPreservesDictionaryInstance() Block secondDictionaryBlock = DictionaryBlock.create(4, dictionary, new int[] {3, 2, 1, 0}); DriverYieldSignal yieldSignal = new DriverYieldSignal(); - Work firstWork = projection.project(null, yieldSignal, new Page(firstDictionaryBlock), SelectedPositions.positionsList(new int[] {0, 1}, 0, 2)); + Work firstWork = projection.project(null, yieldSignal, SourcePage.create(firstDictionaryBlock), SelectedPositions.positionsList(new int[] {0, 1}, 0, 2)); assertThat(firstWork.process()).isTrue(); Block firstOutputBlock = firstWork.getResult(); assertThat(firstOutputBlock).isInstanceOf(DictionaryBlock.class); - Work secondWork = projection.project(null, yieldSignal, new Page(secondDictionaryBlock), SelectedPositions.positionsList(new int[] {0, 1}, 0, 2)); + Work secondWork = projection.project(null, yieldSignal, SourcePage.create(secondDictionaryBlock), SelectedPositions.positionsList(new int[] {0, 1}, 0, 2)); assertThat(secondWork.process()).isTrue(); Block secondOutputBlock = secondWork.getResult(); @@ -290,7 +290,7 @@ private static void testProjectRange(Block block, Class expecte } DriverYieldSignal yieldSignal = new DriverYieldSignal(); - Work work = projection.project(null, yieldSignal, new Page(block), SelectedPositions.positionsRange(5, 10)); + Work work = projection.project(null, yieldSignal, SourcePage.create(block), SelectedPositions.positionsRange(5, 10)); Block result; if (forceYield) { result = projectWithYield(work, yieldSignal); @@ -322,7 +322,7 @@ private static void testProjectList(Block block, Class expected DriverYieldSignal yieldSignal = new DriverYieldSignal(); int[] positions = {0, 2, 4, 6, 8, 10}; - Work work = projection.project(null, yieldSignal, new Page(block), SelectedPositions.positionsList(positions, 0, positions.length)); + Work work = projection.project(null, yieldSignal, SourcePage.create(block), SelectedPositions.positionsList(positions, 0, positions.length)); Block result; if (forceYield) { result = projectWithYield(work, yieldSignal); @@ -353,7 +353,7 @@ private static void testProjectFastReturnIgnoreYield(Block block, DictionaryAwar } DriverYieldSignal yieldSignal = new DriverYieldSignal(); - Work work = projection.project(null, yieldSignal, new Page(block), SelectedPositions.positionsRange(5, 10)); + Work work = projection.project(null, yieldSignal, SourcePage.create(block), SelectedPositions.positionsRange(5, 10)); yieldSignal.setWithDelay(1, executor); yieldSignal.forceYieldForTesting(); @@ -411,7 +411,7 @@ public InputChannels getInputChannels() } @Override - public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) + public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, SourcePage page, SelectedPositions selectedPositions) { return new TestPageProjectionWork(yieldSignal, page, selectedPositions); } @@ -427,7 +427,7 @@ private static class TestPageProjectionWork private int nextIndexOrPosition; private Block result; - public TestPageProjectionWork(DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) + public TestPageProjectionWork(DriverYieldSignal yieldSignal, SourcePage page, SelectedPositions selectedPositions) { this.yieldSignal = yieldSignal; this.block = page.getBlock(0); diff --git a/core/trino-main/src/test/java/io/trino/operator/project/TestInputPageProjection.java b/core/trino-main/src/test/java/io/trino/operator/project/TestInputPageProjection.java index e5e5cf3f1d52..31e6ff7647ff 100644 --- a/core/trino-main/src/test/java/io/trino/operator/project/TestInputPageProjection.java +++ b/core/trino-main/src/test/java/io/trino/operator/project/TestInputPageProjection.java @@ -14,9 +14,9 @@ package io.trino.operator.project; import io.trino.operator.DriverYieldSignal; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.LazyBlock; +import io.trino.spi.connector.SourcePage; import org.junit.jupiter.api.Test; import static io.trino.block.BlockAssertions.createLongSequenceBlock; @@ -31,11 +31,11 @@ public void testLazyInputPage() { InputPageProjection projection = new InputPageProjection(0, BIGINT); Block block = createLongSequenceBlock(0, 100); - Block result = projection.project(SESSION, new DriverYieldSignal(), new Page(block), SelectedPositions.positionsRange(0, 100)).getResult(); + Block result = projection.project(SESSION, new DriverYieldSignal(), SourcePage.create(block), SelectedPositions.positionsRange(0, 100)).getResult(); assertThat(result).isNotInstanceOf(LazyBlock.class); block = lazyWrapper(block); - result = projection.project(SESSION, new DriverYieldSignal(), new Page(block), SelectedPositions.positionsRange(0, 100)).getResult(); + result = projection.project(SESSION, new DriverYieldSignal(), SourcePage.create(block), SelectedPositions.positionsRange(0, 100)).getResult(); assertThat(result).isInstanceOf(LazyBlock.class); assertThat(result.isLoaded()).isFalse(); } diff --git a/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java b/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java index ea2ffe28f15e..579fdbb89237 100644 --- a/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java +++ b/core/trino-main/src/test/java/io/trino/operator/project/TestPageFieldsToInputParametersRewriter.java @@ -17,9 +17,8 @@ import com.google.common.collect.ImmutableSet; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TestingFunctionResolution; -import io.trino.spi.Page; +import io.trino.operator.TestingSourcePage; import io.trino.spi.block.Block; -import io.trino.spi.block.LazyBlock; import io.trino.spi.function.OperatorType; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; @@ -132,19 +131,15 @@ private static void verifyEagerlyLoadedColumns(RowExpression rowExpression, int Result result = rewritePageFieldsToInputParameters(rowExpression); Block[] blocks = new Block[columnCount]; for (int channel = 0; channel < columnCount; channel++) { - blocks[channel] = lazyWrapper(createLongSequenceBlock(0, 100)); + blocks[channel] = createLongSequenceBlock(0, 100); } - Page page = result.getInputChannels().getInputChannels(new Page(blocks)); + TestingSourcePage inputPage = new TestingSourcePage(100, blocks); + result.getInputChannels().getInputChannels(inputPage); for (int channel = 0; channel < columnCount; channel++) { - assertThat(page.getBlock(channel).isLoaded()).isEqualTo(eagerlyLoadedChannels.contains(channel)); + assertThat(inputPage.wasLoaded(channel)).isEqualTo(eagerlyLoadedChannels.contains(channel)); } } - private static LazyBlock lazyWrapper(Block block) - { - return new LazyBlock(block.getPositionCount(), block::getLoadedBlock); - } - private static class RowExpressionBuilder { private final Map sourceLayout = new HashMap<>(); diff --git a/core/trino-main/src/test/java/io/trino/operator/project/TestPageProcessor.java b/core/trino-main/src/test/java/io/trino/operator/project/TestPageProcessor.java index ee2af66e22e7..131c80680824 100644 --- a/core/trino-main/src/test/java/io/trino/operator/project/TestPageProcessor.java +++ b/core/trino-main/src/test/java/io/trino/operator/project/TestPageProcessor.java @@ -23,17 +23,19 @@ import io.trino.metadata.TestingFunctionResolution; import io.trino.operator.CompletedWork; import io.trino.operator.DriverYieldSignal; +import io.trino.operator.TestingSourcePage; import io.trino.operator.Work; import io.trino.spi.Page; import io.trino.spi.block.Block; -import io.trino.spi.block.LazyBlock; import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import io.trino.sql.gen.ExpressionProfiler; import io.trino.sql.gen.PageFunctionCompiler; import io.trino.sql.gen.columnar.PageFilterEvaluator; import io.trino.sql.relational.CallExpression; +import org.assertj.core.data.Offset; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -94,7 +96,7 @@ public void testProjectNoColumns() { PageProcessor pageProcessor = new PageProcessor(Optional.empty(), Optional.empty(), ImmutableList.of(), OptionalInt.of(MAX_BATCH_SIZE)); - Page inputPage = new Page(createLongSequenceBlock(0, 100)); + SourcePage inputPage = SourcePage.create(createLongSequenceBlock(0, 100)); Iterator> output = processAndAssertRetainedPageSize(pageProcessor, inputPage); @@ -110,7 +112,7 @@ public void testFilterNoColumns() { PageProcessor pageProcessor = new PageProcessor(Optional.of(new PageFilterEvaluator(new TestingPageFilter(positionsRange(0, 50)))), ImmutableList.of()); - Page inputPage = new Page(createLongSequenceBlock(0, 100)); + SourcePage inputPage = SourcePage.create(createLongSequenceBlock(0, 100)); LocalMemoryContext memoryContext = newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()); Iterator> output = pageProcessor.process(SESSION, new DriverYieldSignal(), memoryContext, inputPage); @@ -132,7 +134,7 @@ public void testPartialFilter() ImmutableList.of(new InputPageProjection(0, BIGINT)), OptionalInt.of(MAX_BATCH_SIZE)); - Page inputPage = new Page(createLongSequenceBlock(0, 100)); + SourcePage inputPage = SourcePage.create(createLongSequenceBlock(0, 100)); Iterator> output = processAndAssertRetainedPageSize(pageProcessor, inputPage); @@ -150,7 +152,7 @@ public void testSelectAllFilter() ImmutableList.of(new InputPageProjection(0, BIGINT)), OptionalInt.of(MAX_BATCH_SIZE)); - Page inputPage = new Page(createLongSequenceBlock(0, 100)); + SourcePage inputPage = SourcePage.create(createLongSequenceBlock(0, 100)); Iterator> output = processAndAssertRetainedPageSize(pageProcessor, inputPage); @@ -164,7 +166,7 @@ public void testSelectNoneFilter() { PageProcessor pageProcessor = new PageProcessor(Optional.of(new PageFilterEvaluator(new SelectNoneFilter())), ImmutableList.of(new InputPageProjection(0, BIGINT))); - Page inputPage = new Page(createLongSequenceBlock(0, 100)); + SourcePage inputPage = SourcePage.create(createLongSequenceBlock(0, 100)); LocalMemoryContext memoryContext = newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()); Iterator> output = pageProcessor.process(SESSION, new DriverYieldSignal(), memoryContext, inputPage); @@ -179,7 +181,7 @@ public void testProjectEmptyPage() { PageProcessor pageProcessor = new PageProcessor(Optional.of(new PageFilterEvaluator(new SelectAllFilter())), ImmutableList.of(new InputPageProjection(0, BIGINT))); - Page inputPage = new Page(createLongSequenceBlock(0, 0)); + SourcePage inputPage = SourcePage.create(createLongSequenceBlock(0, 0)); LocalMemoryContext memoryContext = newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()); Iterator> output = pageProcessor.process(SESSION, new DriverYieldSignal(), memoryContext, inputPage); @@ -196,9 +198,7 @@ public void testSelectNoneFilterLazyLoad() PageProcessor pageProcessor = new PageProcessor(Optional.of(new PageFilterEvaluator(new SelectNoneFilter())), ImmutableList.of(new InputPageProjection(1, BIGINT))); // if channel 1 is loaded, test will fail - Page inputPage = new Page(createLongSequenceBlock(0, 100), new LazyBlock(100, () -> { - throw new AssertionError("Lazy block should not be loaded"); - })); + SourcePage inputPage = new TestingSourcePage(100, createLongSequenceBlock(0, 100), null); LocalMemoryContext memoryContext = newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()); Iterator> output = pageProcessor.process(SESSION, new DriverYieldSignal(), memoryContext, inputPage); @@ -217,9 +217,7 @@ public void testProjectLazyLoad() OptionalInt.of(MAX_BATCH_SIZE)); // if channel 1 is loaded, test will fail - Page inputPage = new Page(createLongSequenceBlock(0, 100), new LazyBlock(100, () -> { - throw new AssertionError("Lazy block should not be loaded"); - })); + SourcePage inputPage = new TestingSourcePage(100, createLongSequenceBlock(0, 100), null); LocalMemoryContext memoryContext = newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()); Iterator> output = pageProcessor.process(SESSION, new DriverYieldSignal(), memoryContext, inputPage); @@ -238,7 +236,7 @@ public void testBatchedOutput() ImmutableList.of(new InputPageProjection(0, BIGINT)), OptionalInt.of(MAX_BATCH_SIZE)); - Page inputPage = new Page(createLongSequenceBlock(0, (int) (MAX_BATCH_SIZE * 2.5))); + SourcePage inputPage = SourcePage.create(createLongSequenceBlock(0, (int) (MAX_BATCH_SIZE * 2.5))); Iterator> output = processAndAssertRetainedPageSize(pageProcessor, inputPage); @@ -264,7 +262,7 @@ public void testAdaptiveBatchSize() // process large page which will reduce batch size Slice[] slices = new Slice[(int) (MAX_BATCH_SIZE * 2.5)]; Arrays.fill(slices, Slices.allocate(4096)); - Page inputPage = new Page(createSlicesBlock(slices)); + SourcePage inputPage = SourcePage.create(createSlicesBlock(slices)); Iterator> output = processAndAssertRetainedPageSize(pageProcessor, new DriverYieldSignal(), inputPage); @@ -280,7 +278,7 @@ public void testAdaptiveBatchSize() // process small page which will increase batch size Arrays.fill(slices, Slices.allocate(128)); - inputPage = new Page(createSlicesBlock(slices)); + inputPage = SourcePage.create(createSlicesBlock(slices)); output = processAndAssertRetainedPageSize(pageProcessor, new DriverYieldSignal(), inputPage); @@ -310,7 +308,7 @@ public void testOptimisticProcessing() // process large page which will reduce batch size Slice[] slices = new Slice[(int) (MAX_BATCH_SIZE * 2.5)]; Arrays.fill(slices, Slices.allocate(4096)); - Page inputPage = new Page(createSlicesBlock(slices)); + SourcePage inputPage = SourcePage.create(createSlicesBlock(slices)); Iterator> output = processAndAssertRetainedPageSize(pageProcessor, inputPage); @@ -354,7 +352,7 @@ public void testRetainedSize() // this can force previouslyComputedResults to be saved given the page is 48MB in size String value = join("", nCopies(30_000, "a")); List values = nCopies(800, value); - Page inputPage = new Page(createStringsBlock(values), createStringsBlock(values)); + SourcePage inputPage = SourcePage.create(new Page(createStringsBlock(values), createStringsBlock(values))); AggregatedMemoryContext memoryContext = newSimpleAggregatedMemoryContext(); Iterator> output = processAndAssertRetainedPageSize(pageProcessor, new DriverYieldSignal(), memoryContext, inputPage); @@ -365,7 +363,7 @@ public void testRetainedSize() // verify we do not count block sizes twice // comparing with the input page, the output page also contains an extra instance size for previouslyComputedResults - assertThat(memoryContext.getBytes() - instanceSize(VariableWidthBlock.class)).isEqualTo(inputPage.getRetainedSizeInBytes()); + assertThat(memoryContext.getBytes() - instanceSize(VariableWidthBlock.class)).isCloseTo(inputPage.getRetainedSizeInBytes(), Offset.offset(200L)); } @Test @@ -384,7 +382,7 @@ public void testYieldProjection() Slice[] slices = new Slice[rows]; Arrays.fill(slices, Slices.allocate(rows)); - Page inputPage = new Page(createSlicesBlock(slices)); + SourcePage inputPage = SourcePage.create(createSlicesBlock(slices)); Iterator> output = processAndAssertRetainedPageSize(pageProcessor, yieldSignal, inputPage); @@ -425,7 +423,7 @@ public void testExpressionProfiler() PageFunctionCompiler functionCompiler = functionResolution.getPageFunctionCompiler(); Supplier projectionSupplier = functionCompiler.compileProjection(add10Expression, Optional.empty()); PageProjection projection = projectionSupplier.get(); - Page page = new Page(createLongSequenceBlock(1, 11)); + SourcePage page = SourcePage.create(createLongSequenceBlock(1, 11)); ExpressionProfiler profiler = new ExpressionProfiler(testingTicker, SPLIT_RUN_QUANTA); for (int i = 0; i < 100; i++) { profiler.start(); @@ -462,7 +460,7 @@ public void testIncreasingBatchSize() Slice[] slices = new Slice[rows]; Arrays.fill(slices, Slices.allocate(rows)); - Page inputPage = new Page(createSlicesBlock(slices)); + SourcePage inputPage = SourcePage.create(createSlicesBlock(slices)); Iterator> output = processAndAssertRetainedPageSize(pageProcessor, inputPage); long previousPositionCount = 1; @@ -497,7 +495,7 @@ public void testDecreasingBatchSize() Slice[] slices = new Slice[rows]; Arrays.fill(slices, Slices.allocate(rows)); - Page inputPage = new Page(createSlicesBlock(slices)); + SourcePage inputPage = SourcePage.create(createSlicesBlock(slices)); Iterator> output = processAndAssertRetainedPageSize(pageProcessor, inputPage); long previousPositionCount = 1; @@ -515,17 +513,17 @@ public void testDecreasingBatchSize() } } - private Iterator> processAndAssertRetainedPageSize(PageProcessor pageProcessor, Page inputPage) + private Iterator> processAndAssertRetainedPageSize(PageProcessor pageProcessor, SourcePage inputPage) { return processAndAssertRetainedPageSize(pageProcessor, new DriverYieldSignal(), inputPage); } - private Iterator> processAndAssertRetainedPageSize(PageProcessor pageProcessor, DriverYieldSignal yieldSignal, Page inputPage) + private Iterator> processAndAssertRetainedPageSize(PageProcessor pageProcessor, DriverYieldSignal yieldSignal, SourcePage inputPage) { return processAndAssertRetainedPageSize(pageProcessor, yieldSignal, newSimpleAggregatedMemoryContext(), inputPage); } - private Iterator> processAndAssertRetainedPageSize(PageProcessor pageProcessor, DriverYieldSignal yieldSignal, AggregatedMemoryContext memoryContext, Page inputPage) + private Iterator> processAndAssertRetainedPageSize(PageProcessor pageProcessor, DriverYieldSignal yieldSignal, AggregatedMemoryContext memoryContext, SourcePage inputPage) { Iterator> output = pageProcessor.process( SESSION, @@ -536,11 +534,6 @@ private Iterator> processAndAssertRetainedPageSize(PageProcessor return output; } - private static LazyBlock lazyWrapper(Block block) - { - return new LazyBlock(block.getPositionCount(), block::getLoadedBlock); - } - private static class InvocationCountPageProjection implements PageProjection { @@ -571,7 +564,7 @@ public InputChannels getInputChannels() } @Override - public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) + public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, SourcePage page, SelectedPositions selectedPositions) { setInvocationCount(getInvocationCount() + 1); return delegate.project(session, yieldSignal, page, selectedPositions); @@ -597,7 +590,7 @@ public YieldPageProjection(PageProjection delegate) } @Override - public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) + public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, SourcePage page, SelectedPositions selectedPositions) { return new YieldPageProjectionWork(session, yieldSignal, page, selectedPositions); } @@ -608,7 +601,7 @@ private class YieldPageProjectionWork private final DriverYieldSignal yieldSignal; private final Work work; - public YieldPageProjectionWork(ConnectorSession session, DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) + public YieldPageProjectionWork(ConnectorSession session, DriverYieldSignal yieldSignal, SourcePage page, SelectedPositions selectedPositions) { this.yieldSignal = yieldSignal; this.work = delegate.project(session, yieldSignal, page, selectedPositions); @@ -653,7 +646,7 @@ public InputChannels getInputChannels() } @Override - public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, Page page, SelectedPositions selectedPositions) + public Work project(ConnectorSession session, DriverYieldSignal yieldSignal, SourcePage page, SelectedPositions selectedPositions) { return new CompletedWork<>(page.getBlock(0).getLoadedBlock()); } @@ -682,7 +675,7 @@ public InputChannels getInputChannels() } @Override - public SelectedPositions filter(ConnectorSession session, Page page) + public SelectedPositions filter(ConnectorSession session, SourcePage page) { return selectedPositions; } @@ -704,7 +697,7 @@ public InputChannels getInputChannels() } @Override - public SelectedPositions filter(ConnectorSession session, Page page) + public SelectedPositions filter(ConnectorSession session, SourcePage page) { return positionsRange(0, page.getPositionCount()); } @@ -726,7 +719,7 @@ public InputChannels getInputChannels() } @Override - public SelectedPositions filter(ConnectorSession session, Page page) + public SelectedPositions filter(ConnectorSession session, SourcePage page) { return positionsRange(0, 0); } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayDistinct.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayDistinct.java index 95fc7573dd6e..5afb5b7e9e2e 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayDistinct.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayDistinct.java @@ -23,6 +23,7 @@ import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlType; import io.trino.spi.type.ArrayType; @@ -90,7 +91,7 @@ public List> arrayDistinct(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayDotProduct.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayDotProduct.java index df6716860fd5..3b4a6aeecc83 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayDotProduct.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayDotProduct.java @@ -20,6 +20,7 @@ import io.trino.spi.Page; import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.sql.gen.ExpressionCompiler; @@ -70,7 +71,7 @@ public List> arrayIntersect(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayFilter.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayFilter.java index 947bf47ac7b2..dbef90ba6aff 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayFilter.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayFilter.java @@ -25,6 +25,7 @@ import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.spi.function.BoundSignature; import io.trino.spi.function.FunctionMetadata; import io.trino.spi.function.Signature; @@ -112,7 +113,7 @@ public List> benchmark(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @Benchmark @@ -124,7 +125,7 @@ public List> benchmarkObject(RowBenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayIntersect.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayIntersect.java index 351318f6bda2..6b3ba0fa9e3d 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayIntersect.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayIntersect.java @@ -21,6 +21,7 @@ import io.trino.spi.Page; import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.sql.gen.ExpressionCompiler; @@ -74,7 +75,7 @@ public List> arrayIntersect(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayJoin.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayJoin.java index 24d8b3df47f7..3603f5324fd3 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayJoin.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayJoin.java @@ -22,6 +22,7 @@ import io.trino.spi.Page; import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.ArrayType; import io.trino.sql.relational.CallExpression; import io.trino.sql.relational.RowExpression; @@ -69,7 +70,7 @@ public List> benchmark(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArraySort.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArraySort.java index 353f539d713f..f70c775242ee 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArraySort.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArraySort.java @@ -24,6 +24,7 @@ import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlType; import io.trino.spi.type.ArrayType; @@ -87,7 +88,7 @@ public List> arraySort(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArraySubscript.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArraySubscript.java index d5181db6e43b..8b70806d6bc6 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArraySubscript.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArraySubscript.java @@ -24,6 +24,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.sql.gen.ExpressionCompiler; @@ -81,7 +82,7 @@ public List> arraySubscript(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayTransform.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayTransform.java index 2d13885e9c6c..68204191863f 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayTransform.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArrayTransform.java @@ -22,6 +22,7 @@ import io.trino.spi.PageBuilder; import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.sql.gen.ExpressionCompiler; @@ -87,7 +88,7 @@ public Object benchmark(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArraysOverlap.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArraysOverlap.java index 251e7e6be6af..f9cc9e0f1df8 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArraysOverlap.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkArraysOverlap.java @@ -22,6 +22,7 @@ import io.trino.spi.Page; import io.trino.spi.block.ArrayBlockBuilder; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.sql.gen.ExpressionCompiler; @@ -72,7 +73,7 @@ public List> benchmark(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkEqualsConjunctsOperator.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkEqualsConjunctsOperator.java index 818316202e9e..1d5171a66d20 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkEqualsConjunctsOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkEqualsConjunctsOperator.java @@ -21,6 +21,7 @@ import io.trino.spi.PageBuilder; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import io.trino.sql.gen.ExpressionCompiler; import io.trino.sql.relational.RowExpression; @@ -118,7 +119,7 @@ public List processPage(BenchmarkData data) SESSION, SIGNAL, newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.page); + SourcePage.create(data.page)); while (pageProcessorOutput.hasNext()) { pageProcessorOutput.next().ifPresent(output::add); } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonFunctions.java index cfc0a3818efb..2e5831fe8e55 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonFunctions.java @@ -28,6 +28,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.spi.security.ConnectorIdentity; import io.trino.spi.type.Type; import io.trino.spi.type.TypeId; @@ -106,7 +107,7 @@ public List> benchmarkJsonValueFunction(BenchmarkData data) FULL_CONNECTOR_SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @Benchmark @@ -118,7 +119,7 @@ public List> benchmarkJsonExtractScalarFunction(BenchmarkData dat FULL_CONNECTOR_SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @Benchmark @@ -130,7 +131,7 @@ public List> benchmarkJsonQueryFunction(BenchmarkData data) FULL_CONNECTOR_SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @Benchmark @@ -142,7 +143,7 @@ public List> benchmarkJsonExtractFunction(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonPathBinaryOperators.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonPathBinaryOperators.java index 90a4e4d01cf6..ca5f9296984d 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonPathBinaryOperators.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonPathBinaryOperators.java @@ -29,6 +29,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.spi.security.ConnectorIdentity; import io.trino.spi.type.Decimals; import io.trino.spi.type.Type; @@ -92,7 +93,7 @@ public List> benchmarkJsonValueFunctionConstantTypes(BenchmarkDat FULL_CONNECTOR_SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPageConstantTypes())); + SourcePage.create(data.getPageConstantTypes()))); } @Benchmark @@ -104,7 +105,7 @@ public List> benchmarkJsonValueFunctionVaryingTypes(BenchmarkData FULL_CONNECTOR_SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPageVaryingTypes())); + SourcePage.create(data.getPageVaryingTypes()))); } @Benchmark @@ -116,7 +117,7 @@ public List> benchmarkJsonValueFunctionMultipleVaryingTypes(Bench FULL_CONNECTOR_SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPageMultipleVaryingTypes())); + SourcePage.create(data.getPageMultipleVaryingTypes()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToArrayCast.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToArrayCast.java index 1274391cd103..d0704e1dafcb 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToArrayCast.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToArrayCast.java @@ -23,6 +23,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.ArrayType; import io.trino.spi.type.Type; import io.trino.sql.relational.CallExpression; @@ -73,7 +74,7 @@ public List> benchmark(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToMapCast.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToMapCast.java index c2a958142ae9..4fbb98e837a4 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToMapCast.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkJsonToMapCast.java @@ -23,6 +23,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.sql.relational.CallExpression; @@ -74,7 +75,7 @@ public List> benchmark(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkMapConcat.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkMapConcat.java index 77a73488bc4d..9bf9e2d9ab17 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkMapConcat.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkMapConcat.java @@ -22,6 +22,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.MapType; import io.trino.sql.gen.ExpressionCompiler; import io.trino.sql.relational.CallExpression; @@ -76,7 +77,7 @@ public List> mapConcat(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkMapSubscript.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkMapSubscript.java index 61d03292241c..d9f2761dc8b6 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkMapSubscript.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkMapSubscript.java @@ -23,6 +23,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.sql.gen.ExpressionCompiler; @@ -81,7 +82,7 @@ public List> mapSubscript(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkMapToMapCast.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkMapToMapCast.java index fa4d147a99e5..99560df1efcb 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkMapToMapCast.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkMapToMapCast.java @@ -21,6 +21,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.MapType; import io.trino.sql.relational.CallExpression; import io.trino.sql.relational.RowExpression; @@ -70,7 +71,7 @@ public List> benchmark(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkRowToRowCast.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkRowToRowCast.java index 8718efcaaae6..f14ece4e72f3 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkRowToRowCast.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkRowToRowCast.java @@ -21,6 +21,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; @@ -68,7 +69,7 @@ public List> benchmark(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkTransformKey.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkTransformKey.java index fa9307c8938e..c853826a1e3e 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkTransformKey.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkTransformKey.java @@ -23,6 +23,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.sql.gen.ExpressionCompiler; @@ -83,7 +84,7 @@ public List> benchmark(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkTransformValue.java b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkTransformValue.java index 47c0435f61ca..8fd0110a5fd7 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkTransformValue.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/BenchmarkTransformValue.java @@ -24,6 +24,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.MapBlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; import io.trino.sql.gen.ExpressionCompiler; @@ -85,7 +86,7 @@ public List> benchmark(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.getPage())); + SourcePage.create(data.getPage()))); } @SuppressWarnings("FieldMayBeFinal") diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestPageProcessorCompiler.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestPageProcessorCompiler.java index 7969bd2d71e7..64b3b97b3d29 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestPageProcessorCompiler.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestPageProcessorCompiler.java @@ -25,6 +25,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.ArrayType; import io.trino.sql.gen.ExpressionCompiler; import io.trino.sql.relational.CallExpression; @@ -81,7 +82,7 @@ public void testSanityRLE() null, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - page)) + SourcePage.create(page))) .orElseThrow(() -> new AssertionError("page is not present")); assertThat(outputPage.getPositionCount()).isEqualTo(100); @@ -112,7 +113,7 @@ public void testSanityFilterOnDictionary() null, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - page)) + SourcePage.create(page))) .orElseThrow(() -> new AssertionError("page is not present")); assertThat(outputPage.getPositionCount()).isEqualTo(100); @@ -127,7 +128,7 @@ public void testSanityFilterOnDictionary() null, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - page)).orElseThrow(() -> new AssertionError("page is not present")); + SourcePage.create(page))).orElseThrow(() -> new AssertionError("page is not present")); assertThat(outputPage2.getPositionCount()).isEqualTo(100); assertThat(outputPage2.getBlock(0) instanceof DictionaryBlock).isTrue(); @@ -150,7 +151,7 @@ public void testSanityFilterOnRLE() null, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - page)) + SourcePage.create(page))) .orElseThrow(() -> new AssertionError("page is not present")); assertThat(outputPage.getPositionCount()).isEqualTo(100); @@ -171,7 +172,7 @@ public void testSanityColumnarDictionary() null, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - page)) + SourcePage.create(page))) .orElseThrow(() -> new AssertionError("page is not present")); assertThat(outputPage.getPositionCount()).isEqualTo(100); @@ -201,7 +202,7 @@ public void testNonDeterministicProject() null, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - page)) + SourcePage.create(page))) .orElseThrow(() -> new AssertionError("page is not present")); assertThat(outputPage.getBlock(0) instanceof DictionaryBlock).isFalse(); } diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/BenchmarkCastTimestampToVarchar.java b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/BenchmarkCastTimestampToVarchar.java index c481c6f7ddf0..c94ce94b7ad6 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/BenchmarkCastTimestampToVarchar.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/timestamp/BenchmarkCastTimestampToVarchar.java @@ -23,6 +23,7 @@ import io.trino.operator.scalar.timetz.TimeWithTimeZoneToTimeWithTimeZoneCast; import io.trino.spi.Page; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.LongTimeWithTimeZone; import io.trino.spi.type.LongTimestamp; import io.trino.spi.type.LongTimestampWithTimeZone; @@ -73,7 +74,7 @@ public class BenchmarkCastTimestampToVarchar @Benchmark public List> benchmarkCastToVarchar(BenchmarkData data) { - return ImmutableList.copyOf(data.pageProcessor.process(SESSION, data.yieldSignal, data.localMemoryContext, data.page)); + return ImmutableList.copyOf(data.pageProcessor.process(SESSION, data.yieldSignal, data.localMemoryContext, SourcePage.create(data.page))); } @State(Scope.Thread) diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkAndColumnarFilterTpchData.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkAndColumnarFilterTpchData.java index 4dafeade620f..0e36d9595415 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkAndColumnarFilterTpchData.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkAndColumnarFilterTpchData.java @@ -22,6 +22,7 @@ import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.SourcePage; import io.trino.sql.relational.CallExpression; import io.trino.sql.relational.InputReferenceExpression; import io.trino.sql.relational.RowExpression; @@ -108,7 +109,7 @@ public List> compiled() null, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - inputPage)); + SourcePage.create(inputPage))); } private static Page createInputPage() diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkColumnarFilter.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkColumnarFilter.java index 778f3f8001f0..614267351593 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkColumnarFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkColumnarFilter.java @@ -26,6 +26,7 @@ import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.ShortArrayBlock; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.SourcePage; import io.trino.spi.function.OperatorType; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.Type; @@ -180,7 +181,7 @@ public long evaluateFilter() new DriverYieldSignal(), context, new PageProcessorMetrics(), - inputPage); + SourcePage.create(inputPage)); if (workProcessor.process() && !workProcessor.isFinished()) { outputRows += workProcessor.getResult().getPositionCount(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkDynamicPageFilter.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkDynamicPageFilter.java index dfebef659217..33be36fec48f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkDynamicPageFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkDynamicPageFilter.java @@ -22,6 +22,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.LazyBlock; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.SourcePage; import io.trino.spi.connector.TestingColumnHandle; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; @@ -173,7 +174,7 @@ public double filterPages() long rowsProcessed = 0; long rowsFiltered = 0; for (Page page : inputData) { - FilterEvaluator.SelectionResult result = filterEvaluator.evaluate(FULL_CONNECTOR_SESSION, positionsRange(0, page.getPositionCount()), page); + FilterEvaluator.SelectionResult result = filterEvaluator.evaluate(FULL_CONNECTOR_SESSION, positionsRange(0, page.getPositionCount()), SourcePage.create(page)); SelectedPositions selectedPositions = result.selectedPositions(); int selectedPositionCount = selectedPositions.size(); rowsProcessed += page.getPositionCount(); diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkInCodeGenerator.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkInCodeGenerator.java index d56a19b6b7b6..1d2392011c89 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkInCodeGenerator.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkInCodeGenerator.java @@ -25,6 +25,7 @@ import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.SourcePage; import io.trino.spi.function.OperatorType; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.Type; @@ -209,7 +210,7 @@ public List> benchmark(BenchmarkData data) SESSION, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - data.inputPage)); + SourcePage.create(data.inputPage))); } @Test diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor.java index 82c1e44c9deb..517787fc2ffe 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor.java @@ -23,6 +23,7 @@ import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.VariableWidthBlock; +import io.trino.spi.connector.SourcePage; import io.trino.sql.relational.CallExpression; import io.trino.sql.relational.RowExpression; import io.trino.sql.relational.SpecialForm; @@ -106,7 +107,7 @@ public List> compiled() null, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - inputPage)); + SourcePage.create(inputPage))); } public static void main(String[] args) diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java index 0156829a411a..6e9793294823 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/BenchmarkPageProcessor2.java @@ -27,6 +27,7 @@ import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.connector.RecordSet; +import io.trino.spi.connector.SourcePage; import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; import io.trino.sql.PlannerContext; @@ -153,7 +154,7 @@ public List> columnOriented() null, new DriverYieldSignal(), newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()), - inputPage)); + SourcePage.create(inputPage))); } private RowExpression getFilter(Type type) diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/TestColumnarFilters.java b/core/trino-main/src/test/java/io/trino/sql/gen/TestColumnarFilters.java index 336bca5a9989..df64f35cfd6d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/TestColumnarFilters.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/TestColumnarFilters.java @@ -38,6 +38,7 @@ import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.SourcePage; import io.trino.spi.function.LiteralParameters; import io.trino.spi.function.ScalarFunction; import io.trino.spi.function.SqlNullable; @@ -670,7 +671,7 @@ private static List processFilter(List inputPages, boolean columnarE new DriverYieldSignal(), context, new PageProcessorMetrics(), - inputPage); + SourcePage.create(inputPage)); if (workProcessor.process() && !workProcessor.isFinished()) { outputPagesBuilder.add(workProcessor.getResult()); } diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/TestDictionaryAwareColumnarFilter.java b/core/trino-main/src/test/java/io/trino/sql/gen/TestDictionaryAwareColumnarFilter.java index 24ea80dec7ef..d01564b6efff 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/TestDictionaryAwareColumnarFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/TestDictionaryAwareColumnarFilter.java @@ -16,12 +16,12 @@ import com.google.common.collect.ImmutableList; import io.trino.FullConnectorSession; import io.trino.operator.project.InputChannels; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.security.ConnectorIdentity; import io.trino.sql.gen.columnar.ColumnarFilter; import io.trino.sql.gen.columnar.DictionaryAwareColumnarFilter; @@ -51,13 +51,13 @@ public void testGetInputChannels() { DictionaryAwareColumnarFilter filter = new DictionaryAwareColumnarFilter(new ColumnarFilter() { @Override - public int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, Page loadedPage) + public int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, SourcePage loadedPage) { throw new UnsupportedOperationException(); } @Override - public int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, Page loadedPage) + public int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, SourcePage loadedPage) { throw new UnsupportedOperationException(); } @@ -190,10 +190,10 @@ private static void testFilter(DictionaryAwareColumnarFilter filter, Block block int[] outputPositions = new int[block.getPositionCount()]; int outputPositionsCount; if (usePositionsList) { - outputPositionsCount = filter.filterPositionsList(FULL_CONNECTOR_SESSION, outputPositions, toPositionsList(0, block.getPositionCount()), 0, block.getPositionCount(), new Page(block)); + outputPositionsCount = filter.filterPositionsList(FULL_CONNECTOR_SESSION, outputPositions, toPositionsList(0, block.getPositionCount()), 0, block.getPositionCount(), SourcePage.create(block)); } else { - outputPositionsCount = filter.filterPositionsRange(FULL_CONNECTOR_SESSION, outputPositions, 0, block.getPositionCount(), new Page(block)); + outputPositionsCount = filter.filterPositionsRange(FULL_CONNECTOR_SESSION, outputPositions, 0, block.getPositionCount(), SourcePage.create(block)); } IntSet actualSelectedPositions = new IntArraySet(Arrays.copyOfRange(outputPositions, 0, outputPositionsCount)); IntSet expectedSelectedPositions = new IntArraySet(block.getPositionCount()); @@ -257,7 +257,7 @@ public InputChannels getInputChannels() } @Override - public int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, Page loadedPage) + public int filterPositionsRange(ConnectorSession session, int[] outputPositions, int offset, int size, SourcePage loadedPage) { assertThat(loadedPage.getChannelCount()).isEqualTo(1); Block block = loadedPage.getBlock(0); @@ -283,7 +283,7 @@ public int filterPositionsRange(ConnectorSession session, int[] outputPositions, } @Override - public int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, Page loadedPage) + public int filterPositionsList(ConnectorSession session, int[] outputPositions, int[] activePositions, int offset, int size, SourcePage loadedPage) { assertThat(loadedPage.getChannelCount()).isEqualTo(1); Block block = loadedPage.getBlock(0); diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/TestDynamicPageFilter.java b/core/trino-main/src/test/java/io/trino/sql/gen/TestDynamicPageFilter.java index f164723e9329..a48f84cfb67c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/TestDynamicPageFilter.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/TestDynamicPageFilter.java @@ -17,13 +17,14 @@ import com.google.common.collect.ImmutableMap; import io.trino.FullConnectorSession; import io.trino.Session; +import io.trino.operator.TestingSourcePage; import io.trino.operator.project.SelectedPositions; import io.trino.spi.Page; import io.trino.spi.block.Block; -import io.trino.spi.block.LazyBlock; import io.trino.spi.block.RowBlock; import io.trino.spi.block.SqlRow; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.SourcePage; import io.trino.spi.connector.TestingColumnHandle; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; @@ -83,9 +84,9 @@ public class TestDynamicPageFilter @Test public void testAllPageFilter() { - Page page = new Page( + SourcePage page = SourcePage.create(new Page( createLongsBlock(1L, 2L, null, 5L, null), - createLongsBlock(null, 102L, 135L, null, 3L)); + createLongsBlock(null, 102L, 135L, null, 3L))); FilterEvaluator filterEvaluator = createDynamicFilterEvaluator(TupleDomain.all(), ImmutableMap.of()); verifySelectedPositions(filterPage(page, filterEvaluator), page.getPositionCount()); } @@ -93,9 +94,9 @@ public void testAllPageFilter() @Test public void testNonePageFilter() { - Page page = new Page( + SourcePage page = SourcePage.create(new Page( createLongsBlock(1L, 2L, null, 5L, null), - createLongsBlock(null, 102L, 135L, null, 3L)); + createLongsBlock(null, 102L, 135L, null, 3L))); FilterEvaluator filterEvaluator = createDynamicFilterEvaluator(TupleDomain.none(), ImmutableMap.of()); verifySelectedPositions(filterPage(page, filterEvaluator), 0); } @@ -107,9 +108,9 @@ public void testStringFilter() FilterEvaluator filterEvaluator = createDynamicFilterEvaluator( TupleDomain.withColumnDomains(ImmutableMap.of(column, onlyNull(VARCHAR))), ImmutableMap.of(column, 0)); - Page page = new Page( + SourcePage page = SourcePage.create(new Page( createStringsBlock("ab", "bc", null, "cd", null), - createStringsBlock(null, "de", "ef", null, "fg")); + createStringsBlock(null, "de", "ef", null, "fg"))); verifySelectedPositions(filterPage(page, filterEvaluator), new int[] {2, 4}); filterEvaluator = createDynamicFilterEvaluator( @@ -134,9 +135,9 @@ public void testLongBlockFilter() FilterEvaluator filterEvaluator = createDynamicFilterEvaluator( TupleDomain.withColumnDomains(ImmutableMap.of(column, onlyNull(INTEGER))), ImmutableMap.of(column, 0)); - Page page = new Page( + SourcePage page = SourcePage.create(new Page( createTypedLongsBlock(INTEGER, 1L, 2L, null, 5L, null), - createTypedLongsBlock(INTEGER, null, 102L, 135L, null, 3L)); + createTypedLongsBlock(INTEGER, null, 102L, 135L, null, 3L))); verifySelectedPositions(filterPage(page, filterEvaluator), new int[] {2, 4}); filterEvaluator = createDynamicFilterEvaluator( @@ -168,7 +169,7 @@ public void testStructuralTypeFilter() column, multipleValues(rowType, ImmutableList.of(new SqlRow(0, filterBlocks), new SqlRow(1, filterBlocks))))), ImmutableMap.of(column, 0)); - Page page = new Page(rowBlock); + SourcePage page = SourcePage.create(new Page(rowBlock)); // Columnar filter evaluation does not support IN on structural types, therefore this is a no-op filter // This should change to filter rows when the above is resolved verifySelectedPositions(filterPage(page, filterEvaluator), page.getPositionCount()); @@ -184,13 +185,13 @@ public void testSelectivePageFilter() ImmutableMap.of(columnB, 1)); // page without null - Page page = new Page(createLongSequenceBlock(0, 101), createLongSequenceBlock(100, 201)); + SourcePage page = SourcePage.create(new Page(createLongSequenceBlock(0, 101), createLongSequenceBlock(100, 201))); verifySelectedPositions(filterPage(page, filterEvaluator), new int[] {35, 85}); // page with null - page = new Page( + page = SourcePage.create(new Page( createLongsBlock(1L, 2L, null, 5L, null), - createLongsBlock(null, 102L, 135L, null, 3L)); + createLongsBlock(null, 102L, 135L, null, 3L))); verifySelectedPositions(filterPage(page, filterEvaluator), new int[] {2}); } @@ -205,15 +206,15 @@ public void testNonSelectivePageFilter() ImmutableMap.of(columnB, 1)); // page without null - Page page = new Page( + SourcePage page = SourcePage.create(new Page( createLongSequenceBlock(0, 101), - createLongSequenceBlock(100, 201)); + createLongSequenceBlock(100, 201))); verifySelectedPositions(filterPage(page, filterEvaluator), 101); // page with null - page = new Page( + page = SourcePage.create(new Page( createLongsBlock(1L, 2L, null, 5L, null), - createLongsBlock(null, 102L, 135L, null, 3L)); + createLongsBlock(null, 102L, 135L, null, 3L))); verifySelectedPositions(filterPage(page, filterEvaluator), new int[] {1, 2, 4}); } @@ -232,9 +233,9 @@ public void testPageFilterWithPositionsList() // block with nulls is second column (positions list instead of range) verifySelectedPositions( filterPage( - new Page( + SourcePage.create(new Page( createLongsBlock(3, 1, 5), - createLongsBlock(3L, null, 1L)), + createLongsBlock(3L, null, 1L))), filterEvaluator), new int[] {0, 1}); } @@ -253,7 +254,7 @@ public void testPageFilterWithRealNaN() verifySelectedPositions( filterPage( - new Page(createBlockOfReals(42.0f, Float.NaN, 32.0f, null, 53.1f)), + SourcePage.create(new Page(createBlockOfReals(42.0f, Float.NaN, 32.0f, null, 53.1f))), filterEvaluator), new int[] {2}); } @@ -274,10 +275,10 @@ public void testDynamicFilterUpdates() ImmutableMap.of(symbolA, columnA, symbolB, columnB, symbolC, columnC), ImmutableMap.of(symbolA, 0, symbolB, 1, symbolC, 2), 1); - Page page = new Page( + SourcePage page = SourcePage.create(new Page( createLongSequenceBlock(0, 101), createLongSequenceBlock(100, 201), - createLongSequenceBlock(200, 301)); + createLongSequenceBlock(200, 301))); FilterEvaluator filterEvaluator = pageFilter.createDynamicPageFilterEvaluator(COMPILER, dynamicFilter).get(); verifySelectedPositions(filterPage(page, filterEvaluator), 101); @@ -317,10 +318,10 @@ public void testDifferentDynamicFilterInstances() ImmutableMap.of(symbolA, columnA, symbolB, columnB, symbolC, columnC), ImmutableMap.of(symbolA, 0, symbolB, 1, symbolC, 2), 1); - Page page = new Page( + SourcePage page = SourcePage.create(new Page( createLongSequenceBlock(0, 101), createLongSequenceBlock(100, 201), - createLongSequenceBlock(200, 301)); + createLongSequenceBlock(200, 301))); TestingDynamicFilter dynamicFilter = new TestingDynamicFilter(1); dynamicFilter.update(TupleDomain.withColumnDomains( @@ -350,7 +351,7 @@ public void testDifferentDynamicFilterInstances() public void testIneffectiveFilter() { ColumnHandle column = new TestingColumnHandle("column"); - List inputPages = generateInputPages(3, 1, 1024); + List inputPages = generateInputPages(3, 1, 1024); FilterEvaluator filterEvaluator = createDynamicFilterEvaluator( TupleDomain.withColumnDomains(ImmutableMap.of(column, getRangePredicate(100, 5000))), ImmutableMap.of(column, 0), @@ -360,20 +361,20 @@ public void testIneffectiveFilter() // EffectiveFilterProfiler should turn off row filtering assertThat(filterPage(inputPages.get(2), filterEvaluator).size()).isEqualTo(1024); - assertThat(inputPages.get(2).getBlock(0)).isInstanceOf(LazyBlock.class); + assertThat(inputPages.get(2).wasLoaded(0)).isFalse(); } @Test public void testEffectiveFilter() { ColumnHandle column = new TestingColumnHandle("column"); - List inputPages = generateInputPages(5, 1, 1024); + List inputPages = generateInputPages(5, 1, 1024); FilterEvaluator filterEvaluator = createDynamicFilterEvaluator( TupleDomain.withColumnDomains(ImmutableMap.of(column, singleValue(BIGINT, 13L))), ImmutableMap.of(column, 0), 0.1); // EffectiveFilterProfiler should not turn off row filtering - for (Page inputPage : inputPages) { + for (TestingSourcePage inputPage : inputPages) { assertThat(filterPage(inputPage, filterEvaluator).size()).isEqualTo(1); } } @@ -383,7 +384,7 @@ public void testIneffectiveFilterFirst() { ColumnHandle columnA = new TestingColumnHandle("columnA"); ColumnHandle columnB = new TestingColumnHandle("columnB"); - List inputPages = generateInputPages(3, 2, 1024); + List inputPages = generateInputPages(3, 2, 1024); FilterEvaluator filterEvaluator = createDynamicFilterEvaluator( TupleDomain.withColumnDomains(ImmutableMap.of( columnA, getRangePredicate(100, 1024), @@ -395,7 +396,7 @@ columnB, singleValue(BIGINT, 13L))), // EffectiveFilterProfiler should turn off row filtering only for the first column filter assertThat(filterPage(inputPages.get(2), filterEvaluator).size()).isEqualTo(1); - assertThat(inputPages.get(2).getBlock(0)).isInstanceOf(LazyBlock.class); + assertThat(inputPages.get(2).wasLoaded(0)).isFalse(); } @Test @@ -403,7 +404,7 @@ public void testIneffectiveFilterLast() { ColumnHandle columnA = new TestingColumnHandle("columnA"); ColumnHandle columnB = new TestingColumnHandle("columnB"); - List inputPages = generateInputPages(4, 2, 1024); + List inputPages = generateInputPages(4, 2, 1024); FilterEvaluator filterEvaluator = createDynamicFilterEvaluator( TupleDomain.withColumnDomains(ImmutableMap.of( columnA, getRangePredicate(50, 950), @@ -416,7 +417,7 @@ columnB, getRangePredicate(100, 1024))), // EffectiveFilterProfiler should turn off row filtering only for the last column filter assertThat(filterPage(inputPages.get(3), filterEvaluator).size()).isEqualTo(900); - assertThat(inputPages.get(3).getBlock(1)).isInstanceOf(LazyBlock.class); + assertThat(inputPages.get(3).wasLoaded(1)).isFalse(); } @Test @@ -425,18 +426,18 @@ public void testMultipleColumnsShortCircuit() ColumnHandle columnA = new TestingColumnHandle("columnA"); ColumnHandle columnB = new TestingColumnHandle("columnB"); ColumnHandle columnC = new TestingColumnHandle("columnC"); - List inputPages = generateInputPages(5, 3, 100); + List inputPages = generateInputPages(5, 3, 100); FilterEvaluator filterEvaluator = createDynamicFilterEvaluator( TupleDomain.withColumnDomains(ImmutableMap.of( columnA, multipleValues(BIGINT, ImmutableList.of(-10L, 5L, 15L, 35L, 50L, 85L, 95L, 105L)), columnB, singleValue(BIGINT, 0L), columnC, getRangePredicate(150, 250))), ImmutableMap.of(columnA, 0, columnB, 1, columnC, 2)); - for (Page inputPage : inputPages) { + for (TestingSourcePage inputPage : inputPages) { assertThat(filterPage(inputPage, filterEvaluator).size()).isEqualTo(0); - assertThat(inputPage.getBlock(0)).isNotInstanceOf(LazyBlock.class); - assertThat(inputPage.getBlock(1)).isNotInstanceOf(LazyBlock.class); - assertThat(inputPage.getBlock(2)).isInstanceOf(LazyBlock.class); + assertThat(inputPage.wasLoaded(0)).isTrue(); + assertThat(inputPage.wasLoaded(1)).isTrue(); + assertThat(inputPage.wasLoaded(2)).isFalse(); } } @@ -445,23 +446,23 @@ public void testDynamicFilterOnSubsetOfColumns() { ColumnHandle columnB = new TestingColumnHandle("columnB"); ColumnHandle columnD = new TestingColumnHandle("columnD"); - List inputPages = generateInputPages(5, 5, 1024); + List inputPages = generateInputPages(5, 5, 1024); FilterEvaluator filterEvaluator = createDynamicFilterEvaluator( TupleDomain.withColumnDomains(ImmutableMap.of( columnB, multipleValues(BIGINT, ImmutableList.of(-10L, 5L, 15L, 35L, 50L, 85L, 95L, 105L)), columnD, getRangePredicate(-50, 90))), ImmutableMap.of(columnB, 1, columnD, 3)); - for (Page inputPage : inputPages) { + for (TestingSourcePage inputPage : inputPages) { assertThat(filterPage(inputPage, filterEvaluator).size()).isEqualTo(5); - assertThat(inputPage.getBlock(0)).isInstanceOf(LazyBlock.class); - assertThat(inputPage.getBlock(1)).isNotInstanceOf(LazyBlock.class); - assertThat(inputPage.getBlock(2)).isInstanceOf(LazyBlock.class); - assertThat(inputPage.getBlock(3)).isNotInstanceOf(LazyBlock.class); - assertThat(inputPage.getBlock(4)).isInstanceOf(LazyBlock.class); + assertThat(inputPage.wasLoaded(0)).isFalse(); + assertThat(inputPage.wasLoaded(1)).isTrue(); + assertThat(inputPage.wasLoaded(2)).isFalse(); + assertThat(inputPage.wasLoaded(3)).isTrue(); + assertThat(inputPage.wasLoaded(4)).isFalse(); } } - private static SelectedPositions filterPage(Page page, FilterEvaluator filterEvaluator) + private static SelectedPositions filterPage(SourcePage page, FilterEvaluator filterEvaluator) { FilterEvaluator.SelectionResult result = filterEvaluator.evaluate(FULL_CONNECTOR_SESSION, positionsRange(0, page.getPositionCount()), page); return result.selectedPositions(); @@ -482,11 +483,11 @@ private static void verifySelectedPositions(SelectedPositions selectedPositions, assertThat(selectedPositions.size()).isEqualTo(rangeSize); } - private static List generateInputPages(int pages, int blocks, int positionsPerBlock) + private static List generateInputPages(int pages, int blocks, int positionsPerBlock) { return IntStream.range(0, pages) - .mapToObj(i -> new Page(IntStream.range(0, blocks) - .mapToObj(_ -> new LazyBlock(positionsPerBlock, () -> createLongSequenceBlock(0, positionsPerBlock))) + .mapToObj(_ -> new TestingSourcePage(positionsPerBlock, IntStream.range(0, blocks) + .mapToObj(_ -> createLongSequenceBlock(0, positionsPerBlock)) .toArray(Block[]::new))) .collect(toImmutableList()); } diff --git a/core/trino-main/src/test/java/io/trino/sql/gen/TestPageFunctionCompiler.java b/core/trino-main/src/test/java/io/trino/sql/gen/TestPageFunctionCompiler.java index 9651ed0ed24c..0dd0cbdc19cf 100644 --- a/core/trino-main/src/test/java/io/trino/sql/gen/TestPageFunctionCompiler.java +++ b/core/trino-main/src/test/java/io/trino/sql/gen/TestPageFunctionCompiler.java @@ -22,6 +22,7 @@ import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.sql.relational.CallExpression; import org.junit.jupiter.api.Test; @@ -80,7 +81,7 @@ public void testGeneratedClassName() String classSuffix = stageId + "_" + planNodeId; Supplier projectionSupplier = functionCompiler.compileProjection(ADD_10_EXPRESSION, Optional.of(classSuffix)); PageProjection projection = projectionSupplier.get(); - Work work = projection.project(SESSION, new DriverYieldSignal(), createLongBlockPage(0), SelectedPositions.positionsRange(0, 1)); + Work work = projection.project(SESSION, new DriverYieldSignal(), SourcePage.create(createLongBlockPage(0)), SelectedPositions.positionsRange(0, 1)); // class name should look like PageProjectionOutput_20170707_223500_67496_zguwn_2_7_XX assertThat(work.getClass().getSimpleName().startsWith("PageProjectionWork_" + stageId.replace('.', '_') + "_" + planNodeId)).isTrue(); } @@ -103,7 +104,7 @@ public void testCache() private Block project(PageProjection projection, Page page, SelectedPositions selectedPositions) { - Work work = projection.project(SESSION, new DriverYieldSignal(), page, selectedPositions); + Work work = projection.project(SESSION, new DriverYieldSignal(), SourcePage.create(page), selectedPositions); assertThat(work.process()).isTrue(); return work.getResult(); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSource.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSource.java index 9fd9f898e3c5..7a5ee82a9d51 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSource.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorPageSource.java @@ -55,8 +55,26 @@ default OptionalLong getCompletedPositions() /** * Gets the next page of data. This method is allowed to return null. + * + * @deprecated Use {@link #getNextSourcePage()} instead + */ + @Deprecated(forRemoval = true) + default Page getNextPage() + { + throw new UnsupportedOperationException(); + } + + /** + * Gets the next page of data. This method is allowed to return null. */ - Page getNextPage(); + default SourcePage getNextSourcePage() + { + Page nextPage = getNextPage(); + if (nextPage == null) { + return null; + } + return SourcePage.create(nextPage); + } /** * Get the total memory that needs to be reserved in the memory pool. diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/EmptyPageSource.java b/core/trino-spi/src/main/java/io/trino/spi/connector/EmptyPageSource.java index edf0e69ce4d2..0ff0106f3120 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/EmptyPageSource.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/EmptyPageSource.java @@ -37,11 +37,18 @@ public boolean isFinished() } @Override + @SuppressWarnings("removal") public Page getNextPage() { return null; } + @Override + public SourcePage getNextSourcePage() + { + return null; + } + @Override public long getMemoryUsage() { diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/FixedPageSource.java b/core/trino-spi/src/main/java/io/trino/spi/connector/FixedPageSource.java index 39536c57ea75..50a3d4057546 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/FixedPageSource.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/FixedPageSource.java @@ -74,6 +74,7 @@ public boolean isFinished() } @Override + @SuppressWarnings("removal") public Page getNextPage() { if (isFinished()) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/FixedSourcePage.java b/core/trino-spi/src/main/java/io/trino/spi/connector/FixedSourcePage.java new file mode 100644 index 000000000000..67460269cdc1 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/FixedSourcePage.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import io.trino.spi.Page; +import io.trino.spi.block.Block; + +import java.util.function.ObjLongConsumer; + +import static java.util.Objects.requireNonNull; + +final class FixedSourcePage + implements SourcePage +{ + private Page page; + + FixedSourcePage(Page page) + { + requireNonNull(page, "page is null"); + this.page = page; + } + + @Override + public int getPositionCount() + { + return page.getPositionCount(); + } + + @Override + public long getSizeInBytes() + { + return page.getSizeInBytes(); + } + + @Override + public long getRetainedSizeInBytes() + { + return page.getRetainedSizeInBytes(); + } + + @Override + public void retainedBytesForEachPart(ObjLongConsumer consumer) + { + for (int i = 0; i < page.getChannelCount(); i++) { + page.getBlock(i).retainedBytesForEachPart(consumer); + } + } + + @Override + public int getChannelCount() + { + return page.getChannelCount(); + } + + @Override + public Block getBlock(int channel) + { + return page.getBlock(channel); + } + + @Override + public Page getPage() + { + return page; + } + + @Override + public Page getColumns(int[] channels) + { + return page.getColumns(channels); + } + + @Override + public void selectPositions(int[] positions, int offset, int size) + { + page = page.getPositions(positions, offset, size); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/PositionCountSourcePage.java b/core/trino-spi/src/main/java/io/trino/spi/connector/PositionCountSourcePage.java new file mode 100644 index 000000000000..cbb08e2c4363 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/PositionCountSourcePage.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import io.trino.spi.Page; +import io.trino.spi.block.Block; + +import java.util.Objects; +import java.util.function.ObjLongConsumer; + +final class PositionCountSourcePage + implements SourcePage +{ + private int positionCount; + + PositionCountSourcePage(int positionCount) + { + if (positionCount < 0) { + throw new IllegalArgumentException("positionCount is negative"); + } + this.positionCount = positionCount; + } + + @Override + public int getPositionCount() + { + return positionCount; + } + + @Override + public long getSizeInBytes() + { + return 0; + } + + @Override + public long getRetainedSizeInBytes() + { + return 0; + } + + @Override + public void retainedBytesForEachPart(ObjLongConsumer consumer) {} + + @Override + public int getChannelCount() + { + return 0; + } + + @Override + public Block getBlock(int channel) + { + throw new IllegalArgumentException("Page has no channels"); + } + + @Override + public Page getPage() + { + return new Page(positionCount); + } + + @Override + public void selectPositions(int[] positions, int offset, int size) + { + if (size > positionCount) { + throw new IllegalArgumentException("Page has no channels"); + } + + for (int i = 0; i < size; i++) { + Objects.checkIndex(offset + i, positionCount); + } + positionCount = size; + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/RecordPageSource.java b/core/trino-spi/src/main/java/io/trino/spi/connector/RecordPageSource.java index 611a08fdda5b..fd30bea015c5 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/RecordPageSource.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/RecordPageSource.java @@ -81,6 +81,7 @@ public boolean isFinished() } @Override + @SuppressWarnings("removal") public Page getNextPage() { if (!closed) { diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/SourcePage.java b/core/trino-spi/src/main/java/io/trino/spi/connector/SourcePage.java new file mode 100644 index 000000000000..37316e2331d5 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/SourcePage.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.connector; + +import io.trino.spi.Page; +import io.trino.spi.block.Block; + +import java.util.function.ObjLongConsumer; + +/** + * A page of data from a connector. + *

+ * A page has a fixed number of positions and a fixed set of channels. + *

+ * This interface is not thread safe. + */ +public interface SourcePage +{ + /** + * Creates a new SourcePage from the specified block. + */ + static SourcePage create(int positionCount) + { + return new PositionCountSourcePage(positionCount); + } + + /** + * Creates a new SourcePage from the specified block. + */ + static SourcePage create(Block block) + { + return new FixedSourcePage(new Page(block.getPositionCount(), block)); + } + + /** + * Creates a new SourcePage from the specified page. + */ + static SourcePage create(Page page) + { + return new FixedSourcePage(page); + } + + /** + * Gets the number of positions in the page. + */ + int getPositionCount(); + + /** + * Gets the current loaded size of the page in bytes. + */ + long getSizeInBytes(); + + /** + * Gets the current retained size of the page in bytes. + */ + long getRetainedSizeInBytes(); + + /** + * Calls retainedBytesForEachPart on all loaded blocks; + */ + void retainedBytesForEachPart(ObjLongConsumer consumer); + + /** + * Gets the number of channels in the page. + */ + int getChannelCount(); + + /** + * Gets the block for the specified channel. + */ + Block getBlock(int channel); + + /** + * Gets all data. + */ + Page getPage(); + + /** + * Gets a projection of the page containing only the specified channels. + */ + default Page getColumns(int[] channels) + { + Block[] blocks = new Block[channels.length]; + for (int i = 0; i < channels.length; i++) { + blocks[i] = getBlock(channels[i]); + } + return new Page(getPositionCount(), blocks); + } + + /** + * Modify this page to mask data internally. After this method is called + * this page and all returned blocks and pages will have the specified size. + *

+ * This method should be preferred to {@link Block#getPositions(int[], int, int)} + * and {@link Page#getPositions(int[], int, int)} where possible, as this allows + * the underlying reader to filter positions on subsequent reads. + */ + void selectPositions(int[] positions, int offset, int size); +} diff --git a/lib/trino-orc/src/main/java/io/trino/orc/OrcReader.java b/lib/trino-orc/src/main/java/io/trino/orc/OrcReader.java index 5fcda7220bc2..2e8608e8698c 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/OrcReader.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/OrcReader.java @@ -34,7 +34,7 @@ import io.trino.orc.metadata.PostScript.HiveWriterVersion; import io.trino.orc.stream.OrcChunkLoader; import io.trino.orc.stream.OrcInputStream; -import io.trino.spi.Page; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import org.joda.time.DateTimeZone; @@ -252,6 +252,7 @@ public CompressionKind getCompressionKind() public OrcRecordReader createRecordReader( List readColumns, List readTypes, + boolean appendRowNumberColumn, OrcPredicate predicate, DateTimeZone legacyFileTimeZone, AggregatedMemoryContext memoryUsage, @@ -263,6 +264,7 @@ public OrcRecordReader createRecordReader( readColumns, readTypes, Collections.nCopies(readColumns.size(), fullyProjectedLayout()), + appendRowNumberColumn, predicate, 0, orcDataSource.getEstimatedSize(), @@ -277,6 +279,7 @@ public OrcRecordReader createRecordReader( List readColumns, List readTypes, List readLayouts, + boolean appendRowNumberColumn, OrcPredicate predicate, long offset, long length, @@ -291,6 +294,7 @@ public OrcRecordReader createRecordReader( requireNonNull(readColumns, "readColumns is null"), requireNonNull(readTypes, "readTypes is null"), requireNonNull(readLayouts, "readLayouts is null"), + appendRowNumberColumn, requireNonNull(predicate, "predicate is null"), footer.getNumberOfRows(), footer.getStripes(), @@ -416,6 +420,7 @@ static void validateFile( try (OrcRecordReader orcRecordReader = orcReader.createRecordReader( orcReader.getRootColumn().getNestedColumns(), readTypes, + false, OrcPredicate.TRUE, UTC, newSimpleAggregatedMemoryContext(), @@ -424,9 +429,9 @@ static void validateFile( throwIfUnchecked(exception); return new RuntimeException(exception); })) { - for (Page page = orcRecordReader.nextPage(); page != null; page = orcRecordReader.nextPage()) { + for (SourcePage page = orcRecordReader.nextPage(); page != null; page = orcRecordReader.nextPage()) { // fully load the page - page.getLoadedPage(); + page.getPage(); } } } diff --git a/lib/trino-orc/src/main/java/io/trino/orc/OrcRecordReader.java b/lib/trino-orc/src/main/java/io/trino/orc/OrcRecordReader.java index ded8bfbc4598..5f9a4010b002 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/OrcRecordReader.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/OrcRecordReader.java @@ -19,6 +19,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import com.google.common.io.Closer; +import com.google.errorprone.annotations.CheckReturnValue; import com.google.errorprone.annotations.FormatMethod; import io.airlift.slice.Slice; import io.airlift.units.DataSize; @@ -40,7 +41,10 @@ import io.trino.orc.stream.InputStreamSources; import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; +import jakarta.annotation.Nullable; import org.joda.time.DateTimeZone; import java.io.Closeable; @@ -54,9 +58,11 @@ import java.util.Optional; import java.util.OptionalInt; import java.util.function.Function; +import java.util.function.ObjLongConsumer; import java.util.function.Predicate; import java.util.stream.Collectors; +import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.instanceSize; import static io.trino.orc.OrcDataSourceUtils.mergeAdjacentDiskRanges; @@ -70,6 +76,7 @@ import static java.lang.Math.min; import static java.lang.Math.toIntExact; import static java.util.Comparator.comparingLong; +import static java.util.Objects.checkIndex; import static java.util.Objects.requireNonNull; public class OrcRecordReader @@ -77,7 +84,9 @@ public class OrcRecordReader { private static final int INSTANCE_SIZE = instanceSize(OrcRecordReader.class); + private final List columns; private final OrcDataSource orcDataSource; + private final boolean appendRowNumberColumn; private final ColumnReader[] columnReaders; private final long[] currentBytesPerCell; @@ -129,6 +138,7 @@ public OrcRecordReader( List readColumns, List readTypes, List readLayouts, + boolean appendRowNumberColumn, OrcPredicate predicate, long numberOfRows, List fileStripes, @@ -152,7 +162,7 @@ public OrcRecordReader( FieldMapperFactory fieldMapperFactory) throws OrcCorruptionException { - requireNonNull(readColumns, "readColumns is null"); + this.columns = requireNonNull(readColumns, "readColumns is null"); checkArgument(readColumns.stream().distinct().count() == readColumns.size(), "readColumns contains duplicate entries"); requireNonNull(readTypes, "readTypes is null"); checkArgument(readColumns.size() == readTypes.size(), "readColumns and readTypes must have the same size"); @@ -168,6 +178,7 @@ public OrcRecordReader( requireNonNull(userMetadata, "userMetadata is null"); requireNonNull(memoryUsage, "memoryUsage is null"); requireNonNull(exceptionTransform, "exceptionTransform is null"); + this.appendRowNumberColumn = appendRowNumberColumn; this.writeValidation = requireNonNull(writeValidation, "writeValidation is null"); this.writeChecksumBuilder = writeValidation.map(validation -> createWriteChecksumBuilder(orcTypes, readTypes)); @@ -406,7 +417,7 @@ public void close() } } - public Page nextPage() + public SourcePage nextPage() throws IOException { // update position for current row group (advancing resets them) @@ -447,21 +458,156 @@ public Page nextPage() // create a lazy page blockFactory.nextPage(); Arrays.fill(currentBytesPerCell, 0); - Block[] blocks = new Block[columnReaders.length]; - for (int i = 0; i < columnReaders.length; i++) { - int columnIndex = i; - blocks[columnIndex] = blockFactory.createBlock( - currentBatchSize, - columnReaders[columnIndex]::readBlock, - false); - listenForLoads(blocks[columnIndex], block -> blockLoaded(columnIndex, block)); - } - - Page page = new Page(currentBatchSize, blocks); + SourcePage page = new OrcSourcePage(currentBatchSize); validateWritePageChecksum(page); return page; } + private class OrcSourcePage + implements SourcePage + { + private final Block[] blocks = new Block[columnReaders.length + (appendRowNumberColumn ? 1 : 0)]; + private final int rowNumberColumnIndex = appendRowNumberColumn ? columnReaders.length : -1; + private SelectedPositions selectedPositions; + + public OrcSourcePage(int positionCount) + { + selectedPositions = new SelectedPositions(positionCount, null); + } + + @Override + public int getPositionCount() + { + return selectedPositions.positionCount(); + } + + @Override + public long getSizeInBytes() + { + long sizeInBytes = 0; + for (Block block : blocks) { + if (block != null) { + sizeInBytes += block.getSizeInBytes(); + } + } + return sizeInBytes; + } + + @Override + public long getRetainedSizeInBytes() + { + long retainedSizeInBytes = 0; + for (Block block : blocks) { + if (block != null) { + retainedSizeInBytes += block.getRetainedSizeInBytes(); + } + } + return retainedSizeInBytes; + } + + @Override + public void retainedBytesForEachPart(ObjLongConsumer consumer) + { + for (Block block : blocks) { + if (block != null) { + block.retainedBytesForEachPart(consumer); + } + } + } + + @Override + public int getChannelCount() + { + return blocks.length; + } + + @Override + public Block getBlock(int channel) + { + checkIndex(channel, blocks.length); + + Block block = blocks[channel]; + if (block == null) { + if (channel == rowNumberColumnIndex) { + block = selectedPositions.createRowNumberBlock(filePosition); + } + else { + // todo use selected positions to improve read performance + block = blockFactory.createBlock( + currentBatchSize, + columnReaders[channel]::readBlock, + false); + listenForLoads(block, nestedBlock -> blockLoaded(channel, nestedBlock)); + block = selectedPositions.apply(block); + } + blocks[channel] = block; + } + return block; + } + + @Override + public Page getPage() + { + // ensure all blocks are loaded + for (int i = 0; i < blocks.length; i++) { + getBlock(i); + } + return new Page(selectedPositions.positionCount(), blocks); + } + + @Override + public void selectPositions(int[] positions, int offset, int size) + { + selectedPositions = selectedPositions.selectPositions(positions, offset, size); + for (int i = 0; i < blocks.length; i++) { + Block block = blocks[i]; + if (block != null) { + block = selectedPositions.apply(block); + blocks[i] = block; + } + } + } + } + + private record SelectedPositions(int positionCount, @Nullable int[] positions) + { + @CheckReturnValue + public Block apply(Block block) + { + if (positions == null) { + return block; + } + return block.getPositions(positions, 0, positionCount); + } + + public Block createRowNumberBlock(long filePosition) + { + long[] rowNumbers = new long[positionCount]; + for (int i = 0; i < positionCount; i++) { + int position = positions == null ? i : positions[i]; + rowNumbers[i] = filePosition + position; + } + return new LongArrayBlock(positionCount, Optional.empty(), rowNumbers); + } + + @CheckReturnValue + public SelectedPositions selectPositions(int[] positions, int offset, int size) + { + if (this.positions == null) { + for (int i = 0; i < size; i++) { + checkIndex(offset + i, positionCount); + } + return new SelectedPositions(size, Arrays.copyOfRange(positions, offset, offset + size)); + } + + int[] newPositions = new int[size]; + for (int i = 0; i < size; i++) { + newPositions[i] = this.positions[positions[offset + i]]; + } + return new SelectedPositions(size, newPositions); + } + } + private void blockLoaded(int columnIndex, Block block) { if (block.getPositionCount() <= 0) { @@ -586,10 +732,10 @@ private void validateWriteStripe(int rowCount) writeChecksumBuilder.ifPresent(builder -> builder.addStripe(rowCount)); } - private void validateWritePageChecksum(Page page) + private void validateWritePageChecksum(SourcePage sourcePage) { if (writeChecksumBuilder.isPresent()) { - page = page.getLoadedPage(); + Page page = sourcePage.getPage(); writeChecksumBuilder.get().addPage(page); rowGroupStatisticsValidation.get().addPage(page); stripeStatisticsValidation.get().addPage(page); @@ -597,6 +743,16 @@ private void validateWritePageChecksum(Page page) } } + @Override + public String toString() + { + return toStringHelper(this) + .add("orcDataSource", orcDataSource.getId()) + .add("columns", columns) + .add("appendRowNumberColumn", appendRowNumberColumn) + .toString(); + } + private static ColumnReader[] createColumnReaders( List columns, List readTypes, diff --git a/lib/trino-orc/src/test/java/io/trino/orc/BenchmarkColumnReaders.java b/lib/trino-orc/src/test/java/io/trino/orc/BenchmarkColumnReaders.java index c54da217b737..eadbeed05874 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/BenchmarkColumnReaders.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/BenchmarkColumnReaders.java @@ -19,6 +19,7 @@ import io.trino.plugin.tpch.DecimalTypeMapping; import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.DecimalType; import io.trino.spi.type.SqlDecimal; import io.trino.spi.type.SqlTimestamp; @@ -328,8 +329,8 @@ public Object readLineitem(LineitemBenchmarkData data) { List pages = new ArrayList<>(); try (OrcRecordReader recordReader = data.createRecordReader()) { - for (Page page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) { - pages.add(page.getLoadedPage()); + for (SourcePage page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) { + pages.add(page.getPage()); } } return pages; @@ -375,7 +376,7 @@ private Object readFirstColumn(OrcRecordReader recordReader) throws IOException { List blocks = new ArrayList<>(); - for (Page page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) { + for (SourcePage page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) { blocks.add(page.getBlock(0).getLoadedBlock()); } return blocks; @@ -429,6 +430,7 @@ OrcRecordReader createRecordReader() return orcReader.createRecordReader( orcReader.getRootColumn().getNestedColumns(), types, + false, OrcPredicate.TRUE, UTC, // arbitrary newSimpleAggregatedMemoryContext(), diff --git a/lib/trino-orc/src/test/java/io/trino/orc/BenchmarkOrcDecimalReader.java b/lib/trino-orc/src/test/java/io/trino/orc/BenchmarkOrcDecimalReader.java index 3d102e218a01..fafa4ab1a142 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/BenchmarkOrcDecimalReader.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/BenchmarkOrcDecimalReader.java @@ -14,8 +14,8 @@ package io.trino.orc; import com.google.common.collect.ImmutableList; -import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.DecimalType; import io.trino.spi.type.SqlDecimal; import org.joda.time.DateTimeZone; @@ -71,8 +71,8 @@ public Object readDecimal(BenchmarkData data) { OrcRecordReader recordReader = data.createRecordReader(); List blocks = new ArrayList<>(); - for (Page page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) { - blocks.add(page.getBlock(0).getLoadedBlock()); + for (SourcePage page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) { + blocks.add(page.getBlock(0)); } return blocks; } @@ -118,6 +118,7 @@ private OrcRecordReader createRecordReader() return orcReader.createRecordReader( orcReader.getRootColumn().getNestedColumns(), ImmutableList.of(DECIMAL_TYPE), + false, OrcPredicate.TRUE, DateTimeZone.UTC, // arbitrary newSimpleAggregatedMemoryContext(), diff --git a/lib/trino-orc/src/test/java/io/trino/orc/OrcTester.java b/lib/trino-orc/src/test/java/io/trino/orc/OrcTester.java index 5f02c8c8fe7e..39ad6ddf6a6a 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/OrcTester.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/OrcTester.java @@ -32,6 +32,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.MapBlockBuilder; import io.trino.spi.block.RowBlockBuilder; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.ArrayType; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; @@ -497,7 +498,7 @@ private static void assertFileContentsTrino( boolean isFirst = true; int rowsProcessed = 0; Iterator iterator = expectedValues.iterator(); - for (Page page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) { + for (SourcePage page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) { int batchSize = page.getPositionCount(); if (skipStripe && rowsProcessed < 10000) { assertThat(advance(iterator, batchSize)).isEqualTo(batchSize); @@ -626,6 +627,7 @@ static OrcRecordReader createCustomOrcRecordReader(TempFile tempFile, OrcPredica return orcReader.createRecordReader( orcReader.getRootColumn().getNestedColumns(), ImmutableList.of(type), + false, predicate, HIVE_STORAGE_TIME_ZONE, newSimpleAggregatedMemoryContext(), diff --git a/lib/trino-orc/src/test/java/io/trino/orc/TestCachingOrcDataSource.java b/lib/trino-orc/src/test/java/io/trino/orc/TestCachingOrcDataSource.java index 49fa3724d0f6..6e4ae9d3aee8 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/TestCachingOrcDataSource.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/TestCachingOrcDataSource.java @@ -20,8 +20,8 @@ import io.airlift.units.DataSize.Unit; import io.trino.orc.metadata.StripeInformation; import io.trino.orc.stream.OrcDataReader; -import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -201,6 +201,7 @@ private void doIntegration(TestingOrcDataSource orcDataSource, DataSize maxMerge OrcRecordReader orcRecordReader = orcReader.createRecordReader( orcReader.getRootColumn().getNestedColumns(), ImmutableList.of(VARCHAR), + false, (numberOfRows, statisticsByColumnIndex) -> true, HIVE_STORAGE_TIME_ZONE, newSimpleAggregatedMemoryContext(), @@ -208,11 +209,10 @@ private void doIntegration(TestingOrcDataSource orcDataSource, DataSize maxMerge RuntimeException::new); int positionCount = 0; while (true) { - Page page = orcRecordReader.nextPage(); + SourcePage page = orcRecordReader.nextPage(); if (page == null) { break; } - page = page.getLoadedPage(); Block block = page.getBlock(0); positionCount += block.getPositionCount(); } diff --git a/lib/trino-orc/src/test/java/io/trino/orc/TestOrcLz4.java b/lib/trino-orc/src/test/java/io/trino/orc/TestOrcLz4.java index 2848b7dce103..fc48ca941947 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/TestOrcLz4.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/TestOrcLz4.java @@ -15,8 +15,8 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; -import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import org.joda.time.DateTimeZone; import org.junit.jupiter.api.Test; @@ -60,6 +60,7 @@ private void testReadLz4(byte[] data) try (OrcRecordReader reader = orcReader.createRecordReader( orcReader.getRootColumn().getNestedColumns(), ImmutableList.of(BIGINT, INTEGER, BIGINT), + false, OrcPredicate.TRUE, DateTimeZone.UTC, newSimpleAggregatedMemoryContext(), @@ -67,11 +68,10 @@ private void testReadLz4(byte[] data) RuntimeException::new)) { int rows = 0; while (true) { - Page page = reader.nextPage(); + SourcePage page = reader.nextPage(); if (page == null) { break; } - page = page.getLoadedPage(); rows += page.getPositionCount(); Block xBlock = page.getBlock(0); diff --git a/lib/trino-orc/src/test/java/io/trino/orc/TestOrcReaderMemoryUsage.java b/lib/trino-orc/src/test/java/io/trino/orc/TestOrcReaderMemoryUsage.java index d59ee52d419c..e7cb30476a58 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/TestOrcReaderMemoryUsage.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/TestOrcReaderMemoryUsage.java @@ -15,6 +15,7 @@ import io.trino.orc.metadata.CompressionKind; import io.trino.spi.Page; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.Type; import io.trino.spi.type.TypeSignature; @@ -60,11 +61,11 @@ public void testVarcharTypeWithoutNulls() long readerMemoryUsage = reader.getMemoryUsage(); while (true) { - Page page = reader.nextPage(); - if (page == null) { + SourcePage sourcePage = reader.nextPage(); + if (sourcePage == null) { break; } - page = page.getLoadedPage(); + Page page = sourcePage.getPage(); // We only verify the memory usage when the batchSize reaches MAX_BATCH_SIZE as batchSize may be // increasing during the test, which will cause the StreamReader buffer sizes to increase too. @@ -105,11 +106,11 @@ public void testBigIntTypeWithNulls() long readerMemoryUsage = reader.getMemoryUsage(); while (true) { - Page page = reader.nextPage(); - if (page == null) { + SourcePage sourcePage = reader.nextPage(); + if (sourcePage == null) { break; } - page = page.getLoadedPage(); + Page page = sourcePage.getPage(); // We only verify the memory usage when the batchSize reaches MAX_BATCH_SIZE as batchSize may be // increasing during the test, which will cause the StreamReader buffer sizes to increase too. @@ -152,11 +153,11 @@ public void testMapTypeWithNulls() long readerMemoryUsage = reader.getMemoryUsage(); while (true) { - Page page = reader.nextPage(); - if (page == null) { + SourcePage sourcePage = reader.nextPage(); + if (sourcePage == null) { break; } - page = page.getLoadedPage(); + Page page = sourcePage.getPage(); // We only verify the memory usage when the batchSize reaches MAX_BATCH_SIZE as batchSize may be // increasing during the test, which will cause the StreamReader buffer sizes to increase too. diff --git a/lib/trino-orc/src/test/java/io/trino/orc/TestOrcReaderPositions.java b/lib/trino-orc/src/test/java/io/trino/orc/TestOrcReaderPositions.java index 7c0ae4650258..64fcc042d0a1 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/TestOrcReaderPositions.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/TestOrcReaderPositions.java @@ -23,6 +23,7 @@ import io.trino.orc.metadata.statistics.IntegerStatistics; import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.exec.FileSinkOperator; @@ -75,7 +76,7 @@ public void testEntireFile() assertThat(reader.getFilePosition()).isEqualTo(reader.getReaderPosition()); for (int i = 0; i < 5; i++) { - Page page = reader.nextPage().getLoadedPage(); + Page page = reader.nextPage().getPage(); assertThat(page.getPositionCount()).isEqualTo(20); assertThat(reader.getReaderPosition()).isEqualTo(i * 20L); assertThat(reader.getFilePosition()).isEqualTo(reader.getReaderPosition()); @@ -113,21 +114,20 @@ public void testStripeSkipping() assertThat(reader.getReaderPosition()).isEqualTo(0); // second stripe - Page page = reader.nextPage().getLoadedPage(); + Page page = reader.nextPage().getPage(); assertThat(page.getPositionCount()).isEqualTo(20); assertThat(reader.getReaderPosition()).isEqualTo(0); assertThat(reader.getFilePosition()).isEqualTo(20); assertCurrentBatch(page, 1); // fourth stripe - page = reader.nextPage().getLoadedPage(); + page = reader.nextPage().getPage(); assertThat(page.getPositionCount()).isEqualTo(20); assertThat(reader.getReaderPosition()).isEqualTo(20); assertThat(reader.getFilePosition()).isEqualTo(60); assertCurrentBatch(page, 3); - page = reader.nextPage(); - assertThat(page).isNull(); + assertThat(reader.nextPage()).isNull(); assertThat(reader.getReaderPosition()).isEqualTo(40); assertThat(reader.getFilePosition()).isEqualTo(100); } @@ -160,11 +160,11 @@ public void testRowGroupSkipping() long position = 50_000; while (true) { - Page page = reader.nextPage(); - if (page == null) { + SourcePage sourcePage = reader.nextPage(); + if (sourcePage == null) { break; } - page = page.getLoadedPage(); + Page page = sourcePage.getPage(); Block block = page.getBlock(0); for (int i = 0; i < block.getPositionCount(); i++) { @@ -212,11 +212,11 @@ public void testBatchSizesForVariableWidth() int currentStringBytes = baseStringBytes + Integer.BYTES + Byte.BYTES; int rowCountsInCurrentRowGroup = 0; while (true) { - Page page = reader.nextPage(); - if (page == null) { + SourcePage sourcePage = reader.nextPage(); + if (sourcePage == null) { break; } - page = page.getLoadedPage(); + Page page = sourcePage.getPage(); rowCountsInCurrentRowGroup += page.getPositionCount(); @@ -268,11 +268,11 @@ public void testBatchSizesForFixedWidth() int rowCountsInCurrentRowGroup = 0; while (true) { - Page page = reader.nextPage(); - if (page == null) { + SourcePage sourcePage = reader.nextPage(); + if (sourcePage == null) { break; } - page = page.getLoadedPage(); + Page page = sourcePage.getPage(); rowCountsInCurrentRowGroup += page.getPositionCount(); Block block = page.getBlock(0); @@ -332,11 +332,11 @@ public void testBatchSizeGrowth() int expectedBatchSize = INITIAL_BATCH_SIZE; int rowCountsInCurrentRowGroup = 0; while (true) { - Page page = reader.nextPage(); - if (page == null) { + SourcePage sourcePage = reader.nextPage(); + if (sourcePage == null) { break; } - page = page.getLoadedPage(); + Page page = sourcePage.getPage(); assertThat(page.getPositionCount()).isEqualTo(expectedBatchSize); assertThat(reader.getReaderPosition()).isEqualTo(totalReadRows); diff --git a/lib/trino-orc/src/test/java/io/trino/orc/TestOrcWithoutRowGroupInfo.java b/lib/trino-orc/src/test/java/io/trino/orc/TestOrcWithoutRowGroupInfo.java index e43eeb70356b..d5132556b8de 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/TestOrcWithoutRowGroupInfo.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/TestOrcWithoutRowGroupInfo.java @@ -15,9 +15,9 @@ import com.google.common.collect.ImmutableList; import io.trino.orc.metadata.OrcColumnId; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.SqlRow; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.Domain; import io.trino.spi.type.RowType; import org.joda.time.DateTimeZone; @@ -67,6 +67,7 @@ private void testAndVerifyResults(OrcPredicate orcPredicate) OrcRecordReader reader = orcReader.createRecordReader( orcReader.getRootColumn().getNestedColumns(), ImmutableList.of(INTEGER, BIGINT, INTEGER, BIGINT, BIGINT, rowType), + false, orcPredicate, DateTimeZone.UTC, newSimpleAggregatedMemoryContext(), @@ -75,11 +76,10 @@ private void testAndVerifyResults(OrcPredicate orcPredicate) int rows = 0; while (true) { - Page page = reader.nextPage(); + SourcePage page = reader.nextPage(); if (page == null) { break; } - page = page.getLoadedPage(); rows += page.getPositionCount(); Block rowBlock = page.getBlock(5); diff --git a/lib/trino-orc/src/test/java/io/trino/orc/TestReadBloomFilter.java b/lib/trino-orc/src/test/java/io/trino/orc/TestReadBloomFilter.java index c2aae84f0ab3..8906ceefd0aa 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/TestReadBloomFilter.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/TestReadBloomFilter.java @@ -88,7 +88,7 @@ private static void testType(Type type, List uniqueValues, T inBloomFilte // without predicate a normal block will be created try (OrcRecordReader recordReader = createCustomOrcRecordReader(tempFile, OrcPredicate.TRUE, type, MAX_BATCH_SIZE)) { - assertThat(recordReader.nextPage().getLoadedPage().getPositionCount()).isEqualTo(MAX_BATCH_SIZE); + assertThat(recordReader.nextPage().getPage().getPositionCount()).isEqualTo(MAX_BATCH_SIZE); } // predicate for specific value within the min/max range without bloom filter being enabled @@ -97,7 +97,7 @@ private static void testType(Type type, List uniqueValues, T inBloomFilte .build(); try (OrcRecordReader recordReader = createCustomOrcRecordReader(tempFile, noBloomFilterPredicate, type, MAX_BATCH_SIZE)) { - assertThat(recordReader.nextPage().getLoadedPage().getPositionCount()).isEqualTo(MAX_BATCH_SIZE); + assertThat(recordReader.nextPage().getPage().getPositionCount()).isEqualTo(MAX_BATCH_SIZE); } // predicate for specific value within the min/max range with bloom filter enabled, but a value not in the bloom filter @@ -117,7 +117,7 @@ private static void testType(Type type, List uniqueValues, T inBloomFilte .build(); try (OrcRecordReader recordReader = createCustomOrcRecordReader(tempFile, matchBloomFilterPredicate, type, MAX_BATCH_SIZE)) { - assertThat(recordReader.nextPage().getLoadedPage().getPositionCount()).isEqualTo(MAX_BATCH_SIZE); + assertThat(recordReader.nextPage().getPage().getPositionCount()).isEqualTo(MAX_BATCH_SIZE); } } } @@ -135,6 +135,7 @@ private static OrcRecordReader createCustomOrcRecordReader(TempFile tempFile, Or return orcReader.createRecordReader( orcReader.getRootColumn().getNestedColumns(), ImmutableList.of(type), + false, predicate, HIVE_STORAGE_TIME_ZONE, newSimpleAggregatedMemoryContext(), diff --git a/lib/trino-orc/src/test/java/io/trino/orc/TestStructColumnReader.java b/lib/trino-orc/src/test/java/io/trino/orc/TestStructColumnReader.java index 65e3078963e6..a3d022175d80 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/TestStructColumnReader.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/TestStructColumnReader.java @@ -266,13 +266,14 @@ private RowBlock read(TempFile tempFile, Type readerType) OrcRecordReader recordReader = orcReader.createRecordReader( orcReader.getRootColumn().getNestedColumns(), ImmutableList.of(readerType), + false, OrcPredicate.TRUE, UTC, newSimpleAggregatedMemoryContext(), OrcReader.INITIAL_BATCH_SIZE, RuntimeException::new); - RowBlock block = (RowBlock) recordReader.nextPage().getLoadedPage().getBlock(0); + RowBlock block = (RowBlock) recordReader.nextPage().getBlock(0).getLoadedBlock(); recordReader.close(); return block; } diff --git a/lib/trino-orc/src/test/java/io/trino/orc/TestTimestampTzMicros.java b/lib/trino-orc/src/test/java/io/trino/orc/TestTimestampTzMicros.java index 03595c186be7..17bfb2604df3 100644 --- a/lib/trino-orc/src/test/java/io/trino/orc/TestTimestampTzMicros.java +++ b/lib/trino-orc/src/test/java/io/trino/orc/TestTimestampTzMicros.java @@ -48,6 +48,7 @@ public void test() try (OrcRecordReader reader = orcReader.createRecordReader( orcReader.getRootColumn().getNestedColumns(), ImmutableList.of(timestampTzType, TimestampType.createTimestampType(6)), + false, OrcPredicate.TRUE, DateTimeZone.UTC, newSimpleAggregatedMemoryContext(), diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java index 5ba0976581bf..d2656f72711c 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ListMultimap; +import com.google.errorprone.annotations.CheckReturnValue; import com.google.errorprone.annotations.FormatMethod; import io.airlift.log.Logger; import io.airlift.slice.Slice; @@ -43,8 +44,10 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RowBlock; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.connector.SourcePage; import io.trino.spi.metrics.Metric; import io.trino.spi.metrics.Metrics; import io.trino.spi.type.ArrayType; @@ -64,12 +67,14 @@ import java.io.Closeable; import java.io.IOException; import java.time.ZoneId; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.Function; +import java.util.function.ObjLongConsumer; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -88,6 +93,7 @@ import static java.lang.Math.min; import static java.lang.Math.toIntExact; import static java.lang.String.format; +import static java.util.Objects.checkIndex; import static java.util.Objects.requireNonNull; public class ParquetReader @@ -103,6 +109,7 @@ public class ParquetReader private final Optional fileCreatedBy; private final List rowGroups; private final List columnFields; + private final boolean appendRowNumberColumn; private final List primitiveFields; private final ParquetDataSource dataSource; private final ZoneId zoneId; @@ -142,6 +149,7 @@ public class ParquetReader public ParquetReader( Optional fileCreatedBy, List columnFields, + boolean appendRowNumberColumn, List rowGroups, ParquetDataSource dataSource, DateTimeZone timeZone, @@ -155,6 +163,7 @@ public ParquetReader( this.fileCreatedBy = requireNonNull(fileCreatedBy, "fileCreatedBy is null"); requireNonNull(columnFields, "columnFields is null"); this.columnFields = ImmutableList.copyOf(columnFields); + this.appendRowNumberColumn = appendRowNumberColumn; this.primitiveFields = getPrimitiveFields(columnFields.stream().map(Column::field).collect(toImmutableList())); this.rowGroups = requireNonNull(rowGroups, "rowGroups is null"); this.dataSource = requireNonNull(dataSource, "dataSource is null"); @@ -247,7 +256,7 @@ public void close() } } - public Page nextPage() + public SourcePage nextPage() throws IOException { int batchSize = nextBatch(); @@ -256,16 +265,151 @@ public Page nextPage() } // create a lazy page blockFactory.nextPage(); - Block[] blocks = new Block[columnFields.size()]; - for (int channel = 0; channel < columnFields.size(); channel++) { - Field field = columnFields.get(channel).field(); - blocks[channel] = blockFactory.createBlock(batchSize, () -> readBlock(field)); - } - Page page = new Page(batchSize, blocks); + SourcePage page = new ParquetSourcePage(batchSize); validateWritePageChecksum(page); return page; } + private class ParquetSourcePage + implements SourcePage + { + private final Block[] blocks = new Block[columnFields.size() + (appendRowNumberColumn ? 1 : 0)]; + private final int rowNumberColumnIndex = appendRowNumberColumn ? columnFields.size() : -1; + private SelectedPositions selectedPositions; + + public ParquetSourcePage(int positionCount) + { + selectedPositions = new SelectedPositions(positionCount, null); + } + + @Override + public int getPositionCount() + { + return selectedPositions.positionCount(); + } + + @Override + public long getSizeInBytes() + { + long sizeInBytes = 0; + for (Block block : blocks) { + if (block != null) { + sizeInBytes += block.getSizeInBytes(); + } + } + return sizeInBytes; + } + + @Override + public long getRetainedSizeInBytes() + { + long retainedSizeInBytes = 0; + for (Block block : blocks) { + if (block != null) { + retainedSizeInBytes += block.getRetainedSizeInBytes(); + } + } + return retainedSizeInBytes; + } + + @Override + public void retainedBytesForEachPart(ObjLongConsumer consumer) + { + for (Block block : blocks) { + if (block != null) { + block.retainedBytesForEachPart(consumer); + } + } + } + + @Override + public int getChannelCount() + { + return blocks.length; + } + + @Override + public Block getBlock(int channel) + { + Block block = blocks[channel]; + if (block == null) { + if (channel == rowNumberColumnIndex) { + block = selectedPositions.createRowNumberBlock(lastBatchStartRow()); + } + else { + // todo use selected positions to improve read performance + Field field = columnFields.get(channel).field(); + block = blockFactory.createBlock(batchSize, () -> readBlock(field)); + block = selectedPositions.apply(block); + } + blocks[channel] = block; + } + return block; + } + + @Override + public Page getPage() + { + // ensure all blocks are loaded + for (int i = 0; i < blocks.length; i++) { + getBlock(i); + } + return new Page(selectedPositions.positionCount(), blocks); + } + + @Override + public void selectPositions(int[] positions, int offset, int size) + { + selectedPositions = selectedPositions.selectPositions(positions, offset, size); + for (int i = 0; i < blocks.length; i++) { + Block block = blocks[i]; + if (block != null) { + block = selectedPositions.apply(block); + blocks[i] = block; + } + } + } + } + + private record SelectedPositions(int positionCount, @Nullable int[] positions) + { + @CheckReturnValue + public Block apply(Block block) + { + if (positions == null) { + return block; + } + return block.getPositions(positions, 0, positionCount); + } + + public Block createRowNumberBlock(long startRowNumber) + { + long[] rowNumbers = new long[positionCount]; + for (int i = 0; i < positionCount; i++) { + int position = positions == null ? i : positions[i]; + rowNumbers[i] = startRowNumber + position; + } + return new LongArrayBlock(positionCount, Optional.empty(), rowNumbers); + } + + @CheckReturnValue + public SelectedPositions selectPositions(int[] positions, int offset, int size) + { + if (this.positions == null) { + for (int i = 0; i < size; i++) { + checkIndex(offset + i, positionCount); + } + return new SelectedPositions(size, Arrays.copyOfRange(positions, offset, offset + size)); + } + + int[] newPositions = new int[size]; + for (int i = 0; i < size; i++) { + newPositions[i] = this.positions[positions[offset + i]]; + } + return new SelectedPositions(size, newPositions); + } + } + /** * Get the global row index of the first row in the last batch. */ @@ -627,10 +771,10 @@ private static FilteredRowRanges[] calculateFilteredRowRanges( return blockRowRanges; } - private void validateWritePageChecksum(Page page) + private void validateWritePageChecksum(SourcePage sourcePage) { if (writeChecksumBuilder.isPresent()) { - page = page.getLoadedPage(); + Page page = sourcePage.getPage(); writeChecksumBuilder.get().addPage(page); rowGroupStatisticsValidation.orElseThrow().addPage(page); } diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java index a9c4c54f794c..7d5c5c0967c8 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java @@ -32,6 +32,7 @@ import io.trino.parquet.reader.RowGroupInfo; import io.trino.parquet.writer.ColumnWriter.BufferData; import io.trino.spi.Page; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; import org.apache.parquet.column.ColumnDescriptor; @@ -242,9 +243,9 @@ public void validate(ParquetDataSource input) try { ParquetMetadata parquetMetadata = MetadataReader.readFooter(input, Optional.of(writeValidation)); try (ParquetReader parquetReader = createParquetReader(input, parquetMetadata, writeValidation)) { - for (Page page = parquetReader.nextPage(); page != null; page = parquetReader.nextPage()) { + for (SourcePage page = parquetReader.nextPage(); page != null; page = parquetReader.nextPage()) { // fully load the page - page.getLoadedPage(); + page.getPage(); } } } @@ -286,6 +287,7 @@ private ParquetReader createParquetReader(ParquetDataSource input, ParquetMetada return new ParquetReader( Optional.ofNullable(fileMetaData.getCreatedBy()), columnFields.build(), + false, rowGroupInfoBuilder.build(), input, parquetTimeZone.orElseThrow(), diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java b/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java index 9f7918115838..f8d4cee8b66c 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/BenchmarkColumnarFilterParquetData.java @@ -30,6 +30,7 @@ import io.trino.parquet.writer.ParquetWriterOptions; import io.trino.spi.Page; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.SourcePage; import io.trino.spi.function.OperatorType; import io.trino.spi.type.Type; import io.trino.sql.gen.ExpressionCompiler; @@ -240,7 +241,7 @@ public long compiled() { ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), columnTypes, columnNames); LocalMemoryContext context = newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()); - Page inputPage = reader.nextPage(); + SourcePage inputPage = reader.nextPage(); long outputRows = 0; while (inputPage != null) { WorkProcessor workProcessor = compiledProcessor.createWorkProcessor( diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java index 0528adcd1166..b992934c79c2 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java @@ -165,6 +165,7 @@ public static ParquetReader createParquetReader( return new ParquetReader( Optional.ofNullable(fileMetaData.getCreatedBy()), columnFields.build(), + false, rowGroups, input, UTC, diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java index d42725e5acb2..cfa97ee617c0 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestByteStreamSplitEncoding.java @@ -18,9 +18,9 @@ import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.metadata.ParquetMetadata; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import org.testng.annotations.Test; @@ -79,7 +79,7 @@ private static void readAndCompare(ParquetReader reader, List> expe { int rowCount = 0; int pageCount = 0; - Page page = reader.nextPage(); + SourcePage page = reader.nextPage(); while (page != null) { assertThat(page.getChannelCount()).isEqualTo(2); if (pageCount % 2 == 1) { // Skip loading every alternative page diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java index aabb734e5b0c..31a41c2688d9 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestInt96Timestamp.java @@ -25,9 +25,9 @@ import io.trino.parquet.PrimitiveField; import io.trino.parquet.metadata.ParquetMetadata; import io.trino.plugin.base.type.DecodedTimestamp; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.Fixed12Block; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.SqlTimestamp; import io.trino.spi.type.TimestampType; import io.trino.spi.type.Timestamps; @@ -115,7 +115,7 @@ public void testNanosOutsideDayRange() ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); - Page page = reader.nextPage(); + SourcePage page = reader.nextPage(); ImmutableList.Builder builder = ImmutableList.builder(); while (page != null) { Fixed12Block block = (Fixed12Block) page.getBlock(0).getLoadedBlock(); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java index db8b4225770a..85ce0784b47e 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestParquetReader.java @@ -23,11 +23,11 @@ import io.trino.parquet.metadata.BlockMetadata; import io.trino.parquet.metadata.ParquetMetadata; import io.trino.parquet.writer.ParquetWriterOptions; -import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.LazyBlock; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.metrics.Count; import io.trino.spi.metrics.Metric; import io.trino.spi.predicate.Domain; @@ -92,7 +92,7 @@ public void testColumnReaderMemoryUsage() AggregatedMemoryContext memoryContext = newSimpleAggregatedMemoryContext(); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, memoryContext, types, columnNames); - Page page = reader.nextPage(); + SourcePage page = reader.nextPage(); assertThat(page.getBlock(0)).isInstanceOf(LazyBlock.class); assertThat(memoryContext.getBytes()).isEqualTo(0); page.getBlock(0).getLoadedBlock(); @@ -143,7 +143,7 @@ public void testEmptyRowRangesWithColumnIndex() "l_commitdate", Domain.create(ValueSet.ofRanges(Range.greaterThan(DATE, LocalDate.of(1995, 1, 1).toEpochDay())), false))); try (ParquetReader reader = createParquetReader(dataSource, parquetMetadata, new ParquetReaderOptions(), newSimpleAggregatedMemoryContext(), types, columnNames, predicate)) { - Page page = reader.nextPage(); + SourcePage page = reader.nextPage(); int rowsRead = 0; while (page != null) { rowsRead += page.getPositionCount(); @@ -234,7 +234,7 @@ private void testReadingOldParquetFiles(File file, List columnNames, Typ ConnectorSession session = TestingConnectorSession.builder().build(); ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); try (ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), ImmutableList.of(columnType), columnNames)) { - Page page = reader.nextPage(); + SourcePage page = reader.nextPage(); Iterator expected = expectedValues.iterator(); while (page != null) { Block block = page.getBlock(0); diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java index 390608f445a9..4555bc525b7d 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/reader/TestTimeMillis.java @@ -18,8 +18,8 @@ import io.trino.parquet.ParquetDataSource; import io.trino.parquet.ParquetReaderOptions; import io.trino.parquet.metadata.ParquetMetadata; -import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.SqlTime; import io.trino.spi.type.TimeType; import io.trino.spi.type.Type; @@ -63,7 +63,7 @@ private void testTimeMillsInt32(TimeType timeType) ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); ParquetReader reader = createParquetReader(dataSource, parquetMetadata, newSimpleAggregatedMemoryContext(), types, columnNames); - Page page = reader.nextPage(); + SourcePage page = reader.nextPage(); Block block = page.getBlock(0).getLoadedBlock(); assertThat(block.getPositionCount()).isEqualTo(1); // TIME '15:03:00' diff --git a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java index 83f07fa7ebc2..2a158cba1af9 100644 --- a/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java +++ b/lib/trino-parquet/src/test/java/io/trino/parquet/writer/TestParquetWriter.java @@ -36,6 +36,7 @@ import io.trino.parquet.reader.TestingParquetDataSource; import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.DecimalType; @@ -457,7 +458,7 @@ public void testWriteBloomFilters(Type type, List data) ImmutableMap.of( "columnA", Domain.singleValue(type, data.get(data.size() / 2)))); try (ParquetReader reader = createParquetReader(dataSource, parquetMetadata, new ParquetReaderOptions().withBloomFilter(true), newSimpleAggregatedMemoryContext(), types, columnNames, predicate)) { - Page page = reader.nextPage(); + SourcePage page = reader.nextPage(); int rowsRead = 0; while (page != null) { rowsRead += page.getPositionCount(); @@ -467,7 +468,7 @@ public void testWriteBloomFilters(Type type, List data) } try (ParquetReader reader = createParquetReader(dataSource, parquetMetadata, new ParquetReaderOptions().withBloomFilter(false), newSimpleAggregatedMemoryContext(), types, columnNames, predicate)) { - Page page = reader.nextPage(); + SourcePage page = reader.nextPage(); int rowsRead = 0; while (page != null) { rowsRead += page.getPositionCount(); diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/MappedPageSource.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/MappedPageSource.java index 958f6f2a9f9d..968761eaf9c1 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/MappedPageSource.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/MappedPageSource.java @@ -15,10 +15,13 @@ import com.google.common.primitives.Ints; import io.trino.spi.Page; +import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import java.io.IOException; import java.util.List; +import java.util.function.ObjLongConsumer; import static java.util.Objects.requireNonNull; @@ -53,13 +56,13 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { - Page nextPage = delegate.getNextPage(); + SourcePage nextPage = delegate.getNextSourcePage(); if (nextPage == null) { return null; } - return nextPage.getColumns(delegateFieldIndex); + return new MappedSourcePage(nextPage, delegateFieldIndex); } @Override @@ -74,4 +77,72 @@ public void close() { delegate.close(); } + + private record MappedSourcePage(SourcePage sourcePage, int[] channels) + implements SourcePage + { + private MappedSourcePage + { + requireNonNull(sourcePage, "sourcePage is null"); + requireNonNull(channels, "channels is null"); + } + + @Override + public int getPositionCount() + { + return sourcePage.getPositionCount(); + } + + @Override + public long getSizeInBytes() + { + return sourcePage.getSizeInBytes(); + } + + @Override + public long getRetainedSizeInBytes() + { + return sourcePage.getRetainedSizeInBytes(); + } + + @Override + public void retainedBytesForEachPart(ObjLongConsumer consumer) + { + sourcePage.retainedBytesForEachPart(consumer); + } + + @Override + public int getChannelCount() + { + return channels.length; + } + + @Override + public Block getBlock(int channel) + { + return sourcePage.getBlock(channels[channel]); + } + + @Override + public Page getPage() + { + return sourcePage.getColumns(channels); + } + + @Override + public Page getColumns(int[] channels) + { + int[] newChannels = new int[channels.length]; + for (int i = 0; i < channels.length; i++) { + newChannels[i] = this.channels[channels[i]]; + } + return sourcePage.getColumns(newChannels); + } + + @Override + public void selectPositions(int[] positions, int offset, int size) + { + sourcePage.selectPositions(positions, offset, size); + } + } } diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/MergeJdbcPageSource.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/MergeJdbcPageSource.java index 87a867a232a4..535d0742b727 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/MergeJdbcPageSource.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/MergeJdbcPageSource.java @@ -18,6 +18,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.RowBlock; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import java.io.IOException; import java.util.List; @@ -56,9 +57,9 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { - Page page = delegate.getNextPage(); + SourcePage page = delegate.getNextSourcePage(); if (page == null || columnAdaptations.isEmpty()) { return page; } @@ -66,14 +67,14 @@ public Page getNextPage() return getColumnAdaptationsPage(page); } - private Page getColumnAdaptationsPage(Page page) + private SourcePage getColumnAdaptationsPage(SourcePage page) { Block[] blocks = new Block[columnAdaptations.size()]; for (int i = 0; i < columnAdaptations.size(); i++) { blocks[i] = columnAdaptations.get(i).getBlock(page); } - return new Page(page.getPositionCount(), blocks); + return SourcePage.create(new Page(page.getPositionCount(), blocks)); } @Override @@ -91,7 +92,7 @@ public void close() public interface ColumnAdaptation { - Block getBlock(Page sourcePage); + Block getBlock(SourcePage sourcePage); } public static final class MergedRowAdaptation @@ -105,7 +106,7 @@ public MergedRowAdaptation(List mergeRowIdSourceChannels) } @Override - public Block getBlock(Page page) + public Block getBlock(SourcePage page) { requireNonNull(page, "page is null"); Block[] mergeRowIdBlocks = new Block[mergeRowIdSourceChannels.size()]; @@ -125,7 +126,7 @@ public record SourceColumn(int sourceChannel) } @Override - public Block getBlock(Page sourcePage) + public Block getBlock(SourcePage sourcePage) { return sourcePage.getBlock(sourceChannel); } diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryEmptyProjectionPageSource.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryEmptyProjectionPageSource.java index 1165faa05c2c..f44afbb462d1 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryEmptyProjectionPageSource.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryEmptyProjectionPageSource.java @@ -13,8 +13,8 @@ */ package io.trino.plugin.bigquery; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static io.trino.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; @@ -57,11 +57,11 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { int positionCount = toIntExact(min(MAX_RLE_PAGE_SIZE, numberOfRows - outputRows)); outputRows += positionCount; - return new Page(positionCount); + return SourcePage.create(positionCount); } @Override diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryQueryPageSource.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryQueryPageSource.java index d403d8e839de..211b37dbc419 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryQueryPageSource.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryQueryPageSource.java @@ -28,6 +28,7 @@ import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.ArrayType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; @@ -157,7 +158,7 @@ public long getMemoryUsage() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { verify(pageBuilder.isEmpty()); if (tableResult == null) { @@ -190,7 +191,7 @@ else if (tableResult.hasNextPage()) { Page page = pageBuilder.build(); pageBuilder.reset(); - return page; + return SourcePage.create(page); } private void appendTo(Type type, FieldValue value, BlockBuilder output) diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryStorageArrowPageSource.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryStorageArrowPageSource.java index 515d6d014281..2c54895a7d0a 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryStorageArrowPageSource.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryStorageArrowPageSource.java @@ -20,6 +20,7 @@ import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.ipc.ReadChannel; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; @@ -102,7 +103,7 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { checkState(pageBuilder.isEmpty(), "PageBuilder is not empty at the beginning of a new page"); ReadRowsResponse response; @@ -125,7 +126,7 @@ public Page getNextPage() pageBuilder.reset(); } readTimeNanos.addAndGet(System.nanoTime() - start); - return page; + return SourcePage.create(page); } @Override diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryStorageAvroPageSource.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryStorageAvroPageSource.java index 6308b26644b0..46a20e21d008 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryStorageAvroPageSource.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryStorageAvroPageSource.java @@ -25,6 +25,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.ArrayType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; @@ -155,7 +156,7 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { checkState(pageBuilder.isEmpty(), "PageBuilder is not empty at the beginning of a new page"); ReadRowsResponse response; @@ -181,7 +182,7 @@ public Page getNextPage() Page page = pageBuilder.build(); pageBuilder.reset(); readTimeNanos.addAndGet(System.nanoTime() - start); - return page; + return SourcePage.create(page); } private static Object getValueRecord(GenericRecord record, BigQueryColumnHandle columnHandle) diff --git a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHolePageSource.java b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHolePageSource.java index 56b425b1bcc2..88e2672305b9 100644 --- a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHolePageSource.java +++ b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHolePageSource.java @@ -17,6 +17,7 @@ import io.airlift.units.Duration; import io.trino.spi.Page; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import java.util.concurrent.CompletableFuture; @@ -49,7 +50,7 @@ class BlackHolePageSource } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { if (isFinished()) { return null; @@ -58,14 +59,14 @@ public Page getNextPage() if (currentPage != null) { Page page = getFutureValue(currentPage); currentPage = null; - return page; + return SourcePage.create(page); } pagesLeft--; completedBytes += page.getSizeInBytes(); if (pageProcessingDelayInMillis == 0) { - return page; + return SourcePage.create(page); } currentPage = toCompletableFuture(executorService.schedule(() -> page, pageProcessingDelayInMillis, MILLISECONDS)); return null; diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java index 66c8134caef3..ed1a6026df1f 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMergeSink.java @@ -32,7 +32,6 @@ import io.trino.plugin.base.metrics.FileFormatDataSourceStats; import io.trino.plugin.deltalake.delete.RoaringBitmapArray; import io.trino.plugin.deltalake.transactionlog.DeletionVectorEntry; -import io.trino.plugin.hive.ReaderPageSource; import io.trino.plugin.hive.parquet.ParquetFileWriter; import io.trino.plugin.hive.parquet.ParquetPageSourceFactory; import io.trino.plugin.hive.parquet.TrinoParquetDataSource; @@ -44,6 +43,7 @@ import io.trino.spi.connector.ConnectorPageSink; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; @@ -367,7 +367,7 @@ private Slice writeMergeResult(Slice path, FileDeletion deletion) deletedRows.or(rowsDeletedByUpdate); if (cdfEnabled) { - try (ConnectorPageSource connectorPageSource = createParquetPageSource(Location.of(path.toStringUtf8())).get()) { + try (ConnectorPageSource connectorPageSource = createParquetPageSource(Location.of(path.toStringUtf8()))) { readConnectorPageSource( connectorPageSource, rowsDeletedByDelete, @@ -542,7 +542,7 @@ private Optional rewriteParquetFile(Location path, FileDeletion de { RoaringBitmapArray rowsDeletedByDelete = deletion.rowsDeletedByDelete(); RoaringBitmapArray rowsDeletedByUpdate = deletion.rowsDeletedByUpdate(); - try (ConnectorPageSource connectorPageSource = createParquetPageSource(path).get()) { + try (ConnectorPageSource connectorPageSource = createParquetPageSource(path)) { readConnectorPageSource( connectorPageSource, rowsDeletedByDelete, @@ -583,10 +583,12 @@ private void readConnectorPageSource( { long filePosition = 0; while (!connectorPageSource.isFinished()) { - Page page = connectorPageSource.getNextPage(); - if (page == null) { + SourcePage sourcePage = connectorPageSource.getNextSourcePage(); + if (sourcePage == null) { continue; } + // fully load page + Page page = sourcePage.getPage(); int positionCount = page.getPositionCount(); int[] retained = new int[positionCount]; @@ -628,14 +630,14 @@ private void storeCdfEntries(Page page, int[] deleted, int deletedCount, FileDel if (cdfPageSink == null) { cdfPageSink = cdfPageSinkSupplier.get(); } - Page cdfPage = page.getPositions(deleted, 0, deletedCount); + page = page.getPositions(deleted, 0, deletedCount); Block[] outputBlocks = new Block[nonSynthesizedColumns.size() + 1]; int cdfPageIndex = 0; int partitionIndex = 0; List partitionValues = deletion.partitionValues; for (int i = 0; i < nonSynthesizedColumns.size(); i++) { if (nonSynthesizedColumns.get(i).columnType() == REGULAR) { - outputBlocks[i] = cdfPage.getBlock(cdfPageIndex); + outputBlocks[i] = page.getBlock(cdfPageIndex); cdfPageIndex++; } else { @@ -644,18 +646,18 @@ private void storeCdfEntries(Page page, int[] deleted, int deletedCount, FileDel deserializePartitionValue( nonSynthesizedColumns.get(i), Optional.ofNullable(partitionValues.get(partitionIndex)))), - cdfPage.getPositionCount()); + page.getPositionCount()); partitionIndex++; } } Block cdfOperationBlock = RunLengthEncodedBlock.create( - nativeValueToBlock(VARCHAR, utf8Slice(operation)), cdfPage.getPositionCount()); + nativeValueToBlock(VARCHAR, utf8Slice(operation)), page.getPositionCount()); outputBlocks[nonSynthesizedColumns.size()] = cdfOperationBlock; - cdfPageSink.appendPage(new Page(cdfPage.getPositionCount(), outputBlocks)); + cdfPageSink.appendPage(new Page(page.getPositionCount(), outputBlocks)); } } - private ReaderPageSource createParquetPageSource(Location path) + private ConnectorPageSource createParquetPageSource(Location path) throws IOException { TrinoInputFile inputFile = fileSystem.newInputFile(path); diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSource.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSource.java deleted file mode 100644 index bdc6c6c7af35..000000000000 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSource.java +++ /dev/null @@ -1,263 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.deltalake; - -import io.airlift.json.JsonCodec; -import io.airlift.json.JsonCodecFactory; -import io.trino.plugin.deltalake.delete.PageFilter; -import io.trino.plugin.hive.ReaderProjectionsAdapter; -import io.trino.spi.Page; -import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; -import io.trino.spi.block.RowBlock; -import io.trino.spi.block.RunLengthEncodedBlock; -import io.trino.spi.connector.ConnectorPageSource; -import io.trino.spi.metrics.Metrics; -import io.trino.spi.predicate.Utils; -import io.trino.spi.type.Type; - -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.OptionalLong; -import java.util.Set; -import java.util.concurrent.CompletableFuture; -import java.util.function.Supplier; - -import static com.google.common.base.Throwables.throwIfInstanceOf; -import static io.airlift.slice.Slices.utf8Slice; -import static io.airlift.slice.Slices.wrappedBuffer; -import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.FILE_MODIFIED_TIME_COLUMN_NAME; -import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.FILE_MODIFIED_TIME_TYPE; -import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.FILE_SIZE_COLUMN_NAME; -import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.FILE_SIZE_TYPE; -import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.PATH_COLUMN_NAME; -import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.PATH_TYPE; -import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.ROW_ID_COLUMN_NAME; -import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_BAD_DATA; -import static io.trino.plugin.deltalake.transactionlog.TransactionLogParser.deserializePartitionValue; -import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; -import static io.trino.spi.type.TimeZoneKey.UTC_KEY; -import static io.trino.spi.type.VarcharType.VARCHAR; -import static java.util.Objects.requireNonNull; - -public class DeltaLakePageSource - implements ConnectorPageSource -{ - private static final JsonCodec> PARTITIONS_CODEC = new JsonCodecFactory().listJsonCodec(String.class); - - private final Block[] prefilledBlocks; - private final int[] delegateIndexes; - private final int rowIdIndex; - private final Block pathBlock; - private final Block partitionsBlock; - private final ConnectorPageSource delegate; - private final Optional projectionsAdapter; - private final Supplier> deletePredicate; - - public DeltaLakePageSource( - List columns, - Set missingColumnNames, - Map> partitionKeys, - Optional> partitionValues, - ConnectorPageSource delegate, - Optional projectionsAdapter, - String path, - long fileSize, - long fileModifiedTime, - Supplier> deletePredicate) - { - int size = columns.size(); - requireNonNull(partitionKeys, "partitionKeys is null"); - this.delegate = requireNonNull(delegate, "delegate is null"); - this.projectionsAdapter = requireNonNull(projectionsAdapter, "projectionsAdapter is null"); - - this.prefilledBlocks = new Block[size]; - this.delegateIndexes = new int[size]; - - int outputIndex = 0; - int delegateIndex = 0; - - int rowIdIndex = -1; - Block pathBlock = null; - Block partitionsBlock = null; - - for (DeltaLakeColumnHandle column : columns) { - if (column.isBaseColumn() && partitionKeys.containsKey(column.basePhysicalColumnName())) { - Type type = column.baseType(); - Object prefilledValue = deserializePartitionValue(column, partitionKeys.get(column.basePhysicalColumnName())); - prefilledBlocks[outputIndex] = Utils.nativeValueToBlock(type, prefilledValue); - delegateIndexes[outputIndex] = -1; - } - else if (column.baseColumnName().equals(PATH_COLUMN_NAME)) { - prefilledBlocks[outputIndex] = Utils.nativeValueToBlock(PATH_TYPE, utf8Slice(path)); - delegateIndexes[outputIndex] = -1; - } - else if (column.baseColumnName().equals(FILE_SIZE_COLUMN_NAME)) { - prefilledBlocks[outputIndex] = Utils.nativeValueToBlock(FILE_SIZE_TYPE, fileSize); - delegateIndexes[outputIndex] = -1; - } - else if (column.baseColumnName().equals(FILE_MODIFIED_TIME_COLUMN_NAME)) { - long packedTimestamp = packDateTimeWithZone(fileModifiedTime, UTC_KEY); - prefilledBlocks[outputIndex] = Utils.nativeValueToBlock(FILE_MODIFIED_TIME_TYPE, packedTimestamp); - delegateIndexes[outputIndex] = -1; - } - else if (column.baseColumnName().equals(ROW_ID_COLUMN_NAME)) { - rowIdIndex = outputIndex; - pathBlock = Utils.nativeValueToBlock(VARCHAR, utf8Slice(path)); - partitionsBlock = Utils.nativeValueToBlock(VARCHAR, wrappedBuffer(PARTITIONS_CODEC.toJsonBytes(partitionValues.orElseThrow(() -> new IllegalStateException("partitionValues not provided"))))); - delegateIndexes[outputIndex] = delegateIndex; - delegateIndex++; - } - else if (missingColumnNames.contains(column.baseColumnName())) { - prefilledBlocks[outputIndex] = Utils.nativeValueToBlock(column.baseType(), null); - delegateIndexes[outputIndex] = -1; - } - else { - delegateIndexes[outputIndex] = delegateIndex; - delegateIndex++; - } - outputIndex++; - } - - this.rowIdIndex = rowIdIndex; - this.pathBlock = pathBlock; - this.partitionsBlock = partitionsBlock; - this.deletePredicate = requireNonNull(deletePredicate, "deletePredicate is null"); - } - - @Override - public long getCompletedBytes() - { - return delegate.getCompletedBytes(); - } - - @Override - public OptionalLong getCompletedPositions() - { - return delegate.getCompletedPositions(); - } - - @Override - public long getReadTimeNanos() - { - return delegate.getReadTimeNanos(); - } - - @Override - public boolean isFinished() - { - return delegate.isFinished(); - } - - @Override - public CompletableFuture isBlocked() - { - return delegate.isBlocked(); - } - - @Override - public Page getNextPage() - { - try { - Page dataPage = delegate.getNextPage(); - if (dataPage == null) { - return null; - } - if (projectionsAdapter.isPresent()) { - dataPage = projectionsAdapter.get().adaptPage(dataPage); - } - Optional deleteFilterPredicate = deletePredicate.get(); - if (deleteFilterPredicate.isPresent()) { - dataPage = deleteFilterPredicate.get().apply(dataPage); - } - - int batchSize = dataPage.getPositionCount(); - Block[] blocks = new Block[prefilledBlocks.length]; - for (int i = 0; i < prefilledBlocks.length; i++) { - if (prefilledBlocks[i] != null) { - blocks[i] = RunLengthEncodedBlock.create(prefilledBlocks[i], batchSize); - } - else if (i == rowIdIndex) { - blocks[i] = createRowIdBlock(dataPage.getBlock(delegateIndexes[i])); - } - else { - blocks[i] = dataPage.getBlock(delegateIndexes[i]); - } - } - return new Page(batchSize, blocks); - } - catch (RuntimeException e) { - closeWithSuppression(e); - throwIfInstanceOf(e, TrinoException.class); - throw new TrinoException(DELTA_LAKE_BAD_DATA, e); - } - } - - private Block createRowIdBlock(Block rowIndexBlock) - { - int positions = rowIndexBlock.getPositionCount(); - Block[] fields = { - RunLengthEncodedBlock.create(pathBlock, positions), - rowIndexBlock, - RunLengthEncodedBlock.create(partitionsBlock, positions), - }; - return RowBlock.fromFieldBlocks(positions, fields); - } - - @Override - public void close() - { - try { - delegate.close(); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - @Override - public String toString() - { - return delegate.toString(); - } - - @Override - public long getMemoryUsage() - { - return delegate.getMemoryUsage(); - } - - @Override - public Metrics getMetrics() - { - return delegate.getMetrics(); - } - - protected void closeWithSuppression(Throwable throwable) - { - requireNonNull(throwable, "throwable is null"); - try { - close(); - } - catch (RuntimeException e) { - // Self-suppression not permitted - if (throwable != e) { - throwable.addSuppressed(e); - } - } - } -} diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java index a65c04d58d1d..180f457e4ae9 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakePageSourceProvider.java @@ -19,6 +19,8 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; +import io.airlift.json.JsonCodec; +import io.airlift.json.JsonCodecFactory; import io.trino.filesystem.Location; import io.trino.filesystem.TrinoFileSystem; import io.trino.filesystem.TrinoFileSystemFactory; @@ -29,16 +31,13 @@ import io.trino.parquet.metadata.ParquetMetadata; import io.trino.parquet.reader.MetadataReader; import io.trino.plugin.base.metrics.FileFormatDataSourceStats; -import io.trino.plugin.deltalake.delete.PageFilter; import io.trino.plugin.deltalake.delete.PositionDeleteFilter; import io.trino.plugin.deltalake.delete.RoaringBitmapArray; import io.trino.plugin.deltalake.transactionlog.DeletionVectorEntry; import io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport.ColumnMappingMode; import io.trino.plugin.hive.HiveColumnHandle; import io.trino.plugin.hive.HiveColumnProjectionInfo; -import io.trino.plugin.hive.HivePageSourceProvider; -import io.trino.plugin.hive.ReaderPageSource; -import io.trino.plugin.hive.ReaderProjectionsAdapter; +import io.trino.plugin.hive.TransformConnectorPageSource; import io.trino.plugin.hive.parquet.ParquetPageSourceFactory; import io.trino.plugin.hive.parquet.ParquetReaderConfig; import io.trino.plugin.hive.parquet.TrinoParquetDataSource; @@ -46,6 +45,8 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.RowBlock; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorPageSourceProvider; @@ -58,6 +59,7 @@ import io.trino.spi.connector.FixedPageSource; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.Utils; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.TypeManager; import org.apache.parquet.schema.MessageType; @@ -71,16 +73,23 @@ import java.util.Map; import java.util.Optional; import java.util.OptionalLong; +import java.util.Set; import java.util.function.Function; -import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.Iterables.getOnlyElement; -import static io.airlift.slice.SizeOf.SIZE_OF_LONG; +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; import static io.trino.plugin.deltalake.DeltaHiveTypeTranslator.toHiveType; +import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.FILE_MODIFIED_TIME_COLUMN_NAME; +import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.FILE_MODIFIED_TIME_TYPE; +import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.FILE_SIZE_COLUMN_NAME; +import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.FILE_SIZE_TYPE; +import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.PATH_COLUMN_NAME; +import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.PATH_TYPE; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.ROW_ID_COLUMN_NAME; import static io.trino.plugin.deltalake.DeltaLakeColumnHandle.rowPositionColumnHandle; import static io.trino.plugin.deltalake.DeltaLakeColumnType.REGULAR; @@ -95,8 +104,11 @@ import static io.trino.plugin.deltalake.delete.DeletionVectors.readDeletionVectors; import static io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport.extractSchema; import static io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport.getColumnMappingMode; +import static io.trino.plugin.deltalake.transactionlog.TransactionLogParser.deserializePartitionValue; import static io.trino.plugin.hive.parquet.ParquetPageSourceFactory.PARQUET_ROW_INDEX_COLUMN; -import static io.trino.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; +import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; +import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static io.trino.spi.type.VarcharType.VARCHAR; import static java.lang.Math.min; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -104,12 +116,9 @@ public class DeltaLakePageSourceProvider implements ConnectorPageSourceProvider { - // This is used whenever a query doesn't reference any data columns. - // We need to limit the number of rows per page in case there are projections - // in the query that can cause page sizes to explode. For example: SELECT rand() FROM some_table - // TODO (https://github.com/trinodb/trino/issues/16824) allow connector to return pages of arbitrary row count and handle this gracefully in engine - private static final int MAX_RLE_PAGE_SIZE = DEFAULT_MAX_PAGE_SIZE_IN_BYTES / SIZE_OF_LONG; - private static final int MAX_RLE_ROW_ID_PAGE_SIZE = DEFAULT_MAX_PAGE_SIZE_IN_BYTES / (SIZE_OF_LONG * 2); + private static final JsonCodec> PARTITIONS_CODEC = new JsonCodecFactory().listJsonCodec(String.class); + + private static final int MAX_ROW_ID_POSITIONS = 100_000; private final TrinoFileSystemFactory fileSystemFactory; private final FileFormatDataSourceStats fileFormatDataSourceStats; @@ -192,22 +201,21 @@ public ConnectorPageSource createPageSource( if (!partitionMatchesPredicate(split.getPartitionKeys(), partitionColumnDomains)) { return new EmptyPageSource(); } + // Skip reading the file if none of the actual file columns are being read if (filteredSplitPredicate.isAll() && split.getStart() == 0 && split.getLength() == split.getFileSize() && split.getFileRowCount().isPresent() && split.getDeletionVector().isEmpty() && (regularColumns.isEmpty() || onlyRowIdColumn(regularColumns))) { - return new DeltaLakePageSource( + return projectColumns( deltaLakeColumns, ImmutableSet.of(), partitionKeys, partitionValues, generatePages(split.getFileRowCount().get(), onlyRowIdColumn(regularColumns)), - Optional.empty(), split.getPath(), split.getFileSize(), - split.getFileModifiedTime(), - Optional::empty); + split.getFileModifiedTime()); } Location location = Location.of(split.getPath()); @@ -222,28 +230,30 @@ public ConnectorPageSource createPageSource( Map parquetFieldIdToName = columnMappingMode == ColumnMappingMode.ID ? loadParquetIdAndNameMapping(inputFile, options) : ImmutableMap.of(); - ImmutableSet.Builder missingColumnNames = ImmutableSet.builder(); - ImmutableList.Builder hiveColumnHandles = ImmutableList.builder(); + ImmutableSet.Builder missingColumnNamesBuilder = ImmutableSet.builder(); + ImmutableList.Builder hiveColumnHandlesBuilder = ImmutableList.builder(); for (DeltaLakeColumnHandle column : regularColumns) { if (column.baseColumnName().equals(ROW_ID_COLUMN_NAME)) { - hiveColumnHandles.add(PARQUET_ROW_INDEX_COLUMN); + hiveColumnHandlesBuilder.add(PARQUET_ROW_INDEX_COLUMN); continue; } toHiveColumnHandle(column, columnMappingMode, parquetFieldIdToName).ifPresentOrElse( - hiveColumnHandles::add, - () -> missingColumnNames.add(column.baseColumnName())); + hiveColumnHandlesBuilder::add, + () -> missingColumnNamesBuilder.add(column.baseColumnName())); } if (split.getDeletionVector().isPresent() && !regularColumns.contains(rowPositionColumnHandle())) { - hiveColumnHandles.add(PARQUET_ROW_INDEX_COLUMN); + hiveColumnHandlesBuilder.add(PARQUET_ROW_INDEX_COLUMN); } + List hiveColumnHandles = hiveColumnHandlesBuilder.build(); + Set missingColumnNames = missingColumnNamesBuilder.build(); TupleDomain parquetPredicate = getParquetTupleDomain(filteredSplitPredicate.simplify(domainCompactionThreshold), columnMappingMode, parquetFieldIdToName); - ReaderPageSource pageSource = ParquetPageSourceFactory.createPageSource( + ConnectorPageSource delegate = ParquetPageSourceFactory.createPageSource( inputFile, split.getStart(), split.getLength(), - hiveColumnHandles.build(), + hiveColumnHandles, ImmutableList.of(parquetPredicate), true, parquetDateTimeZone, @@ -253,40 +263,83 @@ public ConnectorPageSource createPageSource( domainCompactionThreshold, OptionalLong.of(split.getFileSize())); - Optional projectionsAdapter = pageSource.getReaderColumns().map(readerColumns -> - new ReaderProjectionsAdapter( - hiveColumnHandles.build(), - readerColumns, - column -> ((HiveColumnHandle) column).getType(), - HivePageSourceProvider::getProjection)); - - Supplier> deletePredicate = Suppliers.memoize(() -> { - if (split.getDeletionVector().isEmpty()) { - return Optional.empty(); - } - - List requiredColumns = ImmutableList.builderWithExpectedSize(regularColumns.size() + 1) - .addAll(regularColumns) - .add(rowPositionColumnHandle()) - .build(); - PositionDeleteFilter deleteFilter = readDeletes(fileSystem, Location.of(table.location()), split.getDeletionVector().get()); - return Optional.of(deleteFilter.createPredicate(requiredColumns)); - }); + if (split.getDeletionVector().isPresent()) { + var pageFilterSupplier = Suppliers.memoize(() -> { + List requiredColumns = ImmutableList.builderWithExpectedSize(regularColumns.size() + 1) + .addAll(regularColumns) + .add(rowPositionColumnHandle()) + .build(); + PositionDeleteFilter deleteFilter = readDeletes(fileSystem, Location.of(table.location()), split.getDeletionVector().get()); + return deleteFilter.createPredicate(requiredColumns); + }); + delegate = TransformConnectorPageSource.create(delegate, page -> pageFilterSupplier.get().apply(page)); + } - return new DeltaLakePageSource( + return projectColumns( deltaLakeColumns, - missingColumnNames.build(), + missingColumnNames, partitionKeys, partitionValues, - pageSource.get(), - projectionsAdapter, + delegate, split.getPath(), split.getFileSize(), - split.getFileModifiedTime(), - deletePredicate); + split.getFileModifiedTime()); } - private PositionDeleteFilter readDeletes( + public static ConnectorPageSource projectColumns( + List deltaLakeColumns, + Set missingColumnNames, + Map> partitionKeys, + Optional> partitionValues, + ConnectorPageSource delegate, + String path, + long fileSize, + long fileModifiedTime) + { + int delegateIndex = 0; + TransformConnectorPageSource.Builder transform = TransformConnectorPageSource.builder(); + for (DeltaLakeColumnHandle column : deltaLakeColumns) { + if (column.isBaseColumn() && partitionKeys.containsKey(column.basePhysicalColumnName())) { + Object prefilledValue = deserializePartitionValue(column, partitionKeys.get(column.basePhysicalColumnName())); + transform.constantValue(Utils.nativeValueToBlock(column.baseType(), prefilledValue)); + } + else if (column.baseColumnName().equals(PATH_COLUMN_NAME)) { + transform.constantValue(Utils.nativeValueToBlock(PATH_TYPE, utf8Slice(path))); + } + else if (column.baseColumnName().equals(FILE_SIZE_COLUMN_NAME)) { + transform.constantValue(Utils.nativeValueToBlock(FILE_SIZE_TYPE, fileSize)); + } + else if (column.baseColumnName().equals(FILE_MODIFIED_TIME_COLUMN_NAME)) { + long packedTimestamp = packDateTimeWithZone(fileModifiedTime, UTC_KEY); + transform.constantValue(Utils.nativeValueToBlock(FILE_MODIFIED_TIME_TYPE, packedTimestamp)); + } + else if (column.baseColumnName().equals(ROW_ID_COLUMN_NAME)) { + Block pathBlock = Utils.nativeValueToBlock(VARCHAR, utf8Slice(path)); + Block partitionsBlock = Utils.nativeValueToBlock(VARCHAR, wrappedBuffer(PARTITIONS_CODEC.toJsonBytes(partitionValues.orElseThrow(() -> new IllegalStateException("partitionValues not provided"))))); + transform.transform(delegateIndex, new CreateRowIdBlock(pathBlock, partitionsBlock)); + delegateIndex++; + } + else if (missingColumnNames.contains(column.baseColumnName())) { + transform.constantValue(column.baseType().createNullBlock()); + } + else { + transform.column(delegateIndex); + delegateIndex++; + } + } + return transform.build(delegate); + } + + private static Block createRowIdBlock(Block pathValue, Block rowIndexBlock, Block partitionsValue) + { + return RowBlock.fromFieldBlocks(rowIndexBlock.getPositionCount(), new Block[] { + RunLengthEncodedBlock.create(pathValue, rowIndexBlock.getPositionCount()), + rowIndexBlock, + RunLengthEncodedBlock.create(partitionsValue, rowIndexBlock.getPositionCount()), + }); + } + + private static PositionDeleteFilter readDeletes( TrinoFileSystem fileSystem, Location tableLocation, DeletionVectorEntry deletionVector) @@ -300,7 +353,7 @@ private PositionDeleteFilter readDeletes( } } - public Map loadParquetIdAndNameMapping(TrinoInputFile inputFile, ParquetReaderOptions options) + private Map loadParquetIdAndNameMapping(TrinoInputFile inputFile, ParquetReaderOptions options) { try (ParquetDataSource dataSource = new TrinoParquetDataSource(inputFile, options, fileFormatDataSourceStats)) { ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); @@ -373,9 +426,6 @@ private static ConnectorPageSource generatePages(long totalRowCount, boolean pro return new FixedPageSource( new AbstractIterator<>() { - private static final Block[] EMPTY_BLOCKS = new Block[0]; - - private final int maxPageSize = projectRowNumber ? MAX_RLE_ROW_ID_PAGE_SIZE : MAX_RLE_PAGE_SIZE; private long rowIndex; @Override @@ -384,16 +434,17 @@ protected Page computeNext() if (rowIndex == totalRowCount) { return endOfData(); } - int pageSize = toIntExact(min(maxPageSize, totalRowCount - rowIndex)); - Block[] blocks; + int pageSize = toIntExact(min(MAX_ROW_ID_POSITIONS, totalRowCount - rowIndex)); + + Page page; if (projectRowNumber) { - blocks = new Block[] {createRowNumberBlock(rowIndex, pageSize)}; + page = new Page(pageSize, createRowNumberBlock(rowIndex, pageSize)); } else { - blocks = EMPTY_BLOCKS; + page = new Page(pageSize); } rowIndex += pageSize; - return new Page(pageSize, blocks); + return page; } }, 0); @@ -407,4 +458,14 @@ private static Block createRowNumberBlock(long baseIndex, int size) } return new LongArrayBlock(size, Optional.empty(), rowIndices); } + + private record CreateRowIdBlock(Block pathBlock, Block partitionsBlock) + implements Function + { + @Override + public Block apply(Block block) + { + return createRowIdBlock(pathBlock, block, partitionsBlock); + } + } } diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java index c6f85f815c0d..e0dd4e886103 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeWriter.java @@ -32,7 +32,6 @@ import io.trino.spi.block.ColumnarArray; import io.trino.spi.block.ColumnarMap; import io.trino.spi.block.DictionaryBlock; -import io.trino.spi.block.LazyBlock; import io.trino.spi.block.LazyBlockLoader; import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RowBlock; @@ -138,9 +137,7 @@ public void appendRows(Page originalPage) Block originalBlock = originalPage.getBlock(index); Function coercer = coercers.get(index); if (coercer != null) { - translatedBlocks[index] = new LazyBlock( - originalBlock.getPositionCount(), - new CoercionLazyBlockLoader(originalBlock, coercer)); + translatedBlocks[index] = coercer.apply(originalBlock); } else { translatedBlocks[index] = originalBlock; diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/delete/PageFilter.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/delete/PageFilter.java index 7e0a7438e594..7ee914289e76 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/delete/PageFilter.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/delete/PageFilter.java @@ -13,9 +13,9 @@ */ package io.trino.plugin.deltalake.delete; -import io.trino.spi.Page; +import io.trino.spi.connector.SourcePage; import java.util.function.Function; public interface PageFilter - extends Function {} + extends Function {} diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/delete/PositionDeleteFilter.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/delete/PositionDeleteFilter.java index d34c0188e606..3f0c84abb3bc 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/delete/PositionDeleteFilter.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/delete/PositionDeleteFilter.java @@ -50,7 +50,8 @@ public PageFilter createPredicate(List columns) if (retainedCount == positionCount) { return page; } - return page.getPositions(retained, 0, retainedCount); + page.selectPositions(retained, 0, retainedCount); + return page; }; } diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java index 59a2d87850a2..44b9c427f530 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/functions/tablechanges/TableChangesFunctionProcessor.java @@ -22,13 +22,14 @@ import io.trino.parquet.ParquetReaderOptions; import io.trino.plugin.base.metrics.FileFormatDataSourceStats; import io.trino.plugin.deltalake.DeltaLakeColumnHandle; -import io.trino.plugin.deltalake.DeltaLakePageSource; -import io.trino.plugin.hive.ReaderPageSource; +import io.trino.plugin.deltalake.DeltaLakePageSourceProvider; import io.trino.plugin.hive.parquet.ParquetPageSourceFactory; import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.function.table.TableFunctionProcessorState; import io.trino.spi.function.table.TableFunctionSplitProcessor; import io.trino.spi.predicate.TupleDomain; @@ -40,7 +41,6 @@ import java.util.OptionalInt; import java.util.OptionalLong; -import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.slice.Slices.utf8Slice; import static io.trino.plugin.deltalake.DeltaLakeCdfPageSink.CHANGE_TYPE_COLUMN_NAME; @@ -68,7 +68,7 @@ public class TableChangesFunctionProcessor private static final Page EMPTY_PAGE = new Page(0); private final TableChangesFileType fileType; - private final DeltaLakePageSource deltaLakePageSource; + private final ConnectorPageSource deltaLakePageSource; private final Block currentVersionAsBlock; private final Block currentVersionCommitTimestampAsBlock; @@ -117,7 +117,7 @@ public TableFunctionProcessorState process() private TableFunctionProcessorState processCdfFile() { - Page page = deltaLakePageSource.getNextPage(); + SourcePage page = deltaLakePageSource.getNextSourcePage(); if (page != null) { int filePageColumns = page.getChannelCount(); Block[] resultBlock = new Block[filePageColumns + NUMBER_OF_ADDITIONAL_COLUMNS_FOR_CDF_FILE]; @@ -138,7 +138,7 @@ private TableFunctionProcessorState processCdfFile() private TableFunctionProcessorState processDataFile() { - Page page = deltaLakePageSource.getNextPage(); + SourcePage page = deltaLakePageSource.getNextSourcePage(); if (page != null) { int filePageColumns = page.getChannelCount(); Block[] blocks = new Block[filePageColumns + NUMBER_OF_ADDITIONAL_COLUMNS_FOR_DATA_FILE]; @@ -159,7 +159,7 @@ private TableFunctionProcessorState processDataFile() return TableFunctionProcessorState.Processed.produced(EMPTY_PAGE); } - private static DeltaLakePageSource createDeltaLakePageSource( + private static ConnectorPageSource createDeltaLakePageSource( ConnectorSession session, TrinoFileSystemFactory fileSystemFactory, DateTimeZone parquetDateTimeZone, @@ -193,7 +193,7 @@ private static DeltaLakePageSource createDeltaLakePageSource( case DATA_FILE -> handle.columns(); }; - ReaderPageSource pageSource = ParquetPageSourceFactory.createPageSource( + ConnectorPageSource pageSource = ParquetPageSourceFactory.createPageSource( inputFile, 0, split.fileSize(), @@ -207,18 +207,14 @@ private static DeltaLakePageSource createDeltaLakePageSource( domainCompactionThreshold, OptionalLong.empty()); - verify(pageSource.getReaderColumns().isEmpty(), "Unexpected reader columns: %s", pageSource.getReaderColumns().orElse(null)); - - return new DeltaLakePageSource( + return DeltaLakePageSourceProvider.projectColumns( splitColumns, ImmutableSet.of(), partitionKeys, Optional.empty(), - pageSource.get(), - Optional.empty(), + pageSource, split.path(), split.fileSize(), - 0L, - Optional::empty); + 0); } } diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java index 5099d6b25cb9..a917d52a6bd6 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/transactionlog/checkpoint/CheckpointEntryIterator.java @@ -20,8 +20,6 @@ import com.google.common.math.LongMath; import io.airlift.log.Logger; import io.trino.filesystem.TrinoInputFile; -import io.trino.parquet.Column; -import io.trino.parquet.Field; import io.trino.parquet.ParquetReaderOptions; import io.trino.plugin.base.metrics.FileFormatDataSourceStats; import io.trino.plugin.deltalake.DeltaHiveTypeTranslator; @@ -39,17 +37,16 @@ import io.trino.plugin.hive.HiveColumnHandle; import io.trino.plugin.hive.HiveColumnHandle.ColumnType; import io.trino.plugin.hive.HiveColumnProjectionInfo; -import io.trino.plugin.hive.ReaderPageSource; -import io.trino.plugin.hive.parquet.ParquetPageSource; import io.trino.plugin.hive.parquet.ParquetPageSourceFactory; -import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RowBlock; import io.trino.spi.block.SqlRow; import io.trino.spi.block.ValueBlock; +import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.ArrayType; @@ -62,6 +59,8 @@ import jakarta.annotation.Nullable; import org.joda.time.DateTimeZone; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.ArrayDeque; import java.util.List; import java.util.Map; @@ -73,8 +72,6 @@ import java.util.function.Predicate; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.MoreCollectors.onlyElement; import static com.google.common.collect.MoreCollectors.toOptional; @@ -137,7 +134,7 @@ public String getColumnName() private final String checkpointPath; private final ConnectorSession session; - private final ParquetPageSource pageSource; + private final ConnectorPageSource pageSource; private final MapType stringMap; private final ArrayType stringList; private final Queue nextEntries; @@ -146,7 +143,6 @@ public String getColumnName() private final TupleDomain partitionConstraint; private final Optional txnType; private final Optional addType; - private final Optional addPartitionValuesType; private final Optional addDeletionVectorType; private final Optional addParsedStatsFieldType; private final Optional removeType; @@ -160,7 +156,7 @@ public String getColumnName() private boolean deletionVectorsEnabled; private List schema; private List columnsWithMinMaxStats; - private Page page; + private SourcePage page; private int pagePosition; public CheckpointEntryIterator( @@ -209,17 +205,14 @@ public CheckpointEntryIterator( HiveColumnHandle column = buildColumnHandle(field, checkpointSchemaManager, this.metadataEntry, this.protocolEntry, addStatsMinMaxColumnFilter).toHiveColumnHandle(); columnsBuilder.add(column); disjunctDomainsBuilder.add(buildTupleDomainColumnHandle(field, column)); - if (field == ADD) { - Type addEntryPartitionValuesType = checkpointSchemaManager.getAddEntryPartitionValuesType(); - columnsBuilder.add(new DeltaLakeColumnHandle("add", addEntryPartitionValuesType, OptionalInt.empty(), "add", addEntryPartitionValuesType, REGULAR, Optional.empty()).toHiveColumnHandle()); - } } + List columns = columnsBuilder.build(); - ReaderPageSource pageSource = ParquetPageSourceFactory.createPageSource( + this.pageSource = ParquetPageSourceFactory.createPageSource( checkpoint, 0, fileSize, - columnsBuilder.build(), + columns, disjunctDomainsBuilder.build(), // OR-ed condition true, DateTimeZone.UTC, @@ -229,24 +222,20 @@ public CheckpointEntryIterator( domainCompactionThreshold, OptionalLong.of(fileSize)); - this.pageSource = (ParquetPageSource) pageSource.get(); try { - verify(pageSource.getReaderColumns().isEmpty(), "All columns expected to be base columns"); - this.nextEntries = new ArrayDeque<>(); this.extractors = fields.stream() .map(this::createCheckpointFieldExtractor) .collect(toImmutableList()); - txnType = getParquetType(fields, TRANSACTION); - addType = getAddParquetTypeContainingField(fields, "path"); - addPartitionValuesType = getAddParquetTypeContainingField(fields, "partitionValues"); + txnType = getParquetType(fields, TRANSACTION, columns); + addType = getAddParquetTypeContainingField(fields, "path", columns); addDeletionVectorType = addType.flatMap(type -> getOptionalFieldType(type, "deletionVector")); addParsedStatsFieldType = addType.flatMap(type -> getOptionalFieldType(type, "stats_parsed")); - removeType = getParquetType(fields, REMOVE); + removeType = getParquetType(fields, REMOVE, columns); removeDeletionVectorType = removeType.flatMap(type -> getOptionalFieldType(type, "deletionVector")); - metadataType = getParquetType(fields, METADATA); - protocolType = getParquetType(fields, PROTOCOL); - sidecarType = getParquetType(fields, SIDECAR); + metadataType = getParquetType(fields, METADATA, columns); + protocolType = getParquetType(fields, PROTOCOL, columns); + sidecarType = getParquetType(fields, SIDECAR, columns); } catch (Exception e) { try { @@ -254,7 +243,7 @@ public CheckpointEntryIterator( } catch (Exception _) { } - throw new TrinoException(GENERIC_INTERNAL_ERROR, "Error while initilizing the checkpoint entry iterator for the file %s" .formatted(checkpoint.location())); + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Error while initilizing the checkpoint entry iterator for the file %s".formatted(checkpoint.location()), e); } } @@ -267,49 +256,47 @@ private static Optional getOptionalFieldType(RowType type, String field .map(RowType.class::cast); } - private Optional getAddParquetTypeContainingField(Set fields, String fieldName) + private static Optional getAddParquetTypeContainingField(Set fields, String fieldName, List columns) { return fields.contains(ADD) ? - this.pageSource.getColumnFields().stream() - .filter(column -> column.name().equals(ADD.getColumnName()) && - column.field().getType() instanceof RowType rowType && + columns.stream() + .filter(column -> column.getName().equals(ADD.getColumnName()) && + column.getType() instanceof RowType rowType && rowType.getFields().stream().map(RowType.Field::getName).filter(Optional::isPresent).flatMap(Optional::stream).anyMatch(fieldName::equals)) // The field even if it was requested might not exist in Parquet file .collect(toOptional()) - .map(Column::field) - .map(Field::getType) + .map(HiveColumnHandle::getType) .map(RowType.class::cast) : Optional.empty(); } - private Optional getParquetType(Set fields, EntryType field) + private static Optional getParquetType(Set fields, EntryType field, List columns) { - return fields.contains(field) ? getParquetType(field.getColumnName()).map(RowType.class::cast) : Optional.empty(); + return fields.contains(field) ? getParquetType(field.getColumnName(), columns).map(RowType.class::cast) : Optional.empty(); } - private Optional getParquetType(String columnName) + private static Optional getParquetType(String columnName, List columns) { - return pageSource.getColumnFields().stream() - .filter(column -> column.name().equals(columnName)) + return columns.stream() + .filter(column -> column.getName().equals(columnName)) // The field even if it was requested may not exist in Parquet file .collect(toOptional()) - .map(Column::field) - .map(Field::getType); + .map(HiveColumnHandle::getType); } private CheckpointFieldExtractor createCheckpointFieldExtractor(EntryType entryType) { return switch (entryType) { - case TRANSACTION -> (session, pagePosition, blocks) -> buildTxnEntry(session, pagePosition, blocks[0]); + case TRANSACTION -> this::buildTxnEntry; case ADD -> new AddFileEntryExtractor(); - case REMOVE -> (session, pagePosition, blocks) -> buildRemoveEntry(session, pagePosition, blocks[0]); - case METADATA -> (session, pagePosition, blocks) -> buildMetadataEntry(session, pagePosition, blocks[0]); - case PROTOCOL -> (session, pagePosition, blocks) -> buildProtocolEntry(session, pagePosition, blocks[0]); - case SIDECAR -> (session, pagePosition, blocks) -> buildSidecarEntry(session, pagePosition, blocks[0]); + case REMOVE -> this::buildRemoveEntry; + case METADATA -> this::buildMetadataEntry; + case PROTOCOL -> this::buildProtocolEntry; + case SIDECAR -> this::buildSidecarEntry; }; } - private DeltaLakeColumnHandle buildColumnHandle( + private static DeltaLakeColumnHandle buildColumnHandle( EntryType entryType, CheckpointSchemaManager schemaManager, MetadataEntry metadataEntry, @@ -318,7 +305,7 @@ private DeltaLakeColumnHandle buildColumnHandle( { Type type = switch (entryType) { case TRANSACTION -> schemaManager.getTxnEntryType(); - case ADD -> schemaManager.getAddEntryType(metadataEntry, protocolEntry, addStatsMinMaxColumnFilter.orElseThrow(), true, true, false); + case ADD -> schemaManager.getAddEntryType(metadataEntry, protocolEntry, addStatsMinMaxColumnFilter.orElseThrow(), true, true, true); case REMOVE -> schemaManager.getRemoveEntryType(); case METADATA -> schemaManager.getMetadataEntryType(); case PROTOCOL -> schemaManager.getProtocolEntryType(true, true); @@ -520,31 +507,25 @@ private class AddFileEntryExtractor { @Nullable @Override - public DeltaLakeTransactionLogEntry getEntry(ConnectorSession session, int pagePosition, Block... blocks) + public DeltaLakeTransactionLogEntry getEntry(ConnectorSession session, int pagePosition, Block addBlock) { - checkState(blocks.length == getRequiredChannels(), "Unexpected amount of blocks: %s", blocks.length); - Block addBlock = blocks[0]; - Block addPartitionValuesBlock = blocks[1]; log.debug("Building add entry from %s pagePosition %d", addBlock, pagePosition); if (addBlock.isNull(pagePosition)) { return null; } - checkState(!addPartitionValuesBlock.isNull(pagePosition), "Inconsistent blocks provided while building the add file entry"); - SqlRow addPartitionValuesRow = getRow(addPartitionValuesBlock, pagePosition); - CheckpointFieldReader addPartitionValuesReader = new CheckpointFieldReader(session, addPartitionValuesRow, addPartitionValuesType.orElseThrow()); - Map partitionValues = addPartitionValuesReader.getMap(stringMap, "partitionValues"); - Map> canonicalPartitionValues = canonicalizePartitionValues(partitionValues); - if (!partitionConstraint.isAll() && !partitionMatchesPredicate(canonicalPartitionValues, partitionConstraint.getDomains().orElseThrow())) { - return null; - } - // Materialize from Parquet the information needed to build the AddEntry instance addBlock = addBlock.getLoadedBlock(); SqlRow addEntryRow = getRow(addBlock, pagePosition); log.debug("Block %s has %s fields", addBlock, addEntryRow.getFieldCount()); CheckpointFieldReader addReader = new CheckpointFieldReader(session, addEntryRow, addType.orElseThrow()); + Map partitionValues = addReader.getMap(stringMap, "partitionValues"); + Map> canonicalPartitionValues = canonicalizePartitionValues(partitionValues); + if (!partitionConstraint.isAll() && !partitionMatchesPredicate(canonicalPartitionValues, partitionConstraint.getDomains().orElseThrow())) { + return null; + } + String path = addReader.getString("path"); long size = addReader.getLong("size"); long modificationTime = addReader.getLong("modificationTime"); @@ -579,15 +560,9 @@ public DeltaLakeTransactionLogEntry getEntry(ConnectorSession session, int pageP log.debug("Result: %s", result); return DeltaLakeTransactionLogEntry.addFileEntry(result); } - - @Override - public int getRequiredChannels() - { - return 2; - } } - private DeletionVectorEntry parseDeletionVectorFromParquet(ConnectorSession session, SqlRow row, RowType type) + private static DeletionVectorEntry parseDeletionVectorFromParquet(ConnectorSession session, SqlRow row, RowType type) { checkArgument(row.getFieldCount() == 5, "Deletion vector entry must have 5 fields"); @@ -720,29 +695,35 @@ private static long getLongField(SqlRow row, int field) @Override protected DeltaLakeTransactionLogEntry computeNext() { - if (nextEntries.isEmpty()) { - fillNextEntries(); + try { + if (nextEntries.isEmpty()) { + fillNextEntries(); + } + if (!nextEntries.isEmpty()) { + return nextEntries.remove(); + } + pageSource.close(); + return endOfData(); } - if (!nextEntries.isEmpty()) { - return nextEntries.remove(); + catch (IOException e) { + throw new UncheckedIOException(e); } - pageSource.close(); - return endOfData(); } private boolean tryAdvancePage() + throws IOException { if (pageSource.isFinished()) { pageSource.close(); return false; } boolean isFirstPage = page == null; - page = pageSource.getNextPage(); + page = pageSource.getNextSourcePage(); if (page == null) { return false; } if (isFirstPage) { - int requiredExtractorChannels = extractors.stream().mapToInt(CheckpointFieldExtractor::getRequiredChannels).sum(); + int requiredExtractorChannels = extractors.size(); if (page.getChannelCount() != requiredExtractorChannels) { throw new TrinoException(DELTA_LAKE_INVALID_SCHEMA, format("Expected page in %s to contain %d channels, but found %d", @@ -755,10 +736,16 @@ private boolean tryAdvancePage() public void close() { - pageSource.close(); + try { + pageSource.close(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } } private void fillNextEntries() + throws IOException { while (nextEntries.isEmpty()) { // grab next page if needed @@ -774,9 +761,7 @@ private void fillNextEntries() DeltaLakeTransactionLogEntry entry; if (extractor instanceof AddFileEntryExtractor) { // Avoid unnecessary loading of the block in case there is a partition predicate mismatch for this add entry - Block addBlock = page.getBlock(blockIndex); - Block addPartitionValuesBlock = page.getBlock(blockIndex + 1); - entry = extractor.getEntry(session, pagePosition, addBlock, addPartitionValuesBlock.getLoadedBlock()); + entry = extractor.getEntry(session, pagePosition, page.getBlock(blockIndex)); } else { entry = extractor.getEntry(session, pagePosition, page.getBlock(blockIndex).getLoadedBlock()); @@ -784,7 +769,7 @@ private void fillNextEntries() if (entry != null) { nextEntries.add(entry); } - blockIndex += extractor.getRequiredChannels(); + blockIndex++; } pagePosition++; } @@ -812,12 +797,7 @@ private interface CheckpointFieldExtractor * checkpoint filter criteria. */ @Nullable - DeltaLakeTransactionLogEntry getEntry(ConnectorSession session, int pagePosition, Block... blocks); - - default int getRequiredChannels() - { - return 1; - } + DeltaLakeTransactionLogEntry getEntry(ConnectorSession session, int pagePosition, Block block); } private static SqlRow getRow(Block block, int position) diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java index b3800929087b..609aabbf7f6e 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeBasic.java @@ -42,8 +42,8 @@ import io.trino.plugin.deltalake.transactionlog.checkpoint.TransactionLogTail; import io.trino.plugin.deltalake.transactionlog.statistics.DeltaLakeFileStatistics; import io.trino.plugin.hive.parquet.TrinoParquetDataSource; -import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.RowType; import io.trino.spi.type.SqlDate; import io.trino.spi.type.SqlTimestamp; @@ -393,7 +393,7 @@ private void testPartitionValuesParsedCheckpoint(ColumnMappingMode columnMapping ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, Optional.empty()); try (ParquetReader reader = createParquetReader(dataSource, parquetMetadata, ImmutableList.of(addEntryType), List.of("add"))) { List actual = new ArrayList<>(); - Page page = reader.nextPage(); + SourcePage page = reader.nextPage(); while (page != null) { Block block = page.getBlock(0); for (int i = 0; i < block.getPositionCount(); i++) { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeNodeLocalDynamicSplitPruning.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeNodeLocalDynamicSplitPruning.java index 23290483a402..70e0e0473eaa 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeNodeLocalDynamicSplitPruning.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeNodeLocalDynamicSplitPruning.java @@ -38,6 +38,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; @@ -156,7 +157,7 @@ public void testDynamicSplitPruningOnUnpartitionedTable() keyColumnHandle, Domain.singleValue(INTEGER, 1L))); try (ConnectorPageSource emptyPageSource = createTestingPageSource(transaction, deltaLakeConfig, split, tableHandle, ImmutableList.of(keyColumnHandle, dataColumnHandle), getDynamicFilter(splitPruningPredicate))) { - assertThat(emptyPageSource.getNextPage()).isNull(); + assertThat(emptyPageSource.getNextSourcePage()).isNull(); } TupleDomain nonSelectivePredicate = TupleDomain.withColumnDomains( @@ -164,7 +165,7 @@ public void testDynamicSplitPruningOnUnpartitionedTable() keyColumnHandle, Domain.singleValue(INTEGER, (long) keyColumnValue))); try (ConnectorPageSource nonEmptyPageSource = createTestingPageSource(transaction, deltaLakeConfig, split, tableHandle, ImmutableList.of(keyColumnHandle, dataColumnHandle), getDynamicFilter(nonSelectivePredicate))) { - Page page = nonEmptyPageSource.getNextPage(); + SourcePage page = nonEmptyPageSource.getNextSourcePage(); assertThat(page).isNotNull(); assertThat(page.getPositionCount()).isEqualTo(1); assertThat(INTEGER.getInt(page.getBlock(0), 0)).isEqualTo(keyColumnValue); @@ -270,7 +271,7 @@ public void testDynamicSplitPruningWithExplicitPartitionFilter() tableHandle, ImmutableList.of(dateColumnHandle, receiptColumnHandle, amountColumnHandle), getDynamicFilter(partitionPredicate))) { - assertThat(emptyPageSource.getNextPage()).isNull(); + assertThat(emptyPageSource.getNextSourcePage()).isNull(); } } @@ -290,7 +291,7 @@ public void testDynamicSplitPruningWithExplicitPartitionFilter() tableHandle, ImmutableList.of(dateColumnHandle, receiptColumnHandle, amountColumnHandle), getDynamicFilter(partitionPredicate))) { - Page page = nonEmptyPageSource.getNextPage(); + SourcePage page = nonEmptyPageSource.getNextSourcePage(); assertThat(page).isNotNull(); assertThat(page.getPositionCount()).isEqualTo(1); assertThat(INTEGER.getInt(page.getBlock(0), 0)).isEqualTo(dateColumnValue); diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePageSource.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePageSource.java deleted file mode 100644 index 022467bd0bb2..000000000000 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakePageSource.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.deltalake; - -import io.trino.spi.connector.ConnectorPageSource; -import org.junit.jupiter.api.Test; - -import static io.trino.spi.testing.InterfaceTestUtils.assertAllMethodsOverridden; - -public class TestDeltaLakePageSource -{ - @Test - public void testEverythingImplemented() - { - assertAllMethodsOverridden(ConnectorPageSource.class, DeltaLakePageSource.class); - } -} diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/CountQueryPageSource.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/CountQueryPageSource.java index aa68fbbf2060..2869723877cc 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/CountQueryPageSource.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/CountQueryPageSource.java @@ -14,8 +14,8 @@ package io.trino.plugin.elasticsearch; import io.trino.plugin.elasticsearch.client.ElasticsearchClient; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import static io.trino.plugin.elasticsearch.ElasticsearchQueryBuilder.buildSearchQuery; import static java.lang.Math.toIntExact; @@ -60,12 +60,12 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { int batch = toIntExact(Math.min(BATCH_SIZE, remaining)); remaining -= batch; - return new Page(batch); + return SourcePage.create(batch); } @Override diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/PassthroughQueryPageSource.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/PassthroughQueryPageSource.java index 4a12f2b604c7..a4317a3b9fa7 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/PassthroughQueryPageSource.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/PassthroughQueryPageSource.java @@ -13,13 +13,12 @@ */ package io.trino.plugin.elasticsearch; -import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.plugin.elasticsearch.client.ElasticsearchClient; -import io.trino.spi.Page; -import io.trino.spi.PageBuilder; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import java.io.IOException; @@ -62,7 +61,7 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { if (done) { return null; @@ -70,11 +69,10 @@ public Page getNextPage() done = true; - PageBuilder page = new PageBuilder(1, ImmutableList.of(VARCHAR)); - page.declarePosition(); - BlockBuilder column = page.getBlockBuilder(0); - VARCHAR.writeSlice(column, Slices.utf8Slice(result)); - return page.build(); + Slice slice = Slices.utf8Slice(result); + BlockBuilder column = VARCHAR.createBlockBuilder(null, 0, slice.length()); + VARCHAR.writeSlice(column, slice); + return SourcePage.create(column.build()); } @Override diff --git a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ScanQueryPageSource.java b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ScanQueryPageSource.java index e32f722ee52e..cc0aa167f75e 100644 --- a/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ScanQueryPageSource.java +++ b/plugin/trino-elasticsearch/src/main/java/io/trino/plugin/elasticsearch/ScanQueryPageSource.java @@ -23,6 +23,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; @@ -151,7 +152,7 @@ public void close() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { long size = 0; while (size < PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES && iterator.hasNext()) { @@ -187,7 +188,7 @@ public Page getNextPage() columnBuilders[i] = columnBuilders[i].newBlockBuilderLike(null); } - return new Page(blocks); + return SourcePage.create(new Page(blocks)); } private static Map resolveField(Map document, ElasticsearchColumnHandle columnHandle) diff --git a/plugin/trino-faker/src/main/java/io/trino/plugin/faker/FakerMetadata.java b/plugin/trino-faker/src/main/java/io/trino/plugin/faker/FakerMetadata.java index 3d95e2395651..6e3149dac4a3 100644 --- a/plugin/trino-faker/src/main/java/io/trino/plugin/faker/FakerMetadata.java +++ b/plugin/trino-faker/src/main/java/io/trino/plugin/faker/FakerMetadata.java @@ -18,7 +18,6 @@ import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.concurrent.GuardedBy; import io.airlift.slice.Slice; -import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.connector.ColumnHandle; @@ -38,6 +37,7 @@ import io.trino.spi.connector.SaveMode; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; +import io.trino.spi.connector.SourcePage; import io.trino.spi.connector.ViewNotFoundException; import io.trino.spi.function.BoundSignature; import io.trino.spi.function.FunctionDependencyDeclaration; @@ -600,9 +600,9 @@ private Map> getColumnValues(SchemaTableName tableName, Tab } ImmutableMap.Builder> columnValues = ImmutableMap.builder(); try (FakerPageSource pageSource = new FakerPageSource(faker, random, dictionaryColumns, 0, MAX_DICTIONARY_SIZE * 2)) { - Page page = null; + SourcePage page = null; while (page == null) { - page = pageSource.getNextPage(); + page = pageSource.getNextSourcePage(); } Map types = columns.stream().collect(toImmutableMap(ColumnInfo::name, ColumnInfo::type)); for (int channel = 0; channel < dictionaryColumns.size(); channel++) { diff --git a/plugin/trino-faker/src/main/java/io/trino/plugin/faker/FakerPageSource.java b/plugin/trino-faker/src/main/java/io/trino/plugin/faker/FakerPageSource.java index e1fd97128e08..d0d8d6ef7848 100644 --- a/plugin/trino-faker/src/main/java/io/trino/plugin/faker/FakerPageSource.java +++ b/plugin/trino-faker/src/main/java/io/trino/plugin/faker/FakerPageSource.java @@ -26,6 +26,7 @@ import io.trino.spi.block.RowBlock; import io.trino.spi.block.RowBlockBuilder; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.ValueSet; @@ -209,7 +210,7 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { if (!closed) { int positions = (int) Math.min(limit - completedRows, ROWS_PER_PAGE); @@ -233,7 +234,7 @@ public Page getNextPage() if ((closed && !pageBuilder.isEmpty()) || pageBuilder.isFull()) { Page page = pageBuilder.build(); pageBuilder.reset(); - return page; + return SourcePage.create(page); } return null; diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/BucketAdapter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/BucketAdapter.java new file mode 100644 index 000000000000..0dff33b9ba1b --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/BucketAdapter.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import io.trino.metastore.HiveType; +import io.trino.metastore.type.TypeInfo; +import io.trino.plugin.hive.util.HiveBucketing; +import io.trino.spi.Page; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.SourcePage; +import it.unimi.dsi.fastutil.ints.IntArrayList; +import jakarta.annotation.Nullable; + +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.plugin.hive.HiveErrorCode.HIVE_INVALID_BUCKET_FILES; +import static io.trino.plugin.hive.util.HiveBucketing.getHiveBucket; +import static java.lang.String.format; + +public class BucketAdapter +{ + private final int[] bucketColumns; + private final HiveBucketing.BucketingVersion bucketingVersion; + private final int bucketToKeep; + private final int tableBucketCount; + private final int partitionBucketCount; // for sanity check only + private final List typeInfoList; + + public BucketAdapter(HivePageSourceProvider.BucketAdaptation bucketAdaptation) + { + this.bucketColumns = bucketAdaptation.getBucketColumnIndices(); + this.bucketingVersion = bucketAdaptation.getBucketingVersion(); + this.bucketToKeep = bucketAdaptation.getBucketToKeep(); + this.typeInfoList = bucketAdaptation.getBucketColumnHiveTypes().stream() + .map(HiveType::getTypeInfo) + .collect(toImmutableList()); + this.tableBucketCount = bucketAdaptation.getTableBucketCount(); + this.partitionBucketCount = bucketAdaptation.getPartitionBucketCount(); + } + + @Nullable + public SourcePage filterPageToEligibleRowsOrDiscard(SourcePage page) + { + IntArrayList ids = new IntArrayList(page.getPositionCount()); + Page bucketColumnsPage = page.getColumns(bucketColumns); + for (int position = 0; position < page.getPositionCount(); position++) { + int bucket = getHiveBucket(bucketingVersion, tableBucketCount, typeInfoList, bucketColumnsPage, position); + if ((bucket - bucketToKeep) % partitionBucketCount != 0) { + throw new TrinoException(HIVE_INVALID_BUCKET_FILES, format( + "A row that is supposed to be in bucket %s is encountered. Only rows in bucket %s (modulo %s) are expected", + bucket, bucketToKeep % partitionBucketCount, partitionBucketCount)); + } + if (bucket == bucketToKeep) { + ids.add(position); + } + } + int retainedRowCount = ids.size(); + if (retainedRowCount == 0) { + return null; + } + if (retainedRowCount == page.getPositionCount()) { + return page; + } + page.selectPositions(ids.elements(), 0, retainedRowCount); + return page; + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/BucketValidator.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/BucketValidator.java new file mode 100644 index 000000000000..5c74159b3fdb --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/BucketValidator.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import io.trino.filesystem.Location; +import io.trino.metastore.type.TypeInfo; +import io.trino.plugin.hive.util.HiveBucketing; +import io.trino.spi.Page; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.SourcePage; + +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.plugin.hive.HiveErrorCode.HIVE_INVALID_BUCKET_FILES; +import static io.trino.plugin.hive.util.HiveBucketing.getHiveBucket; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class BucketValidator +{ + // validate every ~100 rows but using a prime number + public static final int VALIDATION_STRIDE = 97; + + private final Location path; + private final int[] bucketColumnIndices; + private final List bucketColumnTypes; + private final HiveBucketing.BucketingVersion bucketingVersion; + private final int bucketCount; + private final int expectedBucket; + + public BucketValidator( + Location path, + int[] bucketColumnIndices, + List bucketColumnTypes, + HiveBucketing.BucketingVersion bucketingVersion, + int bucketCount, + int expectedBucket) + { + this.path = requireNonNull(path, "path is null"); + this.bucketColumnIndices = requireNonNull(bucketColumnIndices, "bucketColumnIndices is null"); + this.bucketColumnTypes = requireNonNull(bucketColumnTypes, "bucketColumnTypes is null"); + this.bucketingVersion = requireNonNull(bucketingVersion, "bucketingVersion is null"); + this.bucketCount = bucketCount; + this.expectedBucket = expectedBucket; + checkArgument(bucketColumnIndices.length == bucketColumnTypes.size(), "indices and types counts mismatch"); + } + + public void validate(SourcePage page) + { + Page bucketColumnsPage = page.getColumns(bucketColumnIndices); + for (int position = 0; position < page.getPositionCount(); position += VALIDATION_STRIDE) { + int bucket = getHiveBucket(bucketingVersion, bucketCount, bucketColumnTypes, bucketColumnsPage, position); + if (bucket != expectedBucket) { + throw new TrinoException( + HIVE_INVALID_BUCKET_FILES, + format("Hive table is corrupt. File '%s' is for bucket %s, but contains a row for bucket %s.", path, expectedBucket, bucket)); + } + } + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSource.java deleted file mode 100644 index 99931440040d..000000000000 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSource.java +++ /dev/null @@ -1,375 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hive; - -import com.google.common.collect.ImmutableList; -import io.trino.filesystem.Location; -import io.trino.metastore.HiveType; -import io.trino.metastore.type.TypeInfo; -import io.trino.plugin.hive.HivePageSourceProvider.BucketAdaptation; -import io.trino.plugin.hive.HivePageSourceProvider.ColumnMapping; -import io.trino.plugin.hive.coercions.CoercionUtils.CoercionContext; -import io.trino.plugin.hive.coercions.TypeCoercer; -import io.trino.plugin.hive.util.HiveBucketing.BucketingVersion; -import io.trino.spi.Page; -import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; -import io.trino.spi.block.LazyBlock; -import io.trino.spi.block.LazyBlockLoader; -import io.trino.spi.block.RunLengthEncodedBlock; -import io.trino.spi.connector.ConnectorPageSource; -import io.trino.spi.metrics.Metrics; -import io.trino.spi.type.Type; -import io.trino.spi.type.TypeManager; -import it.unimi.dsi.fastutil.ints.IntArrayList; -import jakarta.annotation.Nullable; - -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.OptionalLong; -import java.util.concurrent.CompletableFuture; -import java.util.function.Function; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.plugin.base.util.Closables.closeAllSuppress; -import static io.trino.plugin.hive.HiveColumnHandle.isRowIdColumnHandle; -import static io.trino.plugin.hive.HiveErrorCode.HIVE_CURSOR_ERROR; -import static io.trino.plugin.hive.HiveErrorCode.HIVE_INVALID_BUCKET_FILES; -import static io.trino.plugin.hive.HivePageSourceProvider.ColumnMappingKind.EMPTY; -import static io.trino.plugin.hive.HivePageSourceProvider.ColumnMappingKind.PREFILLED; -import static io.trino.plugin.hive.coercions.CoercionUtils.createCoercer; -import static io.trino.plugin.hive.util.HiveBucketing.getHiveBucket; -import static io.trino.plugin.hive.util.HiveTypeUtil.getHiveTypeForDereferences; -import static java.lang.String.format; -import static java.util.Objects.requireNonNull; - -public class HivePageSource - implements ConnectorPageSource -{ - public static final int ORIGINAL_TRANSACTION_CHANNEL = 0; - public static final int BUCKET_CHANNEL = 1; - public static final int ROW_ID_CHANNEL = 2; - - private final List columnMappings; - private final Optional bucketAdapter; - private final Optional bucketValidator; - private final Object[] prefilledValues; - private final Type[] types; - private final List>> coercers; - private final Optional projectionsAdapter; - - private final ConnectorPageSource delegate; - - public HivePageSource( - List columnMappings, - Optional bucketAdaptation, - Optional bucketValidator, - Optional projectionsAdapter, - TypeManager typeManager, - CoercionContext coercionContext, - ConnectorPageSource delegate) - { - requireNonNull(columnMappings, "columnMappings is null"); - requireNonNull(typeManager, "typeManager is null"); - requireNonNull(coercionContext, "coercionContext is null"); - - this.delegate = requireNonNull(delegate, "delegate is null"); - this.columnMappings = columnMappings; - this.bucketAdapter = bucketAdaptation.map(BucketAdapter::new); - this.bucketValidator = requireNonNull(bucketValidator, "bucketValidator is null"); - - this.projectionsAdapter = requireNonNull(projectionsAdapter, "projectionsAdapter is null"); - - int size = columnMappings.size(); - - prefilledValues = new Object[size]; - types = new Type[size]; - ImmutableList.Builder>> coercers = ImmutableList.builder(); - - for (int columnIndex = 0; columnIndex < size; columnIndex++) { - ColumnMapping columnMapping = columnMappings.get(columnIndex); - HiveColumnHandle column = columnMapping.getHiveColumnHandle(); - - Type type = column.getType(); - types[columnIndex] = type; - - if (columnMapping.getKind() != EMPTY && columnMapping.getBaseTypeCoercionFrom().isPresent()) { - List dereferenceIndices = column.getHiveColumnProjectionInfo() - .map(HiveColumnProjectionInfo::getDereferenceIndices) - .orElse(ImmutableList.of()); - HiveType fromType = getHiveTypeForDereferences(columnMapping.getBaseTypeCoercionFrom().get(), dereferenceIndices).get(); - HiveType toType = columnMapping.getHiveColumnHandle().getHiveType(); - coercers.add(createCoercer(typeManager, fromType, toType, coercionContext)); - } - else { - coercers.add(Optional.empty()); - } - - if (columnMapping.getKind() == EMPTY || isRowIdColumnHandle(column)) { - prefilledValues[columnIndex] = null; - } - else if (columnMapping.getKind() == PREFILLED) { - prefilledValues[columnIndex] = columnMapping.getPrefilledValue().getValue(); - } - } - this.coercers = coercers.build(); - } - - @Override - public long getCompletedBytes() - { - return delegate.getCompletedBytes(); - } - - @Override - public OptionalLong getCompletedPositions() - { - return delegate.getCompletedPositions(); - } - - @Override - public long getReadTimeNanos() - { - return delegate.getReadTimeNanos(); - } - - @Override - public boolean isFinished() - { - return delegate.isFinished(); - } - - @Override - public CompletableFuture isBlocked() - { - return delegate.isBlocked(); - } - - @Override - public Page getNextPage() - { - try { - Page dataPage = delegate.getNextPage(); - if (dataPage == null) { - return null; - } - - if (projectionsAdapter.isPresent()) { - dataPage = projectionsAdapter.get().adaptPage(dataPage); - } - - if (bucketAdapter.isPresent()) { - dataPage = bucketAdapter.get().filterPageToEligibleRowsOrDiscard(dataPage); - if (dataPage == null) { - return null; - } - } - else { - // bucket adaptation already validates that data is in the right bucket - final Page dataPageRef = dataPage; - bucketValidator.ifPresent(validator -> validator.validate(dataPageRef)); - } - - int batchSize = dataPage.getPositionCount(); - List blocks = new ArrayList<>(); - for (int fieldId = 0; fieldId < columnMappings.size(); fieldId++) { - ColumnMapping columnMapping = columnMappings.get(fieldId); - switch (columnMapping.getKind()) { - case PREFILLED: - case EMPTY: - blocks.add(RunLengthEncodedBlock.create(types[fieldId], prefilledValues[fieldId], batchSize)); - break; - case REGULAR: - case SYNTHESIZED: - Block block = dataPage.getBlock(columnMapping.getIndex()); - Optional> coercer = coercers.get(fieldId); - if (coercer.isPresent()) { - block = new LazyBlock(batchSize, new CoercionLazyBlockLoader(block, coercer.get())); - } - blocks.add(block); - break; - case INTERIM: - // interim columns don't show up in output - break; - default: - throw new UnsupportedOperationException(); - } - } - - return new Page(batchSize, blocks.toArray(new Block[0])); - } - catch (TrinoException e) { - closeAllSuppress(e, this); - throw e; - } - catch (RuntimeException e) { - closeAllSuppress(e, this); - throw new TrinoException(HIVE_CURSOR_ERROR, e); - } - } - - @Override - public void close() - { - try { - delegate.close(); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - @Override - public String toString() - { - return delegate.toString(); - } - - @Override - public long getMemoryUsage() - { - return delegate.getMemoryUsage(); - } - - @Override - public Metrics getMetrics() - { - return delegate.getMetrics(); - } - - public ConnectorPageSource getPageSource() - { - return delegate; - } - - private static final class CoercionLazyBlockLoader - implements LazyBlockLoader - { - private final Function coercer; - private Block block; - - public CoercionLazyBlockLoader(Block block, Function coercer) - { - this.block = requireNonNull(block, "block is null"); - this.coercer = requireNonNull(coercer, "coercer is null"); - } - - @Override - public Block load() - { - checkState(block != null, "Already loaded"); - - Block loaded = coercer.apply(block.getLoadedBlock()); - // clear reference to loader to free resources, since load was successful - block = null; - - return loaded; - } - } - - public static class BucketAdapter - { - private final int[] bucketColumns; - private final BucketingVersion bucketingVersion; - private final int bucketToKeep; - private final int tableBucketCount; - private final int partitionBucketCount; // for sanity check only - private final List typeInfoList; - - public BucketAdapter(BucketAdaptation bucketAdaptation) - { - this.bucketColumns = bucketAdaptation.getBucketColumnIndices(); - this.bucketingVersion = bucketAdaptation.getBucketingVersion(); - this.bucketToKeep = bucketAdaptation.getBucketToKeep(); - this.typeInfoList = bucketAdaptation.getBucketColumnHiveTypes().stream() - .map(HiveType::getTypeInfo) - .collect(toImmutableList()); - this.tableBucketCount = bucketAdaptation.getTableBucketCount(); - this.partitionBucketCount = bucketAdaptation.getPartitionBucketCount(); - } - - @Nullable - public Page filterPageToEligibleRowsOrDiscard(Page page) - { - IntArrayList ids = new IntArrayList(page.getPositionCount()); - Page bucketColumnsPage = page.getColumns(bucketColumns); - for (int position = 0; position < page.getPositionCount(); position++) { - int bucket = getHiveBucket(bucketingVersion, tableBucketCount, typeInfoList, bucketColumnsPage, position); - if ((bucket - bucketToKeep) % partitionBucketCount != 0) { - throw new TrinoException(HIVE_INVALID_BUCKET_FILES, format( - "A row that is supposed to be in bucket %s is encountered. Only rows in bucket %s (modulo %s) are expected", - bucket, bucketToKeep % partitionBucketCount, partitionBucketCount)); - } - if (bucket == bucketToKeep) { - ids.add(position); - } - } - int retainedRowCount = ids.size(); - if (retainedRowCount == 0) { - return null; - } - if (retainedRowCount == page.getPositionCount()) { - return page; - } - return page.getPositions(ids.elements(), 0, retainedRowCount); - } - } - - public static class BucketValidator - { - // validate every ~100 rows but using a prime number - public static final int VALIDATION_STRIDE = 97; - - private final Location path; - private final int[] bucketColumnIndices; - private final List bucketColumnTypes; - private final BucketingVersion bucketingVersion; - private final int bucketCount; - private final int expectedBucket; - - public BucketValidator( - Location path, - int[] bucketColumnIndices, - List bucketColumnTypes, - BucketingVersion bucketingVersion, - int bucketCount, - int expectedBucket) - { - this.path = requireNonNull(path, "path is null"); - this.bucketColumnIndices = requireNonNull(bucketColumnIndices, "bucketColumnIndices is null"); - this.bucketColumnTypes = requireNonNull(bucketColumnTypes, "bucketColumnTypes is null"); - this.bucketingVersion = requireNonNull(bucketingVersion, "bucketingVersion is null"); - this.bucketCount = bucketCount; - this.expectedBucket = expectedBucket; - checkArgument(bucketColumnIndices.length == bucketColumnTypes.size(), "indices and types counts mismatch"); - } - - public void validate(Page page) - { - Page bucketColumnsPage = page.getColumns(bucketColumnIndices); - for (int position = 0; position < page.getPositionCount(); position += VALIDATION_STRIDE) { - int bucket = getHiveBucket(bucketingVersion, bucketCount, bucketColumnTypes, bucketColumnsPage, position); - if (bucket != expectedBucket) { - throw new TrinoException(HIVE_INVALID_BUCKET_FILES, - format("Hive table is corrupt. File '%s' is for bucket %s, but contains a row for bucket %s.", path, expectedBucket, bucket)); - } - } - } - } -} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceFactory.java index 43979da2c704..0756279a362a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceFactory.java @@ -15,6 +15,7 @@ import io.trino.filesystem.Location; import io.trino.plugin.hive.acid.AcidTransaction; +import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.predicate.TupleDomain; @@ -24,7 +25,7 @@ public interface HivePageSourceFactory { - Optional createPageSource( + Optional createPageSource( ConnectorSession session, Location path, long start, diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java index 1ce82c26c55e..0fa66f92e875 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePageSourceProvider.java @@ -13,8 +13,7 @@ */ package io.trino.plugin.hive; -import com.google.common.collect.BiMap; -import com.google.common.collect.ImmutableBiMap; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; @@ -22,11 +21,11 @@ import io.trino.metastore.HiveType; import io.trino.metastore.HiveTypeName; import io.trino.metastore.type.TypeInfo; -import io.trino.plugin.hive.HivePageSource.BucketValidator; import io.trino.plugin.hive.HiveSplit.BucketConversion; import io.trino.plugin.hive.HiveSplit.BucketValidation; import io.trino.plugin.hive.acid.AcidTransaction; import io.trino.plugin.hive.coercions.CoercionUtils.CoercionContext; +import io.trino.plugin.hive.coercions.TypeCoercer; import io.trino.plugin.hive.util.HiveBucketing.BucketingVersion; import io.trino.spi.TrinoException; import io.trino.spi.connector.ColumnHandle; @@ -41,6 +40,8 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.Utils; +import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import java.util.ArrayList; @@ -48,10 +49,10 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; +import java.util.function.Function; import java.util.regex.Pattern; import static com.google.common.base.Preconditions.checkArgument; @@ -67,6 +68,7 @@ import static io.trino.plugin.hive.HivePageSourceProvider.ColumnMapping.toColumnHandles; import static io.trino.plugin.hive.HivePageSourceProvider.ColumnMappingKind.PREFILLED; import static io.trino.plugin.hive.HiveSessionProperties.getTimestampPrecision; +import static io.trino.plugin.hive.coercions.CoercionUtils.createCoercer; import static io.trino.plugin.hive.coercions.CoercionUtils.createTypeFromCoercer; import static io.trino.plugin.hive.coercions.CoercionUtils.extractHiveStorageFormat; import static io.trino.plugin.hive.util.HiveBucketing.HiveBucketFilter; @@ -81,6 +83,9 @@ public class HivePageSourceProvider implements ConnectorPageSourceProvider { + public static final int ORIGINAL_TRANSACTION_CHANNEL = 0; + public static final int BUCKET_CHANNEL = 1; + public static final int ROW_ID_CHANNEL = 2; // The original file path looks like this: /root/dir/nnnnnnn_m(_copy_ccc)? private static final Pattern ORIGINAL_FILE_PATH_MATCHER = Pattern.compile("(?s)(?.*)/(?(?\\d+)_(?.*)?)$"); @@ -200,7 +205,7 @@ public static Optional createHivePageSource( for (HivePageSourceFactory pageSourceFactory : pageSourceFactories) { List desiredColumns = toColumnHandles(regularAndInterimColumnMappings, typeManager, coercionContext); - Optional readerWithProjections = pageSourceFactory.createPageSource( + Optional pageSource = pageSourceFactory.createPageSource( session, path, start, @@ -215,29 +220,75 @@ public static Optional createHivePageSource( originalFile, transaction); - if (readerWithProjections.isPresent()) { - ConnectorPageSource pageSource = readerWithProjections.get().get(); - - Optional readerProjections = readerWithProjections.get().getReaderColumns(); - Optional adapter = Optional.empty(); - if (readerProjections.isPresent()) { - adapter = Optional.of(hiveProjectionsAdapter(desiredColumns, readerProjections.get())); - } - - return Optional.of(new HivePageSource( - columnMappings, + if (pageSource.isPresent()) { + return Optional.of(createHivePageSource(columnMappings, bucketAdaptation, bucketValidator, - adapter, typeManager, coercionContext, - pageSource)); + pageSource.get())); } } return Optional.empty(); } + @VisibleForTesting + static ConnectorPageSource createHivePageSource( + List columnMappings, + Optional bucketAdaptation, + Optional bucketValidator, + TypeManager typeManager, + CoercionContext coercionContext, + ConnectorPageSource pageSource) + { + if (bucketAdaptation.isPresent()) { + BucketAdapter bucketAdapter = new BucketAdapter(bucketAdaptation.get()); + pageSource = TransformConnectorPageSource.create(pageSource, bucketAdapter::filterPageToEligibleRowsOrDiscard); + } + else if (bucketValidator.isPresent()) { + BucketValidator validator = bucketValidator.get(); + pageSource = TransformConnectorPageSource.create(pageSource, page -> { + validator.validate(page); + return page; + }); + } + + TransformConnectorPageSource.Builder transforms = TransformConnectorPageSource.builder(); + for (ColumnMapping columnMapping : columnMappings) { + HiveColumnHandle column = columnMapping.getHiveColumnHandle(); + + Type type = column.getType(); + switch (columnMapping.getKind()) { + case PREFILLED -> transforms.constantValue(Utils.nativeValueToBlock(type, columnMapping.getPrefilledValue().getValue())); + case EMPTY -> transforms.constantValue(type.createNullBlock()); + case REGULAR, SYNTHESIZED -> { + Optional> coercer = Optional.empty(); + if (columnMapping.getBaseTypeCoercionFrom().isPresent()) { + List dereferenceIndices = column.getHiveColumnProjectionInfo() + .map(HiveColumnProjectionInfo::getDereferenceIndices) + .orElse(ImmutableList.of()); + HiveType fromType = getHiveTypeForDereferences(columnMapping.getBaseTypeCoercionFrom().get(), dereferenceIndices).orElseThrow(); + HiveType toType = columnMapping.getHiveColumnHandle().getHiveType(); + coercer = createCoercer(typeManager, fromType, toType, coercionContext); + } + + int inputChannel = columnMapping.getIndex(); + if (coercer.isPresent()) { + transforms.transform(inputChannel, coercer.get()); + } + else { + transforms.column(inputChannel); + } + } + case INTERIM -> { + // interim columns don't show up in output + } + } + } + return transforms.build(pageSource); + } + private static boolean shouldSkipBucket(HiveTableHandle hiveTable, HiveSplit hiveSplit, DynamicFilter dynamicFilter) { if (hiveSplit.getTableBucketNumber().isEmpty()) { @@ -267,13 +318,38 @@ private static boolean shouldSkipSplit(List columnMappings, Dynam return false; } - private static ReaderProjectionsAdapter hiveProjectionsAdapter(List expectedColumns, ReaderColumns readColumns) + /** + * Create a page source using base columns and project the final shape from the base columns. + * This utility is used for page sources that do not handle column dereferences directly. + */ + public static ConnectorPageSource projectColumnDereferences(List columns, Function, ConnectorPageSource> pageSourceFactory) { - return new ReaderProjectionsAdapter( - expectedColumns.stream().map(ColumnHandle.class::cast).collect(toImmutableList()), - readColumns, - column -> ((HiveColumnHandle) column).getType(), - HivePageSourceProvider::getProjection); + // determine base columns and create transform to project the final shape from the base columns + List baseColumns = new ArrayList<>(); + TransformConnectorPageSource.Builder transforms = TransformConnectorPageSource.builder(); + Map baseColumnOrdinalByColumnIndex = new HashMap<>(); + for (HiveColumnHandle column : columns) { + HiveColumnHandle baseColumn = column.getBaseColumn(); + Integer ordinal = baseColumnOrdinalByColumnIndex.get(baseColumn.getBaseHiveColumnIndex()); + if (ordinal == null) { + ordinal = baseColumns.size(); + baseColumnOrdinalByColumnIndex.put(baseColumn.getBaseHiveColumnIndex(), ordinal); + baseColumns.add(baseColumn); + } + + if (column.isBaseColumn()) { + transforms.column(ordinal); + } + else { + transforms.dereferenceField(ImmutableList.builder() + .add(ordinal) + .addAll(getProjection(column, baseColumn)) + .build()); + } + } + + ConnectorPageSource connectorPageSource = pageSourceFactory.apply(baseColumns); + return transforms.build(connectorPageSource); } public static List getProjection(ColumnHandle expected, ColumnHandle read) @@ -633,158 +709,4 @@ static Optional createBucketValidator(Location path, Optional projectBaseColumns(List columns) - { - return projectBaseColumns(columns, false); - } - - /** - * Creates a mapping between the input {@code columns} and base columns based on baseHiveColumnIndex or baseColumnName if required. - */ - public static Optional projectBaseColumns(List columns, boolean useColumnNames) - { - requireNonNull(columns, "columns is null"); - - // No projection is required if all columns are base columns - if (columns.stream().allMatch(HiveColumnHandle::isBaseColumn)) { - return Optional.empty(); - } - - ImmutableList.Builder projectedColumns = ImmutableList.builder(); - ImmutableList.Builder outputColumnMapping = ImmutableList.builder(); - Map mappedHiveBaseColumnKeys = new HashMap<>(); - int projectedColumnCount = 0; - - for (HiveColumnHandle column : columns) { - Object baseColumnKey = useColumnNames ? column.getBaseColumnName() : column.getBaseHiveColumnIndex(); - Integer mapped = mappedHiveBaseColumnKeys.get(baseColumnKey); - - if (mapped == null) { - projectedColumns.add(column.getBaseColumn()); - mappedHiveBaseColumnKeys.put(baseColumnKey, projectedColumnCount); - outputColumnMapping.add(projectedColumnCount); - projectedColumnCount++; - } - else { - outputColumnMapping.add(mapped); - } - } - - return Optional.of(new ReaderColumns(projectedColumns.build(), outputColumnMapping.build())); - } - - /** - * Creates a set of sufficient columns for the input projected columns and prepares a mapping between the two. For example, - * if input columns include columns "a.b" and "a.b.c", then they will be projected from a single column "a.b". - */ - public static Optional projectSufficientColumns(List columns) - { - requireNonNull(columns, "columns is null"); - - if (columns.stream().allMatch(HiveColumnHandle::isBaseColumn)) { - return Optional.empty(); - } - - ImmutableBiMap.Builder dereferenceChainsBuilder = ImmutableBiMap.builder(); - - for (HiveColumnHandle column : columns) { - List indices = column.getHiveColumnProjectionInfo() - .map(HiveColumnProjectionInfo::getDereferenceIndices) - .orElse(ImmutableList.of()); - - DereferenceChain dereferenceChain = new DereferenceChain(column.getBaseColumnName(), indices); - dereferenceChainsBuilder.put(dereferenceChain, column); - } - - BiMap dereferenceChains = dereferenceChainsBuilder.build(); - - List sufficientColumns = new ArrayList<>(); - ImmutableList.Builder outputColumnMapping = ImmutableList.builder(); - - Map pickedColumns = new HashMap<>(); - - // Pick a covering column for every column - for (HiveColumnHandle columnHandle : columns) { - DereferenceChain column = dereferenceChains.inverse().get(columnHandle); - List orderedPrefixes = column.getOrderedPrefixes(); - DereferenceChain chosenColumn = null; - - // Shortest existing prefix is chosen as the input. - for (DereferenceChain prefix : orderedPrefixes) { - if (dereferenceChains.containsKey(prefix)) { - chosenColumn = prefix; - break; - } - } - - checkState(chosenColumn != null, "chosenColumn is null"); - int inputBlockIndex; - - if (pickedColumns.containsKey(chosenColumn)) { - // Use already picked column - inputBlockIndex = pickedColumns.get(chosenColumn); - } - else { - // Add a new column for the reader - sufficientColumns.add(dereferenceChains.get(chosenColumn)); - pickedColumns.put(chosenColumn, sufficientColumns.size() - 1); - inputBlockIndex = sufficientColumns.size() - 1; - } - - outputColumnMapping.add(inputBlockIndex); - } - - return Optional.of(new ReaderColumns(sufficientColumns, outputColumnMapping.build())); - } - - private static class DereferenceChain - { - private final String name; - private final List indices; - - public DereferenceChain(String name, List indices) - { - this.name = requireNonNull(name, "name is null"); - this.indices = ImmutableList.copyOf(requireNonNull(indices, "indices is null")); - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - DereferenceChain that = (DereferenceChain) o; - return Objects.equals(name, that.name) && - Objects.equals(indices, that.indices); - } - - @Override - public int hashCode() - { - return Objects.hash(name, indices); - } - - /** - * Get Prefixes of this Dereference chain in increasing order of lengths - */ - public List getOrderedPrefixes() - { - ImmutableList.Builder prefixes = ImmutableList.builder(); - - for (int prefixLen = 0; prefixLen <= indices.size(); prefixLen++) { - prefixes.add(new DereferenceChain(name, indices.subList(0, prefixLen))); - } - - return prefixes.build(); - } - } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdateBucketFunction.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdateBucketFunction.java index 45f916efcd03..1697bb8d317c 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdateBucketFunction.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveUpdateBucketFunction.java @@ -19,7 +19,7 @@ import io.trino.spi.block.SqlRow; import io.trino.spi.connector.BucketFunction; -import static io.trino.plugin.hive.HivePageSource.BUCKET_CHANNEL; +import static io.trino.plugin.hive.HivePageSourceProvider.BUCKET_CHANNEL; import static io.trino.spi.type.IntegerType.INTEGER; public class HiveUpdateBucketFunction diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java index 855b972e66fb..edd9eef8ab8b 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/MergeFileWriter.java @@ -42,9 +42,9 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.orc.OrcWriter.OrcOperation.DELETE; import static io.trino.orc.OrcWriter.OrcOperation.INSERT; -import static io.trino.plugin.hive.HivePageSource.BUCKET_CHANNEL; -import static io.trino.plugin.hive.HivePageSource.ORIGINAL_TRANSACTION_CHANNEL; -import static io.trino.plugin.hive.HivePageSource.ROW_ID_CHANNEL; +import static io.trino.plugin.hive.HivePageSourceProvider.BUCKET_CHANNEL; +import static io.trino.plugin.hive.HivePageSourceProvider.ORIGINAL_TRANSACTION_CHANNEL; +import static io.trino.plugin.hive.HivePageSourceProvider.ROW_ID_CHANNEL; import static io.trino.plugin.hive.HiveStorageFormat.ORC; import static io.trino.plugin.hive.acid.AcidSchema.ACID_COLUMN_NAMES; import static io.trino.plugin.hive.acid.AcidSchema.createAcidSchema; diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/ReaderColumns.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/ReaderColumns.java deleted file mode 100644 index 1c8b12d8a464..000000000000 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/ReaderColumns.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hive; - -import com.google.common.collect.ImmutableList; -import io.trino.spi.connector.ColumnHandle; - -import java.util.List; - -import static com.google.common.base.Preconditions.checkArgument; -import static java.util.Objects.requireNonNull; - -/** - * Stores a mapping between - * - the projected columns required by a connector level pagesource and - * - the columns supplied by format-specific page source - *

- * Currently used in {@link HivePageSource}, {@code io.trino.plugin.iceberg.IcebergPageSource}, - * and {@code io.trino.plugin.deltalake.DeltaLakePageSource}. - */ -public class ReaderColumns -{ - // columns to be read by the reader (ordered) - private final List readerColumns; - // indices for mapping expected column handles to the reader's column handles - private final List readerBlockIndices; - - public ReaderColumns(List readerColumns, List readerBlockIndices) - { - this.readerColumns = ImmutableList.copyOf(requireNonNull(readerColumns, "readerColumns is null")); - - readerBlockIndices.forEach(value -> checkArgument(value >= 0 && value < readerColumns.size(), "block index out of bounds")); - this.readerBlockIndices = ImmutableList.copyOf(requireNonNull(readerBlockIndices, "readerBlockIndices is null")); - } - - /** - * For a column required by the wrapper page source, returns the column read by the delegate page source or record cursor. - */ - public ColumnHandle getForColumnAt(int index) - { - checkArgument(index >= 0 && index < readerBlockIndices.size(), "index is not valid"); - int readerIndex = readerBlockIndices.get(index); - return readerColumns.get(readerIndex); - } - - /** - * For a channel expected by wrapper page source, returns the channel index in the underlying page source or record cursor. - */ - public int getPositionForColumnAt(int index) - { - checkArgument(index >= 0 && index < readerBlockIndices.size(), "index is invalid"); - return readerBlockIndices.get(index); - } - - /** - * returns the actual list of columns being read by underlying page source or record cursor in order. - */ - public List get() - { - return readerColumns; - } -} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/ReaderPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/ReaderPageSource.java deleted file mode 100644 index ca87810e4fc7..000000000000 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/ReaderPageSource.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hive; - -import io.trino.spi.connector.ConnectorPageSource; - -import java.util.Optional; - -import static java.util.Objects.requireNonNull; - -/** - * A wrapper class for - *

    - *
  • delegate reader page source and - *
  • columns to be read by the delegate (present only if different from the columns desired by the connector pagesource) - *
- */ -public class ReaderPageSource -{ - private final ConnectorPageSource connectorPageSource; - private final Optional columns; - - public ReaderPageSource(ConnectorPageSource connectorPageSource, Optional columns) - { - this.connectorPageSource = requireNonNull(connectorPageSource, "connectorPageSource is null"); - this.columns = requireNonNull(columns, "columns is null"); - } - - public ConnectorPageSource get() - { - return connectorPageSource; - } - - public Optional getReaderColumns() - { - return columns; - } - - public static ReaderPageSource noProjectionAdaptation(ConnectorPageSource connectorPageSource) - { - return new ReaderPageSource(connectorPageSource, Optional.empty()); - } -} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/ReaderProjectionsAdapter.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/ReaderProjectionsAdapter.java deleted file mode 100644 index fa0f9ce7c84d..000000000000 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/ReaderProjectionsAdapter.java +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hive; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; -import io.trino.spi.Page; -import io.trino.spi.block.Block; -import io.trino.spi.block.LazyBlock; -import io.trino.spi.block.LazyBlockLoader; -import io.trino.spi.connector.ColumnHandle; -import io.trino.spi.type.Type; -import jakarta.annotation.Nullable; - -import java.util.List; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.spi.block.RowBlock.getRowFieldsFromBlock; -import static java.util.Objects.requireNonNull; - -public class ReaderProjectionsAdapter -{ - private final List outputToInputMapping; - private final List outputTypes; - private final List inputTypes; - - public ReaderProjectionsAdapter( - List expectedColumns, - ReaderColumns readColumns, - ColumnTypeGetter typeGetter, - ProjectionGetter projectionGetter) - { - requireNonNull(expectedColumns, "expectedColumns is null"); - requireNonNull(readColumns, "readColumns is null"); - - ImmutableList.Builder mappingBuilder = ImmutableList.builder(); - - for (int i = 0; i < expectedColumns.size(); i++) { - ColumnHandle projectedColumnHandle = readColumns.getForColumnAt(i); - int inputChannel = readColumns.getPositionForColumnAt(i); - List dereferences = projectionGetter.get(expectedColumns.get(i), projectedColumnHandle); - - mappingBuilder.add(new ChannelMapping(inputChannel, dereferences)); - } - - outputToInputMapping = mappingBuilder.build(); - - outputTypes = expectedColumns.stream() - .map(typeGetter::get) - .collect(toImmutableList()); - - inputTypes = readColumns.get().stream() - .map(typeGetter::get) - .collect(toImmutableList()); - } - - @Nullable - public Page adaptPage(@Nullable Page input) - { - if (input == null) { - return null; - } - - Block[] blocks = new Block[outputToInputMapping.size()]; - - // Prepare adaptations to extract dereferences - for (int i = 0; i < outputToInputMapping.size(); i++) { - ChannelMapping mapping = outputToInputMapping.get(i); - - Block inputBlock = input.getBlock(mapping.getInputChannelIndex()); - blocks[i] = createAdaptedLazyBlock(inputBlock, mapping.getDereferenceSequence()); - } - - return new Page(input.getPositionCount(), blocks); - } - - private static Block createAdaptedLazyBlock(Block inputBlock, List dereferenceSequence) - { - if (dereferenceSequence.isEmpty()) { - return inputBlock; - } - - if (inputBlock == null) { - return null; - } - - return new LazyBlock(inputBlock.getPositionCount(), new DereferenceBlockLoader(inputBlock, dereferenceSequence)); - } - - private static class DereferenceBlockLoader - implements LazyBlockLoader - { - private final List dereferenceSequence; - private boolean loaded; - private Block inputBlock; - - DereferenceBlockLoader(Block inputBlock, List dereferenceSequence) - { - this.inputBlock = requireNonNull(inputBlock, "inputBlock is null"); - this.dereferenceSequence = requireNonNull(dereferenceSequence, "dereferenceSequence is null"); - } - - @Override - public Block load() - { - checkState(!loaded, "Already loaded"); - Block loadedBlock = loadInternalBlock(dereferenceSequence, inputBlock); - inputBlock = null; - loaded = true; - return loadedBlock; - } - - /** - * Applies dereference operations on the input block to extract the required internal block. If the input block is lazy - * in a nested manner, this implementation avoids loading the entire input block. - */ - private Block loadInternalBlock(List dereferences, Block parentBlock) - { - if (dereferences.isEmpty()) { - return parentBlock.getLoadedBlock(); - } - - List fields = getRowFieldsFromBlock(parentBlock); - - int dereferenceIndex = dereferences.get(0); - List remainingDereferences = dereferences.subList(1, dereferences.size()); - - Block fieldBlock = fields.get(dereferenceIndex); - return loadInternalBlock(remainingDereferences, fieldBlock); - } - } - - List getOutputToInputMapping() - { - return outputToInputMapping; - } - - List getOutputTypes() - { - return outputTypes; - } - - List getInputTypes() - { - return inputTypes; - } - - @VisibleForTesting - static class ChannelMapping - { - private final int inputChannelIndex; - private final List dereferenceSequence; - - public ChannelMapping(int inputBlockIndex, List dereferenceSequence) - { - checkArgument(inputBlockIndex >= 0, "inputBlockIndex cannot be negative"); - this.inputChannelIndex = inputBlockIndex; - this.dereferenceSequence = ImmutableList.copyOf(requireNonNull(dereferenceSequence, "dereferenceSequence is null")); - } - - public int getInputChannelIndex() - { - return inputChannelIndex; - } - - public List getDereferenceSequence() - { - return dereferenceSequence; - } - } - - public interface ColumnTypeGetter - { - Type get(ColumnHandle column); - } - - public interface ProjectionGetter - { - List get(ColumnHandle required, ColumnHandle read); - } -} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/TransformConnectorPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/TransformConnectorPageSource.java new file mode 100644 index 000000000000..9f56e1b7411e --- /dev/null +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/TransformConnectorPageSource.java @@ -0,0 +1,364 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.errorprone.annotations.CheckReturnValue; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; +import io.trino.spi.metrics.Metrics; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.function.ObjLongConsumer; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.plugin.base.util.Closables.closeAllSuppress; +import static io.trino.spi.block.RowBlock.getRowFieldsFromBlock; +import static java.util.Objects.requireNonNull; + +public final class TransformConnectorPageSource + implements ConnectorPageSource +{ + private final ConnectorPageSource connectorPageSource; + private final Function transform; + + @CheckReturnValue + public static TransformConnectorPageSource create(ConnectorPageSource connectorPageSource, Function transform) + { + return new TransformConnectorPageSource(connectorPageSource, transform); + } + + private TransformConnectorPageSource(ConnectorPageSource connectorPageSource, Function transform) + { + this.connectorPageSource = requireNonNull(connectorPageSource, "connectorPageSource is null"); + this.transform = requireNonNull(transform, "transform is null"); + } + + @Override + public long getCompletedBytes() + { + return connectorPageSource.getCompletedBytes(); + } + + @Override + public OptionalLong getCompletedPositions() + { + return connectorPageSource.getCompletedPositions(); + } + + @Override + public long getReadTimeNanos() + { + return connectorPageSource.getReadTimeNanos(); + } + + @Override + public boolean isFinished() + { + return connectorPageSource.isFinished(); + } + + @Override + public SourcePage getNextSourcePage() + { + try { + SourcePage page = connectorPageSource.getNextSourcePage(); + if (page == null) { + return null; + } + return transform.apply(page); + } + catch (Throwable e) { + closeAllSuppress(e, connectorPageSource); + throw e; + } + } + + @Override + public long getMemoryUsage() + { + return connectorPageSource.getMemoryUsage(); + } + + @Override + public void close() + throws IOException + { + connectorPageSource.close(); + } + + @Override + public CompletableFuture isBlocked() + { + return connectorPageSource.isBlocked(); + } + + @Override + public Metrics getMetrics() + { + return connectorPageSource.getMetrics(); + } + + @CheckReturnValue + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private final List> transforms = new ArrayList<>(); + private boolean requiresTransform; + + private Builder() {} + + @CanIgnoreReturnValue + public Builder constantValue(Block constantValue) + { + requiresTransform = true; + transforms.add(new ConstantValue(constantValue)); + return this; + } + + @CanIgnoreReturnValue + public Builder column(int inputField) + { + return column(inputField, Optional.empty()); + } + + @CanIgnoreReturnValue + public Builder column(int inputField, Optional> transform) + { + if (transform.isPresent()) { + return transform(inputField, transform.get()); + } + + if (inputField != transforms.size()) { + requiresTransform = true; + } + transforms.add(new InputColumn(inputField)); + return this; + } + + @CanIgnoreReturnValue + public Builder dereferenceField(List path) + { + return dereferenceField(path, Optional.empty()); + } + + @CanIgnoreReturnValue + public Builder dereferenceField(List path, Optional> transform) + { + requireNonNull(path, "path is null"); + if (path.size() == 1) { + return column(path.getFirst(), transform); + } + + requiresTransform = true; + transforms.add(new DereferenceFieldTransform(path, transform)); + return this; + } + + @CanIgnoreReturnValue + public Builder transform(int inputColumn, Function transform) + { + requireNonNull(transform, "transform is null"); + requiresTransform = true; + transforms.add(new TransformBlock(transform, inputColumn)); + return this; + } + + @CanIgnoreReturnValue + public Builder transform(Function transform) + { + requiresTransform = true; + transforms.add(transform); + return this; + } + + @CheckReturnValue + public ConnectorPageSource build(ConnectorPageSource pageSource) + { + if (!requiresTransform) { + return pageSource; + } + + List> functions = List.copyOf(transforms); + return new TransformConnectorPageSource(pageSource, new TransformPages(functions)); + } + } + + private record ConstantValue(Block constantValue) + implements Function + { + @Override + public Block apply(SourcePage page) + { + return RunLengthEncodedBlock.create(constantValue, page.getPositionCount()); + } + } + + private record InputColumn(int inputField) + implements Function + { + @Override + public Block apply(SourcePage page) + { + return page.getBlock(inputField); + } + } + + private record DereferenceFieldTransform(List path, Optional> transform) + implements Function + { + private DereferenceFieldTransform + { + path = ImmutableList.copyOf(requireNonNull(path, "path is null")); + checkArgument(!path.isEmpty(), "path is empty"); + checkArgument(path.stream().allMatch(element -> element >= 0), "path element is negative"); + requireNonNull(transform, "transform is null"); + } + + @Override + public Block apply(SourcePage sourcePage) + { + Block block = sourcePage.getBlock(path.getFirst()); + for (int dereferenceIndex : path.subList(1, path.size())) { + block = getRowFieldsFromBlock(block).get(dereferenceIndex); + } + if (transform.isPresent()) { + block = transform.get().apply(block); + } + return block; + } + } + + private record TransformBlock(Function transform, int inputColumn) + implements Function + { + @Override + public Block apply(SourcePage page) + { + return transform.apply(page.getBlock(inputColumn)); + } + } + + private record TransformPages(List> functions) + implements Function + { + private TransformPages + { + functions = List.copyOf(requireNonNull(functions, "functions is null")); + } + + @Override + public SourcePage apply(SourcePage page) + { + return new TransformSourcePage(page, functions); + } + } + + private record TransformSourcePage(SourcePage sourcePage, List> transforms, Block[] blocks) + implements SourcePage + { + private TransformSourcePage(SourcePage sourcePage, List> transforms) + { + this(sourcePage, transforms, new Block[transforms.size()]); + } + + private TransformSourcePage + { + requireNonNull(sourcePage, "sourcePage is null"); + transforms = List.copyOf(requireNonNull(transforms, "transforms is null")); + requireNonNull(blocks, "blocks is null"); + checkArgument(transforms.size() == blocks.length, "transforms and blocks size mismatch"); + } + + @Override + public int getPositionCount() + { + return sourcePage.getPositionCount(); + } + + @Override + public long getSizeInBytes() + { + return sourcePage.getSizeInBytes(); + } + + @Override + public long getRetainedSizeInBytes() + { + return sourcePage.getRetainedSizeInBytes(); + } + + @Override + public void retainedBytesForEachPart(ObjLongConsumer consumer) + { + for (Block block : blocks) { + if (block != null) { + block.retainedBytesForEachPart(consumer); + } + } + } + + @Override + public int getChannelCount() + { + return blocks.length; + } + + @Override + public Block getBlock(int channel) + { + Block block = blocks[channel]; + if (block == null) { + block = transforms.get(channel).apply(sourcePage); + blocks[channel] = block; + } + return block; + } + + @Override + public Page getPage() + { + for (int i = 0; i < blocks.length; i++) { + getBlock(i); + } + return new Page(getPositionCount(), blocks); + } + + @Override + public void selectPositions(int[] positions, int offset, int size) + { + sourcePage.selectPositions(positions, offset, size); + for (int i = 0; i < blocks.length; i++) { + Block block = blocks[i]; + if (block != null) { + blocks[i] = block.getPositions(positions, offset, size); + } + } + } + } +} diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSource.java index 54503fc61779..8085a757d95f 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSource.java @@ -18,9 +18,9 @@ import io.trino.hive.formats.avro.AvroFileReader; import io.trino.hive.formats.avro.AvroTypeBlockHandler; import io.trino.hive.formats.avro.AvroTypeException; -import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import org.apache.avro.Schema; import java.io.IOException; @@ -75,11 +75,11 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { try { if (avroFileReader.hasNext()) { - return avroFileReader.next(); + return SourcePage.create(avroFileReader.next()); } else { return null; diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSourceFactory.java index 20209c24b4ae..db9052f69cd2 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSourceFactory.java @@ -27,11 +27,9 @@ import io.trino.plugin.hive.AcidInfo; import io.trino.plugin.hive.HiveColumnHandle; import io.trino.plugin.hive.HivePageSourceFactory; -import io.trino.plugin.hive.HiveTimestampPrecision; -import io.trino.plugin.hive.ReaderColumns; -import io.trino.plugin.hive.ReaderPageSource; import io.trino.plugin.hive.acid.AcidTransaction; import io.trino.spi.TrinoException; +import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.EmptyPageSource; import io.trino.spi.predicate.TupleDomain; @@ -53,12 +51,10 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.hive.formats.HiveClassNames.AVRO_SERDE_CLASS; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CANNOT_OPEN_SPLIT; -import static io.trino.plugin.hive.HivePageSourceProvider.projectBaseColumns; +import static io.trino.plugin.hive.HivePageSourceProvider.projectColumnDereferences; import static io.trino.plugin.hive.HiveSessionProperties.getTimestampPrecision; -import static io.trino.plugin.hive.ReaderPageSource.noProjectionAdaptation; import static io.trino.plugin.hive.avro.AvroHiveFileUtils.getCanonicalToGivenFieldName; import static io.trino.plugin.hive.avro.AvroHiveFileUtils.wrapInUnionWithNull; import static io.trino.plugin.hive.util.HiveUtil.splitError; @@ -80,7 +76,7 @@ public AvroPageSourceFactory(TrinoFileSystemFactory trinoFileSystemFactory) } @Override - public Optional createPageSource( + public Optional createPageSource( ConnectorSession session, Location path, long start, @@ -100,19 +96,35 @@ public Optional createPageSource( } checkArgument(acidInfo.isEmpty(), "Acid is not supported"); - List projectedReaderColumns = columns; - Optional readerProjections = projectBaseColumns(columns); + TrinoFileSystem trinoFileSystem = trinoFileSystemFactory.create(session); + TrinoInputFile inputFile = cacheInputIfSmall(path, start, length, estimatedFileSize, trinoFileSystem); - if (readerProjections.isPresent()) { - projectedReaderColumns = readerProjections.get().get().stream() - .map(HiveColumnHandle.class::cast) - .collect(toImmutableList()); + long actualSplitSize; + try { + actualSplitSize = min(inputFile.length() - start, length); + } + catch (IOException e) { + throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, splitError(e, path, start, length), e); } - TrinoFileSystem trinoFileSystem = trinoFileSystemFactory.create(session); - TrinoInputFile inputFile = trinoFileSystem.newInputFile(path); - HiveTimestampPrecision hiveTimestampPrecision = getTimestampPrecision(session); + // Split may be empty now that the correct file size is known + if (actualSplitSize <= 0) { + return Optional.of(new EmptyPageSource()); + } + + return Optional.of(projectColumnDereferences(columns, baseColumns -> createPageSource(session, trinoFileSystem, inputFile, start, actualSplitSize, schema, baseColumns))); + } + private static AvroPageSource createPageSource( + ConnectorSession session, + TrinoFileSystem trinoFileSystem, + TrinoInputFile inputFile, + long start, + long length, + io.trino.plugin.hive.Schema schema, + List columns) + { + verify(columns.stream().allMatch(HiveColumnHandle::isBaseColumn), "All columns must be base columns"); Schema tableSchema; try { tableSchema = AvroHiveFileUtils.determineSchemaOrThrowException(trinoFileSystem, schema.serdeProperties()); @@ -121,42 +133,22 @@ public Optional createPageSource( throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, "Unable to load or parse schema", e); } - try { - if (estimatedFileSize < BUFFER_SIZE.toBytes()) { - try (TrinoInputStream input = inputFile.newStream()) { - byte[] data = input.readAllBytes(); - inputFile = new MemoryInputFile(path, Slices.wrappedBuffer(data)); - } - } - length = min(inputFile.length() - start, length); - } - catch (TrinoException e) { - throw e; - } - catch (Exception e) { - throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, splitError(e, path, start, length), e); - } - - // Split may be empty now that the correct file size is known - if (length <= 0) { - return Optional.of(noProjectionAdaptation(new EmptyPageSource())); - } - Schema maskedSchema; try { - maskedSchema = maskColumnsFromTableSchema(projectedReaderColumns, tableSchema); + maskedSchema = maskColumnsFromTableSchema(columns, tableSchema); } catch (org.apache.avro.AvroTypeException e) { - throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, "Avro type resolution error when initializing split from %s".formatted(path), e); + throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, "Avro type resolution error when initializing split from %s".formatted(inputFile.location()), e); } + int hiveTimestampPrecision = getTimestampPrecision(session).getPrecision(); if (maskedSchema.getFields().isEmpty()) { - // no non-masked columns to select from partition schema - // hack to return null rows with same total count as underlying data file - // will error if UUID is same name as base column for underlying storage table but should never - // return false data. If file data has f+uuid column in schema then resolution of read null from not null will fail. + // No non-masked columns to select from partition schema. + // Hack to return null rows with the same total count as the underlying data file. + // This will error if UUID is the same name as base column for underlying storage table, but this should never + // return false data. If file data has f+uuid column in schema, then the resolution of read null from not null will fail. SchemaBuilder.FieldAssembler nullSchema = SchemaBuilder.record("null_only").fields(); - for (int i = 0; i < Math.max(projectedReaderColumns.size(), 1); i++) { + for (int i = 0; i < Math.max(columns.size(), 1); i++) { String notAColumnName = null; while (Objects.isNull(notAColumnName) || Objects.nonNull(tableSchema.getField(notAColumnName))) { notAColumnName = "f" + UUID.randomUUID().toString().replace('-', '_'); @@ -164,28 +156,28 @@ public Optional createPageSource( nullSchema = nullSchema.name(notAColumnName).type(Schema.create(Schema.Type.NULL)).withDefault(null); } try { - return Optional.of(noProjectionAdaptation(new AvroPageSource(inputFile, nullSchema.endRecord(), new HiveAvroTypeBlockHandler(createTimestampType(hiveTimestampPrecision.getPrecision())), start, length))); + return new AvroPageSource(inputFile, nullSchema.endRecord(), new HiveAvroTypeBlockHandler(createTimestampType(hiveTimestampPrecision)), start, length); } catch (IOException e) { throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, e); } catch (AvroTypeException e) { - throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, "Avro type resolution error when initializing split from %s".formatted(path), e); + throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, "Avro type resolution error when initializing split from %s".formatted(inputFile.location()), e); } } try { - return Optional.of(new ReaderPageSource(new AvroPageSource(inputFile, maskedSchema, new HiveAvroTypeBlockHandler(createTimestampType(hiveTimestampPrecision.getPrecision())), start, length), readerProjections)); + return new AvroPageSource(inputFile, maskedSchema, new HiveAvroTypeBlockHandler(createTimestampType(hiveTimestampPrecision)), start, length); } catch (IOException e) { throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, e); } catch (AvroTypeException e) { - throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, "Avro type resolution error when initializing split from %s".formatted(path), e); + throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, "Avro type resolution error when initializing split from %s".formatted(inputFile.location()), e); } } - private Schema maskColumnsFromTableSchema(List columns, Schema tableSchema) + private static Schema maskColumnsFromTableSchema(List columns, Schema tableSchema) { verify(tableSchema.getType() == Schema.Type.RECORD); Set maskedColumns = columns.stream().map(HiveColumnHandle::getBaseColumnName).collect(LinkedHashSet::new, HashSet::add, AbstractCollection::addAll); @@ -215,9 +207,9 @@ private Schema maskColumnsFromTableSchema(List columns, Schema .withDefault(defaultObj); } catch (org.apache.avro.AvroTypeException e) { - // in order to maintain backwards compatibility invalid defaults are mapped to null - // behavior defined by io.trino.tests.product.hive.TestAvroSchemaStrictness.testInvalidUnionDefaults - // solution is to make the field nullable and default-able to null. Any place default would be used, null will be + // To maintain backwards compatibility, invalid defaults are mapped to null. + // This behavior defined by io.trino.tests.product.hive.TestAvroSchemaStrictness.testInvalidUnionDefaults. + // The solution is to make the field nullable and default-able to null. Any place default would be used, null will be. if (e.getMessage().contains("Invalid default")) { maskedSchema = maskedSchema .name(field.name()) @@ -242,4 +234,24 @@ private Schema maskColumnsFromTableSchema(List columns, Schema } return maskedSchema.endRecord(); } + + private static TrinoInputFile cacheInputIfSmall(Location path, long start, long length, long estimatedFileSize, TrinoFileSystem trinoFileSystem) + { + TrinoInputFile inputFile = trinoFileSystem.newInputFile(path); + try { + if (estimatedFileSize < BUFFER_SIZE.toBytes()) { + try (TrinoInputStream input = inputFile.newStream()) { + byte[] data = input.readAllBytes(); + inputFile = new MemoryInputFile(path, Slices.wrappedBuffer(data)); + } + } + } + catch (TrinoException e) { + throw e; + } + catch (RuntimeException | IOException e) { + throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, splitError(e, path, start, length), e); + } + return inputFile; + } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/line/LinePageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/line/LinePageSource.java index 2a3de567b717..b4652ca3e1f6 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/line/LinePageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/line/LinePageSource.java @@ -21,6 +21,7 @@ import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import java.io.IOException; import java.util.OptionalLong; @@ -55,7 +56,7 @@ public LinePageSource(LineReader lineReader, LineDeserializer deserializer, Line } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { try { while (!pageBuilder.isFull() && lineReader.readLine(lineBuffer)) { @@ -64,7 +65,7 @@ public Page getNextPage() Page page = pageBuilder.build(); completedPositions += page.getPositionCount(); pageBuilder.reset(); - return page; + return SourcePage.create(page); } catch (TrinoException e) { closeAllSuppress(e, this); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/line/LinePageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/line/LinePageSourceFactory.java index 2ecd6770a0d8..21fa9d4832a1 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/line/LinePageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/line/LinePageSourceFactory.java @@ -29,11 +29,10 @@ import io.trino.plugin.hive.AcidInfo; import io.trino.plugin.hive.HiveColumnHandle; import io.trino.plugin.hive.HivePageSourceFactory; -import io.trino.plugin.hive.ReaderColumns; -import io.trino.plugin.hive.ReaderPageSource; import io.trino.plugin.hive.Schema; import io.trino.plugin.hive.acid.AcidTransaction; import io.trino.spi.TrinoException; +import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.EmptyPageSource; import io.trino.spi.predicate.TupleDomain; @@ -48,8 +47,7 @@ import static io.trino.hive.formats.line.LineDeserializer.EMPTY_LINE_DESERIALIZER; import static io.trino.hive.thrift.metastore.hive_metastoreConstants.FILE_INPUT_FORMAT; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CANNOT_OPEN_SPLIT; -import static io.trino.plugin.hive.HivePageSourceProvider.projectBaseColumns; -import static io.trino.plugin.hive.ReaderPageSource.noProjectionAdaptation; +import static io.trino.plugin.hive.HivePageSourceProvider.projectColumnDereferences; import static io.trino.plugin.hive.util.HiveUtil.getFooterCount; import static io.trino.plugin.hive.util.HiveUtil.getHeaderCount; import static io.trino.plugin.hive.util.HiveUtil.splitError; @@ -76,7 +74,7 @@ protected LinePageSourceFactory( } @Override - public Optional createPageSource( + public Optional createPageSource( ConnectorSession session, Location path, long start, @@ -98,7 +96,19 @@ public Optional createPageSource( checkArgument(acidInfo.isEmpty(), "Acid is not supported"); - // get header and footer count + return Optional.of(projectColumnDereferences(columns, baseColumns -> createPageSource(session, path, start, length, estimatedFileSize, schema, baseColumns))); + } + + private ConnectorPageSource createPageSource( + ConnectorSession session, + Location path, + long start, + long length, + long estimatedFileSize, + Schema schema, + List columns) + { + // get header and footer counts int headerCount = getHeaderCount(schema.serdeProperties()); if (headerCount > 1) { checkArgument(start == 0, "Multiple header rows are not supported for a split file"); @@ -108,20 +118,11 @@ public Optional createPageSource( checkArgument(start == 0, "Footer not supported for a split file"); } - // setup projected columns - List projectedReaderColumns = columns; - Optional readerProjections = projectBaseColumns(columns); - if (readerProjections.isPresent()) { - projectedReaderColumns = readerProjections.get().get().stream() - .map(HiveColumnHandle.class::cast) - .collect(toImmutableList()); - } - // create deserializer LineDeserializer lineDeserializer = EMPTY_LINE_DESERIALIZER; if (!columns.isEmpty()) { lineDeserializer = lineDeserializerFactory.create( - projectedReaderColumns.stream() + columns.stream() .map(column -> new Column(column.getName(), column.getType(), column.getBaseHiveColumnIndex())) .collect(toImmutableList()), schema.serdeProperties()); @@ -141,16 +142,15 @@ public Optional createPageSource( // Skip empty inputs if (length <= 0) { - return Optional.of(noProjectionAdaptation(new EmptyPageSource())); + return new EmptyPageSource(); } LineReader lineReader = lineReaderFactory.createLineReader(inputFile, start, length, headerCount, footerCount); // Split may be empty after discovering the real file size and skipping headers if (lineReader.isClosed()) { - return Optional.of(noProjectionAdaptation(new EmptyPageSource())); + return new EmptyPageSource(); } - LinePageSource pageSource = new LinePageSource(lineReader, lineDeserializer, lineReaderFactory.createLineBuffer(), path); - return Optional.of(new ReaderPageSource(pageSource, readerProjections)); + return new LinePageSource(lineReader, lineDeserializer, lineReaderFactory.createLineBuffer(), path); } catch (TrinoException e) { throw e; diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcDeleteDeltaPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcDeleteDeltaPageSource.java index b948206a8bbe..c03e671859d7 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcDeleteDeltaPageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcDeleteDeltaPageSource.java @@ -27,9 +27,9 @@ import io.trino.orc.OrcReaderOptions; import io.trino.orc.OrcRecordReader; import io.trino.plugin.base.metrics.FileFormatDataSourceStats; -import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import java.io.IOException; import java.io.UncheckedIOException; @@ -135,6 +135,7 @@ private OrcDeleteDeltaPageSource( rowIdColumns, ImmutableList.of(BIGINT, INTEGER, BIGINT), ImmutableList.of(fullyProjectedLayout(), fullyProjectedLayout(), fullyProjectedLayout()), + false, OrcPredicate.TRUE, 0, fileSize, @@ -164,10 +165,10 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { try { - Page page = recordReader.nextPage(); + SourcePage page = recordReader.nextPage(); if (page == null) { close(); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcDeletedRows.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcDeletedRows.java index a7204bab872d..64514d63303f 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcDeletedRows.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcDeletedRows.java @@ -27,9 +27,9 @@ import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; -import io.trino.spi.block.DictionaryBlock; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.EmptyPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.security.ConnectorIdentity; import jakarta.annotation.Nullable; @@ -40,6 +40,7 @@ import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Set; +import java.util.function.ObjLongConsumer; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; @@ -100,52 +101,20 @@ public OrcDeletedRows( this.memoryUsage = memoryContext.newLocalMemoryContext(OrcDeletedRows.class.getSimpleName()); } - public MaskDeletedRowsFunction getMaskDeletedRowsFunction(Page sourcePage, OptionalLong startRowId) + public SourcePage maskPage(SourcePage sourcePage, OptionalLong startRowId) { - return new MaskDeletedRows(sourcePage, startRowId); - } - - public interface MaskDeletedRowsFunction - { - /** - * Retained position count - */ - int getPositionCount(); - - Block apply(Block block); - - static MaskDeletedRowsFunction noMaskForPage(Page page) - { - int positionCount = page.getPositionCount(); - return new MaskDeletedRowsFunction() - { - @Override - public int getPositionCount() - { - return positionCount; - } - - @Override - public Block apply(Block block) - { - return block; - } - }; - } + return new OrcAcidMaskedSourcePage(sourcePage, startRowId); } @NotThreadSafe - private class MaskDeletedRows - implements MaskDeletedRowsFunction + private class OrcAcidMaskedSourcePage + implements SourcePage { - @Nullable - private Page sourcePage; - private int positionCount; - @Nullable - private int[] validPositions; private final OptionalLong startRowId; + private final SourcePage sourcePage; + private boolean deleteMaskApplied; - public MaskDeletedRows(Page sourcePage, OptionalLong startRowId) + public OrcAcidMaskedSourcePage(SourcePage sourcePage, OptionalLong startRowId) { this.sourcePage = requireNonNull(sourcePage, "sourcePage is null"); this.startRowId = requireNonNull(startRowId, "startRowId is null"); @@ -154,35 +123,72 @@ public MaskDeletedRows(Page sourcePage, OptionalLong startRowId) @Override public int getPositionCount() { - if (sourcePage != null) { - loadValidPositions(); - verify(sourcePage == null); + applyDeleteMaskIfNecessary(); + return sourcePage.getPositionCount(); + } + + @Override + public long getSizeInBytes() + { + if (!deleteMaskApplied) { + return 0; } + return sourcePage.getSizeInBytes(); + } - return positionCount; + @Override + public long getRetainedSizeInBytes() + { + return sourcePage.getRetainedSizeInBytes(); } @Override - public Block apply(Block block) + public void retainedBytesForEachPart(ObjLongConsumer consumer) { - if (sourcePage != null) { - loadValidPositions(); - verify(sourcePage == null); - } + sourcePage.retainedBytesForEachPart(consumer); + } + + @Override + public int getChannelCount() + { + return sourcePage.getChannelCount(); + } + + @Override + public Block getBlock(int channel) + { + applyDeleteMaskIfNecessary(); + return sourcePage.getBlock(channel); + } + + @Override + public Page getPage() + { + applyDeleteMaskIfNecessary(); + return sourcePage.getPage(); + } + + @Override + public void selectPositions(int[] positions, int offset, int size) + { + applyDeleteMaskIfNecessary(); + sourcePage.selectPositions(positions, offset, size); + } - if (positionCount == block.getPositionCount()) { - return block; + private void applyDeleteMaskIfNecessary() + { + if (deleteMaskApplied) { + return; } - return DictionaryBlock.create(positionCount, block, validPositions); + applyDeleteMask(); } - private void loadValidPositions() + private void applyDeleteMask() { - verify(sourcePage != null, "sourcePage is null"); + verify(!deleteMaskApplied, "mask already applied"); Set deletedRows = getDeletedRows(); if (deletedRows.isEmpty()) { - this.positionCount = sourcePage.getPositionCount(); - this.sourcePage = null; + deleteMaskApplied = true; return; } @@ -195,9 +201,8 @@ private void loadValidPositions() validPositionsIndex++; } } - this.positionCount = validPositionsIndex; - this.validPositions = validPositions; - this.sourcePage = null; + sourcePage.selectPositions(validPositions, 0, validPositionsIndex); + deleteMaskApplied = true; } private RowId getRowId(int position) @@ -240,7 +245,7 @@ private Set getDeletedRows() /** * Triggers loading of deleted rows ids. Single call to the method may load just part of ids. * If more ids to be loaded remain, method returns false and should be called once again. - * Final call will return true and the loaded ids can be consumed via {@link #getMaskDeletedRowsFunction(Page, OptionalLong)} + * Final call will return true and the loaded ids can be consumed via {@link #maskPage(SourcePage, OptionalLong)} * * @return true when fully loaded, and false if this method should be called again */ @@ -296,7 +301,7 @@ private class Loader @Nullable private Location currentPath; @Nullable - private Page currentPage; + private SourcePage currentPage; private int currentPagePosition; public Optional> loadOrYield() @@ -321,7 +326,7 @@ public Optional> loadOrYield() if (currentPageSource != null) { while (!currentPageSource.isFinished() || currentPage != null) { if (currentPage == null) { - currentPage = currentPageSource.getNextPage(); + currentPage = currentPageSource.getNextSourcePage(); currentPagePosition = 0; } @@ -381,7 +386,7 @@ public void close() } } - private static long retainedMemorySize(int rowCount, @Nullable Page currentPage) + private static long retainedMemorySize(int rowCount, @Nullable SourcePage currentPage) { long pageSize = (currentPage != null) ? currentPage.getRetainedSizeInBytes() : 0; return sizeOfObjectArray(rowCount) + ((long) rowCount * RowId.INSTANCE_SIZE) + pageSize; diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSource.java index 19f6332d433d..4c8df657c118 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSource.java @@ -13,7 +13,6 @@ */ package io.trino.plugin.hive.orc; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.io.Closer; import io.trino.memory.context.AggregatedMemoryContext; @@ -25,50 +24,29 @@ import io.trino.orc.metadata.CompressionKind; import io.trino.plugin.base.metrics.FileFormatDataSourceStats; import io.trino.plugin.base.metrics.LongCount; -import io.trino.plugin.hive.coercions.TypeCoercer; -import io.trino.plugin.hive.orc.OrcDeletedRows.MaskDeletedRowsFunction; -import io.trino.spi.Page; import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; -import io.trino.spi.block.LazyBlock; -import io.trino.spi.block.LazyBlockLoader; -import io.trino.spi.block.LongArrayBlock; -import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.metrics.Metrics; -import io.trino.spi.type.Type; import java.io.IOException; import java.io.UncheckedIOException; -import java.util.List; import java.util.Optional; import java.util.OptionalLong; import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; import static io.trino.plugin.base.util.Closables.closeAllSuppress; import static io.trino.plugin.hive.HiveErrorCode.HIVE_BAD_DATA; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CURSOR_ERROR; -import static io.trino.plugin.hive.HivePageSource.BUCKET_CHANNEL; -import static io.trino.plugin.hive.HivePageSource.ORIGINAL_TRANSACTION_CHANNEL; -import static io.trino.plugin.hive.HivePageSource.ROW_ID_CHANNEL; -import static io.trino.plugin.hive.orc.OrcFileWriter.computeBucketValue; -import static io.trino.spi.block.RowBlock.fromFieldBlocks; -import static io.trino.spi.predicate.Utils.nativeValueToBlock; -import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.IntegerType.INTEGER; import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class OrcPageSource implements ConnectorPageSource { - private static final Block ORIGINAL_FILE_TRANSACTION_ID_BLOCK = nativeValueToBlock(BIGINT, 0L); - public static final String ORC_CODEC_METRIC_PREFIX = "OrcReaderCompressionFormat_"; + private static final String ORC_CODEC_METRIC_PREFIX = "OrcReaderCompressionFormat_"; private final OrcRecordReader recordReader; - private final List columnAdaptations; private final OrcDataSource orcDataSource; private final Optional deletedRows; @@ -85,11 +63,10 @@ public class OrcPageSource private long completedPositions; - private Optional outstandingPage = Optional.empty(); + private Optional outstandingPage = Optional.empty(); public OrcPageSource( OrcRecordReader recordReader, - List columnAdaptations, OrcDataSource orcDataSource, Optional deletedRows, Optional originalFileRowId, @@ -98,7 +75,6 @@ public OrcPageSource( CompressionKind compressionKind) { this.recordReader = requireNonNull(recordReader, "recordReader is null"); - this.columnAdaptations = ImmutableList.copyOf(requireNonNull(columnAdaptations, "columnAdaptations is null")); this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null"); this.deletedRows = requireNonNull(deletedRows, "deletedRows is null"); this.stats = requireNonNull(stats, "stats is null"); @@ -133,9 +109,9 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { - Page page; + SourcePage page; try { if (outstandingPage.isPresent()) { page = outstandingPage.get(); @@ -173,21 +149,10 @@ public Page getNextPage() localMemoryContext.setBytes(page.getRetainedSizeInBytes()); return null; // return control to engine so it can update memory usage for query } + page = deletedRows.get().maskPage(page, startRowId); } - MaskDeletedRowsFunction maskDeletedRowsFunction = deletedRows - .map(deletedRows -> deletedRows.getMaskDeletedRowsFunction(page, startRowId)) - .orElseGet(() -> MaskDeletedRowsFunction.noMaskForPage(page)); - return getColumnAdaptationsPage(page, maskDeletedRowsFunction, recordReader.getFilePosition(), startRowId); - } - - private Page getColumnAdaptationsPage(Page page, MaskDeletedRowsFunction maskDeletedRowsFunction, long filePosition, OptionalLong startRowId) - { - Block[] blocks = new Block[columnAdaptations.size()]; - for (int i = 0; i < columnAdaptations.size(); i++) { - blocks[i] = columnAdaptations.get(i).block(page, maskDeletedRowsFunction, filePosition, startRowId); - } - return new Page(maskDeletedRowsFunction.getPositionCount(), blocks); + return page; } static TrinoException handleException(OrcDataSourceId dataSourceId, Exception exception) @@ -235,8 +200,7 @@ public void close() public String toString() { return toStringHelper(this) - .add("orcDataSource", orcDataSource.getId()) - .add("columns", columnAdaptations) + .add("orcReader", recordReader) .toString(); } @@ -251,243 +215,4 @@ public Metrics getMetrics() { return new Metrics(ImmutableMap.of(ORC_CODEC_METRIC_PREFIX + compressionKind.name(), new LongCount(recordReader.getTotalDataLength()))); } - - public interface ColumnAdaptation - { - Block block(Page sourcePage, MaskDeletedRowsFunction maskDeletedRowsFunction, long filePosition, OptionalLong startRowId); - - static ColumnAdaptation nullColumn(Type type) - { - return new NullColumn(type); - } - - static ColumnAdaptation sourceColumn(int index) - { - return new SourceColumn(index); - } - - static ColumnAdaptation coercedColumn(int index, TypeCoercer typeCoercer) - { - return new CoercedColumn(sourceColumn(index), typeCoercer); - } - - static ColumnAdaptation constantColumn(Block singleValueBlock) - { - return new ConstantAdaptation(singleValueBlock); - } - - static ColumnAdaptation positionColumn() - { - return new PositionAdaptation(); - } - - static ColumnAdaptation mergedRowColumns() - { - return new MergedRowAdaptation(); - } - - static ColumnAdaptation mergedRowColumnsWithOriginalFiles(long startingRowId, int bucketId) - { - return new MergedRowAdaptationWithOriginalFiles(startingRowId, bucketId); - } - } - - private static class NullColumn - implements ColumnAdaptation - { - private final Type type; - private final Block nullBlock; - - public NullColumn(Type type) - { - this.type = requireNonNull(type, "type is null"); - this.nullBlock = type.createNullBlock(); - } - - @Override - public Block block(Page sourcePage, MaskDeletedRowsFunction maskDeletedRowsFunction, long filePosition, OptionalLong startRowId) - { - return RunLengthEncodedBlock.create(nullBlock, maskDeletedRowsFunction.getPositionCount()); - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("type", type) - .toString(); - } - } - - private static class SourceColumn - implements ColumnAdaptation - { - private final int index; - - public SourceColumn(int index) - { - checkArgument(index >= 0, "index is negative"); - this.index = index; - } - - @Override - public Block block(Page sourcePage, MaskDeletedRowsFunction maskDeletedRowsFunction, long filePosition, OptionalLong startRowId) - { - return new LazyBlock(maskDeletedRowsFunction.getPositionCount(), new MaskingBlockLoader(maskDeletedRowsFunction, sourcePage.getBlock(index))); - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("index", index) - .toString(); - } - - private static final class MaskingBlockLoader - implements LazyBlockLoader - { - private MaskDeletedRowsFunction maskDeletedRowsFunction; - private Block sourceBlock; - - public MaskingBlockLoader(MaskDeletedRowsFunction maskDeletedRowsFunction, Block sourceBlock) - { - this.maskDeletedRowsFunction = requireNonNull(maskDeletedRowsFunction, "maskDeletedRowsFunction is null"); - this.sourceBlock = requireNonNull(sourceBlock, "sourceBlock is null"); - } - - @Override - public Block load() - { - checkState(maskDeletedRowsFunction != null, "Already loaded"); - - Block resultBlock = maskDeletedRowsFunction.apply(sourceBlock.getLoadedBlock()); - - maskDeletedRowsFunction = null; - sourceBlock = null; - - return resultBlock; - } - } - } - - private static class CoercedColumn - implements ColumnAdaptation - { - private final ColumnAdaptation delegate; - private final TypeCoercer typeCoercer; - - public CoercedColumn(ColumnAdaptation delegate, TypeCoercer typeCoercer) - { - this.delegate = requireNonNull(delegate, "delegate is null"); - this.typeCoercer = requireNonNull(typeCoercer, "typeCoercer is null"); - } - - @Override - public Block block(Page sourcePage, MaskDeletedRowsFunction maskDeletedRowsFunction, long filePosition, OptionalLong startRowId) - { - Block block = delegate.block(sourcePage, maskDeletedRowsFunction, filePosition, startRowId); - return new LazyBlock(block.getPositionCount(), () -> typeCoercer.apply(block.getLoadedBlock())); - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("delegate", delegate) - .add("fromType", typeCoercer.getFromType()) - .add("toType", typeCoercer.getToType()) - .toString(); - } - } - - /* - * The rowId contains the ACID columns - - originalTransaction, rowId, bucket - */ - private static final class MergedRowAdaptation - implements ColumnAdaptation - { - @Override - public Block block(Page page, MaskDeletedRowsFunction maskDeletedRowsFunction, long filePosition, OptionalLong startRowId) - { - requireNonNull(page, "page is null"); - return maskDeletedRowsFunction.apply(fromFieldBlocks( - page.getPositionCount(), - new Block[] { - page.getBlock(ORIGINAL_TRANSACTION_CHANNEL), - page.getBlock(BUCKET_CHANNEL), - page.getBlock(ROW_ID_CHANNEL) - })); - } - } - - /** - * The rowId contains the ACID columns - - originalTransaction, rowId, bucket, - * derived from the original file. The transactionId is always zero, - * and the rowIds count up from the startingRowId. - */ - private static final class MergedRowAdaptationWithOriginalFiles - implements ColumnAdaptation - { - private final long startingRowId; - private final Block bucketBlock; - - public MergedRowAdaptationWithOriginalFiles(long startingRowId, int bucketId) - { - this.startingRowId = startingRowId; - this.bucketBlock = nativeValueToBlock(INTEGER, (long) computeBucketValue(bucketId, 0)); - } - - @Override - public Block block(Page sourcePage, MaskDeletedRowsFunction maskDeletedRowsFunction, long filePosition, OptionalLong startRowId) - { - int positionCount = sourcePage.getPositionCount(); - return maskDeletedRowsFunction.apply(fromFieldBlocks( - positionCount, - new Block[] { - RunLengthEncodedBlock.create(ORIGINAL_FILE_TRANSACTION_ID_BLOCK, positionCount), - RunLengthEncodedBlock.create(bucketBlock, positionCount), - createRowNumberBlock(startingRowId, filePosition, positionCount) - })); - } - } - - private static class ConstantAdaptation - implements ColumnAdaptation - { - private final Block singleValueBlock; - - public ConstantAdaptation(Block singleValueBlock) - { - requireNonNull(singleValueBlock, "singleValueBlock is null"); - checkArgument(singleValueBlock.getPositionCount() == 1, "ConstantColumnAdaptation singleValueBlock may only contain one position"); - this.singleValueBlock = singleValueBlock; - } - - @Override - public Block block(Page sourcePage, MaskDeletedRowsFunction maskDeletedRowsFunction, long filePosition, OptionalLong startRowId) - { - return RunLengthEncodedBlock.create(singleValueBlock, sourcePage.getPositionCount()); - } - } - - private static class PositionAdaptation - implements ColumnAdaptation - { - @Override - public Block block(Page sourcePage, MaskDeletedRowsFunction maskDeletedRowsFunction, long filePosition, OptionalLong startRowId) - { - checkArgument(startRowId.isEmpty(), "startRowId should not be specified when using PositionAdaptation"); - return createRowNumberBlock(0, filePosition, sourcePage.getPositionCount()); - } - } - - private static Block createRowNumberBlock(long startingRowId, long filePosition, int positionCount) - { - long[] translatedRowIds = new long[positionCount]; - for (int index = 0; index < positionCount; index++) { - translatedRowIds[index] = startingRowId + filePosition + index; - } - return new LongArrayBlock(positionCount, Optional.empty(), translatedRowIds); - } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java index 5c23a767fdab..dce15ea1d655 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/orc/OrcPageSourceFactory.java @@ -39,17 +39,19 @@ import io.trino.plugin.hive.HiveColumnProjectionInfo; import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.HivePageSourceFactory; -import io.trino.plugin.hive.ReaderColumns; -import io.trino.plugin.hive.ReaderPageSource; import io.trino.plugin.hive.Schema; +import io.trino.plugin.hive.TransformConnectorPageSource; import io.trino.plugin.hive.acid.AcidSchema; import io.trino.plugin.hive.acid.AcidTransaction; import io.trino.plugin.hive.coercions.TypeCoercer; -import io.trino.plugin.hive.orc.OrcPageSource.ColumnAdaptation; import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.block.LongArrayBlock; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.EmptyPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; @@ -58,10 +60,12 @@ import java.io.IOException; import java.time.Instant; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.OptionalInt; +import java.util.function.Function; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -69,6 +73,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.Maps.uniqueIndex; +import static com.google.common.collect.MoreCollectors.toOptional; import static io.trino.hive.formats.HiveClassNames.ORC_SERDE_CLASS; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.orc.OrcReader.INITIAL_BATCH_SIZE; @@ -83,7 +88,10 @@ import static io.trino.plugin.hive.HiveErrorCode.HIVE_BAD_DATA; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CANNOT_OPEN_SPLIT; import static io.trino.plugin.hive.HiveErrorCode.HIVE_FILE_MISSING_COLUMN_NAMES; -import static io.trino.plugin.hive.HivePageSourceProvider.projectBaseColumns; +import static io.trino.plugin.hive.HivePageSourceProvider.BUCKET_CHANNEL; +import static io.trino.plugin.hive.HivePageSourceProvider.ORIGINAL_TRANSACTION_CHANNEL; +import static io.trino.plugin.hive.HivePageSourceProvider.ROW_ID_CHANNEL; +import static io.trino.plugin.hive.HivePageSourceProvider.getProjection; import static io.trino.plugin.hive.HiveSessionProperties.getOrcLazyReadSmallRanges; import static io.trino.plugin.hive.HiveSessionProperties.getOrcMaxBufferSize; import static io.trino.plugin.hive.HiveSessionProperties.getOrcMaxMergeDistance; @@ -93,11 +101,13 @@ import static io.trino.plugin.hive.HiveSessionProperties.isOrcBloomFiltersEnabled; import static io.trino.plugin.hive.HiveSessionProperties.isOrcNestedLazy; import static io.trino.plugin.hive.HiveSessionProperties.isUseOrcColumnNames; -import static io.trino.plugin.hive.orc.OrcPageSource.ColumnAdaptation.mergedRowColumns; +import static io.trino.plugin.hive.orc.OrcFileWriter.computeBucketValue; import static io.trino.plugin.hive.orc.OrcPageSource.handleException; import static io.trino.plugin.hive.orc.OrcTypeTranslator.createCoercer; import static io.trino.plugin.hive.util.HiveUtil.splitError; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; +import static io.trino.spi.block.RowBlock.fromFieldBlocks; +import static io.trino.spi.predicate.Utils.nativeValueToBlock; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static java.lang.String.format; @@ -107,11 +117,11 @@ import static java.util.function.Function.identity; import static java.util.stream.Collectors.mapping; import static java.util.stream.Collectors.toList; -import static java.util.stream.Collectors.toUnmodifiableList; public class OrcPageSourceFactory implements HivePageSourceFactory { + private static final Block ORIGINAL_FILE_TRANSACTION_ID_BLOCK = nativeValueToBlock(BIGINT, 0L); private static final Pattern DEFAULT_HIVE_COLUMN_NAME_PATTERN = Pattern.compile("_col\\d+"); private final OrcReaderOptions orcReaderOptions; private final TrinoFileSystemFactory fileSystemFactory; @@ -163,7 +173,7 @@ public static boolean stripUnnecessaryProperties(String serializationLibraryName } @Override - public Optional createPageSource( + public Optional createPageSource( ConnectorSession session, Location path, long start, @@ -182,25 +192,15 @@ public Optional createPageSource( return Optional.empty(); } - List readerColumnHandles = columns; - - Optional readerColumns = projectBaseColumns(columns); - if (readerColumns.isPresent()) { - readerColumnHandles = readerColumns.get().get().stream() - .map(HiveColumnHandle.class::cast) - .collect(toUnmodifiableList()); - } - - ConnectorPageSource orcPageSource = createOrcPageSource( + return Optional.of(createOrcPageSource( session, path, start, length, estimatedFileSize, fileModifiedTime, - readerColumnHandles, columns, - isUseOrcColumnNames(session), + isUseOrcColumnNames(session) || schema.isFullAcidTable(), schema.isFullAcidTable(), effectivePredicate, legacyTimeZone, @@ -217,9 +217,7 @@ public Optional createPageSource( bucketNumber, originalFile, transaction, - stats); - - return Optional.of(new ReaderPageSource(orcPageSource, readerColumns)); + stats)); } private ConnectorPageSource createOrcPageSource( @@ -230,7 +228,6 @@ private ConnectorPageSource createOrcPageSource( long estimatedFileSize, long fileModifiedTime, List columns, - List projections, boolean useOrcColumnNames, boolean isFullAcid, TupleDomain effectivePredicate, @@ -249,7 +246,6 @@ private ConnectorPageSource createOrcPageSource( OrcDataSource orcDataSource; - boolean originalFilesPresent = acidInfo.isPresent() && !acidInfo.get().getOriginalFiles().isEmpty(); try { TrinoFileSystem fileSystem = fileSystemFactory.create(session); TrinoInputFile inputFile = fileSystem.newInputFile(path, estimatedFileSize, Instant.ofEpochMilli(fileModifiedTime)); @@ -276,110 +272,156 @@ private ConnectorPageSource createOrcPageSource( } List fileColumns = reader.getRootColumn().getNestedColumns(); - int actualColumnCount = columns.size() + (isFullAcid ? 3 : 0); - List fileReadColumns = new ArrayList<>(actualColumnCount); - List fileReadTypes = new ArrayList<>(actualColumnCount); - List fileReadLayouts = new ArrayList<>(actualColumnCount); + List fileReadColumns = new ArrayList<>(); + List fileReadTypes = new ArrayList<>(); + List fileReadLayouts = new ArrayList<>(); + boolean originalFilesPresent = acidInfo.isPresent() && hasOriginalFiles(acidInfo.get()); if (isFullAcid && !originalFilesPresent) { verifyAcidSchema(reader, path); Map acidColumnsByName = uniqueIndex(fileColumns, orcColumn -> orcColumn.getColumnName().toLowerCase(ENGLISH)); - fileColumns = ensureColumnNameConsistency(acidColumnsByName.get(AcidSchema.ACID_COLUMN_ROW_STRUCT.toLowerCase(ENGLISH)).getNestedColumns(), columns); + fileColumns = ensureColumnNameConsistency( + requireNonNull(acidColumnsByName.get(AcidSchema.ACID_COLUMN_ROW_STRUCT.toLowerCase(ENGLISH))).getNestedColumns(), + columns); - fileReadColumns.add(acidColumnsByName.get(AcidSchema.ACID_COLUMN_ORIGINAL_TRANSACTION.toLowerCase(ENGLISH))); + fileReadColumns.add(requireNonNull(acidColumnsByName.get(AcidSchema.ACID_COLUMN_ORIGINAL_TRANSACTION.toLowerCase(ENGLISH)))); fileReadTypes.add(BIGINT); fileReadLayouts.add(fullyProjectedLayout()); - fileReadColumns.add(acidColumnsByName.get(AcidSchema.ACID_COLUMN_BUCKET.toLowerCase(ENGLISH))); + fileReadColumns.add(requireNonNull(acidColumnsByName.get(AcidSchema.ACID_COLUMN_BUCKET.toLowerCase(ENGLISH)))); fileReadTypes.add(INTEGER); fileReadLayouts.add(fullyProjectedLayout()); - fileReadColumns.add(acidColumnsByName.get(AcidSchema.ACID_COLUMN_ROW_ID.toLowerCase(ENGLISH))); + fileReadColumns.add(requireNonNull(acidColumnsByName.get(AcidSchema.ACID_COLUMN_ROW_ID.toLowerCase(ENGLISH)))); fileReadTypes.add(BIGINT); fileReadLayouts.add(fullyProjectedLayout()); } Map fileColumnsByName = ImmutableMap.of(); - if (useOrcColumnNames || isFullAcid) { + if (useOrcColumnNames) { verifyFileHasColumnNames(fileColumns, path); // Convert column names read from ORC files to lower case to be consistent with those stored in Hive Metastore fileColumnsByName = uniqueIndex(fileColumns, orcColumn -> orcColumn.getColumnName().toLowerCase(ENGLISH)); } - Map>> projectionsByColumnName = ImmutableMap.of(); - Map>> projectionsByColumnIndex = ImmutableMap.of(); - if (useOrcColumnNames || isFullAcid) { - projectionsByColumnName = projections.stream() - .collect(Collectors.groupingBy( - HiveColumnHandle::getBaseColumnName, - mapping( - OrcPageSourceFactory::getDereferencesAsList, toList()))); - } - else { - projectionsByColumnIndex = projections.stream() - .collect(Collectors.groupingBy( - HiveColumnHandle::getBaseHiveColumnIndex, - mapping( - OrcPageSourceFactory::getDereferencesAsList, toList()))); - } + // Mapping from the base column key (name or index) to the dereferences paths (fields) requested by the caller + Map>> projectionsByBaseColumnKey = columns.stream() + .collect(Collectors.groupingBy( + useOrcColumnNames ? HiveColumnHandle::getBaseColumnName : HiveColumnHandle::getBaseHiveColumnIndex, + mapping( + OrcPageSourceFactory::getDereferencesAsList, toList()))); TupleDomainOrcPredicateBuilder predicateBuilder = TupleDomainOrcPredicate.builder() .setBloomFiltersEnabled(options.isBloomFiltersEnabled()) .setDomainCompactionThreshold(domainCompactionThreshold); Map effectivePredicateDomains = effectivePredicate.getDomains() .orElseThrow(() -> new IllegalArgumentException("Effective predicate is none")); - List columnAdaptations = new ArrayList<>(columns.size()); + TransformConnectorPageSource.Builder transforms = TransformConnectorPageSource.builder(); + Map baseColumnKeyToOrdinal = new HashMap<>(); for (HiveColumnHandle column : columns) { - OrcColumn orcColumn = null; - OrcReader.ProjectedLayout projectedLayout = null; - Map, Domain> columnDomains = null; - - if (useOrcColumnNames || isFullAcid) { - String columnName = column.getName().toLowerCase(ENGLISH); - orcColumn = fileColumnsByName.get(columnName); - if (orcColumn != null) { - projectedLayout = createProjectedLayout(orcColumn, projectionsByColumnName.get(columnName)); - columnDomains = effectivePredicateDomains.entrySet().stream() - .filter(columnDomain -> columnDomain.getKey().getBaseColumnName().toLowerCase(ENGLISH).equals(columnName)) - .collect(toImmutableMap(columnDomain -> columnDomain.getKey().getHiveColumnProjectionInfo(), Map.Entry::getValue)); + HiveColumnHandle baseColumn = column.getBaseColumn(); + Integer ordinal = baseColumnKeyToOrdinal.get(useOrcColumnNames ? column.getBaseColumnName() : column.getBaseHiveColumnIndex()); + if (ordinal == null) { + OrcColumn orcBaseColumn = null; + OrcReader.ProjectedLayout projectedLayout = null; + Map, Domain> columnDomains = null; + if (useOrcColumnNames) { + String columnName = baseColumn.getName().toLowerCase(ENGLISH); + orcBaseColumn = fileColumnsByName.get(columnName); + if (orcBaseColumn != null) { + projectedLayout = createProjectedLayout(orcBaseColumn, projectionsByBaseColumnKey.get(columnName)); + columnDomains = effectivePredicateDomains.entrySet().stream() + .filter(columnDomain -> columnDomain.getKey().getBaseColumnName().toLowerCase(ENGLISH).equals(columnName)) + .collect(toImmutableMap(columnDomain -> columnDomain.getKey().getHiveColumnProjectionInfo(), Map.Entry::getValue)); + } } - } - else if (column.getBaseHiveColumnIndex() < fileColumns.size()) { - orcColumn = fileColumns.get(column.getBaseHiveColumnIndex()); - if (orcColumn != null) { - projectedLayout = createProjectedLayout(orcColumn, projectionsByColumnIndex.get(column.getBaseHiveColumnIndex())); - columnDomains = effectivePredicateDomains.entrySet().stream() - .filter(columnDomain -> columnDomain.getKey().getBaseHiveColumnIndex() == column.getBaseHiveColumnIndex()) - .collect(toImmutableMap(columnDomain -> columnDomain.getKey().getHiveColumnProjectionInfo(), Map.Entry::getValue)); + else if (baseColumn.getBaseHiveColumnIndex() < fileColumns.size()) { + orcBaseColumn = fileColumns.get(baseColumn.getBaseHiveColumnIndex()); + if (orcBaseColumn != null) { + projectedLayout = createProjectedLayout(orcBaseColumn, projectionsByBaseColumnKey.get(baseColumn.getBaseHiveColumnIndex())); + columnDomains = effectivePredicateDomains.entrySet().stream() + .filter(columnDomain -> columnDomain.getKey().getBaseHiveColumnIndex() == baseColumn.getBaseHiveColumnIndex()) + .collect(toImmutableMap(columnDomain -> columnDomain.getKey().getHiveColumnProjectionInfo(), Map.Entry::getValue)); + } } - } - Type readType = column.getType(); - if (orcColumn != null) { - int sourceIndex = fileReadColumns.size(); - Optional> coercer = createCoercer(orcColumn.getColumnType(), orcColumn.getNestedColumns(), readType); - if (coercer.isPresent()) { - fileReadTypes.add(coercer.get().getFromType()); - columnAdaptations.add(ColumnAdaptation.coercedColumn(sourceIndex, coercer.get())); + if (orcBaseColumn == null) { + transforms.constantValue(column.getType().createNullBlock()); + continue; } - else { - columnAdaptations.add(ColumnAdaptation.sourceColumn(sourceIndex)); - fileReadTypes.add(readType); - } - fileReadColumns.add(orcColumn); + + ordinal = fileReadColumns.size(); + baseColumnKeyToOrdinal.put(useOrcColumnNames ? column.getBaseColumnName() : column.getBaseHiveColumnIndex(), ordinal); + fileReadColumns.add(orcBaseColumn); fileReadLayouts.add(projectedLayout); + // todo it should be possible to compute fileReadType without creating the coercer + fileReadTypes.add(createCoercer(orcBaseColumn.getColumnType(), orcBaseColumn.getNestedColumns(), baseColumn.getType()) + .map(TypeCoercer::getFromType) + .orElse(baseColumn.getType())); // Add predicates on top-level and nested columns for (Map.Entry, Domain> columnDomain : columnDomains.entrySet()) { - OrcColumn nestedColumn = getNestedColumn(orcColumn, columnDomain.getKey()); + OrcColumn nestedColumn = getNestedColumn(orcBaseColumn, columnDomain.getKey()); if (nestedColumn != null) { predicateBuilder.addColumn(nestedColumn.getColumnId(), columnDomain.getValue()); } } } + + OrcColumn orcBaseColumn = fileReadColumns.get(ordinal); + if (column.isBaseColumn()) { + Optional> coercer = createCoercer(orcBaseColumn.getColumnType(), orcBaseColumn.getNestedColumns(), column.getType()); + transforms.column(ordinal, coercer.map(identity())); + } else { - columnAdaptations.add(ColumnAdaptation.nullColumn(readType)); + // Dereference the nested column by name + OrcColumn orcFieldColumn = orcBaseColumn; + for (String fieldName : column.getHiveColumnProjectionInfo().orElseThrow().getDereferenceNames()) { + Optional nestedField = orcFieldColumn.getNestedColumns().stream() + .filter(field -> field.getColumnName().equalsIgnoreCase(fieldName)) + .collect(toOptional()); + if (nestedField.isEmpty()) { + orcFieldColumn = null; + break; + } + orcFieldColumn = nestedField.get(); + } + + if (orcFieldColumn == null) { + transforms.constantValue(column.getType().createNullBlock()); + } + else { + Optional> coercer = createCoercer(orcFieldColumn.getColumnType(), orcFieldColumn.getNestedColumns(), column.getType()); + transforms.dereferenceField( + ImmutableList.builder() + .add(ordinal) + .addAll(getProjection(column, baseColumn)) + .build(), + coercer.map(identity())); + } + } + } + + Optional originalFileRowId = acidInfo + .filter(OrcPageSourceFactory::hasOriginalFiles) + // TODO reduce number of file footer accesses. Currently this is quadratic to the number of original files. + .map(_ -> OriginalFilesUtils.getPrecedingRowCount( + acidInfo.get().getOriginalFiles(), + path, + fileSystemFactory, + session.getIdentity(), + options, + stats)); + + boolean appendRowNumberColumn = false; + if (transaction.isMerge()) { + if (originalFile) { + transforms.transform(new MergedRowAdaptationWithOriginalFiles(originalFileRowId.orElse(0L), bucketNumber.orElse(0))); + appendRowNumberColumn = true; + } + else { + transforms.transform(new MergedRowPageFunction()); } } @@ -387,6 +429,7 @@ else if (column.getBaseHiveColumnIndex() < fileColumns.size()) { fileReadColumns, fileReadTypes, fileReadLayouts, + appendRowNumberColumn, predicateBuilder.build(), start, length, @@ -406,37 +449,15 @@ else if (column.getBaseHiveColumnIndex() < fileColumns.size()) { bucketNumber, memoryUsage)); - Optional originalFileRowId = acidInfo - .filter(OrcPageSourceFactory::hasOriginalFiles) - // TODO reduce number of file footer accesses. Currently this is quadratic to the number of original files. - .map(info -> OriginalFilesUtils.getPrecedingRowCount( - acidInfo.get().getOriginalFiles(), - path, - fileSystemFactory, - session.getIdentity(), - options, - stats)); - - if (transaction.isMerge()) { - if (originalFile) { - int bucket = bucketNumber.orElse(0); - long startingRowId = originalFileRowId.orElse(0L); - columnAdaptations.add(OrcPageSource.ColumnAdaptation.mergedRowColumnsWithOriginalFiles(startingRowId, bucket)); - } - else { - columnAdaptations.add(mergedRowColumns()); - } - } - - return new OrcPageSource( + ConnectorPageSource pageSource = new OrcPageSource( recordReader, - columnAdaptations, orcDataSource, deletedRows, originalFileRowId, memoryUsage, stats, reader.getCompressionKind()); + return transforms.build(pageSource); } catch (Exception e) { try { @@ -456,21 +477,21 @@ else if (column.getBaseHiveColumnIndex() < fileColumns.size()) { private static void validateOrcAcidVersion(Location path, OrcReader reader) { - // Trino cannot read ORC ACID tables with version < 2 (written by Hive older than 3.0) + // Trino cannot read ORC ACID tables with a version < 2 (written by Hive older than 3.0) // See https://github.com/trinodb/trino/issues/2790#issuecomment-591901728 for more context - // If we did not manage to validate if ORC ACID version used by table is supported one base don _orc_acid_version metadata file + // If we did not manage to validate if ORC ACID version used by table is supported based on _orc_acid_version metadata file, // we check the data file footer. if (reader.getFooter().getNumberOfRows() == 0) { - // file is empty. assuming we are good. We do not want to depend on metadata in such case + // If the file is empty, assume the version is good. We do not want to depend on metadata in such a case // as some hadoop distributions do not write ORC ACID metadata for empty ORC files return; } int writerId = reader.getFooter().getWriterId().orElseThrow(() -> new TrinoException(HIVE_BAD_DATA, "writerId not set in ORC metadata in " + path)); if (writerId == TRINO_WRITER_ID || writerId == PRESTO_WRITER_ID) { - // file written by Trino. We are good. + // The file was written by Trino, so it is ok return; } @@ -505,21 +526,22 @@ private static Optional getHiveAcidVersion(OrcReader reader) * top-level columns, not nested columns. * * @param fileColumns All OrcColumns nested in the root column of the table. - * @param desiredColumns HiveColumnHandles for the metastore's table columns. + * @param columns Columns from the Hive metastore that are being used * @return Return the fileColumns list with any OrcColumn corresponding to a desiredColumn renamed if * the names differ from those specified in the desiredColumns. */ - private static List ensureColumnNameConsistency(List fileColumns, List desiredColumns) + private static List ensureColumnNameConsistency(List fileColumns, List columns) { int columnCount = fileColumns.size(); ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(columnCount); - Map desiredColumnsByNumber = desiredColumns.stream() + Map baseColumnsByColumnIndex = columns.stream() + .map(HiveColumnHandle::getBaseColumn) .collect(toImmutableMap(HiveColumnHandle::getBaseHiveColumnIndex, identity())); for (int index = 0; index < columnCount; index++) { OrcColumn column = fileColumns.get(index); - HiveColumnHandle handle = desiredColumnsByNumber.get(index); + HiveColumnHandle handle = baseColumnsByColumnIndex.get(index); if (handle != null && !column.getColumnName().equals(handle.getName())) { column = new OrcColumn(column.getPath(), column.getColumnId(), handle.getName(), column.getColumnType(), column.getOrcDataSourceId(), column.getNestedColumns(), column.getAttributes()); } @@ -542,6 +564,66 @@ private static void verifyFileHasColumnNames(List columns, Location p } } + /* + * The rowId contains the ACID columns - - originalTransaction, rowId, bucket + */ + private static final class MergedRowPageFunction + implements Function + { + @Override + public Block apply(SourcePage page) + { + return fromFieldBlocks( + page.getPositionCount(), + new Block[] { + page.getBlock(ORIGINAL_TRANSACTION_CHANNEL), + page.getBlock(BUCKET_CHANNEL), + page.getBlock(ROW_ID_CHANNEL) + }); + } + } + + /** + * The rowId contains the ACID columns - - originalTransaction, rowId, bucket, + * derived from the original file. The transactionId is always zero, + * and the rowIds count up from the startingRowId. + */ + private static final class MergedRowAdaptationWithOriginalFiles + implements Function + { + private final long startingRowId; + private final Block bucketBlock; + + public MergedRowAdaptationWithOriginalFiles(long startingRowId, int bucketId) + { + this.startingRowId = startingRowId; + this.bucketBlock = nativeValueToBlock(INTEGER, (long) computeBucketValue(bucketId, 0)); + } + + @Override + public Block apply(SourcePage sourcePage) + { + int positionCount = sourcePage.getPositionCount(); + + LongArrayBlock rowNumberBlock = (LongArrayBlock) sourcePage.getBlock(sourcePage.getChannelCount() - 1); + if (startingRowId != 0) { + long[] newRowNumbers = new long[rowNumberBlock.getPositionCount()]; + for (int index = 0; index < rowNumberBlock.getPositionCount(); index++) { + newRowNumbers[index] = startingRowId + rowNumberBlock.getLong(index); + } + rowNumberBlock = new LongArrayBlock(rowNumberBlock.getPositionCount(), Optional.empty(), newRowNumbers); + } + + return fromFieldBlocks( + positionCount, + new Block[] { + RunLengthEncodedBlock.create(ORIGINAL_FILE_TRANSACTION_ID_BLOCK, positionCount), + RunLengthEncodedBlock.create(bucketBlock, positionCount), + rowNumberBlock + }); + } + } + static void verifyAcidSchema(OrcReader orcReader, Location path) { OrcColumn rootColumn = orcReader.getRootColumn(); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java index 8a3510f39238..c6dc63d47cc6 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSource.java @@ -13,30 +13,20 @@ */ package io.trino.plugin.hive.parquet; -import com.google.common.collect.ImmutableList; import io.trino.parquet.Column; import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.ParquetDataSourceId; import io.trino.parquet.reader.ParquetReader; -import io.trino.plugin.hive.coercions.TypeCoercer; -import io.trino.spi.Page; import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; -import io.trino.spi.block.LazyBlock; -import io.trino.spi.block.LongArrayBlock; -import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.metrics.Metrics; -import io.trino.spi.type.Type; import java.io.IOException; import java.io.UncheckedIOException; import java.util.List; -import java.util.Optional; import java.util.OptionalLong; -import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.base.Preconditions.checkArgument; import static io.trino.plugin.base.util.Closables.closeAllSuppress; import static io.trino.plugin.hive.HiveErrorCode.HIVE_BAD_DATA; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CURSOR_ERROR; @@ -47,19 +37,13 @@ public class ParquetPageSource implements ConnectorPageSource { private final ParquetReader parquetReader; - private final List columnAdaptations; - private final boolean isColumnAdaptationRequired; private boolean closed; private long completedPositions; - private ParquetPageSource( - ParquetReader parquetReader, - List columnAdaptations) + public ParquetPageSource(ParquetReader parquetReader) { this.parquetReader = requireNonNull(parquetReader, "parquetReader is null"); - this.columnAdaptations = ImmutableList.copyOf(requireNonNull(columnAdaptations, "columnAdaptations is null")); - this.isColumnAdaptationRequired = isColumnAdaptationRequired(columnAdaptations); } public List getColumnFields() @@ -98,11 +82,11 @@ public long getMemoryUsage() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { - Page page; + SourcePage page; try { - page = getColumnAdaptationsPage(parquetReader.nextPage()); + page = parquetReader.nextPage(); } catch (IOException | RuntimeException e) { closeAllSuppress(e, this); @@ -140,70 +124,6 @@ public Metrics getMetrics() return parquetReader.getMetrics(); } - public static Builder builder() - { - return new Builder(); - } - - public static class Builder - { - private final ImmutableList.Builder columns = ImmutableList.builder(); - - private Builder() {} - - public Builder addConstantColumn(Block value) - { - columns.add(new ConstantColumn(value)); - return this; - } - - public Builder addSourceColumn(int sourceChannel) - { - columns.add(new SourceColumn(sourceChannel)); - return this; - } - - public Builder addNullColumn(Type type) - { - columns.add(new NullColumn(type)); - return this; - } - - public Builder addRowIndexColumn() - { - columns.add(new RowIndexColumn()); - return this; - } - - public Builder addCoercedColumn(int sourceChannel, TypeCoercer typeCoercer) - { - columns.add(new CoercedColumn(new SourceColumn(sourceChannel), typeCoercer)); - return this; - } - - public ConnectorPageSource build(ParquetReader parquetReader) - { - return new ParquetPageSource(parquetReader, this.columns.build()); - } - } - - private Page getColumnAdaptationsPage(Page page) - { - if (!isColumnAdaptationRequired) { - return page; - } - if (page == null) { - return null; - } - int batchSize = page.getPositionCount(); - Block[] blocks = new Block[columnAdaptations.size()]; - long startRowId = parquetReader.lastBatchStartRow(); - for (int columnChannel = 0; columnChannel < columnAdaptations.size(); columnChannel++) { - blocks[columnChannel] = columnAdaptations.get(columnChannel).getBlock(page, startRowId); - } - return new Page(batchSize, blocks); - } - static TrinoException handleException(ParquetDataSourceId dataSourceId, Exception exception) { if (exception instanceof TrinoException) { @@ -214,134 +134,4 @@ static TrinoException handleException(ParquetDataSourceId dataSourceId, Exceptio } return new TrinoException(HIVE_CURSOR_ERROR, format("Failed to read Parquet file: %s", dataSourceId), exception); } - - private static boolean isColumnAdaptationRequired(List columnAdaptations) - { - // If no synthetic columns are added and the source columns are in order, no adaptations are required - for (int columnChannel = 0; columnChannel < columnAdaptations.size(); columnChannel++) { - ColumnAdaptation column = columnAdaptations.get(columnChannel); - if (column instanceof SourceColumn) { - int delegateChannel = ((SourceColumn) column).getSourceChannel(); - if (columnChannel != delegateChannel) { - return true; - } - } - else { - return true; - } - } - return false; - } - - private interface ColumnAdaptation - { - Block getBlock(Page sourcePage, long startRowId); - } - - private static class NullColumn - implements ColumnAdaptation - { - private final Block nullBlock; - - private NullColumn(Type type) - { - this.nullBlock = type.createNullBlock(); - } - - @Override - public Block getBlock(Page sourcePage, long startRowId) - { - return RunLengthEncodedBlock.create(nullBlock, sourcePage.getPositionCount()); - } - } - - private static class SourceColumn - implements ColumnAdaptation - { - private final int sourceChannel; - - private SourceColumn(int sourceChannel) - { - checkArgument(sourceChannel >= 0, "sourceChannel is negative"); - this.sourceChannel = sourceChannel; - } - - @Override - public Block getBlock(Page sourcePage, long startRowId) - { - return sourcePage.getBlock(sourceChannel); - } - - public int getSourceChannel() - { - return sourceChannel; - } - } - - private static class ConstantColumn - implements ColumnAdaptation - { - private final Block singleValueBlock; - - private ConstantColumn(Block singleValueBlock) - { - checkArgument(singleValueBlock.getPositionCount() == 1, "ConstantColumnAdaptation singleValueBlock may only contain one position"); - this.singleValueBlock = singleValueBlock; - } - - @Override - public Block getBlock(Page sourcePage, long startRowId) - { - return RunLengthEncodedBlock.create(singleValueBlock, sourcePage.getPositionCount()); - } - } - - private static class RowIndexColumn - implements ColumnAdaptation - { - @Override - public Block getBlock(Page sourcePage, long startRowId) - { - return createRowNumberBlock(startRowId, sourcePage.getPositionCount()); - } - } - - private static class CoercedColumn - implements ParquetPageSource.ColumnAdaptation - { - private final ParquetPageSource.SourceColumn sourceColumn; - private final TypeCoercer typeCoercer; - - public CoercedColumn(ParquetPageSource.SourceColumn sourceColumn, TypeCoercer typeCoercer) - { - this.sourceColumn = requireNonNull(sourceColumn, "sourceColumn is null"); - this.typeCoercer = requireNonNull(typeCoercer, "typeCoercer is null"); - } - - @Override - public Block getBlock(Page sourcePage, long startRowId) - { - Block block = sourceColumn.getBlock(sourcePage, startRowId); - return new LazyBlock(block.getPositionCount(), () -> typeCoercer.apply(block.getLoadedBlock())); - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("sourceColumn", sourceColumn) - .add("fromType", typeCoercer.getFromType()) - .add("toType", typeCoercer.getToType()) - .toString(); - } - } - - private static Block createRowNumberBlock(long baseIndex, int size) - { - long[] rowIndices = new long[size]; - for (int position = 0; position < size; position++) { - rowIndices[position] = baseIndex + position; - } - return new LongArrayBlock(size, Optional.empty(), rowIndices); - } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index f364a9e18739..a19664cdba59 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -13,6 +13,9 @@ */ package io.trino.plugin.hive.parquet; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.BiMap; +import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -42,14 +45,15 @@ import io.trino.plugin.hive.HiveColumnProjectionInfo; import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.HivePageSourceFactory; -import io.trino.plugin.hive.ReaderColumns; -import io.trino.plugin.hive.ReaderPageSource; import io.trino.plugin.hive.Schema; +import io.trino.plugin.hive.TransformConnectorPageSource; import io.trino.plugin.hive.acid.AcidTransaction; import io.trino.plugin.hive.coercions.TypeCoercer; import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import org.apache.parquet.column.ColumnDescriptor; @@ -63,6 +67,9 @@ import java.io.IOException; import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -70,8 +77,10 @@ import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Set; +import java.util.function.Function; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.hive.formats.HiveClassNames.PARQUET_HIVE_SERDE_CLASS; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; @@ -86,8 +95,7 @@ import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; import static io.trino.plugin.hive.HiveErrorCode.HIVE_BAD_DATA; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CANNOT_OPEN_SPLIT; -import static io.trino.plugin.hive.HivePageSourceProvider.projectBaseColumns; -import static io.trino.plugin.hive.HivePageSourceProvider.projectSufficientColumns; +import static io.trino.plugin.hive.HivePageSourceProvider.getProjection; import static io.trino.plugin.hive.HiveSessionProperties.getParquetMaxReadBlockRowCount; import static io.trino.plugin.hive.HiveSessionProperties.getParquetMaxReadBlockSize; import static io.trino.plugin.hive.HiveSessionProperties.getParquetSmallFileThreshold; @@ -101,7 +109,6 @@ import static io.trino.spi.type.BigintType.BIGINT; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.toUnmodifiableList; public class ParquetPageSourceFactory implements HivePageSourceFactory @@ -151,7 +158,7 @@ public static boolean stripUnnecessaryProperties(String serializationLibraryName } @Override - public Optional createPageSource( + public Optional createPageSource( ConnectorSession session, Location path, long start, @@ -199,7 +206,7 @@ public Optional createPageSource( /** * This method is available for other callers to use directly. */ - public static ReaderPageSource createPageSource( + public static ConnectorPageSource createPageSource( TrinoInputFile inputFile, long start, long length, @@ -261,18 +268,12 @@ public static ReaderPageSource createPageSource( domainCompactionThreshold, options); - Optional readerProjections = projectBaseColumns(columns, useColumnNames); - List baseColumns = readerProjections.map(projection -> - projection.get().stream() - .map(HiveColumnHandle.class::cast) - .collect(toUnmodifiableList())) - .orElse(columns); - ParquetDataSourceId dataSourceId = dataSource.getId(); ParquetDataSource finalDataSource = dataSource; - ParquetReaderProvider parquetReaderProvider = fields -> new ParquetReader( + ParquetReaderProvider parquetReaderProvider = (fields, appendRowNumberColumn) -> new ParquetReader( Optional.ofNullable(fileMetaData.getCreatedBy()), fields, + appendRowNumberColumn, rowGroups, finalDataSource, timeZone, @@ -283,8 +284,7 @@ public static ReaderPageSource createPageSource( // are not present in the Parquet files which are read with disjunct predicates. parquetPredicates.size() == 1 ? Optional.of(parquetPredicates.get(0)) : Optional.empty(), parquetWriteValidation); - ConnectorPageSource parquetPageSource = createParquetPageSource(baseColumns, fileSchema, messageColumn, useColumnNames, parquetReaderProvider); - return new ReaderPageSource(parquetPageSource, readerProjections); + return createParquetPageSource(columns, fileSchema, messageColumn, useColumnNames, parquetReaderProvider); } catch (Exception e) { try { @@ -321,11 +321,7 @@ public static ParquetDataSource createDataSource( public static Optional getParquetMessageType(List columns, boolean useColumnNames, MessageType fileSchema) { - Optional message = projectSufficientColumns(columns) - .map(projection -> projection.get().stream() - .map(HiveColumnHandle.class::cast) - .collect(toUnmodifiableList())) - .orElse(columns).stream() + Optional message = projectSufficientColumns(columns).stream() .filter(column -> column.getColumnType() == REGULAR) .map(column -> getColumnType(column, fileSchema, useColumnNames)) .filter(Optional::isPresent) @@ -335,6 +331,58 @@ public static Optional getParquetMessageType(List return message; } + /** + * Creates a set of sufficient columns for the input projected columns and prepares a mapping between the two. For example, + * if input columns include columns "a.b" and "a.b.c", then they will be projected from a single column "a.b". + */ + @VisibleForTesting + static List projectSufficientColumns(List columns) + { + requireNonNull(columns, "columns is null"); + + if (columns.stream().allMatch(HiveColumnHandle::isBaseColumn)) { + return columns; + } + + ImmutableBiMap.Builder dereferenceChainsBuilder = ImmutableBiMap.builder(); + + for (HiveColumnHandle column : columns) { + List indices = column.getHiveColumnProjectionInfo() + .map(HiveColumnProjectionInfo::getDereferenceIndices) + .orElse(ImmutableList.of()); + + DereferenceChain dereferenceChain = new DereferenceChain(column.getBaseColumnName(), indices); + dereferenceChainsBuilder.put(dereferenceChain, column); + } + + BiMap dereferenceChains = dereferenceChainsBuilder.build(); + + List sufficientColumns = new ArrayList<>(); + Set chosenColumns = new HashSet<>(); + + // Pick a covering column for every column + for (HiveColumnHandle columnHandle : columns) { + DereferenceChain column = requireNonNull(dereferenceChains.inverse().get(columnHandle)); + List orderedPrefixes = column.getOrderedPrefixes(); + + // Shortest existing prefix is chosen as the input. + DereferenceChain chosenColumn = null; + for (DereferenceChain prefix : orderedPrefixes) { + if (dereferenceChains.containsKey(prefix)) { + chosenColumn = prefix; + break; + } + } + checkState(chosenColumn != null, "chosenColumn is null"); + + if (chosenColumns.add(chosenColumn)) { + sufficientColumns.add(dereferenceChains.get(chosenColumn)); + } + } + + return sufficientColumns; + } + public static Optional getColumnType(HiveColumnHandle column, MessageType messageType, boolean useParquetColumnNames) { Optional baseColumnType = getBaseColumnParquetType(column, messageType, useParquetColumnNames); @@ -412,59 +460,73 @@ public static TupleDomain getParquetTupleDomain( public interface ParquetReaderProvider { - ParquetReader createParquetReader(List fields) + ParquetReader createParquetReader(List fields, boolean appendRowNumberColumn) throws IOException; } public static ConnectorPageSource createParquetPageSource( - List baseColumns, + List columnHandles, MessageType fileSchema, MessageColumnIO messageColumn, boolean useColumnNames, ParquetReaderProvider parquetReaderProvider) throws IOException { - ParquetPageSource.Builder pageSourceBuilder = ParquetPageSource.builder(); - ImmutableList.Builder parquetColumnFieldsBuilder = ImmutableList.builder(); - int sourceChannel = 0; - for (HiveColumnHandle column : baseColumns) { + List parquetColumnFieldsBuilder = new ArrayList<>(columnHandles.size()); + Map baseColumnIdToOrdinal = new HashMap<>(); + TransformConnectorPageSource.Builder transforms = TransformConnectorPageSource.builder(); + boolean appendRowNumberColumn = false; + for (HiveColumnHandle column : columnHandles) { if (column == PARQUET_ROW_INDEX_COLUMN) { - pageSourceBuilder.addRowIndexColumn(); + appendRowNumberColumn = true; + transforms.transform(new GetRowPositionFromSource()); continue; } - checkArgument(column.getColumnType() == REGULAR, "column type must be REGULAR: %s", column); - Optional parquetType = getBaseColumnParquetType(column, fileSchema, useColumnNames); + + HiveColumnHandle baseColumn = column.getBaseColumn(); + Optional parquetType = getBaseColumnParquetType(baseColumn, fileSchema, useColumnNames); if (parquetType.isEmpty()) { - pageSourceBuilder.addNullColumn(column.getBaseType()); + transforms.constantValue(column.getBaseType().createNullBlock()); continue; } - String columnName = useColumnNames ? column.getBaseColumnName() : fileSchema.getFields().get(column.getBaseHiveColumnIndex()).getName(); + String baseColumnName = useColumnNames ? baseColumn.getBaseColumnName() : fileSchema.getFields().get(baseColumn.getBaseHiveColumnIndex()).getName(); Optional> coercer = Optional.empty(); - ColumnIO columnIO = lookupColumnByName(messageColumn, columnName); - if (columnIO != null && columnIO.getType().isPrimitive()) { - PrimitiveType primitiveType = columnIO.getType().asPrimitiveType(); - coercer = createCoercer(primitiveType.getPrimitiveTypeName(), primitiveType.getLogicalTypeAnnotation(), column.getBaseType()); - } + Integer ordinal = baseColumnIdToOrdinal.get(baseColumnName); + if (ordinal == null) { + ColumnIO columnIO = lookupColumnByName(messageColumn, baseColumnName); + if (columnIO != null && columnIO.getType().isPrimitive()) { + PrimitiveType primitiveType = columnIO.getType().asPrimitiveType(); + coercer = createCoercer(primitiveType.getPrimitiveTypeName(), primitiveType.getLogicalTypeAnnotation(), baseColumn.getBaseType()); + } + io.trino.spi.type.Type readType = coercer.map(TypeCoercer::getFromType).orElseGet(baseColumn::getBaseType); - io.trino.spi.type.Type readType = coercer.map(TypeCoercer::getFromType).orElseGet(column::getBaseType); + Optional field = constructField(readType, columnIO); + if (field.isEmpty()) { + transforms.constantValue(column.getType().createNullBlock()); + continue; + } - Optional field = constructField(readType, columnIO); - if (field.isEmpty()) { - pageSourceBuilder.addNullColumn(readType); - continue; + ordinal = parquetColumnFieldsBuilder.size(); + parquetColumnFieldsBuilder.add(new Column(baseColumnName, field.get())); + baseColumnIdToOrdinal.put(baseColumnName, ordinal); } - parquetColumnFieldsBuilder.add(new Column(columnName, field.get())); - if (coercer.isPresent()) { - pageSourceBuilder.addCoercedColumn(sourceChannel, coercer.get()); + + if (column.isBaseColumn()) { + transforms.column(ordinal, coercer.map(Function.identity())); } else { - pageSourceBuilder.addSourceColumn(sourceChannel); + transforms.dereferenceField( + ImmutableList.builder() + .add(ordinal) + .addAll(getProjection(column, baseColumn)) + .build(), + coercer.map(Function.identity())); } - sourceChannel++; } - - return pageSourceBuilder.build(parquetReaderProvider.createParquetReader(parquetColumnFieldsBuilder.build())); + ParquetReader parquetReader = parquetReaderProvider.createParquetReader(parquetColumnFieldsBuilder, appendRowNumberColumn); + ConnectorPageSource pageSource = new ParquetPageSource(parquetReader); + return transforms.build(pageSource); } private static Optional getBaseColumnParquetType(HiveColumnHandle column, MessageType messageType, boolean useParquetColumnNames) @@ -506,4 +568,37 @@ private static Optional> dereferenceSubFiel return Optional.of(typeBuilder.build()); } + + private record GetRowPositionFromSource() + implements Function + { + @Override + public Block apply(SourcePage page) + { + return page.getBlock(page.getChannelCount() - 1); + } + } + + private record DereferenceChain(String name, List indices) + { + private DereferenceChain + { + requireNonNull(name, "name is null"); + indices = ImmutableList.copyOf(requireNonNull(indices, "indices is null")); + } + + /** + * Get Prefixes of this Dereference chain in increasing order of lengths + */ + public List getOrderedPrefixes() + { + ImmutableList.Builder prefixes = ImmutableList.builder(); + + for (int prefixLen = 0; prefixLen <= indices.size(); prefixLen++) { + prefixes.add(new DereferenceChain(name, indices.subList(0, prefixLen))); + } + + return prefixes.build(); + } + } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/rcfile/RcFilePageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/rcfile/RcFilePageSource.java index 2d2d45774ac7..874deac23563 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/rcfile/RcFilePageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/rcfile/RcFilePageSource.java @@ -14,6 +14,7 @@ package io.trino.plugin.hive.rcfile; import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.CheckReturnValue; import io.airlift.units.DataSize; import io.trino.hive.formats.FileCorruptionException; import io.trino.hive.formats.rcfile.RcFileReader; @@ -21,14 +22,16 @@ import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; -import io.trino.spi.block.LazyBlock; -import io.trino.spi.block.LazyBlockLoader; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; +import jakarta.annotation.Nullable; import java.io.IOException; +import java.util.Arrays; import java.util.List; +import java.util.function.ObjLongConsumer; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; @@ -36,6 +39,7 @@ import static io.trino.plugin.hive.HiveErrorCode.HIVE_BAD_DATA; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CURSOR_ERROR; import static java.lang.String.format; +import static java.util.Objects.checkIndex; import static java.util.Objects.requireNonNull; public class RcFilePageSource @@ -106,7 +110,7 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { try { // advance in the current batch @@ -119,17 +123,7 @@ public Page getNextPage() return null; } - Block[] blocks = new Block[hiveColumnIndexes.length]; - for (int fieldId = 0; fieldId < blocks.length; fieldId++) { - if (constantBlocks[fieldId] != null) { - blocks[fieldId] = RunLengthEncodedBlock.create(constantBlocks[fieldId], currentPageSize); - } - else { - blocks[fieldId] = createBlock(currentPageSize, fieldId); - } - } - - return new Page(currentPageSize, blocks); + return new RcFileSourcePage(currentPageSize); } catch (TrinoException e) { closeAllSuppress(e, this); @@ -173,46 +167,136 @@ public long getMemoryUsage() return GUESSED_MEMORY_USAGE; } - private Block createBlock(int currentPageSize, int fieldId) + private final class RcFileSourcePage + implements SourcePage { - int hiveColumnIndex = hiveColumnIndexes[fieldId]; + private final int expectedBatchId = pageId; + private final Block[] blocks = new Block[hiveColumnIndexes.length]; + private SelectedPositions selectedPositions; - return new LazyBlock( - currentPageSize, - new RcFileBlockLoader(hiveColumnIndex)); - } + private long sizeInBytes; + private long retainedSizeInBytes; - private final class RcFileBlockLoader - implements LazyBlockLoader - { - private final int expectedBatchId = pageId; - private final int columnIndex; - private boolean loaded; + public RcFileSourcePage(int positionCount) + { + selectedPositions = new SelectedPositions(positionCount, null); + } + + @Override + public int getPositionCount() + { + return selectedPositions.positionCount(); + } + + @Override + public long getSizeInBytes() + { + return sizeInBytes; + } + + @Override + public long getRetainedSizeInBytes() + { + return retainedSizeInBytes; + } - public RcFileBlockLoader(int columnIndex) + @Override + public void retainedBytesForEachPart(ObjLongConsumer consumer) { - this.columnIndex = columnIndex; + for (Block block : blocks) { + if (block != null) { + block.retainedBytesForEachPart(consumer); + } + } } @Override - public Block load() + public int getChannelCount() + { + return blocks.length; + } + + @Override + public Block getBlock(int channel) { - checkState(!loaded, "Already loaded"); checkState(pageId == expectedBatchId); + Block block = blocks[channel]; + if (block == null) { + if (constantBlocks[channel] != null) { + block = RunLengthEncodedBlock.create(constantBlocks[channel], selectedPositions.positionCount()); + } + else { + try { + // todo use selected positions to improve read performance + block = rcFileReader.readBlock(hiveColumnIndexes[channel]); + } + catch (FileCorruptionException e) { + throw new TrinoException(HIVE_BAD_DATA, format("Corrupted RC file: %s", rcFileReader.getFileLocation()), e); + } + catch (IOException | RuntimeException e) { + throw new TrinoException(HIVE_CURSOR_ERROR, format("Failed to read RC file: %s", rcFileReader.getFileLocation()), e); + } + block = selectedPositions.apply(block); + } + blocks[channel] = block; + sizeInBytes += block.getSizeInBytes(); + retainedSizeInBytes += block.getRetainedSizeInBytes(); + } + return block; + } - Block block; - try { - block = rcFileReader.readBlock(columnIndex); + @Override + public Page getPage() + { + // ensure all blocks are loaded + for (int i = 0; i < blocks.length; i++) { + getBlock(i); } - catch (FileCorruptionException e) { - throw new TrinoException(HIVE_BAD_DATA, format("Corrupted RC file: %s", rcFileReader.getFileLocation()), e); + return new Page(selectedPositions.positionCount(), blocks); + } + + @Override + public void selectPositions(int[] positions, int offset, int size) + { + selectedPositions = selectedPositions.selectPositions(positions, offset, size); + retainedSizeInBytes = 0; + for (int i = 0; i < blocks.length; i++) { + Block block = blocks[i]; + if (block != null) { + block = selectedPositions.apply(block); + retainedSizeInBytes += block.getRetainedSizeInBytes(); + blocks[i] = block; + } } - catch (IOException | RuntimeException e) { - throw new TrinoException(HIVE_CURSOR_ERROR, format("Failed to read RC file: %s", rcFileReader.getFileLocation()), e); + } + } + + private record SelectedPositions(int positionCount, @Nullable int[] positions) + { + @CheckReturnValue + public Block apply(Block block) + { + if (positions == null) { + return block; } + return block.getPositions(positions, 0, positionCount); + } - loaded = true; - return block; + @CheckReturnValue + public SelectedPositions selectPositions(int[] positions, int offset, int size) + { + if (this.positions == null) { + for (int i = 0; i < size; i++) { + checkIndex(offset + i, positionCount); + } + return new SelectedPositions(size, Arrays.copyOfRange(positions, offset, offset + size)); + } + + int[] newPositions = new int[size]; + for (int i = 0; i < size; i++) { + newPositions[i] = this.positions[positions[offset + i]]; + } + return new SelectedPositions(size, newPositions); } } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/rcfile/RcFilePageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/rcfile/RcFilePageSourceFactory.java index 8acd21a7c425..dac87b44f06e 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/rcfile/RcFilePageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/rcfile/RcFilePageSourceFactory.java @@ -33,8 +33,6 @@ import io.trino.plugin.hive.HiveColumnHandle; import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.HivePageSourceFactory; -import io.trino.plugin.hive.ReaderColumns; -import io.trino.plugin.hive.ReaderPageSource; import io.trino.plugin.hive.Schema; import io.trino.plugin.hive.acid.AcidTransaction; import io.trino.spi.TrinoException; @@ -51,13 +49,11 @@ import java.util.OptionalInt; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.hive.formats.HiveClassNames.COLUMNAR_SERDE_CLASS; import static io.trino.hive.formats.HiveClassNames.LAZY_BINARY_COLUMNAR_SERDE_CLASS; import static io.trino.plugin.hive.HiveErrorCode.HIVE_BAD_DATA; import static io.trino.plugin.hive.HiveErrorCode.HIVE_CANNOT_OPEN_SPLIT; -import static io.trino.plugin.hive.HivePageSourceProvider.projectBaseColumns; -import static io.trino.plugin.hive.ReaderPageSource.noProjectionAdaptation; +import static io.trino.plugin.hive.HivePageSourceProvider.projectColumnDereferences; import static io.trino.plugin.hive.util.HiveUtil.splitError; import static java.lang.Math.min; import static java.util.Objects.requireNonNull; @@ -83,7 +79,7 @@ public static boolean stripUnnecessaryProperties(String serializationLibraryName } @Override - public Optional createPageSource( + public Optional createPageSource( ConnectorSession session, Location path, long start, @@ -112,15 +108,18 @@ else if (serializationLibraryName.equals(COLUMNAR_SERDE_CLASS)) { checkArgument(acidInfo.isEmpty(), "Acid is not supported"); - List projectedReaderColumns = columns; - Optional readerProjections = projectBaseColumns(columns); - - if (readerProjections.isPresent()) { - projectedReaderColumns = readerProjections.get().get().stream() - .map(HiveColumnHandle.class::cast) - .collect(toImmutableList()); - } + return Optional.of(projectColumnDereferences(columns, baseColumns -> createPageSource(session, path, start, length, estimatedFileSize, baseColumns, columnEncodingFactory))); + } + private ConnectorPageSource createPageSource( + ConnectorSession session, + Location path, + long start, + long length, + long estimatedFileSize, + List columns, + ColumnEncodingFactory columnEncodingFactory) + { TrinoFileSystem trinoFileSystem = fileSystemFactory.create(session); TrinoInputFile inputFile = trinoFileSystem.newInputFile(path); try { @@ -144,12 +143,12 @@ else if (serializationLibraryName.equals(COLUMNAR_SERDE_CLASS)) { // Split may be empty now that the correct file size is known if (length <= 0) { - return Optional.of(noProjectionAdaptation(new EmptyPageSource())); + return new EmptyPageSource(); } try { ImmutableMap.Builder readColumns = ImmutableMap.builder(); - for (HiveColumnHandle column : projectedReaderColumns) { + for (HiveColumnHandle column : columns) { readColumns.put(column.getBaseHiveColumnIndex(), column.getType()); } @@ -160,8 +159,7 @@ else if (serializationLibraryName.equals(COLUMNAR_SERDE_CLASS)) { start, length); - ConnectorPageSource pageSource = new RcFilePageSource(rcFileReader, projectedReaderColumns); - return Optional.of(new ReaderPageSource(pageSource, readerProjections)); + return new RcFilePageSource(rcFileReader, columns); } catch (TrinoException e) { throw e; diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/TempFileReader.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/TempFileReader.java index d9bc16a3189f..dec6c5534055 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/TempFileReader.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/TempFileReader.java @@ -21,6 +21,7 @@ import io.trino.orc.OrcRecordReader; import io.trino.spi.Page; import io.trino.spi.TrinoException; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import java.io.IOException; @@ -48,6 +49,7 @@ public TempFileReader(List types, OrcDataSource dataSource) reader = orcReader.createRecordReader( orcReader.getRootColumn().getNestedColumns(), types, + false, OrcPredicate.TRUE, UTC, newSimpleAggregatedMemoryContext(), @@ -67,13 +69,13 @@ protected Page computeNext() throw new InterruptedIOException(); } - Page page = reader.nextPage(); + SourcePage page = reader.nextPage(); if (page == null) { return endOfData(); } // eagerly load the page - return page.getLoadedPage(); + return page.getPage().getLoadedPage(); } catch (IOException e) { throw handleException(e); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSink.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSink.java index af40a28e0d83..ee00c620c4e0 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSink.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSink.java @@ -39,6 +39,7 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.DynamicFilter; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.connector.SourcePage; import io.trino.spi.security.ConnectorIdentity; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; @@ -249,9 +250,9 @@ private static long writeTestFile(TrinoFileSystemFactory fileSystemFactory, Hive List pages = new ArrayList<>(); try (ConnectorPageSource pageSource = createPageSource(fileSystemFactory, transaction, config, fileEntry.location())) { while (!pageSource.isFinished()) { - Page nextPage = pageSource.getNextPage(); + SourcePage nextPage = pageSource.getNextSourcePage(); if (nextPage != null) { - pages.add(nextPage.getLoadedPage()); + pages.add(nextPage.getPage().getLoadedPage()); } } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSource.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSource.java index 5e9bed4dc879..d52c1f5b7b92 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSource.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSource.java @@ -22,6 +22,7 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import org.junit.jupiter.api.Test; import java.io.IOException; @@ -35,11 +36,11 @@ import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.PARTITION_KEY; import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; import static io.trino.plugin.hive.HivePageSourceProvider.createBucketValidator; +import static io.trino.plugin.hive.HivePageSourceProvider.createHivePageSource; import static io.trino.plugin.hive.HiveStorageFormat.PARQUET; import static io.trino.plugin.hive.HiveTimestampPrecision.DEFAULT_PRECISION; import static io.trino.plugin.hive.util.HiveBucketing.BucketingVersion.BUCKETING_V1; import static io.trino.spi.predicate.Utils.nativeValueToBlock; -import static io.trino.spi.testing.InterfaceTestUtils.assertAllMethodsOverridden; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; @@ -47,12 +48,6 @@ public class TestHivePageSource { - @Test - public void testEverythingImplemented() - { - assertAllMethodsOverridden(ConnectorPageSource.class, HivePageSource.class); - } - @Test public void testGetNextPageSucceedsWhenHiveBucketingEnabled() throws IOException @@ -96,7 +91,7 @@ private void testGetNextPageWhenHiveBucketingEnabled(OptionalInt tableBucketNumb List bucketColumns = columns.stream().filter(c -> c.getName().equals(bucketColumnName)).toList(); Optional bucketValidation = Optional.of(new HiveSplit.BucketValidation(BUCKETING_V1, 8, bucketColumns)); - Optional bucketValidator = createBucketValidator( + Optional bucketValidator = createBucketValidator( Location.of("memory:///test"), bucketValidation, tableBucketNumber, @@ -109,16 +104,15 @@ private void testGetNextPageWhenHiveBucketingEnabled(OptionalInt tableBucketNumb Page page = new Page(1, blocks); try ( - ConnectorPageSource pageSource = new TestScanFilterAndProjectOperator.SinglePagePageSource(page); - HivePageSource hivePageSource = new HivePageSource( + ConnectorPageSource pageSource = new TestScanFilterAndProjectOperator.SinglePagePageSource(SourcePage.create(page)); + ConnectorPageSource hivePageSource = createHivePageSource( columnMappings, Optional.empty(), bucketValidator, - Optional.empty(), TESTING_TYPE_MANAGER, new CoercionUtils.CoercionContext(DEFAULT_PRECISION, PARQUET), pageSource)) { - hivePageSource.getNextPage(); + hivePageSource.getNextSourcePage(); } } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSourceProvider.java similarity index 53% rename from plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java rename to plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSourceProvider.java index 2dd81178b3c8..02c8cb74b7e2 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderProjectionsAdapter.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHivePageSourceProvider.java @@ -14,14 +14,13 @@ package io.trino.plugin.hive; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.block.LazyBlock; import io.trino.spi.block.RowBlock; import io.trino.spi.block.SqlRow; -import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.FixedPageSource; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -29,142 +28,81 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Map; import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.block.BlockAssertions.assertBlockEquals; -import static io.trino.plugin.hive.HivePageSourceProvider.projectBaseColumns; +import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; +import static io.trino.plugin.hive.HiveColumnHandle.createBaseColumn; +import static io.trino.plugin.hive.HivePageSourceProvider.projectColumnDereferences; +import static io.trino.plugin.hive.TestHivePageSourceProvider.RowData.rowData; import static io.trino.plugin.hive.TestHiveReaderProjectionsUtil.ROWTYPE_OF_ROW_AND_PRIMITIVES; import static io.trino.plugin.hive.TestHiveReaderProjectionsUtil.createProjectedColumnHandle; -import static io.trino.plugin.hive.TestHiveReaderProjectionsUtil.createTestFullColumns; -import static io.trino.plugin.hive.TestReaderProjectionsAdapter.RowData.rowData; +import static io.trino.plugin.hive.util.HiveTypeTranslator.toHiveType; import static io.trino.spi.type.BigintType.BIGINT; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; -public class TestReaderProjectionsAdapter +class TestHivePageSourceProvider { - private static final String TEST_COLUMN_NAME = "col"; - private static final Type TEST_COLUMN_TYPE = ROWTYPE_OF_ROW_AND_PRIMITIVES; - - private static final Map TEST_FULL_COLUMNS = createTestFullColumns( - ImmutableList.of(TEST_COLUMN_NAME), - ImmutableMap.of(TEST_COLUMN_NAME, TEST_COLUMN_TYPE)); + private static final HiveColumnHandle BASE_COLUMN = createBaseColumn("col", 0, toHiveType(ROWTYPE_OF_ROW_AND_PRIMITIVES), ROWTYPE_OF_ROW_AND_PRIMITIVES, REGULAR, Optional.empty()); @Test - public void testAdaptPage() + void testProjectColumnDereferences() + throws Exception { List columns = ImmutableList.of( - createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col"), ImmutableList.of(0, 0)), - createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col"), ImmutableList.of(0))); - - Optional readerProjections = projectBaseColumns(columns); - - List inputBlockData = new ArrayList<>(); - inputBlockData.add(rowData(rowData(11L, 12L, 13L), 1L)); - inputBlockData.add(rowData(null, 2L)); - inputBlockData.add(null); - inputBlockData.add(rowData(rowData(31L, 32L, 33L), 3L)); - - ReaderProjectionsAdapter adapter = new ReaderProjectionsAdapter( - columns.stream().map(ColumnHandle.class::cast).collect(toImmutableList()), - readerProjections.get(), - column -> ((HiveColumnHandle) column).getType(), - HivePageSourceProvider::getProjection); - verifyPageAdaptation(adapter, ImmutableList.of(inputBlockData)); + createProjectedColumnHandle(BASE_COLUMN, ImmutableList.of(0, 0)), + createProjectedColumnHandle(BASE_COLUMN, ImmutableList.of(0))); + + Page outputPage; + try (ConnectorPageSource connectorPageSource = projectColumnDereferences(columns, TestHivePageSourceProvider::createPageSource)) { + outputPage = connectorPageSource + .getNextSourcePage() + .getPage(); + } + // Verify output block values + Block baseInputBlock = createInputPage().getBlock(0); + for (int i = 0, columnsSize = columns.size(); i < columnsSize; i++) { + HiveColumnHandle column = columns.get(i); + verifyBlock( + outputPage.getBlock(i), + column.getType(), + baseInputBlock, + BASE_COLUMN.getType(), + HivePageSourceProvider.getProjection(column, BASE_COLUMN)); + } } - @Test - public void testLazyDereferenceProjectionLoading() + private static FixedPageSource createPageSource(List columns) { - List columns = ImmutableList.of(createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col"), ImmutableList.of(0, 0))); + assertThat(columns).containsOnly(BASE_COLUMN); + return new FixedPageSource(ImmutableList.of(createInputPage())); + } + private static Page createInputPage() + { List inputBlockData = new ArrayList<>(); inputBlockData.add(rowData(rowData(11L, 12L, 13L), 1L)); inputBlockData.add(rowData(null, 2L)); inputBlockData.add(null); inputBlockData.add(rowData(rowData(31L, 32L, 33L), 3L)); - // Produce an output page by applying adaptation - Optional readerProjections = projectBaseColumns(columns); - ReaderProjectionsAdapter adapter = new ReaderProjectionsAdapter( - columns.stream().map(ColumnHandle.class::cast).collect(toImmutableList()), - readerProjections.get(), - column -> ((HiveColumnHandle) column).getType(), - HivePageSourceProvider::getProjection); - Page inputPage = createPage(ImmutableList.of(inputBlockData), adapter.getInputTypes()); - adapter.adaptPage(inputPage).getLoadedPage(); - - // Verify that only the block corresponding to subfield "col.f_row_0.f_bigint_0" should be completely loaded, others are not. - - // Assertion for "col" - Block lazyBlockLevel1 = inputPage.getBlock(0); - assertThat(lazyBlockLevel1 instanceof LazyBlock).isTrue(); - assertThat(lazyBlockLevel1.isLoaded()).isFalse(); - RowBlock rowBlockLevel1 = (RowBlock) ((LazyBlock) lazyBlockLevel1).getBlock(); - assertThat(rowBlockLevel1.isLoaded()).isFalse(); - - // Assertion for "col.f_row_0" and col.f_bigint_0" - assertThat(rowBlockLevel1.getFieldBlock(0).isLoaded()).isFalse(); - assertThat(rowBlockLevel1.getFieldBlock(1).isLoaded()).isFalse(); - - Block lazyBlockLevel2 = rowBlockLevel1.getFieldBlock(0); - assertThat(lazyBlockLevel2 instanceof LazyBlock).isTrue(); - RowBlock rowBlockLevel2 = ((RowBlock) ((LazyBlock) lazyBlockLevel2).getBlock()); - assertThat(rowBlockLevel2.isLoaded()).isFalse(); - // Assertion for "col.f_row_0.f_bigint_0" and "col.f_row_0.f_bigint_1" - assertThat(rowBlockLevel2.getFieldBlock(0).isLoaded()).isTrue(); - assertThat(rowBlockLevel2.getFieldBlock(1).isLoaded()).isFalse(); - } - - private void verifyPageAdaptation(ReaderProjectionsAdapter adapter, List> inputPageData) - { - List columnMapping = adapter.getOutputToInputMapping(); - List outputTypes = adapter.getOutputTypes(); - List inputTypes = adapter.getInputTypes(); - - Page inputPage = createPage(inputPageData, inputTypes); - Page outputPage = adapter.adaptPage(inputPage).getLoadedPage(); - - // Verify output block values - for (int i = 0; i < columnMapping.size(); i++) { - ReaderProjectionsAdapter.ChannelMapping mapping = columnMapping.get(i); - int inputBlockIndex = mapping.getInputChannelIndex(); - verifyBlock( - outputPage.getBlock(i), - outputTypes.get(i), - inputPage.getBlock(inputBlockIndex), - inputTypes.get(inputBlockIndex), - mapping.getDereferenceSequence()); - } - } - - private static Page createPage(List> pageData, List types) - { - Block[] inputPageBlocks = new Block[pageData.size()]; - for (int i = 0; i < inputPageBlocks.length; i++) { - inputPageBlocks[i] = createInputBlock(pageData.get(i), types.get(i)); - } - - return new Page(inputPageBlocks); + return new Page(createInputBlock(inputBlockData, BASE_COLUMN.getType())); } private static Block createInputBlock(List data, Type type) { - int positionCount = data.size(); - if (type instanceof RowType) { - return new LazyBlock(data.size(), () -> createRowBlockWithLazyNestedBlocks(data, (RowType) type)); + return createRowBlock(data, (RowType) type); } if (BIGINT.equals(type)) { - return new LazyBlock(positionCount, () -> createLongArrayBlock(data)); + return createLongArrayBlock(data); } throw new UnsupportedOperationException(); } - private static Block createRowBlockWithLazyNestedBlocks(List data, RowType rowType) + private static Block createRowBlock(List data, RowType rowType) { int positionCount = data.size(); @@ -203,13 +141,12 @@ private static Block createRowBlockWithLazyNestedBlocks(List data, RowTy private static Block createLongArrayBlock(List data) { BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(data.size()); - for (int i = 0; i < data.size(); i++) { - Long value = (Long) data.get(i); - if (value == null) { + for (Object datum : data) { + if (datum == null) { builder.appendNull(); } else { - BIGINT.writeLong(builder, value); + BIGINT.writeLong(builder, (Long) datum); } } return builder.build(); @@ -244,7 +181,7 @@ private static Block createProjectedColumnBlock(Block data, Type finalType, RowT // Apply all dereferences except for the last one, because the type can be different for (int j = 0; j < dereferences.size() - 1; j++) { if (isNull) { - // If null element is discovered at any dereferencing step, break + // If a null element is discovered at any dereferencing step, break break; } @@ -261,7 +198,7 @@ private static Block createProjectedColumnBlock(Block data, Type finalType, RowT currentData = sourceType.getObject(fieldBlock, rawIndex); } - isNull = isNull || (currentData == null); + isNull = currentData == null; } if (isNull) { @@ -280,7 +217,7 @@ private static Block createProjectedColumnBlock(Block data, Type finalType, RowT static class RowData { - private final List data; + private final List data; private RowData(Object... data) { diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestNodeLocalDynamicSplitPruning.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestNodeLocalDynamicSplitPruning.java index 40b2a9b598c5..0872e83afc5c 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestNodeLocalDynamicSplitPruning.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestNodeLocalDynamicSplitPruning.java @@ -88,7 +88,7 @@ void testDynamicBucketPruning() } try (ConnectorPageSource nonEmptyPageSource = createTestingPageSource(transaction, config, getDynamicFilter(getNonSelectiveBucketTupleDomain()))) { - assertThat(nonEmptyPageSource.getClass()).isEqualTo(HivePageSource.class); + assertThat(nonEmptyPageSource.getClass()).isNotEqualTo(EmptyPageSource.class); } } @@ -104,7 +104,7 @@ void testDynamicPartitionPruning() } try (ConnectorPageSource nonEmptyPageSource = createTestingPageSource(transaction, config, getDynamicFilter(getNonSelectivePartitionTupleDomain()))) { - assertThat(nonEmptyPageSource.getClass()).isEqualTo(HivePageSource.class); + assertThat(nonEmptyPageSource.getClass()).isNotEqualTo(EmptyPageSource.class); } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestOrcPageSourceMemoryTracking.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestOrcPageSourceMemoryTracking.java index a0a48f1746a6..09342d166894 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestOrcPageSourceMemoryTracking.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestOrcPageSourceMemoryTracking.java @@ -46,6 +46,7 @@ import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; import io.trino.sql.gen.CursorProcessorCompiler; @@ -206,7 +207,7 @@ private void testPageSource(boolean useCache) int totalRows = 0; while (totalRows < 20000) { assertThat(pageSource.isFinished()).isFalse(); - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); assertThat(page).isNotNull(); Block block = page.getBlock(1); @@ -239,7 +240,7 @@ private void testPageSource(boolean useCache) memoryUsage = -1; while (totalRows < 40000) { assertThat(pageSource.isFinished()).isFalse(); - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); assertThat(page).isNotNull(); Block block = page.getBlock(1); @@ -272,7 +273,7 @@ private void testPageSource(boolean useCache) memoryUsage = -1; while (totalRows < NUM_ROWS) { assertThat(pageSource.isFinished()).isFalse(); - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); assertThat(page).isNotNull(); Block block = page.getBlock(1); @@ -303,7 +304,7 @@ private void testPageSource(boolean useCache) } assertThat(pageSource.isFinished()).isFalse(); - assertThat(pageSource.getNextPage()).isNull(); + assertThat(pageSource.getNextSourcePage()).isNull(); assertThat(pageSource.isFinished()).isTrue(); if (useCache) { // file is fully cached @@ -360,12 +361,13 @@ private void testMaxReadBytes(int rowCount) try { int positionCount = 0; while (true) { - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); if (pageSource.isFinished()) { break; } assertThat(page).isNotNull(); - page = page.getLoadedPage(); + // load all page data + page.getPage().getLoadedPage(); positionCount += page.getPositionCount(); // assert upper bound is tight // ignore the first MAX_BATCH_SIZE rows given the sizes are set when loading the blocks @@ -375,7 +377,7 @@ private void testMaxReadBytes(int rowCount) assertThat(page.getSizeInBytes() < (long) maxReadBytes * (MAX_BATCH_SIZE / step) || 1 == page.getPositionCount()).isTrue(); } } - + pageSource.close(); // verify the stats are correctly recorded Distribution distribution = stats.getMaxCombinedBytesPerRow().getAllTime(); assertThat((int) distribution.getCount()).isEqualTo(1); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderColumns.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderColumns.java deleted file mode 100644 index c3863d4f557e..000000000000 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestReaderColumns.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hive; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import io.trino.spi.type.Type; -import org.junit.jupiter.api.Test; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.plugin.hive.HivePageSourceProvider.projectBaseColumns; -import static io.trino.plugin.hive.HivePageSourceProvider.projectSufficientColumns; -import static io.trino.plugin.hive.TestHiveReaderProjectionsUtil.ROWTYPE_OF_PRIMITIVES; -import static io.trino.plugin.hive.TestHiveReaderProjectionsUtil.ROWTYPE_OF_ROW_AND_PRIMITIVES; -import static io.trino.plugin.hive.TestHiveReaderProjectionsUtil.createProjectedColumnHandle; -import static io.trino.plugin.hive.TestHiveReaderProjectionsUtil.createTestFullColumns; -import static io.trino.spi.type.BigintType.BIGINT; -import static org.assertj.core.api.Assertions.assertThat; - -public class TestReaderColumns -{ - private static final List TEST_COLUMN_NAMES = ImmutableList.of( - "col_bigint", - "col_struct_of_primitives", - "col_struct_of_non_primitives", - "col_partition_key_1", - "col_partition_key_2"); - - private static final Map TEST_COLUMN_TYPES = ImmutableMap.builder() - .put("col_bigint", BIGINT) - .put("col_struct_of_primitives", ROWTYPE_OF_PRIMITIVES) - .put("col_struct_of_non_primitives", ROWTYPE_OF_ROW_AND_PRIMITIVES) - .put("col_partition_key_1", BIGINT) - .put("col_partition_key_2", BIGINT) - .buildOrThrow(); - - private static final Map TEST_FULL_COLUMNS = createTestFullColumns(TEST_COLUMN_NAMES, TEST_COLUMN_TYPES); - - @Test - public void testNoProjections() - { - List columns = new ArrayList<>(TEST_FULL_COLUMNS.values()); - Optional mapping; - - mapping = projectBaseColumns(columns); - assertThat(mapping.isEmpty()) - .describedAs("Full columns should not require any adaptation") - .isTrue(); - - mapping = projectSufficientColumns(columns); - assertThat(mapping.isEmpty()) - .describedAs("Full columns should not require any adaptation") - .isTrue(); - } - - @Test - public void testBaseColumnsProjection() - { - List columns = ImmutableList.of( - createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_struct_of_primitives"), ImmutableList.of(0)), - createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_struct_of_primitives"), ImmutableList.of(1)), - createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_bigint"), ImmutableList.of()), - createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_struct_of_non_primitives"), ImmutableList.of(0, 1)), - createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_struct_of_non_primitives"), ImmutableList.of(0))); - - Optional mapping = projectBaseColumns(columns); - assertThat(mapping.isPresent()) - .describedAs("Full columns should be created for corresponding projected columns") - .isTrue(); - - List readerColumns = mapping.get().get().stream() - .map(HiveColumnHandle.class::cast) - .collect(toImmutableList()); - - for (int i = 0; i < columns.size(); i++) { - HiveColumnHandle column = columns.get(i); - int readerIndex = mapping.get().getPositionForColumnAt(i); - HiveColumnHandle readerColumn = (HiveColumnHandle) mapping.get().getForColumnAt(i); - assertThat(column.getBaseColumn()).isEqualTo(readerColumn); - assertThat(readerColumns.get(readerIndex)).isEqualTo(readerColumn); - } - } - - @Test - public void testProjectSufficientColumns() - { - List columns = ImmutableList.of( - createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_struct_of_primitives"), ImmutableList.of(0)), - createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_struct_of_primitives"), ImmutableList.of(1)), - createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_bigint"), ImmutableList.of()), - createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_struct_of_non_primitives"), ImmutableList.of(0, 1)), - createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_struct_of_non_primitives"), ImmutableList.of(0))); - - Optional readerProjections = projectSufficientColumns(columns); - assertThat(readerProjections.isPresent()) - .describedAs("expected readerProjections to be present") - .isTrue(); - - assertThat(readerProjections.get().getForColumnAt(0)).isEqualTo(columns.get(0)); - assertThat(readerProjections.get().getForColumnAt(1)).isEqualTo(columns.get(1)); - assertThat(readerProjections.get().getForColumnAt(2)).isEqualTo(columns.get(2)); - assertThat(readerProjections.get().getForColumnAt(3)).isEqualTo(columns.get(4)); - assertThat(readerProjections.get().getForColumnAt(4)).isEqualTo(columns.get(4)); - - assertThat(readerProjections.get().getPositionForColumnAt(0)).isEqualTo(0); - assertThat(readerProjections.get().getPositionForColumnAt(1)).isEqualTo(1); - assertThat(readerProjections.get().getPositionForColumnAt(2)).isEqualTo(2); - assertThat(readerProjections.get().getPositionForColumnAt(3)).isEqualTo(3); - assertThat(readerProjections.get().getPositionForColumnAt(4)).isEqualTo(3); - - List readerColumns = readerProjections.get().get().stream() - .map(HiveColumnHandle.class::cast) - .collect(toImmutableList()); - assertThat(readerColumns.get(0)).isEqualTo(columns.get(0)); - assertThat(readerColumns.get(1)).isEqualTo(columns.get(1)); - assertThat(readerColumns.get(2)).isEqualTo(columns.get(2)); - assertThat(readerColumns.get(3)).isEqualTo(columns.get(4)); - } -} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcDeletedRows.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcDeletedRows.java index 803bd209eb86..3977a53385c3 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcDeletedRows.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcDeletedRows.java @@ -22,6 +22,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.connector.SourcePage; import io.trino.spi.security.ConnectorIdentity; import org.apache.hadoop.hive.ql.io.AcidUtils; import org.junit.jupiter.api.Test; @@ -69,7 +70,7 @@ public void testDeleteLocations() // page with deleted rows Page testPage = createTestPage(0, 10); - Block block = deletedRows.getMaskDeletedRowsFunction(testPage, OptionalLong.empty()).apply(testPage.getBlock(0)); + Block block = deletedRows.maskPage(SourcePage.create(testPage), OptionalLong.empty()).getBlock(0); Set validRows = resultBuilder(SESSION, BIGINT) .page(new Page(block)) .build() @@ -80,7 +81,7 @@ public void testDeleteLocations() // page with no deleted rows testPage = createTestPage(10, 20); - block = deletedRows.getMaskDeletedRowsFunction(testPage, OptionalLong.empty()).apply(testPage.getBlock(1)); + block = deletedRows.maskPage(SourcePage.create(testPage), OptionalLong.empty()).getBlock(1); assertThat(block.getPositionCount()).isEqualTo(10); } @@ -98,7 +99,7 @@ public void testDeletedLocationsOriginalFiles() // page with deleted rows Page testPage = createTestPage(0, 8); - Block block = deletedRows.getMaskDeletedRowsFunction(testPage, OptionalLong.of(0L)).apply(testPage.getBlock(0)); + Block block = deletedRows.maskPage(SourcePage.create(testPage), OptionalLong.of(0L)).getBlock(0); Set validRows = resultBuilder(SESSION, BIGINT) .page(new Page(block)) .build() @@ -109,7 +110,7 @@ public void testDeletedLocationsOriginalFiles() // page with no deleted rows testPage = createTestPage(5, 9); - block = deletedRows.getMaskDeletedRowsFunction(testPage, OptionalLong.empty()).apply(testPage.getBlock(1)); + block = deletedRows.maskPage(SourcePage.create(testPage), OptionalLong.empty()).getBlock(1); assertThat(block.getPositionCount()).isEqualTo(4); } @@ -123,7 +124,7 @@ public void testDeletedLocationsAfterMinorCompaction() // page with deleted rows Page testPage = createTestPage(0, 10); - Block block = deletedRows.getMaskDeletedRowsFunction(testPage, OptionalLong.empty()).apply(testPage.getBlock(0)); + Block block = deletedRows.maskPage(SourcePage.create(testPage), OptionalLong.empty()).getBlock(0); Set validRows = resultBuilder(SESSION, BIGINT) .page(new Page(block)) .build() @@ -134,7 +135,7 @@ public void testDeletedLocationsAfterMinorCompaction() // page with no deleted rows testPage = createTestPage(10, 20); - block = deletedRows.getMaskDeletedRowsFunction(testPage, OptionalLong.empty()).apply(testPage.getBlock(1)); + block = deletedRows.maskPage(SourcePage.create(testPage), OptionalLong.empty()).getBlock(1); assertThat(block.getPositionCount()).isEqualTo(10); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcPageSourceFactory.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcPageSourceFactory.java index 3aa3a4d33172..5973183e2e45 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcPageSourceFactory.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcPageSourceFactory.java @@ -25,10 +25,9 @@ import io.trino.plugin.hive.HiveColumnHandle; import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.HivePageSourceFactory; -import io.trino.plugin.hive.ReaderPageSource; import io.trino.plugin.hive.Schema; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.security.ConnectorIdentity; @@ -49,7 +48,6 @@ import java.util.Set; import java.util.function.LongPredicate; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.io.Resources.getResource; import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; @@ -248,7 +246,7 @@ private static List readFile( new FileFormatDataSourceStats(), new HiveConfig()); - Optional pageSourceWithProjections = pageSourceFactory.createPageSource( + ConnectorPageSource pageSource = pageSourceFactory.createPageSource( SESSION, location, 0, @@ -261,13 +259,8 @@ private static List readFile( acidInfo, OptionalInt.empty(), false, - NO_ACID_TRANSACTION); - - checkArgument(pageSourceWithProjections.isPresent()); - checkArgument(pageSourceWithProjections.get().getReaderColumns().isEmpty(), - "projected columns not expected here"); - - ConnectorPageSource pageSource = pageSourceWithProjections.get().get(); + NO_ACID_TRANSACTION) + .orElseThrow(); int nationKeyColumn = columnNames.indexOf("n_nationkey"); int nameColumn = columnNames.indexOf("n_name"); @@ -276,12 +269,11 @@ private static List readFile( ImmutableList.Builder rows = ImmutableList.builder(); while (!pageSource.isFinished()) { - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); if (page == null) { continue; } - page = page.getLoadedPage(); for (int position = 0; position < page.getPositionCount(); position++) { long nationKey = -42; if (nationKeyColumn >= 0) { diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcPredicates.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcPredicates.java index bbca43a8720f..e7f61efbfa01 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcPredicates.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/orc/TestOrcPredicates.java @@ -42,6 +42,7 @@ import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.RowType; @@ -143,7 +144,7 @@ private static void assertFilteredRows( try (ConnectorPageSource pageSource = createPageSource(fileSystemFactory, location, effectivePredicate, columnsToRead, session)) { int filteredRows = 0; while (!pageSource.isFinished()) { - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); if (page != null) { filteredRows += page.getPositionCount(); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java index 38d209b5ee4e..62659feadd16 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetTester.java @@ -36,7 +36,6 @@ import io.trino.plugin.hive.parquet.write.SingleLevelArrayMapKeyValuesSchemaConverter; import io.trino.plugin.hive.parquet.write.SingleLevelArraySchemaConverter; import io.trino.plugin.hive.parquet.write.TestingMapredParquetOutputFormat; -import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; import io.trino.spi.block.ArrayBlockBuilder; @@ -49,6 +48,7 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.RecordCursor; import io.trino.spi.connector.RecordPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.ArrayType; import io.trino.spi.type.CharType; import io.trino.spi.type.DateType; @@ -500,7 +500,7 @@ private static void assertPageSource(List types, Iterator[] valuesByFie private static void assertPageSource(List types, Iterator[] valuesByField, ConnectorPageSource pageSource, Optional maxReadBlockSize) { while (!pageSource.isFinished()) { - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); if (page == null) { continue; } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetUtil.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetUtil.java index adbcb32c3211..110f42fcd204 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetUtil.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/ParquetUtil.java @@ -103,8 +103,7 @@ private static ConnectorPageSource createPageSource(ConnectorSession session, Fi OptionalInt.empty(), false, NO_ACID_TRANSACTION) - .orElseThrow() - .get(); + .orElseThrow(); } private static List getBaseColumns(List columnNames, List columnTypes) diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetPageSourceFactory.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetPageSourceFactory.java index 48e92a420613..f283a73b6758 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestParquetPageSourceFactory.java @@ -14,20 +14,31 @@ package io.trino.plugin.hive.parquet; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.trino.metastore.HiveType; import io.trino.plugin.hive.HiveColumnHandle; import io.trino.plugin.hive.HiveColumnProjectionInfo; import io.trino.spi.type.IntegerType; import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; import org.junit.jupiter.api.Test; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import java.util.Optional; import static io.trino.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; +import static io.trino.plugin.hive.TestHiveReaderProjectionsUtil.ROWTYPE_OF_PRIMITIVES; +import static io.trino.plugin.hive.TestHiveReaderProjectionsUtil.ROWTYPE_OF_ROW_AND_PRIMITIVES; +import static io.trino.plugin.hive.TestHiveReaderProjectionsUtil.createProjectedColumnHandle; +import static io.trino.plugin.hive.TestHiveReaderProjectionsUtil.createTestFullColumns; +import static io.trino.plugin.hive.parquet.ParquetPageSourceFactory.projectSufficientColumns; import static io.trino.plugin.hive.util.HiveTypeTranslator.toHiveType; +import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.RowType.rowType; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32; import static org.apache.parquet.schema.Type.Repetition.OPTIONAL; @@ -71,4 +82,48 @@ private void testGetNestedMixedRepetitionColumnType(boolean useColumnNames) new PrimitiveType(REQUIRED, INT32, "required_level3")))); assertThat(ParquetPageSourceFactory.getColumnType(columnHandle, fileSchema, useColumnNames).get()).isEqualTo(fileSchema.getType("optional_level1")); } + + private static final List TEST_COLUMN_NAMES = ImmutableList.of( + "col_bigint", + "col_struct_of_primitives", + "col_struct_of_non_primitives", + "col_partition_key_1", + "col_partition_key_2"); + + private static final Map TEST_COLUMN_TYPES = ImmutableMap.builder() + .put("col_bigint", BIGINT) + .put("col_struct_of_primitives", ROWTYPE_OF_PRIMITIVES) + .put("col_struct_of_non_primitives", ROWTYPE_OF_ROW_AND_PRIMITIVES) + .put("col_partition_key_1", BIGINT) + .put("col_partition_key_2", BIGINT) + .buildOrThrow(); + + private static final Map TEST_FULL_COLUMNS = createTestFullColumns(TEST_COLUMN_NAMES, TEST_COLUMN_TYPES); + + @Test + void testNoProjections() + { + List columns = new ArrayList<>(TEST_FULL_COLUMNS.values()); + List sufficientColumns = projectSufficientColumns(columns); + assertThat(sufficientColumns) + .describedAs("Full columns should not require any adaptation") + .isEqualTo(columns); + } + + @Test + void testProjectSufficientColumns() + { + List columns = ImmutableList.of( + createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_struct_of_primitives"), ImmutableList.of(0)), + createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_struct_of_primitives"), ImmutableList.of(1)), + createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_bigint"), ImmutableList.of()), + createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_struct_of_non_primitives"), ImmutableList.of(0, 1)), + createProjectedColumnHandle(TEST_FULL_COLUMNS.get("col_struct_of_non_primitives"), ImmutableList.of(0))); + + List sufficientColumns = projectSufficientColumns(columns); + assertThat(sufficientColumns.get(0)).isEqualTo(columns.get(0)); + assertThat(sufficientColumns.get(1)).isEqualTo(columns.get(1)); + assertThat(sufficientColumns.get(2)).isEqualTo(columns.get(2)); + assertThat(sufficientColumns.get(3)).isEqualTo(columns.get(4)); + } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestTimestamp.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestTimestamp.java index f7b2427347ee..e56d64425202 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestTimestamp.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestTimestamp.java @@ -19,10 +19,10 @@ import com.google.common.collect.Range; import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.HiveTimestampPrecision; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.SqlTimestamp; import io.trino.spi.type.Type; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -146,7 +146,7 @@ private static void testReadingAs(Type type, ConnectorSession session, ParquetTe Iterator expected = expectedValues.iterator(); try (ConnectorPageSource pageSource = ParquetUtil.createPageSource(session, tempFile.getFile(), columnNames, ImmutableList.of(type), dateTimeZone)) { // skip a page to exercise the decoder's skip() logic - Page firstPage = pageSource.getNextPage(); + SourcePage firstPage = pageSource.getNextSourcePage(); assertThat(firstPage.getPositionCount() > 0) .describedAs("Expected first page to have at least 1 row") .isTrue(); @@ -157,7 +157,7 @@ private static void testReadingAs(Type type, ConnectorSession session, ParquetTe int pageCount = 1; while (!pageSource.isFinished()) { - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); if (page == null) { continue; } diff --git a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSource.java b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSource.java deleted file mode 100644 index 337ad783db8e..000000000000 --- a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSource.java +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hudi; - -import io.trino.plugin.hive.HiveColumnHandle; -import io.trino.spi.Page; -import io.trino.spi.block.Block; -import io.trino.spi.block.RunLengthEncodedBlock; -import io.trino.spi.connector.ConnectorPageSource; -import io.trino.spi.metrics.Metrics; - -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.OptionalLong; -import java.util.concurrent.CompletableFuture; - -import static io.airlift.slice.Slices.utf8Slice; -import static io.trino.plugin.base.util.Closables.closeAllSuppress; -import static io.trino.plugin.hive.HiveColumnHandle.FILE_MODIFIED_TIME_COLUMN_NAME; -import static io.trino.plugin.hive.HiveColumnHandle.FILE_MODIFIED_TIME_TYPE_SIGNATURE; -import static io.trino.plugin.hive.HiveColumnHandle.FILE_SIZE_COLUMN_NAME; -import static io.trino.plugin.hive.HiveColumnHandle.FILE_SIZE_TYPE_SIGNATURE; -import static io.trino.plugin.hive.HiveColumnHandle.PARTITION_COLUMN_NAME; -import static io.trino.plugin.hive.HiveColumnHandle.PARTITION_TYPE_SIGNATURE; -import static io.trino.plugin.hive.HiveColumnHandle.PATH_COLUMN_NAME; -import static io.trino.plugin.hive.HiveColumnHandle.PATH_TYPE; -import static io.trino.spi.predicate.Utils.nativeValueToBlock; -import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; -import static io.trino.spi.type.TimeZoneKey.UTC_KEY; -import static java.util.Objects.requireNonNull; - -public class HudiPageSource - implements ConnectorPageSource -{ - private final Block[] prefilledBlocks; - private final int[] delegateIndexes; - private final ConnectorPageSource dataPageSource; - - public HudiPageSource( - String partitionName, - List columnHandles, - Map partitionBlocks, - ConnectorPageSource dataPageSource, - String path, - long fileSize, - long fileModifiedTime) - { - requireNonNull(columnHandles, "columnHandles is null"); - this.dataPageSource = requireNonNull(dataPageSource, "dataPageSource is null"); - - int size = columnHandles.size(); - this.prefilledBlocks = new Block[size]; - this.delegateIndexes = new int[size]; - - int outputIndex = 0; - int delegateIndex = 0; - - for (HiveColumnHandle column : columnHandles) { - if (partitionBlocks.containsKey(column.getName())) { - Block partitionValue = partitionBlocks.get(column.getName()); - prefilledBlocks[outputIndex] = partitionValue; - delegateIndexes[outputIndex] = -1; - } - else if (column.getName().equals(PARTITION_COLUMN_NAME)) { - prefilledBlocks[outputIndex] = nativeValueToBlock(PARTITION_TYPE_SIGNATURE, utf8Slice(partitionName)); - delegateIndexes[outputIndex] = -1; - } - else if (column.getName().equals(PATH_COLUMN_NAME)) { - prefilledBlocks[outputIndex] = nativeValueToBlock(PATH_TYPE, utf8Slice(path)); - delegateIndexes[outputIndex] = -1; - } - else if (column.getName().equals(FILE_SIZE_COLUMN_NAME)) { - prefilledBlocks[outputIndex] = nativeValueToBlock(FILE_SIZE_TYPE_SIGNATURE, fileSize); - delegateIndexes[outputIndex] = -1; - } - else if (column.getName().equals(FILE_MODIFIED_TIME_COLUMN_NAME)) { - long packedTimestamp = packDateTimeWithZone(fileModifiedTime, UTC_KEY); - prefilledBlocks[outputIndex] = nativeValueToBlock(FILE_MODIFIED_TIME_TYPE_SIGNATURE, packedTimestamp); - delegateIndexes[outputIndex] = -1; - } - else { - delegateIndexes[outputIndex] = delegateIndex; - delegateIndex++; - } - outputIndex++; - } - } - - @Override - public long getCompletedBytes() - { - return dataPageSource.getCompletedBytes(); - } - - @Override - public long getReadTimeNanos() - { - return dataPageSource.getReadTimeNanos(); - } - - @Override - public boolean isFinished() - { - return dataPageSource.isFinished(); - } - - @Override - public CompletableFuture isBlocked() - { - return dataPageSource.isBlocked(); - } - - @Override - public OptionalLong getCompletedPositions() - { - return dataPageSource.getCompletedPositions(); - } - - @Override - public Metrics getMetrics() - { - return dataPageSource.getMetrics(); - } - - @Override - public Page getNextPage() - { - try { - Page page = dataPageSource.getNextPage(); - if (page == null) { - return null; - } - int positionCount = page.getPositionCount(); - Block[] blocks = new Block[prefilledBlocks.length]; - for (int i = 0; i < prefilledBlocks.length; i++) { - if (prefilledBlocks[i] != null) { - blocks[i] = RunLengthEncodedBlock.create(prefilledBlocks[i], positionCount); - } - else { - blocks[i] = page.getBlock(delegateIndexes[i]); - } - } - return new Page(positionCount, blocks); - } - catch (RuntimeException e) { - closeAllSuppress(e, this); - throw e; - } - } - - @Override - public long getMemoryUsage() - { - return dataPageSource.getMemoryUsage(); - } - - @Override - public void close() - throws IOException - { - dataPageSource.close(); - } -} diff --git a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java index a40079af1627..4d0cc94a20b3 100644 --- a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java +++ b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiPageSourceProvider.java @@ -33,7 +33,7 @@ import io.trino.plugin.base.metrics.FileFormatDataSourceStats; import io.trino.plugin.hive.HiveColumnHandle; import io.trino.plugin.hive.HivePartitionKey; -import io.trino.plugin.hive.ReaderColumns; +import io.trino.plugin.hive.TransformConnectorPageSource; import io.trino.plugin.hive.parquet.ParquetReaderConfig; import io.trino.plugin.hudi.model.HudiFileFormat; import io.trino.spi.TrinoException; @@ -65,7 +65,6 @@ import java.util.Optional; import java.util.OptionalLong; import java.util.TimeZone; -import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.Slices.utf8Slice; @@ -75,7 +74,14 @@ import static io.trino.parquet.ParquetTypeUtils.getDescriptors; import static io.trino.parquet.predicate.PredicateUtils.buildPredicate; import static io.trino.parquet.predicate.PredicateUtils.getFilteredRowGroups; -import static io.trino.plugin.hive.HivePageSourceProvider.projectBaseColumns; +import static io.trino.plugin.hive.HiveColumnHandle.FILE_MODIFIED_TIME_COLUMN_NAME; +import static io.trino.plugin.hive.HiveColumnHandle.FILE_MODIFIED_TIME_TYPE_SIGNATURE; +import static io.trino.plugin.hive.HiveColumnHandle.FILE_SIZE_COLUMN_NAME; +import static io.trino.plugin.hive.HiveColumnHandle.FILE_SIZE_TYPE_SIGNATURE; +import static io.trino.plugin.hive.HiveColumnHandle.PARTITION_COLUMN_NAME; +import static io.trino.plugin.hive.HiveColumnHandle.PARTITION_TYPE_SIGNATURE; +import static io.trino.plugin.hive.HiveColumnHandle.PATH_COLUMN_NAME; +import static io.trino.plugin.hive.HiveColumnHandle.PATH_TYPE; import static io.trino.plugin.hive.parquet.ParquetPageSourceFactory.ParquetReaderProvider; import static io.trino.plugin.hive.parquet.ParquetPageSourceFactory.createDataSource; import static io.trino.plugin.hive.parquet.ParquetPageSourceFactory.createParquetPageSource; @@ -91,6 +97,7 @@ import static io.trino.plugin.hudi.HudiSessionProperties.shouldUseParquetColumnNames; import static io.trino.plugin.hudi.HudiUtil.getHudiFileFormat; import static io.trino.spi.predicate.Utils.nativeValueToBlock; +import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; import static io.trino.spi.type.StandardTypes.BIGINT; import static io.trino.spi.type.StandardTypes.BOOLEAN; import static io.trino.spi.type.StandardTypes.DATE; @@ -103,6 +110,7 @@ import static io.trino.spi.type.StandardTypes.TINYINT; import static io.trino.spi.type.StandardTypes.VARBINARY; import static io.trino.spi.type.StandardTypes.VARCHAR; +import static io.trino.spi.type.TimeZoneKey.UTC_KEY; import static java.lang.Double.parseDouble; import static java.lang.Float.floatToRawIntBits; import static java.lang.Float.parseFloat; @@ -112,7 +120,6 @@ import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toMap; -import static java.util.stream.Collectors.toUnmodifiableList; public class HudiPageSourceProvider implements ConnectorPageSourceProvider @@ -147,7 +154,7 @@ public ConnectorPageSource createPageSource( HudiSplit split = (HudiSplit) connectorSplit; String path = split.location(); HudiFileFormat hudiFileFormat = getHudiFileFormat(path); - if (!HudiFileFormat.PARQUET.equals(hudiFileFormat)) { + if (HudiFileFormat.PARQUET != hudiFileFormat) { throw new TrinoException(HUDI_UNSUPPORTED_FILE_FORMAT, format("File format %s not supported", hudiFileFormat)); } @@ -158,10 +165,10 @@ public ConnectorPageSource createPageSource( // for partition columns, separate blocks will be created List regularColumns = hiveColumns.stream() .filter(columnHandle -> !columnHandle.isPartitionKey() && !columnHandle.isHidden()) - .collect(Collectors.toList()); + .collect(toList()); TrinoFileSystem fileSystem = fileSystemFactory.create(session); TrinoInputFile inputFile = fileSystem.newInputFile(Location.of(path), split.fileSize()); - ConnectorPageSource dataPageSource = createPageSource( + ConnectorPageSource pageSource = createPageSource( session, regularColumns, split, @@ -171,14 +178,34 @@ public ConnectorPageSource createPageSource( .withVectorizedDecodingEnabled(isParquetVectorizedDecodingEnabled(session)), timeZone); - return new HudiPageSource( - toPartitionName(split.partitionKeys()), - hiveColumns, - convertPartitionValues(hiveColumns, split.partitionKeys()), // create blocks for partition values - dataPageSource, - path, - split.fileSize(), - split.fileModifiedTime()); + Map partitionBlocks = convertPartitionValues(hiveColumns, split.partitionKeys()); + + TransformConnectorPageSource.Builder transforms = TransformConnectorPageSource.builder(); + int delegateIndex = 0; + for (HiveColumnHandle column : hiveColumns) { + if (partitionBlocks.containsKey(column.getName())) { + transforms.constantValue(partitionBlocks.get(column.getName())); + } + else if (column.getName().equals(PARTITION_COLUMN_NAME)) { + transforms.constantValue(nativeValueToBlock(PARTITION_TYPE_SIGNATURE, utf8Slice(toPartitionName(split.partitionKeys())))); + } + else if (column.getName().equals(PATH_COLUMN_NAME)) { + transforms.constantValue(nativeValueToBlock(PATH_TYPE, utf8Slice(path))); + } + else if (column.getName().equals(FILE_SIZE_COLUMN_NAME)) { + transforms.constantValue(nativeValueToBlock(FILE_SIZE_TYPE_SIGNATURE, split.fileSize())); + } + else if (column.getName().equals(FILE_MODIFIED_TIME_COLUMN_NAME)) { + long packedTimestamp = packDateTimeWithZone(split.fileModifiedTime(), UTC_KEY); + transforms.constantValue(nativeValueToBlock(FILE_MODIFIED_TIME_TYPE_SIGNATURE, packedTimestamp)); + } + else { + transforms.column(delegateIndex); + delegateIndex++; + } + } + + return transforms.build(pageSource); } private static ConnectorPageSource createPageSource( @@ -226,17 +253,12 @@ private static ConnectorPageSource createPageSource( DOMAIN_COMPACTION_THRESHOLD, options); - Optional readerProjections = projectBaseColumns(columns); - List baseColumns = readerProjections.map(projection -> - projection.get().stream() - .map(HiveColumnHandle.class::cast) - .collect(toUnmodifiableList())) - .orElse(columns); ParquetDataSourceId dataSourceId = dataSource.getId(); ParquetDataSource finalDataSource = dataSource; - ParquetReaderProvider parquetReaderProvider = fields -> new ParquetReader( + ParquetReaderProvider parquetReaderProvider = (fields, appendRowNumberColumn) -> new ParquetReader( Optional.ofNullable(fileMetaData.getCreatedBy()), fields, + appendRowNumberColumn, rowGroups, finalDataSource, timeZone, @@ -245,7 +267,7 @@ private static ConnectorPageSource createPageSource( exception -> handleException(dataSourceId, exception), Optional.of(parquetPredicate), Optional.empty()); - return createParquetPageSource(baseColumns, fileSchema, messageColumn, useColumnNames, parquetReaderProvider); + return createParquetPageSource(columns, fileSchema, messageColumn, useColumnNames, parquetReaderProvider); } catch (IOException | RuntimeException e) { try { @@ -277,7 +299,7 @@ private static TrinoException handleException(ParquetDataSourceId dataSourceId, return new TrinoException(HUDI_CURSOR_ERROR, format("Failed to read Parquet file: %s", dataSourceId), exception); } - private Map convertPartitionValues( + private static Map convertPartitionValues( List allColumns, List partitionKeys) { diff --git a/plugin/trino-hudi/src/test/java/io/trino/plugin/hudi/TestHudiPageSource.java b/plugin/trino-hudi/src/test/java/io/trino/plugin/hudi/TestHudiPageSource.java deleted file mode 100644 index 8bb46f8fb3f5..000000000000 --- a/plugin/trino-hudi/src/test/java/io/trino/plugin/hudi/TestHudiPageSource.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.hudi; - -import io.trino.spi.connector.ConnectorPageSource; -import org.junit.jupiter.api.Test; - -import static io.trino.spi.testing.InterfaceTestUtils.assertAllMethodsOverridden; - -public class TestHudiPageSource -{ - @Test - public void testEverythingImplemented() - { - assertAllMethodsOverridden(ConnectorPageSource.class, HudiPageSource.class); - } -} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ConstantPopulatingPageSource.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ConstantPopulatingPageSource.java deleted file mode 100644 index c0b40f7e51c8..000000000000 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/ConstantPopulatingPageSource.java +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.iceberg; - -import com.google.common.collect.ImmutableList; -import io.trino.spi.Page; -import io.trino.spi.block.Block; -import io.trino.spi.block.RunLengthEncodedBlock; -import io.trino.spi.connector.ConnectorPageSource; -import io.trino.spi.metrics.Metrics; - -import java.io.IOException; -import java.util.List; -import java.util.OptionalLong; -import java.util.concurrent.CompletableFuture; - -import static com.google.common.base.Preconditions.checkArgument; -import static java.util.Objects.requireNonNull; - -public class ConstantPopulatingPageSource - implements ConnectorPageSource -{ - private final ConnectorPageSource delegate; - private final Block[] constantColumns; - private final int[] targetChannelToSourceChannel; - - private ConstantPopulatingPageSource(ConnectorPageSource delegate, Block[] constantColumns, int[] targetChannelToSourceChannel) - { - this.delegate = requireNonNull(delegate, "delegate is null"); - this.constantColumns = requireNonNull(constantColumns, "constantColumns is null"); - this.targetChannelToSourceChannel = requireNonNull(targetChannelToSourceChannel, "targetChannelToSourceChannel is null"); - } - - @Override - public long getCompletedBytes() - { - return delegate.getCompletedBytes(); - } - - @Override - public OptionalLong getCompletedPositions() - { - return delegate.getCompletedPositions(); - } - - @Override - public long getReadTimeNanos() - { - return delegate.getReadTimeNanos(); - } - - @Override - public boolean isFinished() - { - return delegate.isFinished(); - } - - @Override - public Page getNextPage() - { - Page delegatePage = delegate.getNextPage(); - - if (delegatePage == null) { - return null; - } - - int size = constantColumns.length; - Block[] blocks = new Block[size]; - for (int targetChannel = 0; targetChannel < size; targetChannel++) { - Block constantValue = constantColumns[targetChannel]; - if (constantValue != null) { - blocks[targetChannel] = RunLengthEncodedBlock.create(constantValue, delegatePage.getPositionCount()); - } - else { - blocks[targetChannel] = delegatePage.getBlock(targetChannelToSourceChannel[targetChannel]); - } - } - - return new Page(delegatePage.getPositionCount(), blocks); - } - - @Override - public long getMemoryUsage() - { - return delegate.getMemoryUsage(); - } - - @Override - public void close() - throws IOException - { - delegate.close(); - } - - @Override - public CompletableFuture isBlocked() - { - return delegate.isBlocked(); - } - - @Override - public Metrics getMetrics() - { - return delegate.getMetrics(); - } - - public static Builder builder() - { - return new Builder(); - } - - public static class Builder - { - private final ImmutableList.Builder columns = ImmutableList.builder(); - - private Builder() - { } - - public Builder addConstantColumn(Block value) - { - columns.add(new ConstantColumn(value)); - return this; - } - - public Builder addDelegateColumn(int sourceChannel) - { - columns.add(new DelegateColumn(sourceChannel)); - return this; - } - - public ConnectorPageSource build(ConnectorPageSource delegate) - { - List columns = this.columns.build(); - Block[] constantValues = new Block[columns.size()]; - int[] delegateIndexes = new int[columns.size()]; - - // If no constant columns are added and the delegate columns are in order, nothing to do - boolean isRequired = false; - - for (int columnChannel = 0; columnChannel < columns.size(); columnChannel++) { - ColumnType column = columns.get(columnChannel); - if (column instanceof ConstantColumn) { - constantValues[columnChannel] = ((ConstantColumn) column).value(); - isRequired = true; - } - else if (column instanceof DelegateColumn) { - int delegateChannel = ((DelegateColumn) column).sourceChannel(); - delegateIndexes[columnChannel] = delegateChannel; - if (columnChannel != delegateChannel) { - isRequired = true; - } - } - else { - throw new IllegalStateException("Unknown ConstantPopulatingPageSource ColumnType " + column); - } - } - - if (!isRequired) { - return delegate; - } - - return new ConstantPopulatingPageSource(delegate, constantValues, delegateIndexes); - } - } - - public interface ColumnType {} - - private record ConstantColumn(Block value) - implements ColumnType - { - private ConstantColumn - { - requireNonNull(value, "value is null"); - checkArgument(value.getPositionCount() == 1, "ConstantColumn may only contain one value"); - } - } - - private record DelegateColumn(int sourceChannel) - implements ColumnType - { - } -} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java index 65fbad86c794..e0d53d9316d3 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java @@ -18,6 +18,7 @@ import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import org.apache.iceberg.Schema; import org.apache.iceberg.avro.Avro; @@ -49,11 +50,7 @@ public class IcebergAvroPageSource private final List columnNames; private final List columnTypes; private final Map icebergTypes; - /** - * Indicates whether the column at each index should be populated with the - * indices of its rows - */ - private final List rowIndexLocations; + private final boolean appendRowNumberColumn; private final PageBuilder pageBuilder; private final AggregatedMemoryContext memoryUsage; @@ -69,16 +66,16 @@ public IcebergAvroPageSource( Optional nameMapping, List columnNames, List columnTypes, - List rowIndexLocations, + boolean appendRowNumberColumn, AggregatedMemoryContext memoryUsage) { this.columnNames = ImmutableList.copyOf(requireNonNull(columnNames, "columnNames is null")); this.columnTypes = ImmutableList.copyOf(requireNonNull(columnTypes, "columnTypes is null")); - this.rowIndexLocations = ImmutableList.copyOf(requireNonNull(rowIndexLocations, "rowIndexLocations is null")); + this.appendRowNumberColumn = appendRowNumberColumn; this.memoryUsage = requireNonNull(memoryUsage, "memoryUsage is null"); checkArgument( - columnNames.size() == rowIndexLocations.size() && columnNames.size() == columnTypes.size(), - "names, rowIndexLocations, and types must correspond one-to-one-to-one"); + columnNames.size() == columnTypes.size(), + "names and types must correspond one-to-one-to-one"); // The column orders in the generated schema might be different from the original order Schema readSchema = fileSchema.select(columnNames); @@ -90,15 +87,10 @@ public IcebergAvroPageSource( AvroIterable avroReader = builder.build(); icebergTypes = readSchema.columns().stream() .collect(toImmutableMap(Types.NestedField::name, Types.NestedField::type)); - pageBuilder = new PageBuilder(columnTypes); + pageBuilder = new PageBuilder(appendRowNumberColumn ? ImmutableList.builder().addAll(columnTypes).add(BIGINT).build() : columnTypes); recordIterator = avroReader.iterator(); } - private boolean isIndexColumn(int column) - { - return rowIndexLocations.get(column); - } - @Override public long getCompletedBytes() { @@ -118,7 +110,7 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { if (!recordIterator.hasNext()) { return null; @@ -131,13 +123,11 @@ public Page getNextPage() pageBuilder.declarePosition(); Record record = recordIterator.next(); for (int channel = 0; channel < columnTypes.size(); channel++) { - if (isIndexColumn(channel)) { - BIGINT.writeLong(pageBuilder.getBlockBuilder(channel), rowId); - } - else { - String name = columnNames.get(channel); - serializeToTrinoBlock(columnTypes.get(channel), icebergTypes.get(name), pageBuilder.getBlockBuilder(channel), record.getField(name)); - } + String name = columnNames.get(channel); + serializeToTrinoBlock(columnTypes.get(channel), icebergTypes.get(name), pageBuilder.getBlockBuilder(channel), record.getField(name)); + } + if (appendRowNumberColumn) { + BIGINT.writeLong(pageBuilder.getBlockBuilder(columnTypes.size()), rowId); } rowId++; } @@ -146,7 +136,7 @@ public Page getNextPage() readBytes += page.getSizeInBytes(); readTimeNanos += System.nanoTime() - start; - return page; + return SourcePage.create(page); } @Override diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSource.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSource.java deleted file mode 100644 index ec5bd9e55d85..000000000000 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSource.java +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.iceberg; - -import io.trino.plugin.hive.ReaderProjectionsAdapter; -import io.trino.plugin.iceberg.delete.RowPredicate; -import io.trino.spi.Page; -import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; -import io.trino.spi.block.RowBlock; -import io.trino.spi.connector.ConnectorPageSource; -import io.trino.spi.metrics.Metrics; - -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.List; -import java.util.Optional; -import java.util.OptionalLong; -import java.util.concurrent.CompletableFuture; -import java.util.function.Function; -import java.util.function.Supplier; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Throwables.throwIfInstanceOf; -import static io.trino.plugin.base.util.Closables.closeAllSuppress; -import static io.trino.plugin.iceberg.IcebergErrorCode.ICEBERG_BAD_DATA; -import static java.util.Objects.requireNonNull; - -public class IcebergPageSource - implements ConnectorPageSource -{ - private final int[] expectedColumnIndexes; - private final ConnectorPageSource delegate; - private final Optional projectionsAdapter; - private final Supplier> deletePredicate; - private final Function rowIdBlockFactory; - // The $row_id's index in 'expectedColumns', or -1 if there isn't one - // this column with contain row position populated in the source, and must be wrapped with constant data for full row id - private int rowIdColumnIndex = -1; - // Maps the Iceberg field ids of unmodified columns to their indexes in updateRowIdChildColumnIndexes - - public IcebergPageSource( - List expectedColumns, - List requiredColumns, - ConnectorPageSource delegate, - Optional projectionsAdapter, - Supplier> deletePredicate, - Function rowIdBlockFactory) - { - // expectedColumns should contain columns which should be in the final Page - // requiredColumns should include all expectedColumns as well as any columns needed by the DeleteFilter - requireNonNull(expectedColumns, "expectedColumns is null"); - requireNonNull(requiredColumns, "requiredColumns is null"); - this.expectedColumnIndexes = new int[expectedColumns.size()]; - for (int i = 0; i < expectedColumns.size(); i++) { - IcebergColumnHandle expectedColumn = expectedColumns.get(i); - checkArgument(expectedColumn.equals(requiredColumns.get(i)), "Expected columns must be a prefix of required columns"); - expectedColumnIndexes[i] = i; - - if (expectedColumn.isMergeRowIdColumn()) { - this.rowIdColumnIndex = i; - } - } - - this.delegate = requireNonNull(delegate, "delegate is null"); - this.projectionsAdapter = requireNonNull(projectionsAdapter, "projectionsAdapter is null"); - this.deletePredicate = requireNonNull(deletePredicate, "deletePredicate is null"); - this.rowIdBlockFactory = requireNonNull(rowIdBlockFactory, "rowIdBlockFactory is null"); - } - - @Override - public long getCompletedBytes() - { - return delegate.getCompletedBytes(); - } - - @Override - public OptionalLong getCompletedPositions() - { - return delegate.getCompletedPositions(); - } - - @Override - public long getReadTimeNanos() - { - return delegate.getReadTimeNanos(); - } - - @Override - public boolean isFinished() - { - return delegate.isFinished(); - } - - @Override - public CompletableFuture isBlocked() - { - return delegate.isBlocked(); - } - - @Override - public Page getNextPage() - { - try { - Page dataPage = delegate.getNextPage(); - if (dataPage == null) { - return null; - } - - Optional deleteFilterPredicate = deletePredicate.get(); - if (deleteFilterPredicate.isPresent()) { - dataPage = deleteFilterPredicate.get().filterPage(dataPage); - } - - if (projectionsAdapter.isPresent()) { - dataPage = projectionsAdapter.get().adaptPage(dataPage); - } - - dataPage = withRowIdBlock(dataPage); - dataPage = dataPage.getColumns(expectedColumnIndexes); - - return dataPage; - } - catch (RuntimeException e) { - closeWithSuppression(e); - throwIfInstanceOf(e, TrinoException.class); - throw new TrinoException(ICEBERG_BAD_DATA, e); - } - } - - /** - * The $row_id column used for updates is a composite column of at least one other column in the Page. - * The indexes of the columns needed for the $row_id are in the updateRowIdChildColumnIndexes array. - * - * @param page The raw Page from the Parquet/ORC reader. - * @return A Page where the $row_id channel has been populated. - */ - private Page withRowIdBlock(Page page) - { - if (rowIdColumnIndex == -1) { - return page; - } - - RowBlock rowIdBlock = rowIdBlockFactory.apply(page.getBlock(rowIdColumnIndex)); - Block[] fullPage = new Block[page.getChannelCount()]; - for (int channel = 0; channel < page.getChannelCount(); channel++) { - fullPage[channel] = channel == rowIdColumnIndex ? rowIdBlock : page.getBlock(channel); - } - return new Page(page.getPositionCount(), fullPage); - } - - @Override - public void close() - { - try { - delegate.close(); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - @Override - public String toString() - { - return delegate.toString(); - } - - @Override - public long getMemoryUsage() - { - return delegate.getMemoryUsage(); - } - - @Override - public Metrics getMetrics() - { - return delegate.getMetrics(); - } - - protected void closeWithSuppression(Throwable throwable) - { - closeAllSuppress(throwable, this); - } -} diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index 021af69723a6..a6f1b11b959e 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -47,11 +47,8 @@ import io.trino.parquet.reader.ParquetReader; import io.trino.parquet.reader.RowGroupInfo; import io.trino.plugin.base.metrics.FileFormatDataSourceStats; -import io.trino.plugin.hive.ReaderColumns; -import io.trino.plugin.hive.ReaderPageSource; -import io.trino.plugin.hive.ReaderProjectionsAdapter; +import io.trino.plugin.hive.TransformConnectorPageSource; import io.trino.plugin.hive.orc.OrcPageSource; -import io.trino.plugin.hive.orc.OrcPageSource.ColumnAdaptation; import io.trino.plugin.hive.parquet.ParquetPageSource; import io.trino.plugin.iceberg.IcebergParquetColumnIOConverter.FieldContext; import io.trino.plugin.iceberg.delete.DeleteFile; @@ -75,6 +72,7 @@ import io.trino.spi.connector.DynamicFilter; import io.trino.spi.connector.EmptyPageSource; import io.trino.spi.connector.FixedPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.TupleDomain; @@ -99,7 +97,6 @@ import org.apache.iceberg.types.Types; import org.apache.iceberg.util.StructLikeWrapper; import org.apache.parquet.column.ColumnDescriptor; -import org.apache.parquet.io.ColumnIO; import org.apache.parquet.io.MessageColumnIO; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; @@ -117,10 +114,14 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; +import java.util.function.ObjLongConsumer; import java.util.function.Supplier; +import java.util.stream.IntStream; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Suppliers.memoize; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -167,25 +168,22 @@ import static io.trino.plugin.iceberg.IcebergUtil.getPartitionValues; import static io.trino.plugin.iceberg.IcebergUtil.schemaFromHandles; import static io.trino.plugin.iceberg.util.OrcIcebergIds.fileColumnsByIcebergId; -import static io.trino.plugin.iceberg.util.OrcTypeConverter.ICEBERG_BINARY_TYPE; import static io.trino.plugin.iceberg.util.OrcTypeConverter.ORC_ICEBERG_ID_KEY; import static io.trino.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; import static io.trino.spi.predicate.Utils.nativeValueToBlock; -import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; import static io.trino.spi.type.TimeZoneKey.UTC_KEY; -import static io.trino.spi.type.UuidType.UUID; import static java.lang.Math.min; import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Locale.ENGLISH; +import static java.util.Objects.checkIndex; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; import static java.util.function.Predicate.not; import static java.util.stream.Collectors.groupingBy; import static java.util.stream.Collectors.mapping; -import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toUnmodifiableList; import static org.apache.iceberg.FileContent.EQUALITY_DELETES; import static org.apache.iceberg.FileContent.POSITION_DELETES; @@ -290,16 +288,7 @@ public ConnectorPageSource createPageSource( long dataSequenceNumber, Optional nameMapping) { - Set deleteFilterRequiredColumns = requiredColumnsForDeletes(tableSchema, deletes); - String partition = partitionSpec.partitionToPath(partitionData); Map> partitionKeys = getPartitionKeys(partitionData, partitionSpec); - - List requiredColumns = new ArrayList<>(icebergColumns); - - deleteFilterRequiredColumns.stream() - .filter(not(icebergColumns::contains)) - .forEach(requiredColumns::add); - TupleDomain effectivePredicate = getUnenforcedPredicate( tableSchema, partitionKeys, @@ -310,14 +299,15 @@ public ConnectorPageSource createPageSource( return new EmptyPageSource(); } + // exit early when only reading partition keys from a simple split + String partition = partitionSpec.partitionToPath(partitionData); TrinoFileSystem fileSystem = fileSystemFactory.create(session.getIdentity(), fileIoProperties); - TrinoInputFile inputfile = isUseFileSizeFromMetadata(session) + TrinoInputFile inputFile = isUseFileSizeFromMetadata(session) ? fileSystem.newInputFile(Location.of(path), fileSize) : fileSystem.newInputFile(Location.of(path)); - try { if (effectivePredicate.isAll() && - start == 0 && length == inputfile.length() && + start == 0 && length == inputFile.length() && deletes.isEmpty() && icebergColumns.stream().allMatch(column -> partitionKeys.containsKey(column.getId()))) { return generatePages( @@ -330,12 +320,21 @@ public ConnectorPageSource createPageSource( throw new UncheckedIOException(e); } + List requiredColumns = new ArrayList<>(icebergColumns); + + Set deleteFilterRequiredColumns = requiredColumnsForDeletes(tableSchema, deletes); + deleteFilterRequiredColumns.stream() + .filter(not(icebergColumns::contains)) + .forEach(requiredColumns::add); + ReaderPageSourceWithRowPositions readerPageSourceWithRowPositions = createDataPageSource( session, - inputfile, + inputFile, start, length, fileSize, + partitionSpec.specId(), + partitionDataJson, fileFormat, tableSchema, requiredColumns, @@ -343,36 +342,36 @@ public ConnectorPageSource createPageSource( nameMapping, partition, partitionKeys); - ReaderPageSource dataPageSource = readerPageSourceWithRowPositions.readerPageSource(); - - Optional projectionsAdapter = dataPageSource.getReaderColumns().map(readerColumns -> - new ReaderProjectionsAdapter( - requiredColumns, - readerColumns, - column -> ((IcebergColumnHandle) column).getType(), - IcebergPageSourceProvider::applyProjection)); - - List readColumns = dataPageSource.getReaderColumns() - .map(readerColumns -> readerColumns.get().stream().map(IcebergColumnHandle.class::cast).collect(toList())) - .orElse(requiredColumns); - - Supplier> deletePredicate = memoize(() -> getDeleteManager(partitionSpec, partitionData) - .getDeletePredicate( - path, - dataSequenceNumber, - deletes, - readColumns, - tableSchema, - readerPageSourceWithRowPositions, - (deleteFile, deleteColumns, tupleDomain) -> openDeletes(session, fileSystem, deleteFile, deleteColumns, tupleDomain))); - - return new IcebergPageSource( - icebergColumns, - requiredColumns, - dataPageSource.get(), - projectionsAdapter, - deletePredicate, - MergeRowIdBlockFactory.create(utf8Slice(inputfile.location().toString()), partitionSpec.specId(), utf8Slice(partitionDataJson))); + + ConnectorPageSource pageSource = readerPageSourceWithRowPositions.pageSource(); + + // filter out deleted rows + if (!deletes.isEmpty()) { + Supplier> deletePredicate = memoize(() -> getDeleteManager(partitionSpec, partitionData) + .getDeletePredicate( + path, + dataSequenceNumber, + deletes, + requiredColumns, + tableSchema, + readerPageSourceWithRowPositions, + (deleteFile, deleteColumns, tupleDomain) -> openDeletes(session, fileSystem, deleteFile, deleteColumns, tupleDomain))); + pageSource = TransformConnectorPageSource.create(pageSource, page -> { + try { + Optional rowPredicate = deletePredicate.get(); + rowPredicate.ifPresent(predicate -> predicate.applyFilter(page)); + if (icebergColumns.size() == page.getChannelCount()) { + return page; + } + return new PrefixColumnsSourcePage(page, icebergColumns.size()); + } + catch (RuntimeException e) { + throwIfInstanceOf(e, TrinoException.class); + throw new TrinoException(ICEBERG_BAD_DATA, e); + } + }); + } + return pageSource; } private DeleteManager getDeleteManager(PartitionSpec partitionSpec, PartitionData partitionData) @@ -474,6 +473,8 @@ private ConnectorPageSource openDeletes( 0, delete.fileSizeInBytes(), delete.fileSizeInBytes(), + 0, + "", IcebergFileFormat.fromIceberg(delete.format()), schemaFromHandles(columns), columns, @@ -481,8 +482,7 @@ private ConnectorPageSource openDeletes( Optional.empty(), "", ImmutableMap.of()) - .readerPageSource() - .get(); + .pageSource(); } private ReaderPageSourceWithRowPositions createDataPageSource( @@ -491,6 +491,8 @@ private ReaderPageSourceWithRowPositions createDataPageSource( long start, long length, long fileSize, + int partitionSpecId, + String partitionData, IcebergFileFormat fileFormat, Schema fileSchema, List dataColumns, @@ -504,6 +506,8 @@ private ReaderPageSourceWithRowPositions createDataPageSource( inputFile, start, length, + partitionSpecId, + partitionData, dataColumns, predicate, orcReaderOptions @@ -525,6 +529,8 @@ private ReaderPageSourceWithRowPositions createDataPageSource( start, length, fileSize, + partitionSpecId, + partitionData, dataColumns, parquetReaderOptions .withMaxReadBlockSize(getParquetMaxReadBlockSize(session)) @@ -544,6 +550,8 @@ private ReaderPageSourceWithRowPositions createDataPageSource( inputFile, start, length, + partitionSpecId, + partitionData, fileSchema, nameMapping, partition, @@ -590,6 +598,8 @@ private static ReaderPageSourceWithRowPositions createOrcPageSource( TrinoInputFile inputFile, long start, long length, + int partitionSpecId, + String partitionData, List columns, TupleDomain effectivePredicate, OrcReaderOptions options, @@ -612,76 +622,84 @@ private static ReaderPageSourceWithRowPositions createOrcPageSource( .setBloomFiltersEnabled(options.isBloomFiltersEnabled()); Map effectivePredicateDomains = effectivePredicate.getDomains() .orElseThrow(() -> new IllegalArgumentException("Effective predicate is none")); + for (IcebergColumnHandle column : columns) { + for (Map.Entry domainEntry : effectivePredicateDomains.entrySet()) { + IcebergColumnHandle predicateColumn = domainEntry.getKey(); + OrcColumn predicateOrcColumn = fileColumnsByIcebergId.get(predicateColumn.getId()); + if (predicateOrcColumn != null && column.getBaseColumnIdentity().equals(predicateColumn.getBaseColumnIdentity())) { + predicateBuilder.addColumn(predicateOrcColumn.getColumnId(), domainEntry.getValue()); + } + } + } - Optional baseColumnProjections = projectBaseColumns(columns); Map>> projectionsByFieldId = columns.stream() .collect(groupingBy( column -> column.getBaseColumnIdentity().getId(), mapping(IcebergColumnHandle::getPath, toUnmodifiableList()))); - List readBaseColumns = baseColumnProjections - .map(readerColumns -> (List) readerColumns.get().stream().map(IcebergColumnHandle.class::cast).collect(toImmutableList())) - .orElse(columns); - List fileReadColumns = new ArrayList<>(readBaseColumns.size()); - List fileReadTypes = new ArrayList<>(readBaseColumns.size()); - List projectedLayouts = new ArrayList<>(readBaseColumns.size()); - List columnAdaptations = new ArrayList<>(readBaseColumns.size()); - - for (IcebergColumnHandle column : readBaseColumns) { - verify(column.isBaseColumn(), "Column projections must be based from a root column"); - OrcColumn orcColumn = fileColumnsByIcebergId.get(column.getId()); + List baseColumns = new ArrayList<>(columns.size()); + Map baseColumnIdToOrdinal = new HashMap<>(); + List fileReadColumns = new ArrayList<>(columns.size()); + List fileReadTypes = new ArrayList<>(columns.size()); + List projectedLayouts = new ArrayList<>(columns.size()); + TransformConnectorPageSource.Builder transforms = TransformConnectorPageSource.builder(); + boolean appendRowNumberColumn = false; + for (IcebergColumnHandle column : columns) { if (column.isIsDeletedColumn()) { - columnAdaptations.add(ColumnAdaptation.constantColumn(nativeValueToBlock(BOOLEAN, false))); + transforms.constantValue(nativeValueToBlock(BOOLEAN, false)); } else if (partitionKeys.containsKey(column.getId())) { Type trinoType = column.getType(); - columnAdaptations.add(ColumnAdaptation.constantColumn(nativeValueToBlock( + transforms.constantValue(nativeValueToBlock( trinoType, - deserializePartitionValue(trinoType, partitionKeys.get(column.getId()).orElse(null), column.getName())))); + deserializePartitionValue(trinoType, partitionKeys.get(column.getId()).orElse(null), column.getName()))); } else if (column.isPartitionColumn()) { - columnAdaptations.add(ColumnAdaptation.constantColumn(nativeValueToBlock(PARTITION.getType(), utf8Slice(partition)))); + transforms.constantValue(nativeValueToBlock(PARTITION.getType(), utf8Slice(partition))); } else if (column.isPathColumn()) { - columnAdaptations.add(ColumnAdaptation.constantColumn(nativeValueToBlock(FILE_PATH.getType(), utf8Slice(inputFile.location().toString())))); + transforms.constantValue(nativeValueToBlock(FILE_PATH.getType(), utf8Slice(inputFile.location().toString()))); } else if (column.isFileModifiedTimeColumn()) { - columnAdaptations.add(ColumnAdaptation.constantColumn(nativeValueToBlock(FILE_MODIFIED_TIME.getType(), packDateTimeWithZone(inputFile.lastModified().toEpochMilli(), UTC_KEY)))); + transforms.constantValue(nativeValueToBlock(FILE_MODIFIED_TIME.getType(), packDateTimeWithZone(inputFile.lastModified().toEpochMilli(), UTC_KEY))); } else if (column.isMergeRowIdColumn()) { - // The merge $row_id is a composite of the row position and constant file information. The final value is assembled in IcebergPageSource - columnAdaptations.add(ColumnAdaptation.positionColumn()); + appendRowNumberColumn = true; + transforms.transform(MergeRowIdTransform.create(utf8Slice(inputFile.location().toString()), partitionSpecId, utf8Slice(partitionData))); } else if (column.isRowPositionColumn()) { - columnAdaptations.add(ColumnAdaptation.positionColumn()); + appendRowNumberColumn = true; + transforms.transform(new GetRowPositionFromSource()); } - else if (orcColumn != null) { - Type readType = getOrcReadType(column.getType(), typeManager); - - if (column.getType() == UUID && !"UUID".equals(orcColumn.getAttributes().get(ICEBERG_BINARY_TYPE))) { - throw new TrinoException(ICEBERG_BAD_DATA, format("Expected ORC column for UUID data to be annotated with %s=UUID: %s", ICEBERG_BINARY_TYPE, orcColumn)); + else if (!fileColumnsByIcebergId.containsKey(column.getBaseColumnIdentity().getId())) { + transforms.constantValue(column.getType().createNullBlock()); + } + else { + IcebergColumnHandle baseColumn = column.getBaseColumn(); + Integer ordinal = baseColumnIdToOrdinal.get(baseColumn.getId()); + if (ordinal == null) { + ordinal = baseColumns.size(); + baseColumns.add(baseColumn); + baseColumnIdToOrdinal.put(baseColumn.getId(), ordinal); + + OrcColumn orcBaseColumn = requireNonNull(fileColumnsByIcebergId.get(baseColumn.getId())); + fileReadColumns.add(orcBaseColumn); + fileReadTypes.add(getOrcReadType(baseColumn.getType(), typeManager)); + projectedLayouts.add(IcebergOrcProjectedLayout.createProjectedLayout( + orcBaseColumn, + projectionsByFieldId.get(baseColumn.getId()))); } - List> fieldIdProjections = projectionsByFieldId.get(column.getId()); - ProjectedLayout projectedLayout = IcebergOrcProjectedLayout.createProjectedLayout(orcColumn, fieldIdProjections); - - int sourceIndex = fileReadColumns.size(); - columnAdaptations.add(ColumnAdaptation.sourceColumn(sourceIndex)); - fileReadColumns.add(orcColumn); - fileReadTypes.add(readType); - projectedLayouts.add(projectedLayout); - - for (Map.Entry domainEntry : effectivePredicateDomains.entrySet()) { - IcebergColumnHandle predicateColumn = domainEntry.getKey(); - OrcColumn predicateOrcColumn = fileColumnsByIcebergId.get(predicateColumn.getId()); - if (predicateOrcColumn != null && column.getColumnIdentity().equals(predicateColumn.getBaseColumnIdentity())) { - predicateBuilder.addColumn(predicateOrcColumn.getColumnId(), domainEntry.getValue()); - } + if (column.isBaseColumn()) { + transforms.column(ordinal); + } + else { + transforms.dereferenceField(ImmutableList.builder() + .add(ordinal) + .addAll(applyProjection(column, baseColumn)) + .build()); } - } - else { - columnAdaptations.add(ColumnAdaptation.nullColumn(column.getType())); } } @@ -691,6 +709,7 @@ else if (orcColumn != null) { fileReadColumns, fileReadTypes, projectedLayouts, + appendRowNumberColumn, predicateBuilder.build(), start, length, @@ -698,20 +717,21 @@ else if (orcColumn != null) { memoryUsage, INITIAL_BATCH_SIZE, exception -> handleException(orcDataSourceId, exception), - new IdBasedFieldMapperFactory(readBaseColumns)); + new IdBasedFieldMapperFactory(baseColumns)); + + ConnectorPageSource pageSource = new OrcPageSource( + recordReader, + orcDataSource, + Optional.empty(), + Optional.empty(), + memoryUsage, + stats, + reader.getCompressionKind()); + + pageSource = transforms.build(pageSource); return new ReaderPageSourceWithRowPositions( - new ReaderPageSource( - new OrcPageSource( - recordReader, - columnAdaptations, - orcDataSource, - Optional.empty(), - Optional.empty(), - memoryUsage, - stats, - reader.getCompressionKind()), - baseColumnProjections), + pageSource, recordReader.getStartRowPosition(), recordReader.getEndRowPosition()); } @@ -862,7 +882,9 @@ private static ReaderPageSourceWithRowPositions createParquetPageSource( long start, long length, long fileSize, - List regularColumns, + int partitionSpecId, + String partitionData, + List columns, ParquetReaderOptions options, TupleDomain effectivePredicate, FileFormatDataSourceStats fileFormatDataSourceStats, @@ -884,98 +906,101 @@ private static ReaderPageSourceWithRowPositions createParquetPageSource( } // Mapping from Iceberg field ID to Parquet fields. - Map parquetIdToField = createParquetIdToFieldMapping(fileSchema); - - Optional baseColumnProjections = projectBaseColumns(regularColumns); - List readBaseColumns = baseColumnProjections - .map(readerColumns -> (List) readerColumns.get().stream().map(IcebergColumnHandle.class::cast).collect(toImmutableList())) - .orElse(regularColumns); - - List parquetFields = readBaseColumns.stream() - .map(column -> parquetIdToField.get(column.getId())) - .toList(); + Map parquetIdToFieldName = createParquetIdToFieldMapping(fileSchema); - MessageType requestedSchema = getMessageType(regularColumns, fileSchema.getName(), parquetIdToField); + MessageType requestedSchema = getMessageType(columns, fileSchema.getName(), parquetIdToFieldName); Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, requestedSchema); TupleDomain parquetTupleDomain = options.isIgnoreStatistics() ? TupleDomain.all() : getParquetTupleDomain(descriptorsByPath, effectivePredicate); TupleDomainParquetPredicate parquetPredicate = buildPredicate(requestedSchema, parquetTupleDomain, descriptorsByPath, UTC); - List rowGroups = getFilteredRowGroups( - start, - length, - dataSource, - parquetMetadata, - ImmutableList.of(parquetTupleDomain), - ImmutableList.of(parquetPredicate), - descriptorsByPath, - UTC, - ICEBERG_DOMAIN_COMPACTION_THRESHOLD, - options); - Optional startRowPosition = Optional.empty(); - Optional endRowPosition = Optional.empty(); - if (!rowGroups.isEmpty()) { - startRowPosition = Optional.of(rowGroups.getFirst().fileRowOffset()); - RowGroupInfo lastRowGroup = rowGroups.getLast(); - endRowPosition = Optional.of(lastRowGroup.fileRowOffset() + lastRowGroup.prunedBlockMetadata().getRowCount()); - } - MessageColumnIO messageColumnIO = getColumnIO(fileSchema, requestedSchema); - ParquetPageSource.Builder pageSourceBuilder = ParquetPageSource.builder(); - int parquetSourceChannel = 0; - + Map baseColumnIdToOrdinal = new HashMap<>(); + TransformConnectorPageSource.Builder transforms = TransformConnectorPageSource.builder(); + boolean appendRowNumberColumn = false; + int nextOrdinal = 0; ImmutableList.Builder parquetColumnFieldsBuilder = ImmutableList.builder(); - for (int columnIndex = 0; columnIndex < readBaseColumns.size(); columnIndex++) { - IcebergColumnHandle column = readBaseColumns.get(columnIndex); + for (IcebergColumnHandle column : columns) { if (column.isIsDeletedColumn()) { - pageSourceBuilder.addConstantColumn(nativeValueToBlock(BOOLEAN, false)); + transforms.constantValue(nativeValueToBlock(BOOLEAN, false)); } else if (partitionKeys.containsKey(column.getId())) { Type trinoType = column.getType(); - pageSourceBuilder.addConstantColumn(nativeValueToBlock( + transforms.constantValue(nativeValueToBlock( trinoType, deserializePartitionValue(trinoType, partitionKeys.get(column.getId()).orElse(null), column.getName()))); } else if (column.isPartitionColumn()) { - pageSourceBuilder.addConstantColumn(nativeValueToBlock(PARTITION.getType(), utf8Slice(partition))); + transforms.constantValue(nativeValueToBlock(PARTITION.getType(), utf8Slice(partition))); } else if (column.isPathColumn()) { - pageSourceBuilder.addConstantColumn(nativeValueToBlock(FILE_PATH.getType(), utf8Slice(inputFile.location().toString()))); + transforms.constantValue(nativeValueToBlock(FILE_PATH.getType(), utf8Slice(inputFile.location().toString()))); } else if (column.isFileModifiedTimeColumn()) { - pageSourceBuilder.addConstantColumn(nativeValueToBlock(FILE_MODIFIED_TIME.getType(), packDateTimeWithZone(inputFile.lastModified().toEpochMilli(), UTC_KEY))); + transforms.constantValue(nativeValueToBlock(FILE_MODIFIED_TIME.getType(), packDateTimeWithZone(inputFile.lastModified().toEpochMilli(), UTC_KEY))); } else if (column.isMergeRowIdColumn()) { - // The merge $row_id is a composite of the row position and constant file information. The final value is assembled in IcebergPageSource - pageSourceBuilder.addRowIndexColumn(); + appendRowNumberColumn = true; + transforms.transform(MergeRowIdTransform.create(utf8Slice(inputFile.location().toString()), partitionSpecId, utf8Slice(partitionData))); } else if (column.isRowPositionColumn()) { - pageSourceBuilder.addRowIndexColumn(); + appendRowNumberColumn = true; + transforms.transform(new GetRowPositionFromSource()); + } + else if (!parquetIdToFieldName.containsKey(column.getBaseColumn().getId())) { + transforms.constantValue(column.getType().createNullBlock()); } else { - org.apache.parquet.schema.Type parquetField = parquetFields.get(columnIndex); - Type trinoType = column.getBaseType(); - if (parquetField == null) { - pageSourceBuilder.addNullColumn(trinoType); - continue; + IcebergColumnHandle baseColumn = column.getBaseColumn(); + Integer ordinal = baseColumnIdToOrdinal.get(baseColumn.getId()); + if (ordinal == null) { + String parquetFieldName = requireNonNull(parquetIdToFieldName.get(baseColumn.getId())).getName(); + + // The top level columns are already mapped by name/id appropriately. + Optional field = IcebergParquetColumnIOConverter.constructField( + new FieldContext(baseColumn.getType(), baseColumn.getColumnIdentity()), + messageColumnIO.getChild(parquetFieldName)); + if (field.isEmpty()) { + // base column is missing so return a null + transforms.constantValue(column.getType().createNullBlock()); + continue; + } + + ordinal = nextOrdinal; + nextOrdinal++; + baseColumnIdToOrdinal.put(baseColumn.getId(), ordinal); + + parquetColumnFieldsBuilder.add(new Column(parquetFieldName, field.get())); + } + if (column.isBaseColumn()) { + transforms.column(ordinal); } - // The top level columns are already mapped by name/id appropriately. - ColumnIO columnIO = messageColumnIO.getChild(parquetField.getName()); - Optional field = IcebergParquetColumnIOConverter.constructField(new FieldContext(trinoType, column.getColumnIdentity()), columnIO); - if (field.isEmpty()) { - pageSourceBuilder.addNullColumn(trinoType); - continue; + else { + transforms.dereferenceField(ImmutableList.builder() + .add(ordinal) + .addAll(applyProjection(column, baseColumn)) + .build()); } - parquetColumnFieldsBuilder.add(new Column(parquetField.getName(), field.get())); - pageSourceBuilder.addSourceColumn(parquetSourceChannel); - parquetSourceChannel++; } } + List rowGroups = getFilteredRowGroups( + start, + length, + dataSource, + parquetMetadata, + ImmutableList.of(parquetTupleDomain), + ImmutableList.of(parquetPredicate), + descriptorsByPath, + UTC, + ICEBERG_DOMAIN_COMPACTION_THRESHOLD, + options); + ParquetDataSourceId dataSourceId = dataSource.getId(); ParquetReader parquetReader = new ParquetReader( Optional.ofNullable(fileMetaData.getCreatedBy()), parquetColumnFieldsBuilder.build(), + appendRowNumberColumn, rowGroups, dataSource, UTC, @@ -984,10 +1009,20 @@ else if (column.isRowPositionColumn()) { exception -> handleException(dataSourceId, exception), Optional.empty(), Optional.empty()); + + ConnectorPageSource pageSource = new ParquetPageSource(parquetReader); + pageSource = transforms.build(pageSource); + + Optional startRowPosition = Optional.empty(); + Optional endRowPosition = Optional.empty(); + if (!rowGroups.isEmpty()) { + startRowPosition = Optional.of(rowGroups.getFirst().fileRowOffset()); + RowGroupInfo lastRowGroup = rowGroups.getLast(); + endRowPosition = Optional.of(lastRowGroup.fileRowOffset() + lastRowGroup.prunedBlockMetadata().getRowCount()); + } + return new ReaderPageSourceWithRowPositions( - new ReaderPageSource( - pageSourceBuilder.build(parquetReader), - baseColumnProjections), + pageSource, startRowPosition, endRowPosition); } @@ -1040,10 +1075,7 @@ else if (type instanceof GroupType groupType) { private static MessageType getMessageType(List regularColumns, String fileSchemaName, Map parquetIdToField) { - return projectSufficientColumns(regularColumns) - .map(readerColumns -> readerColumns.get().stream().map(IcebergColumnHandle.class::cast).toList()) - .orElse(regularColumns) - .stream() + return projectSufficientColumns(regularColumns).stream() .map(column -> getColumnType(column, parquetIdToField)) .filter(Optional::isPresent) .map(Optional::get) @@ -1056,24 +1088,17 @@ private static ReaderPageSourceWithRowPositions createAvroPageSource( TrinoInputFile inputFile, long start, long length, + int partitionSpecId, + String partitionData, Schema fileSchema, Optional nameMapping, String partition, List columns) { - ConstantPopulatingPageSource.Builder constantPopulatingPageSourceBuilder = ConstantPopulatingPageSource.builder(); - int avroSourceChannel = 0; - - Optional baseColumnProjections = projectBaseColumns(columns); - - List readBaseColumns = baseColumnProjections - .map(readerColumns -> (List) readerColumns.get().stream().map(IcebergColumnHandle.class::cast).collect(toImmutableList())) - .orElse(columns); - InputFile file = new ForwardingInputFile(inputFile); OptionalLong fileModifiedTime = OptionalLong.empty(); try { - if (readBaseColumns.stream().anyMatch(IcebergColumnHandle::isFileModifiedTimeColumn)) { + if (columns.stream().anyMatch(IcebergColumnHandle::isFileModifiedTimeColumn)) { fileModifiedTime = OptionalLong.of(inputFile.lastModified().toEpochMilli()); } } @@ -1095,55 +1120,70 @@ private static ReaderPageSourceWithRowPositions createAvroPageSource( ImmutableList.Builder columnNames = ImmutableList.builder(); ImmutableList.Builder columnTypes = ImmutableList.builder(); - ImmutableList.Builder rowIndexChannels = ImmutableList.builder(); - - for (IcebergColumnHandle column : readBaseColumns) { - verify(column.isBaseColumn(), "Column projections must be based from a root column"); - org.apache.avro.Schema.Field field = fileColumnsByIcebergId.get(column.getId()); + TransformConnectorPageSource.Builder transforms = TransformConnectorPageSource.builder(); + boolean appendRowNumberColumn = false; + Map baseColumnIdToOrdinal = new HashMap<>(); + int nextOrdinal = 0; + for (IcebergColumnHandle column : columns) { if (column.isPartitionColumn()) { - constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(PARTITION.getType(), utf8Slice(partition))); + transforms.constantValue(nativeValueToBlock(PARTITION.getType(), utf8Slice(partition))); } else if (column.isPathColumn()) { - constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(FILE_PATH.getType(), utf8Slice(file.location()))); + transforms.constantValue(nativeValueToBlock(FILE_PATH.getType(), utf8Slice(file.location()))); } else if (column.isFileModifiedTimeColumn()) { - constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(FILE_MODIFIED_TIME.getType(), packDateTimeWithZone(fileModifiedTime.orElseThrow(), UTC_KEY))); + transforms.constantValue(nativeValueToBlock(FILE_MODIFIED_TIME.getType(), packDateTimeWithZone(fileModifiedTime.orElseThrow(), UTC_KEY))); + } + else if (column.isMergeRowIdColumn()) { + appendRowNumberColumn = true; + transforms.transform(MergeRowIdTransform.create(utf8Slice(file.location()), partitionSpecId, utf8Slice(partitionData))); } - // For delete - else if (column.isMergeRowIdColumn() || column.isRowPositionColumn()) { - // The merge $row_id is a composite of the row position and constant file information. The final value is assembled in IcebergPageSource - rowIndexChannels.add(true); - columnNames.add(ROW_POSITION.name()); - columnTypes.add(BIGINT); - constantPopulatingPageSourceBuilder.addDelegateColumn(avroSourceChannel); - avroSourceChannel++; + else if (column.isRowPositionColumn()) { + appendRowNumberColumn = true; + transforms.transform(new GetRowPositionFromSource()); } - else if (field == null) { - constantPopulatingPageSourceBuilder.addConstantColumn(nativeValueToBlock(column.getType(), null)); + else if (!fileColumnsByIcebergId.containsKey(column.getBaseColumn().getId())) { + transforms.constantValue(nativeValueToBlock(column.getType(), null)); } else { - rowIndexChannels.add(false); - columnNames.add(column.getName()); - columnTypes.add(column.getType()); - constantPopulatingPageSourceBuilder.addDelegateColumn(avroSourceChannel); - avroSourceChannel++; + IcebergColumnHandle baseColumn = column.getBaseColumn(); + Integer ordinal = baseColumnIdToOrdinal.get(baseColumn.getId()); + if (ordinal == null) { + ordinal = nextOrdinal; + nextOrdinal++; + baseColumnIdToOrdinal.put(baseColumn.getId(), ordinal); + + columnNames.add(baseColumn.getName()); + columnTypes.add(baseColumn.getType()); + } + + if (column.isBaseColumn()) { + transforms.column(ordinal); + } + else { + transforms.dereferenceField(ImmutableList.builder() + .add(ordinal) + .addAll(applyProjection(column, baseColumn)) + .build()); + } } } + ConnectorPageSource pageSource = new IcebergAvroPageSource( + file, + start, + length, + fileSchema, + nameMapping, + columnNames.build(), + columnTypes.build(), + appendRowNumberColumn, + newSimpleAggregatedMemoryContext()); + pageSource = transforms.build(pageSource); + return new ReaderPageSourceWithRowPositions( - new ReaderPageSource( - constantPopulatingPageSourceBuilder.build(new IcebergAvroPageSource( - file, - start, - length, - fileSchema, - nameMapping, - columnNames.build(), - columnTypes.build(), - rowIndexChannels.build(), - newSimpleAggregatedMemoryContext())), - baseColumnProjections), + pageSource, Optional.empty(), Optional.empty()); } @@ -1246,52 +1286,17 @@ public ProjectedLayout getFieldLayout(OrcColumn orcColumn) } } - /** - * Creates a mapping between the input {@code columns} and base columns if required. - */ - private static Optional projectBaseColumns(List columns) - { - requireNonNull(columns, "columns is null"); - - // No projection is required if all columns are base columns - if (columns.stream().allMatch(IcebergColumnHandle::isBaseColumn)) { - return Optional.empty(); - } - - ImmutableList.Builder projectedColumns = ImmutableList.builder(); - ImmutableList.Builder outputColumnMapping = ImmutableList.builder(); - Map mappedFieldIds = new HashMap<>(); - int projectedColumnCount = 0; - - for (IcebergColumnHandle column : columns) { - int baseColumnId = column.getBaseColumnIdentity().getId(); - Integer mapped = mappedFieldIds.get(baseColumnId); - - if (mapped == null) { - projectedColumns.add(column.getBaseColumn()); - mappedFieldIds.put(baseColumnId, projectedColumnCount); - outputColumnMapping.add(projectedColumnCount); - projectedColumnCount++; - } - else { - outputColumnMapping.add(mapped); - } - } - - return Optional.of(new ReaderColumns(projectedColumns.build(), outputColumnMapping.build())); - } - /** * Creates a set of sufficient columns for the input projected columns and prepares a mapping between the two. * For example, if input columns include columns "a.b" and "a.b.c", then they will be projected * from a single column "a.b". */ - private static Optional projectSufficientColumns(List columns) + private static List projectSufficientColumns(List columns) { requireNonNull(columns, "columns is null"); if (columns.stream().allMatch(IcebergColumnHandle::isBaseColumn)) { - return Optional.empty(); + return columns; } ImmutableBiMap.Builder dereferenceChainsBuilder = ImmutableBiMap.builder(); @@ -1303,8 +1308,7 @@ private static Optional projectSufficientColumns(List dereferenceChains = dereferenceChainsBuilder.build(); - List sufficientColumns = new ArrayList<>(); - ImmutableList.Builder outputColumnMapping = ImmutableList.builder(); + List sufficientColumns = new ArrayList<>(); Map pickedColumns = new HashMap<>(); @@ -1322,23 +1326,15 @@ private static Optional projectSufficientColumns(List getColumnType(IcebergColumnHandle column, Map parquetIdToField) @@ -1414,11 +1410,14 @@ private static TrinoException handleException(ParquetDataSourceId dataSourceId, return new TrinoException(ICEBERG_CURSOR_ERROR, format("Failed to read Parquet file: %s", dataSourceId), exception); } - public record ReaderPageSourceWithRowPositions(ReaderPageSource readerPageSource, Optional startRowPosition, Optional endRowPosition) + public record ReaderPageSourceWithRowPositions( + ConnectorPageSource pageSource, + Optional startRowPosition, + Optional endRowPosition) { public ReaderPageSourceWithRowPositions { - requireNonNull(readerPageSource, "readerPageSource is null"); + requireNonNull(pageSource, "pageSource is null"); requireNonNull(startRowPosition, "startRowPosition is null"); requireNonNull(endRowPosition, "endRowPosition is null"); } @@ -1477,26 +1476,113 @@ public int hashCode() } } - private record MergeRowIdBlockFactory(VariableWidthBlock filePath, IntArrayBlock partitionSpecId, VariableWidthBlock partitionData) - implements Function + private record MergeRowIdTransform(VariableWidthBlock filePath, IntArrayBlock partitionSpecId, VariableWidthBlock partitionData) + implements Function { - private static Function create(Slice filePath, int partitionSpecId, Slice partitionData) + private static Function create(Slice filePath, int partitionSpecId, Slice partitionData) { - return new MergeRowIdBlockFactory( + return new MergeRowIdTransform( new VariableWidthBlock(1, filePath, new int[] {0, filePath.length()}, Optional.empty()), new IntArrayBlock(1, Optional.empty(), new int[] {partitionSpecId}), new VariableWidthBlock(1, partitionData, new int[] {0, partitionData.length()}, Optional.empty())); } @Override - public RowBlock apply(Block rowPosition) + public Block apply(SourcePage page) { - return RowBlock.fromFieldBlocks(rowPosition.getPositionCount(), new Block[] { + Block rowPosition = page.getBlock(page.getChannelCount() - 1); + Block[] fields = new Block[] { RunLengthEncodedBlock.create(filePath, rowPosition.getPositionCount()), rowPosition, RunLengthEncodedBlock.create(partitionSpecId, rowPosition.getPositionCount()), RunLengthEncodedBlock.create(partitionData, rowPosition.getPositionCount()) - }); + }; + return RowBlock.fromFieldBlocks(rowPosition.getPositionCount(), fields); + } + } + + private record GetRowPositionFromSource() + implements Function + { + @Override + public Block apply(SourcePage page) + { + return page.getBlock(page.getChannelCount() - 1); + } + } + + private record PrefixColumnsSourcePage(SourcePage sourcePage, int channelCount, int[] channels) + implements SourcePage + { + private PrefixColumnsSourcePage + { + requireNonNull(sourcePage, "sourcePage is null"); + checkArgument(channelCount >= 0, "channelCount is negative"); + checkArgument(channelCount < sourcePage.getChannelCount(), "channelCount is greater than or equal to sourcePage channel count"); + checkArgument(channels.length == channelCount, "channels length does not match channelCount"); + } + + private PrefixColumnsSourcePage(SourcePage sourcePage, int channelCount) + { + this(sourcePage, channelCount, IntStream.range(0, channelCount).toArray()); + } + + @Override + public int getPositionCount() + { + return sourcePage.getPositionCount(); + } + + @Override + public long getSizeInBytes() + { + return sourcePage.getSizeInBytes(); + } + + @Override + public long getRetainedSizeInBytes() + { + return sourcePage.getRetainedSizeInBytes(); + } + + @Override + public void retainedBytesForEachPart(ObjLongConsumer consumer) + { + sourcePage.retainedBytesForEachPart(consumer); + } + + @Override + public int getChannelCount() + { + return channelCount; + } + + @Override + public Block getBlock(int channel) + { + checkIndex(channel, channelCount); + return sourcePage.getBlock(channel); + } + + @Override + public Page getPage() + { + return sourcePage.getColumns(channels); + } + + @Override + public Page getColumns(int[] channels) + { + for (int channel : channels) { + checkIndex(channel, channelCount); + } + return sourcePage.getColumns(channels); + } + + @Override + public void selectPositions(int[] positions, int offset, int size) + { + sourcePage.selectPositions(positions, offset, size); } } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/delete/EqualityDeleteFilter.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/delete/EqualityDeleteFilter.java index f59c6c6a0b68..f1845de3c987 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/delete/EqualityDeleteFilter.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/delete/EqualityDeleteFilter.java @@ -19,9 +19,9 @@ import com.google.common.util.concurrent.ListenableFutureTask; import io.trino.plugin.iceberg.IcebergColumnHandle; import io.trino.plugin.iceberg.delete.DeleteManager.DeletePageSourceProvider; -import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; import org.apache.iceberg.Schema; @@ -116,7 +116,7 @@ private void readEqualityDeletesInternal(DeleteFile deleteFile, List columns, long data { int filePosChannel = rowPositionChannel(columns); return (page, position) -> { - long filePos = BIGINT.getLong(page.getBlock(filePosChannel), position); + Block block = page.getBlock(filePosChannel); + long filePos = BIGINT.getLong(block, position); return !deletedRows.contains(filePos); }; } @@ -66,7 +67,7 @@ public static void readPositionDeletes(ConnectorPageSource pageSource, Slice tar // entries for a single path. The comparison cost is minimal due if the // path values are dictionary encoded, since we only do the comparison once. while (!pageSource.isFinished()) { - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); if (page == null) { continue; } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/delete/RowPredicate.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/delete/RowPredicate.java index 8ce6bbd8773a..8ce3e628d754 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/delete/RowPredicate.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/delete/RowPredicate.java @@ -14,14 +14,14 @@ package io.trino.plugin.iceberg.delete; import com.google.errorprone.annotations.ThreadSafe; -import io.trino.spi.Page; +import io.trino.spi.connector.SourcePage; import static java.util.Objects.requireNonNull; @ThreadSafe public interface RowPredicate { - boolean test(Page page, int position); + boolean test(SourcePage page, int position); default RowPredicate and(RowPredicate other) { @@ -29,7 +29,7 @@ default RowPredicate and(RowPredicate other) return (page, position) -> test(page, position) && other.test(page, position); } - default Page filterPage(Page page) + default void applyFilter(SourcePage page) { int positionCount = page.getPositionCount(); int[] retained = new int[positionCount]; @@ -40,9 +40,8 @@ default Page filterPage(Page page) retainedCount++; } } - if (retainedCount == positionCount) { - return page; + if (retainedCount != positionCount) { + page.selectPositions(retained, 0, retainedCount); } - return page.getPositions(retained, 0, retainedCount); } } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/delete/TrinoRow.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/delete/TrinoRow.java index 1852b8b5e98d..490ad34cc235 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/delete/TrinoRow.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/delete/TrinoRow.java @@ -13,7 +13,7 @@ */ package io.trino.plugin.iceberg.delete; -import io.trino.spi.Page; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import org.apache.iceberg.StructLike; @@ -27,7 +27,7 @@ final class TrinoRow { private final Object[] values; - public TrinoRow(Type[] types, Page page, int position) + public TrinoRow(Type[] types, SourcePage page, int position) { checkArgument(types.length == page.getChannelCount(), "mismatched types for page"); values = new Object[types.length]; diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/functions/tablechanges/TableChangesFunctionProcessor.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/functions/tablechanges/TableChangesFunctionProcessor.java index d93f661bf231..ba9eca5bf69b 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/functions/tablechanges/TableChangesFunctionProcessor.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/functions/tablechanges/TableChangesFunctionProcessor.java @@ -23,6 +23,7 @@ import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.SourcePage; import io.trino.spi.function.table.TableFunctionProcessorState; import io.trino.spi.function.table.TableFunctionSplitProcessor; import io.trino.spi.predicate.TupleDomain; @@ -153,7 +154,7 @@ public TableFunctionProcessorState process() return FINISHED; } - Page dataPage = pageSource.getNextPage(); + SourcePage dataPage = pageSource.getNextSourcePage(); if (dataPage == null) { return TableFunctionProcessorState.Processed.produced(EMPTY_PAGE); } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java index 452018889e3a..0f85b2f457ab 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/IcebergTestUtils.java @@ -43,10 +43,10 @@ import io.trino.plugin.iceberg.catalog.file.FileMetastoreTableOperationsProvider; import io.trino.plugin.iceberg.catalog.hms.TrinoHiveCatalog; import io.trino.plugin.iceberg.fileio.ForwardingInputFile; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.catalog.CatalogName; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.TestingTypeManager; import io.trino.spi.type.Type; import io.trino.testing.QueryRunner; @@ -115,14 +115,15 @@ private static boolean checkOrcFileSorting(Supplier dataSourceSup try (OrcRecordReader recordReader = orcReader.createRecordReader( List.of(sortColumn), List.of(sortColumnType), + false, OrcPredicate.TRUE, UTC, newSimpleAggregatedMemoryContext(), INITIAL_BATCH_SIZE, RuntimeException::new)) { Comparable previousMax = null; - for (Page page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) { - Block block = page.getLoadedPage().getBlock(0); + for (SourcePage page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) { + Block block = page.getBlock(0); for (int position = 0; position < block.getPositionCount(); position++) { Comparable current = (Comparable) readNativeValue(sortColumnType, block, position); if (previousMax != null && previousMax.compareTo(current) > 0) { diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergNodeLocalDynamicSplitPruning.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergNodeLocalDynamicSplitPruning.java index 01bcecfe8d50..91fb514201a8 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergNodeLocalDynamicSplitPruning.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergNodeLocalDynamicSplitPruning.java @@ -42,6 +42,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.DynamicFilter; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; @@ -183,7 +184,7 @@ public void testDynamicSplitPruningOnUnpartitionedTable() keyColumnHandle, Domain.singleValue(INTEGER, 1L))); try (ConnectorPageSource emptyPageSource = createTestingPageSource(transaction, icebergConfig, split, tableHandle, ImmutableList.of(keyColumnHandle, dataColumnHandle), getDynamicFilter(splitPruningPredicate))) { - assertThat(emptyPageSource.getNextPage()).isNull(); + assertThat(emptyPageSource.getNextSourcePage()).isNull(); } TupleDomain nonSelectivePredicate = TupleDomain.withColumnDomains( @@ -191,7 +192,7 @@ public void testDynamicSplitPruningOnUnpartitionedTable() keyColumnHandle, Domain.singleValue(INTEGER, (long) keyColumnValue))); try (ConnectorPageSource nonEmptyPageSource = createTestingPageSource(transaction, icebergConfig, split, tableHandle, ImmutableList.of(keyColumnHandle, dataColumnHandle), getDynamicFilter(nonSelectivePredicate))) { - Page page = nonEmptyPageSource.getNextPage(); + SourcePage page = nonEmptyPageSource.getNextSourcePage(); assertThat(page).isNotNull(); assertThat(page.getPositionCount()).isEqualTo(1); assertThat(INTEGER.getInt(page.getBlock(0), 0)).isEqualTo(keyColumnValue); @@ -240,11 +241,11 @@ public void testDynamicSplitPruningOnUnpartitionedTable() transaction); try (ConnectorPageSource emptyPageSource = createTestingPageSource(transaction, icebergConfig, split, tableHandle, ImmutableList.of(keyColumnHandle, dataColumnHandle), getDynamicFilter(splitPruningPredicate))) { - assertThat(emptyPageSource.getNextPage()).isNull(); + assertThat(emptyPageSource.getNextSourcePage()).isNull(); } try (ConnectorPageSource nonEmptyPageSource = createTestingPageSource(transaction, icebergConfig, split, tableHandle, ImmutableList.of(keyColumnHandle, dataColumnHandle), getDynamicFilter(nonSelectivePredicate))) { - Page page = nonEmptyPageSource.getNextPage(); + SourcePage page = nonEmptyPageSource.getNextSourcePage(); assertThat(page).isNotNull(); assertThat(page.getPositionCount()).isEqualTo(1); assertThat(INTEGER.getInt(page.getBlock(0), 0)).isEqualTo(keyColumnValue); @@ -369,7 +370,7 @@ public void testDynamicSplitPruningWithExplicitPartitionFilter() tableHandle, ImmutableList.of(dateColumnHandle, receiptColumnHandle, amountColumnHandle), getDynamicFilter(partitionPredicate))) { - assertThat(emptyPageSource.getNextPage()).isNull(); + assertThat(emptyPageSource.getNextSourcePage()).isNull(); } } @@ -389,7 +390,7 @@ public void testDynamicSplitPruningWithExplicitPartitionFilter() tableHandle, ImmutableList.of(dateColumnHandle, receiptColumnHandle, amountColumnHandle), getDynamicFilter(partitionPredicate))) { - Page page = nonEmptyPageSource.getNextPage(); + SourcePage page = nonEmptyPageSource.getNextSourcePage(); assertThat(page).isNotNull(); assertThat(page.getPositionCount()).isEqualTo(1); assertThat(INTEGER.getInt(page.getBlock(0), 0)).isEqualTo(dateColumnValue); @@ -532,7 +533,7 @@ public void testDynamicSplitPruningWithExplicitPartitionFilterPartitionEvolution tableHandle, ImmutableList.of(yearColumnHandle, monthColumnHandle, receiptColumnHandle, amountColumnHandle), getDynamicFilter(partitionPredicate))) { - assertThat(emptyPageSource.getNextPage()).isNull(); + assertThat(emptyPageSource.getNextSourcePage()).isNull(); } } @@ -554,7 +555,7 @@ public void testDynamicSplitPruningWithExplicitPartitionFilterPartitionEvolution tableHandle, ImmutableList.of(yearColumnHandle, monthColumnHandle, receiptColumnHandle, amountColumnHandle), getDynamicFilter(partitionPredicate))) { - Page page = nonEmptyPageSource.getNextPage(); + SourcePage page = nonEmptyPageSource.getNextSourcePage(); assertThat(page).isNotNull(); assertThat(page.getPositionCount()).isEqualTo(1); assertThat(INTEGER.getInt(page.getBlock(0), 0)).isEqualTo(2023L); diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergPageSource.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergPageSource.java deleted file mode 100644 index 15fe7fda0f9b..000000000000 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergPageSource.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.iceberg; - -import io.trino.spi.connector.ConnectorPageSource; -import org.junit.jupiter.api.Test; - -import static io.trino.spi.testing.InterfaceTestUtils.assertAllMethodsOverridden; - -public class TestIcebergPageSource -{ - @Test - public void testEverythingImplemented() - { - assertAllMethodsOverridden(ConnectorPageSource.class, IcebergPageSource.class); - } - - @Test - public void testEverythingImplementedConstantPopulatingPageSource() - { - assertAllMethodsOverridden(ConnectorPageSource.class, ConstantPopulatingPageSource.class); - } -} diff --git a/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryPageSourceProvider.java b/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryPageSourceProvider.java index 8b20790aa60c..29bfe600fbe6 100644 --- a/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryPageSourceProvider.java +++ b/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryPageSourceProvider.java @@ -27,6 +27,7 @@ import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.DynamicFilter; import io.trino.spi.connector.FixedPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.metrics.Metrics; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; @@ -136,7 +137,7 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { if (enableLazyDynamicFiltering && dynamicFilter.isAwaitable()) { return null; @@ -146,14 +147,14 @@ public Page getNextPage() close(); return null; } - Page page = delegate.getNextPage(); + SourcePage page = delegate.getNextSourcePage(); if (page == null) { return null; } completedPositions += page.getPositionCount(); if (!predicate.isAll()) { - page = applyFilter(page, predicate.transformKeys(columns::indexOf).getDomains().get()); + applyFilter(page, predicate.transformKeys(columns::indexOf).getDomains().get()); } rows += page.getPositionCount(); return page; @@ -191,7 +192,7 @@ public Metrics getMetrics() } } - private static Page applyFilter(Page page, Map domains) + private static void applyFilter(SourcePage page, Map domains) { int[] positions = new int[page.getPositionCount()]; int length = 0; @@ -200,10 +201,10 @@ private static Page applyFilter(Page page, Map domains) positions[length++] = position; } } - return page.getPositions(positions, 0, length); + page.selectPositions(positions, 0, length); } - private static boolean positionMatchesPredicate(Page page, int position, Map domains) + private static boolean positionMatchesPredicate(SourcePage page, int position, Map domains) { for (Map.Entry entry : domains.entrySet()) { int channel = entry.getKey(); diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java index c12501c681b2..d853688e1961 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoPageSource.java @@ -30,6 +30,7 @@ import io.trino.spi.block.SqlMap; import io.trino.spi.block.SqlRow; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.ArrayType; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; @@ -141,7 +142,7 @@ public long getMemoryUsage() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { verify(pageBuilder.isEmpty()); for (int i = 0; i < ROWS_PER_REQUEST; i++) { @@ -161,7 +162,7 @@ public Page getNextPage() Page page = pageBuilder.build(); pageBuilder.reset(); - return page; + return SourcePage.create(page); } private void appendTo(Type type, Object value, BlockBuilder output) diff --git a/plugin/trino-opensearch/src/main/java/io/trino/plugin/opensearch/CountQueryPageSource.java b/plugin/trino-opensearch/src/main/java/io/trino/plugin/opensearch/CountQueryPageSource.java index aeacfd87e313..d31ac749ba14 100644 --- a/plugin/trino-opensearch/src/main/java/io/trino/plugin/opensearch/CountQueryPageSource.java +++ b/plugin/trino-opensearch/src/main/java/io/trino/plugin/opensearch/CountQueryPageSource.java @@ -14,8 +14,8 @@ package io.trino.plugin.opensearch; import io.trino.plugin.opensearch.client.OpenSearchClient; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -59,12 +59,12 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { int batch = toIntExact(Math.min(BATCH_SIZE, remaining)); remaining -= batch; - return new Page(batch); + return SourcePage.create(batch); } @Override diff --git a/plugin/trino-opensearch/src/main/java/io/trino/plugin/opensearch/PassthroughQueryPageSource.java b/plugin/trino-opensearch/src/main/java/io/trino/plugin/opensearch/PassthroughQueryPageSource.java index 6e18b85551d1..7eb951bda41b 100644 --- a/plugin/trino-opensearch/src/main/java/io/trino/plugin/opensearch/PassthroughQueryPageSource.java +++ b/plugin/trino-opensearch/src/main/java/io/trino/plugin/opensearch/PassthroughQueryPageSource.java @@ -13,13 +13,12 @@ */ package io.trino.plugin.opensearch; -import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.trino.plugin.opensearch.client.OpenSearchClient; -import io.trino.spi.Page; -import io.trino.spi.PageBuilder; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import java.io.IOException; @@ -62,7 +61,7 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { if (done) { return null; @@ -70,11 +69,10 @@ public Page getNextPage() done = true; - PageBuilder page = new PageBuilder(1, ImmutableList.of(VARCHAR)); - page.declarePosition(); - BlockBuilder column = page.getBlockBuilder(0); - VARCHAR.writeSlice(column, Slices.utf8Slice(result)); - return page.build(); + Slice slice = Slices.utf8Slice(result); + BlockBuilder column = VARCHAR.createBlockBuilder(null, 1, result.length()); + VARCHAR.writeSlice(column, slice); + return SourcePage.create(column.build()); } @Override diff --git a/plugin/trino-opensearch/src/main/java/io/trino/plugin/opensearch/ScanQueryPageSource.java b/plugin/trino-opensearch/src/main/java/io/trino/plugin/opensearch/ScanQueryPageSource.java index 9d05384429aa..e78999ec910e 100644 --- a/plugin/trino-opensearch/src/main/java/io/trino/plugin/opensearch/ScanQueryPageSource.java +++ b/plugin/trino-opensearch/src/main/java/io/trino/plugin/opensearch/ScanQueryPageSource.java @@ -23,6 +23,7 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; @@ -150,7 +151,7 @@ public void close() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { long size = 0; while (size < PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES && iterator.hasNext()) { @@ -186,7 +187,7 @@ public Page getNextPage() columnBuilders[i] = columnBuilders[i].newBlockBuilderLike(null); } - return new Page(blocks); + return SourcePage.create(new Page(blocks)); } private static Map resolveField(Map document, OpenSearchColumnHandle columnHandle) diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotBrokerPageSource.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotBrokerPageSource.java index 7f5144f5cdf9..e7c688898488 100644 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotBrokerPageSource.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotBrokerPageSource.java @@ -24,6 +24,7 @@ import io.trino.spi.block.PageBuilderStatus; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.SourcePage; import java.util.Arrays; import java.util.Iterator; @@ -98,7 +99,7 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { if (finished) { return null; @@ -133,9 +134,9 @@ public Page getNextPage() columnBuilders[i] = columnBuilders[i].newBlockBuilderLike(null); } if (decoders.isEmpty()) { - return new Page(rowCount); + return SourcePage.create(rowCount); } - return new Page(blocks); + return SourcePage.create(new Page(rowCount, blocks)); } @Override diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSegmentPageSource.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSegmentPageSource.java index b3bbe9bf7895..d372bd39991c 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSegmentPageSource.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotSegmentPageSource.java @@ -19,12 +19,12 @@ import io.trino.plugin.pinot.client.PinotDataFetcher; import io.trino.plugin.pinot.client.PinotDataTableWithSize; import io.trino.plugin.pinot.conversion.PinotTimestamps; -import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.StandardTypes; import io.trino.spi.type.TimestampType; import io.trino.spi.type.Type; @@ -111,7 +111,7 @@ public boolean isFinished() * @return constructed page for pinot data. */ @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { if (isFinished()) { close(); @@ -153,7 +153,7 @@ public Page getNextPage() } } - return pageBuilder.build(); + return SourcePage.create(pageBuilder.build()); } private static Map buildColumnIdToNullRowId(DataTable dataTable, List columnHandles) diff --git a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestBrokerQueries.java b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestBrokerQueries.java index ac61ad881a80..5f4fa8f2f54d 100644 --- a/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestBrokerQueries.java +++ b/plugin/trino-pinot/src/test/java/io/trino/plugin/pinot/TestBrokerQueries.java @@ -18,8 +18,8 @@ import io.trino.plugin.pinot.client.PinotClient.BrokerResultRow; import io.trino.plugin.pinot.client.PinotClient.ResultsIterator; import io.trino.plugin.pinot.query.PinotQueryInfo; -import io.trino.spi.Page; import io.trino.spi.block.Block; +import io.trino.spi.connector.SourcePage; import org.apache.pinot.common.response.broker.BrokerResponseNative; import org.apache.pinot.common.response.broker.ResultTable; import org.apache.pinot.common.utils.DataSchema; @@ -115,7 +115,7 @@ public void testBrokerQuery() testingPinotClient, LIMIT_FOR_BROKER_QUERIES); - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); assertThat(page.getChannelCount()).isEqualTo(columnHandles.size()); assertThat(page.getPositionCount()).isEqualTo(RESPONSE.getResultTable().getRows().size()); Block block = page.getBlock(0); @@ -136,7 +136,7 @@ public void testCountStarBrokerQuery() ImmutableList.of(), testingPinotClient, LIMIT_FOR_BROKER_QUERIES); - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); assertThat(page.getPositionCount()).isEqualTo(RESPONSE.getResultTable().getRows().size()); assertThat(page.getChannelCount()).isEqualTo(0); } @@ -169,7 +169,7 @@ public void testBrokerResponseHasTooManyRows() testingPinotClient, LIMIT_FOR_BROKER_QUERIES); assertThatExceptionOfType(PinotException.class) - .isThrownBy(pageSource::getNextPage) + .isThrownBy(pageSource::getNextSourcePage) .withMessage("Broker query returned '3' rows, maximum allowed is '2' rows. with query \"SELECT col_1, col_2, col_3 FROM test_table\""); } } diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftPageSourceProvider.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftPageSourceProvider.java index ff14082a2878..2254aa9f7b0d 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftPageSourceProvider.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftPageSourceProvider.java @@ -121,6 +121,7 @@ private ParquetReader parquetReader(TrinoInputFile inputFile, List return new ParquetReader( Optional.ofNullable(parquetMetadata.getFileMetaData().getCreatedBy()), fields, + false, rowGroupInfoBuilder.build(), dataSource, timeZone, diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftParquetPageSource.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftParquetPageSource.java index 3371f28c6e98..b1d7c812c0ee 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftParquetPageSource.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftParquetPageSource.java @@ -15,9 +15,9 @@ import io.trino.parquet.ParquetCorruptionException; import io.trino.parquet.reader.ParquetReader; -import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.metrics.Metrics; import java.io.IOException; @@ -66,9 +66,9 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { - Page page; + SourcePage page; try { page = parquetReader.nextPage(); } diff --git a/plugin/trino-thrift-testing-server/src/main/java/io/trino/plugin/thrift/server/ThriftTpchService.java b/plugin/trino-thrift-testing-server/src/main/java/io/trino/plugin/thrift/server/ThriftTpchService.java index 3d7f753d855f..59817adce73e 100644 --- a/plugin/trino-thrift-testing-server/src/main/java/io/trino/plugin/thrift/server/ThriftTpchService.java +++ b/plugin/trino-thrift-testing-server/src/main/java/io/trino/plugin/thrift/server/ThriftTpchService.java @@ -34,9 +34,9 @@ import io.trino.plugin.thrift.api.TrinoThriftTableMetadata; import io.trino.plugin.thrift.api.TrinoThriftTupleDomain; import io.trino.plugin.tpch.DecimalTypeMapping; -import io.trino.spi.Page; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.RecordPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; import io.trino.tpch.TpchColumn; @@ -238,9 +238,9 @@ private static TrinoThriftPageResult getRowsInternal(ConnectorPageSource pageSou int skipPages = nextToken != null ? Ints.fromByteArray(nextToken.getId()) : 0; skipPages(pageSource, skipPages); - Page page = null; + SourcePage page = null; while (!pageSource.isFinished() && page == null) { - page = pageSource.getNextPage(); + page = pageSource.getNextSourcePage(); skipPages++; } TrinoThriftId newNextToken = pageSource.isFinished() ? null : new TrinoThriftId(Ints.toByteArray(skipPages)); @@ -248,7 +248,7 @@ private static TrinoThriftPageResult getRowsInternal(ConnectorPageSource pageSou return toThriftPage(page, types(tableName, columnNames), newNextToken); } - private static TrinoThriftPageResult toThriftPage(Page page, List columnTypes, @Nullable TrinoThriftId nextToken) + private static TrinoThriftPageResult toThriftPage(SourcePage page, List columnTypes, @Nullable TrinoThriftId nextToken) { if (page == null) { checkState(nextToken == null, "there must be no more data when page is null"); @@ -267,7 +267,7 @@ private static void skipPages(ConnectorPageSource pageSource, int skipPages) { for (int i = 0; i < skipPages; i++) { checkState(!pageSource.isFinished(), "pageSource is unexpectedly finished"); - pageSource.getNextPage(); + pageSource.getNextSourcePage(); } } diff --git a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftIndexPageSource.java b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftIndexPageSource.java index 318ece1ad96c..6388a2028dc2 100644 --- a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftIndexPageSource.java +++ b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftIndexPageSource.java @@ -28,6 +28,7 @@ import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.RecordSet; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; @@ -167,7 +168,7 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { if (finished) { return null; @@ -214,7 +215,7 @@ public Page getNextPage() // can get more data sendDataRequest(resultContext, pageResult.getNextToken()); updateSignalAndStatusFutures(); - return page; + return SourcePage.create(page); } // are there more splits available @@ -233,7 +234,10 @@ else if (!dataRequests.isEmpty()) { statusFuture = null; finished = true; } - return page; + if (page == null) { + return null; + } + return SourcePage.create(page); } private boolean loadAllSplits() diff --git a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftPageSource.java b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftPageSource.java index bd36318bdae9..c8ee2bf924d3 100644 --- a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftPageSource.java +++ b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftPageSource.java @@ -24,6 +24,7 @@ import io.trino.spi.Page; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import java.util.List; @@ -125,7 +126,7 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { if (future == null) { // no data request in progress @@ -152,7 +153,10 @@ public Page getNextPage() future = null; } - return result; + if (result == null) { + return null; + } + return SourcePage.create(result); } private static boolean canGetMoreData(TrinoThriftId nextToken) diff --git a/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/TestThriftIndexPageSource.java b/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/TestThriftIndexPageSource.java index 016f8fddcba1..83d374818b24 100644 --- a/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/TestThriftIndexPageSource.java +++ b/plugin/trino-thrift/src/test/java/io/trino/plugin/thrift/TestThriftIndexPageSource.java @@ -30,10 +30,10 @@ import io.trino.plugin.thrift.api.TrinoThriftSplitBatch; import io.trino.plugin.thrift.api.TrinoThriftTupleDomain; import io.trino.plugin.thrift.api.datatypes.TrinoThriftInteger; -import io.trino.spi.Page; import io.trino.spi.block.Block; import io.trino.spi.connector.InMemoryRecordSet; import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; import org.junit.jupiter.api.Test; @@ -93,7 +93,7 @@ public ListenableFuture getRows(TrinoThriftId splitId, Li MAX_BYTES_PER_RESPONSE, lookupRequestsConcurrency); - assertThat(pageSource.getNextPage()).isNull(); + assertThat(pageSource.getNextSourcePage()).isNull(); assertThat((long) stats.getIndexPageSize().getAllTime().getTotal()).isEqualTo(0); signals.get(0).await(1, SECONDS); signals.get(1).await(1, SECONDS); @@ -110,12 +110,12 @@ public ListenableFuture getRows(TrinoThriftId splitId, Li // at this point first two requests were sent assertThat(pageSource.isFinished()).isFalse(); - assertThat(pageSource.getNextPage()).isNull(); + assertThat(pageSource.getNextSourcePage()).isNull(); assertThat((long) stats.getIndexPageSize().getAllTime().getTotal()).isEqualTo(0); // completing the second request futures.get(1).set(pageResult(20, null)); - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); pageSizeReceived += page.getSizeInBytes(); assertThat((long) stats.getIndexPageSize().getAllTime().getTotal()).isEqualTo(pageSizeReceived); assertThat(page).isNotNull(); @@ -132,7 +132,7 @@ public ListenableFuture getRows(TrinoThriftId splitId, Li // completing the first request futures.get(0).set(pageResult(10, null)); - page = pageSource.getNextPage(); + page = pageSource.getNextSourcePage(); assertThat(page).isNotNull(); pageSizeReceived += page.getSizeInBytes(); assertThat((long) stats.getIndexPageSize().getAllTime().getTotal()).isEqualTo(pageSizeReceived); @@ -143,7 +143,7 @@ public ListenableFuture getRows(TrinoThriftId splitId, Li // completing the third request futures.get(2).set(pageResult(30, null)); - page = pageSource.getNextPage(); + page = pageSource.getNextSourcePage(); assertThat(page).isNotNull(); pageSizeReceived += page.getSizeInBytes(); assertThat((long) stats.getIndexPageSize().getAllTime().getTotal()).isEqualTo(pageSizeReceived); @@ -153,7 +153,7 @@ public ListenableFuture getRows(TrinoThriftId splitId, Li assertThat(pageSource.isFinished()).isTrue(); // after completion - assertThat(pageSource.getNextPage()).isNull(); + assertThat(pageSource.getNextSourcePage()).isNull(); pageSource.close(); } @@ -204,7 +204,7 @@ private static void runGeneralTest(int splits, int lookupRequestsConcurrency, in while (!pageSource.isFinished()) { CompletableFuture blocked = pageSource.isBlocked(); blocked.get(1, SECONDS); - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); if (page != null) { Block block = page.getBlock(0); for (int position = 0; position < block.getPositionCount(); position++) { @@ -223,7 +223,7 @@ private static void runGeneralTest(int splits, int lookupRequestsConcurrency, in assertThat(actual).isEqualTo(expected); // must be null after finish - assertThat(pageSource.getNextPage()).isNull(); + assertThat(pageSource.getNextSourcePage()).isNull(); pageSource.close(); } diff --git a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/LazyRecordPageSource.java b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchPageSource.java similarity index 68% rename from plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/LazyRecordPageSource.java rename to plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchPageSource.java index a029e92176b5..7dca3dce22ea 100644 --- a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/LazyRecordPageSource.java +++ b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchPageSource.java @@ -19,13 +19,14 @@ import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.block.LazyBlock; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.RecordCursor; import io.trino.spi.connector.RecordSet; +import io.trino.spi.connector.SourcePage; import io.trino.spi.type.Type; import java.util.List; +import java.util.function.ObjLongConsumer; import static java.util.Objects.requireNonNull; @@ -33,7 +34,7 @@ * Wraps pages into lazy blocks. This enables counting of materialized bytes * for testing purposes. */ -class LazyRecordPageSource +class TpchPageSource implements ConnectorPageSource { private static final int ROWS_PER_REQUEST = 4096; @@ -43,7 +44,7 @@ class LazyRecordPageSource private final PageBuilder pageBuilder; private boolean closed; - LazyRecordPageSource(int maxRowsPerPage, RecordSet recordSet) + TpchPageSource(int maxRowsPerPage, RecordSet recordSet) { requireNonNull(recordSet, "recordSet is null"); @@ -85,7 +86,7 @@ public boolean isFinished() } @Override - public Page getNextPage() + public SourcePage getNextSourcePage() { if (!closed) { for (int i = 0; i < ROWS_PER_REQUEST && !pageBuilder.isFull() && pageBuilder.getPositionCount() < maxRowsPerPage; i++) { @@ -127,20 +128,78 @@ else if (javaType == Slice.class) { if ((closed && !pageBuilder.isEmpty()) || pageBuilder.isFull() || pageBuilder.getPositionCount() >= maxRowsPerPage) { Page page = pageBuilder.build(); pageBuilder.reset(); - return lazyWrapper(page); + return new TpchSourcePage(page); } return null; } - private Page lazyWrapper(Page page) + private static class TpchSourcePage + implements SourcePage { - Block[] lazyBlocks = new Block[page.getChannelCount()]; - for (int i = 0; i < page.getChannelCount(); ++i) { - Block block = page.getBlock(i); - lazyBlocks[i] = new LazyBlock(page.getPositionCount(), () -> block); + private Page page; + private final boolean[] loaded; + private long sizeInBytes; + + public TpchSourcePage(Page page) + { + this.page = requireNonNull(page, "page is null"); + this.loaded = new boolean[page.getChannelCount()]; } - return new Page(page.getPositionCount(), lazyBlocks); + @Override + public int getPositionCount() + { + return page.getPositionCount(); + } + + @Override + public long getSizeInBytes() + { + return sizeInBytes; + } + + @Override + public long getRetainedSizeInBytes() + { + return page.getRetainedSizeInBytes(); + } + + @Override + public void retainedBytesForEachPart(ObjLongConsumer consumer) + { + for (int i = 0; i < page.getChannelCount(); i++) { + page.getBlock(i).retainedBytesForEachPart(consumer); + } + } + + @Override + public int getChannelCount() + { + return page.getChannelCount(); + } + + @Override + public Block getBlock(int channel) + { + Block block = page.getBlock(channel); + if (!loaded[channel]) { + loaded[channel] = true; + sizeInBytes += block.getSizeInBytes(); + } + return block; + } + + @Override + public Page getPage() + { + return page; + } + + @Override + public void selectPositions(int[] positions, int offset, int size) + { + page = page.getPositions(positions, offset, size); + } } } diff --git a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchPageSourceProvider.java b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchPageSourceProvider.java index 5a15c15a8a0a..444ef9fb9f03 100644 --- a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchPageSourceProvider.java +++ b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchPageSourceProvider.java @@ -59,7 +59,7 @@ public ConnectorPageSource createPageSource( TpchSplit tpchSplit = (TpchSplit) split; TpchTableHandle tpchTable = (TpchTableHandle) table; - return new LazyRecordPageSource( + return new TpchPageSource( maxRowsPerPage, getRecordSet( TpchTable.getTable(tpchTable.tableName()), diff --git a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchTables.java b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchTables.java index c8f67320c509..1eb41e7e7ee9 100644 --- a/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchTables.java +++ b/plugin/trino-tpch/src/main/java/io/trino/plugin/tpch/TpchTables.java @@ -17,6 +17,7 @@ import io.trino.spi.Page; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.RecordPageSource; +import io.trino.spi.connector.SourcePage; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.Type; import io.trino.tpch.TpchTable; @@ -56,12 +57,12 @@ protected Page computeNext() return endOfData(); } - Page page = pageSource.getNextPage(); + SourcePage page = pageSource.getNextSourcePage(); if (page == null) { return computeNext(); } - return page.getLoadedPage(); + return page.getPage(); } }; }