diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveSessionProperties.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveSessionProperties.java index ebc1118725bf..ea6510f3c173 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveSessionProperties.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/HiveSessionProperties.java @@ -55,6 +55,7 @@ public final class HiveSessionProperties private static final String ORC_TINY_STRIPE_THRESHOLD = "orc_tiny_stripe_threshold"; private static final String ORC_MAX_READ_BLOCK_SIZE = "orc_max_read_block_size"; private static final String ORC_LAZY_READ_SMALL_RANGES = "orc_lazy_read_small_ranges"; + private static final String ORC_NESTED_LAZY_ENABLED = "orc_nested_lazy_enabled"; private static final String ORC_STRING_STATISTICS_LIMIT = "orc_string_statistics_limit"; private static final String ORC_OPTIMIZED_WRITER_VALIDATE = "orc_optimized_writer_validate"; private static final String ORC_OPTIMIZED_WRITER_VALIDATE_PERCENTAGE = "orc_optimized_writer_validate_percentage"; @@ -168,6 +169,11 @@ public HiveSessionProperties( "Experimental: ORC: Read small file segments lazily", orcReaderConfig.isLazyReadSmallRanges(), false), + booleanProperty( + ORC_NESTED_LAZY_ENABLED, + "Experimental: ORC: Lazily read nested data", + orcReaderConfig.isNestedLazy(), + false), dataSizeProperty( ORC_STRING_STATISTICS_LIMIT, "ORC: Maximum size of string statistics; drop if exceeding", @@ -385,6 +391,11 @@ public static boolean getOrcLazyReadSmallRanges(ConnectorSession session) return session.getProperty(ORC_LAZY_READ_SMALL_RANGES, Boolean.class); } + public static boolean isOrcNestedLazy(ConnectorSession session) + { + return session.getProperty(ORC_NESTED_LAZY_ENABLED, Boolean.class); + } + public static DataSize getOrcStringStatisticsLimit(ConnectorSession session) { return session.getProperty(ORC_STRING_STATISTICS_LIMIT, DataSize.class); diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSource.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSource.java index fa874b05554e..e6922d126d3d 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSource.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSource.java @@ -13,27 +13,26 @@ */ package io.prestosql.plugin.hive.orc; +import com.google.common.collect.ImmutableList; import io.prestosql.memory.context.AggregatedMemoryContext; import io.prestosql.orc.OrcCorruptionException; import io.prestosql.orc.OrcDataSource; +import io.prestosql.orc.OrcDataSourceId; import io.prestosql.orc.OrcRecordReader; import io.prestosql.plugin.hive.FileFormatDataSourceStats; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.Block; -import io.prestosql.spi.block.LazyBlock; -import io.prestosql.spi.block.LazyBlockLoader; import io.prestosql.spi.block.RunLengthEncodedBlock; import io.prestosql.spi.connector.ConnectorPageSource; import io.prestosql.spi.type.Type; import java.io.IOException; import java.io.UncheckedIOException; -import java.util.Map; +import java.util.List; import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.base.Preconditions.checkState; -import static io.prestosql.orc.OrcReader.MAX_BATCH_SIZE; +import static com.google.common.base.Preconditions.checkArgument; import static io.prestosql.plugin.hive.HiveErrorCode.HIVE_BAD_DATA; import static io.prestosql.plugin.hive.HiveErrorCode.HIVE_CURSOR_ERROR; import static java.lang.String.format; @@ -43,12 +42,9 @@ public class OrcPageSource implements ConnectorPageSource { private final OrcRecordReader recordReader; + private final List columnAdaptations; private final OrcDataSource orcDataSource; - private final Block[] constantBlocks; - private final int[] hiveColumnIndexes; - - private int batchId; private boolean closed; private final AggregatedMemoryContext systemMemoryContext; @@ -57,31 +53,15 @@ public class OrcPageSource public OrcPageSource( OrcRecordReader recordReader, + List columnAdaptations, OrcDataSource orcDataSource, - Map includedColumns, AggregatedMemoryContext systemMemoryContext, FileFormatDataSourceStats stats) { this.recordReader = requireNonNull(recordReader, "recordReader is null"); + this.columnAdaptations = ImmutableList.copyOf(requireNonNull(columnAdaptations, "columnAdaptations is null")); this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null"); - - int size = requireNonNull(includedColumns, "includedColumns is null").size(); - this.stats = requireNonNull(stats, "stats is null"); - - this.constantBlocks = new Block[size]; - this.hiveColumnIndexes = new int[size]; - - int blockIndex = 0; - for (Map.Entry entry : includedColumns.entrySet()) { - hiveColumnIndexes[blockIndex] = entry.getKey(); - if (!recordReader.isColumnPresent(hiveColumnIndexes[blockIndex])) { - Type type = entry.getValue(); - constantBlocks[blockIndex] = RunLengthEncodedBlock.create(type, null, MAX_BATCH_SIZE); - } - blockIndex++; - } - this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); } @@ -106,37 +86,36 @@ public boolean isFinished() @Override public Page getNextPage() { + Page page; try { - batchId++; - int batchSize = recordReader.nextBatch(); - if (batchSize <= 0) { - close(); - return null; - } - - Block[] blocks = new Block[hiveColumnIndexes.length]; - for (int fieldId = 0; fieldId < blocks.length; fieldId++) { - if (constantBlocks[fieldId] != null) { - blocks[fieldId] = constantBlocks[fieldId].getRegion(0, batchSize); - } - else { - blocks[fieldId] = new LazyBlock(batchSize, new OrcBlockLoader(hiveColumnIndexes[fieldId])); - } - } - return new Page(batchSize, blocks); + page = recordReader.nextPage(); } - catch (PrestoException e) { + catch (IOException | RuntimeException e) { closeWithSuppression(e); - throw e; + throw handleException(orcDataSource.getId(), e); } - catch (OrcCorruptionException e) { - closeWithSuppression(e); - throw new PrestoException(HIVE_BAD_DATA, e); + + if (page == null) { + close(); + return null; } - catch (IOException | RuntimeException e) { - closeWithSuppression(e); - throw new PrestoException(HIVE_CURSOR_ERROR, format("Failed to read ORC file: %s", orcDataSource.getId()), e); + + Block[] blocks = new Block[columnAdaptations.size()]; + for (int i = 0; i < columnAdaptations.size(); i++) { + blocks[i] = columnAdaptations.get(i).block(page); } + return new Page(page.getPositionCount(), blocks); + } + + static PrestoException handleException(OrcDataSourceId dataSourceId, Exception exception) + { + if (exception instanceof PrestoException) { + return (PrestoException) exception; + } + if (exception instanceof OrcCorruptionException) { + return new PrestoException(HIVE_BAD_DATA, exception); + } + return new PrestoException(HIVE_CURSOR_ERROR, format("Failed to read ORC file: %s", dataSourceId), exception); } @Override @@ -161,7 +140,8 @@ public void close() public String toString() { return toStringHelper(this) - .add("hiveColumnIndexes", hiveColumnIndexes) + .add("orcDataSource", orcDataSource.getId()) + .add("columns", columnAdaptations) .toString(); } @@ -171,7 +151,7 @@ public long getSystemMemoryUsage() return systemMemoryContext.getBytes(); } - protected void closeWithSuppression(Throwable throwable) + private void closeWithSuppression(Throwable throwable) { requireNonNull(throwable, "throwable is null"); try { @@ -185,36 +165,73 @@ protected void closeWithSuppression(Throwable throwable) } } - private final class OrcBlockLoader - implements LazyBlockLoader + public interface ColumnAdaptation { - private final int expectedBatchId = batchId; - private final int columnIndex; - private boolean loaded; + Block block(Page sourcePage); - public OrcBlockLoader(int columnIndex) + static ColumnAdaptation nullColumn(Type type) { - this.columnIndex = columnIndex; + return new NullColumn(type); + } + + static ColumnAdaptation sourceColumn(int index) + { + return new SourceColumn(index); + } + } + + private static class NullColumn + implements ColumnAdaptation + { + private final Type type; + private final Block nullBlock; + + public NullColumn(Type type) + { + this.type = requireNonNull(type, "type is null"); + this.nullBlock = type.createBlockBuilder(null, 1, 0) + .appendNull() + .build(); } @Override - public final void load(LazyBlock lazyBlock) + public Block block(Page sourcePage) { - checkState(!loaded, "Already loaded"); - checkState(batchId == expectedBatchId); + return new RunLengthEncodedBlock(nullBlock, sourcePage.getPositionCount()); + } - try { - Block block = recordReader.readBlock(columnIndex); - lazyBlock.setBlock(block); - } - catch (OrcCorruptionException e) { - throw new PrestoException(HIVE_BAD_DATA, e); - } - catch (IOException | RuntimeException e) { - throw new PrestoException(HIVE_CURSOR_ERROR, format("Failed to read ORC file: %s", orcDataSource.getId()), e); - } + @Override + public String toString() + { + return toStringHelper(this) + .add("type", type) + .toString(); + } + } + + private static class SourceColumn + implements ColumnAdaptation + { + private final int index; + + public SourceColumn(int index) + { + checkArgument(index >= 0, "index is negative"); + this.index = index; + } - loaded = true; + @Override + public Block block(Page sourcePage) + { + return sourcePage.getBlock(index); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("index", index) + .toString(); } } } diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java index 2c95ff6d55ee..c033cb077177 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java @@ -16,22 +16,24 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.prestosql.memory.context.AggregatedMemoryContext; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcDataSource; import io.prestosql.orc.OrcDataSourceId; -import io.prestosql.orc.OrcPredicate; import io.prestosql.orc.OrcReader; import io.prestosql.orc.OrcReaderOptions; import io.prestosql.orc.OrcRecordReader; import io.prestosql.orc.TupleDomainOrcPredicate; -import io.prestosql.orc.TupleDomainOrcPredicate.ColumnReference; +import io.prestosql.orc.TupleDomainOrcPredicate.TupleDomainOrcPredicateBuilder; import io.prestosql.plugin.hive.FileFormatDataSourceStats; import io.prestosql.plugin.hive.HdfsEnvironment; import io.prestosql.plugin.hive.HiveColumnHandle; import io.prestosql.plugin.hive.HivePageSourceFactory; +import io.prestosql.plugin.hive.orc.OrcPageSource.ColumnAdaptation; import io.prestosql.spi.PrestoException; import io.prestosql.spi.connector.ConnectorPageSource; import io.prestosql.spi.connector.ConnectorSession; import io.prestosql.spi.connector.FixedPageSource; +import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.type.Type; import org.apache.hadoop.conf.Configuration; @@ -46,14 +48,16 @@ import java.io.FileNotFoundException; import java.io.IOException; +import java.util.ArrayList; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.Properties; import java.util.regex.Pattern; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.collect.Maps.uniqueIndex; import static io.prestosql.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.prestosql.orc.OrcReader.INITIAL_BATCH_SIZE; import static io.prestosql.plugin.hive.HiveColumnHandle.ColumnType.REGULAR; @@ -67,8 +71,11 @@ import static io.prestosql.plugin.hive.HiveSessionProperties.getOrcStreamBufferSize; import static io.prestosql.plugin.hive.HiveSessionProperties.getOrcTinyStripeThreshold; import static io.prestosql.plugin.hive.HiveSessionProperties.isOrcBloomFiltersEnabled; +import static io.prestosql.plugin.hive.HiveSessionProperties.isOrcNestedLazy; +import static io.prestosql.plugin.hive.orc.OrcPageSource.handleException; import static io.prestosql.plugin.hive.util.HiveUtil.isDeserializerClass; import static java.lang.String.format; +import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; public class OrcPageSourceFactory @@ -139,6 +146,7 @@ public Optional createPageSource( .withTinyStripeThreshold(getOrcTinyStripeThreshold(session)) .withMaxReadBlockSize(getOrcMaxReadBlockSize(session)) .withLazyReadSmallRanges(getOrcLazyReadSmallRanges(session)) + .withNestedLazy(isOrcNestedLazy(session)) .withBloomFiltersEnabled(isOrcBloomFiltersEnabled(session)), stats)); } @@ -158,6 +166,11 @@ private static OrcPageSource createOrcPageSource( OrcReaderOptions options, FileFormatDataSourceStats stats) { + for (HiveColumnHandle column : columns) { + checkArgument(column.getColumnType() == REGULAR, "column type must be regular: %s", column); + } + checkArgument(!effectivePredicate.isNone()); + OrcDataSource orcDataSource; try { FileSystem fileSystem = hdfsEnvironment.getFileSystem(sessionUser, path, configuration); @@ -181,34 +194,65 @@ private static OrcPageSource createOrcPageSource( try { OrcReader reader = new OrcReader(orcDataSource, options); - List physicalColumns = getPhysicalHiveColumnHandles(columns, useOrcColumnNames, reader, path); - ImmutableMap.Builder includedColumnsBuilder = ImmutableMap.builder(); - ImmutableList.Builder> columnReferences = ImmutableList.builder(); - for (HiveColumnHandle column : physicalColumns) { - if (column.getColumnType() == REGULAR) { - Type type = column.getType(); - includedColumnsBuilder.put(column.getHiveColumnIndex(), type); - columnReferences.add(new ColumnReference<>(column, column.getHiveColumnIndex(), type)); - } + if (useOrcColumnNames) { + verifyFileHasColumnNames(reader.getColumnNames(), path); } - ImmutableMap includedColumns = includedColumnsBuilder.build(); + List fileColumns = reader.getRootColumn().getNestedColumns(); + Map fileColumnsByName = ImmutableMap.of(); + if (useOrcColumnNames) { + // Convert column names read from ORC files to lower case to be consistent with those stored in Hive Metastore + fileColumnsByName = uniqueIndex(fileColumns, orcColumn -> orcColumn.getColumnName().toLowerCase(ENGLISH)); + } + + TupleDomainOrcPredicateBuilder predicateBuilder = TupleDomainOrcPredicate.builder() + .setBloomFiltersEnabled(options.isBloomFiltersEnabled()); + Map effectivePredicateDomains = effectivePredicate.getDomains() + .orElseThrow(() -> new IllegalArgumentException("Effective predicate is none")); + List fileReadColumns = new ArrayList<>(columns.size()); + List fileReadTypes = new ArrayList<>(columns.size()); + List columnAdaptations = new ArrayList<>(columns.size()); + for (HiveColumnHandle column : columns) { + OrcColumn orcColumn = null; + if (useOrcColumnNames) { + orcColumn = fileColumnsByName.get(column.getName().toLowerCase(ENGLISH)); + } + else if (column.getHiveColumnIndex() < fileColumns.size()) { + orcColumn = fileColumns.get(column.getHiveColumnIndex()); + } + + Type readType = column.getType(); + if (orcColumn != null) { + int sourceIndex = fileReadColumns.size(); + columnAdaptations.add(ColumnAdaptation.sourceColumn(sourceIndex)); + fileReadColumns.add(orcColumn); + fileReadTypes.add(readType); - OrcPredicate predicate = new TupleDomainOrcPredicate<>(effectivePredicate, columnReferences.build(), options.isBloomFiltersEnabled()); + Domain domain = effectivePredicateDomains.get(column); + if (domain != null) { + predicateBuilder.addColumn(orcColumn.getColumnId(), domain); + } + } + else { + columnAdaptations.add(ColumnAdaptation.nullColumn(readType)); + } + } OrcRecordReader recordReader = reader.createRecordReader( - includedColumns, - predicate, + fileReadColumns, + fileReadTypes, + predicateBuilder.build(), start, length, hiveStorageTimeZone, systemMemoryUsage, - INITIAL_BATCH_SIZE); + INITIAL_BATCH_SIZE, + exception -> handleException(orcDataSource.getId(), exception)); return new OrcPageSource( recordReader, + columnAdaptations, orcDataSource, - includedColumns, systemMemoryUsage, stats); } @@ -234,31 +278,6 @@ private static String splitError(Throwable t, Path path, long start, long length return format("Error opening Hive split %s (offset=%s, length=%s): %s", path, start, length, t.getMessage()); } - private static List getPhysicalHiveColumnHandles(List columns, boolean useOrcColumnNames, OrcReader reader, Path path) - { - if (!useOrcColumnNames) { - return columns; - } - - verifyFileHasColumnNames(reader.getColumnNames(), path); - - Map physicalNameOrdinalMap = buildPhysicalNameOrdinalMap(reader); - int nextMissingColumnIndex = physicalNameOrdinalMap.size(); - - ImmutableList.Builder physicalColumns = ImmutableList.builder(); - for (HiveColumnHandle column : columns) { - Integer physicalOrdinal = physicalNameOrdinalMap.get(column.getName()); - if (physicalOrdinal == null) { - // if the column is missing from the file, assign it a column number larger - // than the number of columns in the file so the reader will fill it with nulls - physicalOrdinal = nextMissingColumnIndex; - nextMissingColumnIndex++; - } - physicalColumns.add(new HiveColumnHandle(column.getName(), column.getHiveType(), column.getType(), physicalOrdinal, column.getColumnType(), column.getComment())); - } - return physicalColumns.build(); - } - private static void verifyFileHasColumnNames(List physicalColumnNames, Path path) { if (!physicalColumnNames.isEmpty() && physicalColumnNames.stream().allMatch(physicalColumnName -> DEFAULT_HIVE_COLUMN_NAME_PATTERN.matcher(physicalColumnName).matches())) { @@ -267,18 +286,4 @@ private static void verifyFileHasColumnNames(List physicalColumnNames, P "ORC file does not contain column names in the footer: " + path); } } - - private static Map buildPhysicalNameOrdinalMap(OrcReader reader) - { - ImmutableMap.Builder physicalNameOrdinalMap = ImmutableMap.builder(); - - int ordinal = 0; - for (String physicalColumnName : reader.getColumnNames()) { - // Convert column names read from ORC files to lower case to be consistent with those stored in Hive Metastore - physicalNameOrdinalMap.put(physicalColumnName.toLowerCase(Locale.ENGLISH), ordinal); - ordinal++; - } - - return physicalNameOrdinalMap.build(); - } } diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcReaderConfig.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcReaderConfig.java index 6cc4ece90692..7410a0f355c8 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcReaderConfig.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcReaderConfig.java @@ -136,4 +136,20 @@ public OrcReaderConfig setLazyReadSmallRanges(boolean lazyReadSmallRanges) options = options.withLazyReadSmallRanges(lazyReadSmallRanges); return this; } + + @Deprecated + public boolean isNestedLazy() + { + return options.isNestedLazy(); + } + + // TODO remove config option once efficacy is proven + @Deprecated + @Config("hive.orc.nested-lazy") + @ConfigDescription("ORC lazily read nested data") + public OrcReaderConfig setNestedLazy(boolean nestedLazy) + { + options = options.withNestedLazy(nestedLazy); + return this; + } } diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/util/TempFileReader.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/util/TempFileReader.java index aa0ec283a281..66d8b602cdd4 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/util/TempFileReader.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/util/TempFileReader.java @@ -21,14 +21,11 @@ import io.prestosql.orc.OrcRecordReader; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; -import io.prestosql.spi.block.Block; import io.prestosql.spi.type.Type; import java.io.IOException; import java.io.InterruptedIOException; -import java.util.HashMap; import java.util.List; -import java.util.Map; import static io.prestosql.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.prestosql.orc.OrcReader.INITIAL_BATCH_SIZE; @@ -39,31 +36,25 @@ public class TempFileReader extends AbstractIterator { - private final int columnCount; private final OrcRecordReader reader; public TempFileReader(List types, OrcDataSource dataSource) { requireNonNull(types, "types is null"); - this.columnCount = types.size(); try { OrcReader orcReader = new OrcReader(dataSource, new OrcReaderOptions()); - - Map includedColumns = new HashMap<>(); - for (int i = 0; i < types.size(); i++) { - includedColumns.put(i, types.get(i)); - } - reader = orcReader.createRecordReader( - includedColumns, + orcReader.getRootColumn().getNestedColumns(), + types, OrcPredicate.TRUE, UTC, newSimpleAggregatedMemoryContext(), - INITIAL_BATCH_SIZE); + INITIAL_BATCH_SIZE, + TempFileReader::handleException); } catch (IOException e) { - throw new PrestoException(HIVE_WRITER_DATA_ERROR, "Failed to read temporary data"); + throw handleException(e); } } @@ -75,19 +66,21 @@ protected Page computeNext() throw new InterruptedIOException(); } - int batchSize = reader.nextBatch(); - if (batchSize <= 0) { + Page page = reader.nextPage(); + if (page == null) { return endOfData(); } - Block[] blocks = new Block[columnCount]; - for (int i = 0; i < columnCount; i++) { - blocks[i] = reader.readBlock(i).getLoadedBlock(); - } - return new Page(batchSize, blocks); + // eagerly load the page + return page.getLoadedPage(); } catch (IOException e) { - throw new PrestoException(HIVE_WRITER_DATA_ERROR, "Failed to read temporary data"); + throw handleException(e); } } + + private static PrestoException handleException(Exception e) + { + return new PrestoException(HIVE_WRITER_DATA_ERROR, "Failed to read temporary data", e); + } } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/orc/TestOrcReaderConfig.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/orc/TestOrcReaderConfig.java index 036f0b69d19c..f60eeee6863d 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/orc/TestOrcReaderConfig.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/orc/TestOrcReaderConfig.java @@ -37,7 +37,8 @@ public void testDefaults() .setStreamBufferSize(new DataSize(8, Unit.MEGABYTE)) .setTinyStripeThreshold(new DataSize(8, Unit.MEGABYTE)) .setMaxBlockSize(new DataSize(16, Unit.MEGABYTE)) - .setLazyReadSmallRanges(true)); + .setLazyReadSmallRanges(true) + .setNestedLazy(true)); } @Test @@ -52,6 +53,7 @@ public void testExplicitPropertyMappings() .put("hive.orc.tiny-stripe-threshold", "61kB") .put("hive.orc.max-read-block-size", "66kB") .put("hive.orc.lazy-read-small-ranges", "false") + .put("hive.orc.nested-lazy", "false") .build(); OrcReaderConfig expected = new OrcReaderConfig() @@ -62,7 +64,8 @@ public void testExplicitPropertyMappings() .setStreamBufferSize(new DataSize(55, Unit.KILOBYTE)) .setTinyStripeThreshold(new DataSize(61, Unit.KILOBYTE)) .setMaxBlockSize(new DataSize(66, Unit.KILOBYTE)) - .setLazyReadSmallRanges(false); + .setLazyReadSmallRanges(false) + .setNestedLazy(false); assertFullMapping(properties, expected); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/OrcBlockFactory.java b/presto-orc/src/main/java/io/prestosql/orc/OrcBlockFactory.java new file mode 100644 index 000000000000..c4b721b49dc5 --- /dev/null +++ b/presto-orc/src/main/java/io/prestosql/orc/OrcBlockFactory.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.orc; + +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.LazyBlock; +import io.prestosql.spi.block.LazyBlockLoader; + +import java.io.IOException; +import java.util.function.Consumer; +import java.util.function.Function; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class OrcBlockFactory +{ + private final Function exceptionTransform; + private final boolean nestedLazy; + private int currentPageId; + + public OrcBlockFactory(Function exceptionTransform, boolean nestedLazy) + { + this.exceptionTransform = requireNonNull(exceptionTransform, "exceptionTransform is null"); + this.nestedLazy = nestedLazy; + } + + public void nextPage() + { + currentPageId++; + } + + public Block createBlock(int positionCount, OrcBlockReader reader, Consumer onBlockLoaded) + { + return new LazyBlock(positionCount, new OrcBlockLoader(reader, onBlockLoaded)); + } + + public NestedBlockFactory createNestedBlockFactory(Consumer onBlockLoaded) + { + return new NestedBlockFactory(nestedLazy, onBlockLoaded); + } + + public interface OrcBlockReader + { + Block readBlock() + throws IOException; + } + + public class NestedBlockFactory + { + private final boolean lazy; + private final Consumer onBlockLoaded; + + private NestedBlockFactory(boolean lazy, Consumer onBlockLoaded) + { + this.lazy = lazy; + this.onBlockLoaded = requireNonNull(onBlockLoaded, "onBlockLoaded is null"); + } + + public Block createBlock(int positionCount, OrcBlockReader reader) + { + if (lazy) { + return new LazyBlock(positionCount, new OrcBlockLoader(reader, onBlockLoaded)); + } + + try { + Block block = reader.readBlock(); + onBlockLoaded.accept(block); + return block; + } + catch (Exception e) { + throw exceptionTransform.apply(e); + } + } + } + + private final class OrcBlockLoader + implements LazyBlockLoader + { + private final int expectedPageId = currentPageId; + private final OrcBlockReader blockReader; + private final Consumer onBlockLoaded; + private boolean loaded; + + public OrcBlockLoader(OrcBlockReader blockReader, Consumer onBlockLoaded) + { + this.blockReader = requireNonNull(blockReader, "blockReader is null"); + this.onBlockLoaded = requireNonNull(onBlockLoaded, "onBlockLoaded is null"); + } + + @Override + public final void load(LazyBlock lazyBlock) + { + checkState(!loaded, "Already loaded"); + checkState(currentPageId == expectedPageId, "ORC reader has been advanced beyond block"); + + try { + Block block = blockReader.readBlock(); + lazyBlock.setBlock(block); + onBlockLoaded.accept(block); + } + catch (IOException | RuntimeException e) { + throw exceptionTransform.apply(e); + } + + loaded = true; + } + } +} diff --git a/presto-orc/src/main/java/io/prestosql/orc/OrcColumn.java b/presto-orc/src/main/java/io/prestosql/orc/OrcColumn.java new file mode 100644 index 000000000000..26562fef51ed --- /dev/null +++ b/presto-orc/src/main/java/io/prestosql/orc/OrcColumn.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.orc; + +import com.google.common.collect.ImmutableList; +import io.prestosql.orc.metadata.OrcColumnId; +import io.prestosql.orc.metadata.OrcType.OrcTypeKind; + +import java.util.List; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public final class OrcColumn +{ + private final String path; + private final OrcColumnId columnId; + private final OrcTypeKind columnType; + private final String columnName; + private final OrcDataSourceId orcDataSourceId; + private final List nestedColumns; + + public OrcColumn( + String path, + OrcColumnId columnId, + String columnName, + OrcTypeKind columnType, + OrcDataSourceId orcDataSourceId, + List nestedColumns) + { + this.path = requireNonNull(path, "path is null"); + this.columnId = requireNonNull(columnId, "columnId is null"); + this.columnName = requireNonNull(columnName, "columnName is null"); + this.columnType = requireNonNull(columnType, "columnType is null"); + this.orcDataSourceId = requireNonNull(orcDataSourceId, "orcDataSourceId is null"); + this.nestedColumns = ImmutableList.copyOf(requireNonNull(nestedColumns, "nestedColumns is null")); + } + + public String getPath() + { + return path; + } + + public OrcColumnId getColumnId() + { + return columnId; + } + + public OrcTypeKind getColumnType() + { + return columnType; + } + + public String getColumnName() + { + return columnName; + } + + public OrcDataSourceId getOrcDataSourceId() + { + return orcDataSourceId; + } + + public List getNestedColumns() + { + return nestedColumns; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("path", path) + .add("columnId", columnId) + .add("streamType", columnType) + .add("dataSource", orcDataSourceId) + .toString(); + } +} diff --git a/presto-orc/src/main/java/io/prestosql/orc/OrcPredicate.java b/presto-orc/src/main/java/io/prestosql/orc/OrcPredicate.java index 46470b59582a..d04dde4f47f5 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/OrcPredicate.java +++ b/presto-orc/src/main/java/io/prestosql/orc/OrcPredicate.java @@ -13,10 +13,9 @@ */ package io.prestosql.orc; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.metadata.statistics.ColumnStatistics; -import java.util.Map; - public interface OrcPredicate { OrcPredicate TRUE = (numberOfRows, statisticsByColumnIndex) -> true; @@ -26,8 +25,7 @@ public interface OrcPredicate * * @param numberOfRows the number of rows in the segment; this can be used with * {@code ColumnStatistics} to determine if a column is only null - * @param statisticsByColumnIndex statistics for column by ordinal position - * in the file; this will match the field order from the hive metastore + * @param allColumnStatistics column statistics */ - boolean matches(long numberOfRows, Map statisticsByColumnIndex); + boolean matches(long numberOfRows, ColumnMetadata allColumnStatistics); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/OrcReader.java b/presto-orc/src/main/java/io/prestosql/orc/OrcReader.java index b4d630949d98..d6434c5434ea 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/OrcReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/OrcReader.java @@ -14,33 +14,42 @@ package io.prestosql.orc; import com.google.common.base.Joiner; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableList; import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.airlift.units.DataSize; import io.prestosql.memory.context.AggregatedMemoryContext; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.metadata.CompressionKind; import io.prestosql.orc.metadata.ExceptionWrappingMetadataReader; import io.prestosql.orc.metadata.Footer; import io.prestosql.orc.metadata.Metadata; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.OrcMetadataReader; +import io.prestosql.orc.metadata.OrcType; +import io.prestosql.orc.metadata.OrcType.OrcTypeKind; import io.prestosql.orc.metadata.PostScript; import io.prestosql.orc.metadata.PostScript.HiveWriterVersion; import io.prestosql.orc.stream.OrcChunkLoader; import io.prestosql.orc.stream.OrcInputStream; +import io.prestosql.spi.Page; import io.prestosql.spi.type.Type; import org.joda.time.DateTimeZone; import java.io.IOException; import java.io.InputStream; import java.util.List; -import java.util.Map; import java.util.Optional; +import java.util.function.Function; import java.util.function.Predicate; +import java.util.stream.IntStream; +import static com.google.common.base.Throwables.throwIfUnchecked; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.slice.SizeOf.SIZE_OF_BYTE; import static io.prestosql.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.prestosql.orc.OrcDecompressor.createOrcDecompressor; +import static io.prestosql.orc.metadata.OrcColumnId.ROOT_COLUMN; import static io.prestosql.orc.metadata.PostScript.MAGIC; import static java.lang.Math.min; import static java.lang.Math.toIntExact; @@ -67,6 +76,7 @@ public class OrcReader private final Optional decompressor; private final Footer footer; private final Metadata metadata; + private final OrcColumn rootColumn; private final Optional writeValidation; @@ -165,10 +175,12 @@ private OrcReader( try (InputStream footerInputStream = new OrcInputStream(OrcChunkLoader.create(orcDataSource.getId(), footerSlice, decompressor, newSimpleAggregatedMemoryContext()))) { this.footer = metadataReader.readFooter(hiveWriterVersion, footerInputStream); } - if (footer.getTypes().isEmpty()) { + if (footer.getTypes().size() == 0) { throw new OrcCorruptionException(orcDataSource.getId(), "File has no columns"); } + this.rootColumn = createOrcColumn("", "", new OrcColumnId(0), footer.getTypes(), orcDataSource.getId()); + validateWrite(validation -> validation.getColumnNames().equals(getColumnNames()), "Unexpected column names"); validateWrite(validation -> validation.getRowGroupMaxRowCount() == footer.getRowsInRowGroup(), "Unexpected rows in group"); if (writeValidation.isPresent()) { @@ -180,7 +192,7 @@ private OrcReader( public List getColumnNames() { - return footer.getTypes().get(0).getFieldNames(); + return footer.getTypes().get(ROOT_COLUMN).getFieldNames(); } public Footer getFooter() @@ -193,6 +205,11 @@ public Metadata getMetadata() return metadata; } + public OrcColumn getRootColumn() + { + return rootColumn; + } + public int getBufferSize() { return bufferSize; @@ -203,24 +220,43 @@ public CompressionKind getCompressionKind() return compressionKind; } - public OrcRecordReader createRecordReader(Map includedColumns, OrcPredicate predicate, DateTimeZone hiveStorageTimeZone, AggregatedMemoryContext systemMemoryUsage, int initialBatchSize) + public OrcRecordReader createRecordReader( + List readColumns, + List readTypes, + OrcPredicate predicate, + DateTimeZone hiveStorageTimeZone, + AggregatedMemoryContext systemMemoryUsage, + int initialBatchSize, + Function exceptionTransform) throws OrcCorruptionException { - return createRecordReader(includedColumns, predicate, 0, orcDataSource.getSize(), hiveStorageTimeZone, systemMemoryUsage, initialBatchSize); + return createRecordReader( + readColumns, + readTypes, + predicate, + 0, + orcDataSource.getSize(), + hiveStorageTimeZone, + systemMemoryUsage, + initialBatchSize, + exceptionTransform); } public OrcRecordReader createRecordReader( - Map includedColumns, + List readColumns, + List readTypes, OrcPredicate predicate, long offset, long length, DateTimeZone hiveStorageTimeZone, AggregatedMemoryContext systemMemoryUsage, - int initialBatchSize) + int initialBatchSize, + Function exceptionTransform) throws OrcCorruptionException { return new OrcRecordReader( - requireNonNull(includedColumns, "includedColumns is null"), + requireNonNull(readColumns, "readColumns is null"), + requireNonNull(readTypes, "readTypes is null"), requireNonNull(predicate, "predicate is null"), footer.getNumberOfRows(), footer.getStripes(), @@ -239,7 +275,8 @@ public OrcRecordReader createRecordReader( footer.getUserMetadata(), systemMemoryUsage, writeValidation, - initialBatchSize); + initialBatchSize, + exceptionTransform); } private static OrcDataSource wrapWithCacheIfTiny(OrcDataSource dataSource, DataSize maxCacheSize) @@ -254,6 +291,38 @@ private static OrcDataSource wrapWithCacheIfTiny(OrcDataSource dataSource, DataS return new CachingOrcDataSource(dataSource, desiredOffset -> diskRange); } + private static OrcColumn createOrcColumn( + String parentStreamName, + String fieldName, + OrcColumnId columnId, + ColumnMetadata types, + OrcDataSourceId orcDataSourceId) + { + String path = fieldName.isEmpty() ? parentStreamName : parentStreamName + "." + fieldName; + OrcType orcType = types.get(columnId); + + List nestedColumns = ImmutableList.of(); + if (orcType.getOrcTypeKind() == OrcTypeKind.STRUCT) { + nestedColumns = IntStream.range(0, orcType.getFieldCount()) + .mapToObj(fieldId -> createOrcColumn( + path, + orcType.getFieldName(fieldId), + orcType.getFieldTypeIndex(fieldId), + types, + orcDataSourceId)) + .collect(toImmutableList()); + } + else if (orcType.getOrcTypeKind() == OrcTypeKind.LIST) { + nestedColumns = ImmutableList.of(createOrcColumn(path, "item", orcType.getFieldTypeIndex(0), types, orcDataSourceId)); + } + else if (orcType.getOrcTypeKind() == OrcTypeKind.MAP) { + nestedColumns = ImmutableList.of( + createOrcColumn(path, "key", orcType.getFieldTypeIndex(0), types, orcDataSourceId), + createOrcColumn(path, "value", orcType.getFieldTypeIndex(1), types, orcDataSourceId)); + } + return new OrcColumn(path, columnId, fieldName, orcType.getOrcTypeKind(), orcDataSourceId, nestedColumns); + } + /** * Does the file start with the ORC magic bytes? */ @@ -299,19 +368,26 @@ private void validateWrite(Predicate test, String messageFor static void validateFile( OrcWriteValidation writeValidation, OrcDataSource input, - List types, + List readTypes, DateTimeZone hiveStorageTimeZone) throws OrcCorruptionException { - ImmutableMap.Builder readTypes = ImmutableMap.builder(); - for (int columnIndex = 0; columnIndex < types.size(); columnIndex++) { - readTypes.put(columnIndex, types.get(columnIndex)); - } try { OrcReader orcReader = new OrcReader(input, new OrcReaderOptions(), Optional.of(writeValidation)); - try (OrcRecordReader orcRecordReader = orcReader.createRecordReader(readTypes.build(), OrcPredicate.TRUE, hiveStorageTimeZone, newSimpleAggregatedMemoryContext(), INITIAL_BATCH_SIZE)) { - while (orcRecordReader.nextBatch() >= 0) { - // ignored + try (OrcRecordReader orcRecordReader = orcReader.createRecordReader( + orcReader.getRootColumn().getNestedColumns(), + readTypes, + OrcPredicate.TRUE, + hiveStorageTimeZone, + newSimpleAggregatedMemoryContext(), + INITIAL_BATCH_SIZE, + exception -> { + throwIfUnchecked(exception); + return new RuntimeException(exception); + })) { + for (Page page = orcRecordReader.nextPage(); page != null; page = orcRecordReader.nextPage()) { + // fully load the page + page.getLoadedPage(); } } } diff --git a/presto-orc/src/main/java/io/prestosql/orc/OrcReaderOptions.java b/presto-orc/src/main/java/io/prestosql/orc/OrcReaderOptions.java index 0b96a04602d3..f4f7969b8cfa 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/OrcReaderOptions.java +++ b/presto-orc/src/main/java/io/prestosql/orc/OrcReaderOptions.java @@ -27,6 +27,7 @@ public class OrcReaderOptions private static final DataSize DEFAULT_STREAM_BUFFER_SIZE = new DataSize(8, MEGABYTE); private static final DataSize DEFAULT_MAX_BLOCK_SIZE = new DataSize(16, MEGABYTE); private static final boolean DEFAULT_LAZY_READ_SMALL_RANGES = true; + private static final boolean DEFAULT_NESTED_LAZY = true; private final boolean bloomFiltersEnabled; @@ -36,6 +37,7 @@ public class OrcReaderOptions private final DataSize streamBufferSize; private final DataSize maxBlockSize; private final boolean lazyReadSmallRanges; + private final boolean nestedLazy; public OrcReaderOptions() { @@ -46,6 +48,7 @@ public OrcReaderOptions() streamBufferSize = DEFAULT_STREAM_BUFFER_SIZE; maxBlockSize = DEFAULT_MAX_BLOCK_SIZE; lazyReadSmallRanges = DEFAULT_LAZY_READ_SMALL_RANGES; + nestedLazy = DEFAULT_NESTED_LAZY; } private OrcReaderOptions( @@ -55,7 +58,8 @@ private OrcReaderOptions( DataSize tinyStripeThreshold, DataSize streamBufferSize, DataSize maxBlockSize, - boolean lazyReadSmallRanges) + boolean lazyReadSmallRanges, + boolean nestedLazy) { this.maxMergeDistance = requireNonNull(maxMergeDistance, "maxMergeDistance is null"); this.maxBufferSize = requireNonNull(maxBufferSize, "maxBufferSize is null"); @@ -64,6 +68,7 @@ private OrcReaderOptions( this.maxBlockSize = requireNonNull(maxBlockSize, "maxBlockSize is null"); this.lazyReadSmallRanges = requireNonNull(lazyReadSmallRanges, "lazyReadSmallRanges is null"); this.bloomFiltersEnabled = bloomFiltersEnabled; + this.nestedLazy = nestedLazy; } public boolean isBloomFiltersEnabled() @@ -101,6 +106,11 @@ public boolean isLazyReadSmallRanges() return lazyReadSmallRanges; } + public boolean isNestedLazy() + { + return nestedLazy; + } + public OrcReaderOptions withBloomFiltersEnabled(boolean bloomFiltersEnabled) { return new OrcReaderOptions( @@ -110,7 +120,8 @@ public OrcReaderOptions withBloomFiltersEnabled(boolean bloomFiltersEnabled) tinyStripeThreshold, streamBufferSize, maxBlockSize, - lazyReadSmallRanges); + lazyReadSmallRanges, + nestedLazy); } public OrcReaderOptions withMaxMergeDistance(DataSize maxMergeDistance) @@ -122,7 +133,8 @@ public OrcReaderOptions withMaxMergeDistance(DataSize maxMergeDistance) tinyStripeThreshold, streamBufferSize, maxBlockSize, - lazyReadSmallRanges); + lazyReadSmallRanges, + nestedLazy); } public OrcReaderOptions withMaxBufferSize(DataSize maxBufferSize) @@ -134,7 +146,8 @@ public OrcReaderOptions withMaxBufferSize(DataSize maxBufferSize) tinyStripeThreshold, streamBufferSize, maxBlockSize, - lazyReadSmallRanges); + lazyReadSmallRanges, + nestedLazy); } public OrcReaderOptions withTinyStripeThreshold(DataSize tinyStripeThreshold) @@ -146,7 +159,8 @@ public OrcReaderOptions withTinyStripeThreshold(DataSize tinyStripeThreshold) tinyStripeThreshold, streamBufferSize, maxBlockSize, - lazyReadSmallRanges); + lazyReadSmallRanges, + nestedLazy); } public OrcReaderOptions withStreamBufferSize(DataSize streamBufferSize) @@ -158,7 +172,8 @@ public OrcReaderOptions withStreamBufferSize(DataSize streamBufferSize) tinyStripeThreshold, streamBufferSize, maxBlockSize, - lazyReadSmallRanges); + lazyReadSmallRanges, + nestedLazy); } public OrcReaderOptions withMaxReadBlockSize(DataSize maxBlockSize) @@ -170,7 +185,8 @@ public OrcReaderOptions withMaxReadBlockSize(DataSize maxBlockSize) tinyStripeThreshold, streamBufferSize, maxBlockSize, - lazyReadSmallRanges); + lazyReadSmallRanges, + nestedLazy); } // TODO remove config option once efficacy is proven @@ -184,6 +200,22 @@ public OrcReaderOptions withLazyReadSmallRanges(boolean lazyReadSmallRanges) tinyStripeThreshold, streamBufferSize, maxBlockSize, - lazyReadSmallRanges); + lazyReadSmallRanges, + nestedLazy); + } + + // TODO remove config option once efficacy is proven + @Deprecated + public OrcReaderOptions withNestedLazy(boolean nestedLazy) + { + return new OrcReaderOptions( + bloomFiltersEnabled, + maxMergeDistance, + maxBufferSize, + tinyStripeThreshold, + streamBufferSize, + maxBlockSize, + lazyReadSmallRanges, + nestedLazy); } } diff --git a/presto-orc/src/main/java/io/prestosql/orc/OrcRecordReader.java b/presto-orc/src/main/java/io/prestosql/orc/OrcRecordReader.java index 61bdca4ceeb3..51aee0fdca4a 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/OrcRecordReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/OrcRecordReader.java @@ -27,15 +27,14 @@ import io.prestosql.orc.OrcWriteValidation.WriteChecksum; import io.prestosql.orc.OrcWriteValidation.WriteChecksumBuilder; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.metadata.MetadataReader; import io.prestosql.orc.metadata.OrcType; -import io.prestosql.orc.metadata.OrcType.OrcTypeKind; import io.prestosql.orc.metadata.PostScript.HiveWriterVersion; import io.prestosql.orc.metadata.StripeInformation; import io.prestosql.orc.metadata.statistics.ColumnStatistics; import io.prestosql.orc.metadata.statistics.StripeStatistics; -import io.prestosql.orc.reader.StreamReader; -import io.prestosql.orc.reader.StreamReaders; +import io.prestosql.orc.reader.ColumnReader; import io.prestosql.orc.stream.InputStreamSources; import io.prestosql.spi.Page; import io.prestosql.spi.block.Block; @@ -47,11 +46,12 @@ import java.io.IOException; import java.time.ZoneId; import java.util.ArrayList; +import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; +import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -61,6 +61,7 @@ import static io.prestosql.orc.OrcReader.MAX_BATCH_SIZE; import static io.prestosql.orc.OrcRecordReader.LinearProbeRangeFinder.createTinyStripesRangeFinder; import static io.prestosql.orc.OrcWriteValidation.WriteChecksumBuilder.createWriteChecksumBuilder; +import static io.prestosql.orc.reader.ColumnReaders.createColumnReader; import static java.lang.Math.max; import static java.lang.Math.min; import static java.lang.Math.toIntExact; @@ -74,13 +75,13 @@ public class OrcRecordReader private final OrcDataSource orcDataSource; - private final StreamReader[] streamReaders; + private final ColumnReader[] columnReaders; + private final long[] currentBytesPerCell; private final long[] maxBytesPerCell; private long maxCombinedBytesPerRow; private final long totalRowCount; private final long splitLength; - private final Set presentColumns; private final long maxBlockBytes; private long currentPosition; private long currentStripePosition; @@ -106,6 +107,8 @@ public class OrcRecordReader private final AggregatedMemoryContext systemMemoryUsage; + private final OrcBlockFactory blockFactory; + private final Optional writeValidation; private final Optional writeChecksumBuilder; private final Optional rowGroupStatisticsValidation; @@ -113,16 +116,17 @@ public class OrcRecordReader private final Optional fileStatisticsValidation; public OrcRecordReader( - Map includedColumns, + List readColumns, + List readTypes, OrcPredicate predicate, long numberOfRows, List fileStripes, - List fileStats, - List stripeStats, + Optional> fileStats, + List> stripeStats, OrcDataSource orcDataSource, long splitOffset, long splitLength, - List types, + ColumnMetadata orcTypes, Optional decompressor, int rowsInRowGroup, DateTimeZone hiveStorageTimeZone, @@ -132,40 +136,32 @@ public OrcRecordReader( Map userMetadata, AggregatedMemoryContext systemMemoryUsage, Optional writeValidation, - int initialBatchSize) + int initialBatchSize, + Function exceptionTransform) throws OrcCorruptionException { - requireNonNull(includedColumns, "includedColumns is null"); + requireNonNull(readColumns, "readColumns is null"); + checkArgument(readColumns.stream().distinct().count() == readColumns.size(), "readColumns contains duplicate entries"); + requireNonNull(readTypes, "readTypes is null"); + checkArgument(readColumns.size() == readTypes.size(), "readColumns and readTypes must have the same size"); requireNonNull(predicate, "predicate is null"); requireNonNull(fileStripes, "fileStripes is null"); requireNonNull(stripeStats, "stripeStats is null"); requireNonNull(orcDataSource, "orcDataSource is null"); - requireNonNull(types, "types is null"); + requireNonNull(orcTypes, "types is null"); requireNonNull(decompressor, "decompressor is null"); requireNonNull(hiveStorageTimeZone, "hiveStorageTimeZone is null"); requireNonNull(userMetadata, "userMetadata is null"); requireNonNull(systemMemoryUsage, "systemMemoryUsage is null"); + requireNonNull(exceptionTransform, "exceptionTransform is null"); this.writeValidation = requireNonNull(writeValidation, "writeValidation is null"); - this.writeChecksumBuilder = writeValidation.map(validation -> createWriteChecksumBuilder(includedColumns)); - this.rowGroupStatisticsValidation = writeValidation.map(validation -> validation.createWriteStatisticsBuilder(includedColumns)); - this.stripeStatisticsValidation = writeValidation.map(validation -> validation.createWriteStatisticsBuilder(includedColumns)); - this.fileStatisticsValidation = writeValidation.map(validation -> validation.createWriteStatisticsBuilder(includedColumns)); + this.writeChecksumBuilder = writeValidation.map(validation -> createWriteChecksumBuilder(orcTypes, readTypes)); + this.rowGroupStatisticsValidation = writeValidation.map(validation -> validation.createWriteStatisticsBuilder(orcTypes, readTypes)); + this.stripeStatisticsValidation = writeValidation.map(validation -> validation.createWriteStatisticsBuilder(orcTypes, readTypes)); + this.fileStatisticsValidation = writeValidation.map(validation -> validation.createWriteStatisticsBuilder(orcTypes, readTypes)); this.systemMemoryUsage = systemMemoryUsage.newAggregatedMemoryContext(); - - // reduce the included columns to the set that is also present - ImmutableSet.Builder presentColumns = ImmutableSet.builder(); - ImmutableMap.Builder presentColumnsAndTypes = ImmutableMap.builder(); - OrcType root = types.get(0); - for (Map.Entry entry : includedColumns.entrySet()) { - // an old file can have less columns since columns can be added - // after the file was written - if (entry.getKey() < root.getFieldCount()) { - presentColumns.add(entry.getKey()); - presentColumnsAndTypes.put(entry.getKey(), entry.getValue()); - } - } - this.presentColumns = presentColumns.build(); + this.blockFactory = new OrcBlockFactory(exceptionTransform, options.isNestedLazy()); requireNonNull(options, "options is null"); this.maxBlockBytes = options.getMaxBlockSize().toBytes(); @@ -179,7 +175,7 @@ public OrcRecordReader( Optional stats = Optional.empty(); // ignore all stripe stats if too few or too many if (stripeStats.size() == fileStripes.size()) { - stats = Optional.of(stripeStats.get(i)); + stats = stripeStats.get(i); } stripeInfos.add(new StripeInfo(fileStripes.get(i), stats)); } @@ -189,11 +185,11 @@ public OrcRecordReader( long fileRowCount = 0; ImmutableList.Builder stripes = ImmutableList.builder(); ImmutableList.Builder stripeFilePositions = ImmutableList.builder(); - if (predicate.matches(numberOfRows, getStatisticsByColumnOrdinal(root, fileStats))) { + if (!fileStats.isPresent() || predicate.matches(numberOfRows, fileStats.get())) { // select stripes that start within the specified split for (StripeInfo info : stripeInfos) { StripeInformation stripe = info.getStripe(); - if (splitContainsStripe(splitOffset, splitLength, stripe) && isStripeIncluded(root, stripe, info.getStats(), predicate)) { + if (splitContainsStripe(splitOffset, splitLength, stripe) && isStripeIncluded(stripe, info.getStats(), predicate)) { stripes.add(stripe); stripeFilePositions.add(fileRowCount); totalRowCount += stripe.getNumberOfRows(); @@ -228,16 +224,17 @@ public OrcRecordReader( orcDataSource, hiveStorageTimeZone.toTimeZone().toZoneId(), decompressor, - types, - this.presentColumns, + orcTypes, + ImmutableSet.copyOf(readColumns), rowsInRowGroup, predicate, hiveWriterVersion, metadataReader, writeValidation); - streamReaders = createStreamReaders(orcDataSource, types, presentColumnsAndTypes.build(), streamReadersSystemMemoryContext); - maxBytesPerCell = new long[streamReaders.length]; + columnReaders = createColumnReaders(readColumns, readTypes, streamReadersSystemMemoryContext, blockFactory); + currentBytesPerCell = new long[columnReaders.length]; + maxBytesPerCell = new long[columnReaders.length]; nextBatchSize = initialBatchSize; } @@ -248,7 +245,6 @@ private static boolean splitContainsStripe(long splitOffset, long splitLength, S } private static boolean isStripeIncluded( - OrcType rootStructType, StripeInformation stripe, Optional stripeStats, OrcPredicate predicate) @@ -256,8 +252,7 @@ private static boolean isStripeIncluded( // if there are no stats, include the column return stripeStats .map(StripeStatistics::getColumnStatistics) - .map(columnStats -> getStatisticsByColumnOrdinal(rootStructType, columnStats)) - .map(statsByColumn -> predicate.matches(stripe.getNumberOfRows(), statsByColumn)) + .map(columnStats -> predicate.matches(stripe.getNumberOfRows(), columnStats)) .orElse(true); } @@ -333,7 +328,7 @@ public void close() { try (Closer closer = Closer.create()) { closer.register(orcDataSource); - for (StreamReader column : streamReaders) { + for (ColumnReader column : columnReaders) { if (column != null) { closer.register(column::close); } @@ -352,22 +347,18 @@ public void close() validateWrite(validation -> validation.getChecksum().getStripeHash() == actualChecksum.getStripeHash(), "Invalid stripes checksum"); } if (fileStatisticsValidation.isPresent()) { - List columnStatistics = fileStatisticsValidation.get().build(); + Optional> columnStatistics = fileStatisticsValidation.get().build(); writeValidation.get().validateFileStatistics(orcDataSource.getId(), columnStatistics); } } - public boolean isColumnPresent(int hiveColumnIndex) - { - return presentColumns.contains(hiveColumnIndex); - } - - public int nextBatch() + public Page nextPage() throws IOException { // update position for current row group (advancing resets them) filePosition += currentBatchSize; currentPosition += currentBatchSize; + currentBatchSize = 0; // if next row is within the current group return if (nextRowInGroup >= currentGroupRowCount) { @@ -375,7 +366,7 @@ public int nextBatch() if (!advanceToNextRowGroup()) { filePosition = fileRowCount; currentPosition = totalRowCount; - return -1; + return null; } } @@ -392,7 +383,7 @@ public int nextBatch() nextBatchSize = min(currentBatchSize * BATCH_SIZE_GROWTH_FACTOR, MAX_BATCH_SIZE); currentBatchSize = toIntExact(min(currentBatchSize, currentGroupRowCount - nextRowInGroup)); - for (StreamReader column : streamReaders) { + for (ColumnReader column : columnReaders) { if (column != null) { column.prepareNextRead(currentBatchSize); } @@ -400,22 +391,34 @@ public int nextBatch() nextRowInGroup += currentBatchSize; validateWritePageChecksum(); - return currentBatchSize; + + // create a lazy page + blockFactory.nextPage(); + Arrays.fill(currentBytesPerCell, 0); + Block[] blocks = new Block[columnReaders.length]; + for (int i = 0; i < columnReaders.length; i++) { + int columnIndex = i; + blocks[columnIndex] = blockFactory.createBlock( + currentBatchSize, + columnReaders[columnIndex]::readBlock, + block -> blockLoaded(columnIndex, block)); + } + return new Page(currentBatchSize, blocks); } - public Block readBlock(int columnIndex) - throws IOException + private void blockLoaded(int columnIndex, Block block) { - Block block = streamReaders[columnIndex].readBlock(); - if (block.getPositionCount() > 0) { - long bytesPerCell = block.getSizeInBytes() / block.getPositionCount(); - if (maxBytesPerCell[columnIndex] < bytesPerCell) { - maxCombinedBytesPerRow = maxCombinedBytesPerRow - maxBytesPerCell[columnIndex] + bytesPerCell; - maxBytesPerCell[columnIndex] = bytesPerCell; - maxBatchSize = toIntExact(min(maxBatchSize, max(1, maxBlockBytes / maxCombinedBytesPerRow))); - } + if (block.getPositionCount() <= 0) { + return; + } + + currentBytesPerCell[columnIndex] += block.getSizeInBytes() / currentBatchSize; + if (maxBytesPerCell[columnIndex] < currentBytesPerCell[columnIndex]) { + long delta = currentBytesPerCell[columnIndex] - maxBytesPerCell[columnIndex]; + maxCombinedBytesPerRow += delta; + maxBytesPerCell[columnIndex] = currentBytesPerCell[columnIndex]; + maxBatchSize = toIntExact(min(maxBatchSize, max(1, maxBlockBytes / maxCombinedBytesPerRow))); } - return block; } public Map getUserMetadata() @@ -432,7 +435,7 @@ private boolean advanceToNextRowGroup() if (rowGroupStatisticsValidation.isPresent()) { StatisticsValidation statisticsValidation = rowGroupStatisticsValidation.get(); long offset = stripes.get(currentStripe).getOffset(); - writeValidation.get().validateRowGroupStatistics(orcDataSource.getId(), offset, currentRowGroup, statisticsValidation.build()); + writeValidation.get().validateRowGroupStatistics(orcDataSource.getId(), offset, currentRowGroup, statisticsValidation.build().get()); statisticsValidation.reset(); } } @@ -458,7 +461,7 @@ private boolean advanceToNextRowGroup() // give reader data streams from row group InputStreamSources rowGroupStreamSources = currentRowGroup.getStreamSources(); - for (StreamReader column : streamReaders) { + for (ColumnReader column : columnReaders) { if (column != null) { column.startRowGroup(rowGroupStreamSources); } @@ -478,7 +481,7 @@ private void advanceToNextStripe() if (stripeStatisticsValidation.isPresent()) { StatisticsValidation statisticsValidation = stripeStatisticsValidation.get(); long offset = stripes.get(currentStripe).getOffset(); - writeValidation.get().validateStripeStatistics(orcDataSource.getId(), offset, statisticsValidation.build()); + writeValidation.get().validateStripeStatistics(orcDataSource.getId(), offset, statisticsValidation.build().get()); statisticsValidation.reset(); } } @@ -499,9 +502,9 @@ private void advanceToNextStripe() if (stripe != null) { // Give readers access to dictionary streams InputStreamSources dictionaryStreamSources = stripe.getDictionaryStreamSources(); - List columnEncodings = stripe.getColumnEncodings(); + ColumnMetadata columnEncodings = stripe.getColumnEncodings(); ZoneId timeZone = stripe.getTimeZone(); - for (StreamReader column : streamReaders) { + for (ColumnReader column : columnReaders) { if (column != null) { column.startStripe(timeZone, dictionaryStreamSources, columnEncodings); } @@ -528,9 +531,11 @@ private void validateWritePageChecksum() throws IOException { if (writeChecksumBuilder.isPresent()) { - Block[] blocks = new Block[streamReaders.length]; - for (int columnIndex = 0; columnIndex < streamReaders.length; columnIndex++) { - blocks[columnIndex] = readBlock(columnIndex); + Block[] blocks = new Block[columnReaders.length]; + for (int columnIndex = 0; columnIndex < columnReaders.length; columnIndex++) { + Block block = columnReaders[columnIndex].readBlock(); + blocks[columnIndex] = block; + blockLoaded(columnIndex, block); } Page page = new Page(currentBatchSize, blocks); writeChecksumBuilder.get().addPage(page); @@ -540,67 +545,25 @@ private void validateWritePageChecksum() } } - private static StreamReader[] createStreamReaders( - OrcDataSource orcDataSource, - List types, - Map includedColumns, - AggregatedMemoryContext systemMemoryContext) + private ColumnReader[] createColumnReaders( + List columns, + List readTypes, + AggregatedMemoryContext systemMemoryContext, + OrcBlockFactory blockFactory) throws OrcCorruptionException { - List streamDescriptors = createStreamDescriptor("", "", 0, types, orcDataSource).getNestedStreams(); - - OrcType rowType = types.get(0); - StreamReader[] streamReaders = new StreamReader[rowType.getFieldCount()]; - for (int columnId = 0; columnId < rowType.getFieldCount(); columnId++) { - Type type = includedColumns.get(columnId); - if (type != null) { - StreamDescriptor streamDescriptor = streamDescriptors.get(columnId); - streamReaders[columnId] = StreamReaders.createStreamReader(type, streamDescriptor, systemMemoryContext); - } - } - return streamReaders; - } - - private static StreamDescriptor createStreamDescriptor(String parentStreamName, String fieldName, int typeId, List types, OrcDataSource dataSource) - { - OrcType type = types.get(typeId); - - if (!fieldName.isEmpty()) { - parentStreamName += "." + fieldName; - } - - ImmutableList.Builder nestedStreams = ImmutableList.builder(); - if (type.getOrcTypeKind() == OrcTypeKind.STRUCT) { - for (int i = 0; i < type.getFieldCount(); ++i) { - nestedStreams.add(createStreamDescriptor(parentStreamName, type.getFieldName(i), type.getFieldTypeIndex(i), types, dataSource)); - } - } - else if (type.getOrcTypeKind() == OrcTypeKind.LIST) { - nestedStreams.add(createStreamDescriptor(parentStreamName, "item", type.getFieldTypeIndex(0), types, dataSource)); - } - else if (type.getOrcTypeKind() == OrcTypeKind.MAP) { - nestedStreams.add(createStreamDescriptor(parentStreamName, "key", type.getFieldTypeIndex(0), types, dataSource)); - nestedStreams.add(createStreamDescriptor(parentStreamName, "value", type.getFieldTypeIndex(1), types, dataSource)); - } - return new StreamDescriptor(parentStreamName, typeId, fieldName, type.getOrcTypeKind(), dataSource, nestedStreams.build()); - } - - private static Map getStatisticsByColumnOrdinal(OrcType rootStructType, List fileStats) - { - requireNonNull(rootStructType, "rootStructType is null"); - checkArgument(rootStructType.getOrcTypeKind() == OrcTypeKind.STRUCT); - requireNonNull(fileStats, "fileStats is null"); - - ImmutableMap.Builder statistics = ImmutableMap.builder(); - for (int ordinal = 0; ordinal < rootStructType.getFieldCount(); ordinal++) { - if (fileStats.size() > ordinal) { - ColumnStatistics element = fileStats.get(rootStructType.getFieldTypeIndex(ordinal)); - if (element != null) { - statistics.put(ordinal, element); - } - } - } - return statistics.build(); + ColumnReader[] columnReaders = new ColumnReader[columns.size()]; + for (int i = 0; i < columns.size(); i++) { + int columnIndex = i; + Type readType = readTypes.get(columnIndex); + OrcColumn column = columns.get(columnIndex); + columnReaders[columnIndex] = createColumnReader( + readType, + column, + systemMemoryContext, + blockFactory.createNestedBlockFactory(block -> blockLoaded(columnIndex, block))); + } + return columnReaders; } /** @@ -610,7 +573,7 @@ private static Map getStatisticsByColumnOrdinal(OrcTy long getStreamReaderRetainedSizeInBytes() { long totalRetainedSizeInBytes = 0; - for (StreamReader column : streamReaders) { + for (ColumnReader column : columnReaders) { if (column != null) { totalRetainedSizeInBytes += column.getRetainedSizeInBytes(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/OrcWriteValidation.java b/presto-orc/src/main/java/io/prestosql/orc/OrcWriteValidation.java index f6790382a30c..a99b045e03fc 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/OrcWriteValidation.java +++ b/presto-orc/src/main/java/io/prestosql/orc/OrcWriteValidation.java @@ -19,7 +19,10 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.airlift.slice.XxHash64; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; +import io.prestosql.orc.metadata.OrcType; import io.prestosql.orc.metadata.PostScript.HiveWriterVersion; import io.prestosql.orc.metadata.RowGroupIndex; import io.prestosql.orc.metadata.StripeInformation; @@ -60,6 +63,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.SortedMap; import java.util.function.Function; @@ -73,6 +77,7 @@ import static io.prestosql.orc.OrcWriteValidation.OrcWriteValidationMode.BOTH; import static io.prestosql.orc.OrcWriteValidation.OrcWriteValidationMode.DETAILED; import static io.prestosql.orc.OrcWriteValidation.OrcWriteValidationMode.HASHED; +import static io.prestosql.orc.metadata.OrcColumnId.ROOT_COLUMN; import static io.prestosql.orc.metadata.OrcMetadataReader.maxStringTruncateToValidRange; import static io.prestosql.orc.metadata.OrcMetadataReader.minStringTruncateToValidRange; import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED; @@ -108,7 +113,7 @@ public enum OrcWriteValidationMode private final WriteChecksum checksum; private final Map> rowGroupStatistics; private final Map stripeStatistics; - private final List fileStatistics; + private final Optional> fileStatistics; private final int stringStatisticsLimitInBytes; private OrcWriteValidation( @@ -121,7 +126,7 @@ private OrcWriteValidation( WriteChecksum checksum, Map> rowGroupStatistics, Map stripeStatistics, - List fileStatistics, + Optional> fileStatistics, int stringStatisticsLimitInBytes) { this.version = version; @@ -188,13 +193,24 @@ public WriteChecksum getChecksum() return checksum; } - public void validateFileStatistics(OrcDataSourceId orcDataSourceId, List actualFileStatistics) + public void validateFileStatistics(OrcDataSourceId orcDataSourceId, Optional> actualFileStatistics) throws OrcCorruptionException { - validateColumnStatisticsEquivalent(orcDataSourceId, "file", actualFileStatistics, fileStatistics); + // file stats will be absent when no rows are written + if (!fileStatistics.isPresent()) { + if (actualFileStatistics.isPresent()) { + throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: unexpected file statistics"); + } + return; + } + if (!actualFileStatistics.isPresent()) { + throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: expected file statistics"); + } + + validateColumnStatisticsEquivalent(orcDataSourceId, "file", actualFileStatistics.get(), fileStatistics.get()); } - public void validateStripeStatistics(OrcDataSourceId orcDataSourceId, List actualStripes, List actualStripeStatistics) + public void validateStripeStatistics(OrcDataSourceId orcDataSourceId, List actualStripes, List> actualStripeStatistics) throws OrcCorruptionException { requireNonNull(actualStripes, "actualStripes is null"); @@ -206,12 +222,12 @@ public void validateStripeStatistics(OrcDataSourceId orcDataSourceId, List actual) + public void validateStripeStatistics(OrcDataSourceId orcDataSourceId, long stripeOffset, ColumnMetadata actual) throws OrcCorruptionException { StripeStatistics expected = stripeStatistics.get(stripeOffset); @@ -240,16 +256,16 @@ public void validateRowGroupStatistics(OrcDataSourceId orcDataSourceId, long str for (int rowGroupIndex = 0; rowGroupIndex < expectedRowGroupStatistics.size(); rowGroupIndex++) { RowGroupStatistics expectedRowGroup = expectedRowGroupStatistics.get(rowGroupIndex); if (expectedRowGroup.getValidationMode() != HASHED) { - Map expectedStatistics = expectedRowGroup.getColumnStatistics(); - Set actualColumns = actualRowGroupStatistics.keySet().stream() - .map(StreamId::getColumn) + Map expectedStatistics = expectedRowGroup.getColumnStatistics(); + Set actualColumns = actualRowGroupStatistics.keySet().stream() + .map(StreamId::getColumnId) .collect(Collectors.toSet()); if (!expectedStatistics.keySet().equals(actualColumns)) { throw new OrcCorruptionException(orcDataSourceId, "Unexpected column in row group %s in stripe at offset %s", rowGroupIndex, stripeOffset); } for (Entry> entry : actualRowGroupStatistics.entrySet()) { ColumnStatistics actual = entry.getValue().get(rowGroupIndex).getColumnStatistics(); - ColumnStatistics expected = expectedStatistics.get(entry.getKey().getColumn()); + ColumnStatistics expected = expectedStatistics.get(entry.getKey().getColumnId()); validateColumnStatisticsEquivalent(orcDataSourceId, "Row group " + rowGroupIndex + " in stripe at offset " + stripeOffset, actual, expected); } } @@ -269,14 +285,14 @@ private static RowGroupStatistics buildActualRowGroupStatistics(int rowGroupInde BOTH, actualRowGroupStatistics.entrySet() .stream() - .collect(Collectors.toMap(entry -> entry.getKey().getColumn(), entry -> entry.getValue().get(rowGroupIndex).getColumnStatistics()))); + .collect(Collectors.toMap(entry -> entry.getKey().getColumnId(), entry -> entry.getValue().get(rowGroupIndex).getColumnStatistics()))); } public void validateRowGroupStatistics( OrcDataSourceId orcDataSourceId, long stripeOffset, int rowGroupIndex, - List actual) + ColumnMetadata actual) throws OrcCorruptionException { List rowGroups = rowGroupStatistics.get(stripeOffset); @@ -288,16 +304,19 @@ public void validateRowGroupStatistics( } RowGroupStatistics expectedRowGroup = rowGroups.get(rowGroupIndex); - RowGroupStatistics actualRowGroup = new RowGroupStatistics(BOTH, IntStream.range(1, actual.size()).boxed().collect(toImmutableMap(identity(), actual::get))); + RowGroupStatistics actualRowGroup = new RowGroupStatistics(BOTH, IntStream.range(1, actual.size()).mapToObj(OrcColumnId::new).collect(toImmutableMap(identity(), actual::get))); if (expectedRowGroup.getValidationMode() != HASHED) { - Map expectedByColumnIndex = expectedRowGroup.getColumnStatistics(); + Map expectedByColumnIndex = expectedRowGroup.getColumnStatistics(); // new writer does not write row group stats for column zero (table row column) - List expected = IntStream.range(1, actual.size()) - .mapToObj(expectedByColumnIndex::get) - .collect(toImmutableList()); - actual = actual.subList(1, actual.size()); + ColumnMetadata expected = new ColumnMetadata<>(IntStream.range(1, actual.size()) + .mapToObj(OrcColumnId::new) + .map(expectedByColumnIndex::get) + .collect(toImmutableList())); + actual = new ColumnMetadata<>(actual.stream() + .skip(1) + .collect(toImmutableList())); validateColumnStatisticsEquivalent(orcDataSourceId, "Row group " + rowGroupIndex + " in stripe at offset " + stripeOffset, actual, expected); } @@ -309,29 +328,17 @@ public void validateRowGroupStatistics( } } - public StatisticsValidation createWriteStatisticsBuilder(Map readColumns) + public StatisticsValidation createWriteStatisticsBuilder(ColumnMetadata orcTypes, List readTypes) { - requireNonNull(readColumns, "readColumns is null"); - checkArgument(!readColumns.isEmpty(), "readColumns is empty"); - int columnCount = readColumns.keySet().stream() - .mapToInt(Integer::intValue) - .max().getAsInt() + 1; - checkArgument(readColumns.size() == columnCount, "statistics validation requires all columns to be read"); - - ImmutableList.Builder types = ImmutableList.builder(); - for (int column = 0; column < columnCount; column++) { - Type type = readColumns.get(column); - checkArgument(type != null, "statistics validation requires all columns to be read"); - types.add(type); - } - return new StatisticsValidation(types.build()); + checkArgument(readTypes.size() == orcTypes.get(ROOT_COLUMN).getFieldCount(), "statistics validation requires all columns to be read"); + return new StatisticsValidation(readTypes); } private static void validateColumnStatisticsEquivalent( OrcDataSourceId orcDataSourceId, String name, - List actualColumnStatistics, - List expectedColumnStatistics) + ColumnMetadata actualColumnStatistics, + ColumnMetadata expectedColumnStatistics) throws OrcCorruptionException { requireNonNull(name, "name is null"); @@ -341,8 +348,9 @@ private static void validateColumnStatisticsEquivalent( throw new OrcCorruptionException(orcDataSourceId, "Write validation failed: unexpected number of columns in %s statistics", name); } for (int i = 0; i < actualColumnStatistics.size(); i++) { - ColumnStatistics actual = actualColumnStatistics.get(i); - ColumnStatistics expected = expectedColumnStatistics.get(i); + OrcColumnId columnId = new OrcColumnId(i); + ColumnStatistics actual = actualColumnStatistics.get(columnId); + ColumnStatistics expected = expectedColumnStatistics.get(columnId); validateColumnStatisticsEquivalent(orcDataSourceId, name + " column " + i, actual, expected); } } @@ -465,22 +473,10 @@ private WriteChecksumBuilder(List types) this.columnHashes = columnHashes.build(); } - public static WriteChecksumBuilder createWriteChecksumBuilder(Map readColumns) + public static WriteChecksumBuilder createWriteChecksumBuilder(ColumnMetadata orcTypes, List readTypes) { - requireNonNull(readColumns, "readColumns is null"); - checkArgument(!readColumns.isEmpty(), "readColumns is empty"); - int columnCount = readColumns.keySet().stream() - .mapToInt(Integer::intValue) - .max().getAsInt() + 1; - checkArgument(readColumns.size() == columnCount, "checksum requires all columns to be read"); - - ImmutableList.Builder types = ImmutableList.builder(); - for (int column = 0; column < columnCount; column++) { - Type type = readColumns.get(column); - checkArgument(type != null, "checksum requires all columns to be read"); - types.add(type); - } - return new WriteChecksumBuilder(types.build()); + checkArgument(readTypes.size() == orcTypes.get(ROOT_COLUMN).getFieldCount(), "checksum requires all columns to be read"); + return new WriteChecksumBuilder(readTypes); } public void addStripe(int rowCount) @@ -601,15 +597,15 @@ public void addPage(Page page) } } - public List build() + public Optional> build() { - ImmutableList.Builder statisticsBuilders = ImmutableList.builder(); - // if there are no rows, there will be no stats - if (rowCount > 0) { - statisticsBuilders.add(new ColumnStatistics(rowCount, 0, null, null, null, null, null, null, null, null)); - columnStatisticsValidations.forEach(validation -> validation.build(statisticsBuilders)); + if (rowCount == 0) { + return Optional.empty(); } - return statisticsBuilders.build(); + ImmutableList.Builder statisticsBuilders = ImmutableList.builder(); + statisticsBuilders.add(new ColumnStatistics(rowCount, 0, null, null, null, null, null, null, null, null)); + columnStatisticsValidations.forEach(validation -> validation.build(statisticsBuilders)); + return Optional.of(new ColumnMetadata<>(statisticsBuilders.build())); } } @@ -773,10 +769,10 @@ private static class RowGroupStatistics private static final int INSTANCE_SIZE = ClassLayout.parseClass(RowGroupStatistics.class).instanceSize(); private final OrcWriteValidationMode validationMode; - private final SortedMap columnStatistics; + private final SortedMap columnStatistics; private final long hash; - public RowGroupStatistics(OrcWriteValidationMode validationMode, Map columnStatistics) + public RowGroupStatistics(OrcWriteValidationMode validationMode, Map columnStatistics) { this.validationMode = validationMode; @@ -798,12 +794,12 @@ else if (validationMode == BOTH) { } } - private static long hashColumnStatistics(SortedMap columnStatistics) + private static long hashColumnStatistics(SortedMap columnStatistics) { StatisticsHasher statisticsHasher = new StatisticsHasher(); statisticsHasher.putInt(columnStatistics.size()); - for (Entry entry : columnStatistics.entrySet()) { - statisticsHasher.putInt(entry.getKey()) + for (Entry entry : columnStatistics.entrySet()) { + statisticsHasher.putInt(entry.getKey().getId()) .putOptionalHashable(entry.getValue()); } return statisticsHasher.hash(); @@ -814,7 +810,7 @@ public OrcWriteValidationMode getValidationMode() return validationMode; } - public Map getColumnStatistics() + public Map getColumnStatistics() { verify(validationMode != HASHED, "columnStatistics are not available in HASHED mode"); return columnStatistics; @@ -843,7 +839,7 @@ public static class OrcWriteValidationBuilder private List currentRowGroupStatistics = new ArrayList<>(); private final Map> rowGroupStatisticsByStripe = new HashMap<>(); private final Map stripeStatistics = new HashMap<>(); - private List fileStatistics; + private Optional> fileStatistics = Optional.empty(); private long retainedSize = INSTANCE_SIZE; public OrcWriteValidationBuilder(OrcWriteValidationMode validationMode, List types) @@ -908,7 +904,7 @@ public OrcWriteValidationBuilder addPage(Page page) return this; } - public void addRowGroupStatistics(Map columnStatistics) + public void addRowGroupStatistics(Map columnStatistics) { RowGroupStatistics rowGroupStatistics = new RowGroupStatistics(validationMode, columnStatistics); currentRowGroupStatistics.add(rowGroupStatistics); @@ -928,7 +924,7 @@ public void addStripeStatistics(long stripStartOffset, StripeStatistics columnSt currentRowGroupStatistics = new ArrayList<>(); } - public void setFileStatistics(List fileStatistics) + public void setFileStatistics(Optional> fileStatistics) { this.fileStatistics = fileStatistics; } diff --git a/presto-orc/src/main/java/io/prestosql/orc/OrcWriter.java b/presto-orc/src/main/java/io/prestosql/orc/OrcWriter.java index 9298059ae007..c7348ec62f31 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/OrcWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/OrcWriter.java @@ -23,10 +23,12 @@ import io.prestosql.orc.OrcWriteValidation.OrcWriteValidationMode; import io.prestosql.orc.OrcWriterStats.FlushReason; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.metadata.CompressedMetadataWriter; import io.prestosql.orc.metadata.CompressionKind; import io.prestosql.orc.metadata.Footer; import io.prestosql.orc.metadata.Metadata; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.OrcMetadataWriter; import io.prestosql.orc.metadata.OrcType; import io.prestosql.orc.metadata.Stream; @@ -69,6 +71,7 @@ import static io.prestosql.orc.OrcWriterStats.FlushReason.MAX_BYTES; import static io.prestosql.orc.OrcWriterStats.FlushReason.MAX_ROWS; import static io.prestosql.orc.metadata.ColumnEncoding.ColumnEncodingKind.DIRECT; +import static io.prestosql.orc.metadata.OrcColumnId.ROOT_COLUMN; import static io.prestosql.orc.metadata.PostScript.MAGIC; import static io.prestosql.orc.stream.OrcDataOutput.createDataOutput; import static io.prestosql.orc.writer.ColumnWriters.createColumnWriter; @@ -104,7 +107,7 @@ public final class OrcWriter private final DateTimeZone hiveStorageTimeZone; private final List closedStripes = new ArrayList<>(); - private final List orcTypes; + private final ColumnMetadata orcTypes; private final List columnWriters; private final DictionaryCompressionOptimizer dictionaryCompressionOptimizer; @@ -160,16 +163,16 @@ public OrcWriter( this.stats = requireNonNull(stats, "stats is null"); requireNonNull(columnNames, "columnNames is null"); - this.orcTypes = OrcType.createOrcRowType(0, columnNames, types); + this.orcTypes = OrcType.createRootOrcType(columnNames, types); recordValidation(validation -> validation.setColumnNames(columnNames)); // create column writers - OrcType rootType = orcTypes.get(0); + OrcType rootType = orcTypes.get(ROOT_COLUMN); checkArgument(rootType.getFieldCount() == types.size()); ImmutableList.Builder columnWriters = ImmutableList.builder(); ImmutableSet.Builder sliceColumnWriters = ImmutableSet.builder(); for (int fieldId = 0; fieldId < types.size(); fieldId++) { - int fieldColumnIndex = rootType.getFieldTypeIndex(fieldId); + OrcColumnId fieldColumnIndex = rootType.getFieldTypeIndex(fieldId); Type fieldType = types.get(fieldId); ColumnWriter columnWriter = createColumnWriter(fieldColumnIndex, orcTypes, fieldType, compression, maxCompressionBufferSize, hiveStorageTimeZone, options.getMaxStringStatisticsLimit()); columnWriters.add(columnWriter); @@ -311,7 +314,7 @@ else if (dictionaryCompressionOptimizer.isFull(bufferedBytes)) { private void finishRowGroup() { - Map columnStatistics = new HashMap<>(); + Map columnStatistics = new HashMap<>(); columnWriters.forEach(columnWriter -> columnStatistics.putAll(columnWriter.finishRowGroup())); recordValidation(validation -> validation.addRowGroupStatistics(columnStatistics)); rowGroupRowCount = 0; @@ -401,24 +404,24 @@ private List bufferStripeData(long stripeStartOffset, FlushReason allStreams.add(dataStream.getStream()); } - Map columnEncodings = new HashMap<>(); + Map columnEncodings = new HashMap<>(); columnWriters.forEach(columnWriter -> columnEncodings.putAll(columnWriter.getColumnEncodings())); - Map columnStatistics = new HashMap<>(); + Map columnStatistics = new HashMap<>(); columnWriters.forEach(columnWriter -> columnStatistics.putAll(columnWriter.getColumnStripeStatistics())); // the 0th column is a struct column for the whole row - columnEncodings.put(0, new ColumnEncoding(DIRECT, 0)); - columnStatistics.put(0, new ColumnStatistics((long) stripeRowCount, 0, null, null, null, null, null, null, null, null)); + columnEncodings.put(ROOT_COLUMN, new ColumnEncoding(DIRECT, 0)); + columnStatistics.put(ROOT_COLUMN, new ColumnStatistics((long) stripeRowCount, 0, null, null, null, null, null, null, null, null)); // add footer Optional timeZone = Optional.of(hiveStorageTimeZone.toTimeZone().toZoneId()); - StripeFooter stripeFooter = new StripeFooter(allStreams, toDenseList(columnEncodings, orcTypes.size()), timeZone); + StripeFooter stripeFooter = new StripeFooter(allStreams, toColumnMetadata(columnEncodings, orcTypes.size()), timeZone); Slice footer = metadataWriter.writeStripeFooter(stripeFooter); outputData.add(createDataOutput(footer)); // create final stripe statistics - StripeStatistics statistics = new StripeStatistics(toDenseList(columnStatistics, orcTypes.size())); + StripeStatistics statistics = new StripeStatistics(toColumnMetadata(columnStatistics, orcTypes.size())); recordValidation(validation -> validation.addStripeStatistics(stripeStartOffset, statistics)); StripeInformation stripeInformation = new StripeInformation(stripeRowCount, stripeStartOffset, indexLength, dataLength, footer.length()); ClosedStripe closedStripe = new ClosedStripe(stripeInformation, statistics); @@ -457,6 +460,7 @@ private List bufferFileFooter() Metadata metadata = new Metadata(closedStripes.stream() .map(ClosedStripe::getStatistics) + .map(Optional::of) .collect(toList())); Slice metadataSlice = metadataWriter.writeMetadata(metadata); outputData.add(createDataOutput(metadataSlice)); @@ -465,11 +469,10 @@ private List bufferFileFooter() .mapToLong(stripe -> stripe.getStripeInformation().getNumberOfRows()) .sum(); - List fileStats = toFileStats( - closedStripes.stream() - .map(ClosedStripe::getStatistics) - .map(StripeStatistics::getColumnStatistics) - .collect(toList())); + Optional> fileStats = toFileStats(closedStripes.stream() + .map(ClosedStripe::getStatistics) + .map(StripeStatistics::getColumnStatistics) + .collect(toList())); recordValidation(validation -> validation.setFileStatistics(fileStats)); Map userMetadata = this.userMetadata.entrySet().stream() @@ -512,32 +515,32 @@ public void validate(OrcDataSource input) validateFile(validationBuilder.build(), input, types, hiveStorageTimeZone); } - private static List toDenseList(Map data, int expectedSize) + private static ColumnMetadata toColumnMetadata(Map data, int expectedSize) { checkArgument(data.size() == expectedSize); List list = new ArrayList<>(expectedSize); for (int i = 0; i < expectedSize; i++) { - list.add(data.get(i)); + list.add(data.get(new OrcColumnId(i))); } - return ImmutableList.copyOf(list); + return new ColumnMetadata<>(ImmutableList.copyOf(list)); } - private static List toFileStats(List> stripes) + private static Optional> toFileStats(List> stripes) { if (stripes.isEmpty()) { - return ImmutableList.of(); + return Optional.empty(); } int columnCount = stripes.get(0).size(); checkArgument(stripes.stream().allMatch(stripe -> columnCount == stripe.size())); ImmutableList.Builder fileStats = ImmutableList.builder(); for (int i = 0; i < columnCount; i++) { - int column = i; + OrcColumnId columnId = new OrcColumnId(i); fileStats.add(ColumnStatistics.mergeColumnStatistics(stripes.stream() - .map(stripe -> stripe.get(column)) + .map(stripe -> stripe.get(columnId)) .collect(toList()))); } - return fileStats.build(); + return Optional.of(new ColumnMetadata<>(fileStats.build())); } private static class ClosedStripe diff --git a/presto-orc/src/main/java/io/prestosql/orc/StreamDescriptor.java b/presto-orc/src/main/java/io/prestosql/orc/StreamDescriptor.java deleted file mode 100644 index 00a67492ad56..000000000000 --- a/presto-orc/src/main/java/io/prestosql/orc/StreamDescriptor.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.orc; - -import com.google.common.collect.ImmutableList; -import io.prestosql.orc.metadata.OrcType.OrcTypeKind; - -import java.util.List; - -import static com.google.common.base.MoreObjects.toStringHelper; -import static java.util.Objects.requireNonNull; - -public final class StreamDescriptor -{ - private final String streamName; - private final int streamId; - private final OrcTypeKind streamType; - private final String fieldName; - private final OrcDataSource orcDataSource; - private final List nestedStreams; - - public StreamDescriptor(String streamName, int streamId, String fieldName, OrcTypeKind streamType, OrcDataSource orcDataSource, List nestedStreams) - { - this.streamName = requireNonNull(streamName, "streamName is null"); - this.streamId = streamId; - this.fieldName = requireNonNull(fieldName, "fieldName is null"); - this.streamType = requireNonNull(streamType, "type is null"); - this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null"); - this.nestedStreams = ImmutableList.copyOf(requireNonNull(nestedStreams, "nestedStreams is null")); - } - - public String getStreamName() - { - return streamName; - } - - public int getStreamId() - { - return streamId; - } - - public OrcTypeKind getStreamType() - { - return streamType; - } - - public String getFieldName() - { - return fieldName; - } - - public OrcDataSourceId getOrcDataSourceId() - { - return orcDataSource.getId(); - } - - public OrcDataSource getOrcDataSource() - { - return orcDataSource; - } - - public List getNestedStreams() - { - return nestedStreams; - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("streamName", streamName) - .add("streamId", streamId) - .add("streamType", streamType) - .add("dataSource", orcDataSource.getId()) - .toString(); - } -} diff --git a/presto-orc/src/main/java/io/prestosql/orc/StreamId.java b/presto-orc/src/main/java/io/prestosql/orc/StreamId.java index 81011dd60dd4..53229b5d2cb3 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/StreamId.java +++ b/presto-orc/src/main/java/io/prestosql/orc/StreamId.java @@ -13,33 +13,36 @@ */ package io.prestosql.orc; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; import java.util.Objects; import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; public final class StreamId { - private final int column; + private final OrcColumnId columnId; private final StreamKind streamKind; public StreamId(Stream stream) { - this.column = stream.getColumn(); + requireNonNull(stream, "stream is null"); + this.columnId = stream.getColumnId(); this.streamKind = stream.getStreamKind(); } - public StreamId(int column, StreamKind streamKind) + public StreamId(OrcColumnId columnId, StreamKind streamKind) { - this.column = column; + this.columnId = columnId; this.streamKind = streamKind; } - public int getColumn() + public OrcColumnId getColumnId() { - return column; + return columnId; } public StreamKind getStreamKind() @@ -50,29 +53,28 @@ public StreamKind getStreamKind() @Override public int hashCode() { - return Objects.hash(column, streamKind); + return Objects.hash(columnId, streamKind); } @Override - public boolean equals(Object obj) + public boolean equals(Object o) { - if (this == obj) { + if (this == o) { return true; } - if (obj == null || getClass() != obj.getClass()) { + if (o == null || getClass() != o.getClass()) { return false; } - - StreamId other = (StreamId) obj; - return column == other.column && - streamKind == other.streamKind; + StreamId streamId = (StreamId) o; + return Objects.equals(columnId, streamId.columnId) && + streamKind == streamId.streamKind; } @Override public String toString() { return toStringHelper(this) - .add("column", column) + .add("columnId", columnId) .add("streamKind", streamKind) .toString(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/Stripe.java b/presto-orc/src/main/java/io/prestosql/orc/Stripe.java index f43800ce5018..86478a76ed80 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/Stripe.java +++ b/presto-orc/src/main/java/io/prestosql/orc/Stripe.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.InputStreamSources; import java.time.ZoneId; @@ -27,11 +28,11 @@ public class Stripe { private final long rowCount; private final ZoneId timeZone; - private final List columnEncodings; + private final ColumnMetadata columnEncodings; private final List rowGroups; private final InputStreamSources dictionaryStreamSources; - public Stripe(long rowCount, ZoneId timeZone, List columnEncodings, List rowGroups, InputStreamSources dictionaryStreamSources) + public Stripe(long rowCount, ZoneId timeZone, ColumnMetadata columnEncodings, List rowGroups, InputStreamSources dictionaryStreamSources) { this.rowCount = rowCount; this.timeZone = requireNonNull(timeZone, "timeZone is null"); @@ -50,7 +51,7 @@ public ZoneId getTimeZone() return timeZone; } - public List getColumnEncodings() + public ColumnMetadata getColumnEncodings() { return columnEncodings; } diff --git a/presto-orc/src/main/java/io/prestosql/orc/StripeReader.java b/presto-orc/src/main/java/io/prestosql/orc/StripeReader.java index 327374301935..548e6ee81318 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/StripeReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/StripeReader.java @@ -24,7 +24,9 @@ import io.prestosql.orc.checkpoint.StreamCheckpoint; import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.ColumnEncoding.ColumnEncodingKind; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.metadata.MetadataReader; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.OrcType; import io.prestosql.orc.metadata.OrcType.OrcTypeKind; import io.prestosql.orc.metadata.PostScript.HiveWriterVersion; @@ -46,6 +48,8 @@ import java.io.IOException; import java.io.InputStream; import java.time.ZoneId; +import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; @@ -76,9 +80,9 @@ public class StripeReader private final OrcDataSource orcDataSource; private final ZoneId defaultTimeZone; private final Optional decompressor; - private final List types; + private final ColumnMetadata types; private final HiveWriterVersion hiveWriterVersion; - private final Set includedOrcColumns; + private final Set includedOrcColumnIds; private final int rowsInRowGroup; private final OrcPredicate predicate; private final MetadataReader metadataReader; @@ -87,8 +91,8 @@ public class StripeReader public StripeReader(OrcDataSource orcDataSource, ZoneId defaultTimeZone, Optional decompressor, - List types, - Set includedColumns, + ColumnMetadata types, + Set readColumns, int rowsInRowGroup, OrcPredicate predicate, HiveWriterVersion hiveWriterVersion, @@ -98,8 +102,8 @@ public StripeReader(OrcDataSource orcDataSource, this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null"); this.defaultTimeZone = requireNonNull(defaultTimeZone, "defaultTimeZone is null"); this.decompressor = requireNonNull(decompressor, "decompressor is null"); - this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); - this.includedOrcColumns = getIncludedOrcColumns(types, requireNonNull(includedColumns, "includedColumns is null")); + this.types = requireNonNull(types, "types is null"); + this.includedOrcColumnIds = getIncludeColumns(requireNonNull(readColumns, "readColumns is null")); this.rowsInRowGroup = rowsInRowGroup; this.predicate = requireNonNull(predicate, "predicate is null"); this.hiveWriterVersion = requireNonNull(hiveWriterVersion, "hiveWriterVersion is null"); @@ -112,7 +116,7 @@ public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext syste { // read the stripe footer StripeFooter stripeFooter = readStripeFooter(stripe, systemMemoryUsage); - List columnEncodings = stripeFooter.getColumnEncodings(); + ColumnMetadata columnEncodings = stripeFooter.getColumnEncodings(); if (writeValidation.isPresent()) { writeValidation.get().validateTimeZone(orcDataSource.getId(), stripeFooter.getTimeZone().orElse(null)); } @@ -121,7 +125,7 @@ public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext syste // get streams for selected columns Map streams = new HashMap<>(); for (Stream stream : stripeFooter.getStreams()) { - if (includedOrcColumns.contains(stream.getColumn()) && isSupportedStreamType(stream, types.get(stream.getColumn()).getOrcTypeKind())) { + if (includedOrcColumnIds.contains(stream.getColumnId()) && isSupportedStreamType(stream, types.get(stream.getColumnId()).getOrcTypeKind())) { streams.put(new StreamId(stream), stream); } } @@ -137,7 +141,7 @@ public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext syste Map streamsData = readDiskRanges(stripe.getOffset(), diskRanges, systemMemoryUsage); // read the bloom filter for each column - Map> bloomFilterIndexes = readBloomFilterIndexes(streams, streamsData); + Map> bloomFilterIndexes = readBloomFilterIndexes(streams, streamsData); // read the row index for each column Map> columnIndexes = readColumnIndexes(streams, streamsData, bloomFilterIndexes); @@ -183,7 +187,7 @@ public Stripe readStripe(StripeInformation stripe, AggregatedMemoryContext syste ImmutableMap.Builder diskRangesBuilder = ImmutableMap.builder(); for (Entry entry : getDiskRanges(stripeFooter.getStreams()).entrySet()) { StreamId streamId = entry.getKey(); - if (streams.keySet().contains(streamId)) { + if (streams.containsKey(streamId)) { diskRangesBuilder.put(entry); } } @@ -268,13 +272,13 @@ private Map readDiskRanges(long stripeOffset, Map> createValueStreams(Map streams, Map streamsData, List columnEncodings) + private Map> createValueStreams(Map streams, Map streamsData, ColumnMetadata columnEncodings) { ImmutableMap.Builder> valueStreams = ImmutableMap.builder(); for (Entry entry : streams.entrySet()) { StreamId streamId = entry.getKey(); Stream stream = entry.getValue(); - ColumnEncodingKind columnEncoding = columnEncodings.get(stream.getColumn()).getColumnEncodingKind(); + ColumnEncodingKind columnEncoding = columnEncodings.get(stream.getColumnId()).getColumnEncodingKind(); // skip index and empty streams if (isIndexStream(stream) || stream.getLength() == 0) { @@ -282,20 +286,20 @@ private Map> createValueStreams(Map streams, Map> valueStreams, List columnEncodings) + private InputStreamSources createDictionaryStreamSources(Map streams, Map> valueStreams, ColumnMetadata columnEncodings) { ImmutableMap.Builder> dictionaryStreamBuilder = ImmutableMap.builder(); for (Entry entry : streams.entrySet()) { StreamId streamId = entry.getKey(); Stream stream = entry.getValue(); - int column = stream.getColumn(); + OrcColumnId column = stream.getColumnId(); // only process dictionary streams ColumnEncodingKind columnEncoding = columnEncodings.get(column).getColumnEncodingKind(); @@ -309,7 +313,7 @@ private InputStreamSources createDictionaryStreamSources(Map s continue; } - OrcTypeKind columnType = types.get(stream.getColumn()).getOrcTypeKind(); + OrcTypeKind columnType = types.get(stream.getColumnId()).getOrcTypeKind(); StreamCheckpoint streamCheckpoint = getDictionaryStreamCheckpoint(streamId, columnType, columnEncoding); InputStreamSource streamSource = createCheckpointStreamSource(valueStream, streamCheckpoint); @@ -324,13 +328,13 @@ private List createRowGroups( Map> valueStreams, Map> columnIndexes, Set selectedRowGroups, - List encodings) + ColumnMetadata encodings) throws InvalidCheckpointException { ImmutableList.Builder rowGroupBuilder = ImmutableList.builder(); for (int rowGroupId : selectedRowGroups) { - Map checkpoints = getStreamCheckpoints(includedOrcColumns, types, decompressor.isPresent(), rowGroupId, encodings, streams, columnIndexes); + Map checkpoints = getStreamCheckpoints(includedOrcColumnIds, types, decompressor.isPresent(), rowGroupId, encodings, streams, columnIndexes); int rowOffset = rowGroupId * rowsInRowGroup; int rowsInGroup = Math.min(rowsInStripe - rowOffset, rowsInRowGroup); long minAverageRowBytes = columnIndexes @@ -384,28 +388,28 @@ static boolean isIndexStream(Stream stream) return stream.getStreamKind() == ROW_INDEX || stream.getStreamKind() == DICTIONARY_COUNT || stream.getStreamKind() == BLOOM_FILTER || stream.getStreamKind() == BLOOM_FILTER_UTF8; } - private Map> readBloomFilterIndexes(Map streams, Map streamsData) + private Map> readBloomFilterIndexes(Map streams, Map streamsData) throws IOException { - HashMap> bloomFilters = new HashMap<>(); + HashMap> bloomFilters = new HashMap<>(); for (Entry entry : streams.entrySet()) { Stream stream = entry.getValue(); if (stream.getStreamKind() == BLOOM_FILTER_UTF8) { OrcInputStream inputStream = new OrcInputStream(streamsData.get(entry.getKey())); - bloomFilters.put(stream.getColumn(), metadataReader.readBloomFilterIndexes(inputStream)); + bloomFilters.put(stream.getColumnId(), metadataReader.readBloomFilterIndexes(inputStream)); } } for (Entry entry : streams.entrySet()) { Stream stream = entry.getValue(); - if (stream.getStreamKind() == BLOOM_FILTER && !bloomFilters.containsKey(stream.getColumn())) { + if (stream.getStreamKind() == BLOOM_FILTER && !bloomFilters.containsKey(stream.getColumnId())) { OrcInputStream inputStream = new OrcInputStream(streamsData.get(entry.getKey())); - bloomFilters.put(entry.getKey().getColumn(), metadataReader.readBloomFilterIndexes(inputStream)); + bloomFilters.put(entry.getKey().getColumnId(), metadataReader.readBloomFilterIndexes(inputStream)); } } return ImmutableMap.copyOf(bloomFilters); } - private Map> readColumnIndexes(Map streams, Map streamsData, Map> bloomFilterIndexes) + private Map> readColumnIndexes(Map streams, Map streamsData, Map> bloomFilterIndexes) throws IOException { ImmutableMap.Builder> columnIndexes = ImmutableMap.builder(); @@ -413,7 +417,7 @@ private Map> readColumnIndexes(Map bloomFilters = bloomFilterIndexes.get(entry.getKey().getColumn()); + List bloomFilters = bloomFilterIndexes.get(entry.getKey().getColumnId()); List rowGroupIndexes = metadataReader.readRowIndexes(hiveWriterVersion, inputStream); if (bloomFilters != null && !bloomFilters.isEmpty()) { ImmutableList.Builder newRowGroupIndexes = ImmutableList.builder(); @@ -440,7 +444,7 @@ private Set selectRowGroups(StripeInformation stripe, Map statistics = getRowGroupStatistics(types.get(0), columnIndexes, rowGroup); + ColumnMetadata statistics = getRowGroupStatistics(types, columnIndexes, rowGroup); if (predicate.matches(rows, statistics)) { selectedRowGroups.add(rowGroup); } @@ -449,24 +453,25 @@ private Set selectRowGroups(StripeInformation stripe, Map getRowGroupStatistics(OrcType rootStructType, Map> columnIndexes, int rowGroup) + private static ColumnMetadata getRowGroupStatistics(ColumnMetadata types, Map> columnIndexes, int rowGroup) { - requireNonNull(rootStructType, "rootStructType is null"); - checkArgument(rootStructType.getOrcTypeKind() == OrcTypeKind.STRUCT); requireNonNull(columnIndexes, "columnIndexes is null"); checkArgument(rowGroup >= 0, "rowGroup is negative"); - Map> columnIndexesByField = columnIndexes.entrySet().stream() - .collect(toImmutableMap(entry -> entry.getKey().getColumn(), Entry::getValue)); + Map> rowGroupIndexesByColumn = columnIndexes.entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey().getColumnId().getId(), Entry::getValue)); - ImmutableMap.Builder statistics = ImmutableMap.builder(); - for (int ordinal = 0; ordinal < rootStructType.getFieldCount(); ordinal++) { - List rowGroupIndexes = columnIndexesByField.get(rootStructType.getFieldTypeIndex(ordinal)); + List statistics = new ArrayList<>(types.size()); + for (int columnIndex = 0; columnIndex < types.size(); columnIndex++) { + List rowGroupIndexes = rowGroupIndexesByColumn.get(columnIndex); if (rowGroupIndexes != null) { - statistics.put(ordinal, rowGroupIndexes.get(rowGroup).getColumnStatistics()); + statistics.add(rowGroupIndexes.get(rowGroup).getColumnStatistics()); + } + else { + statistics.add(null); } } - return statistics.build(); + return new ColumnMetadata<>(statistics); } private static boolean isDictionary(Stream stream, ColumnEncodingKind columnEncoding) @@ -489,25 +494,18 @@ private static Map getDiskRanges(List streams) return streamDiskRanges.build(); } - private static Set getIncludedOrcColumns(List types, Set includedColumns) + private static Set getIncludeColumns(Set includedColumns) { - Set includes = new LinkedHashSet<>(); - - OrcType root = types.get(0); - for (int includedColumn : includedColumns) { - includeOrcColumnsRecursive(types, includes, root.getFieldTypeIndex(includedColumn)); - } - - return includes; + Set result = new LinkedHashSet<>(); + includeColumnsRecursive(result, includedColumns); + return result; } - private static void includeOrcColumnsRecursive(List types, Set result, int typeId) + private static void includeColumnsRecursive(Set result, Collection readColumns) { - result.add(typeId); - OrcType type = types.get(typeId); - int children = type.getFieldCount(); - for (int i = 0; i < children; ++i) { - includeOrcColumnsRecursive(types, result, type.getFieldTypeIndex(i)); + for (OrcColumn column : readColumns) { + result.add(column.getColumnId()); + includeColumnsRecursive(result, column.getNestedColumns()); } } diff --git a/presto-orc/src/main/java/io/prestosql/orc/TupleDomainOrcPredicate.java b/presto-orc/src/main/java/io/prestosql/orc/TupleDomainOrcPredicate.java index 10879107a647..ddc5eed4cf67 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/TupleDomainOrcPredicate.java +++ b/presto-orc/src/main/java/io/prestosql/orc/TupleDomainOrcPredicate.java @@ -16,13 +16,14 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; +import io.prestosql.orc.metadata.ColumnMetadata; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.statistics.BloomFilter; import io.prestosql.orc.metadata.statistics.BooleanStatistics; import io.prestosql.orc.metadata.statistics.ColumnStatistics; import io.prestosql.orc.metadata.statistics.RangeStatistics; import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.Range; -import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.predicate.ValueSet; import io.prestosql.spi.type.DateType; import io.prestosql.spi.type.DecimalType; @@ -30,14 +31,13 @@ import io.prestosql.spi.type.VarbinaryType; import io.prestosql.spi.type.VarcharType; +import java.util.ArrayList; import java.util.Collection; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.function.Function; import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.base.Preconditions.checkArgument; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.BooleanType.BOOLEAN; import static io.prestosql.spi.type.Chars.isCharType; @@ -58,44 +58,34 @@ import static java.lang.Float.intBitsToFloat; import static java.util.Objects.requireNonNull; -public class TupleDomainOrcPredicate +public class TupleDomainOrcPredicate implements OrcPredicate { - private final TupleDomain effectivePredicate; - private final List> columnReferences; - + private final List columnDomains; private final boolean orcBloomFiltersEnabled; - public TupleDomainOrcPredicate(TupleDomain effectivePredicate, List> columnReferences, boolean orcBloomFiltersEnabled) + public static TupleDomainOrcPredicateBuilder builder() { - this.effectivePredicate = requireNonNull(effectivePredicate, "effectivePredicate is null"); - this.columnReferences = ImmutableList.copyOf(requireNonNull(columnReferences, "columnReferences is null")); + return new TupleDomainOrcPredicateBuilder(); + } + + private TupleDomainOrcPredicate(List columnDomains, boolean orcBloomFiltersEnabled) + { + this.columnDomains = ImmutableList.copyOf(requireNonNull(columnDomains, "columnDomains is null")); this.orcBloomFiltersEnabled = orcBloomFiltersEnabled; } @Override - public boolean matches(long numberOfRows, Map statisticsByColumnIndex) + public boolean matches(long numberOfRows, ColumnMetadata allColumnStatistics) { - Optional> optionalEffectivePredicateDomains = effectivePredicate.getDomains(); - if (!optionalEffectivePredicateDomains.isPresent()) { - // effective predicate is none, so skip this section - return false; - } - Map effectivePredicateDomains = optionalEffectivePredicateDomains.get(); - - for (ColumnReference columnReference : columnReferences) { - Domain predicateDomain = effectivePredicateDomains.get(columnReference.getColumn()); - if (predicateDomain == null) { - // no predicate on this column, so we can't exclude this section - continue; - } - ColumnStatistics columnStatistics = statisticsByColumnIndex.get(columnReference.getOrdinal()); + for (ColumnDomain column : columnDomains) { + ColumnStatistics columnStatistics = allColumnStatistics.get(column.getColumnId()); if (columnStatistics == null) { // no statistics for this column, so we can't exclude this section continue; } - if (!columnOverlaps(columnReference, predicateDomain, numberOfRows, columnStatistics)) { + if (!columnOverlaps(column.getDomain(), numberOfRows, columnStatistics)) { return false; } } @@ -104,9 +94,9 @@ public boolean matches(long numberOfRows, Map statist return true; } - private boolean columnOverlaps(ColumnReference columnReference, Domain predicateDomain, long numberOfRows, ColumnStatistics columnStatistics) + private boolean columnOverlaps(Domain predicateDomain, long numberOfRows, ColumnStatistics columnStatistics) { - Domain stripeDomain = getDomain(columnReference.getType(), numberOfRows, columnStatistics); + Domain stripeDomain = getDomain(predicateDomain.getType(), numberOfRows, columnStatistics); if (!stripeDomain.overlaps(predicateDomain)) { // there is no overlap between the predicate and this column return false; @@ -265,42 +255,57 @@ private static > Domain createDomain(Type type, boole return Domain.create(ValueSet.all(type), hasNullValue); } - public static class ColumnReference + public static class TupleDomainOrcPredicateBuilder { - private final C column; - private final int ordinal; - private final Type type; + private final List columns = new ArrayList<>(); + private boolean bloomFiltersEnabled; - public ColumnReference(C column, int ordinal, Type type) + public TupleDomainOrcPredicateBuilder addColumn(OrcColumnId columnId, Domain domain) { - this.column = requireNonNull(column, "column is null"); - checkArgument(ordinal >= 0, "ordinal is negative"); - this.ordinal = ordinal; - this.type = requireNonNull(type, "type is null"); + requireNonNull(domain, "domain is null"); + columns.add(new ColumnDomain(columnId, domain)); + return this; } - public C getColumn() + public TupleDomainOrcPredicateBuilder setBloomFiltersEnabled(boolean bloomFiltersEnabled) + { + this.bloomFiltersEnabled = bloomFiltersEnabled; + return this; + } + + public TupleDomainOrcPredicate build() + { + return new TupleDomainOrcPredicate(columns, bloomFiltersEnabled); + } + } + + private static class ColumnDomain + { + private final OrcColumnId columnId; + private final Domain domain; + + public ColumnDomain(OrcColumnId columnId, Domain domain) { - return column; + this.columnId = requireNonNull(columnId, "columnId is null"); + this.domain = requireNonNull(domain, "domain is null"); } - public int getOrdinal() + public OrcColumnId getColumnId() { - return ordinal; + return columnId; } - public Type getType() + public Domain getDomain() { - return type; + return domain; } @Override public String toString() { return toStringHelper(this) - .add("column", column) - .add("ordinal", ordinal) - .add("type", type) + .add("columnId", columnId) + .add("domain", domain) .toString(); } } diff --git a/presto-orc/src/main/java/io/prestosql/orc/checkpoint/Checkpoints.java b/presto-orc/src/main/java/io/prestosql/orc/checkpoint/Checkpoints.java index 1a259c345e2c..b7254bd32888 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/checkpoint/Checkpoints.java +++ b/presto-orc/src/main/java/io/prestosql/orc/checkpoint/Checkpoints.java @@ -20,6 +20,8 @@ import io.prestosql.orc.StreamId; import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.ColumnEncoding.ColumnEncodingKind; +import io.prestosql.orc.metadata.ColumnMetadata; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.OrcType; import io.prestosql.orc.metadata.OrcType.OrcTypeKind; import io.prestosql.orc.metadata.RowGroupIndex; @@ -47,73 +49,73 @@ public final class Checkpoints private Checkpoints() {} public static Map getStreamCheckpoints( - Set columns, - List columnTypes, + Set columns, + ColumnMetadata columnTypes, boolean compressed, int rowGroupId, - List columnEncodings, + ColumnMetadata columnEncodings, Map streams, Map> columnIndexes) throws InvalidCheckpointException { - ImmutableSetMultimap.Builder streamKindsBuilder = ImmutableSetMultimap.builder(); + ImmutableSetMultimap.Builder streamKindsBuilder = ImmutableSetMultimap.builder(); for (Stream stream : streams.values()) { - streamKindsBuilder.put(stream.getColumn(), stream.getStreamKind()); + streamKindsBuilder.put(stream.getColumnId(), stream.getStreamKind()); } - SetMultimap streamKinds = streamKindsBuilder.build(); + SetMultimap streamKinds = streamKindsBuilder.build(); ImmutableMap.Builder checkpoints = ImmutableMap.builder(); for (Map.Entry> entry : columnIndexes.entrySet()) { - int column = entry.getKey().getColumn(); + OrcColumnId columnId = entry.getKey().getColumnId(); - if (!columns.contains(column)) { + if (!columns.contains(columnId)) { continue; } List positionsList = entry.getValue().get(rowGroupId).getPositions(); - ColumnEncodingKind columnEncoding = columnEncodings.get(column).getColumnEncodingKind(); - OrcTypeKind columnType = columnTypes.get(column).getOrcTypeKind(); - Set availableStreams = streamKinds.get(column); + ColumnEncodingKind columnEncoding = columnEncodings.get(columnId).getColumnEncodingKind(); + OrcTypeKind columnType = columnTypes.get(columnId).getOrcTypeKind(); + Set availableStreams = streamKinds.get(columnId); - ColumnPositionsList columnPositionsList = new ColumnPositionsList(column, columnType, positionsList); + ColumnPositionsList columnPositionsList = new ColumnPositionsList(columnId, columnType, positionsList); switch (columnType) { case BOOLEAN: - checkpoints.putAll(getBooleanColumnCheckpoints(column, compressed, availableStreams, columnPositionsList)); + checkpoints.putAll(getBooleanColumnCheckpoints(columnId, compressed, availableStreams, columnPositionsList)); break; case BYTE: - checkpoints.putAll(getByteColumnCheckpoints(column, compressed, availableStreams, columnPositionsList)); + checkpoints.putAll(getByteColumnCheckpoints(columnId, compressed, availableStreams, columnPositionsList)); break; case SHORT: case INT: case LONG: case DATE: - checkpoints.putAll(getLongColumnCheckpoints(column, columnEncoding, compressed, availableStreams, columnPositionsList)); + checkpoints.putAll(getLongColumnCheckpoints(columnId, columnEncoding, compressed, availableStreams, columnPositionsList)); break; case FLOAT: - checkpoints.putAll(getFloatColumnCheckpoints(column, compressed, availableStreams, columnPositionsList)); + checkpoints.putAll(getFloatColumnCheckpoints(columnId, compressed, availableStreams, columnPositionsList)); break; case DOUBLE: - checkpoints.putAll(getDoubleColumnCheckpoints(column, compressed, availableStreams, columnPositionsList)); + checkpoints.putAll(getDoubleColumnCheckpoints(columnId, compressed, availableStreams, columnPositionsList)); break; case TIMESTAMP: - checkpoints.putAll(getTimestampColumnCheckpoints(column, columnEncoding, compressed, availableStreams, columnPositionsList)); + checkpoints.putAll(getTimestampColumnCheckpoints(columnId, columnEncoding, compressed, availableStreams, columnPositionsList)); break; case BINARY: case STRING: case VARCHAR: case CHAR: - checkpoints.putAll(getSliceColumnCheckpoints(column, columnEncoding, compressed, availableStreams, columnPositionsList)); + checkpoints.putAll(getSliceColumnCheckpoints(columnId, columnEncoding, compressed, availableStreams, columnPositionsList)); break; case LIST: case MAP: - checkpoints.putAll(getListOrMapColumnCheckpoints(column, columnEncoding, compressed, availableStreams, columnPositionsList)); + checkpoints.putAll(getListOrMapColumnCheckpoints(columnId, columnEncoding, compressed, availableStreams, columnPositionsList)); break; case STRUCT: - checkpoints.putAll(getStructColumnCheckpoints(column, compressed, availableStreams, columnPositionsList)); + checkpoints.putAll(getStructColumnCheckpoints(columnId, compressed, availableStreams, columnPositionsList)); break; case DECIMAL: - checkpoints.putAll(getDecimalColumnCheckpoints(column, columnEncoding, compressed, availableStreams, columnPositionsList)); + checkpoints.putAll(getDecimalColumnCheckpoints(columnId, columnEncoding, compressed, availableStreams, columnPositionsList)); break; default: throw new IllegalArgumentException("Unsupported column type " + columnType); @@ -147,7 +149,7 @@ else if (columnEncoding == DICTIONARY) { } private static Map getBooleanColumnCheckpoints( - int column, + OrcColumnId columnId, boolean compressed, Set availableStreams, ColumnPositionsList positionsList) @@ -155,18 +157,18 @@ private static Map getBooleanColumnCheckpoints( ImmutableMap.Builder checkpoints = ImmutableMap.builder(); if (availableStreams.contains(PRESENT)) { - checkpoints.put(new StreamId(column, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); } if (availableStreams.contains(DATA)) { - checkpoints.put(new StreamId(column, DATA), new BooleanStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, DATA), new BooleanStreamCheckpoint(compressed, positionsList)); } return checkpoints.build(); } private static Map getByteColumnCheckpoints( - int column, + OrcColumnId columnId, boolean compressed, Set availableStreams, ColumnPositionsList positionsList) @@ -174,18 +176,18 @@ private static Map getByteColumnCheckpoints( ImmutableMap.Builder checkpoints = ImmutableMap.builder(); if (availableStreams.contains(PRESENT)) { - checkpoints.put(new StreamId(column, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); } if (availableStreams.contains(DATA)) { - checkpoints.put(new StreamId(column, DATA), new ByteStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, DATA), new ByteStreamCheckpoint(compressed, positionsList)); } return checkpoints.build(); } private static Map getLongColumnCheckpoints( - int column, + OrcColumnId columnId, ColumnEncodingKind encoding, boolean compressed, Set availableStreams, @@ -194,18 +196,18 @@ private static Map getLongColumnCheckpoints( ImmutableMap.Builder checkpoints = ImmutableMap.builder(); if (availableStreams.contains(PRESENT)) { - checkpoints.put(new StreamId(column, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); } if (availableStreams.contains(DATA)) { - checkpoints.put(new StreamId(column, DATA), createLongStreamCheckpoint(encoding, compressed, positionsList)); + checkpoints.put(new StreamId(columnId, DATA), createLongStreamCheckpoint(encoding, compressed, positionsList)); } return checkpoints.build(); } private static Map getFloatColumnCheckpoints( - int column, + OrcColumnId columnId, boolean compressed, Set availableStreams, ColumnPositionsList positionsList) @@ -213,18 +215,18 @@ private static Map getFloatColumnCheckpoints( ImmutableMap.Builder checkpoints = ImmutableMap.builder(); if (availableStreams.contains(PRESENT)) { - checkpoints.put(new StreamId(column, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); } if (availableStreams.contains(DATA)) { - checkpoints.put(new StreamId(column, DATA), new FloatStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, DATA), new FloatStreamCheckpoint(compressed, positionsList)); } return checkpoints.build(); } private static Map getDoubleColumnCheckpoints( - int column, + OrcColumnId columnId, boolean compressed, Set availableStreams, ColumnPositionsList positionsList) @@ -232,18 +234,18 @@ private static Map getDoubleColumnCheckpoints( ImmutableMap.Builder checkpoints = ImmutableMap.builder(); if (availableStreams.contains(PRESENT)) { - checkpoints.put(new StreamId(column, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); } if (availableStreams.contains(DATA)) { - checkpoints.put(new StreamId(column, DATA), new DoubleStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, DATA), new DoubleStreamCheckpoint(compressed, positionsList)); } return checkpoints.build(); } private static Map getTimestampColumnCheckpoints( - int column, + OrcColumnId columnId, ColumnEncodingKind encoding, boolean compressed, Set availableStreams, @@ -252,22 +254,22 @@ private static Map getTimestampColumnCheckpoints( ImmutableMap.Builder checkpoints = ImmutableMap.builder(); if (availableStreams.contains(PRESENT)) { - checkpoints.put(new StreamId(column, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); } if (availableStreams.contains(DATA)) { - checkpoints.put(new StreamId(column, DATA), createLongStreamCheckpoint(encoding, compressed, positionsList)); + checkpoints.put(new StreamId(columnId, DATA), createLongStreamCheckpoint(encoding, compressed, positionsList)); } if (availableStreams.contains(SECONDARY)) { - checkpoints.put(new StreamId(column, SECONDARY), createLongStreamCheckpoint(encoding, compressed, positionsList)); + checkpoints.put(new StreamId(columnId, SECONDARY), createLongStreamCheckpoint(encoding, compressed, positionsList)); } return checkpoints.build(); } private static Map getSliceColumnCheckpoints( - int column, + OrcColumnId columnId, ColumnEncodingKind encoding, boolean compressed, Set availableStreams, @@ -276,21 +278,21 @@ private static Map getSliceColumnCheckpoints( ImmutableMap.Builder checkpoints = ImmutableMap.builder(); if (availableStreams.contains(PRESENT)) { - checkpoints.put(new StreamId(column, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); } if (encoding == DIRECT || encoding == DIRECT_V2) { if (availableStreams.contains(DATA)) { - checkpoints.put(new StreamId(column, DATA), new ByteArrayStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, DATA), new ByteArrayStreamCheckpoint(compressed, positionsList)); } if (availableStreams.contains(LENGTH)) { - checkpoints.put(new StreamId(column, LENGTH), createLongStreamCheckpoint(encoding, compressed, positionsList)); + checkpoints.put(new StreamId(columnId, LENGTH), createLongStreamCheckpoint(encoding, compressed, positionsList)); } } else if (encoding == DICTIONARY || encoding == DICTIONARY_V2) { if (availableStreams.contains(DATA)) { - checkpoints.put(new StreamId(column, DATA), createLongStreamCheckpoint(encoding, compressed, positionsList)); + checkpoints.put(new StreamId(columnId, DATA), createLongStreamCheckpoint(encoding, compressed, positionsList)); } } else { @@ -301,7 +303,7 @@ else if (encoding == DICTIONARY || encoding == DICTIONARY_V2) { } private static Map getListOrMapColumnCheckpoints( - int column, + OrcColumnId columnId, ColumnEncodingKind encoding, boolean compressed, Set availableStreams, @@ -310,18 +312,18 @@ private static Map getListOrMapColumnCheckpoints( ImmutableMap.Builder checkpoints = ImmutableMap.builder(); if (availableStreams.contains(PRESENT)) { - checkpoints.put(new StreamId(column, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); } if (availableStreams.contains(LENGTH)) { - checkpoints.put(new StreamId(column, LENGTH), createLongStreamCheckpoint(encoding, compressed, positionsList)); + checkpoints.put(new StreamId(columnId, LENGTH), createLongStreamCheckpoint(encoding, compressed, positionsList)); } return checkpoints.build(); } private static Map getStructColumnCheckpoints( - int column, + OrcColumnId columnId, boolean compressed, Set availableStreams, ColumnPositionsList positionsList) @@ -329,14 +331,14 @@ private static Map getStructColumnCheckpoints( ImmutableMap.Builder checkpoints = ImmutableMap.builder(); if (availableStreams.contains(PRESENT)) { - checkpoints.put(new StreamId(column, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); } return checkpoints.build(); } private static Map getDecimalColumnCheckpoints( - int column, + OrcColumnId columnId, ColumnEncodingKind encoding, boolean compressed, Set availableStreams, @@ -345,15 +347,15 @@ private static Map getDecimalColumnCheckpoints( ImmutableMap.Builder checkpoints = ImmutableMap.builder(); if (availableStreams.contains(PRESENT)) { - checkpoints.put(new StreamId(column, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, PRESENT), new BooleanStreamCheckpoint(compressed, positionsList)); } if (availableStreams.contains(DATA)) { - checkpoints.put(new StreamId(column, DATA), new DecimalStreamCheckpoint(compressed, positionsList)); + checkpoints.put(new StreamId(columnId, DATA), new DecimalStreamCheckpoint(compressed, positionsList)); } if (availableStreams.contains(SECONDARY)) { - checkpoints.put(new StreamId(column, SECONDARY), createLongStreamCheckpoint(encoding, compressed, positionsList)); + checkpoints.put(new StreamId(columnId, SECONDARY), createLongStreamCheckpoint(encoding, compressed, positionsList)); } return checkpoints.build(); @@ -374,14 +376,14 @@ private static StreamCheckpoint createLongStreamCheckpoint(ColumnEncodingKind en public static class ColumnPositionsList { - private final int column; + private final OrcColumnId columnId; private final OrcTypeKind columnType; private final List positionsList; private int index; - private ColumnPositionsList(int column, OrcTypeKind columnType, List positionsList) + private ColumnPositionsList(OrcColumnId columnId, OrcTypeKind columnType, List positionsList) { - this.column = column; + this.columnId = requireNonNull(columnId, "columnId is null"); this.columnType = requireNonNull(columnType, "columnType is null"); this.positionsList = ImmutableList.copyOf(requireNonNull(positionsList, "positionsList is null")); } @@ -399,9 +401,7 @@ public boolean hasNextPosition() public int nextPosition() { if (!hasNextPosition()) { - throw new InvalidCheckpointException("Not enough positions for column %s and sequence %s, of type %s, checkpoints", - column, - columnType); + throw new InvalidCheckpointException("Not enough positions for column %s:%s checkpoints", columnId, columnType); } return positionsList.get(index++); diff --git a/presto-orc/src/main/java/io/prestosql/orc/metadata/ColumnMetadata.java b/presto-orc/src/main/java/io/prestosql/orc/metadata/ColumnMetadata.java new file mode 100644 index 000000000000..08a06b64fc43 --- /dev/null +++ b/presto-orc/src/main/java/io/prestosql/orc/metadata/ColumnMetadata.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.orc.metadata; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static java.util.Collections.unmodifiableList; +import static java.util.Objects.requireNonNull; + +public class ColumnMetadata +{ + private final List metadata; + + public ColumnMetadata(List metadata) + { + // the metadata list may contain nulls + this.metadata = unmodifiableList(new ArrayList<>(requireNonNull(metadata, "metadata is null"))); + } + + public T get(OrcColumnId columnId) + { + return metadata.get(columnId.getId()); + } + + public int size() + { + return metadata.size(); + } + + @Override + public String toString() + { + return metadata.toString(); + } + + public Stream stream() + { + return metadata.stream(); + } +} diff --git a/presto-orc/src/main/java/io/prestosql/orc/metadata/ExceptionWrappingMetadataReader.java b/presto-orc/src/main/java/io/prestosql/orc/metadata/ExceptionWrappingMetadataReader.java index b4175d2b3712..59448c3f52ac 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/metadata/ExceptionWrappingMetadataReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/metadata/ExceptionWrappingMetadataReader.java @@ -77,7 +77,7 @@ public Footer readFooter(HiveWriterVersion hiveWriterVersion, InputStream inputS } @Override - public StripeFooter readStripeFooter(List types, InputStream inputStream) + public StripeFooter readStripeFooter(ColumnMetadata types, InputStream inputStream) throws IOException { try { diff --git a/presto-orc/src/main/java/io/prestosql/orc/metadata/Footer.java b/presto-orc/src/main/java/io/prestosql/orc/metadata/Footer.java index 08c4a3a3b41e..655bfbff9af8 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/metadata/Footer.java +++ b/presto-orc/src/main/java/io/prestosql/orc/metadata/Footer.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.collect.Maps.transformValues; @@ -31,17 +32,22 @@ public class Footer private final long numberOfRows; private final int rowsInRowGroup; private final List stripes; - private final List types; - private final List fileStats; + private final ColumnMetadata types; + private final Optional> fileStats; private final Map userMetadata; - public Footer(long numberOfRows, int rowsInRowGroup, List stripes, List types, List fileStats, Map userMetadata) + public Footer(long numberOfRows, + int rowsInRowGroup, + List stripes, + ColumnMetadata types, + Optional> fileStats, + Map userMetadata) { this.numberOfRows = numberOfRows; this.rowsInRowGroup = rowsInRowGroup; this.stripes = ImmutableList.copyOf(requireNonNull(stripes, "stripes is null")); - this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); - this.fileStats = ImmutableList.copyOf(requireNonNull(fileStats, "columnStatistics is null")); + this.types = requireNonNull(types, "types is null"); + this.fileStats = requireNonNull(fileStats, "fileStats is null"); requireNonNull(userMetadata, "userMetadata is null"); this.userMetadata = ImmutableMap.copyOf(transformValues(userMetadata, Slices::copyOf)); } @@ -61,12 +67,12 @@ public List getStripes() return stripes; } - public List getTypes() + public ColumnMetadata getTypes() { return types; } - public List getFileStats() + public Optional> getFileStats() { return fileStats; } diff --git a/presto-orc/src/main/java/io/prestosql/orc/metadata/Metadata.java b/presto-orc/src/main/java/io/prestosql/orc/metadata/Metadata.java index ffa67d05ca95..1732b5fbceca 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/metadata/Metadata.java +++ b/presto-orc/src/main/java/io/prestosql/orc/metadata/Metadata.java @@ -16,17 +16,18 @@ import io.prestosql.orc.metadata.statistics.StripeStatistics; import java.util.List; +import java.util.Optional; public class Metadata { - private final List stripeStatistics; + private final List> stripeStatistics; - public Metadata(List stripeStatistics) + public Metadata(List> stripeStatistics) { this.stripeStatistics = stripeStatistics; } - public List getStripeStatsList() + public List> getStripeStatsList() { return stripeStatistics; } diff --git a/presto-orc/src/main/java/io/prestosql/orc/metadata/MetadataReader.java b/presto-orc/src/main/java/io/prestosql/orc/metadata/MetadataReader.java index cb95c23ec61e..e934a137f740 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/metadata/MetadataReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/metadata/MetadataReader.java @@ -31,7 +31,7 @@ Metadata readMetadata(HiveWriterVersion hiveWriterVersion, InputStream inputStre Footer readFooter(HiveWriterVersion hiveWriterVersion, InputStream inputStream) throws IOException; - StripeFooter readStripeFooter(List types, InputStream inputStream) + StripeFooter readStripeFooter(ColumnMetadata types, InputStream inputStream) throws IOException; List readRowIndexes(HiveWriterVersion hiveWriterVersion, InputStream inputStream) diff --git a/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcColumnId.java b/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcColumnId.java new file mode 100644 index 000000000000..99d3a05de693 --- /dev/null +++ b/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcColumnId.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.orc.metadata; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; + +public class OrcColumnId + implements Comparable +{ + public static final OrcColumnId ROOT_COLUMN = new OrcColumnId(0); + private final int id; + + public OrcColumnId(int id) + { + checkArgument(id >= 0, "id is negative"); + this.id = id; + } + + public int getId() + { + return id; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + OrcColumnId that = (OrcColumnId) o; + return id == that.id; + } + + @Override + public int hashCode() + { + return Objects.hash(id); + } + + @Override + public int compareTo(OrcColumnId o) + { + return Integer.compare(id, o.id); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("id", id) + .toString(); + } +} diff --git a/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcMetadataReader.java b/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcMetadataReader.java index 4d6783564639..1629ce4b4423 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcMetadataReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcMetadataReader.java @@ -112,16 +112,17 @@ public Metadata readMetadata(HiveWriterVersion hiveWriterVersion, InputStream in return new Metadata(toStripeStatistics(hiveWriterVersion, metadata.getStripeStatsList())); } - private static List toStripeStatistics(HiveWriterVersion hiveWriterVersion, List types) + private static List> toStripeStatistics(HiveWriterVersion hiveWriterVersion, List types) { return types.stream() .map(stripeStatistics -> toStripeStatistics(hiveWriterVersion, stripeStatistics)) .collect(toImmutableList()); } - private static StripeStatistics toStripeStatistics(HiveWriterVersion hiveWriterVersion, OrcProto.StripeStatistics stripeStatistics) + private static Optional toStripeStatistics(HiveWriterVersion hiveWriterVersion, OrcProto.StripeStatistics stripeStatistics) { - return new StripeStatistics(toColumnStatistics(hiveWriterVersion, stripeStatistics.getColStatsList(), false)); + return toColumnStatistics(hiveWriterVersion, stripeStatistics.getColStatsList(), false) + .map(StripeStatistics::new); } @Override @@ -158,7 +159,7 @@ private static StripeInformation toStripeInformation(OrcProto.StripeInformation } @Override - public StripeFooter readStripeFooter(List types, InputStream inputStream) + public StripeFooter readStripeFooter(ColumnMetadata types, InputStream inputStream) throws IOException { CodedInputStream input = CodedInputStream.newInstance(inputStream); @@ -172,7 +173,7 @@ public StripeFooter readStripeFooter(List types, InputStream inputStrea private static Stream toStream(OrcProto.Stream stream) { - return new Stream(stream.getColumn(), toStreamKind(stream.getKind()), toIntExact(stream.getLength()), true); + return new Stream(new OrcColumnId(stream.getColumn()), toStreamKind(stream.getKind()), toIntExact(stream.getLength()), true); } private static List toStream(List streams) @@ -187,11 +188,11 @@ private static ColumnEncoding toColumnEncoding(OrcProto.ColumnEncoding columnEnc return new ColumnEncoding(toColumnEncodingKind(columnEncoding.getKind()), columnEncoding.getDictionarySize()); } - private static List toColumnEncoding(List columnEncodings) + private static ColumnMetadata toColumnEncoding(List columnEncodings) { - return columnEncodings.stream() + return new ColumnMetadata<>(columnEncodings.stream() .map(OrcMetadataReader::toColumnEncoding) - .collect(toImmutableList()); + .collect(toImmutableList())); } @Override @@ -293,14 +294,14 @@ else if (statistics.hasBinaryStatistics()) { null); } - private static List toColumnStatistics(HiveWriterVersion hiveWriterVersion, List columnStatistics, boolean isRowGroup) + private static Optional> toColumnStatistics(HiveWriterVersion hiveWriterVersion, List columnStatistics, boolean isRowGroup) { - if (columnStatistics == null) { - return ImmutableList.of(); + if (columnStatistics == null || columnStatistics.isEmpty()) { + return Optional.empty(); } - return columnStatistics.stream() + return Optional.of(new ColumnMetadata<>(columnStatistics.stream() .map(statistics -> toColumnStatistics(hiveWriterVersion, statistics, isRowGroup)) - .collect(toImmutableList()); + .collect(toImmutableList()))); } private static Map toUserMetadata(List metadataList) @@ -491,16 +492,23 @@ private static OrcType toType(OrcProto.Type type) precision = Optional.of(type.getPrecision()); scale = Optional.of(type.getScale()); } - return new OrcType(toTypeKind(type.getKind()), type.getSubtypesList(), type.getFieldNamesList(), length, precision, scale); + return new OrcType(toTypeKind(type.getKind()), toOrcColumnId(type.getSubtypesList()), type.getFieldNamesList(), length, precision, scale); } - private static List toType(List types) + private static List toOrcColumnId(List columnIds) { - return types.stream() - .map(OrcMetadataReader::toType) + return columnIds.stream() + .map(OrcColumnId::new) .collect(toImmutableList()); } + private static ColumnMetadata toType(List types) + { + return new ColumnMetadata<>(types.stream() + .map(OrcMetadataReader::toType) + .collect(toImmutableList())); + } + private static OrcTypeKind toTypeKind(OrcProto.Type.Kind typeKind) { switch (typeKind) { diff --git a/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcMetadataWriter.java b/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcMetadataWriter.java index 03a7abf747bb..c63109706f02 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcMetadataWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcMetadataWriter.java @@ -35,6 +35,7 @@ import java.time.ZoneId; import java.util.List; import java.util.Map.Entry; +import java.util.Optional; import java.util.TimeZone; import static java.lang.Math.toIntExact; @@ -86,6 +87,7 @@ public int writeMetadata(SliceOutput output, Metadata metadata) { OrcProto.Metadata metadataProtobuf = OrcProto.Metadata.newBuilder() .addAllStripeStats(metadata.getStripeStatsList().stream() + .map(Optional::get) .map(OrcMetadataWriter::toStripeStatistics) .collect(toList())) .build(); @@ -115,7 +117,7 @@ public int writeFooter(SliceOutput output, Footer footer) .addAllTypes(footer.getTypes().stream() .map(OrcMetadataWriter::toType) .collect(toList())) - .addAllStatistics(footer.getFileStats().stream() + .addAllStatistics(footer.getFileStats().map(ColumnMetadata::stream).orElseGet(java.util.stream.Stream::empty) .map(OrcMetadataWriter::toColumnStatistics) .collect(toList())) .addAllMetadata(footer.getUserMetadata().entrySet().stream() @@ -144,7 +146,9 @@ private static Type toType(OrcType type) { Builder builder = Type.newBuilder() .setKind(toTypeKind(type.getOrcTypeKind())) - .addAllSubtypes(type.getFieldTypeIndexes()) + .addAllSubtypes(type.getFieldTypeIndexes().stream() + .map(OrcColumnId::getId) + .collect(toList())) .addAllFieldNames(type.getFieldNames()); if (type.getLength().isPresent()) { @@ -298,7 +302,7 @@ public int writeStripeFooter(SliceOutput output, StripeFooter footer) private static OrcProto.Stream toStream(Stream stream) { return OrcProto.Stream.newBuilder() - .setColumn(stream.getColumn()) + .setColumn(stream.getColumnId().getId()) .setKind(toStreamKind(stream.getStreamKind())) .setLength(stream.getLength()) .build(); diff --git a/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcType.java b/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcType.java index a70630adca2a..53e1182d1a8f 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcType.java +++ b/presto-orc/src/main/java/io/prestosql/orc/metadata/OrcType.java @@ -75,7 +75,7 @@ public enum OrcTypeKind } private final OrcTypeKind orcTypeKind; - private final List fieldTypeIndexes; + private final List fieldTypeIndexes; private final List fieldNames; private final Optional length; private final Optional precision; @@ -96,12 +96,12 @@ private OrcType(OrcTypeKind orcTypeKind, int precision, int scale) this(orcTypeKind, ImmutableList.of(), ImmutableList.of(), Optional.empty(), Optional.of(precision), Optional.of(scale)); } - private OrcType(OrcTypeKind orcTypeKind, List fieldTypeIndexes, List fieldNames) + private OrcType(OrcTypeKind orcTypeKind, List fieldTypeIndexes, List fieldNames) { this(orcTypeKind, fieldTypeIndexes, fieldNames, Optional.empty(), Optional.empty(), Optional.empty()); } - public OrcType(OrcTypeKind orcTypeKind, List fieldTypeIndexes, List fieldNames, Optional length, Optional precision, Optional scale) + public OrcType(OrcTypeKind orcTypeKind, List fieldTypeIndexes, List fieldNames, Optional length, Optional precision, Optional scale) { this.orcTypeKind = requireNonNull(orcTypeKind, "typeKind is null"); this.fieldTypeIndexes = ImmutableList.copyOf(requireNonNull(fieldTypeIndexes, "fieldTypeIndexes is null")); @@ -127,12 +127,12 @@ public int getFieldCount() return fieldTypeIndexes.size(); } - public int getFieldTypeIndex(int field) + public OrcColumnId getFieldTypeIndex(int field) { return fieldTypeIndexes.get(field); } - public List getFieldTypeIndexes() + public List getFieldTypeIndexes() { return fieldTypeIndexes; } @@ -243,7 +243,7 @@ private static List createOrcArrayType(int nextFieldTypeIndex, Type ite List itemTypes = toOrcType(nextFieldTypeIndex, itemType); List orcTypes = new ArrayList<>(); - orcTypes.add(new OrcType(OrcTypeKind.LIST, ImmutableList.of(nextFieldTypeIndex), ImmutableList.of("item"))); + orcTypes.add(new OrcType(OrcTypeKind.LIST, ImmutableList.of(new OrcColumnId(nextFieldTypeIndex)), ImmutableList.of("item"))); orcTypes.addAll(itemTypes); return orcTypes; } @@ -255,19 +255,27 @@ private static List createOrcMapType(int nextFieldTypeIndex, Type keyTy List valueTypes = toOrcType(nextFieldTypeIndex + keyTypes.size(), valueType); List orcTypes = new ArrayList<>(); - orcTypes.add(new OrcType(OrcTypeKind.MAP, ImmutableList.of(nextFieldTypeIndex, nextFieldTypeIndex + keyTypes.size()), ImmutableList.of("key", "value"))); + orcTypes.add(new OrcType( + OrcTypeKind.MAP, + ImmutableList.of(new OrcColumnId(nextFieldTypeIndex), new OrcColumnId(nextFieldTypeIndex + keyTypes.size())), + ImmutableList.of("key", "value"))); orcTypes.addAll(keyTypes); orcTypes.addAll(valueTypes); return orcTypes; } - public static List createOrcRowType(int nextFieldTypeIndex, List fieldNames, List fieldTypes) + public static ColumnMetadata createRootOrcType(List fieldNames, List fieldTypes) + { + return new ColumnMetadata<>(createOrcRowType(0, fieldNames, fieldTypes)); + } + + private static List createOrcRowType(int nextFieldTypeIndex, List fieldNames, List fieldTypes) { nextFieldTypeIndex++; - List fieldTypeIndexes = new ArrayList<>(); + List fieldTypeIndexes = new ArrayList<>(); List> fieldTypesList = new ArrayList<>(); for (Type fieldType : fieldTypes) { - fieldTypeIndexes.add(nextFieldTypeIndex); + fieldTypeIndexes.add(new OrcColumnId(nextFieldTypeIndex)); List fieldOrcTypes = toOrcType(nextFieldTypeIndex, fieldType); fieldTypesList.add(fieldOrcTypes); nextFieldTypeIndex += fieldOrcTypes.size(); diff --git a/presto-orc/src/main/java/io/prestosql/orc/metadata/Stream.java b/presto-orc/src/main/java/io/prestosql/orc/metadata/Stream.java index f01a36b037f0..5aaea5c15ed8 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/metadata/Stream.java +++ b/presto-orc/src/main/java/io/prestosql/orc/metadata/Stream.java @@ -31,22 +31,22 @@ public enum StreamKind BLOOM_FILTER_UTF8, } - private final int column; + private final OrcColumnId columnId; private final StreamKind streamKind; private final int length; private final boolean useVInts; - public Stream(int column, StreamKind streamKind, int length, boolean useVInts) + public Stream(OrcColumnId columnId, StreamKind streamKind, int length, boolean useVInts) { - this.column = column; + this.columnId = columnId; this.streamKind = requireNonNull(streamKind, "streamKind is null"); this.length = length; this.useVInts = useVInts; } - public int getColumn() + public OrcColumnId getColumnId() { - return column; + return columnId; } public StreamKind getStreamKind() @@ -68,7 +68,7 @@ public boolean isUseVInts() public String toString() { return toStringHelper(this) - .add("column", column) + .add("column", columnId) .add("streamKind", streamKind) .add("length", length) .add("useVInts", useVInts) diff --git a/presto-orc/src/main/java/io/prestosql/orc/metadata/StripeFooter.java b/presto-orc/src/main/java/io/prestosql/orc/metadata/StripeFooter.java index ff1e949ebe72..040b3d9eb6a1 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/metadata/StripeFooter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/metadata/StripeFooter.java @@ -24,17 +24,17 @@ public class StripeFooter { private final List streams; - private final List columnEncodings; + private final ColumnMetadata columnEncodings; private final Optional timeZone; - public StripeFooter(List streams, List columnEncodings, Optional timeZone) + public StripeFooter(List streams, ColumnMetadata columnEncodings, Optional timeZone) { this.streams = ImmutableList.copyOf(requireNonNull(streams, "streams is null")); - this.columnEncodings = ImmutableList.copyOf(requireNonNull(columnEncodings, "columnEncodings is null")); + this.columnEncodings = requireNonNull(columnEncodings, "columnEncodings is null"); this.timeZone = requireNonNull(timeZone, "timeZone is null"); } - public List getColumnEncodings() + public ColumnMetadata getColumnEncodings() { return columnEncodings; } diff --git a/presto-orc/src/main/java/io/prestosql/orc/metadata/statistics/StripeStatistics.java b/presto-orc/src/main/java/io/prestosql/orc/metadata/statistics/StripeStatistics.java index 208ef0448845..481fe08bc8d3 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/metadata/statistics/StripeStatistics.java +++ b/presto-orc/src/main/java/io/prestosql/orc/metadata/statistics/StripeStatistics.java @@ -13,10 +13,9 @@ */ package io.prestosql.orc.metadata.statistics; -import com.google.common.collect.ImmutableList; +import io.prestosql.orc.metadata.ColumnMetadata; import org.openjdk.jol.info.ClassLayout; -import java.util.List; import java.util.Objects; import static java.util.Objects.requireNonNull; @@ -25,16 +24,16 @@ public class StripeStatistics { private static final int INSTANCE_SIZE = ClassLayout.parseClass(StripeStatistics.class).instanceSize(); - private final List columnStatistics; + private final ColumnMetadata columnStatistics; private final long retainedSizeInBytes; - public StripeStatistics(List columnStatistics) + public StripeStatistics(ColumnMetadata columnStatistics) { - this.columnStatistics = ImmutableList.copyOf(requireNonNull(columnStatistics, "columnStatistics is null")); + this.columnStatistics = requireNonNull(columnStatistics, "columnStatistics is null"); this.retainedSizeInBytes = INSTANCE_SIZE + columnStatistics.stream().mapToLong(ColumnStatistics::getRetainedSizeInBytes).sum(); } - public List getColumnStatistics() + public ColumnMetadata getColumnStatistics() { return columnStatistics; } diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/BooleanStreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/BooleanColumnReader.java similarity index 87% rename from presto-orc/src/main/java/io/prestosql/orc/reader/BooleanStreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/BooleanColumnReader.java index 6a95f9fa6b3e..56c1e5c06c60 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/BooleanStreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/BooleanColumnReader.java @@ -14,9 +14,10 @@ package io.prestosql.orc.reader; import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.BooleanInputStream; import io.prestosql.orc.stream.InputStreamSource; import io.prestosql.orc.stream.InputStreamSources; @@ -31,7 +32,6 @@ import java.io.IOException; import java.time.ZoneId; -import java.util.List; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -46,12 +46,12 @@ import static io.prestosql.spi.type.BooleanType.BOOLEAN; import static java.util.Objects.requireNonNull; -public class BooleanStreamReader - implements StreamReader +public class BooleanColumnReader + implements ColumnReader { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(BooleanStreamReader.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(BooleanColumnReader.class).instanceSize(); - private final StreamDescriptor streamDescriptor; + private final OrcColumn column; private int readOffset; private int nextBatchSize; @@ -71,13 +71,13 @@ public class BooleanStreamReader private final LocalMemoryContext systemMemoryContext; - public BooleanStreamReader(Type type, StreamDescriptor streamDescriptor, LocalMemoryContext systemMemoryContext) + public BooleanColumnReader(Type type, OrcColumn column, LocalMemoryContext systemMemoryContext) throws OrcCorruptionException { requireNonNull(type, "type is null"); - verifyStreamType(streamDescriptor, type, BooleanType.class::isInstance); + verifyStreamType(column, type, BooleanType.class::isInstance); - this.streamDescriptor = requireNonNull(streamDescriptor, "stream is null"); + this.column = requireNonNull(column, "column is null"); this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); } @@ -104,7 +104,7 @@ public Block readBlock() } if (readOffset > 0) { if (dataStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is missing"); } dataStream.skip(readOffset); } @@ -113,7 +113,7 @@ public Block readBlock() Block block; if (dataStream == null) { if (presentStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is null but present stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is null but present stream is missing"); } presentStream.skip(nextBatchSize); block = RunLengthEncodedBlock.create(BOOLEAN, null, nextBatchSize); @@ -175,7 +175,7 @@ private void openRowGroup() } @Override - public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) { presentStreamSource = missingStreamSource(BooleanInputStream.class); dataStreamSource = missingStreamSource(BooleanInputStream.class); @@ -192,8 +192,8 @@ public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSour @Override public void startRowGroup(InputStreamSources dataStreamSources) { - presentStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, PRESENT, BooleanInputStream.class); - dataStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, DATA, BooleanInputStream.class); + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); + dataStreamSource = dataStreamSources.getInputStreamSource(column, DATA, BooleanInputStream.class); readOffset = 0; nextBatchSize = 0; @@ -208,7 +208,7 @@ public void startRowGroup(InputStreamSources dataStreamSources) public String toString() { return toStringHelper(this) - .addValue(streamDescriptor) + .addValue(column) .toString(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/ByteStreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/ByteColumnReader.java similarity index 87% rename from presto-orc/src/main/java/io/prestosql/orc/reader/ByteStreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/ByteColumnReader.java index bcbeaa33f679..b8b2faa5a49d 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/ByteStreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/ByteColumnReader.java @@ -14,9 +14,10 @@ package io.prestosql.orc.reader; import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.BooleanInputStream; import io.prestosql.orc.stream.ByteInputStream; import io.prestosql.orc.stream.InputStreamSource; @@ -32,7 +33,6 @@ import java.io.IOException; import java.time.ZoneId; -import java.util.List; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -46,12 +46,12 @@ import static io.prestosql.spi.type.TinyintType.TINYINT; import static java.util.Objects.requireNonNull; -public class ByteStreamReader - implements StreamReader +public class ByteColumnReader + implements ColumnReader { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(ByteStreamReader.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(ByteColumnReader.class).instanceSize(); - private final StreamDescriptor streamDescriptor; + private final OrcColumn column; private int readOffset; private int nextBatchSize; @@ -71,13 +71,13 @@ public class ByteStreamReader private final LocalMemoryContext systemMemoryContext; - public ByteStreamReader(Type type, StreamDescriptor streamDescriptor, LocalMemoryContext systemMemoryContext) + public ByteColumnReader(Type type, OrcColumn column, LocalMemoryContext systemMemoryContext) throws OrcCorruptionException { requireNonNull(type, "type is null"); - verifyStreamType(streamDescriptor, type, TinyintType.class::isInstance); + verifyStreamType(column, type, TinyintType.class::isInstance); - this.streamDescriptor = requireNonNull(streamDescriptor, "stream is null"); + this.column = requireNonNull(column, "column is null"); this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); } @@ -104,7 +104,7 @@ public Block readBlock() } if (readOffset > 0) { if (dataStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is missing"); } dataStream.skip(readOffset); } @@ -113,7 +113,7 @@ public Block readBlock() Block block; if (dataStream == null) { if (presentStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is null but present stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is null but present stream is missing"); } presentStream.skip(nextBatchSize); block = RunLengthEncodedBlock.create(TINYINT, null, nextBatchSize); @@ -176,7 +176,7 @@ private void openRowGroup() } @Override - public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) { presentStreamSource = missingStreamSource(BooleanInputStream.class); dataStreamSource = missingStreamSource(ByteInputStream.class); @@ -193,8 +193,8 @@ public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSour @Override public void startRowGroup(InputStreamSources dataStreamSources) { - presentStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, PRESENT, BooleanInputStream.class); - dataStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, DATA, ByteInputStream.class); + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); + dataStreamSource = dataStreamSources.getInputStreamSource(column, DATA, ByteInputStream.class); readOffset = 0; nextBatchSize = 0; @@ -209,7 +209,7 @@ public void startRowGroup(InputStreamSources dataStreamSources) public String toString() { return toStringHelper(this) - .addValue(streamDescriptor) + .addValue(column) .toString(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/StreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/ColumnReader.java similarity index 89% rename from presto-orc/src/main/java/io/prestosql/orc/reader/StreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/ColumnReader.java index a9221244fab6..241517b985c1 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/StreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/ColumnReader.java @@ -14,21 +14,21 @@ package io.prestosql.orc.reader; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.InputStreamSources; import io.prestosql.spi.block.Block; import java.io.IOException; import java.time.ZoneId; -import java.util.List; -public interface StreamReader +public interface ColumnReader { Block readBlock() throws IOException; void prepareNextRead(int batchSize); - void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) throws IOException; void startRowGroup(InputStreamSources dataStreamSources) diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/ColumnReaders.java b/presto-orc/src/main/java/io/prestosql/orc/reader/ColumnReaders.java new file mode 100644 index 000000000000..587f15b4a1b4 --- /dev/null +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/ColumnReaders.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.orc.reader; + +import io.prestosql.memory.context.AggregatedMemoryContext; +import io.prestosql.orc.OrcBlockFactory.NestedBlockFactory; +import io.prestosql.orc.OrcColumn; +import io.prestosql.orc.OrcCorruptionException; +import io.prestosql.spi.type.Type; + +public final class ColumnReaders +{ + private ColumnReaders() {} + + public static ColumnReader createColumnReader(Type type, OrcColumn column, AggregatedMemoryContext systemMemoryContext, NestedBlockFactory blockFactory) + throws OrcCorruptionException + { + switch (column.getColumnType()) { + case BOOLEAN: + return new BooleanColumnReader(type, column, systemMemoryContext.newLocalMemoryContext(ColumnReaders.class.getSimpleName())); + case BYTE: + return new ByteColumnReader(type, column, systemMemoryContext.newLocalMemoryContext(ColumnReaders.class.getSimpleName())); + case SHORT: + case INT: + case LONG: + case DATE: + return new LongColumnReader(type, column, systemMemoryContext.newLocalMemoryContext(ColumnReaders.class.getSimpleName())); + case FLOAT: + return new FloatColumnReader(type, column, systemMemoryContext.newLocalMemoryContext(ColumnReaders.class.getSimpleName())); + case DOUBLE: + return new DoubleColumnReader(type, column, systemMemoryContext.newLocalMemoryContext(ColumnReaders.class.getSimpleName())); + case BINARY: + case STRING: + case VARCHAR: + case CHAR: + return new SliceColumnReader(type, column, systemMemoryContext); + case TIMESTAMP: + return new TimestampColumnReader(type, column, systemMemoryContext.newLocalMemoryContext(ColumnReaders.class.getSimpleName())); + case LIST: + return new ListColumnReader(type, column, systemMemoryContext, blockFactory); + case STRUCT: + return new StructColumnReader(type, column, systemMemoryContext, blockFactory); + case MAP: + return new MapColumnReader(type, column, systemMemoryContext, blockFactory); + case DECIMAL: + return new DecimalColumnReader(type, column, systemMemoryContext.newLocalMemoryContext(ColumnReaders.class.getSimpleName())); + case UNION: + default: + throw new IllegalArgumentException("Unsupported type: " + column.getColumnType()); + } + } +} diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/DecimalStreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/DecimalColumnReader.java similarity index 90% rename from presto-orc/src/main/java/io/prestosql/orc/reader/DecimalStreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/DecimalColumnReader.java index a4e749f1d7a4..528e4e91f968 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/DecimalStreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/DecimalColumnReader.java @@ -16,9 +16,10 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.BooleanInputStream; import io.prestosql.orc.stream.DecimalInputStream; import io.prestosql.orc.stream.InputStreamSource; @@ -38,7 +39,6 @@ import java.io.IOException; import java.time.ZoneId; -import java.util.List; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -55,13 +55,13 @@ import static io.prestosql.spi.type.DoubleType.DOUBLE; import static java.util.Objects.requireNonNull; -public class DecimalStreamReader - implements StreamReader +public class DecimalColumnReader + implements ColumnReader { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(DecimalStreamReader.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(DecimalColumnReader.class).instanceSize(); private final DecimalType type; - private final StreamDescriptor streamDescriptor; + private final OrcColumn column; private int readOffset; private int nextBatchSize; @@ -87,14 +87,14 @@ public class DecimalStreamReader private final LocalMemoryContext systemMemoryContext; - public DecimalStreamReader(Type type, StreamDescriptor streamDescriptor, LocalMemoryContext systemMemoryContext) + public DecimalColumnReader(Type type, OrcColumn column, LocalMemoryContext systemMemoryContext) throws OrcCorruptionException { requireNonNull(type, "type is null"); - verifyStreamType(streamDescriptor, type, DecimalType.class::isInstance); + verifyStreamType(column, type, DecimalType.class::isInstance); this.type = (DecimalType) type; - this.streamDescriptor = requireNonNull(streamDescriptor, "stream is null"); + this.column = requireNonNull(column, "column is null"); this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); } @@ -118,7 +118,7 @@ public Block readBlock() Block block; if (decimalStream == null && scaleStream == null) { if (presentStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is null but present stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is null but present stream is missing"); } presentStream.skip(nextBatchSize); block = RunLengthEncodedBlock.create(type, null, nextBatchSize); @@ -152,10 +152,10 @@ private void checkDataStreamsArePresent() throws OrcCorruptionException { if (decimalStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but decimal stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but decimal stream is missing"); } if (scaleStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but scale stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but scale stream is missing"); } } @@ -306,7 +306,7 @@ private void seekToOffset() } @Override - public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) { presentStreamSource = missingStreamSource(BooleanInputStream.class); decimalStreamSource = missingStreamSource(DecimalInputStream.class); @@ -325,9 +325,9 @@ public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSour @Override public void startRowGroup(InputStreamSources dataStreamSources) { - presentStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, PRESENT, BooleanInputStream.class); - decimalStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, DATA, DecimalInputStream.class); - scaleStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, SECONDARY, LongInputStream.class); + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); + decimalStreamSource = dataStreamSources.getInputStreamSource(column, DATA, DecimalInputStream.class); + scaleStreamSource = dataStreamSources.getInputStreamSource(column, SECONDARY, LongInputStream.class); readOffset = 0; nextBatchSize = 0; @@ -343,7 +343,7 @@ public void startRowGroup(InputStreamSources dataStreamSources) public String toString() { return toStringHelper(this) - .addValue(streamDescriptor) + .addValue(column) .toString(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/DoubleStreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/DoubleColumnReader.java similarity index 87% rename from presto-orc/src/main/java/io/prestosql/orc/reader/DoubleStreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/DoubleColumnReader.java index fcb64d58fbd5..800c2eb5b512 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/DoubleStreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/DoubleColumnReader.java @@ -14,9 +14,10 @@ package io.prestosql.orc.reader; import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.BooleanInputStream; import io.prestosql.orc.stream.DoubleInputStream; import io.prestosql.orc.stream.InputStreamSource; @@ -32,7 +33,6 @@ import java.io.IOException; import java.time.ZoneId; -import java.util.List; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -47,12 +47,12 @@ import static io.prestosql.spi.type.DoubleType.DOUBLE; import static java.util.Objects.requireNonNull; -public class DoubleStreamReader - implements StreamReader +public class DoubleColumnReader + implements ColumnReader { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(DoubleStreamReader.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(DoubleColumnReader.class).instanceSize(); - private final StreamDescriptor streamDescriptor; + private final OrcColumn column; private int readOffset; private int nextBatchSize; @@ -72,13 +72,13 @@ public class DoubleStreamReader private final LocalMemoryContext systemMemoryContext; - public DoubleStreamReader(Type type, StreamDescriptor streamDescriptor, LocalMemoryContext systemMemoryContext) + public DoubleColumnReader(Type type, OrcColumn column, LocalMemoryContext systemMemoryContext) throws OrcCorruptionException { requireNonNull(type, "type is null"); - verifyStreamType(streamDescriptor, type, DoubleType.class::isInstance); + verifyStreamType(column, type, DoubleType.class::isInstance); - this.streamDescriptor = requireNonNull(streamDescriptor, "stream is null"); + this.column = requireNonNull(column, "column is null"); this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); } @@ -105,7 +105,7 @@ public Block readBlock() } if (readOffset > 0) { if (dataStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is missing"); } dataStream.skip(readOffset); } @@ -114,7 +114,7 @@ public Block readBlock() Block block; if (dataStream == null) { if (presentStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is null but present stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is null but present stream is missing"); } presentStream.skip(nextBatchSize); block = RunLengthEncodedBlock.create(DOUBLE, null, nextBatchSize); @@ -178,7 +178,7 @@ private void openRowGroup() } @Override - public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) { presentStreamSource = missingStreamSource(BooleanInputStream.class); dataStreamSource = missingStreamSource(DoubleInputStream.class); @@ -195,8 +195,8 @@ public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSour @Override public void startRowGroup(InputStreamSources dataStreamSources) { - presentStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, PRESENT, BooleanInputStream.class); - dataStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, DATA, DoubleInputStream.class); + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); + dataStreamSource = dataStreamSources.getInputStreamSource(column, DATA, DoubleInputStream.class); readOffset = 0; nextBatchSize = 0; @@ -211,7 +211,7 @@ public void startRowGroup(InputStreamSources dataStreamSources) public String toString() { return toStringHelper(this) - .addValue(streamDescriptor) + .addValue(column) .toString(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/FloatStreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/FloatColumnReader.java similarity index 87% rename from presto-orc/src/main/java/io/prestosql/orc/reader/FloatStreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/FloatColumnReader.java index df8f11aecca4..f694b971bb10 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/FloatStreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/FloatColumnReader.java @@ -14,9 +14,10 @@ package io.prestosql.orc.reader; import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.BooleanInputStream; import io.prestosql.orc.stream.FloatInputStream; import io.prestosql.orc.stream.InputStreamSource; @@ -32,7 +33,6 @@ import java.io.IOException; import java.time.ZoneId; -import java.util.List; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -46,12 +46,12 @@ import static io.prestosql.spi.type.RealType.REAL; import static java.util.Objects.requireNonNull; -public class FloatStreamReader - implements StreamReader +public class FloatColumnReader + implements ColumnReader { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(FloatStreamReader.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(FloatColumnReader.class).instanceSize(); - private final StreamDescriptor streamDescriptor; + private final OrcColumn column; private int readOffset; private int nextBatchSize; @@ -71,13 +71,13 @@ public class FloatStreamReader private final LocalMemoryContext systemMemoryContext; - public FloatStreamReader(Type type, StreamDescriptor streamDescriptor, LocalMemoryContext systemMemoryContext) + public FloatColumnReader(Type type, OrcColumn column, LocalMemoryContext systemMemoryContext) throws OrcCorruptionException { requireNonNull(type, "type is null"); - verifyStreamType(streamDescriptor, type, RealType.class::isInstance); + verifyStreamType(column, type, RealType.class::isInstance); - this.streamDescriptor = requireNonNull(streamDescriptor, "stream is null"); + this.column = requireNonNull(column, "column is null"); this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); } @@ -104,7 +104,7 @@ public Block readBlock() } if (readOffset > 0) { if (dataStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is missing"); } dataStream.skip(readOffset); } @@ -113,7 +113,7 @@ public Block readBlock() Block block; if (dataStream == null) { if (presentStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is null but present stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is null but present stream is missing"); } presentStream.skip(nextBatchSize); block = RunLengthEncodedBlock.create(REAL, null, nextBatchSize); @@ -177,7 +177,7 @@ private void openRowGroup() } @Override - public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) { presentStreamSource = missingStreamSource(BooleanInputStream.class); dataStreamSource = missingStreamSource(FloatInputStream.class); @@ -194,8 +194,8 @@ public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSour @Override public void startRowGroup(InputStreamSources dataStreamSources) { - presentStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, PRESENT, BooleanInputStream.class); - dataStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, DATA, FloatInputStream.class); + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); + dataStreamSource = dataStreamSources.getInputStreamSource(column, DATA, FloatInputStream.class); readOffset = 0; nextBatchSize = 0; @@ -210,7 +210,7 @@ public void startRowGroup(InputStreamSources dataStreamSources) public String toString() { return toStringHelper(this) - .addValue(streamDescriptor) + .addValue(column) .toString(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/ListStreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/ListColumnReader.java similarity index 75% rename from presto-orc/src/main/java/io/prestosql/orc/reader/ListStreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/ListColumnReader.java index fd3cf8ea4126..ea936706337b 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/ListStreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/ListColumnReader.java @@ -15,9 +15,11 @@ import com.google.common.io.Closer; import io.prestosql.memory.context.AggregatedMemoryContext; +import io.prestosql.orc.OrcBlockFactory.NestedBlockFactory; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.BooleanInputStream; import io.prestosql.orc.stream.InputStreamSource; import io.prestosql.orc.stream.InputStreamSources; @@ -33,29 +35,29 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.time.ZoneId; -import java.util.List; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; import static io.prestosql.orc.metadata.Stream.StreamKind.LENGTH; import static io.prestosql.orc.metadata.Stream.StreamKind.PRESENT; +import static io.prestosql.orc.reader.ColumnReaders.createColumnReader; import static io.prestosql.orc.reader.ReaderUtils.convertLengthVectorToOffsetVector; import static io.prestosql.orc.reader.ReaderUtils.unpackLengthNulls; import static io.prestosql.orc.reader.ReaderUtils.verifyStreamType; -import static io.prestosql.orc.reader.StreamReaders.createStreamReader; import static io.prestosql.orc.stream.MissingInputStreamSource.missingStreamSource; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; -public class ListStreamReader - implements StreamReader +public class ListColumnReader + implements ColumnReader { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(ListStreamReader.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(ListColumnReader.class).instanceSize(); private final Type elementType; - private final StreamDescriptor streamDescriptor; + private final OrcColumn column; + private final NestedBlockFactory blockFactory; - private final StreamReader elementStreamReader; + private final ColumnReader elementColumnReader; private int readOffset; private int nextBatchSize; @@ -70,15 +72,16 @@ public class ListStreamReader private boolean rowGroupOpen; - public ListStreamReader(Type type, StreamDescriptor streamDescriptor, AggregatedMemoryContext systemMemoryContext) + public ListColumnReader(Type type, OrcColumn column, AggregatedMemoryContext systemMemoryContext, NestedBlockFactory blockFactory) throws OrcCorruptionException { requireNonNull(type, "type is null"); - verifyStreamType(streamDescriptor, type, ArrayType.class::isInstance); + verifyStreamType(column, type, ArrayType.class::isInstance); elementType = ((ArrayType) type).getElementType(); - this.streamDescriptor = requireNonNull(streamDescriptor, "stream is null"); - this.elementStreamReader = createStreamReader(elementType, streamDescriptor.getNestedStreams().get(0), systemMemoryContext); + this.column = requireNonNull(column, "column is null"); + this.blockFactory = requireNonNull(blockFactory, "blockFactory is null"); + this.elementColumnReader = createColumnReader(elementType, column.getNestedColumns().get(0), systemMemoryContext, blockFactory); } @Override @@ -104,10 +107,10 @@ public Block readBlock() } if (readOffset > 0) { if (lengthStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is not present"); } long elementSkipSize = lengthStream.sum(readOffset); - elementStreamReader.prepareNextRead(toIntExact(elementSkipSize)); + elementColumnReader.prepareNextRead(toIntExact(elementSkipSize)); } } @@ -117,7 +120,7 @@ public Block readBlock() boolean[] nullVector = null; if (presentStream == null) { if (lengthStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is not present"); } lengthStream.next(offsetVector, nextBatchSize); } @@ -126,7 +129,7 @@ public Block readBlock() int nullValues = presentStream.getUnsetBits(nextBatchSize, nullVector); if (nullValues != nextBatchSize) { if (lengthStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is not present"); } lengthStream.next(offsetVector, nextBatchSize - nullValues); unpackLengthNulls(offsetVector, nullVector, nextBatchSize - nullValues); @@ -138,8 +141,8 @@ public Block readBlock() Block elements; if (elementCount > 0) { - elementStreamReader.prepareNextRead(elementCount); - elements = elementStreamReader.readBlock(); + elementColumnReader.prepareNextRead(elementCount); + elements = blockFactory.createBlock(elementCount, elementColumnReader::readBlock); } else { elements = elementType.createBlockBuilder(null, 0).build(); @@ -162,7 +165,7 @@ private void openRowGroup() } @Override - public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) throws IOException { presentStreamSource = missingStreamSource(BooleanInputStream.class); @@ -176,15 +179,15 @@ public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSour rowGroupOpen = false; - elementStreamReader.startStripe(timeZone, dictionaryStreamSources, encoding); + elementColumnReader.startStripe(timeZone, dictionaryStreamSources, encoding); } @Override public void startRowGroup(InputStreamSources dataStreamSources) throws IOException { - presentStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, PRESENT, BooleanInputStream.class); - lengthStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, LENGTH, LongInputStream.class); + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); + lengthStreamSource = dataStreamSources.getInputStreamSource(column, LENGTH, LongInputStream.class); readOffset = 0; nextBatchSize = 0; @@ -194,14 +197,14 @@ public void startRowGroup(InputStreamSources dataStreamSources) rowGroupOpen = false; - elementStreamReader.startRowGroup(dataStreamSources); + elementColumnReader.startRowGroup(dataStreamSources); } @Override public String toString() { return toStringHelper(this) - .addValue(streamDescriptor) + .addValue(column) .toString(); } @@ -209,7 +212,7 @@ public String toString() public void close() { try (Closer closer = Closer.create()) { - closer.register(elementStreamReader::close); + closer.register(elementColumnReader::close); } catch (IOException e) { throw new UncheckedIOException(e); @@ -219,6 +222,6 @@ public void close() @Override public long getRetainedSizeInBytes() { - return INSTANCE_SIZE + elementStreamReader.getRetainedSizeInBytes(); + return INSTANCE_SIZE + elementColumnReader.getRetainedSizeInBytes(); } } diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/LongStreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/LongColumnReader.java similarity index 90% rename from presto-orc/src/main/java/io/prestosql/orc/reader/LongStreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/LongColumnReader.java index c115e7b1ab99..859d560ad28a 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/LongStreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/LongColumnReader.java @@ -14,9 +14,10 @@ package io.prestosql.orc.reader; import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.BooleanInputStream; import io.prestosql.orc.stream.InputStreamSource; import io.prestosql.orc.stream.InputStreamSources; @@ -37,7 +38,6 @@ import java.io.IOException; import java.time.ZoneId; -import java.util.List; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -53,13 +53,13 @@ import static io.prestosql.orc.stream.MissingInputStreamSource.missingStreamSource; import static java.util.Objects.requireNonNull; -public class LongStreamReader - implements StreamReader +public class LongColumnReader + implements ColumnReader { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(LongStreamReader.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(LongColumnReader.class).instanceSize(); private final Type type; - private final StreamDescriptor streamDescriptor; + private final OrcColumn column; private int readOffset; private int nextBatchSize; @@ -82,14 +82,14 @@ public class LongStreamReader private final LocalMemoryContext systemMemoryContext; - public LongStreamReader(Type type, StreamDescriptor streamDescriptor, LocalMemoryContext systemMemoryContext) + public LongColumnReader(Type type, OrcColumn column, LocalMemoryContext systemMemoryContext) throws OrcCorruptionException { requireNonNull(type, "type is null"); - verifyStreamType(streamDescriptor, type, t -> t instanceof BigintType || t instanceof IntegerType || t instanceof SmallintType || t instanceof DateType); + verifyStreamType(column, type, t -> t instanceof BigintType || t instanceof IntegerType || t instanceof SmallintType || t instanceof DateType); this.type = type; - this.streamDescriptor = requireNonNull(streamDescriptor, "stream is null"); + this.column = requireNonNull(column, "column is null"); this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); } @@ -116,7 +116,7 @@ public Block readBlock() } if (readOffset > 0) { if (dataStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is missing"); } dataStream.skip(readOffset); } @@ -125,7 +125,7 @@ public Block readBlock() Block block; if (dataStream == null) { if (presentStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is null but present stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is null but present stream is missing"); } presentStream.skip(nextBatchSize); block = RunLengthEncodedBlock.create(type, null, nextBatchSize); @@ -251,7 +251,7 @@ private void openRowGroup() } @Override - public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) { presentStreamSource = missingStreamSource(BooleanInputStream.class); dataStreamSource = missingStreamSource(LongInputStream.class); @@ -268,8 +268,8 @@ public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSour @Override public void startRowGroup(InputStreamSources dataStreamSources) { - presentStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, PRESENT, BooleanInputStream.class); - dataStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, DATA, LongInputStream.class); + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); + dataStreamSource = dataStreamSources.getInputStreamSource(column, DATA, LongInputStream.class); readOffset = 0; nextBatchSize = 0; @@ -284,7 +284,7 @@ public void startRowGroup(InputStreamSources dataStreamSources) public String toString() { return toStringHelper(this) - .addValue(streamDescriptor) + .addValue(column) .toString(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/MapStreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/MapColumnReader.java similarity index 75% rename from presto-orc/src/main/java/io/prestosql/orc/reader/MapStreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/MapColumnReader.java index b578db71f4f5..5b4d679f377c 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/MapStreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/MapColumnReader.java @@ -15,9 +15,11 @@ import com.google.common.io.Closer; import io.prestosql.memory.context.AggregatedMemoryContext; +import io.prestosql.orc.OrcBlockFactory.NestedBlockFactory; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.BooleanInputStream; import io.prestosql.orc.stream.InputStreamSource; import io.prestosql.orc.stream.InputStreamSources; @@ -34,30 +36,30 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.time.ZoneId; -import java.util.List; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; import static io.prestosql.orc.metadata.Stream.StreamKind.LENGTH; import static io.prestosql.orc.metadata.Stream.StreamKind.PRESENT; +import static io.prestosql.orc.reader.ColumnReaders.createColumnReader; import static io.prestosql.orc.reader.ReaderUtils.convertLengthVectorToOffsetVector; import static io.prestosql.orc.reader.ReaderUtils.unpackLengthNulls; import static io.prestosql.orc.reader.ReaderUtils.verifyStreamType; -import static io.prestosql.orc.reader.StreamReaders.createStreamReader; import static io.prestosql.orc.stream.MissingInputStreamSource.missingStreamSource; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; -public class MapStreamReader - implements StreamReader +public class MapColumnReader + implements ColumnReader { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(MapStreamReader.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(MapColumnReader.class).instanceSize(); private final MapType type; - private final StreamDescriptor streamDescriptor; + private final OrcColumn column; + private final NestedBlockFactory blockFactory; - private final StreamReader keyStreamReader; - private final StreamReader valueStreamReader; + private final ColumnReader keyColumnReader; + private final ColumnReader valueColumnReader; private int readOffset; private int nextBatchSize; @@ -74,16 +76,17 @@ public class MapStreamReader private boolean rowGroupOpen; - public MapStreamReader(Type type, StreamDescriptor streamDescriptor, AggregatedMemoryContext systemMemoryContext) + public MapColumnReader(Type type, OrcColumn column, AggregatedMemoryContext systemMemoryContext, NestedBlockFactory blockFactory) throws OrcCorruptionException { requireNonNull(type, "type is null"); - verifyStreamType(streamDescriptor, type, MapType.class::isInstance); + verifyStreamType(column, type, MapType.class::isInstance); this.type = (MapType) type; - this.streamDescriptor = requireNonNull(streamDescriptor, "stream is null"); - this.keyStreamReader = createStreamReader(this.type.getKeyType(), streamDescriptor.getNestedStreams().get(0), systemMemoryContext); - this.valueStreamReader = createStreamReader(this.type.getValueType(), streamDescriptor.getNestedStreams().get(1), systemMemoryContext); + this.column = requireNonNull(column, "column is null"); + this.blockFactory = requireNonNull(blockFactory, "blockFactory is null"); + this.keyColumnReader = createColumnReader(this.type.getKeyType(), column.getNestedColumns().get(0), systemMemoryContext, blockFactory); + this.valueColumnReader = createColumnReader(this.type.getValueType(), column.getNestedColumns().get(1), systemMemoryContext, blockFactory); } @Override @@ -109,11 +112,11 @@ public Block readBlock() } if (readOffset > 0) { if (lengthStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is not present"); } long entrySkipSize = lengthStream.sum(readOffset); - keyStreamReader.prepareNextRead(toIntExact(entrySkipSize)); - valueStreamReader.prepareNextRead(toIntExact(entrySkipSize)); + keyColumnReader.prepareNextRead(toIntExact(entrySkipSize)); + valueColumnReader.prepareNextRead(toIntExact(entrySkipSize)); } } @@ -124,7 +127,7 @@ public Block readBlock() if (presentStream == null) { if (lengthStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is not present"); } lengthStream.next(offsetVector, nextBatchSize); } @@ -133,7 +136,7 @@ public Block readBlock() int nullValues = presentStream.getUnsetBits(nextBatchSize, nullVector); if (nullValues != nextBatchSize) { if (lengthStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is not present"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is not present"); } lengthStream.next(offsetVector, nextBatchSize - nullValues); unpackLengthNulls(offsetVector, nullVector, nextBatchSize - nullValues); @@ -149,10 +152,10 @@ public Block readBlock() Block keys; Block values; if (entryCount > 0) { - keyStreamReader.prepareNextRead(entryCount); - valueStreamReader.prepareNextRead(entryCount); - keys = keyStreamReader.readBlock(); - values = valueStreamReader.readBlock(); + keyColumnReader.prepareNextRead(entryCount); + valueColumnReader.prepareNextRead(entryCount); + keys = keyColumnReader.readBlock(); + values = blockFactory.createBlock(entryCount, valueColumnReader::readBlock); } else { keys = type.getKeyType().createBlockBuilder(null, 0).build(); @@ -221,7 +224,7 @@ private void openRowGroup() } @Override - public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) throws IOException { presentStreamSource = missingStreamSource(BooleanInputStream.class); @@ -235,16 +238,16 @@ public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSour rowGroupOpen = false; - keyStreamReader.startStripe(timeZone, dictionaryStreamSources, encoding); - valueStreamReader.startStripe(timeZone, dictionaryStreamSources, encoding); + keyColumnReader.startStripe(timeZone, dictionaryStreamSources, encoding); + valueColumnReader.startStripe(timeZone, dictionaryStreamSources, encoding); } @Override public void startRowGroup(InputStreamSources dataStreamSources) throws IOException { - presentStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, PRESENT, BooleanInputStream.class); - lengthStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, LENGTH, LongInputStream.class); + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); + lengthStreamSource = dataStreamSources.getInputStreamSource(column, LENGTH, LongInputStream.class); readOffset = 0; nextBatchSize = 0; @@ -254,15 +257,15 @@ public void startRowGroup(InputStreamSources dataStreamSources) rowGroupOpen = false; - keyStreamReader.startRowGroup(dataStreamSources); - valueStreamReader.startRowGroup(dataStreamSources); + keyColumnReader.startRowGroup(dataStreamSources); + valueColumnReader.startRowGroup(dataStreamSources); } @Override public String toString() { return toStringHelper(this) - .addValue(streamDescriptor) + .addValue(column) .toString(); } @@ -270,8 +273,8 @@ public String toString() public void close() { try (Closer closer = Closer.create()) { - closer.register(keyStreamReader::close); - closer.register(valueStreamReader::close); + closer.register(keyColumnReader::close); + closer.register(valueColumnReader::close); } catch (IOException e) { throw new UncheckedIOException(e); @@ -281,6 +284,6 @@ public void close() @Override public long getRetainedSizeInBytes() { - return INSTANCE_SIZE + keyStreamReader.getRetainedSizeInBytes() + valueStreamReader.getRetainedSizeInBytes(); + return INSTANCE_SIZE + keyColumnReader.getRetainedSizeInBytes() + valueColumnReader.getRetainedSizeInBytes(); } } diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/ReaderUtils.java b/presto-orc/src/main/java/io/prestosql/orc/reader/ReaderUtils.java index beacb3fedf38..9c2a15214d13 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/ReaderUtils.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/ReaderUtils.java @@ -13,8 +13,8 @@ */ package io.prestosql.orc.reader; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.spi.type.Type; import java.util.function.Predicate; @@ -25,7 +25,7 @@ final class ReaderUtils { private ReaderUtils() {} - public static void verifyStreamType(StreamDescriptor streamDescriptor, Type actual, Predicate validTypes) + public static void verifyStreamType(OrcColumn column, Type actual, Predicate validTypes) throws OrcCorruptionException { if (validTypes.test(actual)) { @@ -33,11 +33,11 @@ public static void verifyStreamType(StreamDescriptor streamDescriptor, Type actu } throw new OrcCorruptionException( - streamDescriptor.getOrcDataSourceId(), + column.getOrcDataSourceId(), "Can not read SQL type %s from ORC stream %s of type %s", actual, - streamDescriptor.getStreamName(), - streamDescriptor.getStreamType()); + column.getPath(), + column.getColumnType()); } public static int minNonNullValueSize(int nonNullCount) diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/SliceStreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/SliceColumnReader.java similarity index 80% rename from presto-orc/src/main/java/io/prestosql/orc/reader/SliceStreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/SliceColumnReader.java index 89fe236e163b..007f430660b1 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/SliceStreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/SliceColumnReader.java @@ -16,10 +16,11 @@ import com.google.common.io.Closer; import io.airlift.slice.Slice; import io.prestosql.memory.context.AggregatedMemoryContext; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.ColumnEncoding.ColumnEncodingKind; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.InputStreamSources; import io.prestosql.spi.block.Block; import io.prestosql.spi.type.CharType; @@ -31,7 +32,6 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.time.ZoneId; -import java.util.List; import static com.google.common.base.MoreObjects.toStringHelper; import static io.prestosql.orc.metadata.ColumnEncoding.ColumnEncodingKind.DICTIONARY; @@ -46,28 +46,28 @@ import static io.prestosql.spi.type.Varchars.isVarcharType; import static java.util.Objects.requireNonNull; -public class SliceStreamReader - implements StreamReader +public class SliceColumnReader + implements ColumnReader { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(SliceStreamReader.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SliceColumnReader.class).instanceSize(); - private final StreamDescriptor streamDescriptor; - private final SliceDirectStreamReader directReader; - private final SliceDictionaryStreamReader dictionaryReader; - private StreamReader currentReader; + private final OrcColumn column; + private final SliceDirectColumnReader directReader; + private final SliceDictionaryColumnReader dictionaryReader; + private ColumnReader currentReader; - public SliceStreamReader(Type type, StreamDescriptor streamDescriptor, AggregatedMemoryContext systemMemoryContext) + public SliceColumnReader(Type type, OrcColumn column, AggregatedMemoryContext systemMemoryContext) throws OrcCorruptionException { requireNonNull(type, "type is null"); - verifyStreamType(streamDescriptor, type, t -> t instanceof VarcharType || t instanceof CharType || t instanceof VarbinaryType); + verifyStreamType(column, type, t -> t instanceof VarcharType || t instanceof CharType || t instanceof VarbinaryType); - this.streamDescriptor = requireNonNull(streamDescriptor, "stream is null"); + this.column = requireNonNull(column, "column is null"); int maxCodePointCount = getMaxCodePointCount(type); boolean charType = isCharType(type); - directReader = new SliceDirectStreamReader(streamDescriptor, maxCodePointCount, charType); - dictionaryReader = new SliceDictionaryStreamReader(streamDescriptor, systemMemoryContext.newLocalMemoryContext(SliceStreamReader.class.getSimpleName()), maxCodePointCount, charType); + directReader = new SliceDirectColumnReader(column, maxCodePointCount, charType); + dictionaryReader = new SliceDictionaryColumnReader(column, systemMemoryContext.newLocalMemoryContext(SliceColumnReader.class.getSimpleName()), maxCodePointCount, charType); } @Override @@ -84,10 +84,10 @@ public void prepareNextRead(int batchSize) } @Override - public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) throws IOException { - ColumnEncodingKind columnEncodingKind = encoding.get(streamDescriptor.getStreamId()).getColumnEncodingKind(); + ColumnEncodingKind columnEncodingKind = encoding.get(column.getColumnId()).getColumnEncodingKind(); if (columnEncodingKind == DIRECT || columnEncodingKind == DIRECT_V2) { currentReader = directReader; } @@ -112,7 +112,7 @@ public void startRowGroup(InputStreamSources dataStreamSources) public String toString() { return toStringHelper(this) - .addValue(streamDescriptor) + .addValue(column) .toString(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/SliceDictionaryStreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/SliceDictionaryColumnReader.java similarity index 89% rename from presto-orc/src/main/java/io/prestosql/orc/reader/SliceDictionaryStreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/SliceDictionaryColumnReader.java index a9a061cacbaf..b3bfce369a1c 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/SliceDictionaryStreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/SliceDictionaryColumnReader.java @@ -15,9 +15,10 @@ import io.airlift.slice.Slice; import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.BooleanInputStream; import io.prestosql.orc.stream.ByteArrayInputStream; import io.prestosql.orc.stream.InputStreamSource; @@ -33,7 +34,6 @@ import java.io.IOException; import java.time.ZoneId; -import java.util.List; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -47,22 +47,22 @@ import static io.prestosql.orc.metadata.Stream.StreamKind.LENGTH; import static io.prestosql.orc.metadata.Stream.StreamKind.PRESENT; import static io.prestosql.orc.reader.ReaderUtils.minNonNullValueSize; -import static io.prestosql.orc.reader.SliceStreamReader.computeTruncatedLength; +import static io.prestosql.orc.reader.SliceColumnReader.computeTruncatedLength; import static io.prestosql.orc.stream.MissingInputStreamSource.missingStreamSource; import static java.lang.Math.toIntExact; import static java.util.Arrays.fill; import static java.util.Objects.requireNonNull; -public class SliceDictionaryStreamReader - implements StreamReader +public class SliceDictionaryColumnReader + implements ColumnReader { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(SliceDictionaryStreamReader.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SliceDictionaryColumnReader.class).instanceSize(); private static final byte[] EMPTY_DICTIONARY_DATA = new byte[0]; // add one extra entry for null after strip/rowGroup dictionary private static final int[] EMPTY_DICTIONARY_OFFSETS = new int[2]; - private final StreamDescriptor streamDescriptor; + private final OrcColumn column; private final int maxCodePointCount; private final boolean isCharType; @@ -97,11 +97,11 @@ public class SliceDictionaryStreamReader private final LocalMemoryContext systemMemoryContext; - public SliceDictionaryStreamReader(StreamDescriptor streamDescriptor, LocalMemoryContext systemMemoryContext, int maxCodePointCount, boolean isCharType) + public SliceDictionaryColumnReader(OrcColumn column, LocalMemoryContext systemMemoryContext, int maxCodePointCount, boolean isCharType) { this.maxCodePointCount = maxCodePointCount; this.isCharType = isCharType; - this.streamDescriptor = requireNonNull(streamDescriptor, "stream is null"); + this.column = requireNonNull(column, "column is null"); this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); } @@ -128,7 +128,7 @@ public Block readBlock() } if (readOffset > 0) { if (dataStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is missing"); } dataStream.skip(readOffset); } @@ -137,7 +137,7 @@ public Block readBlock() Block block; if (dataStream == null) { if (presentStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is null but present stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is null but present stream is missing"); } presentStream.skip(nextBatchSize); block = readAllNullsBlock(); @@ -237,7 +237,7 @@ private void openRowGroup() // read the lengths LongInputStream lengthStream = dictionaryLengthStreamSource.openStream(); if (lengthStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Dictionary is not empty but dictionary length stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Dictionary is not empty but dictionary length stream is missing"); } lengthStream.next(dictionaryLength, dictionarySize); @@ -313,11 +313,11 @@ private static void readDictionary( } @Override - public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) { - dictionaryDataStreamSource = dictionaryStreamSources.getInputStreamSource(streamDescriptor, DICTIONARY_DATA, ByteArrayInputStream.class); - dictionaryLengthStreamSource = dictionaryStreamSources.getInputStreamSource(streamDescriptor, LENGTH, LongInputStream.class); - dictionarySize = encoding.get(streamDescriptor.getStreamId()).getDictionarySize(); + dictionaryDataStreamSource = dictionaryStreamSources.getInputStreamSource(column, DICTIONARY_DATA, ByteArrayInputStream.class); + dictionaryLengthStreamSource = dictionaryStreamSources.getInputStreamSource(column, LENGTH, LongInputStream.class); + dictionarySize = encoding.get(column.getColumnId()).getDictionarySize(); dictionaryOpen = false; presentStreamSource = missingStreamSource(BooleanInputStream.class); @@ -335,8 +335,8 @@ public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSour @Override public void startRowGroup(InputStreamSources dataStreamSources) { - presentStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, PRESENT, BooleanInputStream.class); - dataStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, DATA, LongInputStream.class); + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); + dataStreamSource = dataStreamSources.getInputStreamSource(column, DATA, LongInputStream.class); readOffset = 0; nextBatchSize = 0; @@ -351,7 +351,7 @@ public void startRowGroup(InputStreamSources dataStreamSources) public String toString() { return toStringHelper(this) - .addValue(streamDescriptor) + .addValue(column) .toString(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/SliceDirectStreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/SliceDirectColumnReader.java similarity index 86% rename from presto-orc/src/main/java/io/prestosql/orc/reader/SliceDirectStreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/SliceDirectColumnReader.java index ad503d1860e4..27f17a973907 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/SliceDirectStreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/SliceDirectColumnReader.java @@ -16,9 +16,10 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.airlift.units.DataSize; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.BooleanInputStream; import io.prestosql.orc.stream.ByteArrayInputStream; import io.prestosql.orc.stream.InputStreamSource; @@ -34,7 +35,6 @@ import java.io.IOException; import java.time.ZoneId; -import java.util.List; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -47,22 +47,22 @@ import static io.prestosql.orc.metadata.Stream.StreamKind.PRESENT; import static io.prestosql.orc.reader.ReaderUtils.convertLengthVectorToOffsetVector; import static io.prestosql.orc.reader.ReaderUtils.unpackLengthNulls; -import static io.prestosql.orc.reader.SliceStreamReader.computeTruncatedLength; +import static io.prestosql.orc.reader.SliceColumnReader.computeTruncatedLength; import static io.prestosql.orc.stream.MissingInputStreamSource.missingStreamSource; import static io.prestosql.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -public class SliceDirectStreamReader - implements StreamReader +public class SliceDirectColumnReader + implements ColumnReader { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(SliceDirectStreamReader.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(SliceDirectColumnReader.class).instanceSize(); private static final int ONE_GIGABYTE = toIntExact(new DataSize(1, GIGABYTE).toBytes()); private final int maxCodePointCount; private final boolean isCharType; - private final StreamDescriptor streamDescriptor; + private final OrcColumn column; private int readOffset; private int nextBatchSize; @@ -81,12 +81,12 @@ public class SliceDirectStreamReader private boolean rowGroupOpen; - public SliceDirectStreamReader(StreamDescriptor streamDescriptor, int maxCodePointCount, boolean isCharType) + public SliceDirectColumnReader(OrcColumn column, int maxCodePointCount, boolean isCharType) { this.maxCodePointCount = maxCodePointCount; this.isCharType = isCharType; - this.streamDescriptor = requireNonNull(streamDescriptor, "stream is null"); + this.column = requireNonNull(column, "column is null"); } @Override @@ -112,12 +112,12 @@ public Block readBlock() } if (readOffset > 0) { if (lengthStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but length stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but length stream is missing"); } long dataSkipSize = lengthStream.sum(readOffset); if (dataSkipSize > 0) { if (dataStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is missing"); } dataStream.skip(dataSkipSize); } @@ -126,7 +126,7 @@ public Block readBlock() if (lengthStream == null) { if (presentStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is null but present stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is null but present stream is missing"); } presentStream.skip(nextBatchSize); Block nullValueBlock = readAllNullsBlock(); @@ -157,7 +157,7 @@ public Block readBlock() } if (lengthStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but length stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but length stream is missing"); } if (nullCount == 0) { isNullVector = null; @@ -183,10 +183,10 @@ public Block readBlock() } if (totalLength > ONE_GIGABYTE) { throw new PrestoException(GENERIC_INTERNAL_ERROR, - format("Values in column \"%s\" are too large to process for Presto. %s column values are larger than 1GB [%s]", streamDescriptor.getFieldName(), nextBatchSize, streamDescriptor.getOrcDataSourceId())); + format("Values in column \"%s\" are too large to process for Presto. %s column values are larger than 1GB [%s]", column.getPath(), nextBatchSize, column.getOrcDataSourceId())); } if (dataStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but data stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but data stream is missing"); } // allocate enough space to read @@ -246,7 +246,7 @@ private void openRowGroup() } @Override - public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) { presentStreamSource = missingStreamSource(BooleanInputStream.class); lengthStreamSource = missingStreamSource(LongInputStream.class); @@ -265,9 +265,9 @@ public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSour @Override public void startRowGroup(InputStreamSources dataStreamSources) { - presentStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, PRESENT, BooleanInputStream.class); - lengthStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, LENGTH, LongInputStream.class); - dataByteSource = dataStreamSources.getInputStreamSource(streamDescriptor, DATA, ByteArrayInputStream.class); + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); + lengthStreamSource = dataStreamSources.getInputStreamSource(column, LENGTH, LongInputStream.class); + dataByteSource = dataStreamSources.getInputStreamSource(column, DATA, ByteArrayInputStream.class); readOffset = 0; nextBatchSize = 0; @@ -283,7 +283,7 @@ public void startRowGroup(InputStreamSources dataStreamSources) public String toString() { return toStringHelper(this) - .addValue(streamDescriptor) + .addValue(column) .toString(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/StreamReaders.java b/presto-orc/src/main/java/io/prestosql/orc/reader/StreamReaders.java deleted file mode 100644 index bba7971fcadb..000000000000 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/StreamReaders.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.orc.reader; - -import io.prestosql.memory.context.AggregatedMemoryContext; -import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; -import io.prestosql.spi.type.Type; - -public final class StreamReaders -{ - private StreamReaders() {} - - public static StreamReader createStreamReader(Type type, StreamDescriptor streamDescriptor, AggregatedMemoryContext systemMemoryContext) - throws OrcCorruptionException - { - switch (streamDescriptor.getStreamType()) { - case BOOLEAN: - return new BooleanStreamReader(type, streamDescriptor, systemMemoryContext.newLocalMemoryContext(StreamReaders.class.getSimpleName())); - case BYTE: - return new ByteStreamReader(type, streamDescriptor, systemMemoryContext.newLocalMemoryContext(StreamReaders.class.getSimpleName())); - case SHORT: - case INT: - case LONG: - case DATE: - return new LongStreamReader(type, streamDescriptor, systemMemoryContext.newLocalMemoryContext(StreamReaders.class.getSimpleName())); - case FLOAT: - return new FloatStreamReader(type, streamDescriptor, systemMemoryContext.newLocalMemoryContext(StreamReaders.class.getSimpleName())); - case DOUBLE: - return new DoubleStreamReader(type, streamDescriptor, systemMemoryContext.newLocalMemoryContext(StreamReaders.class.getSimpleName())); - case BINARY: - case STRING: - case VARCHAR: - case CHAR: - return new SliceStreamReader(type, streamDescriptor, systemMemoryContext); - case TIMESTAMP: - return new TimestampStreamReader(type, streamDescriptor, systemMemoryContext.newLocalMemoryContext(StreamReaders.class.getSimpleName())); - case LIST: - return new ListStreamReader(type, streamDescriptor, systemMemoryContext); - case STRUCT: - return new StructStreamReader(type, streamDescriptor, systemMemoryContext); - case MAP: - return new MapStreamReader(type, streamDescriptor, systemMemoryContext); - case DECIMAL: - return new DecimalStreamReader(type, streamDescriptor, systemMemoryContext.newLocalMemoryContext(StreamReaders.class.getSimpleName())); - case UNION: - default: - throw new IllegalArgumentException("Unsupported type: " + streamDescriptor.getStreamType()); - } - } -} diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/StructStreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/StructColumnReader.java similarity index 77% rename from presto-orc/src/main/java/io/prestosql/orc/reader/StructStreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/StructColumnReader.java index 2affc273b2c1..2919bd8f0745 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/StructStreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/StructColumnReader.java @@ -17,9 +17,11 @@ import com.google.common.collect.ImmutableMap; import com.google.common.io.Closer; import io.prestosql.memory.context.AggregatedMemoryContext; +import io.prestosql.orc.OrcBlockFactory.NestedBlockFactory; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.BooleanInputStream; import io.prestosql.orc.stream.InputStreamSource; import io.prestosql.orc.stream.InputStreamSources; @@ -46,19 +48,20 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.prestosql.orc.metadata.Stream.StreamKind.PRESENT; +import static io.prestosql.orc.reader.ColumnReaders.createColumnReader; import static io.prestosql.orc.reader.ReaderUtils.verifyStreamType; -import static io.prestosql.orc.reader.StreamReaders.createStreamReader; import static io.prestosql.orc.stream.MissingInputStreamSource.missingStreamSource; import static java.util.Objects.requireNonNull; -public class StructStreamReader - implements StreamReader +public class StructColumnReader + implements ColumnReader { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(StructStreamReader.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(StructColumnReader.class).instanceSize(); - private final StreamDescriptor streamDescriptor; + private final OrcColumn column; + private final NestedBlockFactory blockFactory; - private final Map structFields; + private final Map structFields; private final RowType type; private final ImmutableList fieldNames; @@ -71,29 +74,30 @@ public class StructStreamReader private boolean rowGroupOpen; - StructStreamReader(Type type, StreamDescriptor streamDescriptor, AggregatedMemoryContext systemMemoryContext) + StructColumnReader(Type type, OrcColumn column, AggregatedMemoryContext systemMemoryContext, NestedBlockFactory blockFactory) throws OrcCorruptionException { requireNonNull(type, "type is null"); - verifyStreamType(streamDescriptor, type, RowType.class::isInstance); + verifyStreamType(column, type, RowType.class::isInstance); this.type = (RowType) type; - this.streamDescriptor = requireNonNull(streamDescriptor, "stream is null"); + this.column = requireNonNull(column, "column is null"); + this.blockFactory = requireNonNull(blockFactory, "blockFactory is null"); - Map nestedStreams = streamDescriptor.getNestedStreams().stream() - .collect(toImmutableMap(stream -> stream.getFieldName().toLowerCase(Locale.ENGLISH), stream -> stream)); + Map nestedColumns = column.getNestedColumns().stream() + .collect(toImmutableMap(stream -> stream.getColumnName().toLowerCase(Locale.ENGLISH), stream -> stream)); ImmutableList.Builder fieldNames = ImmutableList.builder(); - ImmutableMap.Builder structFields = ImmutableMap.builder(); + ImmutableMap.Builder structFields = ImmutableMap.builder(); for (Field field : this.type.getFields()) { String fieldName = field.getName() .orElseThrow(() -> new IllegalArgumentException("ROW type does not have field names declared: " + type)) .toLowerCase(Locale.ENGLISH); fieldNames.add(fieldName); - StreamDescriptor fieldStream = nestedStreams.get(fieldName); + OrcColumn fieldStream = nestedColumns.get(fieldName); if (fieldStream != null) { - structFields.put(fieldName, createStreamReader(field.getType(), fieldStream, systemMemoryContext)); + structFields.put(fieldName, createColumnReader(field.getType(), fieldStream, systemMemoryContext, blockFactory)); } } this.fieldNames = fieldNames.build(); @@ -121,7 +125,7 @@ public Block readBlock() // and use this as the skip size for the field readers readOffset = presentStream.countBitsSet(readOffset); } - for (StreamReader structField : structFields.values()) { + for (ColumnReader structField : structFields.values()) { structField.prepareNextRead(readOffset); } } @@ -170,7 +174,7 @@ private void openRowGroup() } @Override - public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) throws IOException { presentStreamSource = missingStreamSource(BooleanInputStream.class); @@ -182,7 +186,7 @@ public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSour rowGroupOpen = false; - for (StreamReader structField : structFields.values()) { + for (ColumnReader structField : structFields.values()) { structField.startStripe(timeZone, dictionaryStreamSources, encoding); } } @@ -191,7 +195,7 @@ public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSour public void startRowGroup(InputStreamSources dataStreamSources) throws IOException { - presentStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, PRESENT, BooleanInputStream.class); + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); readOffset = 0; nextBatchSize = 0; @@ -200,7 +204,7 @@ public void startRowGroup(InputStreamSources dataStreamSources) rowGroupOpen = false; - for (StreamReader structField : structFields.values()) { + for (ColumnReader structField : structFields.values()) { structField.startRowGroup(dataStreamSources); } } @@ -209,22 +213,21 @@ public void startRowGroup(InputStreamSources dataStreamSources) public String toString() { return toStringHelper(this) - .addValue(streamDescriptor) + .addValue(column) .toString(); } private Block[] getBlocksForType(int positionCount) - throws IOException { Block[] blocks = new Block[fieldNames.size()]; for (int i = 0; i < fieldNames.size(); i++) { String fieldName = fieldNames.get(i); - StreamReader streamReader = structFields.get(fieldName); - if (streamReader != null) { - streamReader.prepareNextRead(positionCount); - blocks[i] = streamReader.readBlock(); + ColumnReader columnReader = structFields.get(fieldName); + if (columnReader != null) { + columnReader.prepareNextRead(positionCount); + blocks[i] = blockFactory.createBlock(positionCount, columnReader::readBlock); } else { blocks[i] = RunLengthEncodedBlock.create(type.getFields().get(i).getType(), null, positionCount); @@ -237,7 +240,7 @@ private Block[] getBlocksForType(int positionCount) public void close() { try (Closer closer = Closer.create()) { - for (StreamReader structField : structFields.values()) { + for (ColumnReader structField : structFields.values()) { closer.register(structField::close); } } @@ -250,7 +253,7 @@ public void close() public long getRetainedSizeInBytes() { long retainedSizeInBytes = INSTANCE_SIZE; - for (StreamReader structField : structFields.values()) { + for (ColumnReader structField : structFields.values()) { retainedSizeInBytes += structField.getRetainedSizeInBytes(); } return retainedSizeInBytes; diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/TimestampStreamReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/TimestampColumnReader.java similarity index 83% rename from presto-orc/src/main/java/io/prestosql/orc/reader/TimestampStreamReader.java rename to presto-orc/src/main/java/io/prestosql/orc/reader/TimestampColumnReader.java index 8e727d2529f6..03b2d39266e7 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/TimestampStreamReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/TimestampColumnReader.java @@ -14,9 +14,10 @@ package io.prestosql.orc.reader; import io.prestosql.memory.context.LocalMemoryContext; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; -import io.prestosql.orc.StreamDescriptor; import io.prestosql.orc.metadata.ColumnEncoding; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.BooleanInputStream; import io.prestosql.orc.stream.InputStreamSource; import io.prestosql.orc.stream.InputStreamSources; @@ -33,7 +34,6 @@ import java.io.IOException; import java.time.ZoneId; import java.time.ZonedDateTime; -import java.util.List; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -46,14 +46,14 @@ import static io.prestosql.spi.type.TimestampType.TIMESTAMP; import static java.util.Objects.requireNonNull; -public class TimestampStreamReader - implements StreamReader +public class TimestampColumnReader + implements ColumnReader { - private static final int INSTANCE_SIZE = ClassLayout.parseClass(TimestampStreamReader.class).instanceSize(); + private static final int INSTANCE_SIZE = ClassLayout.parseClass(TimestampColumnReader.class).instanceSize(); private static final int MILLIS_PER_SECOND = 1000; - private final StreamDescriptor streamDescriptor; + private final OrcColumn column; private long baseTimestampInSeconds; @@ -80,13 +80,13 @@ public class TimestampStreamReader private final LocalMemoryContext systemMemoryContext; - public TimestampStreamReader(Type type, StreamDescriptor streamDescriptor, LocalMemoryContext systemMemoryContext) + public TimestampColumnReader(Type type, OrcColumn column, LocalMemoryContext systemMemoryContext) throws OrcCorruptionException { requireNonNull(type, "type is null"); - verifyStreamType(streamDescriptor, type, TimestampType.class::isInstance); + verifyStreamType(column, type, TimestampType.class::isInstance); - this.streamDescriptor = requireNonNull(streamDescriptor, "stream is null"); + this.column = requireNonNull(column, "column is null"); this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); } @@ -113,10 +113,10 @@ public Block readBlock() } if (readOffset > 0) { if (secondsStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but seconds stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but seconds stream is missing"); } if (nanosStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but nanos stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but nanos stream is missing"); } secondsStream.skip(readOffset); @@ -127,7 +127,7 @@ public Block readBlock() Block block; if (secondsStream == null && nanosStream == null) { if (presentStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is null but present stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is null but present stream is missing"); } presentStream.skip(nextBatchSize); block = RunLengthEncodedBlock.create(TIMESTAMP, null, nextBatchSize); @@ -158,10 +158,10 @@ private Block readNonNullBlock() throws IOException { if (secondsStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but seconds stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but seconds stream is missing"); } if (nanosStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but nanos stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but nanos stream is missing"); } long[] values = new long[nextBatchSize]; @@ -175,10 +175,10 @@ private Block readNullBlock(boolean[] isNull) throws IOException { if (secondsStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but seconds stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but seconds stream is missing"); } if (nanosStream == null) { - throw new OrcCorruptionException(streamDescriptor.getOrcDataSourceId(), "Value is not null but nanos stream is missing"); + throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is not null but nanos stream is missing"); } long[] values = new long[isNull.length]; @@ -201,7 +201,7 @@ private void openRowGroup() } @Override - public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, List encoding) + public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSources, ColumnMetadata encoding) { baseTimestampInSeconds = ZonedDateTime.of(2015, 1, 1, 0, 0, 0, 0, timeZone).toEpochSecond(); @@ -222,9 +222,9 @@ public void startStripe(ZoneId timeZone, InputStreamSources dictionaryStreamSour @Override public void startRowGroup(InputStreamSources dataStreamSources) { - presentStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, PRESENT, BooleanInputStream.class); - secondsStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, DATA, LongInputStream.class); - nanosStreamSource = dataStreamSources.getInputStreamSource(streamDescriptor, SECONDARY, LongInputStream.class); + presentStreamSource = dataStreamSources.getInputStreamSource(column, PRESENT, BooleanInputStream.class); + secondsStreamSource = dataStreamSources.getInputStreamSource(column, DATA, LongInputStream.class); + nanosStreamSource = dataStreamSources.getInputStreamSource(column, SECONDARY, LongInputStream.class); readOffset = 0; nextBatchSize = 0; @@ -240,7 +240,7 @@ public void startRowGroup(InputStreamSources dataStreamSources) public String toString() { return toStringHelper(this) - .addValue(streamDescriptor) + .addValue(column) .toString(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/stream/BooleanOutputStream.java b/presto-orc/src/main/java/io/prestosql/orc/stream/BooleanOutputStream.java index 5bdd75a98a03..b1b698bec1cd 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/stream/BooleanOutputStream.java +++ b/presto-orc/src/main/java/io/prestosql/orc/stream/BooleanOutputStream.java @@ -18,6 +18,7 @@ import io.prestosql.orc.checkpoint.BooleanStreamCheckpoint; import io.prestosql.orc.checkpoint.ByteStreamCheckpoint; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import org.openjdk.jol.info.ClassLayout; import java.util.ArrayList; @@ -155,10 +156,10 @@ public List getCheckpoints() } @Override - public StreamDataOutput getStreamDataOutput(int column) + public StreamDataOutput getStreamDataOutput(OrcColumnId columnId) { checkState(closed); - return byteOutputStream.getStreamDataOutput(column); + return byteOutputStream.getStreamDataOutput(columnId); } @Override diff --git a/presto-orc/src/main/java/io/prestosql/orc/stream/ByteArrayOutputStream.java b/presto-orc/src/main/java/io/prestosql/orc/stream/ByteArrayOutputStream.java index d9288f937259..a2f39337a2f0 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/stream/ByteArrayOutputStream.java +++ b/presto-orc/src/main/java/io/prestosql/orc/stream/ByteArrayOutputStream.java @@ -18,6 +18,7 @@ import io.prestosql.orc.OrcOutputBuffer; import io.prestosql.orc.checkpoint.ByteArrayStreamCheckpoint; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; import org.openjdk.jol.info.ClassLayout; @@ -81,9 +82,9 @@ public List getCheckpoints() } @Override - public StreamDataOutput getStreamDataOutput(int column) + public StreamDataOutput getStreamDataOutput(OrcColumnId columnId) { - return new StreamDataOutput(buffer::writeDataTo, new Stream(column, streamKind, toIntExact(buffer.getOutputDataSize()), false)); + return new StreamDataOutput(buffer::writeDataTo, new Stream(columnId, streamKind, toIntExact(buffer.getOutputDataSize()), false)); } @Override diff --git a/presto-orc/src/main/java/io/prestosql/orc/stream/ByteOutputStream.java b/presto-orc/src/main/java/io/prestosql/orc/stream/ByteOutputStream.java index ece62da19793..46d48d1e7ff6 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/stream/ByteOutputStream.java +++ b/presto-orc/src/main/java/io/prestosql/orc/stream/ByteOutputStream.java @@ -18,6 +18,7 @@ import io.prestosql.orc.OrcOutputBuffer; import io.prestosql.orc.checkpoint.ByteStreamCheckpoint; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.Stream; import org.openjdk.jol.info.ClassLayout; @@ -147,9 +148,9 @@ public List getCheckpoints() } @Override - public StreamDataOutput getStreamDataOutput(int column) + public StreamDataOutput getStreamDataOutput(OrcColumnId columnId) { - return new StreamDataOutput(buffer::writeDataTo, new Stream(column, DATA, toIntExact(buffer.getOutputDataSize()), false)); + return new StreamDataOutput(buffer::writeDataTo, new Stream(columnId, DATA, toIntExact(buffer.getOutputDataSize()), false)); } @Override diff --git a/presto-orc/src/main/java/io/prestosql/orc/stream/DecimalOutputStream.java b/presto-orc/src/main/java/io/prestosql/orc/stream/DecimalOutputStream.java index 66f500d92e98..fbc28203edb3 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/stream/DecimalOutputStream.java +++ b/presto-orc/src/main/java/io/prestosql/orc/stream/DecimalOutputStream.java @@ -18,6 +18,7 @@ import io.prestosql.orc.OrcOutputBuffer; import io.prestosql.orc.checkpoint.DecimalStreamCheckpoint; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.Stream; import io.prestosql.spi.type.Decimals; import org.openjdk.jol.info.ClassLayout; @@ -109,9 +110,9 @@ public List getCheckpoints() } @Override - public StreamDataOutput getStreamDataOutput(int column) + public StreamDataOutput getStreamDataOutput(OrcColumnId columnId) { - return new StreamDataOutput(buffer::writeDataTo, new Stream(column, DATA, toIntExact(buffer.getOutputDataSize()), true)); + return new StreamDataOutput(buffer::writeDataTo, new Stream(columnId, DATA, toIntExact(buffer.getOutputDataSize()), true)); } @Override diff --git a/presto-orc/src/main/java/io/prestosql/orc/stream/DoubleOutputStream.java b/presto-orc/src/main/java/io/prestosql/orc/stream/DoubleOutputStream.java index 955e068ec37e..e1fd3015d440 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/stream/DoubleOutputStream.java +++ b/presto-orc/src/main/java/io/prestosql/orc/stream/DoubleOutputStream.java @@ -17,6 +17,7 @@ import io.prestosql.orc.OrcOutputBuffer; import io.prestosql.orc.checkpoint.DoubleStreamCheckpoint; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.Stream; import org.openjdk.jol.info.ClassLayout; @@ -69,9 +70,9 @@ public List getCheckpoints() } @Override - public StreamDataOutput getStreamDataOutput(int column) + public StreamDataOutput getStreamDataOutput(OrcColumnId columnId) { - return new StreamDataOutput(buffer::writeDataTo, new Stream(column, DATA, toIntExact(buffer.getOutputDataSize()), false)); + return new StreamDataOutput(buffer::writeDataTo, new Stream(columnId, DATA, toIntExact(buffer.getOutputDataSize()), false)); } @Override diff --git a/presto-orc/src/main/java/io/prestosql/orc/stream/FloatOutputStream.java b/presto-orc/src/main/java/io/prestosql/orc/stream/FloatOutputStream.java index cb19ffcba0f4..f0585313defd 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/stream/FloatOutputStream.java +++ b/presto-orc/src/main/java/io/prestosql/orc/stream/FloatOutputStream.java @@ -17,6 +17,7 @@ import io.prestosql.orc.OrcOutputBuffer; import io.prestosql.orc.checkpoint.FloatStreamCheckpoint; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.Stream; import org.openjdk.jol.info.ClassLayout; @@ -69,9 +70,9 @@ public List getCheckpoints() } @Override - public StreamDataOutput getStreamDataOutput(int column) + public StreamDataOutput getStreamDataOutput(OrcColumnId columnId) { - return new StreamDataOutput(buffer::writeDataTo, new Stream(column, DATA, toIntExact(buffer.getOutputDataSize()), false)); + return new StreamDataOutput(buffer::writeDataTo, new Stream(columnId, DATA, toIntExact(buffer.getOutputDataSize()), false)); } @Override diff --git a/presto-orc/src/main/java/io/prestosql/orc/stream/InputStreamSources.java b/presto-orc/src/main/java/io/prestosql/orc/stream/InputStreamSources.java index eb9f2f629639..a79f91899f7f 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/stream/InputStreamSources.java +++ b/presto-orc/src/main/java/io/prestosql/orc/stream/InputStreamSources.java @@ -14,7 +14,7 @@ package io.prestosql.orc.stream; import com.google.common.collect.ImmutableMap; -import io.prestosql.orc.StreamDescriptor; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.StreamId; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -33,19 +33,19 @@ public InputStreamSources(Map> streamSources) this.streamSources = ImmutableMap.copyOf(requireNonNull(streamSources, "streamSources is null")); } - public > InputStreamSource getInputStreamSource(StreamDescriptor streamDescriptor, StreamKind streamKind, Class streamType) + public > InputStreamSource getInputStreamSource(OrcColumn column, StreamKind streamKind, Class streamType) { - requireNonNull(streamDescriptor, "streamDescriptor is null"); + requireNonNull(column, "column is null"); requireNonNull(streamType, "streamType is null"); - InputStreamSource streamSource = streamSources.get(new StreamId(streamDescriptor.getStreamId(), streamKind)); + InputStreamSource streamSource = streamSources.get(new StreamId(column.getColumnId(), streamKind)); if (streamSource == null) { streamSource = missingStreamSource(streamType); } checkArgument(streamType.isAssignableFrom(streamSource.getStreamType()), "%s must be of type %s, not %s", - streamDescriptor, + column, streamType.getName(), streamSource.getStreamType().getName()); diff --git a/presto-orc/src/main/java/io/prestosql/orc/stream/LongOutputStreamV1.java b/presto-orc/src/main/java/io/prestosql/orc/stream/LongOutputStreamV1.java index bed5e213abb0..d9ea2e583a0c 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/stream/LongOutputStreamV1.java +++ b/presto-orc/src/main/java/io/prestosql/orc/stream/LongOutputStreamV1.java @@ -19,6 +19,7 @@ import io.prestosql.orc.checkpoint.LongStreamCheckpoint; import io.prestosql.orc.checkpoint.LongStreamV1Checkpoint; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; import org.openjdk.jol.info.ClassLayout; @@ -188,9 +189,9 @@ public List getCheckpoints() } @Override - public StreamDataOutput getStreamDataOutput(int column) + public StreamDataOutput getStreamDataOutput(OrcColumnId columnId) { - return new StreamDataOutput(buffer::writeDataTo, new Stream(column, streamKind, toIntExact(buffer.getOutputDataSize()), true)); + return new StreamDataOutput(buffer::writeDataTo, new Stream(columnId, streamKind, toIntExact(buffer.getOutputDataSize()), true)); } @Override diff --git a/presto-orc/src/main/java/io/prestosql/orc/stream/LongOutputStreamV2.java b/presto-orc/src/main/java/io/prestosql/orc/stream/LongOutputStreamV2.java index d1b3f8b1193f..25d89adb18f2 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/stream/LongOutputStreamV2.java +++ b/presto-orc/src/main/java/io/prestosql/orc/stream/LongOutputStreamV2.java @@ -20,6 +20,7 @@ import io.prestosql.orc.checkpoint.LongStreamCheckpoint; import io.prestosql.orc.checkpoint.LongStreamV2Checkpoint; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; import org.openjdk.jol.info.ClassLayout; @@ -747,9 +748,9 @@ public List getCheckpoints() } @Override - public StreamDataOutput getStreamDataOutput(int column) + public StreamDataOutput getStreamDataOutput(OrcColumnId columnId) { - return new StreamDataOutput(buffer::writeDataTo, new Stream(column, streamKind, toIntExact(buffer.getOutputDataSize()), true)); + return new StreamDataOutput(buffer::writeDataTo, new Stream(columnId, streamKind, toIntExact(buffer.getOutputDataSize()), true)); } @Override diff --git a/presto-orc/src/main/java/io/prestosql/orc/stream/PresentOutputStream.java b/presto-orc/src/main/java/io/prestosql/orc/stream/PresentOutputStream.java index f208369d056e..46ea61bfd4af 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/stream/PresentOutputStream.java +++ b/presto-orc/src/main/java/io/prestosql/orc/stream/PresentOutputStream.java @@ -16,6 +16,7 @@ import io.prestosql.orc.OrcOutputBuffer; import io.prestosql.orc.checkpoint.BooleanStreamCheckpoint; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.Stream; import org.openjdk.jol.info.ClassLayout; @@ -101,15 +102,15 @@ public Optional> getCheckpoints() return Optional.of(booleanOutputStream.getCheckpoints()); } - public Optional getStreamDataOutput(int column) + public Optional getStreamDataOutput(OrcColumnId columnId) { checkArgument(closed); if (booleanOutputStream == null) { return Optional.empty(); } - StreamDataOutput streamDataOutput = booleanOutputStream.getStreamDataOutput(column); + StreamDataOutput streamDataOutput = booleanOutputStream.getStreamDataOutput(columnId); // rewrite the DATA stream created by the boolean output stream to a PRESENT stream - Stream stream = new Stream(column, PRESENT, toIntExact(streamDataOutput.size()), streamDataOutput.getStream().isUseVInts()); + Stream stream = new Stream(columnId, PRESENT, toIntExact(streamDataOutput.size()), streamDataOutput.getStream().isUseVInts()); return Optional.of(new StreamDataOutput( sliceOutput -> { streamDataOutput.writeData(sliceOutput); diff --git a/presto-orc/src/main/java/io/prestosql/orc/stream/ValueOutputStream.java b/presto-orc/src/main/java/io/prestosql/orc/stream/ValueOutputStream.java index 8c4a42dec631..51442cb58191 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/stream/ValueOutputStream.java +++ b/presto-orc/src/main/java/io/prestosql/orc/stream/ValueOutputStream.java @@ -14,6 +14,7 @@ package io.prestosql.orc.stream; import io.prestosql.orc.checkpoint.StreamCheckpoint; +import io.prestosql.orc.metadata.OrcColumnId; import java.util.List; @@ -25,7 +26,7 @@ public interface ValueOutputStream List getCheckpoints(); - StreamDataOutput getStreamDataOutput(int column); + StreamDataOutput getStreamDataOutput(OrcColumnId columnId); /** * This method returns the size of the flushed data plus any unflushed data. diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/BooleanColumnWriter.java b/presto-orc/src/main/java/io/prestosql/orc/writer/BooleanColumnWriter.java index 443a9da0de57..49f22d380551 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/BooleanColumnWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/BooleanColumnWriter.java @@ -20,6 +20,7 @@ import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.CompressedMetadataWriter; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.RowGroupIndex; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -50,7 +51,7 @@ public class BooleanColumnWriter private static final int INSTANCE_SIZE = ClassLayout.parseClass(BooleanColumnWriter.class).instanceSize(); private static final ColumnEncoding COLUMN_ENCODING = new ColumnEncoding(DIRECT, 0); - private final int column; + private final OrcColumnId columnId; private final Type type; private final boolean compressed; private final BooleanOutputStream dataStream; @@ -62,10 +63,9 @@ public class BooleanColumnWriter private boolean closed; - public BooleanColumnWriter(int column, Type type, CompressionKind compression, int bufferSize) + public BooleanColumnWriter(OrcColumnId columnId, Type type, CompressionKind compression, int bufferSize) { - checkArgument(column >= 0, "column is negative"); - this.column = column; + this.columnId = requireNonNull(columnId, "columnId is null"); this.type = requireNonNull(type, "type is null"); this.compressed = requireNonNull(compression, "compression is null") != NONE; this.dataStream = new BooleanOutputStream(compression, bufferSize); @@ -73,9 +73,9 @@ public BooleanColumnWriter(int column, Type type, CompressionKind compression, i } @Override - public Map getColumnEncodings() + public Map getColumnEncodings() { - return ImmutableMap.of(column, COLUMN_ENCODING); + return ImmutableMap.of(columnId, COLUMN_ENCODING); } @Override @@ -107,13 +107,13 @@ public void writeBlock(Block block) } @Override - public Map finishRowGroup() + public Map finishRowGroup() { checkState(!closed); ColumnStatistics statistics = statisticsBuilder.buildColumnStatistics(); rowGroupColumnStatistics.add(statistics); statisticsBuilder = new BooleanStatisticsBuilder(); - return ImmutableMap.of(column, statistics); + return ImmutableMap.of(columnId, statistics); } @Override @@ -125,10 +125,10 @@ public void close() } @Override - public Map getColumnStripeStatistics() + public Map getColumnStripeStatistics() { checkState(closed); - return ImmutableMap.of(column, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); + return ImmutableMap.of(columnId, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); } @Override @@ -151,7 +151,7 @@ public List getIndexStreams(CompressedMetadataWriter metadataW } Slice slice = metadataWriter.writeRowIndexes(rowGroupIndexes.build()); - Stream stream = new Stream(column, StreamKind.ROW_INDEX, slice.length(), false); + Stream stream = new Stream(columnId, StreamKind.ROW_INDEX, slice.length(), false); return ImmutableList.of(new StreamDataOutput(slice, stream)); } @@ -172,8 +172,8 @@ public List getDataStreams() checkState(closed); ImmutableList.Builder outputDataStreams = ImmutableList.builder(); - presentStream.getStreamDataOutput(column).ifPresent(outputDataStreams::add); - outputDataStreams.add(dataStream.getStreamDataOutput(column)); + presentStream.getStreamDataOutput(columnId).ifPresent(outputDataStreams::add); + outputDataStreams.add(dataStream.getStreamDataOutput(columnId)); return outputDataStreams.build(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/ByteColumnWriter.java b/presto-orc/src/main/java/io/prestosql/orc/writer/ByteColumnWriter.java index bf1e0aa2fc52..b5127cd5e809 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/ByteColumnWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/ByteColumnWriter.java @@ -22,6 +22,7 @@ import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.CompressedMetadataWriter; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.RowGroupIndex; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -51,7 +52,7 @@ public class ByteColumnWriter private static final int INSTANCE_SIZE = ClassLayout.parseClass(ByteColumnWriter.class).instanceSize(); private static final ColumnEncoding COLUMN_ENCODING = new ColumnEncoding(DIRECT, 0); - private final int column; + private final OrcColumnId columnId; private final Type type; private final boolean compressed; private final ByteOutputStream dataStream; @@ -63,10 +64,9 @@ public class ByteColumnWriter private boolean closed; - public ByteColumnWriter(int column, Type type, CompressionKind compression, int bufferSize) + public ByteColumnWriter(OrcColumnId columnId, Type type, CompressionKind compression, int bufferSize) { - checkArgument(column >= 0, "column is negative"); - this.column = column; + this.columnId = requireNonNull(columnId, "columnId is null"); this.type = requireNonNull(type, "type is null"); this.compressed = requireNonNull(compression, "compression is null") != NONE; this.dataStream = new ByteOutputStream(compression, bufferSize); @@ -74,9 +74,9 @@ public ByteColumnWriter(int column, Type type, CompressionKind compression, int } @Override - public Map getColumnEncodings() + public Map getColumnEncodings() { - return ImmutableMap.of(column, COLUMN_ENCODING); + return ImmutableMap.of(columnId, COLUMN_ENCODING); } @Override @@ -107,13 +107,13 @@ public void writeBlock(Block block) } @Override - public Map finishRowGroup() + public Map finishRowGroup() { checkState(!closed); ColumnStatistics statistics = new ColumnStatistics((long) nonNullValueCount, 0, null, null, null, null, null, null, null, null); rowGroupColumnStatistics.add(statistics); nonNullValueCount = 0; - return ImmutableMap.of(column, statistics); + return ImmutableMap.of(columnId, statistics); } @Override @@ -125,10 +125,10 @@ public void close() } @Override - public Map getColumnStripeStatistics() + public Map getColumnStripeStatistics() { checkState(closed); - return ImmutableMap.of(column, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); + return ImmutableMap.of(columnId, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); } @Override @@ -151,7 +151,7 @@ public List getIndexStreams(CompressedMetadataWriter metadataW } Slice slice = metadataWriter.writeRowIndexes(rowGroupIndexes.build()); - Stream stream = new Stream(column, StreamKind.ROW_INDEX, slice.length(), false); + Stream stream = new Stream(columnId, StreamKind.ROW_INDEX, slice.length(), false); return ImmutableList.of(new StreamDataOutput(slice, stream)); } @@ -172,8 +172,8 @@ public List getDataStreams() checkState(closed); ImmutableList.Builder outputDataStreams = ImmutableList.builder(); - presentStream.getStreamDataOutput(column).ifPresent(outputDataStreams::add); - outputDataStreams.add(dataStream.getStreamDataOutput(column)); + presentStream.getStreamDataOutput(columnId).ifPresent(outputDataStreams::add); + outputDataStreams.add(dataStream.getStreamDataOutput(columnId)); return outputDataStreams.build(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/ColumnWriter.java b/presto-orc/src/main/java/io/prestosql/orc/writer/ColumnWriter.java index 2b5483f53cb8..62a08704b3a1 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/ColumnWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/ColumnWriter.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.CompressedMetadataWriter; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.statistics.ColumnStatistics; import io.prestosql.orc.stream.StreamDataOutput; import io.prestosql.spi.block.Block; @@ -31,17 +32,17 @@ default List getNestedColumnWriters() return ImmutableList.of(); } - Map getColumnEncodings(); + Map getColumnEncodings(); void beginRowGroup(); void writeBlock(Block block); - Map finishRowGroup(); + Map finishRowGroup(); void close(); - Map getColumnStripeStatistics(); + Map getColumnStripeStatistics(); /** * Write index streams to the output and return the streams in the diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/ColumnWriters.java b/presto-orc/src/main/java/io/prestosql/orc/writer/ColumnWriters.java index dfce456f40e5..857ec0dc8cbf 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/ColumnWriters.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/ColumnWriters.java @@ -15,7 +15,9 @@ import com.google.common.collect.ImmutableList; import io.airlift.units.DataSize; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.OrcType; import io.prestosql.orc.metadata.statistics.BinaryStatisticsBuilder; import io.prestosql.orc.metadata.statistics.DateStatisticsBuilder; @@ -23,8 +25,6 @@ import io.prestosql.spi.type.Type; import org.joda.time.DateTimeZone; -import java.util.List; - import static java.util.Objects.requireNonNull; public final class ColumnWriters @@ -32,8 +32,8 @@ public final class ColumnWriters private ColumnWriters() {} public static ColumnWriter createColumnWriter( - int columnIndex, - List orcTypes, + OrcColumnId columnId, + ColumnMetadata orcTypes, Type type, CompressionKind compression, int bufferSize, @@ -41,47 +41,47 @@ public static ColumnWriter createColumnWriter( DataSize stringStatisticsLimit) { requireNonNull(type, "type is null"); - OrcType orcType = orcTypes.get(columnIndex); + OrcType orcType = orcTypes.get(columnId); switch (orcType.getOrcTypeKind()) { case BOOLEAN: - return new BooleanColumnWriter(columnIndex, type, compression, bufferSize); + return new BooleanColumnWriter(columnId, type, compression, bufferSize); case FLOAT: - return new FloatColumnWriter(columnIndex, type, compression, bufferSize); + return new FloatColumnWriter(columnId, type, compression, bufferSize); case DOUBLE: - return new DoubleColumnWriter(columnIndex, type, compression, bufferSize); + return new DoubleColumnWriter(columnId, type, compression, bufferSize); case BYTE: - return new ByteColumnWriter(columnIndex, type, compression, bufferSize); + return new ByteColumnWriter(columnId, type, compression, bufferSize); case DATE: - return new LongColumnWriter(columnIndex, type, compression, bufferSize, DateStatisticsBuilder::new); + return new LongColumnWriter(columnId, type, compression, bufferSize, DateStatisticsBuilder::new); case SHORT: case INT: case LONG: - return new LongColumnWriter(columnIndex, type, compression, bufferSize, IntegerStatisticsBuilder::new); + return new LongColumnWriter(columnId, type, compression, bufferSize, IntegerStatisticsBuilder::new); case DECIMAL: - return new DecimalColumnWriter(columnIndex, type, compression, bufferSize); + return new DecimalColumnWriter(columnId, type, compression, bufferSize); case TIMESTAMP: - return new TimestampColumnWriter(columnIndex, type, compression, bufferSize, hiveStorageTimeZone); + return new TimestampColumnWriter(columnId, type, compression, bufferSize, hiveStorageTimeZone); case BINARY: - return new SliceDirectColumnWriter(columnIndex, type, compression, bufferSize, BinaryStatisticsBuilder::new); + return new SliceDirectColumnWriter(columnId, type, compression, bufferSize, BinaryStatisticsBuilder::new); case CHAR: case VARCHAR: case STRING: - return new SliceDictionaryColumnWriter(columnIndex, type, compression, bufferSize, stringStatisticsLimit); + return new SliceDictionaryColumnWriter(columnId, type, compression, bufferSize, stringStatisticsLimit); case LIST: { - int fieldColumnIndex = orcType.getFieldTypeIndex(0); + OrcColumnId fieldColumnIndex = orcType.getFieldTypeIndex(0); Type fieldType = type.getTypeParameters().get(0); ColumnWriter elementWriter = createColumnWriter(fieldColumnIndex, orcTypes, fieldType, compression, bufferSize, hiveStorageTimeZone, stringStatisticsLimit); - return new ListColumnWriter(columnIndex, compression, bufferSize, elementWriter); + return new ListColumnWriter(columnId, compression, bufferSize, elementWriter); } case MAP: { @@ -101,17 +101,17 @@ public static ColumnWriter createColumnWriter( bufferSize, hiveStorageTimeZone, stringStatisticsLimit); - return new MapColumnWriter(columnIndex, compression, bufferSize, keyWriter, valueWriter); + return new MapColumnWriter(columnId, compression, bufferSize, keyWriter, valueWriter); } case STRUCT: { ImmutableList.Builder fieldWriters = ImmutableList.builder(); for (int fieldId = 0; fieldId < orcType.getFieldCount(); fieldId++) { - int fieldColumnIndex = orcType.getFieldTypeIndex(fieldId); + OrcColumnId fieldColumnIndex = orcType.getFieldTypeIndex(fieldId); Type fieldType = type.getTypeParameters().get(fieldId); fieldWriters.add(createColumnWriter(fieldColumnIndex, orcTypes, fieldType, compression, bufferSize, hiveStorageTimeZone, stringStatisticsLimit)); } - return new StructColumnWriter(columnIndex, compression, bufferSize, fieldWriters.build()); + return new StructColumnWriter(columnId, compression, bufferSize, fieldWriters.build()); } } diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/DecimalColumnWriter.java b/presto-orc/src/main/java/io/prestosql/orc/writer/DecimalColumnWriter.java index b9a74740c0d6..bd3ac3792c59 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/DecimalColumnWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/DecimalColumnWriter.java @@ -22,6 +22,7 @@ import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.CompressedMetadataWriter; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.RowGroupIndex; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -57,7 +58,7 @@ public class DecimalColumnWriter implements ColumnWriter { private static final int INSTANCE_SIZE = ClassLayout.parseClass(DecimalColumnWriter.class).instanceSize(); - private final int column; + private final OrcColumnId columnId; private final DecimalType type; private final ColumnEncoding columnEncoding; private final boolean compressed; @@ -72,10 +73,9 @@ public class DecimalColumnWriter private boolean closed; - public DecimalColumnWriter(int column, Type type, CompressionKind compression, int bufferSize) + public DecimalColumnWriter(OrcColumnId columnId, Type type, CompressionKind compression, int bufferSize) { - checkArgument(column >= 0, "column is negative"); - this.column = column; + this.columnId = requireNonNull(columnId, "columnId is null"); this.type = (DecimalType) requireNonNull(type, "type is null"); this.compressed = requireNonNull(compression, "compression is null") != NONE; this.columnEncoding = new ColumnEncoding(DIRECT_V2, 0); @@ -91,9 +91,9 @@ public DecimalColumnWriter(int column, Type type, CompressionKind compression, i } @Override - public Map getColumnEncodings() + public Map getColumnEncodings() { - return ImmutableMap.of(column, columnEncoding); + return ImmutableMap.of(columnId, columnEncoding); } @Override @@ -143,7 +143,7 @@ public void writeBlock(Block block) } @Override - public Map finishRowGroup() + public Map finishRowGroup() { checkState(!closed); ColumnStatistics statistics; @@ -157,7 +157,7 @@ public Map finishRowGroup() } rowGroupColumnStatistics.add(statistics); - return ImmutableMap.of(column, statistics); + return ImmutableMap.of(columnId, statistics); } @Override @@ -170,10 +170,10 @@ public void close() } @Override - public Map getColumnStripeStatistics() + public Map getColumnStripeStatistics() { checkState(closed); - return ImmutableMap.of(column, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); + return ImmutableMap.of(columnId, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); } @Override @@ -198,7 +198,7 @@ public List getIndexStreams(CompressedMetadataWriter metadataW } Slice slice = metadataWriter.writeRowIndexes(rowGroupIndexes.build()); - Stream stream = new Stream(column, StreamKind.ROW_INDEX, slice.length(), false); + Stream stream = new Stream(columnId, StreamKind.ROW_INDEX, slice.length(), false); return ImmutableList.of(new StreamDataOutput(slice, stream)); } @@ -221,9 +221,9 @@ public List getDataStreams() checkState(closed); ImmutableList.Builder outputDataStreams = ImmutableList.builder(); - presentStream.getStreamDataOutput(column).ifPresent(outputDataStreams::add); - outputDataStreams.add(dataStream.getStreamDataOutput(column)); - outputDataStreams.add(scaleStream.getStreamDataOutput(column)); + presentStream.getStreamDataOutput(columnId).ifPresent(outputDataStreams::add); + outputDataStreams.add(dataStream.getStreamDataOutput(columnId)); + outputDataStreams.add(scaleStream.getStreamDataOutput(columnId)); return outputDataStreams.build(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/DoubleColumnWriter.java b/presto-orc/src/main/java/io/prestosql/orc/writer/DoubleColumnWriter.java index 6feba45469ab..2c71c9fd2a37 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/DoubleColumnWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/DoubleColumnWriter.java @@ -21,6 +21,7 @@ import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.CompressedMetadataWriter; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.RowGroupIndex; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -51,7 +52,7 @@ public class DoubleColumnWriter private static final int INSTANCE_SIZE = ClassLayout.parseClass(DoubleColumnWriter.class).instanceSize(); private static final ColumnEncoding COLUMN_ENCODING = new ColumnEncoding(DIRECT, 0); - private final int column; + private final OrcColumnId columnId; private final Type type; private final boolean compressed; private final DoubleOutputStream dataStream; @@ -63,10 +64,9 @@ public class DoubleColumnWriter private boolean closed; - public DoubleColumnWriter(int column, Type type, CompressionKind compression, int bufferSize) + public DoubleColumnWriter(OrcColumnId columnId, Type type, CompressionKind compression, int bufferSize) { - checkArgument(column >= 0, "column is negative"); - this.column = column; + this.columnId = requireNonNull(columnId, "columnId is null"); this.type = requireNonNull(type, "type is null"); this.compressed = requireNonNull(compression, "compression is null") != NONE; this.dataStream = new DoubleOutputStream(compression, bufferSize); @@ -74,9 +74,9 @@ public DoubleColumnWriter(int column, Type type, CompressionKind compression, in } @Override - public Map getColumnEncodings() + public Map getColumnEncodings() { - return ImmutableMap.of(column, COLUMN_ENCODING); + return ImmutableMap.of(columnId, COLUMN_ENCODING); } @Override @@ -108,13 +108,13 @@ public void writeBlock(Block block) } @Override - public Map finishRowGroup() + public Map finishRowGroup() { checkState(!closed); ColumnStatistics statistics = statisticsBuilder.buildColumnStatistics(); rowGroupColumnStatistics.add(statistics); statisticsBuilder = new DoubleStatisticsBuilder(); - return ImmutableMap.of(column, statistics); + return ImmutableMap.of(columnId, statistics); } @Override @@ -126,10 +126,10 @@ public void close() } @Override - public Map getColumnStripeStatistics() + public Map getColumnStripeStatistics() { checkState(closed); - return ImmutableMap.of(column, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); + return ImmutableMap.of(columnId, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); } @Override @@ -152,7 +152,7 @@ public List getIndexStreams(CompressedMetadataWriter metadataW } Slice slice = metadataWriter.writeRowIndexes(rowGroupIndexes.build()); - Stream stream = new Stream(column, StreamKind.ROW_INDEX, slice.length(), false); + Stream stream = new Stream(columnId, StreamKind.ROW_INDEX, slice.length(), false); return ImmutableList.of(new StreamDataOutput(slice, stream)); } @@ -173,8 +173,8 @@ public List getDataStreams() checkState(closed); ImmutableList.Builder outputDataStreams = ImmutableList.builder(); - presentStream.getStreamDataOutput(column).ifPresent(outputDataStreams::add); - outputDataStreams.add(dataStream.getStreamDataOutput(column)); + presentStream.getStreamDataOutput(columnId).ifPresent(outputDataStreams::add); + outputDataStreams.add(dataStream.getStreamDataOutput(columnId)); return outputDataStreams.build(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/FloatColumnWriter.java b/presto-orc/src/main/java/io/prestosql/orc/writer/FloatColumnWriter.java index a99f7b367234..52912d3ae2cf 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/FloatColumnWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/FloatColumnWriter.java @@ -21,6 +21,7 @@ import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.CompressedMetadataWriter; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.RowGroupIndex; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -52,7 +53,7 @@ public class FloatColumnWriter private static final int INSTANCE_SIZE = ClassLayout.parseClass(FloatColumnWriter.class).instanceSize(); private static final ColumnEncoding COLUMN_ENCODING = new ColumnEncoding(DIRECT, 0); - private final int column; + private final OrcColumnId columnId; private final Type type; private final boolean compressed; private final FloatOutputStream dataStream; @@ -64,10 +65,9 @@ public class FloatColumnWriter private boolean closed; - public FloatColumnWriter(int column, Type type, CompressionKind compression, int bufferSize) + public FloatColumnWriter(OrcColumnId columnId, Type type, CompressionKind compression, int bufferSize) { - checkArgument(column >= 0, "column is negative"); - this.column = column; + this.columnId = requireNonNull(columnId, "columnId is null"); this.type = requireNonNull(type, "type is null"); this.compressed = requireNonNull(compression, "compression is null") != NONE; this.dataStream = new FloatOutputStream(compression, bufferSize); @@ -75,9 +75,9 @@ public FloatColumnWriter(int column, Type type, CompressionKind compression, int } @Override - public Map getColumnEncodings() + public Map getColumnEncodings() { - return ImmutableMap.of(column, COLUMN_ENCODING); + return ImmutableMap.of(columnId, COLUMN_ENCODING); } @Override @@ -110,13 +110,13 @@ public void writeBlock(Block block) } @Override - public Map finishRowGroup() + public Map finishRowGroup() { checkState(!closed); ColumnStatistics statistics = statisticsBuilder.buildColumnStatistics(); rowGroupColumnStatistics.add(statistics); statisticsBuilder = new DoubleStatisticsBuilder(); - return ImmutableMap.of(column, statistics); + return ImmutableMap.of(columnId, statistics); } @Override @@ -128,10 +128,10 @@ public void close() } @Override - public Map getColumnStripeStatistics() + public Map getColumnStripeStatistics() { checkState(closed); - return ImmutableMap.of(column, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); + return ImmutableMap.of(columnId, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); } @Override @@ -154,7 +154,7 @@ public List getIndexStreams(CompressedMetadataWriter metadataW } Slice slice = metadataWriter.writeRowIndexes(rowGroupIndexes.build()); - Stream stream = new Stream(column, StreamKind.ROW_INDEX, slice.length(), false); + Stream stream = new Stream(columnId, StreamKind.ROW_INDEX, slice.length(), false); return ImmutableList.of(new StreamDataOutput(slice, stream)); } @@ -175,8 +175,8 @@ public List getDataStreams() checkState(closed); ImmutableList.Builder outputDataStreams = ImmutableList.builder(); - presentStream.getStreamDataOutput(column).ifPresent(outputDataStreams::add); - outputDataStreams.add(dataStream.getStreamDataOutput(column)); + presentStream.getStreamDataOutput(columnId).ifPresent(outputDataStreams::add); + outputDataStreams.add(dataStream.getStreamDataOutput(columnId)); return outputDataStreams.build(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/ListColumnWriter.java b/presto-orc/src/main/java/io/prestosql/orc/writer/ListColumnWriter.java index 78e30c72d712..eba74d63dd39 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/ListColumnWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/ListColumnWriter.java @@ -21,6 +21,7 @@ import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.CompressedMetadataWriter; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.RowGroupIndex; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -50,7 +51,7 @@ public class ListColumnWriter implements ColumnWriter { private static final int INSTANCE_SIZE = ClassLayout.parseClass(ListColumnWriter.class).instanceSize(); - private final int column; + private final OrcColumnId columnId; private final boolean compressed; private final ColumnEncoding columnEncoding; private final LongOutputStream lengthStream; @@ -63,10 +64,9 @@ public class ListColumnWriter private boolean closed; - public ListColumnWriter(int column, CompressionKind compression, int bufferSize, ColumnWriter elementWriter) + public ListColumnWriter(OrcColumnId columnId, CompressionKind compression, int bufferSize, ColumnWriter elementWriter) { - checkArgument(column >= 0, "column is negative"); - this.column = column; + this.columnId = requireNonNull(columnId, "columnId is null"); this.compressed = requireNonNull(compression, "compression is null") != NONE; this.columnEncoding = new ColumnEncoding(DIRECT_V2, 0); this.elementWriter = requireNonNull(elementWriter, "elementWriter is null"); @@ -84,10 +84,10 @@ public List getNestedColumnWriters() } @Override - public Map getColumnEncodings() + public Map getColumnEncodings() { - ImmutableMap.Builder encodings = ImmutableMap.builder(); - encodings.put(column, columnEncoding); + ImmutableMap.Builder encodings = ImmutableMap.builder(); + encodings.put(columnId, columnEncoding); encodings.putAll(elementWriter.getColumnEncodings()); return encodings.build(); } @@ -131,7 +131,7 @@ private void writeColumnarArray(ColumnarArray columnarArray) } @Override - public Map finishRowGroup() + public Map finishRowGroup() { checkState(!closed); @@ -139,8 +139,8 @@ public Map finishRowGroup() rowGroupColumnStatistics.add(statistics); nonNullValueCount = 0; - ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); - columnStatistics.put(column, statistics); + ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); + columnStatistics.put(columnId, statistics); columnStatistics.putAll(elementWriter.finishRowGroup()); return columnStatistics.build(); } @@ -155,11 +155,11 @@ public void close() } @Override - public Map getColumnStripeStatistics() + public Map getColumnStripeStatistics() { checkState(closed); - ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); - columnStatistics.put(column, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); + ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); + columnStatistics.put(columnId, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); columnStatistics.putAll(elementWriter.getColumnStripeStatistics()); return columnStatistics.build(); } @@ -184,7 +184,7 @@ public List getIndexStreams(CompressedMetadataWriter metadataW } Slice slice = metadataWriter.writeRowIndexes(rowGroupIndexes.build()); - Stream stream = new Stream(column, StreamKind.ROW_INDEX, slice.length(), false); + Stream stream = new Stream(columnId, StreamKind.ROW_INDEX, slice.length(), false); ImmutableList.Builder indexStreams = ImmutableList.builder(); indexStreams.add(new StreamDataOutput(slice, stream)); @@ -209,8 +209,8 @@ public List getDataStreams() checkState(closed); ImmutableList.Builder outputDataStreams = ImmutableList.builder(); - presentStream.getStreamDataOutput(column).ifPresent(outputDataStreams::add); - outputDataStreams.add(lengthStream.getStreamDataOutput(column)); + presentStream.getStreamDataOutput(columnId).ifPresent(outputDataStreams::add); + outputDataStreams.add(lengthStream.getStreamDataOutput(columnId)); outputDataStreams.addAll(elementWriter.getDataStreams()); return outputDataStreams.build(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/LongColumnWriter.java b/presto-orc/src/main/java/io/prestosql/orc/writer/LongColumnWriter.java index c16bfe4a879c..4fd7fc4fd8ae 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/LongColumnWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/LongColumnWriter.java @@ -21,6 +21,7 @@ import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.CompressedMetadataWriter; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.RowGroupIndex; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -52,7 +53,7 @@ public class LongColumnWriter implements ColumnWriter { private static final int INSTANCE_SIZE = ClassLayout.parseClass(LongColumnWriter.class).instanceSize(); - private final int column; + private final OrcColumnId columnId; private final Type type; private final boolean compressed; private final ColumnEncoding columnEncoding; @@ -66,10 +67,9 @@ public class LongColumnWriter private boolean closed; - public LongColumnWriter(int column, Type type, CompressionKind compression, int bufferSize, Supplier statisticsBuilderSupplier) + public LongColumnWriter(OrcColumnId columnId, Type type, CompressionKind compression, int bufferSize, Supplier statisticsBuilderSupplier) { - checkArgument(column >= 0, "column is negative"); - this.column = column; + this.columnId = requireNonNull(columnId, "columnId is null"); this.type = requireNonNull(type, "type is null"); this.compressed = requireNonNull(compression, "compression is null") != NONE; this.columnEncoding = new ColumnEncoding(DIRECT_V2, 0); @@ -80,9 +80,9 @@ public LongColumnWriter(int column, Type type, CompressionKind compression, int } @Override - public Map getColumnEncodings() + public Map getColumnEncodings() { - return ImmutableMap.of(column, columnEncoding); + return ImmutableMap.of(columnId, columnEncoding); } @Override @@ -114,13 +114,13 @@ public void writeBlock(Block block) } @Override - public Map finishRowGroup() + public Map finishRowGroup() { checkState(!closed); ColumnStatistics statistics = statisticsBuilder.buildColumnStatistics(); rowGroupColumnStatistics.add(statistics); statisticsBuilder = statisticsBuilderSupplier.get(); - return ImmutableMap.of(column, statistics); + return ImmutableMap.of(columnId, statistics); } @Override @@ -132,10 +132,10 @@ public void close() } @Override - public Map getColumnStripeStatistics() + public Map getColumnStripeStatistics() { checkState(closed); - return ImmutableMap.of(column, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); + return ImmutableMap.of(columnId, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); } @Override @@ -158,7 +158,7 @@ public List getIndexStreams(CompressedMetadataWriter metadataW } Slice slice = metadataWriter.writeRowIndexes(rowGroupIndexes.build()); - Stream stream = new Stream(column, StreamKind.ROW_INDEX, slice.length(), false); + Stream stream = new Stream(columnId, StreamKind.ROW_INDEX, slice.length(), false); return ImmutableList.of(new StreamDataOutput(slice, stream)); } @@ -179,8 +179,8 @@ public List getDataStreams() checkState(closed); ImmutableList.Builder outputDataStreams = ImmutableList.builder(); - presentStream.getStreamDataOutput(column).ifPresent(outputDataStreams::add); - outputDataStreams.add(dataStream.getStreamDataOutput(column)); + presentStream.getStreamDataOutput(columnId).ifPresent(outputDataStreams::add); + outputDataStreams.add(dataStream.getStreamDataOutput(columnId)); return outputDataStreams.build(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/MapColumnWriter.java b/presto-orc/src/main/java/io/prestosql/orc/writer/MapColumnWriter.java index 2df072d5295b..847c79d4c079 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/MapColumnWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/MapColumnWriter.java @@ -21,6 +21,7 @@ import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.CompressedMetadataWriter; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.RowGroupIndex; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -50,7 +51,7 @@ public class MapColumnWriter implements ColumnWriter { private static final int INSTANCE_SIZE = ClassLayout.parseClass(MapColumnWriter.class).instanceSize(); - private final int column; + private final OrcColumnId columnId; private final boolean compressed; private final ColumnEncoding columnEncoding; private final LongOutputStream lengthStream; @@ -64,10 +65,9 @@ public class MapColumnWriter private boolean closed; - public MapColumnWriter(int column, CompressionKind compression, int bufferSize, ColumnWriter keyWriter, ColumnWriter valueWriter) + public MapColumnWriter(OrcColumnId columnId, CompressionKind compression, int bufferSize, ColumnWriter keyWriter, ColumnWriter valueWriter) { - checkArgument(column >= 0, "column is negative"); - this.column = column; + this.columnId = requireNonNull(columnId, "columnId is null"); this.compressed = requireNonNull(compression, "compression is null") != NONE; this.columnEncoding = new ColumnEncoding(DIRECT_V2, 0); this.keyWriter = requireNonNull(keyWriter, "keyWriter is null"); @@ -88,10 +88,10 @@ public List getNestedColumnWriters() } @Override - public Map getColumnEncodings() + public Map getColumnEncodings() { - ImmutableMap.Builder encodings = ImmutableMap.builder(); - encodings.put(column, columnEncoding); + ImmutableMap.Builder encodings = ImmutableMap.builder(); + encodings.put(columnId, columnEncoding); encodings.putAll(keyWriter.getColumnEncodings()); encodings.putAll(valueWriter.getColumnEncodings()); return encodings.build(); @@ -138,7 +138,7 @@ private void writeColumnarMap(ColumnarMap columnarMap) } @Override - public Map finishRowGroup() + public Map finishRowGroup() { checkState(!closed); @@ -146,8 +146,8 @@ public Map finishRowGroup() rowGroupColumnStatistics.add(statistics); nonNullValueCount = 0; - ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); - columnStatistics.put(column, statistics); + ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); + columnStatistics.put(columnId, statistics); columnStatistics.putAll(keyWriter.finishRowGroup()); columnStatistics.putAll(valueWriter.finishRowGroup()); return columnStatistics.build(); @@ -164,11 +164,11 @@ public void close() } @Override - public Map getColumnStripeStatistics() + public Map getColumnStripeStatistics() { checkState(closed); - ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); - columnStatistics.put(column, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); + ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); + columnStatistics.put(columnId, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); columnStatistics.putAll(keyWriter.getColumnStripeStatistics()); columnStatistics.putAll(valueWriter.getColumnStripeStatistics()); return columnStatistics.build(); @@ -194,7 +194,7 @@ public List getIndexStreams(CompressedMetadataWriter metadataW } Slice slice = metadataWriter.writeRowIndexes(rowGroupIndexes.build()); - Stream stream = new Stream(column, StreamKind.ROW_INDEX, slice.length(), false); + Stream stream = new Stream(columnId, StreamKind.ROW_INDEX, slice.length(), false); ImmutableList.Builder indexStreams = ImmutableList.builder(); indexStreams.add(new StreamDataOutput(slice, stream)); @@ -220,8 +220,8 @@ public List getDataStreams() checkState(closed); ImmutableList.Builder outputDataStreams = ImmutableList.builder(); - presentStream.getStreamDataOutput(column).ifPresent(outputDataStreams::add); - outputDataStreams.add(lengthStream.getStreamDataOutput(column)); + presentStream.getStreamDataOutput(columnId).ifPresent(outputDataStreams::add); + outputDataStreams.add(lengthStream.getStreamDataOutput(columnId)); outputDataStreams.addAll(keyWriter.getDataStreams()); outputDataStreams.addAll(valueWriter.getDataStreams()); return outputDataStreams.build(); diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/SliceDictionaryColumnWriter.java b/presto-orc/src/main/java/io/prestosql/orc/writer/SliceDictionaryColumnWriter.java index 66cbb5a2fc88..687511fff04b 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/SliceDictionaryColumnWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/SliceDictionaryColumnWriter.java @@ -24,6 +24,7 @@ import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.CompressedMetadataWriter; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.RowGroupIndex; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -66,7 +67,7 @@ public class SliceDictionaryColumnWriter private static final int INSTANCE_SIZE = ClassLayout.parseClass(SliceDictionaryColumnWriter.class).instanceSize(); private static final int DIRECT_CONVERSION_CHUNK_MAX_LOGICAL_BYTES = toIntExact(new DataSize(32, MEGABYTE).toBytes()); - private final int column; + private final OrcColumnId columnId; private final Type type; private final CompressionKind compression; private final int bufferSize; @@ -96,10 +97,9 @@ public class SliceDictionaryColumnWriter private boolean directEncoded; private SliceDirectColumnWriter directColumnWriter; - public SliceDictionaryColumnWriter(int column, Type type, CompressionKind compression, int bufferSize, DataSize stringStatisticsLimit) + public SliceDictionaryColumnWriter(OrcColumnId columnId, Type type, CompressionKind compression, int bufferSize, DataSize stringStatisticsLimit) { - checkArgument(column >= 0, "column is negative"); - this.column = column; + this.columnId = requireNonNull(columnId, "columnId is null"); this.type = requireNonNull(type, "type is null"); this.compression = requireNonNull(compression, "compression is null"); this.bufferSize = bufferSize; @@ -160,7 +160,7 @@ public OptionalInt tryConvertToDirect(int maxDirectBytes) checkState(!closed); checkState(!directEncoded); if (directColumnWriter == null) { - directColumnWriter = new SliceDirectColumnWriter(column, type, compression, bufferSize, this::newStringStatisticsBuilder); + directColumnWriter = new SliceDirectColumnWriter(columnId, type, compression, bufferSize, this::newStringStatisticsBuilder); } checkState(directColumnWriter.getBufferedBytes() == 0); @@ -246,13 +246,13 @@ private boolean writeDictionaryRowGroup(Block dictionary, int valueCount, IntBig } @Override - public Map getColumnEncodings() + public Map getColumnEncodings() { checkState(closed); if (directEncoded) { return directColumnWriter.getColumnEncodings(); } - return ImmutableMap.of(column, columnEncoding); + return ImmutableMap.of(columnId, columnEncoding); } @Override @@ -296,7 +296,7 @@ public void writeBlock(Block block) } @Override - public Map finishRowGroup() + public Map finishRowGroup() { checkState(!closed); checkState(inRowGroup); @@ -311,7 +311,7 @@ public Map finishRowGroup() rowGroupValueCount = 0; statisticsBuilder = newStringStatisticsBuilder(); values = new IntBigArray(); - return ImmutableMap.of(column, statistics); + return ImmutableMap.of(columnId, statistics); } @Override @@ -329,14 +329,14 @@ public void close() } @Override - public Map getColumnStripeStatistics() + public Map getColumnStripeStatistics() { checkState(closed); if (directEncoded) { return directColumnWriter.getColumnStripeStatistics(); } - return ImmutableMap.of(column, ColumnStatistics.mergeColumnStatistics(rowGroups.stream() + return ImmutableMap.of(columnId, ColumnStatistics.mergeColumnStatistics(rowGroups.stream() .map(DictionaryRowGroup::getColumnStatistics) .collect(toList()))); } @@ -463,7 +463,7 @@ public List getIndexStreams(CompressedMetadataWriter metadataW } Slice slice = metadataWriter.writeRowIndexes(rowGroupIndexes.build()); - Stream stream = new Stream(column, StreamKind.ROW_INDEX, slice.length(), false); + Stream stream = new Stream(columnId, StreamKind.ROW_INDEX, slice.length(), false); return ImmutableList.of(new StreamDataOutput(slice, stream)); } @@ -489,10 +489,10 @@ public List getDataStreams() // actually write data ImmutableList.Builder outputDataStreams = ImmutableList.builder(); - presentStream.getStreamDataOutput(column).ifPresent(outputDataStreams::add); - outputDataStreams.add(dataStream.getStreamDataOutput(column)); - outputDataStreams.add(dictionaryLengthStream.getStreamDataOutput(column)); - outputDataStreams.add(dictionaryDataStream.getStreamDataOutput(column)); + presentStream.getStreamDataOutput(columnId).ifPresent(outputDataStreams::add); + outputDataStreams.add(dataStream.getStreamDataOutput(columnId)); + outputDataStreams.add(dictionaryLengthStream.getStreamDataOutput(columnId)); + outputDataStreams.add(dictionaryDataStream.getStreamDataOutput(columnId)); return outputDataStreams.build(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/SliceDirectColumnWriter.java b/presto-orc/src/main/java/io/prestosql/orc/writer/SliceDirectColumnWriter.java index b40720b37f51..0ae2ff3b0bbb 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/SliceDirectColumnWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/SliceDirectColumnWriter.java @@ -22,6 +22,7 @@ import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.CompressedMetadataWriter; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.RowGroupIndex; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -53,7 +54,7 @@ public class SliceDirectColumnWriter implements ColumnWriter { private static final int INSTANCE_SIZE = ClassLayout.parseClass(SliceDirectColumnWriter.class).instanceSize(); - private final int column; + private final OrcColumnId columnId; private final Type type; private final boolean compressed; private final ColumnEncoding columnEncoding; @@ -68,10 +69,9 @@ public class SliceDirectColumnWriter private boolean closed; - public SliceDirectColumnWriter(int column, Type type, CompressionKind compression, int bufferSize, Supplier statisticsBuilderSupplier) + public SliceDirectColumnWriter(OrcColumnId columnId, Type type, CompressionKind compression, int bufferSize, Supplier statisticsBuilderSupplier) { - checkArgument(column >= 0, "column is negative"); - this.column = column; + this.columnId = requireNonNull(columnId, "columnId is null"); this.type = requireNonNull(type, "type is null"); this.compressed = requireNonNull(compression, "compression is null") != NONE; this.columnEncoding = new ColumnEncoding(DIRECT_V2, 0); @@ -83,9 +83,9 @@ public SliceDirectColumnWriter(int column, Type type, CompressionKind compressio } @Override - public Map getColumnEncodings() + public Map getColumnEncodings() { - return ImmutableMap.of(column, columnEncoding); + return ImmutableMap.of(columnId, columnEncoding); } @Override @@ -120,7 +120,7 @@ public void writeBlock(Block block) } @Override - public Map finishRowGroup() + public Map finishRowGroup() { checkState(!closed); @@ -128,7 +128,7 @@ public Map finishRowGroup() rowGroupColumnStatistics.add(statistics); statisticsBuilder = statisticsBuilderSupplier.get(); - return ImmutableMap.of(column, statistics); + return ImmutableMap.of(columnId, statistics); } @Override @@ -142,10 +142,10 @@ public void close() } @Override - public Map getColumnStripeStatistics() + public Map getColumnStripeStatistics() { checkState(closed); - return ImmutableMap.of(column, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); + return ImmutableMap.of(columnId, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); } @Override @@ -170,7 +170,7 @@ public List getIndexStreams(CompressedMetadataWriter metadataW } Slice slice = metadataWriter.writeRowIndexes(rowGroupIndexes.build()); - Stream stream = new Stream(column, StreamKind.ROW_INDEX, slice.length(), false); + Stream stream = new Stream(columnId, StreamKind.ROW_INDEX, slice.length(), false); return ImmutableList.of(new StreamDataOutput(slice, stream)); } @@ -193,9 +193,9 @@ public List getDataStreams() checkState(closed); ImmutableList.Builder outputDataStreams = ImmutableList.builder(); - presentStream.getStreamDataOutput(column).ifPresent(outputDataStreams::add); - outputDataStreams.add(lengthStream.getStreamDataOutput(column)); - outputDataStreams.add(dataStream.getStreamDataOutput(column)); + presentStream.getStreamDataOutput(columnId).ifPresent(outputDataStreams::add); + outputDataStreams.add(lengthStream.getStreamDataOutput(columnId)); + outputDataStreams.add(dataStream.getStreamDataOutput(columnId)); return outputDataStreams.build(); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/StructColumnWriter.java b/presto-orc/src/main/java/io/prestosql/orc/writer/StructColumnWriter.java index c4e792f41a9a..fbc7641ca33c 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/StructColumnWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/StructColumnWriter.java @@ -20,6 +20,7 @@ import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.CompressedMetadataWriter; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.RowGroupIndex; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -49,7 +50,7 @@ public class StructColumnWriter private static final int INSTANCE_SIZE = ClassLayout.parseClass(StructColumnWriter.class).instanceSize(); private static final ColumnEncoding COLUMN_ENCODING = new ColumnEncoding(DIRECT, 0); - private final int column; + private final OrcColumnId columnId; private final boolean compressed; private final PresentOutputStream presentStream; private final List structFields; @@ -60,10 +61,9 @@ public class StructColumnWriter private boolean closed; - public StructColumnWriter(int column, CompressionKind compression, int bufferSize, List structFields) + public StructColumnWriter(OrcColumnId columnId, CompressionKind compression, int bufferSize, List structFields) { - checkArgument(column >= 0, "column is negative"); - this.column = column; + this.columnId = columnId; this.compressed = requireNonNull(compression, "compression is null") != NONE; this.structFields = ImmutableList.copyOf(requireNonNull(structFields, "structFields is null")); this.presentStream = new PresentOutputStream(compression, bufferSize); @@ -82,10 +82,10 @@ public List getNestedColumnWriters() } @Override - public Map getColumnEncodings() + public Map getColumnEncodings() { - ImmutableMap.Builder encodings = ImmutableMap.builder(); - encodings.put(column, COLUMN_ENCODING); + ImmutableMap.Builder encodings = ImmutableMap.builder(); + encodings.put(columnId, COLUMN_ENCODING); structFields.stream() .map(ColumnWriter::getColumnEncodings) .forEach(encodings::putAll); @@ -132,15 +132,15 @@ private void writeColumnarRow(ColumnarRow columnarRow) } @Override - public Map finishRowGroup() + public Map finishRowGroup() { checkState(!closed); ColumnStatistics statistics = new ColumnStatistics((long) nonNullValueCount, 0, null, null, null, null, null, null, null, null); rowGroupColumnStatistics.add(statistics); nonNullValueCount = 0; - ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); - columnStatistics.put(column, statistics); + ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); + columnStatistics.put(columnId, statistics); structFields.stream() .map(ColumnWriter::finishRowGroup) .forEach(columnStatistics::putAll); @@ -156,11 +156,11 @@ public void close() } @Override - public Map getColumnStripeStatistics() + public Map getColumnStripeStatistics() { checkState(closed); - ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); - columnStatistics.put(column, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); + ImmutableMap.Builder columnStatistics = ImmutableMap.builder(); + columnStatistics.put(columnId, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); structFields.stream() .map(ColumnWriter::getColumnStripeStatistics) .forEach(columnStatistics::putAll); @@ -185,7 +185,7 @@ public List getIndexStreams(CompressedMetadataWriter metadataW } Slice slice = metadataWriter.writeRowIndexes(rowGroupIndexes.build()); - Stream stream = new Stream(column, StreamKind.ROW_INDEX, slice.length(), false); + Stream stream = new Stream(columnId, StreamKind.ROW_INDEX, slice.length(), false); ImmutableList.Builder indexStreams = ImmutableList.builder(); indexStreams.add(new StreamDataOutput(slice, stream)); @@ -210,7 +210,7 @@ public List getDataStreams() checkState(closed); ImmutableList.Builder outputDataStreams = ImmutableList.builder(); - presentStream.getStreamDataOutput(column).ifPresent(outputDataStreams::add); + presentStream.getStreamDataOutput(columnId).ifPresent(outputDataStreams::add); for (ColumnWriter structField : structFields) { outputDataStreams.addAll(structField.getDataStreams()); } diff --git a/presto-orc/src/main/java/io/prestosql/orc/writer/TimestampColumnWriter.java b/presto-orc/src/main/java/io/prestosql/orc/writer/TimestampColumnWriter.java index a7f6195a39e6..7c5f24c5642e 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/writer/TimestampColumnWriter.java +++ b/presto-orc/src/main/java/io/prestosql/orc/writer/TimestampColumnWriter.java @@ -21,6 +21,7 @@ import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.CompressedMetadataWriter; import io.prestosql.orc.metadata.CompressionKind; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.RowGroupIndex; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -56,7 +57,7 @@ public class TimestampColumnWriter private static final int MILLIS_PER_SECOND = 1000; private static final int MILLIS_TO_NANOS_TRAILING_ZEROS = 5; - private final int column; + private final OrcColumnId columnId; private final Type type; private final boolean compressed; private final ColumnEncoding columnEncoding; @@ -71,10 +72,9 @@ public class TimestampColumnWriter private boolean closed; - public TimestampColumnWriter(int column, Type type, CompressionKind compression, int bufferSize, DateTimeZone hiveStorageTimeZone) + public TimestampColumnWriter(OrcColumnId columnId, Type type, CompressionKind compression, int bufferSize, DateTimeZone hiveStorageTimeZone) { - checkArgument(column >= 0, "column is negative"); - this.column = column; + this.columnId = requireNonNull(columnId, "columnId is null"); this.type = requireNonNull(type, "type is null"); this.compressed = requireNonNull(compression, "compression is null") != NONE; this.columnEncoding = new ColumnEncoding(DIRECT_V2, 0); @@ -85,9 +85,9 @@ public TimestampColumnWriter(int column, Type type, CompressionKind compression, } @Override - public Map getColumnEncodings() + public Map getColumnEncodings() { - return ImmutableMap.of(column, columnEncoding); + return ImmutableMap.of(columnId, columnEncoding); } @Override @@ -147,13 +147,13 @@ public void writeBlock(Block block) } @Override - public Map finishRowGroup() + public Map finishRowGroup() { checkState(!closed); ColumnStatistics statistics = new ColumnStatistics((long) nonNullValueCount, 0, null, null, null, null, null, null, null, null); rowGroupColumnStatistics.add(statistics); nonNullValueCount = 0; - return ImmutableMap.of(column, statistics); + return ImmutableMap.of(columnId, statistics); } @Override @@ -166,10 +166,10 @@ public void close() } @Override - public Map getColumnStripeStatistics() + public Map getColumnStripeStatistics() { checkState(closed); - return ImmutableMap.of(column, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); + return ImmutableMap.of(columnId, ColumnStatistics.mergeColumnStatistics(rowGroupColumnStatistics)); } @Override @@ -194,7 +194,7 @@ public List getIndexStreams(CompressedMetadataWriter metadataW } Slice slice = metadataWriter.writeRowIndexes(rowGroupIndexes.build()); - Stream stream = new Stream(column, StreamKind.ROW_INDEX, slice.length(), false); + Stream stream = new Stream(columnId, StreamKind.ROW_INDEX, slice.length(), false); return ImmutableList.of(new StreamDataOutput(slice, stream)); } @@ -217,9 +217,9 @@ public List getDataStreams() checkState(closed); ImmutableList.Builder outputDataStreams = ImmutableList.builder(); - presentStream.getStreamDataOutput(column).ifPresent(outputDataStreams::add); - outputDataStreams.add(secondsStream.getStreamDataOutput(column)); - outputDataStreams.add(nanosStream.getStreamDataOutput(column)); + presentStream.getStreamDataOutput(columnId).ifPresent(outputDataStreams::add); + outputDataStreams.add(secondsStream.getStreamDataOutput(columnId)); + outputDataStreams.add(nanosStream.getStreamDataOutput(columnId)); return outputDataStreams.build(); } diff --git a/presto-orc/src/test/java/io/prestosql/orc/BenchmarkStreamReaders.java b/presto-orc/src/test/java/io/prestosql/orc/BenchmarkColumnReaders.java similarity index 82% rename from presto-orc/src/test/java/io/prestosql/orc/BenchmarkStreamReaders.java rename to presto-orc/src/test/java/io/prestosql/orc/BenchmarkColumnReaders.java index 5e1e9143170f..54869ab37843 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/BenchmarkStreamReaders.java +++ b/presto-orc/src/test/java/io/prestosql/orc/BenchmarkColumnReaders.java @@ -13,7 +13,8 @@ */ package io.prestosql.orc; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.spi.Page; import io.prestosql.spi.block.Block; import io.prestosql.spi.type.DecimalType; import io.prestosql.spi.type.SqlDecimal; @@ -79,8 +80,8 @@ @Warmup(iterations = 30, time = 500, timeUnit = MILLISECONDS) @Measurement(iterations = 20, time = 500, timeUnit = MILLISECONDS) @BenchmarkMode(Mode.AverageTime) -@OperationsPerInvocation(BenchmarkStreamReaders.ROWS) -public class BenchmarkStreamReaders +@OperationsPerInvocation(BenchmarkColumnReaders.ROWS) +public class BenchmarkColumnReaders { private static final DecimalType SHORT_DECIMAL_TYPE = createDecimalType(10, 5); private static final DecimalType LONG_DECIMAL_TYPE = createDecimalType(30, 5); @@ -94,12 +95,7 @@ public Object readBooleanNoNull(BooleanNoNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -108,12 +104,7 @@ public Object readBooleanWithNull(BooleanWithNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -122,12 +113,7 @@ public Object readAllNull(AllNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -136,12 +122,7 @@ public Object readByteNoNull(TinyIntNoNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -150,12 +131,7 @@ public Object readByteWithNull(TinyIntWithNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -164,12 +140,7 @@ public Object readShortDecimalNoNull(ShortDecimalNoNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -178,12 +149,7 @@ public Object readShortDecimalWithNull(ShortDecimalWithNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -192,12 +158,7 @@ public Object readLongDecimalNoNull(LongDecimalNoNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -206,12 +167,7 @@ public Object readLongDecimalWithNull(LongDecimalWithNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -220,12 +176,7 @@ public Object readDoubleNoNull(DoubleNoNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -234,12 +185,7 @@ public Object readDoubleWithNull(DoubleWithNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -248,12 +194,7 @@ public Object readFloatNoNull(FloatNoNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -262,12 +203,7 @@ public Object readFloatWithNull(FloatWithNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -276,12 +212,7 @@ public Object readLongNoNull(BigintNoNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -290,12 +221,7 @@ public Object readLongWithNull(BigintWithNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -304,12 +230,7 @@ public Object readIntNoNull(IntegerNoNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -318,12 +239,7 @@ public Object readIntWithNull(IntegerWithNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -332,12 +248,7 @@ public Object readShortNoNull(SmallintNoNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -346,12 +257,7 @@ public Object readShortWithNull(SmallintWithNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -360,12 +266,7 @@ public Object readSliceDirectNoNull(VarcharDirectNoNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -374,12 +275,7 @@ public Object readSliceDirectWithNull(VarcharDirectWithNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -388,12 +284,7 @@ public Object readSliceDictionaryNoNull(VarcharDictionaryNoNullBenchmarkData dat throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -402,12 +293,7 @@ public Object readSliceDictionaryWithNull(VarcharDictionaryWithNullBenchmarkData throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -416,12 +302,7 @@ public Object readTimestampNoNull(TimestampNoNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); } } @@ -430,13 +311,18 @@ public Object readTimestampWithNull(TimestampWithNullBenchmarkData data) throws Throwable { try (OrcRecordReader recordReader = data.createRecordReader()) { - List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); - } - return blocks; + return readFirstColumn(recordReader); + } + } + + private Object readFirstColumn(OrcRecordReader recordReader) + throws IOException + { + List blocks = new ArrayList<>(); + for (Page page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) { + blocks.add(page.getBlock(0).getLoadedBlock()); } + return blocks; } private abstract static class BenchmarkData @@ -480,11 +366,13 @@ OrcRecordReader createRecordReader() { OrcReader orcReader = new OrcReader(dataSource, new OrcReaderOptions()); return orcReader.createRecordReader( - ImmutableMap.of(0, type), + orcReader.getRootColumn().getNestedColumns(), + ImmutableList.of(type), OrcPredicate.TRUE, UTC, // arbitrary newSimpleAggregatedMemoryContext(), - INITIAL_BATCH_SIZE); + INITIAL_BATCH_SIZE, + RuntimeException::new); } } @@ -1183,7 +1071,7 @@ protected Iterator createValues() static { try { // call all versions of the long stream reader to pollute the profile - BenchmarkStreamReaders benchmark = new BenchmarkStreamReaders(); + BenchmarkColumnReaders benchmark = new BenchmarkColumnReaders(); benchmark.readLongNoNull(BigintNoNullBenchmarkData.create()); benchmark.readLongWithNull(BigintWithNullBenchmarkData.create()); benchmark.readIntNoNull(IntegerNoNullBenchmarkData.create()); @@ -1201,7 +1089,7 @@ public static void main(String[] args) { Options options = new OptionsBuilder() .verbosity(VerboseMode.NORMAL) - .include(".*" + BenchmarkStreamReaders.class.getSimpleName() + ".*") + .include(".*" + BenchmarkColumnReaders.class.getSimpleName() + ".*") .build(); new Runner(options).run(); diff --git a/presto-orc/src/test/java/io/prestosql/orc/BenchmarkOrcDecimalReader.java b/presto-orc/src/test/java/io/prestosql/orc/BenchmarkOrcDecimalReader.java index 3eb7c5fff7cc..06ee304b911a 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/BenchmarkOrcDecimalReader.java +++ b/presto-orc/src/test/java/io/prestosql/orc/BenchmarkOrcDecimalReader.java @@ -13,7 +13,8 @@ */ package io.prestosql.orc; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.spi.Page; import io.prestosql.spi.block.Block; import io.prestosql.spi.type.DecimalType; import io.prestosql.spi.type.SqlDecimal; @@ -72,9 +73,8 @@ public Object readDecimal(BenchmarkData data) { OrcRecordReader recordReader = data.createRecordReader(); List blocks = new ArrayList<>(); - while (recordReader.nextBatch() > 0) { - Block block = recordReader.readBlock(0); - blocks.add(block); + for (Page page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) { + blocks.add(page.getBlock(0).getLoadedBlock()); } return blocks; } @@ -117,11 +117,13 @@ private OrcRecordReader createRecordReader() OrcDataSource dataSource = new FileOrcDataSource(dataPath, READER_OPTIONS); OrcReader orcReader = new OrcReader(dataSource, READER_OPTIONS); return orcReader.createRecordReader( - ImmutableMap.of(0, DECIMAL_TYPE), + orcReader.getRootColumn().getNestedColumns(), + ImmutableList.of(DECIMAL_TYPE), OrcPredicate.TRUE, DateTimeZone.UTC, // arbitrary newSimpleAggregatedMemoryContext(), - INITIAL_BATCH_SIZE); + INITIAL_BATCH_SIZE, + RuntimeException::new); } private List createDecimalValues() diff --git a/presto-orc/src/test/java/io/prestosql/orc/OrcTester.java b/presto-orc/src/test/java/io/prestosql/orc/OrcTester.java index 1fb43fc1f9bd..ab220dce5267 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/OrcTester.java +++ b/presto-orc/src/test/java/io/prestosql/orc/OrcTester.java @@ -134,7 +134,6 @@ import static io.prestosql.spi.type.Varchars.truncateToLength; import static io.prestosql.testing.DateTimeTestingUtils.sqlTimestampOf; import static io.prestosql.testing.TestingConnectorSession.SESSION; -import static java.lang.Math.toIntExact; import static java.util.Arrays.asList; import static java.util.concurrent.TimeUnit.SECONDS; import static java.util.stream.Collectors.toList; @@ -158,6 +157,7 @@ import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.getCharTypeInfo; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; public class OrcTester @@ -457,7 +457,8 @@ private static void assertFileContentsPresto( boolean isFirst = true; int rowsProcessed = 0; Iterator iterator = expectedValues.iterator(); - for (int batchSize = toIntExact(recordReader.nextBatch()); batchSize >= 0; batchSize = toIntExact(recordReader.nextBatch())) { + for (Page page = recordReader.nextPage(); page != null; page = recordReader.nextPage()) { + int batchSize = page.getPositionCount(); if (skipStripe && rowsProcessed < 10000) { assertEquals(advance(iterator, batchSize), batchSize); } @@ -466,7 +467,7 @@ else if (skipFirstBatch && isFirst) { isFirst = false; } else { - Block block = recordReader.readBlock(0); + Block block = page.getBlock(0); List data = new ArrayList<>(block.getPositionCount()); for (int position = 0; position < block.getPositionCount(); position++) { @@ -485,6 +486,7 @@ else if (skipFirstBatch && isFirst) { rowsProcessed += batchSize; } assertFalse(iterator.hasNext()); + assertNull(recordReader.nextPage()); assertEquals(recordReader.getReaderPosition(), rowsProcessed); assertEquals(recordReader.getFilePosition(), rowsProcessed); @@ -572,7 +574,14 @@ static OrcRecordReader createCustomOrcRecordReader(TempFile tempFile, OrcPredica assertEquals(orcReader.getColumnNames(), ImmutableList.of("test")); assertEquals(orcReader.getFooter().getRowsInRowGroup(), 10_000); - return orcReader.createRecordReader(ImmutableMap.of(0, type), predicate, HIVE_STORAGE_TIME_ZONE, newSimpleAggregatedMemoryContext(), initialBatchSize); + return orcReader.createRecordReader( + orcReader.getRootColumn().getNestedColumns(), + ImmutableList.of(type), + predicate, + HIVE_STORAGE_TIME_ZONE, + newSimpleAggregatedMemoryContext(), + initialBatchSize, + RuntimeException::new); } public static void writeOrcColumnPresto(File outputFile, CompressionKind compression, Type type, Iterator values, OrcWriterStats stats) diff --git a/presto-orc/src/test/java/io/prestosql/orc/TestCachingOrcDataSource.java b/presto-orc/src/test/java/io/prestosql/orc/TestCachingOrcDataSource.java index 495365ef4827..4e82a2bd82c0 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/TestCachingOrcDataSource.java +++ b/presto-orc/src/test/java/io/prestosql/orc/TestCachingOrcDataSource.java @@ -14,7 +14,6 @@ package io.prestosql.orc; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import io.airlift.units.DataSize; @@ -23,6 +22,7 @@ import io.prestosql.orc.metadata.CompressionKind; import io.prestosql.orc.metadata.StripeInformation; import io.prestosql.orc.stream.OrcDataReader; +import io.prestosql.spi.Page; import io.prestosql.spi.block.Block; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.exec.FileSinkOperator; @@ -210,18 +210,21 @@ private void doIntegration(TestingOrcDataSource orcDataSource, DataSize maxMerge assertInstanceOf(wrapWithCacheIfTinyStripes(orcDataSource, stripes, maxMergeDistance, tinyStripeThreshold), CachingOrcDataSource.class); OrcRecordReader orcRecordReader = orcReader.createRecordReader( - ImmutableMap.of(0, VARCHAR), + orcReader.getRootColumn().getNestedColumns(), + ImmutableList.of(VARCHAR), (numberOfRows, statisticsByColumnIndex) -> true, HIVE_STORAGE_TIME_ZONE, newSimpleAggregatedMemoryContext(), - INITIAL_BATCH_SIZE); + INITIAL_BATCH_SIZE, + RuntimeException::new); int positionCount = 0; while (true) { - int batchSize = orcRecordReader.nextBatch(); - if (batchSize <= 0) { + Page page = orcRecordReader.nextPage(); + if (page == null) { break; } - Block block = orcRecordReader.readBlock(0); + page = page.getLoadedPage(); + Block block = page.getBlock(0); positionCount += block.getPositionCount(); } assertEquals(positionCount, POSITION_COUNT); diff --git a/presto-orc/src/test/java/io/prestosql/orc/TestOrcBloomFilters.java b/presto-orc/src/test/java/io/prestosql/orc/TestOrcBloomFilters.java index 2bb8630dabe3..bd7d1f1d50e0 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/TestOrcBloomFilters.java +++ b/presto-orc/src/test/java/io/prestosql/orc/TestOrcBloomFilters.java @@ -17,7 +17,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Longs; import io.airlift.slice.Slice; -import io.prestosql.orc.TupleDomainOrcPredicate.ColumnReference; +import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.metadata.OrcMetadataReader; import io.prestosql.orc.metadata.statistics.BloomFilter; import io.prestosql.orc.metadata.statistics.ColumnStatistics; @@ -25,7 +25,6 @@ import io.prestosql.orc.proto.OrcProto; import io.prestosql.orc.protobuf.CodedInputStream; import io.prestosql.spi.predicate.Domain; -import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.type.RealType; import io.prestosql.spi.type.Type; import org.apache.orc.util.Murmur3; @@ -47,6 +46,7 @@ import static io.airlift.slice.Slices.wrappedBuffer; import static io.prestosql.orc.TupleDomainOrcPredicate.checkInBloomFilter; import static io.prestosql.orc.TupleDomainOrcPredicate.extractDiscreteValues; +import static io.prestosql.orc.metadata.OrcColumnId.ROOT_COLUMN; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.BooleanType.BOOLEAN; import static io.prestosql.spi.type.DateType.DATE; @@ -283,22 +283,11 @@ public void testExtractValuesFromSingleDomain() // simulate query on a 2 columns where 1 is used as part of the where, with and without bloom filter public void testMatches() { - // stripe column - Domain testingColumnHandleDomain = Domain.singleValue(BIGINT, 1234L); - TupleDomain.ColumnDomain column0 = new TupleDomain.ColumnDomain<>(COLUMN_0, testingColumnHandleDomain); - - // predicate consist of the bigint_0 = 1234 - TupleDomain effectivePredicate = TupleDomain.fromColumnDomains(Optional.of(ImmutableList.of(column0))); - TupleDomain emptyEffectivePredicate = TupleDomain.all(); - - // predicate column references - List> columnReferences = ImmutableList.>builder() - .add(new ColumnReference<>(COLUMN_0, 0, BIGINT)) - .add(new ColumnReference<>(COLUMN_1, 1, BIGINT)) + TupleDomainOrcPredicate predicate = TupleDomainOrcPredicate.builder() + .setBloomFiltersEnabled(true) + .addColumn(ROOT_COLUMN, Domain.singleValue(BIGINT, 1234L)) .build(); - - TupleDomainOrcPredicate predicate = new TupleDomainOrcPredicate<>(effectivePredicate, columnReferences, true); - TupleDomainOrcPredicate emptyPredicate = new TupleDomainOrcPredicate<>(emptyEffectivePredicate, columnReferences, true); + TupleDomainOrcPredicate emptyPredicate = TupleDomainOrcPredicate.builder().build(); // assemble a matching and a non-matching bloom filter BloomFilter bloomFilter = new BloomFilter(1000, 0.01); @@ -306,7 +295,7 @@ public void testMatches() bloomFilter.addLong(1234); OrcProto.BloomFilter orcBloomFilter = toOrcBloomFilter(bloomFilter); - Map matchingStatisticsByColumnIndex = ImmutableMap.of(0, new ColumnStatistics( + ColumnMetadata matchingStatisticsByColumnIndex = new ColumnMetadata<>(ImmutableList.of(new ColumnStatistics( null, 0, null, @@ -316,9 +305,9 @@ public void testMatches() null, null, null, - toBloomFilter(orcBloomFilter))); + toBloomFilter(orcBloomFilter)))); - Map nonMatchingStatisticsByColumnIndex = ImmutableMap.of(0, new ColumnStatistics( + ColumnMetadata nonMatchingStatisticsByColumnIndex = new ColumnMetadata<>(ImmutableList.of(new ColumnStatistics( null, 0, null, @@ -328,9 +317,9 @@ public void testMatches() null, null, null, - toBloomFilter(emptyOrcBloomFilter))); + toBloomFilter(emptyOrcBloomFilter)))); - Map withoutBloomFilterStatisticsByColumnIndex = ImmutableMap.of(0, new ColumnStatistics( + ColumnMetadata withoutBloomFilterStatisticsByColumnIndex = new ColumnMetadata<>(ImmutableList.of(new ColumnStatistics( null, 0, null, @@ -340,7 +329,7 @@ public void testMatches() null, null, null, - null)); + null))); assertTrue(predicate.matches(1L, matchingStatisticsByColumnIndex)); assertTrue(predicate.matches(1L, withoutBloomFilterStatisticsByColumnIndex)); diff --git a/presto-orc/src/test/java/io/prestosql/orc/TestOrcLz4.java b/presto-orc/src/test/java/io/prestosql/orc/TestOrcLz4.java index 8e20e2b8ffc1..1278a9a86dff 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/TestOrcLz4.java +++ b/presto-orc/src/test/java/io/prestosql/orc/TestOrcLz4.java @@ -13,15 +13,13 @@ */ package io.prestosql.orc; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableList; import io.airlift.units.DataSize; +import io.prestosql.spi.Page; import io.prestosql.spi.block.Block; -import io.prestosql.spi.type.Type; import org.joda.time.DateTimeZone; import org.testng.annotations.Test; -import java.util.Map; - import static com.google.common.io.Resources.getResource; import static com.google.common.io.Resources.toByteArray; import static io.airlift.units.DataSize.Unit.MEGABYTE; @@ -50,32 +48,29 @@ public void testReadLz4() assertEquals(orcReader.getCompressionKind(), LZ4); assertEquals(orcReader.getFooter().getNumberOfRows(), 10_000); - Map includedColumns = ImmutableMap.builder() - .put(0, BIGINT) - .put(1, INTEGER) - .put(2, BIGINT) - .build(); - OrcRecordReader reader = orcReader.createRecordReader( - includedColumns, + orcReader.getRootColumn().getNestedColumns(), + ImmutableList.of(BIGINT, INTEGER, BIGINT), OrcPredicate.TRUE, DateTimeZone.UTC, newSimpleAggregatedMemoryContext(), - INITIAL_BATCH_SIZE); + INITIAL_BATCH_SIZE, + RuntimeException::new); int rows = 0; while (true) { - int batchSize = reader.nextBatch(); - if (batchSize <= 0) { + Page page = reader.nextPage(); + if (page == null) { break; } - rows += batchSize; + page = page.getLoadedPage(); + rows += page.getPositionCount(); - Block xBlock = reader.readBlock(0); - Block yBlock = reader.readBlock(1); - Block zBlock = reader.readBlock(2); + Block xBlock = page.getBlock(0); + Block yBlock = page.getBlock(1); + Block zBlock = page.getBlock(2); - for (int position = 0; position < batchSize; position++) { + for (int position = 0; position < page.getPositionCount(); position++) { BIGINT.getLong(xBlock, position); INTEGER.getLong(yBlock, position); BIGINT.getLong(zBlock, position); diff --git a/presto-orc/src/test/java/io/prestosql/orc/TestOrcReaderMemoryUsage.java b/presto-orc/src/test/java/io/prestosql/orc/TestOrcReaderMemoryUsage.java index d29b8ee53caa..8756bf823b4b 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/TestOrcReaderMemoryUsage.java +++ b/presto-orc/src/test/java/io/prestosql/orc/TestOrcReaderMemoryUsage.java @@ -16,7 +16,7 @@ import com.google.common.base.Strings; import io.prestosql.metadata.Metadata; import io.prestosql.orc.metadata.CompressionKind; -import io.prestosql.spi.block.Block; +import io.prestosql.spi.Page; import io.prestosql.spi.type.StandardTypes; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeSignature; @@ -65,17 +65,15 @@ public void testVarcharTypeWithoutNulls() long readerSystemMemoryUsage = reader.getSystemMemoryUsage(); while (true) { - int batchSize = reader.nextBatch(); - if (batchSize == -1) { + Page page = reader.nextPage(); + if (page == null) { break; } - - Block block = reader.readBlock(0); - assertEquals(block.getPositionCount(), batchSize); + page = page.getLoadedPage(); // We only verify the memory usage when the batchSize reaches MAX_BATCH_SIZE as batchSize may be // increasing during the test, which will cause the StreamReader buffer sizes to increase too. - if (batchSize < MAX_BATCH_SIZE) { + if (page.getPositionCount() < MAX_BATCH_SIZE) { continue; } @@ -112,17 +110,15 @@ public void testBigIntTypeWithNulls() long readerSystemMemoryUsage = reader.getSystemMemoryUsage(); while (true) { - int batchSize = reader.nextBatch(); - if (batchSize == -1) { + Page page = reader.nextPage(); + if (page == null) { break; } - - Block block = reader.readBlock(0); - assertEquals(block.getPositionCount(), batchSize); + page = page.getLoadedPage(); // We only verify the memory usage when the batchSize reaches MAX_BATCH_SIZE as batchSize may be // increasing during the test, which will cause the StreamReader buffer sizes to increase too. - if (batchSize < MAX_BATCH_SIZE) { + if (page.getPositionCount() < MAX_BATCH_SIZE) { continue; } @@ -161,17 +157,15 @@ public void testMapTypeWithNulls() long readerSystemMemoryUsage = reader.getSystemMemoryUsage(); while (true) { - int batchSize = reader.nextBatch(); - if (batchSize == -1) { + Page page = reader.nextPage(); + if (page == null) { break; } - - Block block = reader.readBlock(0); - assertEquals(block.getPositionCount(), batchSize); + page = page.getLoadedPage(); // We only verify the memory usage when the batchSize reaches MAX_BATCH_SIZE as batchSize may be // increasing during the test, which will cause the StreamReader buffer sizes to increase too. - if (batchSize < MAX_BATCH_SIZE) { + if (page.getPositionCount() < MAX_BATCH_SIZE) { continue; } diff --git a/presto-orc/src/test/java/io/prestosql/orc/TestOrcReaderPositions.java b/presto-orc/src/test/java/io/prestosql/orc/TestOrcReaderPositions.java index 8f8a391a3e62..ff6e8e77dfa5 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/TestOrcReaderPositions.java +++ b/presto-orc/src/test/java/io/prestosql/orc/TestOrcReaderPositions.java @@ -18,7 +18,9 @@ import io.airlift.slice.Slice; import io.prestosql.orc.metadata.CompressionKind; import io.prestosql.orc.metadata.Footer; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.statistics.IntegerStatistics; +import io.prestosql.spi.Page; import io.prestosql.spi.block.Block; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; @@ -55,6 +57,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static org.apache.hadoop.hive.ql.io.orc.CompressionKind.SNAPPY; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; @@ -74,13 +77,14 @@ public void testEntireFile() assertEquals(reader.getFilePosition(), reader.getReaderPosition()); for (int i = 0; i < 5; i++) { - assertEquals(reader.nextBatch(), 20); + Page page = reader.nextPage().getLoadedPage(); + assertEquals(page.getPositionCount(), 20); assertEquals(reader.getReaderPosition(), i * 20L); assertEquals(reader.getFilePosition(), reader.getReaderPosition()); - assertCurrentBatch(reader, i); + assertCurrentBatch(page, i); } - assertEquals(reader.nextBatch(), -1); + assertNull(reader.nextPage()); assertEquals(reader.getReaderPosition(), 100); assertEquals(reader.getFilePosition(), reader.getReaderPosition()); } @@ -95,11 +99,11 @@ public void testStripeSkipping() createMultiStripeFile(tempFile.getFile()); // test reading second and fourth stripes - OrcPredicate predicate = (numberOfRows, statisticsByColumnIndex) -> { + OrcPredicate predicate = (numberOfRows, allColumnStatistics) -> { if (numberOfRows == 100) { return true; } - IntegerStatistics stats = statisticsByColumnIndex.get(0).getIntegerStatistics(); + IntegerStatistics stats = allColumnStatistics.get(new OrcColumnId(1)).getIntegerStatistics(); return ((stats.getMin() == 60) && (stats.getMax() == 117)) || ((stats.getMin() == 180) && (stats.getMax() == 237)); }; @@ -111,18 +115,21 @@ public void testStripeSkipping() assertEquals(reader.getReaderPosition(), 0); // second stripe - assertEquals(reader.nextBatch(), 20); + Page page = reader.nextPage().getLoadedPage(); + assertEquals(page.getPositionCount(), 20); assertEquals(reader.getReaderPosition(), 0); assertEquals(reader.getFilePosition(), 20); - assertCurrentBatch(reader, 1); + assertCurrentBatch(page, 1); // fourth stripe - assertEquals(reader.nextBatch(), 20); + page = reader.nextPage().getLoadedPage(); + assertEquals(page.getPositionCount(), 20); assertEquals(reader.getReaderPosition(), 20); assertEquals(reader.getFilePosition(), 60); - assertCurrentBatch(reader, 3); + assertCurrentBatch(page, 3); - assertEquals(reader.nextBatch(), -1); + page = reader.nextPage(); + assertNull(page); assertEquals(reader.getReaderPosition(), 40); assertEquals(reader.getFilePosition(), 100); } @@ -139,11 +146,11 @@ public void testRowGroupSkipping() createSequentialFile(tempFile.getFile(), rowCount); // test reading two row groups from middle of file - OrcPredicate predicate = (numberOfRows, statisticsByColumnIndex) -> { + OrcPredicate predicate = (numberOfRows, allColumnStatistics) -> { if (numberOfRows == rowCount) { return true; } - IntegerStatistics stats = statisticsByColumnIndex.get(0).getIntegerStatistics(); + IntegerStatistics stats = allColumnStatistics.get(new OrcColumnId(1)).getIntegerStatistics(); return (stats.getMin() == 50_000) || (stats.getMin() == 60_000); }; @@ -155,19 +162,20 @@ public void testRowGroupSkipping() long position = 50_000; while (true) { - int batchSize = reader.nextBatch(); - if (batchSize == -1) { + Page page = reader.nextPage(); + if (page == null) { break; } + page = page.getLoadedPage(); - Block block = reader.readBlock(0); - for (int i = 0; i < batchSize; i++) { + Block block = page.getBlock(0); + for (int i = 0; i < block.getPositionCount(); i++) { assertEquals(BIGINT.getLong(block, i), position + i); } assertEquals(reader.getFilePosition(), position); assertEquals(reader.getReaderPosition(), position); - position += batchSize; + position += page.getPositionCount(); } assertEquals(position, 70_000); @@ -206,14 +214,15 @@ public void testBatchSizesForVariableWidth() int currentStringBytes = baseStringBytes + Integer.BYTES + Byte.BYTES; int rowCountsInCurrentRowGroup = 0; while (true) { - int batchSize = reader.nextBatch(); - if (batchSize == -1) { + Page page = reader.nextPage(); + if (page == null) { break; } + page = page.getLoadedPage(); - rowCountsInCurrentRowGroup += batchSize; + rowCountsInCurrentRowGroup += page.getPositionCount(); - Block block = reader.readBlock(0); + Block block = page.getBlock(0); if (MAX_BATCH_SIZE * currentStringBytes <= READER_OPTIONS.getMaxBlockSize().toBytes()) { // Either we are bounded by 1024 rows per batch, or it is the last batch in the row group // For the first 3 row groups, the strings are of length 300, 600, and 900 respectively @@ -261,13 +270,14 @@ public void testBatchSizesForFixedWidth() int rowCountsInCurrentRowGroup = 0; while (true) { - int batchSize = reader.nextBatch(); - if (batchSize == -1) { + Page page = reader.nextPage(); + if (page == null) { break; } - rowCountsInCurrentRowGroup += batchSize; + page = page.getLoadedPage(); + rowCountsInCurrentRowGroup += page.getPositionCount(); - Block block = reader.readBlock(0); + Block block = page.getBlock(0); // 8 bytes per row; 1024 row at most given 1024 X 8B < 1MB assertTrue(block.getPositionCount() == MAX_BATCH_SIZE || rowCountsInCurrentRowGroup == rowsInRowGroup); @@ -323,24 +333,25 @@ public void testBatchSizeGrowth() int expectedBatchSize = INITIAL_BATCH_SIZE; int rowCountsInCurrentRowGroup = 0; while (true) { - int batchSize = reader.nextBatch(); - if (batchSize == -1) { + Page page = reader.nextPage(); + if (page == null) { break; } + page = page.getLoadedPage(); - assertEquals(batchSize, expectedBatchSize); + assertEquals(page.getPositionCount(), expectedBatchSize); assertEquals(reader.getReaderPosition(), totalReadRows); assertEquals(reader.getFilePosition(), reader.getReaderPosition()); - assertCurrentBatch(reader, (int) reader.getReaderPosition(), batchSize); + assertCurrentBatch(page, (int) reader.getReaderPosition(), page.getPositionCount()); if (nextBatchSize > 20 - rowCountsInCurrentRowGroup) { nextBatchSize *= BATCH_SIZE_GROWTH_FACTOR; } else { - nextBatchSize = batchSize * BATCH_SIZE_GROWTH_FACTOR; + nextBatchSize = page.getPositionCount() * BATCH_SIZE_GROWTH_FACTOR; } - rowCountsInCurrentRowGroup += batchSize; - totalReadRows += batchSize; + rowCountsInCurrentRowGroup += page.getPositionCount(); + totalReadRows += page.getPositionCount(); if (rowCountsInCurrentRowGroup == 20) { rowCountsInCurrentRowGroup = 0; } @@ -357,19 +368,17 @@ else if (rowCountsInCurrentRowGroup > 20) { } } - private static void assertCurrentBatch(OrcRecordReader reader, int rowIndex, int batchSize) - throws IOException + private static void assertCurrentBatch(Page page, int rowIndex, int batchSize) { - Block block = reader.readBlock(0); + Block block = page.getBlock(0); for (int i = 0; i < batchSize; i++) { assertEquals(BIGINT.getLong(block, i), (rowIndex + i) * 3); } } - private static void assertCurrentBatch(OrcRecordReader reader, int stripe) - throws IOException + private static void assertCurrentBatch(Page page, int stripe) { - Block block = reader.readBlock(0); + Block block = page.getBlock(0); for (int i = 0; i < 20; i++) { assertEquals(BIGINT.getLong(block, i), ((stripe * 20L) + i) * 3); } diff --git a/presto-orc/src/test/java/io/prestosql/orc/TestReadBloomFilter.java b/presto-orc/src/test/java/io/prestosql/orc/TestReadBloomFilter.java index e87b5ecf543a..5791d73a76bb 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/TestReadBloomFilter.java +++ b/presto-orc/src/test/java/io/prestosql/orc/TestReadBloomFilter.java @@ -14,9 +14,8 @@ package io.prestosql.orc; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import io.prestosql.orc.TupleDomainOrcPredicate.ColumnReference; -import io.prestosql.spi.predicate.NullableValue; +import io.prestosql.orc.metadata.OrcColumnId; +import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.type.SqlDate; import io.prestosql.spi.type.SqlTimestamp; import io.prestosql.spi.type.SqlVarbinary; @@ -39,7 +38,6 @@ import static io.prestosql.orc.OrcTester.READER_OPTIONS; import static io.prestosql.orc.OrcTester.writeOrcColumnHive; import static io.prestosql.orc.metadata.CompressionKind.LZ4; -import static io.prestosql.spi.predicate.TupleDomain.fromFixedValues; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.DateType.DATE; import static io.prestosql.spi.type.DoubleType.DOUBLE; @@ -53,6 +51,7 @@ import static java.lang.Float.floatToIntBits; import static java.nio.charset.StandardCharsets.UTF_8; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; public class TestReadBloomFilter { @@ -93,37 +92,36 @@ private static void testType(Type type, List uniqueValues, T inBloomFilte // without predicate a normal block will be created try (OrcRecordReader recordReader = createCustomOrcRecordReader(tempFile, OrcPredicate.TRUE, type, MAX_BATCH_SIZE)) { - assertEquals(recordReader.nextBatch(), 1024); + assertEquals(recordReader.nextPage().getLoadedPage().getPositionCount(), 1024); } // predicate for specific value within the min/max range without bloom filter being enabled - TupleDomainOrcPredicate noBloomFilterPredicate = new TupleDomainOrcPredicate<>( - fromFixedValues(ImmutableMap.of("test", NullableValue.of(type, notInBloomFilter))), - ImmutableList.of(new ColumnReference<>("test", 0, type)), - false); + TupleDomainOrcPredicate noBloomFilterPredicate = TupleDomainOrcPredicate.builder() + .addColumn(new OrcColumnId(1), Domain.singleValue(type, notInBloomFilter)) + .build(); try (OrcRecordReader recordReader = createCustomOrcRecordReader(tempFile, noBloomFilterPredicate, type, MAX_BATCH_SIZE)) { - assertEquals(recordReader.nextBatch(), 1024); + assertEquals(recordReader.nextPage().getLoadedPage().getPositionCount(), 1024); } // predicate for specific value within the min/max range with bloom filter enabled, but a value not in the bloom filter - TupleDomainOrcPredicate notMatchBloomFilterPredicate = new TupleDomainOrcPredicate<>( - fromFixedValues(ImmutableMap.of("test", NullableValue.of(type, notInBloomFilter))), - ImmutableList.of(new ColumnReference<>("test", 0, type)), - true); + TupleDomainOrcPredicate notMatchBloomFilterPredicate = TupleDomainOrcPredicate.builder() + .addColumn(new OrcColumnId(1), Domain.singleValue(type, notInBloomFilter)) + .setBloomFiltersEnabled(true) + .build(); try (OrcRecordReader recordReader = createCustomOrcRecordReader(tempFile, notMatchBloomFilterPredicate, type, MAX_BATCH_SIZE)) { - assertEquals(recordReader.nextBatch(), -1); + assertNull(recordReader.nextPage()); } // predicate for specific value within the min/max range with bloom filter enabled, and a value in the bloom filter - TupleDomainOrcPredicate matchBloomFilterPredicate = new TupleDomainOrcPredicate<>( - fromFixedValues(ImmutableMap.of("test", NullableValue.of(type, inBloomFilter))), - ImmutableList.of(new ColumnReference<>("test", 0, type)), - true); + TupleDomainOrcPredicate matchBloomFilterPredicate = TupleDomainOrcPredicate.builder() + .addColumn(new OrcColumnId(1), Domain.singleValue(type, inBloomFilter)) + .setBloomFiltersEnabled(true) + .build(); try (OrcRecordReader recordReader = createCustomOrcRecordReader(tempFile, matchBloomFilterPredicate, type, MAX_BATCH_SIZE)) { - assertEquals(recordReader.nextBatch(), 1024); + assertEquals(recordReader.nextPage().getLoadedPage().getPositionCount(), 1024); } } } @@ -137,6 +135,13 @@ private static OrcRecordReader createCustomOrcRecordReader(TempFile tempFile, Or assertEquals(orcReader.getColumnNames(), ImmutableList.of("test")); assertEquals(orcReader.getFooter().getRowsInRowGroup(), 10_000); - return orcReader.createRecordReader(ImmutableMap.of(0, type), predicate, HIVE_STORAGE_TIME_ZONE, newSimpleAggregatedMemoryContext(), initialBatchSize); + return orcReader.createRecordReader( + orcReader.getRootColumn().getNestedColumns(), + ImmutableList.of(type), + predicate, + HIVE_STORAGE_TIME_ZONE, + newSimpleAggregatedMemoryContext(), + initialBatchSize, + RuntimeException::new); } } diff --git a/presto-orc/src/test/java/io/prestosql/orc/TestSliceDictionaryColumnWriter.java b/presto-orc/src/test/java/io/prestosql/orc/TestSliceDictionaryColumnWriter.java index 5dfb3cac7af8..7965c5a1f9b8 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/TestSliceDictionaryColumnWriter.java +++ b/presto-orc/src/test/java/io/prestosql/orc/TestSliceDictionaryColumnWriter.java @@ -26,6 +26,7 @@ import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.prestosql.orc.OrcWriterOptions.DEFAULT_MAX_COMPRESSION_BUFFER_SIZE; import static io.prestosql.orc.OrcWriterOptions.DEFAULT_MAX_STRING_STATISTICS_LIMIT; +import static io.prestosql.orc.metadata.OrcColumnId.ROOT_COLUMN; import static io.prestosql.spi.type.VarcharType.VARCHAR; import static java.lang.Math.toIntExact; import static org.testng.Assert.assertFalse; @@ -36,7 +37,7 @@ public class TestSliceDictionaryColumnWriter public void testDirectConversion() { SliceDictionaryColumnWriter writer = new SliceDictionaryColumnWriter( - 0, + ROOT_COLUMN, VARCHAR, CompressionKind.NONE, toIntExact(DEFAULT_MAX_COMPRESSION_BUFFER_SIZE.toBytes()), diff --git a/presto-orc/src/test/java/io/prestosql/orc/TestStructStreamReader.java b/presto-orc/src/test/java/io/prestosql/orc/TestStructColumnReader.java similarity index 96% rename from presto-orc/src/test/java/io/prestosql/orc/TestStructStreamReader.java rename to presto-orc/src/test/java/io/prestosql/orc/TestStructColumnReader.java index 105371bc2278..1e7c4b298022 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/TestStructStreamReader.java +++ b/presto-orc/src/test/java/io/prestosql/orc/TestStructColumnReader.java @@ -38,9 +38,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Optional; import static io.airlift.units.DataSize.Unit.MEGABYTE; @@ -58,7 +56,7 @@ import static org.testng.Assert.assertNull; @Test(singleThreaded = true) -public class TestStructStreamReader +public class TestStructColumnReader { private static final Metadata METADATA = createTestMetadataManager(); @@ -264,13 +262,16 @@ private RowBlock read(TempFile tempFile, Type readerType) OrcDataSource orcDataSource = new FileOrcDataSource(tempFile.getFile(), READER_OPTIONS); OrcReader orcReader = new OrcReader(orcDataSource, READER_OPTIONS); - Map includedColumns = new HashMap<>(); - includedColumns.put(0, readerType); + OrcRecordReader recordReader = orcReader.createRecordReader( + orcReader.getRootColumn().getNestedColumns(), + ImmutableList.of(readerType), + OrcPredicate.TRUE, + UTC, + newSimpleAggregatedMemoryContext(), + OrcReader.INITIAL_BATCH_SIZE, + RuntimeException::new); - OrcRecordReader recordReader = orcReader.createRecordReader(includedColumns, OrcPredicate.TRUE, UTC, newSimpleAggregatedMemoryContext(), OrcReader.INITIAL_BATCH_SIZE); - - recordReader.nextBatch(); - RowBlock block = (RowBlock) recordReader.readBlock(0); + RowBlock block = (RowBlock) recordReader.nextPage().getLoadedPage().getBlock(0); recordReader.close(); return block; } diff --git a/presto-orc/src/test/java/io/prestosql/orc/TestingOrcPredicate.java b/presto-orc/src/test/java/io/prestosql/orc/TestingOrcPredicate.java index 3c360731c212..5beff66910c5 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/TestingOrcPredicate.java +++ b/presto-orc/src/test/java/io/prestosql/orc/TestingOrcPredicate.java @@ -16,6 +16,8 @@ import com.google.common.collect.Ordering; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import io.prestosql.orc.metadata.ColumnMetadata; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.statistics.BloomFilter; import io.prestosql.orc.metadata.statistics.ColumnStatistics; import io.prestosql.spi.type.ArrayType; @@ -33,7 +35,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.concurrent.ThreadLocalRandom; @@ -126,9 +127,9 @@ public BasicOrcPredicate(Iterable expectedValues, Class type) } @Override - public boolean matches(long numberOfRows, Map statisticsByColumnIndex) + public boolean matches(long numberOfRows, ColumnMetadata allColumnStatistics) { - ColumnStatistics columnStatistics = statisticsByColumnIndex.get(0); + ColumnStatistics columnStatistics = allColumnStatistics.get(new OrcColumnId(1)); assertTrue(columnStatistics.hasNumberOfValues()); if (numberOfRows == expectedValues.size()) { diff --git a/presto-orc/src/test/java/io/prestosql/orc/stream/AbstractTestValueStream.java b/presto-orc/src/test/java/io/prestosql/orc/stream/AbstractTestValueStream.java index 1bda0b7aa237..8b7f43bede3a 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/stream/AbstractTestValueStream.java +++ b/presto-orc/src/test/java/io/prestosql/orc/stream/AbstractTestValueStream.java @@ -18,6 +18,7 @@ import io.prestosql.orc.OrcCorruptionException; import io.prestosql.orc.OrcDataSourceId; import io.prestosql.orc.checkpoint.StreamCheckpoint; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; @@ -49,11 +50,11 @@ protected void testWriteValue(List> groups) outputStream.close(); DynamicSliceOutput sliceOutput = new DynamicSliceOutput(1000); - StreamDataOutput streamDataOutput = outputStream.getStreamDataOutput(33); + StreamDataOutput streamDataOutput = outputStream.getStreamDataOutput(new OrcColumnId(33)); streamDataOutput.writeData(sliceOutput); Stream stream = streamDataOutput.getStream(); assertEquals(stream.getStreamKind(), StreamKind.DATA); - assertEquals(stream.getColumn(), 33); + assertEquals(stream.getColumnId(), new OrcColumnId(33)); assertEquals(stream.getLength(), sliceOutput.size()); List checkpoints = outputStream.getCheckpoints(); diff --git a/presto-orc/src/test/java/io/prestosql/orc/stream/TestBooleanStream.java b/presto-orc/src/test/java/io/prestosql/orc/stream/TestBooleanStream.java index 44faa1de2649..b44d901066b7 100644 --- a/presto-orc/src/test/java/io/prestosql/orc/stream/TestBooleanStream.java +++ b/presto-orc/src/test/java/io/prestosql/orc/stream/TestBooleanStream.java @@ -18,6 +18,7 @@ import io.prestosql.orc.OrcCorruptionException; import io.prestosql.orc.OrcDecompressor; import io.prestosql.orc.checkpoint.BooleanStreamCheckpoint; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.Stream; import io.prestosql.orc.metadata.Stream.StreamKind; import it.unimi.dsi.fastutil.booleans.BooleanArrayList; @@ -90,11 +91,11 @@ public void testWriteMultiple() outputStream.close(); DynamicSliceOutput sliceOutput = new DynamicSliceOutput(1000); - StreamDataOutput streamDataOutput = outputStream.getStreamDataOutput(33); + StreamDataOutput streamDataOutput = outputStream.getStreamDataOutput(new OrcColumnId(33)); streamDataOutput.writeData(sliceOutput); Stream stream = streamDataOutput.getStream(); assertEquals(stream.getStreamKind(), StreamKind.DATA); - assertEquals(stream.getColumn(), 33); + assertEquals(stream.getColumnId(), new OrcColumnId(33)); assertEquals(stream.getLength(), sliceOutput.size()); BooleanInputStream valueStream = createValueStream(sliceOutput.slice()); diff --git a/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/OrcPageSource.java b/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/OrcPageSource.java index f2724b833047..de4e975587ec 100644 --- a/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/OrcPageSource.java +++ b/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/OrcPageSource.java @@ -22,8 +22,6 @@ import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.Block; import io.prestosql.spi.block.BlockBuilder; -import io.prestosql.spi.block.LazyBlock; -import io.prestosql.spi.block.LazyBlockLoader; import io.prestosql.spi.block.RunLengthEncodedBlock; import io.prestosql.spi.connector.UpdatablePageSource; import io.prestosql.spi.type.Type; @@ -41,86 +39,42 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static io.airlift.slice.Slices.utf8Slice; -import static io.prestosql.orc.OrcReader.MAX_BATCH_SIZE; +import static io.prestosql.plugin.raptor.legacy.RaptorColumnHandle.SHARD_UUID_COLUMN_TYPE; import static io.prestosql.plugin.raptor.legacy.RaptorErrorCode.RAPTOR_ERROR; -import static io.prestosql.spi.predicate.Utils.nativeValueToBlock; import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.IntegerType.INTEGER; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class OrcPageSource implements UpdatablePageSource { - public static final int NULL_COLUMN = -1; - public static final int ROWID_COLUMN = -2; - public static final int SHARD_UUID_COLUMN = -3; - public static final int BUCKET_NUMBER_COLUMN = -4; - private final Optional shardRewriter; private final OrcRecordReader recordReader; + private final List columnAdaptations; private final OrcDataSource orcDataSource; private final BitSet rowsToDelete; - private final List columnIds; - private final List types; - - private final Block[] constantBlocks; - private final int[] columnIndexes; - private final AggregatedMemoryContext systemMemoryContext; - private int batchId; private boolean closed; public OrcPageSource( Optional shardRewriter, OrcRecordReader recordReader, + List columnAdaptations, OrcDataSource orcDataSource, - List columnIds, - List columnTypes, - List columnIndexes, - UUID shardUuid, - OptionalInt bucketNumber, AggregatedMemoryContext systemMemoryContext) { this.shardRewriter = requireNonNull(shardRewriter, "shardRewriter is null"); this.recordReader = requireNonNull(recordReader, "recordReader is null"); + this.columnAdaptations = ImmutableList.copyOf(requireNonNull(columnAdaptations, "columnAdaptations is null")); this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null"); this.rowsToDelete = new BitSet(toIntExact(recordReader.getFileRowCount())); - checkArgument(columnIds.size() == columnTypes.size(), "ids and types mismatch"); - checkArgument(columnIds.size() == columnIndexes.size(), "ids and indexes mismatch"); - int size = columnIds.size(); - - this.columnIds = ImmutableList.copyOf(columnIds); - this.types = ImmutableList.copyOf(columnTypes); - - this.constantBlocks = new Block[size]; - this.columnIndexes = new int[size]; - - requireNonNull(shardUuid, "shardUuid is null"); - - for (int i = 0; i < size; i++) { - this.columnIndexes[i] = columnIndexes.get(i); - if (this.columnIndexes[i] == NULL_COLUMN) { - constantBlocks[i] = buildSingleValueBlock(columnTypes.get(i), null); - } - else if (this.columnIndexes[i] == SHARD_UUID_COLUMN) { - constantBlocks[i] = buildSingleValueBlock(columnTypes.get(i), utf8Slice(shardUuid.toString())); - } - else if (this.columnIndexes[i] == BUCKET_NUMBER_COLUMN) { - if (bucketNumber.isPresent()) { - constantBlocks[i] = buildSingleValueBlock(columnTypes.get(i), (long) bucketNumber.getAsInt()); - } - else { - constantBlocks[i] = buildSingleValueBlock(columnTypes.get(i), null); - } - } - } - this.systemMemoryContext = requireNonNull(systemMemoryContext, "systemMemoryContext is null"); } @@ -145,39 +99,42 @@ public boolean isFinished() @Override public Page getNextPage() { + Page page; try { - batchId++; - int batchSize = recordReader.nextBatch(); - if (batchSize <= 0) { - close(); - return null; - } - long filePosition = recordReader.getFilePosition(); - - Block[] blocks = new Block[columnIndexes.length]; - for (int fieldId = 0; fieldId < blocks.length; fieldId++) { - if (constantBlocks[fieldId] != null) { - blocks[fieldId] = constantBlocks[fieldId].getRegion(0, batchSize); - } - else if (columnIndexes[fieldId] == ROWID_COLUMN) { - blocks[fieldId] = buildSequenceBlock(filePosition, batchSize); - } - else { - blocks[fieldId] = new LazyBlock(batchSize, new OrcBlockLoader(columnIndexes[fieldId])); - } - } - - return new Page(batchSize, blocks); + page = recordReader.nextPage(); } catch (IOException | RuntimeException e) { closeWithSuppression(e); - throw new PrestoException(RAPTOR_ERROR, e); + throw handleException(e); + } + + if (page == null) { + close(); + return null; + } + + long filePosition = recordReader.getFilePosition(); + Block[] blocks = new Block[columnAdaptations.size()]; + for (int i = 0; i < columnAdaptations.size(); i++) { + blocks[i] = columnAdaptations.get(i).block(page, filePosition); + } + return new Page(page.getPositionCount(), blocks); + } + + static PrestoException handleException(Exception exception) + { + if (exception instanceof PrestoException) { + return (PrestoException) exception; } + throw new PrestoException(RAPTOR_ERROR, exception); } @Override public void close() { + if (closed) { + return; + } closed = true; try { @@ -192,8 +149,7 @@ public void close() public String toString() { return toStringHelper(this) - .add("columnNames", columnIds) - .add("types", types) + .add("columns", columnAdaptations) .toString(); } @@ -233,48 +189,166 @@ private void closeWithSuppression(Throwable throwable) } } - private static Block buildSequenceBlock(long start, int count) + public interface ColumnAdaptation { - BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(count); - for (int i = 0; i < count; i++) { - BIGINT.writeLong(builder, start + i); + Block block(Page sourcePage, long filePosition); + + static ColumnAdaptation nullColumn(Type type) + { + return new NullColumn(type); + } + + static ColumnAdaptation shardUuidColumn(UUID shardUuid) + { + return new ShardUuidAdaptation(shardUuid); + } + + static ColumnAdaptation bucketNumberColumn(OptionalInt bucketNumber) + { + if (!bucketNumber.isPresent()) { + return nullColumn(INTEGER); + } + return new BucketNumberColumn(bucketNumber.getAsInt()); + } + + static ColumnAdaptation rowIdColumn() + { + return new RowIdColumn(); + } + + static ColumnAdaptation sourceColumn(int index) + { + return new SourceColumn(index); } - return builder.build(); } - private static Block buildSingleValueBlock(Type type, Object value) + private static class ShardUuidAdaptation + implements ColumnAdaptation { - Block block = nativeValueToBlock(type, value); - return new RunLengthEncodedBlock(block, MAX_BATCH_SIZE); + private final Block shardUuidBlock; + + public ShardUuidAdaptation(UUID shardUuid) + { + Slice slice = utf8Slice(shardUuid.toString()); + BlockBuilder blockBuilder = SHARD_UUID_COLUMN_TYPE.createBlockBuilder(null, 1, slice.length()); + SHARD_UUID_COLUMN_TYPE.writeSlice(blockBuilder, slice); + this.shardUuidBlock = blockBuilder.build(); + } + + @Override + public Block block(Page sourcePage, long filePosition) + { + return new RunLengthEncodedBlock(shardUuidBlock, sourcePage.getPositionCount()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .toString(); + } + } + + private static class RowIdColumn + implements ColumnAdaptation + { + @Override + public Block block(Page sourcePage, long filePosition) + { + int count = sourcePage.getPositionCount(); + BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(count); + for (int i = 0; i < count; i++) { + BIGINT.writeLong(builder, filePosition + i); + } + return builder.build(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .toString(); + } } - private final class OrcBlockLoader - implements LazyBlockLoader + private static class NullColumn + implements ColumnAdaptation { - private final int expectedBatchId = batchId; - private final int columnIndex; - private boolean loaded; + private final Type type; + private final Block nullBlock; - public OrcBlockLoader(int columnIndex) + public NullColumn(Type type) { - this.columnIndex = columnIndex; + this.type = requireNonNull(type, "type is null"); + this.nullBlock = type.createBlockBuilder(null, 1, 0) + .appendNull() + .build(); } @Override - public final void load(LazyBlock lazyBlock) + public Block block(Page sourcePage, long filePosition) { - checkState(!loaded, "Already loaded"); - checkState(batchId == expectedBatchId); + return new RunLengthEncodedBlock(nullBlock, sourcePage.getPositionCount()); + } - try { - Block block = recordReader.readBlock(columnIndex); - lazyBlock.setBlock(block); - } - catch (IOException e) { - throw new PrestoException(RAPTOR_ERROR, e); - } + @Override + public String toString() + { + return toStringHelper(this) + .add("type", type) + .toString(); + } + } + + private static class BucketNumberColumn + implements ColumnAdaptation + { + private final Block bucketNumberBlock; - loaded = true; + public BucketNumberColumn(int bucketNumber) + { + BlockBuilder blockBuilder = INTEGER.createFixedSizeBlockBuilder(1); + INTEGER.writeLong(blockBuilder, bucketNumber); + this.bucketNumberBlock = blockBuilder.build(); + } + + @Override + public Block block(Page sourcePage, long filePosition) + { + return new RunLengthEncodedBlock(bucketNumberBlock, sourcePage.getPositionCount()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .toString(); + } + } + + private static class SourceColumn + implements ColumnAdaptation + { + private final int index; + + public SourceColumn(int index) + { + checkArgument(index >= 0, "index is negative"); + this.index = index; + } + + @Override + public Block block(Page sourcePage, long filePosition) + { + return sourcePage.getBlock(index); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("index", index) + .toString(); } } } diff --git a/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/OrcStorageManager.java b/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/OrcStorageManager.java index b1ffd853e109..83cf9f72708f 100644 --- a/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/OrcStorageManager.java +++ b/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/OrcStorageManager.java @@ -15,7 +15,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.json.JsonCodec; import io.airlift.slice.Slice; @@ -25,13 +24,16 @@ import io.airlift.units.Duration; import io.prestosql.memory.context.AggregatedMemoryContext; import io.prestosql.orc.FileOrcDataSource; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcDataSource; import io.prestosql.orc.OrcPredicate; import io.prestosql.orc.OrcReader; import io.prestosql.orc.OrcReaderOptions; import io.prestosql.orc.OrcRecordReader; import io.prestosql.orc.TupleDomainOrcPredicate; -import io.prestosql.orc.TupleDomainOrcPredicate.ColumnReference; +import io.prestosql.orc.TupleDomainOrcPredicate.TupleDomainOrcPredicateBuilder; +import io.prestosql.orc.metadata.ColumnMetadata; +import io.prestosql.orc.metadata.OrcColumnId; import io.prestosql.orc.metadata.OrcType; import io.prestosql.plugin.raptor.legacy.RaptorColumnHandle; import io.prestosql.plugin.raptor.legacy.RaptorConnectorId; @@ -43,6 +45,7 @@ import io.prestosql.plugin.raptor.legacy.metadata.ShardInfo; import io.prestosql.plugin.raptor.legacy.metadata.ShardRecorder; import io.prestosql.plugin.raptor.legacy.storage.OrcFileRewriter.OrcFileInfo; +import io.prestosql.plugin.raptor.legacy.storage.OrcPageSource.ColumnAdaptation; import io.prestosql.spi.NodeManager; import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; @@ -93,6 +96,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Maps.uniqueIndex; import static io.airlift.concurrent.MoreFutures.allAsList; import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.airlift.concurrent.Threads.daemonThreadsNamed; @@ -100,6 +104,7 @@ import static io.airlift.units.DataSize.Unit.PETABYTE; import static io.prestosql.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.prestosql.orc.OrcReader.INITIAL_BATCH_SIZE; +import static io.prestosql.orc.metadata.OrcColumnId.ROOT_COLUMN; import static io.prestosql.plugin.raptor.legacy.RaptorColumnHandle.isBucketNumberColumn; import static io.prestosql.plugin.raptor.legacy.RaptorColumnHandle.isHiddenColumn; import static io.prestosql.plugin.raptor.legacy.RaptorColumnHandle.isShardRowIdColumn; @@ -108,10 +113,6 @@ import static io.prestosql.plugin.raptor.legacy.RaptorErrorCode.RAPTOR_LOCAL_DISK_FULL; import static io.prestosql.plugin.raptor.legacy.RaptorErrorCode.RAPTOR_RECOVERY_ERROR; import static io.prestosql.plugin.raptor.legacy.RaptorErrorCode.RAPTOR_RECOVERY_TIMEOUT; -import static io.prestosql.plugin.raptor.legacy.storage.OrcPageSource.BUCKET_NUMBER_COLUMN; -import static io.prestosql.plugin.raptor.legacy.storage.OrcPageSource.NULL_COLUMN; -import static io.prestosql.plugin.raptor.legacy.storage.OrcPageSource.ROWID_COLUMN; -import static io.prestosql.plugin.raptor.legacy.storage.OrcPageSource.SHARD_UUID_COLUMN; import static io.prestosql.plugin.raptor.legacy.storage.ShardStats.computeColumnStats; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.BooleanType.BOOLEAN; @@ -244,36 +245,48 @@ public ConnectorPageSource getPageSource( try { OrcReader reader = new OrcReader(dataSource, orcReaderOptions); - Map indexMap = columnIdIndex(reader.getColumnNames()); - ImmutableMap.Builder includedColumns = ImmutableMap.builder(); - ImmutableList.Builder columnIndexes = ImmutableList.builder(); + Map indexMap = columnIdIndex(reader.getRootColumn().getNestedColumns()); + List fileReadColumn = new ArrayList<>(columnIds.size()); + List fileReadTypes = new ArrayList<>(columnIds.size()); + List columnAdaptations = new ArrayList<>(columnIds.size()); for (int i = 0; i < columnIds.size(); i++) { long columnId = columnIds.get(i); + if (isHiddenColumn(columnId)) { - columnIndexes.add(toSpecialIndex(columnId)); + columnAdaptations.add(specialColumnAdaptation(columnId, shardUuid, bucketNumber)); continue; } - Integer index = indexMap.get(columnId); - if (index == null) { - columnIndexes.add(NULL_COLUMN); + Type type = toOrcFileType(columnTypes.get(i), typeManager); + OrcColumn fileColumn = indexMap.get(columnId); + if (fileColumn == null) { + columnAdaptations.add(ColumnAdaptation.nullColumn(type)); } else { - columnIndexes.add(index); - includedColumns.put(index, toOrcFileType(columnTypes.get(i), typeManager)); + int sourceIndex = fileReadColumn.size(); + columnAdaptations.add(ColumnAdaptation.sourceColumn(sourceIndex)); + fileReadColumn.add(fileColumn); + fileReadTypes.add(type); } } OrcPredicate predicate = getPredicate(effectivePredicate, indexMap); - OrcRecordReader recordReader = reader.createRecordReader(includedColumns.build(), predicate, UTC, systemMemoryUsage, INITIAL_BATCH_SIZE); + OrcRecordReader recordReader = reader.createRecordReader( + fileReadColumn, + fileReadTypes, + predicate, + UTC, + systemMemoryUsage, + INITIAL_BATCH_SIZE, + OrcPageSource::handleException); Optional shardRewriter = Optional.empty(); if (transactionId.isPresent()) { shardRewriter = Optional.of(createShardRewriter(transactionId.getAsLong(), bucketNumber, shardUuid)); } - return new OrcPageSource(shardRewriter, recordReader, dataSource, columnIds, columnTypes, columnIndexes.build(), shardUuid, bucketNumber, systemMemoryUsage); + return new OrcPageSource(shardRewriter, recordReader, columnAdaptations, dataSource, systemMemoryUsage); } catch (IOException | RuntimeException e) { closeQuietly(dataSource); @@ -285,16 +298,16 @@ public ConnectorPageSource getPageSource( } } - private static int toSpecialIndex(long columnId) + private static ColumnAdaptation specialColumnAdaptation(long columnId, UUID shardUuid, OptionalInt bucketNumber) { if (isShardRowIdColumn(columnId)) { - return ROWID_COLUMN; + return ColumnAdaptation.rowIdColumn(); } if (isShardUuidColumn(columnId)) { - return SHARD_UUID_COLUMN; + return ColumnAdaptation.shardUuidColumn(shardUuid); } if (isBucketNumberColumn(columnId)) { - return BUCKET_NUMBER_COLUMN; + return ColumnAdaptation.bucketNumberColumn(bucketNumber); } throw new PrestoException(RAPTOR_ERROR, "Invalid column ID: " + columnId); } @@ -458,9 +471,9 @@ private List getColumnInfo(OrcReader reader) return getColumnInfoFromOrcColumnTypes(reader.getColumnNames(), reader.getFooter().getTypes()); } - private List getColumnInfoFromOrcColumnTypes(List orcColumnNames, List orcColumnTypes) + private List getColumnInfoFromOrcColumnTypes(List orcColumnNames, ColumnMetadata orcColumnTypes) { - Type rowType = getType(orcColumnTypes, 0); + Type rowType = getType(orcColumnTypes, ROOT_COLUMN); if (orcColumnNames.size() != rowType.getTypeParameters().size()) { throw new PrestoException(RAPTOR_ERROR, "Column names and types do not match"); } @@ -497,9 +510,9 @@ private List getColumnInfoFromOrcUserMetadata(OrcFileMetadata orcFil .collect(toList()); } - private Type getType(List types, int index) + private Type getType(ColumnMetadata types, OrcColumnId columnId) { - OrcType type = types.get(index); + OrcType type = types.get(columnId); switch (type.getOrcTypeKind()) { case BOOLEAN: return BOOLEAN; @@ -561,25 +574,21 @@ static Type toOrcFileType(Type raptorType, TypeManager typeManager) return raptorType; } - private static OrcPredicate getPredicate(TupleDomain effectivePredicate, Map indexMap) + private static OrcPredicate getPredicate(TupleDomain effectivePredicate, Map indexMap) { - ImmutableList.Builder> columns = ImmutableList.builder(); - for (RaptorColumnHandle column : effectivePredicate.getDomains().get().keySet()) { - Integer index = indexMap.get(column.getColumnId()); - if (index != null) { - columns.add(new ColumnReference<>(column, index, column.getColumnType())); + TupleDomainOrcPredicateBuilder predicateBuilder = TupleDomainOrcPredicate.builder(); + effectivePredicate.getDomains().get().forEach((columnHandle, value) -> { + OrcColumn fileColumn = indexMap.get(columnHandle.getColumnId()); + if (fileColumn != null) { + predicateBuilder.addColumn(fileColumn.getColumnId(), value); } - } - return new TupleDomainOrcPredicate<>(effectivePredicate, columns.build(), false); + }); + return predicateBuilder.build(); } - private static Map columnIdIndex(List columnNames) + private static Map columnIdIndex(List columns) { - ImmutableMap.Builder map = ImmutableMap.builder(); - for (int i = 0; i < columnNames.size(); i++) { - map.put(Long.valueOf(columnNames.get(i)), i); - } - return map.build(); + return uniqueIndex(columns, column -> Long.valueOf(column.getColumnName())); } private class OrcStoragePageSink diff --git a/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/ShardStats.java b/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/ShardStats.java index 3c176ddf67dc..1b095acab965 100644 --- a/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/ShardStats.java +++ b/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/ShardStats.java @@ -13,12 +13,14 @@ */ package io.prestosql.plugin.raptor.legacy.storage; -import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; +import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcPredicate; import io.prestosql.orc.OrcReader; import io.prestosql.orc.OrcRecordReader; import io.prestosql.plugin.raptor.legacy.metadata.ColumnStats; +import io.prestosql.spi.Page; import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.Block; import io.prestosql.spi.type.BigintType; @@ -68,36 +70,44 @@ public static Optional computeColumnStats(OrcReader orcReader, long private static ColumnStats doComputeColumnStats(OrcReader orcReader, long columnId, Type type, TypeManager typeManager) throws IOException { - int columnIndex = columnIndex(orcReader.getColumnNames(), columnId); - OrcRecordReader reader = orcReader.createRecordReader(ImmutableMap.of(columnIndex, toOrcFileType(type, typeManager)), OrcPredicate.TRUE, UTC, newSimpleAggregatedMemoryContext(), INITIAL_BATCH_SIZE); + OrcColumn column = getColumn(orcReader.getRootColumn().getNestedColumns(), columnId); + Type columnType = toOrcFileType(type, typeManager); + OrcRecordReader reader = orcReader.createRecordReader( + ImmutableList.of(column), + ImmutableList.of(columnType), + OrcPredicate.TRUE, + UTC, + newSimpleAggregatedMemoryContext(), + INITIAL_BATCH_SIZE, + exception -> new PrestoException(RAPTOR_ERROR, "Error reading column: " + columnId, exception)); if (type.equals(BooleanType.BOOLEAN)) { - return indexBoolean(reader, columnIndex, columnId); + return indexBoolean(reader, columnId); } if (type.equals(BigintType.BIGINT) || type.equals(DateType.DATE) || type.equals(TimestampType.TIMESTAMP)) { - return indexLong(type, reader, columnIndex, columnId); + return indexLong(type, reader, columnId); } if (type.equals(DoubleType.DOUBLE)) { - return indexDouble(reader, columnIndex, columnId); + return indexDouble(reader, columnId); } if (type instanceof VarcharType) { - return indexString(type, reader, columnIndex, columnId); + return indexString(type, reader, columnId); } return null; } - private static int columnIndex(List columnNames, long columnId) + private static OrcColumn getColumn(List columnNames, long columnId) { - int index = columnNames.indexOf(String.valueOf(columnId)); - if (index == -1) { - throw new PrestoException(RAPTOR_ERROR, "Missing column ID: " + columnId); - } - return index; + String columnName = String.valueOf(columnId); + return columnNames.stream() + .filter(column -> column.getColumnName().equals(columnName)) + .findFirst() + .orElseThrow(() -> new PrestoException(RAPTOR_ERROR, "Missing column ID: " + columnId)); } - private static ColumnStats indexBoolean(OrcRecordReader reader, int columnIndex, long columnId) + private static ColumnStats indexBoolean(OrcRecordReader reader, long columnId) throws IOException { boolean minSet = false; @@ -106,13 +116,13 @@ private static ColumnStats indexBoolean(OrcRecordReader reader, int columnIndex, boolean max = false; while (true) { - int batchSize = reader.nextBatch(); - if (batchSize <= 0) { + Page page = reader.nextPage(); + if (page == null) { break; } - Block block = reader.readBlock(columnIndex); + Block block = page.getBlock(0).getLoadedBlock(); - for (int i = 0; i < batchSize; i++) { + for (int i = 0; i < page.getPositionCount(); i++) { if (block.isNull(i)) { continue; } @@ -133,7 +143,7 @@ private static ColumnStats indexBoolean(OrcRecordReader reader, int columnIndex, maxSet ? max : null); } - private static ColumnStats indexLong(Type type, OrcRecordReader reader, int columnIndex, long columnId) + private static ColumnStats indexLong(Type type, OrcRecordReader reader, long columnId) throws IOException { boolean minSet = false; @@ -142,13 +152,13 @@ private static ColumnStats indexLong(Type type, OrcRecordReader reader, int colu long max = 0; while (true) { - int batchSize = reader.nextBatch(); - if (batchSize <= 0) { + Page page = reader.nextPage(); + if (page == null) { break; } - Block block = reader.readBlock(columnIndex); + Block block = page.getBlock(0).getLoadedBlock(); - for (int i = 0; i < batchSize; i++) { + for (int i = 0; i < page.getPositionCount(); i++) { if (block.isNull(i)) { continue; } @@ -169,7 +179,7 @@ private static ColumnStats indexLong(Type type, OrcRecordReader reader, int colu maxSet ? max : null); } - private static ColumnStats indexDouble(OrcRecordReader reader, int columnIndex, long columnId) + private static ColumnStats indexDouble(OrcRecordReader reader, long columnId) throws IOException { boolean minSet = false; @@ -178,13 +188,13 @@ private static ColumnStats indexDouble(OrcRecordReader reader, int columnIndex, double max = 0; while (true) { - int batchSize = reader.nextBatch(); - if (batchSize <= 0) { + Page page = reader.nextPage(); + if (page == null) { break; } - Block block = reader.readBlock(columnIndex); + Block block = page.getBlock(0).getLoadedBlock(); - for (int i = 0; i < batchSize; i++) { + for (int i = 0; i < page.getPositionCount(); i++) { if (block.isNull(i)) { continue; } @@ -218,7 +228,7 @@ private static ColumnStats indexDouble(OrcRecordReader reader, int columnIndex, maxSet ? max : null); } - private static ColumnStats indexString(Type type, OrcRecordReader reader, int columnIndex, long columnId) + private static ColumnStats indexString(Type type, OrcRecordReader reader, long columnId) throws IOException { boolean minSet = false; @@ -227,13 +237,13 @@ private static ColumnStats indexString(Type type, OrcRecordReader reader, int co Slice max = null; while (true) { - int batchSize = reader.nextBatch(); - if (batchSize <= 0) { + Page page = reader.nextPage(); + if (page == null) { break; } - Block block = reader.readBlock(columnIndex); + Block block = page.getBlock(0).getLoadedBlock(); - for (int i = 0; i < batchSize; i++) { + for (int i = 0; i < page.getPositionCount(); i++) { if (block.isNull(i)) { continue; } diff --git a/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/StorageManagerConfig.java b/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/StorageManagerConfig.java index 1df15bf6f608..78ac31d94e6f 100644 --- a/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/StorageManagerConfig.java +++ b/presto-raptor-legacy/src/main/java/io/prestosql/plugin/raptor/legacy/storage/StorageManagerConfig.java @@ -163,6 +163,21 @@ public StorageManagerConfig setOrcLazyReadSmallRanges(boolean orcLazyReadSmallRa return this; } + @Deprecated + public boolean isOrcNestedLazy() + { + return options.isNestedLazy(); + } + + // TODO remove config option once efficacy is proven + @Deprecated + @Config("storage.orc.nested-lazy") + public StorageManagerConfig setOrcNestedLazy(boolean nestedLazy) + { + options = options.withNestedLazy(nestedLazy); + return this; + } + @Min(1) public int getDeletionThreads() { diff --git a/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/OrcTestingUtil.java b/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/OrcTestingUtil.java index 9094d2c4c30b..fe80e48532eb 100644 --- a/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/OrcTestingUtil.java +++ b/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/OrcTestingUtil.java @@ -13,11 +13,9 @@ */ package io.prestosql.plugin.raptor.legacy.storage; -import com.google.common.collect.ImmutableMap; import com.google.common.primitives.UnsignedBytes; import io.airlift.units.DataSize; import io.prestosql.orc.FileOrcDataSource; -import io.prestosql.orc.OrcCorruptionException; import io.prestosql.orc.OrcDataSource; import io.prestosql.orc.OrcPredicate; import io.prestosql.orc.OrcReader; @@ -29,9 +27,7 @@ import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; -import java.util.HashMap; import java.util.List; -import java.util.Map; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.prestosql.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; @@ -42,7 +38,7 @@ final class OrcTestingUtil { private OrcTestingUtil() {} - public static final OrcReaderOptions READER_OPTIONS = new OrcReaderOptions() + private static final OrcReaderOptions READER_OPTIONS = new OrcReaderOptions() .withMaxReadBlockSize(new DataSize(1, MEGABYTE)) .withMaxMergeDistance(new DataSize(1, MEGABYTE)) .withMaxBufferSize(new DataSize(1, MEGABYTE)) @@ -63,31 +59,14 @@ public static OrcRecordReader createReader(OrcDataSource dataSource, List List columnNames = orcReader.getColumnNames(); assertEquals(columnNames.size(), columnIds.size()); - Map includedColumns = new HashMap<>(); - int ordinal = 0; - for (long columnId : columnIds) { - assertEquals(columnNames.get(ordinal), String.valueOf(columnId)); - includedColumns.put(ordinal, types.get(ordinal)); - ordinal++; - } - - return createRecordReader(orcReader, includedColumns); - } - - public static OrcRecordReader createReaderNoRows(OrcDataSource dataSource) - throws IOException - { - OrcReader orcReader = new OrcReader(dataSource, READER_OPTIONS); - - assertEquals(orcReader.getColumnNames().size(), 0); - - return createRecordReader(orcReader, ImmutableMap.of()); - } - - public static OrcRecordReader createRecordReader(OrcReader orcReader, Map includedColumns) - throws OrcCorruptionException - { - return orcReader.createRecordReader(includedColumns, OrcPredicate.TRUE, DateTimeZone.UTC, newSimpleAggregatedMemoryContext(), MAX_BATCH_SIZE); + return orcReader.createRecordReader( + orcReader.getRootColumn().getNestedColumns(), + types, + OrcPredicate.TRUE, + DateTimeZone.UTC, + newSimpleAggregatedMemoryContext(), + MAX_BATCH_SIZE, + RuntimeException::new); } public static byte[] octets(int... values) diff --git a/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestOrcFileRewriter.java b/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestOrcFileRewriter.java index e57b79b5b584..bb983a8ac161 100644 --- a/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestOrcFileRewriter.java +++ b/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestOrcFileRewriter.java @@ -58,6 +58,7 @@ import static java.util.UUID.randomUUID; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; @Test(singleThreaded = true) @@ -113,9 +114,10 @@ public void testRewrite() assertEquals(reader.getFileRowCount(), 5); assertEquals(reader.getSplitLength(), file.length()); - assertEquals(reader.nextBatch(), 5); + Page page = reader.nextPage(); + assertEquals(page.getPositionCount(), 5); - Block column0 = reader.readBlock(0); + Block column0 = page.getBlock(0); assertEquals(column0.getPositionCount(), 5); for (int i = 0; i < 5; i++) { assertEquals(column0.isNull(i), false); @@ -126,7 +128,7 @@ public void testRewrite() assertEquals(BIGINT.getLong(column0, 3), 888L); assertEquals(BIGINT.getLong(column0, 4), 999L); - Block column1 = reader.readBlock(1); + Block column1 = page.getBlock(1); assertEquals(column1.getPositionCount(), 5); for (int i = 0; i < 5; i++) { assertEquals(column1.isNull(i), false); @@ -137,7 +139,7 @@ public void testRewrite() assertEquals(createVarcharType(20).getSlice(column1, 3), utf8Slice("world")); assertEquals(createVarcharType(20).getSlice(column1, 4), utf8Slice("done")); - Block column2 = reader.readBlock(2); + Block column2 = page.getBlock(2); assertEquals(column2.getPositionCount(), 5); for (int i = 0; i < 5; i++) { assertEquals(column2.isNull(i), false); @@ -148,7 +150,7 @@ public void testRewrite() assertTrue(arrayBlocksEqual(BIGINT, arrayType.getObject(column2, 3), arrayBlockOf(BIGINT, 7, 8))); assertTrue(arrayBlocksEqual(BIGINT, arrayType.getObject(column2, 4), arrayBlockOf(BIGINT, 9, 10))); - Block column3 = reader.readBlock(3); + Block column3 = page.getBlock(3); assertEquals(column3.getPositionCount(), 5); for (int i = 0; i < 5; i++) { assertEquals(column3.isNull(i), false); @@ -159,7 +161,7 @@ public void testRewrite() assertTrue(mapBlocksEqual(createVarcharType(5), BOOLEAN, arrayType.getObject(column3, 3), mapBlockOf(createVarcharType(5), BOOLEAN, "k4", true))); assertTrue(mapBlocksEqual(createVarcharType(5), BOOLEAN, arrayType.getObject(column3, 4), mapBlockOf(createVarcharType(5), BOOLEAN, "k5", true))); - Block column4 = reader.readBlock(4); + Block column4 = page.getBlock(4); assertEquals(column4.getPositionCount(), 5); for (int i = 0; i < 5; i++) { assertEquals(column4.isNull(i), false); @@ -170,7 +172,7 @@ public void testRewrite() assertTrue(arrayBlocksEqual(arrayType, arrayOfArrayType.getObject(column4, 3), arrayBlockOf(arrayType, null, arrayBlockOf(BIGINT, 8), null))); assertTrue(arrayBlocksEqual(arrayType, arrayOfArrayType.getObject(column4, 4), arrayBlockOf(arrayType, arrayBlockOf(BIGINT, 9, 10)))); - assertEquals(reader.nextBatch(), -1); + assertNull(reader.nextPage()); OrcFileMetadata orcFileMetadata = METADATA_CODEC.fromJson(reader.getUserMetadata().get(OrcFileMetadata.KEY).getBytes()); assertEquals(orcFileMetadata, new OrcFileMetadata(ImmutableMap.builder() @@ -200,9 +202,10 @@ public void testRewrite() assertEquals(reader.getFileRowCount(), 2); assertEquals(reader.getSplitLength(), newFile.length()); - assertEquals(reader.nextBatch(), 2); + Page page = reader.nextPage(); + assertEquals(page.getPositionCount(), 2); - Block column0 = reader.readBlock(0); + Block column0 = page.getBlock(0); assertEquals(column0.getPositionCount(), 2); for (int i = 0; i < 2; i++) { assertEquals(column0.isNull(i), false); @@ -210,7 +213,7 @@ public void testRewrite() assertEquals(BIGINT.getLong(column0, 0), 123L); assertEquals(BIGINT.getLong(column0, 1), 456L); - Block column1 = reader.readBlock(1); + Block column1 = page.getBlock(1); assertEquals(column1.getPositionCount(), 2); for (int i = 0; i < 2; i++) { assertEquals(column1.isNull(i), false); @@ -218,7 +221,7 @@ public void testRewrite() assertEquals(createVarcharType(20).getSlice(column1, 0), utf8Slice("hello")); assertEquals(createVarcharType(20).getSlice(column1, 1), utf8Slice("bye")); - Block column2 = reader.readBlock(2); + Block column2 = page.getBlock(2); assertEquals(column2.getPositionCount(), 2); for (int i = 0; i < 2; i++) { assertEquals(column2.isNull(i), false); @@ -226,7 +229,7 @@ public void testRewrite() assertTrue(arrayBlocksEqual(BIGINT, arrayType.getObject(column2, 0), arrayBlockOf(BIGINT, 1, 2))); assertTrue(arrayBlocksEqual(BIGINT, arrayType.getObject(column2, 1), arrayBlockOf(BIGINT, 5, 6))); - Block column3 = reader.readBlock(3); + Block column3 = page.getBlock(3); assertEquals(column3.getPositionCount(), 2); for (int i = 0; i < 2; i++) { assertEquals(column3.isNull(i), false); @@ -234,7 +237,7 @@ public void testRewrite() assertTrue(mapBlocksEqual(createVarcharType(5), BOOLEAN, arrayType.getObject(column3, 0), mapBlockOf(createVarcharType(5), BOOLEAN, "k1", true))); assertTrue(mapBlocksEqual(createVarcharType(5), BOOLEAN, arrayType.getObject(column3, 1), mapBlockOf(createVarcharType(5), BOOLEAN, "k3", true))); - Block column4 = reader.readBlock(4); + Block column4 = page.getBlock(4); assertEquals(column4.getPositionCount(), 2); for (int i = 0; i < 2; i++) { assertEquals(column4.isNull(i), false); @@ -242,7 +245,7 @@ public void testRewrite() assertTrue(arrayBlocksEqual(arrayType, arrayOfArrayType.getObject(column4, 0), arrayBlockOf(arrayType, arrayBlockOf(BIGINT, 5)))); assertTrue(arrayBlocksEqual(arrayType, arrayOfArrayType.getObject(column4, 1), arrayBlockOf(arrayType, arrayBlockOf(BIGINT, 7)))); - assertEquals(reader.nextBatch(), -1); + assertEquals(reader.nextPage(), null); OrcFileMetadata orcFileMetadata = METADATA_CODEC.fromJson(reader.getUserMetadata().get(OrcFileMetadata.KEY).getBytes()); assertEquals(orcFileMetadata, new OrcFileMetadata(ImmutableMap.builder() @@ -279,9 +282,10 @@ public void testRewriteWithoutMetadata() assertEquals(reader.getFileRowCount(), 2); assertEquals(reader.getSplitLength(), file.length()); - assertEquals(reader.nextBatch(), 2); + Page page = reader.nextPage(); + assertEquals(page.getPositionCount(), 2); - Block column0 = reader.readBlock(0); + Block column0 = page.getBlock(0); assertEquals(column0.getPositionCount(), 2); for (int i = 0; i < 2; i++) { assertEquals(column0.isNull(i), false); @@ -289,7 +293,7 @@ public void testRewriteWithoutMetadata() assertEquals(BIGINT.getLong(column0, 0), 123L); assertEquals(BIGINT.getLong(column0, 1), 777L); - Block column1 = reader.readBlock(1); + Block column1 = page.getBlock(1); assertEquals(column1.getPositionCount(), 2); for (int i = 0; i < 2; i++) { assertEquals(column1.isNull(i), false); @@ -315,14 +319,15 @@ public void testRewriteWithoutMetadata() assertEquals(reader.getFileRowCount(), 1); assertEquals(reader.getSplitLength(), newFile.length()); - assertEquals(reader.nextBatch(), 1); + Page page = reader.nextPage(); + assertEquals(page.getPositionCount(), 1); - Block column0 = reader.readBlock(0); + Block column0 = page.getBlock(0); assertEquals(column0.getPositionCount(), 1); assertEquals(column0.isNull(0), false); assertEquals(BIGINT.getLong(column0, 0), 123L); - Block column1 = reader.readBlock(1); + Block column1 = page.getBlock(1); assertEquals(column1.getPositionCount(), 1); assertEquals(column1.isNull(0), false); assertEquals(createVarcharType(20).getSlice(column1, 0), utf8Slice("hello")); diff --git a/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestOrcStorageManager.java b/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestOrcStorageManager.java index de016bfa28c1..fdbea7a9422e 100644 --- a/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestOrcStorageManager.java +++ b/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestOrcStorageManager.java @@ -104,6 +104,7 @@ import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotEquals; import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; import static org.testng.FileAssert.assertDirectory; @@ -228,19 +229,20 @@ public void testWriter() try (OrcDataSource dataSource = manager.openShard(shardUuid, READER_OPTIONS)) { OrcRecordReader reader = createReader(dataSource, columnIds, columnTypes); - assertEquals(reader.nextBatch(), 2); + Page page = reader.nextPage(); + assertEquals(page.getPositionCount(), 2); - Block column0 = reader.readBlock(0); + Block column0 = page.getBlock(0); assertEquals(column0.isNull(0), false); assertEquals(column0.isNull(1), false); assertEquals(BIGINT.getLong(column0, 0), 123L); assertEquals(BIGINT.getLong(column0, 1), 456L); - Block column1 = reader.readBlock(1); + Block column1 = page.getBlock(1); assertEquals(createVarcharType(10).getSlice(column1, 0), utf8Slice("hello")); assertEquals(createVarcharType(10).getSlice(column1, 1), utf8Slice("bye")); - assertEquals(reader.nextBatch(), -1); + assertNull(reader.nextPage()); } } diff --git a/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestShardWriter.java b/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestShardWriter.java index 4ca2b847ca5e..9b123d97ffef 100644 --- a/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestShardWriter.java +++ b/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestShardWriter.java @@ -19,6 +19,7 @@ import io.prestosql.RowPagesBuilder; import io.prestosql.orc.OrcDataSource; import io.prestosql.orc.OrcRecordReader; +import io.prestosql.spi.Page; import io.prestosql.spi.block.Block; import io.prestosql.spi.classloader.ThreadContextClassLoader; import io.prestosql.spi.type.ArrayType; @@ -54,6 +55,7 @@ import static io.prestosql.tests.StructuralTestUtil.mapBlocksEqual; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; public class TestShardWriter @@ -108,28 +110,29 @@ public void testWriter() assertEquals(reader.getFileRowCount(), reader.getReaderRowCount()); assertEquals(reader.getFilePosition(), reader.getFilePosition()); - assertEquals(reader.nextBatch(), 3); + Page page = reader.nextPage(); + assertEquals(page.getPositionCount(), 3); assertEquals(reader.getReaderPosition(), 0); assertEquals(reader.getFilePosition(), reader.getFilePosition()); - Block column0 = reader.readBlock(0); + Block column0 = page.getBlock(0); assertEquals(column0.isNull(0), false); assertEquals(column0.isNull(1), true); assertEquals(column0.isNull(2), false); assertEquals(BIGINT.getLong(column0, 0), 123L); assertEquals(BIGINT.getLong(column0, 2), 456L); - Block column1 = reader.readBlock(1); + Block column1 = page.getBlock(1); assertEquals(createVarcharType(10).getSlice(column1, 0), utf8Slice("hello")); assertEquals(createVarcharType(10).getSlice(column1, 1), utf8Slice("world")); assertEquals(createVarcharType(10).getSlice(column1, 2), utf8Slice("bye \u2603")); - Block column2 = reader.readBlock(2); + Block column2 = page.getBlock(2); assertEquals(VARBINARY.getSlice(column2, 0), wrappedBuffer(bytes1)); assertEquals(column2.isNull(1), true); assertEquals(VARBINARY.getSlice(column2, 2), wrappedBuffer(bytes3)); - Block column3 = reader.readBlock(3); + Block column3 = page.getBlock(3); assertEquals(column3.isNull(0), false); assertEquals(column3.isNull(1), false); assertEquals(column3.isNull(2), false); @@ -137,21 +140,21 @@ public void testWriter() assertEquals(DOUBLE.getDouble(column3, 1), Double.POSITIVE_INFINITY); assertEquals(DOUBLE.getDouble(column3, 2), Double.NaN); - Block column4 = reader.readBlock(4); + Block column4 = page.getBlock(4); assertEquals(column4.isNull(0), false); assertEquals(column4.isNull(1), true); assertEquals(column4.isNull(2), false); assertEquals(BOOLEAN.getBoolean(column4, 0), true); assertEquals(BOOLEAN.getBoolean(column4, 2), false); - Block column5 = reader.readBlock(5); + Block column5 = page.getBlock(5); assertEquals(column5.getPositionCount(), 3); assertTrue(arrayBlocksEqual(BIGINT, arrayType.getObject(column5, 0), arrayBlockOf(BIGINT, 1, 2))); assertTrue(arrayBlocksEqual(BIGINT, arrayType.getObject(column5, 1), arrayBlockOf(BIGINT, 3, null))); assertTrue(arrayBlocksEqual(BIGINT, arrayType.getObject(column5, 2), arrayBlockOf(BIGINT))); - Block column6 = reader.readBlock(6); + Block column6 = page.getBlock(6); assertEquals(column6.getPositionCount(), 3); assertTrue(mapBlocksEqual(createVarcharType(5), BOOLEAN, arrayType.getObject(column6, 0), mapBlockOf(createVarcharType(5), BOOLEAN, "k1", true))); @@ -160,14 +163,14 @@ public void testWriter() assertTrue(mapBlocksEqual(createVarcharType(5), BOOLEAN, object, k2)); assertTrue(mapBlocksEqual(createVarcharType(5), BOOLEAN, arrayType.getObject(column6, 2), mapBlockOf(createVarcharType(5), BOOLEAN, "k3", false))); - Block column7 = reader.readBlock(7); + Block column7 = page.getBlock(7); assertEquals(column7.getPositionCount(), 3); assertTrue(arrayBlocksEqual(arrayType, arrayOfArrayType.getObject(column7, 0), arrayBlockOf(arrayType, arrayBlockOf(BIGINT, 5)))); assertTrue(arrayBlocksEqual(arrayType, arrayOfArrayType.getObject(column7, 1), arrayBlockOf(arrayType, null, arrayBlockOf(BIGINT, 6, 7)))); assertTrue(arrayBlocksEqual(arrayType, arrayOfArrayType.getObject(column7, 2), arrayBlockOf(arrayType, arrayBlockOf(BIGINT)))); - assertEquals(reader.nextBatch(), -1); + assertNull(reader.nextPage()); assertEquals(reader.getReaderPosition(), 3); assertEquals(reader.getFilePosition(), reader.getFilePosition()); diff --git a/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestStorageManagerConfig.java b/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestStorageManagerConfig.java index 8e8097b00ccb..f218eb2625d6 100644 --- a/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestStorageManagerConfig.java +++ b/presto-raptor-legacy/src/test/java/io/prestosql/plugin/raptor/legacy/storage/TestStorageManagerConfig.java @@ -52,6 +52,7 @@ public void testDefaults() .setOrcStreamBufferSize(new DataSize(8, MEGABYTE)) .setOrcTinyStripeThreshold(new DataSize(8, MEGABYTE)) .setOrcLazyReadSmallRanges(true) + .setOrcNestedLazy(true) .setDeletionThreads(max(1, getRuntime().availableProcessors() / 2)) .setShardRecoveryTimeout(new Duration(30, SECONDS)) .setMissingShardDiscoveryInterval(new Duration(5, MINUTES)) @@ -81,6 +82,7 @@ public void testExplicitPropertyMappings() .put("storage.orc.stream-buffer-size", "16kB") .put("storage.orc.tiny-stripe-threshold", "15kB") .put("storage.orc.lazy-read-small-ranges", "false") + .put("storage.orc.nested-lazy", "false") .put("storage.max-deletion-threads", "999") .put("storage.shard-recovery-timeout", "1m") .put("storage.missing-shard-discovery-interval", "4m") @@ -107,6 +109,7 @@ public void testExplicitPropertyMappings() .setOrcStreamBufferSize(new DataSize(16, KILOBYTE)) .setOrcTinyStripeThreshold(new DataSize(15, KILOBYTE)) .setOrcLazyReadSmallRanges(false) + .setOrcNestedLazy(false) .setDeletionThreads(999) .setShardRecoveryTimeout(new Duration(1, MINUTES)) .setMissingShardDiscoveryInterval(new Duration(4, MINUTES))