From 71306218bb55be80535d67c41ffa5672ae3d9ad5 Mon Sep 17 00:00:00 2001 From: Pratham Desai Date: Thu, 16 Apr 2020 01:58:15 -0700 Subject: [PATCH 1/3] Add test for dereferences on data containing null rows --- .../hive/TestHiveIntegrationSmokeTest.java | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java index b48f0809b0c4..9374b9411888 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/TestHiveIntegrationSmokeTest.java @@ -3017,6 +3017,50 @@ private void testRows(Session session, HiveStorageFormat format) assertUpdate(session, "DROP TABLE " + tableName); } + @Test + public void testRowsWithNulls() + { + testRowsWithNulls(getSession(), HiveStorageFormat.ORC); + testRowsWithNulls(getSession(), HiveStorageFormat.PARQUET); + } + + private void testRowsWithNulls(Session session, HiveStorageFormat format) + { + String tableName = "test_dereferences_with_nulls"; + @Language("SQL") String createTable = "" + + "CREATE TABLE " + tableName + "\n" + + "(col0 BIGINT, col1 row(f0 BIGINT, f1 BIGINT), col2 row(f0 BIGINT, f1 ROW(f0 BIGINT, f1 BIGINT)))\n" + + "WITH (format = '" + format + "')"; + + assertUpdate(session, createTable); + + @Language("SQL") String insertTable = "" + + "INSERT INTO " + tableName + " VALUES \n" + + "row(1, row(2, 3), row(4, row(5, 6))),\n" + + "row(7, row(8, 9), row(10, row(11, NULL))),\n" + + "row(NULL, NULL, row(12, NULL)),\n" + + "row(13, row(NULL, 14), NULL),\n" + + "row(15, row(16, NULL), row(NULL, row(17, 18)))"; + + assertUpdate(session, insertTable, 5); + + assertQuery( + session, + format("SELECT col0, col1.f0, col2.f1.f1 FROM %s", tableName), + "SELECT * FROM \n" + + " (SELECT 1, 2, 6) UNION\n" + + " (SELECT 7, 8, NULL) UNION\n" + + " (SELECT NULL, NULL, NULL) UNION\n" + + " (SELECT 13, NULL, NULL) UNION\n" + + " (SELECT 15, 16, 18)"); + + assertQuery(session, format("SELECT col0 FROM %s WHERE col2.f1.f1 IS NOT NULL", tableName), "SELECT * FROM UNNEST(array[1, 15])"); + + assertQuery(session, format("SELECT col0, col1.f0, col1.f1 FROM %s WHERE col2.f1.f1 = 18", tableName), "SELECT 15, 16, NULL"); + + assertUpdate(session, "DROP TABLE " + tableName); + } + @Test public void testComplex() { From cfc6df17179c17e1877fe886c0bee82f34a74b82 Mon Sep 17 00:00:00 2001 From: Pratham Desai Date: Thu, 16 Apr 2020 01:59:05 -0700 Subject: [PATCH 2/3] Pushdown dereference projections in ORC reader --- .../plugin/hive/orc/OrcPageSourceFactory.java | 46 ++++++++++- .../main/java/io/prestosql/orc/OrcReader.java | 80 +++++++++++++++++++ .../io/prestosql/orc/OrcRecordReader.java | 9 ++- .../prestosql/orc/reader/ColumnReaders.java | 10 ++- .../orc/reader/ListColumnReader.java | 3 +- .../prestosql/orc/reader/MapColumnReader.java | 5 +- .../orc/reader/StructColumnReader.java | 9 ++- 7 files changed, 152 insertions(+), 10 deletions(-) diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java index 601f54bee51e..033fcbec1dc8 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java @@ -30,6 +30,7 @@ import io.prestosql.plugin.hive.FileFormatDataSourceStats; import io.prestosql.plugin.hive.HdfsEnvironment; import io.prestosql.plugin.hive.HiveColumnHandle; +import io.prestosql.plugin.hive.HiveColumnProjectionInfo; import io.prestosql.plugin.hive.HivePageSourceFactory; import io.prestosql.plugin.hive.ReaderProjections; import io.prestosql.plugin.hive.orc.OrcPageSource.ColumnAdaptation; @@ -58,12 +59,15 @@ import java.util.Optional; import java.util.Properties; import java.util.regex.Pattern; +import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.nullToEmpty; import static com.google.common.collect.Maps.uniqueIndex; import static io.prestosql.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.prestosql.orc.OrcReader.INITIAL_BATCH_SIZE; +import static io.prestosql.orc.OrcReader.ProjectedLayout.createProjectedLayout; +import static io.prestosql.orc.OrcReader.ProjectedLayout.fullyProjectedLayout; import static io.prestosql.orc.metadata.OrcType.OrcTypeKind.INT; import static io.prestosql.orc.metadata.OrcType.OrcTypeKind.LONG; import static io.prestosql.orc.metadata.OrcType.OrcTypeKind.STRUCT; @@ -90,6 +94,8 @@ import static java.lang.String.format; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.mapping; +import static java.util.stream.Collectors.toList; import static org.apache.hadoop.hive.ql.io.AcidUtils.isFullAcidTable; public class OrcPageSourceFactory @@ -162,6 +168,7 @@ public Optional createPageSource( projectedReaderColumns .map(ReaderProjections::getReaderColumns) .orElse(columns), + columns, isUseOrcColumnNames(session), isFullAcidTable(Maps.fromProperties(schema)), effectivePredicate, @@ -190,6 +197,7 @@ private static OrcPageSource createOrcPageSource( long length, long fileSize, List columns, + List projections, boolean useOrcColumnNames, boolean isFullAcid, TupleDomain effectivePredicate, @@ -229,6 +237,7 @@ private static OrcPageSource createOrcPageSource( List fileColumns = reader.getRootColumn().getNestedColumns(); List fileReadColumns = new ArrayList<>(columns.size() + (isFullAcid ? 3 : 0)); List fileReadTypes = new ArrayList<>(columns.size() + (isFullAcid ? 3 : 0)); + List fileReadLayouts = new ArrayList<>(columns.size() + (isFullAcid ? 3 : 0)); if (isFullAcid) { verifyAcidSchema(reader, path); Map acidColumnsByName = uniqueIndex(fileColumns, orcColumn -> orcColumn.getColumnName().toLowerCase(ENGLISH)); @@ -236,10 +245,15 @@ private static OrcPageSource createOrcPageSource( fileReadColumns.add(acidColumnsByName.get(ACID_COLUMN_ORIGINAL_TRANSACTION.toLowerCase(ENGLISH))); fileReadTypes.add(BIGINT); + fileReadLayouts.add(fullyProjectedLayout()); + fileReadColumns.add(acidColumnsByName.get(ACID_COLUMN_BUCKET.toLowerCase(ENGLISH))); fileReadTypes.add(INTEGER); + fileReadLayouts.add(fullyProjectedLayout()); + fileReadColumns.add(acidColumnsByName.get(ACID_COLUMN_ROW_ID.toLowerCase(ENGLISH))); fileReadTypes.add(BIGINT); + fileReadLayouts.add(fullyProjectedLayout()); } Map fileColumnsByName = ImmutableMap.of(); @@ -250,6 +264,25 @@ private static OrcPageSource createOrcPageSource( fileColumnsByName = uniqueIndex(fileColumns, orcColumn -> orcColumn.getColumnName().toLowerCase(ENGLISH)); } + Map>> projectionsByColumnName = ImmutableMap.of(); + Map>> projectionsByColumnIndex = ImmutableMap.of(); + if (useOrcColumnNames || isFullAcid) { + projectionsByColumnName = projections.stream() + .collect(Collectors.groupingBy( + HiveColumnHandle::getBaseColumnName, + mapping( + column -> column.getHiveColumnProjectionInfo().map(HiveColumnProjectionInfo::getDereferenceNames).orElse(ImmutableList.of()), + toList()))); + } + else { + projectionsByColumnIndex = projections.stream() + .collect(Collectors.groupingBy( + HiveColumnHandle::getBaseHiveColumnIndex, + mapping( + column -> column.getHiveColumnProjectionInfo().map(HiveColumnProjectionInfo::getDereferenceNames).orElse(ImmutableList.of()), + toList()))); + } + TupleDomainOrcPredicateBuilder predicateBuilder = TupleDomainOrcPredicate.builder() .setBloomFiltersEnabled(options.isBloomFiltersEnabled()); Map effectivePredicateDomains = effectivePredicate.getDomains() @@ -257,11 +290,20 @@ private static OrcPageSource createOrcPageSource( List columnAdaptations = new ArrayList<>(columns.size()); for (HiveColumnHandle column : columns) { OrcColumn orcColumn = null; + OrcReader.ProjectedLayout projectedLayout = null; + if (useOrcColumnNames || isFullAcid) { - orcColumn = fileColumnsByName.get(column.getName().toLowerCase(ENGLISH)); + String columnName = column.getName().toLowerCase(ENGLISH); + orcColumn = fileColumnsByName.get(columnName); + if (orcColumn != null) { + projectedLayout = createProjectedLayout(orcColumn, projectionsByColumnName.get(columnName)); + } } else if (column.getBaseHiveColumnIndex() < fileColumns.size()) { orcColumn = fileColumns.get(column.getBaseHiveColumnIndex()); + if (orcColumn != null) { + projectedLayout = createProjectedLayout(orcColumn, projectionsByColumnIndex.get(column.getBaseHiveColumnIndex())); + } } Type readType = column.getType(); @@ -270,6 +312,7 @@ else if (column.getBaseHiveColumnIndex() < fileColumns.size()) { columnAdaptations.add(ColumnAdaptation.sourceColumn(sourceIndex)); fileReadColumns.add(orcColumn); fileReadTypes.add(readType); + fileReadLayouts.add(projectedLayout); Domain domain = effectivePredicateDomains.get(column); if (domain != null) { @@ -284,6 +327,7 @@ else if (column.getBaseHiveColumnIndex() < fileColumns.size()) { OrcRecordReader recordReader = reader.createRecordReader( fileReadColumns, fileReadTypes, + fileReadLayouts, predicateBuilder.build(), start, length, diff --git a/presto-orc/src/main/java/io/prestosql/orc/OrcReader.java b/presto-orc/src/main/java/io/prestosql/orc/OrcReader.java index b9dbaa93b4e9..963f0749e83c 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/OrcReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/OrcReader.java @@ -15,6 +15,7 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.airlift.units.DataSize; @@ -38,10 +39,13 @@ import java.io.IOException; import java.io.InputStream; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.function.Function; import java.util.function.Predicate; +import java.util.stream.Collectors; import java.util.stream.IntStream; import static com.google.common.base.Throwables.throwIfUnchecked; @@ -53,7 +57,10 @@ import static io.prestosql.orc.metadata.PostScript.MAGIC; import static java.lang.Math.min; import static java.lang.Math.toIntExact; +import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.mapping; +import static java.util.stream.Collectors.toList; public class OrcReader { @@ -253,10 +260,37 @@ public OrcRecordReader createRecordReader( int initialBatchSize, Function exceptionTransform) throws OrcCorruptionException + { + return createRecordReader( + readColumns, + readTypes, + Collections.nCopies(readColumns.size(), ProjectedLayout.fullyProjectedLayout()), + predicate, + offset, + length, + hiveStorageTimeZone, + systemMemoryUsage, + initialBatchSize, + exceptionTransform); + } + + public OrcRecordReader createRecordReader( + List readColumns, + List readTypes, + List readLayouts, + OrcPredicate predicate, + long offset, + long length, + DateTimeZone hiveStorageTimeZone, + AggregatedMemoryContext systemMemoryUsage, + int initialBatchSize, + Function exceptionTransform) + throws OrcCorruptionException { return new OrcRecordReader( requireNonNull(readColumns, "readColumns is null"), requireNonNull(readTypes, "readTypes is null"), + requireNonNull(readLayouts, "readLayouts is null"), requireNonNull(predicate, "predicate is null"), footer.getNumberOfRows(), footer.getStripes(), @@ -395,4 +429,50 @@ static void validateFile( throw new OrcCorruptionException(e, input.getId(), "Validation failed"); } } + + public static class ProjectedLayout + { + private final Optional> fieldLayouts; + + private ProjectedLayout(Optional> fieldLayouts) + { + this.fieldLayouts = requireNonNull(fieldLayouts, "fieldLayouts is null"); + } + + public ProjectedLayout getFieldLayout(String name) + { + if (fieldLayouts.isPresent()) { + return fieldLayouts.get().get(name); + } + + return fullyProjectedLayout(); + } + + public static ProjectedLayout fullyProjectedLayout() + { + return new ProjectedLayout(Optional.empty()); + } + + public static ProjectedLayout createProjectedLayout(OrcColumn root, List> dereferences) + { + if (dereferences.stream().map(List::size).anyMatch(Predicate.isEqual(0))) { + return fullyProjectedLayout(); + } + + Map>> dereferencesByField = dereferences.stream().collect( + Collectors.groupingBy( + sequence -> sequence.get(0), + mapping(sequence -> sequence.subList(1, sequence.size()), toList()))); + + ImmutableMap.Builder fieldLayouts = ImmutableMap.builder(); + for (OrcColumn nestedColumn : root.getNestedColumns()) { + String fieldName = nestedColumn.getColumnName().toLowerCase(ENGLISH); + if (dereferencesByField.containsKey(fieldName)) { + fieldLayouts.put(fieldName, createProjectedLayout(nestedColumn, dereferencesByField.get(fieldName))); + } + } + + return new ProjectedLayout(Optional.of(fieldLayouts.build())); + } + } } diff --git a/presto-orc/src/main/java/io/prestosql/orc/OrcRecordReader.java b/presto-orc/src/main/java/io/prestosql/orc/OrcRecordReader.java index 084486f573bf..7bfe116a7696 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/OrcRecordReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/OrcRecordReader.java @@ -119,6 +119,7 @@ public class OrcRecordReader public OrcRecordReader( List readColumns, List readTypes, + List readLayouts, OrcPredicate predicate, long numberOfRows, List fileStripes, @@ -145,6 +146,8 @@ public OrcRecordReader( checkArgument(readColumns.stream().distinct().count() == readColumns.size(), "readColumns contains duplicate entries"); requireNonNull(readTypes, "readTypes is null"); checkArgument(readColumns.size() == readTypes.size(), "readColumns and readTypes must have the same size"); + requireNonNull(readLayouts, "readLayouts is null"); + checkArgument(readColumns.size() == readLayouts.size(), "readColumns and readLayouts must have the same size"); requireNonNull(predicate, "predicate is null"); requireNonNull(fileStripes, "fileStripes is null"); requireNonNull(stripeStats, "stripeStats is null"); @@ -233,7 +236,7 @@ public OrcRecordReader( metadataReader, writeValidation); - columnReaders = createColumnReaders(readColumns, readTypes, streamReadersSystemMemoryContext, blockFactory); + columnReaders = createColumnReaders(readColumns, readTypes, readLayouts, streamReadersSystemMemoryContext, blockFactory); currentBytesPerCell = new long[columnReaders.length]; maxBytesPerCell = new long[columnReaders.length]; nextBatchSize = initialBatchSize; @@ -545,6 +548,7 @@ private void validateWritePageChecksum(Page page) private ColumnReader[] createColumnReaders( List columns, List readTypes, + List readLayouts, AggregatedMemoryContext systemMemoryContext, OrcBlockFactory blockFactory) throws OrcCorruptionException @@ -554,7 +558,8 @@ private ColumnReader[] createColumnReaders( int columnIndex = i; Type readType = readTypes.get(columnIndex); OrcColumn column = columns.get(columnIndex); - columnReaders[columnIndex] = createColumnReader(readType, column, systemMemoryContext, blockFactory); + OrcReader.ProjectedLayout projectedLayout = readLayouts.get(columnIndex); + columnReaders[columnIndex] = createColumnReader(readType, column, projectedLayout, systemMemoryContext, blockFactory); } return columnReaders; } diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/ColumnReaders.java b/presto-orc/src/main/java/io/prestosql/orc/reader/ColumnReaders.java index 1cc371e39678..80dd8665d6e0 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/ColumnReaders.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/ColumnReaders.java @@ -17,13 +17,19 @@ import io.prestosql.orc.OrcBlockFactory; import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; +import io.prestosql.orc.OrcReader; import io.prestosql.spi.type.Type; public final class ColumnReaders { private ColumnReaders() {} - public static ColumnReader createColumnReader(Type type, OrcColumn column, AggregatedMemoryContext systemMemoryContext, OrcBlockFactory blockFactory) + public static ColumnReader createColumnReader( + Type type, + OrcColumn column, + OrcReader.ProjectedLayout projectedLayout, + AggregatedMemoryContext systemMemoryContext, + OrcBlockFactory blockFactory) throws OrcCorruptionException { switch (column.getColumnType()) { @@ -50,7 +56,7 @@ public static ColumnReader createColumnReader(Type type, OrcColumn column, Aggre case LIST: return new ListColumnReader(type, column, systemMemoryContext, blockFactory); case STRUCT: - return new StructColumnReader(type, column, systemMemoryContext, blockFactory); + return new StructColumnReader(type, column, projectedLayout, systemMemoryContext, blockFactory); case MAP: return new MapColumnReader(type, column, systemMemoryContext, blockFactory); case DECIMAL: diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/ListColumnReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/ListColumnReader.java index 43190253bcd4..f4c69e95d531 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/ListColumnReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/ListColumnReader.java @@ -38,6 +38,7 @@ import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; +import static io.prestosql.orc.OrcReader.ProjectedLayout.fullyProjectedLayout; import static io.prestosql.orc.metadata.Stream.StreamKind.LENGTH; import static io.prestosql.orc.metadata.Stream.StreamKind.PRESENT; import static io.prestosql.orc.reader.ColumnReaders.createColumnReader; @@ -81,7 +82,7 @@ public ListColumnReader(Type type, OrcColumn column, AggregatedMemoryContext sys this.column = requireNonNull(column, "column is null"); this.blockFactory = requireNonNull(blockFactory, "blockFactory is null"); - this.elementColumnReader = createColumnReader(elementType, column.getNestedColumns().get(0), systemMemoryContext, blockFactory); + this.elementColumnReader = createColumnReader(elementType, column.getNestedColumns().get(0), fullyProjectedLayout(), systemMemoryContext, blockFactory); } @Override diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/MapColumnReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/MapColumnReader.java index 5c1896bc43a8..73ce9a7b5045 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/MapColumnReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/MapColumnReader.java @@ -39,6 +39,7 @@ import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; +import static io.prestosql.orc.OrcReader.ProjectedLayout.fullyProjectedLayout; import static io.prestosql.orc.metadata.Stream.StreamKind.LENGTH; import static io.prestosql.orc.metadata.Stream.StreamKind.PRESENT; import static io.prestosql.orc.reader.ColumnReaders.createColumnReader; @@ -85,8 +86,8 @@ public MapColumnReader(Type type, OrcColumn column, AggregatedMemoryContext syst this.column = requireNonNull(column, "column is null"); this.blockFactory = requireNonNull(blockFactory, "blockFactory is null"); - this.keyColumnReader = createColumnReader(this.type.getKeyType(), column.getNestedColumns().get(0), systemMemoryContext, blockFactory); - this.valueColumnReader = createColumnReader(this.type.getValueType(), column.getNestedColumns().get(1), systemMemoryContext, blockFactory); + this.keyColumnReader = createColumnReader(this.type.getKeyType(), column.getNestedColumns().get(0), fullyProjectedLayout(), systemMemoryContext, blockFactory); + this.valueColumnReader = createColumnReader(this.type.getValueType(), column.getNestedColumns().get(1), fullyProjectedLayout(), systemMemoryContext, blockFactory); } @Override diff --git a/presto-orc/src/main/java/io/prestosql/orc/reader/StructColumnReader.java b/presto-orc/src/main/java/io/prestosql/orc/reader/StructColumnReader.java index bf4f5784746c..688f1360cb5f 100644 --- a/presto-orc/src/main/java/io/prestosql/orc/reader/StructColumnReader.java +++ b/presto-orc/src/main/java/io/prestosql/orc/reader/StructColumnReader.java @@ -20,6 +20,7 @@ import io.prestosql.orc.OrcBlockFactory; import io.prestosql.orc.OrcColumn; import io.prestosql.orc.OrcCorruptionException; +import io.prestosql.orc.OrcReader; import io.prestosql.orc.metadata.ColumnEncoding; import io.prestosql.orc.metadata.ColumnMetadata; import io.prestosql.orc.stream.BooleanInputStream; @@ -74,7 +75,7 @@ public class StructColumnReader private boolean rowGroupOpen; - StructColumnReader(Type type, OrcColumn column, AggregatedMemoryContext systemMemoryContext, OrcBlockFactory blockFactory) + StructColumnReader(Type type, OrcColumn column, OrcReader.ProjectedLayout readLayout, AggregatedMemoryContext systemMemoryContext, OrcBlockFactory blockFactory) throws OrcCorruptionException { requireNonNull(type, "type is null"); @@ -96,8 +97,12 @@ public class StructColumnReader fieldNames.add(fieldName); OrcColumn fieldStream = nestedColumns.get(fieldName); + if (fieldStream != null) { - structFields.put(fieldName, createColumnReader(field.getType(), fieldStream, systemMemoryContext, blockFactory)); + OrcReader.ProjectedLayout fieldLayout = readLayout.getFieldLayout(fieldName); + if (fieldLayout != null) { + structFields.put(fieldName, createColumnReader(field.getType(), fieldStream, fieldLayout, systemMemoryContext, blockFactory)); + } } } this.fieldNames = fieldNames.build(); From 83b3d5221f1a2f25fdb4d548affe0ad580b992d5 Mon Sep 17 00:00:00 2001 From: Pratham Desai Date: Mon, 27 Jan 2020 14:27:16 -0800 Subject: [PATCH 3/3] Implement predicate pushdown for nested columns in ORC reader --- .../plugin/hive/orc/OrcPageSourceFactory.java | 38 ++- .../plugin/hive/orc/TestOrcPredicates.java | 226 ++++++++++++++++++ 2 files changed, 260 insertions(+), 4 deletions(-) create mode 100644 presto-hive/src/test/java/io/prestosql/plugin/hive/orc/TestOrcPredicates.java diff --git a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java index 033fcbec1dc8..ad43b31a9c3d 100644 --- a/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java +++ b/presto-hive/src/main/java/io/prestosql/plugin/hive/orc/OrcPageSourceFactory.java @@ -63,6 +63,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.Maps.uniqueIndex; import static io.prestosql.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.prestosql.orc.OrcReader.INITIAL_BATCH_SIZE; @@ -155,7 +156,6 @@ public Optional createPageSource( } Optional projectedReaderColumns = projectBaseColumns(columns); - effectivePredicate = effectivePredicate.transform(column -> column.isBaseColumn() ? column : null); ConnectorPageSource orcPageSource = createOrcPageSource( hdfsEnvironment, @@ -291,18 +291,25 @@ private static OrcPageSource createOrcPageSource( for (HiveColumnHandle column : columns) { OrcColumn orcColumn = null; OrcReader.ProjectedLayout projectedLayout = null; + Map, Domain> columnDomains = null; if (useOrcColumnNames || isFullAcid) { String columnName = column.getName().toLowerCase(ENGLISH); orcColumn = fileColumnsByName.get(columnName); if (orcColumn != null) { projectedLayout = createProjectedLayout(orcColumn, projectionsByColumnName.get(columnName)); + columnDomains = effectivePredicateDomains.entrySet().stream() + .filter(columnDomain -> columnDomain.getKey().getBaseColumnName().toLowerCase(ENGLISH).equals(columnName)) + .collect(toImmutableMap(columnDomain -> columnDomain.getKey().getHiveColumnProjectionInfo(), Map.Entry::getValue)); } } else if (column.getBaseHiveColumnIndex() < fileColumns.size()) { orcColumn = fileColumns.get(column.getBaseHiveColumnIndex()); if (orcColumn != null) { projectedLayout = createProjectedLayout(orcColumn, projectionsByColumnIndex.get(column.getBaseHiveColumnIndex())); + columnDomains = effectivePredicateDomains.entrySet().stream() + .filter(columnDomain -> columnDomain.getKey().getBaseHiveColumnIndex() == column.getBaseHiveColumnIndex()) + .collect(toImmutableMap(columnDomain -> columnDomain.getKey().getHiveColumnProjectionInfo(), Map.Entry::getValue)); } } @@ -314,9 +321,12 @@ else if (column.getBaseHiveColumnIndex() < fileColumns.size()) { fileReadTypes.add(readType); fileReadLayouts.add(projectedLayout); - Domain domain = effectivePredicateDomains.get(column); - if (domain != null) { - predicateBuilder.addColumn(orcColumn.getColumnId(), domain); + // Add predicates on top-level and nested columns + for (Map.Entry, Domain> columnDomain : columnDomains.entrySet()) { + OrcColumn nestedColumn = getNestedColumn(orcColumn, columnDomain.getKey()); + if (nestedColumn != null) { + predicateBuilder.addColumn(nestedColumn.getColumnId(), columnDomain.getValue()); + } } } else { @@ -408,4 +418,24 @@ private static void verifyAcidColumn(OrcReader orcReader, int columnIndex, Strin throw new PrestoException(HIVE_BAD_DATA, format("ORC ACID file %s column should be type %s: %s", columnName, columnType, path)); } } + + private static OrcColumn getNestedColumn(OrcColumn baseColumn, Optional projectionInfo) + { + if (!projectionInfo.isPresent()) { + return baseColumn; + } + + OrcColumn current = baseColumn; + for (String field : projectionInfo.get().getDereferenceNames()) { + Optional orcColumn = current.getNestedColumns().stream() + .filter(column -> column.getColumnName().toLowerCase(ENGLISH).equals(field)) + .findFirst(); + + if (!orcColumn.isPresent()) { + return null; + } + current = orcColumn.get(); + } + return current; + } } diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/orc/TestOrcPredicates.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/orc/TestOrcPredicates.java new file mode 100644 index 000000000000..efa08ad07a40 --- /dev/null +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/orc/TestOrcPredicates.java @@ -0,0 +1,226 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.plugin.hive.orc; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.prestosql.orc.OrcReaderOptions; +import io.prestosql.orc.OrcWriterOptions; +import io.prestosql.plugin.hive.AbstractTestHiveFileFormats; +import io.prestosql.plugin.hive.FileFormatDataSourceStats; +import io.prestosql.plugin.hive.HiveColumnHandle; +import io.prestosql.plugin.hive.HiveCompressionCodec; +import io.prestosql.plugin.hive.HiveConfig; +import io.prestosql.plugin.hive.HivePageSourceProvider; +import io.prestosql.plugin.hive.HivePartitionKey; +import io.prestosql.plugin.hive.NodeVersion; +import io.prestosql.plugin.hive.TableToPartitionMapping; +import io.prestosql.spi.Page; +import io.prestosql.spi.connector.ConnectorPageSource; +import io.prestosql.spi.connector.ConnectorSession; +import io.prestosql.spi.predicate.Domain; +import io.prestosql.spi.predicate.TupleDomain; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapred.FileSplit; +import org.joda.time.DateTimeZone; +import org.testng.annotations.Test; + +import java.io.File; +import java.time.Instant; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Properties; +import java.util.Set; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkState; +import static io.prestosql.plugin.hive.HiveStorageFormat.ORC; +import static io.prestosql.plugin.hive.HiveTestUtils.HDFS_ENVIRONMENT; +import static io.prestosql.plugin.hive.HiveTestUtils.TYPE_MANAGER; +import static io.prestosql.plugin.hive.HiveTestUtils.getHiveSession; +import static io.prestosql.plugin.hive.parquet.ParquetTester.HIVE_STORAGE_TIME_ZONE; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.testing.StructuralTestUtil.rowBlockOf; +import static java.util.stream.Collectors.toList; +import static org.apache.hadoop.hive.metastore.api.hive_metastoreConstants.FILE_INPUT_FORMAT; +import static org.apache.hadoop.hive.serde.serdeConstants.SERIALIZATION_LIB; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardStructObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaIntObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaLongObjectInspector; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestOrcPredicates + extends AbstractTestHiveFileFormats +{ + private static final int NUM_ROWS = 50000; + private static final FileFormatDataSourceStats STATS = new FileFormatDataSourceStats(); + + // Prepare test columns + private static final TestColumn columnPrimitiveInteger = new TestColumn("column_primitive_integer", javaIntObjectInspector, 3, 3); + private static final TestColumn columnStruct = new TestColumn( + "column1_struct", + getStandardStructObjectInspector(ImmutableList.of("field0", "field1"), ImmutableList.of(javaLongObjectInspector, javaLongObjectInspector)), + new Long[] {4L, 5L}, + rowBlockOf(ImmutableList.of(BIGINT, BIGINT), 4L, 5L)); + private static final TestColumn columnPrimitiveBigInt = new TestColumn("column_primitive_bigint", javaLongObjectInspector, 6L, 6L); + + @Test + public void testOrcPredicates() + throws Exception + { + testOrcPredicates(getHiveSession(new HiveConfig(), new OrcReaderConfig().setUseColumnNames(true))); + testOrcPredicates(getHiveSession(new HiveConfig(), new OrcReaderConfig())); + } + + private void testOrcPredicates(ConnectorSession session) + throws Exception + { + List columnsToWrite = ImmutableList.of(columnPrimitiveInteger, columnStruct, columnPrimitiveBigInt); + + File file = File.createTempFile("test", "orc_predicate"); + file.delete(); + try { + // Write data + OrcFileWriterFactory writerFactory = new OrcFileWriterFactory(HDFS_ENVIRONMENT, TYPE_MANAGER, new NodeVersion("test"), HIVE_STORAGE_TIME_ZONE, false, STATS, new OrcWriterOptions()); + FileSplit split = createTestFile(file.getAbsolutePath(), ORC, HiveCompressionCodec.NONE, columnsToWrite, session, NUM_ROWS, writerFactory); + + TupleDomain testingPredicate; + + // Verify predicates on base column + List columnsToRead = columnsToWrite; + // All rows returned for a satisfying predicate + testingPredicate = TupleDomain.withColumnDomains(ImmutableMap.of(columnPrimitiveBigInt, Domain.singleValue(BIGINT, 6L))); + assertFilteredRows(testingPredicate, columnsToRead, session, split, NUM_ROWS); + // No rows returned for a mismatched predicate + testingPredicate = TupleDomain.withColumnDomains(ImmutableMap.of(columnPrimitiveBigInt, Domain.singleValue(BIGINT, 1L))); + assertFilteredRows(testingPredicate, columnsToRead, session, split, 0); + + // Verify predicates on projected column + TestColumn projectedColumn = new TestColumn( + columnStruct.getBaseName(), + columnStruct.getBaseObjectInspector(), + ImmutableList.of("field1"), + ImmutableList.of(1), + javaLongObjectInspector, + 5L, + 5L, + false); + + columnsToRead = ImmutableList.of(columnPrimitiveBigInt, projectedColumn); + // All rows returned for a satisfying predicate + testingPredicate = TupleDomain.withColumnDomains(ImmutableMap.of(projectedColumn, Domain.singleValue(BIGINT, 5L))); + assertFilteredRows(testingPredicate, columnsToRead, session, split, NUM_ROWS); + // No rows returned for a mismatched predicate + testingPredicate = TupleDomain.withColumnDomains(ImmutableMap.of(projectedColumn, Domain.singleValue(BIGINT, 6L))); + assertFilteredRows(testingPredicate, columnsToRead, session, split, 0); + } + finally { + file.delete(); + } + } + + private void assertFilteredRows( + TupleDomain effectivePredicate, + List columnsToRead, + ConnectorSession session, + FileSplit split, + int expectedRows) + { + ConnectorPageSource pageSource = createPageSource(effectivePredicate, columnsToRead, session, split); + + int filteredRows = 0; + while (!pageSource.isFinished()) { + Page page = pageSource.getNextPage(); + if (page != null) { + filteredRows += page.getPositionCount(); + } + } + + assertEquals(filteredRows, expectedRows); + } + + private ConnectorPageSource createPageSource( + TupleDomain effectivePredicate, + List columnsToRead, + ConnectorSession session, + FileSplit split) + { + OrcPageSourceFactory readerFactory = new OrcPageSourceFactory(new OrcReaderOptions(), HDFS_ENVIRONMENT, STATS); + + Properties splitProperties = new Properties(); + splitProperties.setProperty(FILE_INPUT_FORMAT, ORC.getInputFormat()); + splitProperties.setProperty(SERIALIZATION_LIB, ORC.getSerDe()); + + // Use full columns in split properties + ImmutableList.Builder splitPropertiesColumnNames = ImmutableList.builder(); + ImmutableList.Builder splitPropertiesColumnTypes = ImmutableList.builder(); + Set baseColumnNames = new HashSet<>(); + for (TestColumn columnToRead : columnsToRead) { + String name = columnToRead.getBaseName(); + if (!baseColumnNames.contains(name) && !columnToRead.isPartitionKey()) { + baseColumnNames.add(name); + splitPropertiesColumnNames.add(name); + splitPropertiesColumnTypes.add(columnToRead.getBaseObjectInspector().getTypeName()); + } + } + + splitProperties.setProperty("columns", splitPropertiesColumnNames.build().stream().collect(Collectors.joining(","))); + splitProperties.setProperty("columns.types", splitPropertiesColumnTypes.build().stream().collect(Collectors.joining(","))); + + List partitionKeys = columnsToRead.stream() + .filter(TestColumn::isPartitionKey) + .map(input -> new HivePartitionKey(input.getName(), (String) input.getWriteValue())) + .collect(toList()); + + List columnHandles = getColumnHandles(columnsToRead); + + TupleDomain predicate = effectivePredicate.transform(testColumn -> { + Optional handle = columnHandles.stream() + .filter(column -> testColumn.getName().equals(column.getName())) + .findFirst(); + + checkState(handle.isPresent(), "Predicate on invalid column"); + return handle.get(); + }); + + Optional pageSource = HivePageSourceProvider.createHivePageSource( + ImmutableSet.of(readerFactory), + ImmutableSet.of(), + new Configuration(false), + session, + split.getPath(), + OptionalInt.empty(), + split.getStart(), + split.getLength(), + split.getLength(), + Instant.now().toEpochMilli(), + splitProperties, + predicate, + columnHandles, + partitionKeys, + DateTimeZone.getDefault(), + TYPE_MANAGER, + TableToPartitionMapping.empty(), + Optional.empty(), + false, + Optional.empty()); + + assertTrue(pageSource.isPresent()); + return pageSource.get(); + } +}