diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveColumnHandle.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveColumnHandle.java index 41dd09150d83b..2f6f623e7ad94 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveColumnHandle.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveColumnHandle.java @@ -15,11 +15,13 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.NestedField; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.TypeSignature; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; import java.util.Objects; import java.util.Optional; @@ -30,6 +32,7 @@ import static com.facebook.presto.hive.HiveType.HIVE_STRING; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; public class HiveColumnHandle @@ -60,6 +63,7 @@ public enum ColumnType private final int hiveColumnIndex; private final ColumnType columnType; private final Optional comment; + private final Optional nestedField; @JsonCreator public HiveColumnHandle( @@ -68,7 +72,8 @@ public HiveColumnHandle( @JsonProperty("typeSignature") TypeSignature typeSignature, @JsonProperty("hiveColumnIndex") int hiveColumnIndex, @JsonProperty("columnType") ColumnType columnType, - @JsonProperty("comment") Optional comment) + @JsonProperty("comment") Optional comment, + @JsonProperty("nestedField") Optional nestedField) { this.name = requireNonNull(name, "name is null"); checkArgument(hiveColumnIndex >= 0 || columnType == PARTITION_KEY || columnType == SYNTHESIZED, "hiveColumnIndex is negative"); @@ -77,6 +82,7 @@ public HiveColumnHandle( this.typeName = requireNonNull(typeSignature, "type is null"); this.columnType = requireNonNull(columnType, "columnType is null"); this.comment = requireNonNull(comment, "comment is null"); + this.nestedField = requireNonNull(nestedField, "nestedField is null"); } @JsonProperty @@ -112,12 +118,23 @@ public ColumnMetadata getColumnMetadata(TypeManager typeManager) return new ColumnMetadata(name, typeManager.getType(typeName), null, isHidden()); } + public List getNameList() + { + return asList(name.split("\\.")); + } + @JsonProperty public Optional getComment() { return comment; } + @JsonProperty + public Optional getNestedField() + { + return nestedField; + } + @JsonProperty public TypeSignature getTypeSignature() { @@ -133,7 +150,7 @@ public ColumnType getColumnType() @Override public int hashCode() { - return Objects.hash(name, hiveColumnIndex, hiveType, columnType, comment); + return Objects.hash(name, hiveColumnIndex, hiveType, columnType, comment, nestedField); } @Override @@ -150,13 +167,14 @@ public boolean equals(Object obj) Objects.equals(this.hiveColumnIndex, other.hiveColumnIndex) && Objects.equals(this.hiveType, other.hiveType) && Objects.equals(this.columnType, other.columnType) && + Objects.equals(this.nestedField, other.nestedField) && Objects.equals(this.comment, other.comment); } @Override public String toString() { - return name + ":" + hiveType + ":" + hiveColumnIndex + ":" + columnType; + return name + ":" + hiveType + ":" + hiveColumnIndex + ":" + columnType + ":" + nestedField; } public static HiveColumnHandle updateRowIdHandle() @@ -167,12 +185,12 @@ public static HiveColumnHandle updateRowIdHandle() // plan-time support for row-by-row delete so that planning doesn't fail. This is why we need // rowid handle. Note that in Hive connector, rowid handle is not implemented beyond plan-time. - return new HiveColumnHandle(UPDATE_ROW_ID_COLUMN_NAME, HIVE_LONG, BIGINT.getTypeSignature(), -1, SYNTHESIZED, Optional.empty()); + return new HiveColumnHandle(UPDATE_ROW_ID_COLUMN_NAME, HIVE_LONG, BIGINT.getTypeSignature(), -1, SYNTHESIZED, Optional.empty(), Optional.empty()); } public static HiveColumnHandle pathColumnHandle() { - return new HiveColumnHandle(PATH_COLUMN_NAME, PATH_HIVE_TYPE, PATH_TYPE_SIGNATURE, PATH_COLUMN_INDEX, SYNTHESIZED, Optional.empty()); + return new HiveColumnHandle(PATH_COLUMN_NAME, PATH_HIVE_TYPE, PATH_TYPE_SIGNATURE, PATH_COLUMN_INDEX, SYNTHESIZED, Optional.empty(), Optional.empty()); } /** @@ -182,7 +200,7 @@ public static HiveColumnHandle pathColumnHandle() */ public static HiveColumnHandle bucketColumnHandle() { - return new HiveColumnHandle(BUCKET_COLUMN_NAME, BUCKET_HIVE_TYPE, BUCKET_TYPE_SIGNATURE, BUCKET_COLUMN_INDEX, SYNTHESIZED, Optional.empty()); + return new HiveColumnHandle(BUCKET_COLUMN_NAME, BUCKET_HIVE_TYPE, BUCKET_TYPE_SIGNATURE, BUCKET_COLUMN_INDEX, SYNTHESIZED, Optional.empty(), Optional.empty()); } public static boolean isPathColumnHandle(HiveColumnHandle column) diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java index 432b135657bfb..a70e37c1e549d 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java @@ -45,6 +45,7 @@ import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.DiscretePredicates; import com.facebook.presto.spi.InMemoryRecordSet; +import com.facebook.presto.spi.NestedField; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.RecordCursor; import com.facebook.presto.spi.SchemaTableName; @@ -143,6 +144,7 @@ import static com.facebook.presto.hive.HiveSessionProperties.isSortedWritingEnabled; import static com.facebook.presto.hive.HiveSessionProperties.isStatisticsEnabled; import static com.facebook.presto.hive.HiveSessionProperties.isWritingStagingFilesEnabled; +import static com.facebook.presto.hive.HiveStorageFormat.PARQUET; import static com.facebook.presto.hive.HiveTableProperties.AVRO_SCHEMA_URL; import static com.facebook.presto.hive.HiveTableProperties.BUCKETED_BY_PROPERTY; import static com.facebook.presto.hive.HiveTableProperties.BUCKET_COUNT_PROPERTY; @@ -166,6 +168,7 @@ import static com.facebook.presto.hive.HiveUtil.decodeViewData; import static com.facebook.presto.hive.HiveUtil.encodeViewData; import static com.facebook.presto.hive.HiveUtil.getPartitionKeyColumnHandles; +import static com.facebook.presto.hive.HiveUtil.getRegularColumnHandles; import static com.facebook.presto.hive.HiveUtil.hiveColumnHandles; import static com.facebook.presto.hive.HiveUtil.schemaTableName; import static com.facebook.presto.hive.HiveUtil.toPartitionValues; @@ -544,6 +547,32 @@ public Map getColumnHandles(ConnectorSession session, Conn return columnHandles.build(); } + @Override + public Map getNestedColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle, Collection nestedFields) + { + SchemaTableName tableName = schemaTableName(tableHandle); + Optional table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName()); + if (!table.isPresent()) { + throw new TableNotFoundException(tableName); + } + + if (extractHiveStorageFormat(table.get()).equals(PARQUET)) { + List regularColumnHandles = getRegularColumnHandles(table.get()); + Map regularHiveColumnHandles = regularColumnHandles.stream() + .collect(Collectors.toMap(HiveColumnHandle::getName, identity())); + ImmutableMap.Builder nestedColumnHandles = ImmutableMap.builder(); + for (NestedField field : nestedFields) { + HiveColumnHandle hiveColumnHandle = regularHiveColumnHandles.get(field.getBase()); + Optional type = hiveColumnHandle.getHiveType().getFieldType(field); + if (hiveColumnHandle != null) { + nestedColumnHandles.put(field, new HiveColumnHandle(field.getName(), type.get(), type.get().getTypeSignature(), hiveColumnHandle.getHiveColumnIndex(), hiveColumnHandle.getColumnType(), hiveColumnHandle.getComment(), Optional.of(field))); + } + } + return nestedColumnHandles.build(); + } + return ImmutableMap.of(); + } + @SuppressWarnings("TryWithIdenticalCatches") @Override public Map> listTableColumns(ConnectorSession session, SchemaTablePrefix prefix) @@ -2174,7 +2203,8 @@ else if (column.isHidden()) { column.getType().getTypeSignature(), ordinal, columnType, - Optional.ofNullable(column.getComment()))); + Optional.ofNullable(column.getComment()), + Optional.empty())); ordinal++; } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceProvider.java index 930958ac74ee4..6a3dc3ce65828 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceProvider.java @@ -319,8 +319,14 @@ public static List buildColumnMappings( for (HiveColumnHandle column : columns) { Optional coercionFrom = Optional.ofNullable(columnCoercions.get(column.getHiveColumnIndex())); if (column.getColumnType() == REGULAR) { - checkArgument(regularColumnIndices.add(column.getHiveColumnIndex()), "duplicate hiveColumnIndex in columns list"); - columnMappings.add(regular(column, regularIndex, coercionFrom)); + if (column.getNestedField().isPresent()) { + Optional hiveType = coercionFrom.flatMap(type -> type.getFieldType(column.getNestedField().get())); + columnMappings.add(regular(column, regularIndex, hiveType)); + } + else { + checkArgument(regularColumnIndices.add(column.getHiveColumnIndex()), "duplicate hiveColumnIndex in columns list"); + columnMappings.add(regular(column, regularIndex, coercionFrom)); + } regularIndex++; } else { @@ -365,7 +371,8 @@ public static List toColumnHandles(List regular columnMapping.getCoercionFrom().get().getTypeSignature(), columnHandle.getHiveColumnIndex(), columnHandle.getColumnType(), - Optional.empty()); + Optional.empty(), + columnHandle.getNestedField()); }) .collect(toList()); } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveType.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveType.java index ea4fbc9cf99f0..67a9e5cf12686 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveType.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveType.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive; +import com.facebook.presto.spi.NestedField; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.type.NamedTypeSignature; import com.facebook.presto.spi.type.RowFieldName; @@ -156,6 +157,24 @@ public boolean isSupportedType() return isSupportedType(getTypeInfo()); } + public Optional getFieldType(NestedField nestedField) + { + TypeInfo typeInfo = getTypeInfo(); + for (String field : nestedField.getRemaining()) { + if (!(typeInfo instanceof StructTypeInfo)) { + throw new IllegalArgumentException("Invalid type: " + typeInfo + ". expecting RowType"); + } + StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + try { + typeInfo = structTypeInfo.getStructFieldTypeInfo(field); + } + catch (RuntimeException e) { + return Optional.empty(); + } + } + return Optional.of(toHiveType(typeInfo)); + } + public static boolean isSupportedType(TypeInfo typeInfo) { switch (typeInfo.getCategory()) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java index 5d970c7223a2e..d7fd03808a43b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java @@ -830,7 +830,7 @@ public static List getRegularColumnHandles(Table table) // ignore unsupported types rather than failing HiveType hiveType = field.getType(); if (hiveType.isSupportedType()) { - columns.add(new HiveColumnHandle(field.getName(), hiveType, hiveType.getTypeSignature(), hiveColumnIndex, REGULAR, field.getComment())); + columns.add(new HiveColumnHandle(field.getName(), hiveType, hiveType.getTypeSignature(), hiveColumnIndex, REGULAR, field.getComment(), Optional.empty())); } hiveColumnIndex++; } @@ -848,7 +848,7 @@ public static List getPartitionKeyColumnHandles(Table table) if (!hiveType.isSupportedType()) { throw new PrestoException(NOT_SUPPORTED, format("Unsupported Hive type %s found in partition keys of table %s.%s", hiveType, table.getDatabaseName(), table.getTableName())); } - columns.add(new HiveColumnHandle(field.getName(), hiveType, hiveType.getTypeSignature(), -1, PARTITION_KEY, field.getComment())); + columns.add(new HiveColumnHandle(field.getName(), hiveType, hiveType.getTypeSignature(), -1, PARTITION_KEY, field.getComment(), Optional.empty())); } return columns.build(); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactory.java index c40ebcf00f44a..51d4097707a63 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactory.java @@ -263,7 +263,7 @@ private static List getPhysicalHiveColumnHandles(List columnNames; private final List types; private final List> fields; + private final List> nestedFields; private final Block[] constantBlocks; private final int[] hiveColumnIndexes; @@ -86,6 +93,7 @@ public ParquetPageSource( int size = columns.size(); this.constantBlocks = new Block[size]; this.hiveColumnIndexes = new int[size]; + this.nestedFields = new ArrayList<>(size); ImmutableList.Builder namesBuilder = ImmutableList.builder(); ImmutableList.Builder typesBuilder = ImmutableList.builder(); @@ -99,15 +107,22 @@ public ParquetPageSource( namesBuilder.add(name); typesBuilder.add(type); + nestedFields.add(column.getNestedField()); hiveColumnIndexes[columnIndex] = column.getHiveColumnIndex(); - if (getParquetType(column, fileSchema, useParquetColumnNames) == null) { + if (getColumnType(column, fileSchema, useParquetColumnNames) == null) { constantBlocks[columnIndex] = RunLengthEncodedBlock.create(type, null, MAX_VECTOR_LENGTH); fieldsBuilder.add(Optional.empty()); } else { - String columnName = useParquetColumnNames ? name : fileSchema.getFields().get(column.getHiveColumnIndex()).getName(); - fieldsBuilder.add(constructField(type, lookupColumnByName(messageColumnIO, columnName))); + if (column.getNestedField().isPresent()) { + NestedField nestedField = column.getNestedField().get(); + fieldsBuilder.add(constructField(getNestedType(nestedField, type), lookupColumnByName(messageColumnIO, nestedField.getBase()))); + } + else { + String columnName = useParquetColumnNames ? name : fileSchema.getFields().get(column.getHiveColumnIndex()).getName(); + fieldsBuilder.add(constructField(type, lookupColumnByName(messageColumnIO, columnName))); + } } } types = typesBuilder.build(); @@ -157,19 +172,17 @@ public Page getNextPage() blocks[fieldId] = constantBlocks[fieldId].getRegion(0, batchSize); } else { - Type type = types.get(fieldId); Optional field = fields.get(fieldId); - int fieldIndex; - if (useParquetColumnNames) { - fieldIndex = getFieldIndex(fileSchema, columnNames.get(fieldId)); - } - else { - fieldIndex = hiveColumnIndexes[fieldId]; - } - if (fieldIndex != -1 && field.isPresent()) { - blocks[fieldId] = new LazyBlock(batchSize, new ParquetBlockLoader(field.get())); + if (field.isPresent()) { + if (nestedFields.get(fieldId).isPresent()) { + blocks[fieldId] = new LazyBlock(batchSize, new ParquetNestedBlockLoader(field.get(), types.get(fieldId))); + } + else { + blocks[fieldId] = new LazyBlock(batchSize, new ParquetBlockLoader(field.get())); + } } else { + Type type = types.get(fieldId); blocks[fieldId] = RunLengthEncodedBlock.create(type, null, batchSize); } } @@ -250,4 +263,90 @@ public final void load(LazyBlock lazyBlock) loaded = true; } } + + private final class ParquetNestedBlockLoader + implements LazyBlockLoader + { + private final int expectedBatchId = batchId; + private final Field field; + private final Type leafType; + private boolean loaded; + + public ParquetNestedBlockLoader(Field field, Type leafType) + { + this.field = requireNonNull(field, "field is null"); + this.leafType = requireNonNull(leafType, "leafType is null"); + } + + private int getDepth(Type rootType, Type leafType) + { + int depth = 0; + Type type = rootType; + while (!type.equals(leafType)) { + type = type.getTypeParameters().get(0); + ++depth; + } + return depth; + } + + @Override + public final void load(LazyBlock lazyBlock) + { + if (loaded) { + return; + } + + checkState(batchId == expectedBatchId); + + try { + Block block = parquetReader.readBlock(field); + + int size = block.getPositionCount(); + boolean[] isNulls = new boolean[size]; + + for (int level = 0; level < getDepth(field.getType(), leafType); ++level) { + ColumnarRow rowBlock = toColumnarRow(block); + int index = 0; + for (int j = 0; j < size; ++j) { + if (!isNulls[j]) { + isNulls[j] = rowBlock.isNull(index); + ++index; + } + } + block = rowBlock.getField(0); + } + + BlockBuilder blockBuilder = leafType.createBlockBuilder(null, size); + int position = 0; + for (int i = 0; i < size; ++i) { + if (isNulls[i]) { + blockBuilder.appendNull(); + } + else { + checkArgument(position < block.getPositionCount(), "current position cannot exceed total position count"); + leafType.appendTo(block, position, blockBuilder); + position++; + } + } + lazyBlock.setBlock(blockBuilder.build()); + } + catch (ParquetCorruptionException e) { + throw new PrestoException(HIVE_BAD_DATA, e); + } + catch (IOException e) { + throw new PrestoException(HIVE_CURSOR_ERROR, e); + } + loaded = true; + } + } + + private Type getNestedType(NestedField nestedField, Type leafType) + { + Type type = leafType; + List names = nestedField.getRemaining(); + for (int i = names.size() - 1; i >= 0; --i) { + type = RowType.from(ImmutableList.of(RowType.field(names.get(i), type))); + } + return type; + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetPageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetPageSourceFactory.java index b429ce4e4115b..f121ed311ca18 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetPageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetPageSourceFactory.java @@ -68,13 +68,13 @@ import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static com.facebook.presto.parquet.ParquetTypeUtils.getColumnIO; import static com.facebook.presto.parquet.ParquetTypeUtils.getDescriptors; +import static com.facebook.presto.parquet.ParquetTypeUtils.getFieldType; import static com.facebook.presto.parquet.ParquetTypeUtils.getParquetTypeByName; import static com.facebook.presto.parquet.predicate.PredicateUtils.buildPredicate; import static com.facebook.presto.parquet.predicate.PredicateUtils.predicateMatches; import static com.google.common.base.Strings.nullToEmpty; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.toList; import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.PRIMITIVE; public class ParquetPageSourceFactory @@ -158,13 +158,14 @@ public static ParquetPageSource createParquetPageSource( MessageType fileSchema = fileMetaData.getSchema(); dataSource = buildHdfsParquetDataSource(inputStream, path, fileSize, stats); - List fields = columns.stream() + Optional message = columns.stream() .filter(column -> column.getColumnType() == REGULAR) - .map(column -> getParquetType(column, fileSchema, useParquetColumnNames)) + .map(column -> getColumnType(column, fileSchema, useParquetColumnNames)) .filter(Objects::nonNull) - .collect(toList()); + .map(type -> new MessageType(fileSchema.getName(), type)) + .reduce(MessageType::union); - MessageType requestedSchema = new MessageType(fileSchema.getName(), fields); + MessageType requestedSchema = message.orElseGet(() -> new MessageType(fileSchema.getName(), ImmutableList.of())); ImmutableList.Builder footerBlocks = ImmutableList.builder(); for (BlockMetaData block : parquetMetadata.getBlocks()) { @@ -241,7 +242,7 @@ public static TupleDomain getParquetTupleDomain(Map partitionColumns = ImmutableList.of(dsColumn, fileFormatColumn, dummyColumn); List partitions = ImmutableList.builder() diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileFormats.java b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileFormats.java index 031a9f17aef7b..f713ee4d6c8c3 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileFormats.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileFormats.java @@ -486,7 +486,7 @@ protected List getColumnHandles(List testColumns) int columnIndex = testColumn.isPartitionKey() ? -1 : nextHiveColumnIndex++; HiveType hiveType = HiveType.valueOf(testColumn.getObjectInspector().getTypeName()); - columns.add(new HiveColumnHandle(testColumn.getName(), hiveType, hiveType.getTypeSignature(), columnIndex, testColumn.isPartitionKey() ? PARTITION_KEY : REGULAR, Optional.empty())); + columns.add(new HiveColumnHandle(testColumn.getName(), hiveType, hiveType.getTypeSignature(), columnIndex, testColumn.isPartitionKey() ? PARTITION_KEY : REGULAR, Optional.empty(), Optional.empty())); } return columns; } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestBackgroundHiveSplitLoader.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestBackgroundHiveSplitLoader.java index add56250f110d..b49ad4d9fe813 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestBackgroundHiveSplitLoader.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestBackgroundHiveSplitLoader.java @@ -94,7 +94,7 @@ public class TestBackgroundHiveSplitLoader private static final List PARTITION_COLUMNS = ImmutableList.of( new Column("partitionColumn", HIVE_INT, Optional.empty())); private static final List BUCKET_COLUMN_HANDLES = ImmutableList.of( - new HiveColumnHandle("col1", HIVE_INT, INTEGER.getTypeSignature(), 0, ColumnType.REGULAR, Optional.empty())); + new HiveColumnHandle("col1", HIVE_INT, INTEGER.getTypeSignature(), 0, ColumnType.REGULAR, Optional.empty(), Optional.empty())); private static final Optional BUCKET_PROPERTY = Optional.of( new HiveBucketProperty(ImmutableList.of("col1"), BUCKET_COUNT, ImmutableList.of())); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveColumnHandle.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveColumnHandle.java index f9f001aeed720..2ce021719391a 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveColumnHandle.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveColumnHandle.java @@ -38,14 +38,14 @@ public void testHiddenColumn() @Test public void testRegularColumn() { - HiveColumnHandle expectedPartitionColumn = new HiveColumnHandle("name", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, PARTITION_KEY, Optional.empty()); + HiveColumnHandle expectedPartitionColumn = new HiveColumnHandle("name", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, PARTITION_KEY, Optional.empty(), Optional.empty()); testRoundTrip(expectedPartitionColumn); } @Test public void testPartitionKeyColumn() { - HiveColumnHandle expectedRegularColumn = new HiveColumnHandle("name", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, REGULAR, Optional.empty()); + HiveColumnHandle expectedRegularColumn = new HiveColumnHandle("name", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, REGULAR, Optional.empty(), Optional.empty()); testRoundTrip(expectedRegularColumn); } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadata.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadata.java index a35efdb715905..55a05928a392f 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadata.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadata.java @@ -34,6 +34,7 @@ public class TestHiveMetadata TypeSignature.parseTypeSignature("varchar"), 0, HiveColumnHandle.ColumnType.PARTITION_KEY, + Optional.empty(), Optional.empty()); @Test(timeOut = 5000) diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePageSink.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePageSink.java index 7a45796fda6f6..2a94a2b116897 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePageSink.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePageSink.java @@ -282,7 +282,7 @@ private static List getColumnHandles() for (int i = 0; i < columns.size(); i++) { LineItemColumn column = columns.get(i); HiveType hiveType = getHiveType(column.getType()); - handles.add(new HiveColumnHandle(column.getColumnName(), hiveType, hiveType.getTypeSignature(), i, REGULAR, Optional.empty())); + handles.add(new HiveColumnHandle(column.getColumnName(), hiveType, hiveType.getTypeSignature(), i, REGULAR, Optional.empty(), Optional.empty())); } return handles.build(); } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java index 26d0441f2e6ee..0bb390c087b45 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java @@ -61,7 +61,7 @@ public void testJsonRoundTrip() Optional.of(new HiveSplit.BucketConversion( 32, 16, - ImmutableList.of(new HiveColumnHandle("col", HIVE_LONG, BIGINT.getTypeSignature(), 5, ColumnType.REGULAR, Optional.of("comment"))))), + ImmutableList.of(new HiveColumnHandle("col", HIVE_LONG, BIGINT.getTypeSignature(), 5, ColumnType.REGULAR, Optional.of("comment"), Optional.empty())))), false); String json = codec.toJson(expected); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestIonSqlQueryBuilder.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestIonSqlQueryBuilder.java index b620e8c2c4cc1..51780c0c6b300 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestIonSqlQueryBuilder.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestIonSqlQueryBuilder.java @@ -56,9 +56,9 @@ public void testBuildSQL() { IonSqlQueryBuilder queryBuilder = new IonSqlQueryBuilder(new TypeRegistry()); List columns = ImmutableList.of( - new HiveColumnHandle("n_nationkey", HIVE_INT, parseTypeSignature(INTEGER), 0, REGULAR, Optional.empty()), - new HiveColumnHandle("n_name", HIVE_STRING, parseTypeSignature(VARCHAR), 1, REGULAR, Optional.empty()), - new HiveColumnHandle("n_regionkey", HIVE_INT, parseTypeSignature(INTEGER), 2, REGULAR, Optional.empty())); + new HiveColumnHandle("n_nationkey", HIVE_INT, parseTypeSignature(INTEGER), 0, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("n_name", HIVE_STRING, parseTypeSignature(VARCHAR), 1, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("n_regionkey", HIVE_INT, parseTypeSignature(INTEGER), 2, REGULAR, Optional.empty(), Optional.empty())); assertEquals("SELECT s._1, s._2, s._3 FROM S3Object s", queryBuilder.buildSql(columns, TupleDomain.all())); @@ -81,9 +81,9 @@ public void testDecimalColumns() TypeManager typeManager = new TypeRegistry(); IonSqlQueryBuilder queryBuilder = new IonSqlQueryBuilder(typeManager); List columns = ImmutableList.of( - new HiveColumnHandle("quantity", HiveType.valueOf("decimal(20,0)"), parseTypeSignature(DECIMAL), 0, REGULAR, Optional.empty()), - new HiveColumnHandle("extendedprice", HiveType.valueOf("decimal(20,2)"), parseTypeSignature(DECIMAL), 1, REGULAR, Optional.empty()), - new HiveColumnHandle("discount", HiveType.valueOf("decimal(10,2)"), parseTypeSignature(DECIMAL), 2, REGULAR, Optional.empty())); + new HiveColumnHandle("quantity", HiveType.valueOf("decimal(20,0)"), parseTypeSignature(DECIMAL), 0, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("extendedprice", HiveType.valueOf("decimal(20,2)"), parseTypeSignature(DECIMAL), 1, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("discount", HiveType.valueOf("decimal(10,2)"), parseTypeSignature(DECIMAL), 2, REGULAR, Optional.empty(), Optional.empty())); DecimalType decimalType = DecimalType.createDecimalType(10, 2); TupleDomain tupleDomain = withColumnDomains( ImmutableMap.of( @@ -101,8 +101,8 @@ public void testDateColumn() { IonSqlQueryBuilder queryBuilder = new IonSqlQueryBuilder(new TypeRegistry()); List columns = ImmutableList.of( - new HiveColumnHandle("t1", HIVE_TIMESTAMP, parseTypeSignature(TIMESTAMP), 0, REGULAR, Optional.empty()), - new HiveColumnHandle("t2", HIVE_DATE, parseTypeSignature(StandardTypes.DATE), 1, REGULAR, Optional.empty())); + new HiveColumnHandle("t1", HIVE_TIMESTAMP, parseTypeSignature(TIMESTAMP), 0, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("t2", HIVE_DATE, parseTypeSignature(StandardTypes.DATE), 1, REGULAR, Optional.empty(), Optional.empty())); TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of( columns.get(1), Domain.create(SortedRangeSet.copyOf(DATE, ImmutableList.of(Range.equal(DATE, (long) DateTimeUtils.parseDate("2001-08-22")))), false))); @@ -114,9 +114,9 @@ public void testNotPushDoublePredicates() { IonSqlQueryBuilder queryBuilder = new IonSqlQueryBuilder(new TypeRegistry()); List columns = ImmutableList.of( - new HiveColumnHandle("quantity", HIVE_INT, parseTypeSignature(INTEGER), 0, REGULAR, Optional.empty()), - new HiveColumnHandle("extendedprice", HIVE_DOUBLE, parseTypeSignature(StandardTypes.DOUBLE), 1, REGULAR, Optional.empty()), - new HiveColumnHandle("discount", HIVE_DOUBLE, parseTypeSignature(StandardTypes.DOUBLE), 2, REGULAR, Optional.empty())); + new HiveColumnHandle("quantity", HIVE_INT, parseTypeSignature(INTEGER), 0, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("extendedprice", HIVE_DOUBLE, parseTypeSignature(StandardTypes.DOUBLE), 1, REGULAR, Optional.empty(), Optional.empty()), + new HiveColumnHandle("discount", HIVE_DOUBLE, parseTypeSignature(StandardTypes.DOUBLE), 2, REGULAR, Optional.empty(), Optional.empty())); TupleDomain tupleDomain = withColumnDomains( ImmutableMap.of( columns.get(0), Domain.create(ofRanges(Range.lessThan(BIGINT, 50L)), false), diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestJsonHiveHandles.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestJsonHiveHandles.java index 4922d7647e7fa..9ee5469c89b4a 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestJsonHiveHandles.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestJsonHiveHandles.java @@ -77,7 +77,7 @@ public void testTableHandleDeserialize() public void testColumnHandleSerialize() throws Exception { - HiveColumnHandle columnHandle = new HiveColumnHandle("column", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), -1, PARTITION_KEY, Optional.of("comment")); + HiveColumnHandle columnHandle = new HiveColumnHandle("column", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), -1, PARTITION_KEY, Optional.of("comment"), Optional.empty()); assertTrue(objectMapper.canSerialize(HiveColumnHandle.class)); String json = objectMapper.writeValueAsString(columnHandle); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcPageSourceMemoryTracking.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcPageSourceMemoryTracking.java index d712b2a097ab2..ed936ae341758 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcPageSourceMemoryTracking.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcPageSourceMemoryTracking.java @@ -445,7 +445,7 @@ public TestPreparer(String tempFilePath, List testColumns, int numRo HiveType hiveType = HiveType.valueOf(inspector.getTypeName()); Type type = hiveType.getType(TYPE_MANAGER); - columnsBuilder.add(new HiveColumnHandle(testColumn.getName(), hiveType, type.getTypeSignature(), columnIndex, testColumn.isPartitionKey() ? PARTITION_KEY : REGULAR, Optional.empty())); + columnsBuilder.add(new HiveColumnHandle(testColumn.getName(), hiveType, type.getTypeSignature(), columnIndex, testColumn.isPartitionKey() ? PARTITION_KEY : REGULAR, Optional.empty(), Optional.empty())); typesBuilder.add(type); } columns = columnsBuilder.build(); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestS3SelectRecordCursor.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestS3SelectRecordCursor.java index 2f76e68ccdccd..2a79b67006fe8 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestS3SelectRecordCursor.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestS3SelectRecordCursor.java @@ -48,12 +48,12 @@ public class TestS3SelectRecordCursor { private static final String LAZY_SERDE_CLASS_NAME = LazySimpleSerDe.class.getName(); - private static final HiveColumnHandle ARTICLE_COLUMN = new HiveColumnHandle("article", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 1, REGULAR, Optional.empty()); - private static final HiveColumnHandle AUTHOR_COLUMN = new HiveColumnHandle("author", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 1, REGULAR, Optional.empty()); - private static final HiveColumnHandle DATE_ARTICLE_COLUMN = new HiveColumnHandle("date_pub", HIVE_INT, parseTypeSignature(StandardTypes.DATE), 1, REGULAR, Optional.empty()); - private static final HiveColumnHandle QUANTITY_COLUMN = new HiveColumnHandle("quantity", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), 1, REGULAR, Optional.empty()); + private static final HiveColumnHandle ARTICLE_COLUMN = new HiveColumnHandle("article", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 1, REGULAR, Optional.empty(), Optional.empty()); + private static final HiveColumnHandle AUTHOR_COLUMN = new HiveColumnHandle("author", HIVE_STRING, parseTypeSignature(StandardTypes.VARCHAR), 1, REGULAR, Optional.empty(), Optional.empty()); + private static final HiveColumnHandle DATE_ARTICLE_COLUMN = new HiveColumnHandle("date_pub", HIVE_INT, parseTypeSignature(StandardTypes.DATE), 1, REGULAR, Optional.empty(), Optional.empty()); + private static final HiveColumnHandle QUANTITY_COLUMN = new HiveColumnHandle("quantity", HIVE_INT, parseTypeSignature(StandardTypes.INTEGER), 1, REGULAR, Optional.empty(), Optional.empty()); private static final HiveColumnHandle[] DEFAULT_TEST_COLUMNS = {ARTICLE_COLUMN, AUTHOR_COLUMN, DATE_ARTICLE_COLUMN, QUANTITY_COLUMN}; - private static final HiveColumnHandle MOCK_HIVE_COLUMN_HANDLE = new HiveColumnHandle("mockName", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, PARTITION_KEY, Optional.empty()); + private static final HiveColumnHandle MOCK_HIVE_COLUMN_HANDLE = new HiveColumnHandle("mockName", HiveType.HIVE_FLOAT, parseTypeSignature(StandardTypes.DOUBLE), 88, PARTITION_KEY, Optional.empty(), Optional.empty()); private static final TypeManager MOCK_TYPE_MANAGER = new TestingTypeManager(); private static final Path MOCK_PATH = new Path("mockPath"); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java index 47dfe5aeeb53c..51802fa07e3dc 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java @@ -354,7 +354,7 @@ private static ConnectorPageSource createPageSource( for (int i = 0; i < columnNames.size(); i++) { String columnName = columnNames.get(i); Type columnType = columnTypes.get(i); - columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty())); + columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty(), Optional.empty())); } RecordCursor recordCursor = cursorProvider @@ -388,7 +388,7 @@ private static ConnectorPageSource createPageSource( for (int i = 0; i < columnNames.size(); i++) { String columnName = columnNames.get(i); Type columnType = columnTypes.get(i); - columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty())); + columnHandles.add(new HiveColumnHandle(columnName, toHiveType(typeTranslator, columnType), columnType.getTypeSignature(), i, REGULAR, Optional.empty(), Optional.empty())); } return pageSourceFactory diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/predicate/TestParquetPredicateUtils.java b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/predicate/TestParquetPredicateUtils.java index 75ca5678529dd..b63b7c805ca24 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/predicate/TestParquetPredicateUtils.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/predicate/TestParquetPredicateUtils.java @@ -56,7 +56,7 @@ public class TestParquetPredicateUtils @Test public void testParquetTupleDomainPrimitiveArray() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_array", HiveType.valueOf("array"), parseTypeSignature(StandardTypes.ARRAY), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_array", HiveType.valueOf("array"), parseTypeSignature(StandardTypes.ARRAY), 0, REGULAR, Optional.empty(), Optional.empty()); TupleDomain domain = withColumnDomains(ImmutableMap.of(columnHandle, Domain.notNull(new ArrayType(INTEGER)))); MessageType fileSchema = new MessageType("hive_schema", @@ -71,7 +71,7 @@ public void testParquetTupleDomainPrimitiveArray() @Test public void testParquetTupleDomainStructArray() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_array_struct", HiveType.valueOf("array>"), parseTypeSignature(StandardTypes.ARRAY), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_array_struct", HiveType.valueOf("array>"), parseTypeSignature(StandardTypes.ARRAY), 0, REGULAR, Optional.empty(), Optional.empty()); RowType.Field rowField = new RowType.Field(Optional.of("a"), INTEGER); RowType rowType = RowType.from(ImmutableList.of(rowField)); TupleDomain domain = withColumnDomains(ImmutableMap.of(columnHandle, Domain.notNull(new ArrayType(rowType)))); @@ -89,7 +89,7 @@ public void testParquetTupleDomainStructArray() @Test public void testParquetTupleDomainPrimitive() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_primitive", HiveType.valueOf("bigint"), parseTypeSignature(StandardTypes.BIGINT), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_primitive", HiveType.valueOf("bigint"), parseTypeSignature(StandardTypes.BIGINT), 0, REGULAR, Optional.empty(), Optional.empty()); Domain singleValueDomain = Domain.singleValue(BIGINT, 123L); TupleDomain domain = withColumnDomains(ImmutableMap.of(columnHandle, singleValueDomain)); @@ -110,7 +110,7 @@ public void testParquetTupleDomainPrimitive() @Test public void testParquetTupleDomainStruct() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_struct", HiveType.valueOf("struct"), parseTypeSignature(StandardTypes.ROW), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_struct", HiveType.valueOf("struct"), parseTypeSignature(StandardTypes.ROW), 0, REGULAR, Optional.empty(), Optional.empty()); RowType.Field rowField = new RowType.Field(Optional.of("my_struct"), INTEGER); RowType rowType = RowType.from(ImmutableList.of(rowField)); TupleDomain domain = withColumnDomains(ImmutableMap.of(columnHandle, Domain.notNull(rowType))); @@ -127,7 +127,7 @@ public void testParquetTupleDomainStruct() @Test public void testParquetTupleDomainMap() { - HiveColumnHandle columnHandle = new HiveColumnHandle("my_map", HiveType.valueOf("map"), parseTypeSignature(StandardTypes.MAP), 0, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle("my_map", HiveType.valueOf("map"), parseTypeSignature(StandardTypes.MAP), 0, REGULAR, Optional.empty(), Optional.empty()); MapType mapType = new MapType( INTEGER, diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestMetastoreHiveStatisticsProvider.java b/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestMetastoreHiveStatisticsProvider.java index 219af9e77ce1b..a55246181b54e 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestMetastoreHiveStatisticsProvider.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestMetastoreHiveStatisticsProvider.java @@ -94,8 +94,8 @@ public class TestMetastoreHiveStatisticsProvider private static final String COLUMN = "column"; private static final DecimalType DECIMAL = createDecimalType(5, 3); - private static final HiveColumnHandle PARTITION_COLUMN_1 = new HiveColumnHandle("p1", HIVE_STRING, VARCHAR.getTypeSignature(), 0, PARTITION_KEY, Optional.empty()); - private static final HiveColumnHandle PARTITION_COLUMN_2 = new HiveColumnHandle("p2", HIVE_LONG, BIGINT.getTypeSignature(), 1, PARTITION_KEY, Optional.empty()); + private static final HiveColumnHandle PARTITION_COLUMN_1 = new HiveColumnHandle("p1", HIVE_STRING, VARCHAR.getTypeSignature(), 0, PARTITION_KEY, Optional.empty(), Optional.empty()); + private static final HiveColumnHandle PARTITION_COLUMN_2 = new HiveColumnHandle("p2", HIVE_LONG, BIGINT.getTypeSignature(), 1, PARTITION_KEY, Optional.empty(), Optional.empty()); @Test public void testGetPartitionsSample() @@ -611,7 +611,7 @@ public void testGetTableStatistics() .build(); MetastoreHiveStatisticsProvider statisticsProvider = new MetastoreHiveStatisticsProvider((table, hivePartitions) -> ImmutableMap.of(partitionName, statistics)); TestingConnectorSession session = new TestingConnectorSession(new HiveSessionProperties(new HiveClientConfig(), new OrcFileWriterConfig(), new ParquetFileWriterConfig()).getSessionProperties()); - HiveColumnHandle columnHandle = new HiveColumnHandle(COLUMN, HIVE_LONG, BIGINT.getTypeSignature(), 2, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle(COLUMN, HIVE_LONG, BIGINT.getTypeSignature(), 2, REGULAR, Optional.empty(), Optional.empty()); TableStatistics expected = TableStatistics.builder() .setRowCount(Estimate.of(1000)) .setColumnStatistics( @@ -661,7 +661,7 @@ public void testGetTableStatisticsUnpartitioned() .build(); MetastoreHiveStatisticsProvider statisticsProvider = new MetastoreHiveStatisticsProvider((table, hivePartitions) -> ImmutableMap.of(UNPARTITIONED_ID, statistics)); TestingConnectorSession session = new TestingConnectorSession(new HiveSessionProperties(new HiveClientConfig(), new OrcFileWriterConfig(), new ParquetFileWriterConfig()).getSessionProperties()); - HiveColumnHandle columnHandle = new HiveColumnHandle(COLUMN, HIVE_LONG, BIGINT.getTypeSignature(), 2, REGULAR, Optional.empty()); + HiveColumnHandle columnHandle = new HiveColumnHandle(COLUMN, HIVE_LONG, BIGINT.getTypeSignature(), 2, REGULAR, Optional.empty(), Optional.empty()); TableStatistics expected = TableStatistics.builder() .setRowCount(Estimate.of(1000)) .setColumnStatistics( diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/Metadata.java b/presto-main/src/main/java/com/facebook/presto/metadata/Metadata.java index e3c6799c46cb7..2a2483b51375b 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/Metadata.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/Metadata.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.NestedField; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.block.BlockEncodingSerde; @@ -123,6 +124,13 @@ public interface Metadata */ Map getColumnHandles(Session session, TableHandle tableHandle); + /** + * Gets all nested columns on the specified table, or an empty map if the columns can not be enumerated. + * + * @throws RuntimeException if table handle is no longer valid + */ + Map getNestedColumnHandles(Session session, TableHandle tableHandle, Collection nestedFields); + /** * Gets the metadata for the specified table column. * diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java index b97d332d6f97e..82a6ce6a8d6d7 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/MetadataManager.java @@ -30,6 +30,7 @@ import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.ConnectorViewDefinition; import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.NestedField; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.SchemaTableName; @@ -491,6 +492,14 @@ public Map getColumnHandles(Session session, TableHandle t return map.build(); } + @Override + public Map getNestedColumnHandles(Session session, TableHandle tableHandle, Collection nestedFields) + { + ConnectorId connectorId = tableHandle.getConnectorId(); + ConnectorMetadata metadata = getMetadata(session, connectorId); + return metadata.getNestedColumnHandles(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), nestedFields); + } + @Override public ColumnMetadata getColumnMetadata(Session session, TableHandle tableHandle, ColumnHandle columnHandle) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java index 937b311164a4e..5f296a3f10f2f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/DomainTranslator.java @@ -30,6 +30,7 @@ import com.facebook.presto.spi.predicate.Utils; import com.facebook.presto.spi.predicate.ValueSet; import com.facebook.presto.spi.type.Type; +import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.InterpretedFunctionInvoker; import com.facebook.presto.sql.analyzer.ExpressionAnalyzer; @@ -39,6 +40,7 @@ import com.facebook.presto.sql.tree.BooleanLiteral; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.InListExpression; import com.facebook.presto.sql.tree.InPredicate; @@ -79,6 +81,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Iterators.peekingIterator; +import static java.lang.String.join; import static java.util.Collections.emptyList; import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; @@ -266,6 +269,24 @@ private static boolean isBetween(Range range) && !range.getHigh().isUpperUnbounded() && range.getHigh().getBound() == Marker.Bound.EXACTLY; } + private static List getDereferenceComponents(DereferenceExpression expression) + { + Expression base = expression.getBase(); + if (base instanceof SymbolReference) { + List result = new ArrayList<>(); + result.add(((SymbolReference) base).getName()); + result.add(expression.getField().getValue()); + return result; + } + else if (base instanceof DereferenceExpression) { + List result = getDereferenceComponents((DereferenceExpression) base); + if (result != null) { + result.add(expression.getField().getValue()); + } + return result; + } + return null; + } /** * Convert an Expression predicate into an ExtractionResult consisting of: * 1) A successfully extracted TupleDomain @@ -321,6 +342,56 @@ private static Expression complementIfNecessary(Expression expression, boolean c return complement ? new NotExpression(expression) : expression; } + private Type recursiveTypeLookup(DereferenceExpression expression) + { + List fieldNames = getDereferenceComponents(expression); + if (fieldNames == null) { + return null; + } + checkArgument(fieldNames.size() > 0, "Invalide Dereference Expression: %s", expression); + Type type = types.get(new Symbol(fieldNames.get(0))); + checkArgument(type != null, "Types is missing info for dereference expression: %s", expression); + int index = 1; + while (index < fieldNames.size()) { + String field = fieldNames.get(index); + for (TypeSignatureParameter typeSignatureParameter : type.getTypeSignature().getParameters()) { + Optional name = typeSignatureParameter.getNamedTypeSignature().getName(); + if (name.isPresent()) { + if (name.get().equalsIgnoreCase(field)) { + type = metadata.getTypeManager().getType(typeSignatureParameter.getNamedTypeSignature().getTypeSignature()); + break; + } + } + } + index = index + 1; + } + checkArgument(type != null, "Types is missing info for dereference expression: %s", expression); + return type; + } + + private Optional> combineDereferenceTupleDomain(LogicalBinaryExpression.Operator operator, Optional> left, Optional> right) + { + if (left.isPresent() && right.isPresent()) { + switch (operator) { + case AND: + return Optional.of(left.get().intersect(right.get())); + case OR: + return Optional.of(TupleDomain.columnWiseUnion(left.get(), right.get())); + default: + throw new AssertionError("Unknown operator: " + operator); + } + } + + if (left.isPresent()) { + return left; + } + + if (right.isPresent()) { + return right; + } + return Optional.empty(); + } + @Override protected ExtractionResult visitExpression(Expression node, Boolean complement) { @@ -338,11 +409,13 @@ protected ExtractionResult visitLogicalBinaryExpression(LogicalBinaryExpression TupleDomain rightTupleDomain = rightResult.getTupleDomain(); LogicalBinaryExpression.Operator operator = complement ? node.getOperator().flip() : node.getOperator(); + Optional> dereferenceTupleDomain = combineDereferenceTupleDomain(operator, leftResult.getDereferenceTupleDomain(), rightResult.getDereferenceTupleDomain()); switch (operator) { case AND: return new ExtractionResult( leftTupleDomain.intersect(rightTupleDomain), - combineConjuncts(leftResult.getRemainingExpression(), rightResult.getRemainingExpression())); + combineConjuncts(leftResult.getRemainingExpression(), rightResult.getRemainingExpression()), + dereferenceTupleDomain); case OR: TupleDomain columnUnionedTupleDomain = TupleDomain.columnWiseUnion(leftTupleDomain, rightTupleDomain); @@ -372,7 +445,7 @@ protected ExtractionResult visitLogicalBinaryExpression(LogicalBinaryExpression } } - return new ExtractionResult(columnUnionedTupleDomain, remainingExpression); + return new ExtractionResult(columnUnionedTupleDomain, remainingExpression, dereferenceTupleDomain); default: throw new AssertionError("Unknown operator: " + node.getOperator()); @@ -394,52 +467,59 @@ protected ExtractionResult visitComparisonExpression(ComparisonExpression node, } NormalizedSimpleComparison normalized = optionalNormalized.get(); - Expression symbolExpression = normalized.getSymbolExpression(); - if (symbolExpression instanceof SymbolReference) { - Symbol symbol = Symbol.from(symbolExpression); - NullableValue value = normalized.getValue(); - Type type = value.getType(); // common type for symbol and value - return createComparisonExtractionResult(normalized.getComparisonOperator(), symbol, type, value.getValue(), complement); - } - else if (symbolExpression instanceof Cast) { - Cast castExpression = (Cast) symbolExpression; - if (!isImplicitCoercion(castExpression)) { - // - // we cannot use non-coercion cast to literal_type on symbol side to build tuple domain - // - // example which illustrates the problem: - // - // let t be of timestamp type: - // - // and expression be: - // cast(t as date) == date_literal - // - // after dropping cast we end up with: - // - // t == date_literal - // - // if we build tuple domain based coercion of date_literal to timestamp type we would - // end up with tuple domain with just one time point (cast(date_literal as timestamp). - // While we need range which maps to single date pointed by date_literal. - // - return super.visitComparisonExpression(node, complement); + if (normalized.getSymbolExpression().isPresent()) { + Expression symbolExpression = normalized.getSymbolExpression().get(); + if (symbolExpression instanceof SymbolReference) { + Symbol symbol = Symbol.from(symbolExpression); + NullableValue value = normalized.getValue(); + Type type = value.getType(); // common type for symbol and value + return createComparisonExtractionResult(normalized.getComparisonOperator(), symbol, type, value.getValue(), complement, Optional.empty(), node); } + else if (symbolExpression instanceof Cast) { + Cast castExpression = (Cast) symbolExpression; + if (!isImplicitCoercion(castExpression)) { + // + // we cannot use non-coercion cast to literal_type on symbol side to build tuple domain + // + // example which illustrates the problem: + // + // let t be of timestamp type: + // + // and expression be: + // cast(t as date) == date_literal + // + // after dropping cast we end up with: + // + // t == date_literal + // + // if we build tuple domain based coercion of date_literal to timestamp type we would + // end up with tuple domain with just one time point (cast(date_literal as timestamp). + // While we need range which maps to single date pointed by date_literal. + // + return super.visitComparisonExpression(node, complement); + } - Type castSourceType = typeOf(castExpression.getExpression(), session, metadata, types); // type of expression which is then cast to type of value + Type castSourceType = typeOf(castExpression.getExpression(), session, metadata, types); // type of expression which is then cast to type of value - // we use saturated floor cast value -> castSourceType to rewrite original expression to new one with one cast peeled off the symbol side - Optional coercedExpression = coerceComparisonWithRounding( - castSourceType, castExpression.getExpression(), normalized.getValue(), normalized.getComparisonOperator()); + // we use saturated floor cast value -> castSourceType to rewrite original expression to new one with one cast peeled off the symbol side + Optional coercedExpression = coerceComparisonWithRounding(castSourceType, castExpression.getExpression(), normalized.getValue(), normalized.getComparisonOperator()); - if (coercedExpression.isPresent()) { - return process(coercedExpression.get(), complement); - } + if (coercedExpression.isPresent()) { + return process(coercedExpression.get(), complement); + } - return super.visitComparisonExpression(node, complement); + return super.visitComparisonExpression(node, complement); + } } - else { - return super.visitComparisonExpression(node, complement); + else if (normalized.getDereferenceExpression().isPresent()) { + DereferenceExpression dereferenceExpression = normalized.getDereferenceExpression().get(); + Type type = recursiveTypeLookup(dereferenceExpression); + if (type == null) { + return super.visitComparisonExpression(node, complement); + } + return createComparisonExtractionResult(normalized.getComparisonOperator(), null, type, normalized.getValue().getValue(), complement, normalized.getDereferenceExpression(), node); } + return super.visitComparisonExpression(node, complement); } /** @@ -462,22 +542,33 @@ private Optional toNormalizedSimpleComparison(Compar return Optional.empty(); } - Expression symbolExpression; + Optional symbolExpression = Optional.empty(); + Optional dereferenceExpression = Optional.empty(); ComparisonExpression.Operator comparisonOperator; NullableValue value; if (left instanceof Expression) { - symbolExpression = comparison.getLeft(); + if (left instanceof DereferenceExpression) { + dereferenceExpression = Optional.of((DereferenceExpression) left); + } + else { + symbolExpression = Optional.of(comparison.getLeft()); + } comparisonOperator = comparison.getOperator(); value = new NullableValue(rightType, right); } else { - symbolExpression = comparison.getRight(); + if (right instanceof DereferenceExpression) { + dereferenceExpression = Optional.of((DereferenceExpression) right); + } + else { + symbolExpression = Optional.of(comparison.getRight()); + } comparisonOperator = comparison.getOperator().flip(); value = new NullableValue(leftType, left); } - return Optional.of(new NormalizedSimpleComparison(symbolExpression, comparisonOperator, value)); + return Optional.of(new NormalizedSimpleComparison(symbolExpression, comparisonOperator, value, dereferenceExpression)); } private boolean isImplicitCoercion(Cast cast) @@ -493,7 +584,7 @@ private Map, Type> analyzeExpression(Expression expression) return ExpressionAnalyzer.getExpressionTypes(session, metadata, new SqlParser(), types, expression, emptyList(), WarningCollector.NOOP); } - private static ExtractionResult createComparisonExtractionResult(ComparisonExpression.Operator comparisonOperator, Symbol column, Type type, @Nullable Object value, boolean complement) + private static ExtractionResult createComparisonExtractionResult(ComparisonExpression.Operator comparisonOperator, Symbol column, Type type, @Nullable Object value, boolean complement, Optional dereferenceExpression, ComparisonExpression node) { if (value == null) { switch (comparisonOperator) { @@ -507,9 +598,7 @@ private static ExtractionResult createComparisonExtractionResult(ComparisonExpre case IS_DISTINCT_FROM: Domain domain = complementIfNecessary(Domain.notNull(type), complement); - return new ExtractionResult( - TupleDomain.withColumnDomains(ImmutableMap.of(column, domain)), - TRUE_LITERAL); + return buildExtractionResult(column, domain, dereferenceExpression, node); default: throw new AssertionError("Unhandled operator: " + comparisonOperator); @@ -527,9 +616,15 @@ else if (type.isComparable()) { throw new AssertionError("Type cannot be used in a comparison expression (should have been caught in analysis): " + type); } - return new ExtractionResult( - TupleDomain.withColumnDomains(ImmutableMap.of(column, domain)), - TRUE_LITERAL); + return buildExtractionResult(column, domain, dereferenceExpression, node); + } + + private static ExtractionResult buildExtractionResult(Symbol column, Domain domain, Optional dereferenceExpression, ComparisonExpression node) + { + if (column == null) { + return new ExtractionResult(TupleDomain.all(), node, Optional.of(TupleDomain.withColumnDomains(ImmutableMap.of(dereferenceExpression.get(), domain)))); + } + return new ExtractionResult(TupleDomain.withColumnDomains(ImmutableMap.of(column, domain)), TRUE_LITERAL); } private static Domain extractOrderableDomain(ComparisonExpression.Operator comparisonOperator, Type type, Object value, boolean complement) @@ -768,18 +863,20 @@ private static Type typeOf(Expression expression, Session session, Metadata meta private static class NormalizedSimpleComparison { - private final Expression symbolExpression; + private final Optional symbolExpression; private final ComparisonExpression.Operator comparisonOperator; private final NullableValue value; + private final Optional dereferenceExpression; - public NormalizedSimpleComparison(Expression symbolExpression, ComparisonExpression.Operator comparisonOperator, NullableValue value) + public NormalizedSimpleComparison(Optional symbolExpression, ComparisonExpression.Operator comparisonOperator, NullableValue value, Optional dereferenceExpression) { this.symbolExpression = requireNonNull(symbolExpression, "nameReference is null"); this.comparisonOperator = requireNonNull(comparisonOperator, "comparisonOperator is null"); this.value = requireNonNull(value, "value is null"); + this.dereferenceExpression = requireNonNull(dereferenceExpression, "dereferenceExpression is null"); } - public Expression getSymbolExpression() + public Optional getSymbolExpression() { return symbolExpression; } @@ -793,27 +890,55 @@ public NullableValue getValue() { return value; } + + public Optional getDereferenceExpression() + { + return dereferenceExpression; + } } public static class ExtractionResult { private final TupleDomain tupleDomain; private final Expression remainingExpression; + private final Optional> dereferenceTupleDomain; public ExtractionResult(TupleDomain tupleDomain, Expression remainingExpression) + { + this(tupleDomain, remainingExpression, Optional.empty()); + } + + public ExtractionResult(TupleDomain tupleDomain, Expression remainingExpression, Optional> dereferenceTupleDomain) { this.tupleDomain = requireNonNull(tupleDomain, "tupleDomain is null"); this.remainingExpression = requireNonNull(remainingExpression, "remainingExpression is null"); + this.dereferenceTupleDomain = requireNonNull(dereferenceTupleDomain, "dereferenceTupleDomain is null"); } public TupleDomain getTupleDomain() { - return tupleDomain; + if (!dereferenceTupleDomain.isPresent()) { + return tupleDomain; + } + Optional> dereferenceDomain = dereferenceTupleDomain.get().getDomains(); + if (!dereferenceDomain.isPresent()) { + return tupleDomain; + } + ImmutableMap.Builder domainBuilder = ImmutableMap.builder(); + for (Map.Entry entry : dereferenceDomain.get().entrySet()) { + domainBuilder.put(new Symbol(join(".", getDereferenceComponents(entry.getKey()))), entry.getValue()); + } + return tupleDomain.intersect(TupleDomain.withColumnDomains(domainBuilder.build())); } public Expression getRemainingExpression() { return remainingExpression; } + + public Optional> getDereferenceTupleDomain() + { + return dereferenceTupleDomain; + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index ae5418ef4cda4..b8bc64ad05bfd 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -60,6 +60,7 @@ import com.facebook.presto.sql.planner.iterative.rule.PruneJoinColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneLimitColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneMarkDistinctColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneNestedFields; import com.facebook.presto.sql.planner.iterative.rule.PruneOrderByInAggregation; import com.facebook.presto.sql.planner.iterative.rule.PruneOutputColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneProjectColumns; @@ -343,6 +344,13 @@ public PlanOptimizers( new RemoveRedundantIdentityProjections(), new TransformCorrelatedSingleRowSubqueryToProject())), new CheckSubqueryNodesAreRewritten(), + new IterativeOptimizer( + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.>builder() + .addAll(new PruneNestedFields(metadata, sqlParser).rules()) + .build()), predicatePushDown, new IterativeOptimizer( ruleStats, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneNestedFields.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneNestedFields.java new file mode 100644 index 0000000000000..2a2e902afc4b8 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneNestedFields.java @@ -0,0 +1,471 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.execution.warnings.WarningCollector; +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableHandle; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.NestedField; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.iterative.Rule.Context; +import com.facebook.presto.sql.planner.iterative.Rule.Result; +import com.facebook.presto.sql.planner.plan.Assignments; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.planner.plan.TableScanNode; +import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor; +import com.facebook.presto.sql.tree.DereferenceExpression; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.ExpressionRewriter; +import com.facebook.presto.sql.tree.ExpressionTreeRewriter; +import com.facebook.presto.sql.tree.Identifier; +import com.facebook.presto.sql.tree.NodeRef; +import com.facebook.presto.sql.tree.SubscriptExpression; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.facebook.presto.matching.Capture.newCapture; +import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static com.facebook.presto.sql.planner.ExpressionExtractor.extractExpressionsNonRecursive; +import static com.facebook.presto.sql.planner.SymbolsExtractor.extractAll; +import static com.facebook.presto.sql.planner.plan.Patterns.filter; +import static com.facebook.presto.sql.planner.plan.Patterns.project; +import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.planner.plan.Patterns.tableScan; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Collections.emptyList; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toMap; + +public class PruneNestedFields +{ + private final Metadata metadata; + private final SqlParser sqlParser; + + public PruneNestedFields(Metadata metadata, SqlParser sqlParser) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + } + + public Set> rules() + { + return ImmutableSet.of( + new PruneProjectFilter(metadata, sqlParser), + new PruneProjectTableScan(metadata, sqlParser)); + } + + @VisibleForTesting + public static final class PruneProjectFilter + implements Rule + { + private static final Capture FILTER = newCapture(); + private static final Pattern PATTERN = project() + .with(source().matching(filter().capturedAs(FILTER))); + + private final Metadata metadata; + private final SqlParser sqlParser; + + public PruneProjectFilter(Metadata metadata, SqlParser sqlParser) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + } + + @Override + public boolean isEnabled(Session session) + { + //TODO: add session property + return true; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(ProjectNode node, Captures captures, Context context) + { + FilterNode filterNode = captures.get(FILTER); + Map expressions = getDereferenceSymbolMap(node, context, metadata, sqlParser); + List symbols = filterNode.getOutputSymbols().stream().collect(toList()); + Map pushdownExpressions = expressions.entrySet().stream() + .filter(entry -> symbols.contains(getOnlyElement(extractAll(entry.getKey())))) + .collect(toMap(Map.Entry::getKey, Map.Entry::getValue)); + Set dereferences = pushdownExpressions.keySet(); + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + Map pushdownDereferences = pushdownExpressions.entrySet().stream().collect(toMap(Map.Entry::getValue, Map.Entry::getKey)); + List predicates = extractDereference(filterNode.getPredicate()); + Map predicateDereferences = predicates.stream() + .distinct() + .filter(expression -> !expressions.containsKey(expression)) + .collect(toMap(expression -> getSymbol(expression, context, metadata, sqlParser), Function.identity())); + Map predicateExpressions = predicateDereferences.entrySet().stream() + .collect(toMap(Map.Entry::getValue, Map.Entry::getKey)); + + Assignments.Builder assignmentsBuilder = Assignments.builder(); + for (Map.Entry entry : node.getAssignments().entrySet()) { + assignmentsBuilder.put(entry.getKey(), ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(expressions), entry.getValue())); + } + Assignments assignments = assignmentsBuilder.build(); + + List outputs = filterNode.getOutputSymbols().stream() + .filter(symbol -> assignments.getMap().containsKey(symbol)) + .collect(toList()); + + Assignments.Builder pushdownBuilder = Assignments.builder(); + pushdownBuilder.putAll(pushdownDereferences).putAll(predicateDereferences).putIdentities(outputs); + ProjectNode child = new ProjectNode(context.getIdAllocator().getNextId(), filterNode.getSource(), pushdownBuilder.build()); + + ImmutableMap.Builder constraintBuilder = ImmutableMap.builder(); + constraintBuilder.putAll(expressions); + constraintBuilder.putAll(predicateExpressions); + + Expression predicate = ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(constraintBuilder.build()), filterNode.getPredicate()); + FilterNode target = new FilterNode(context.getIdAllocator().getNextId(), child, predicate); + + ProjectNode result = new ProjectNode(context.getIdAllocator().getNextId(), target, assignments); + return Result.ofPlanNode(result); + } + } + + @VisibleForTesting + public static final class PruneProjectTableScan + implements Rule + { + private static final Capture TABLE_SCAN = newCapture(); + private static final Pattern PATTERN = project() + .with(source().matching(tableScan().capturedAs(TABLE_SCAN))); + + private final Metadata metadata; + private final SqlParser sqlParser; + + public PruneProjectTableScan(Metadata metadata, SqlParser sqlParser) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + } + + @Override + public boolean isEnabled(Session session) + { + //TODO: add session property + return true; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(ProjectNode node, Captures captures, Context context) + { + TableScanNode tableScanNode = captures.get(TABLE_SCAN); + Map expressions = getDereferenceSymbolMap(node, context, metadata, sqlParser); + List symbols = tableScanNode.getOutputSymbols().stream().collect(toList()); + Map pushdownExpressions = expressions.entrySet().stream() + .filter(entry -> symbols.contains(getOnlyElement(extractAll(entry.getKey())))) + .collect(toMap(Map.Entry::getKey, Map.Entry::getValue)); + Set dereferences = pushdownExpressions.keySet(); + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + NestedFieldTranslator nestedColumnTranslator = new NestedFieldTranslator(tableScanNode.getAssignments(), tableScanNode.getTable(), context.getSession()); + Map nestedColumns = dereferences.stream().collect(Collectors.toMap(Function.identity(), nestedColumnTranslator::toNestedField)); + + Map nestedColumnHandles = metadata.getNestedColumnHandles(context.getSession(), tableScanNode.getTable(), nestedColumns.values()).entrySet().stream() + .filter(entry -> !nestedColumnTranslator.columnHandleExists(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + if (nestedColumnHandles.isEmpty()) { + return Result.empty(); + } + + if (tableScanNode.getAssignments().values().containsAll(nestedColumnHandles.values())) { + return Result.empty(); + } + ImmutableMap.Builder columnHandleBuilder = ImmutableMap.builder(); + columnHandleBuilder.putAll(tableScanNode.getAssignments()); + + ImmutableMap.Builder symbolExpressionBuilder = ImmutableMap.builder(); + for (Map.Entry entry : nestedColumnHandles.entrySet()) { + NestedField nestedColumn = entry.getKey(); + Expression expression = nestedColumnTranslator.toExpression(nestedColumn); + Symbol symbol = context.getSymbolAllocator().newSymbol(nestedColumn.getName(), getExpressionType(expression, context, metadata, sqlParser)); + symbolExpressionBuilder.put(expression, symbol); + columnHandleBuilder.put(symbol, entry.getValue()); + } + ImmutableMap nestedColumnsMap = columnHandleBuilder.build(); + + TableScanNode source = new TableScanNode(context.getIdAllocator().getNextId(), tableScanNode.getTable(), ImmutableList.copyOf(nestedColumnsMap.keySet()), nestedColumnsMap, tableScanNode.getLayout(), tableScanNode.getCurrentConstraint(), tableScanNode.getEnforcedConstraint()); + + Rewriter rewriter = new Rewriter(symbolExpressionBuilder.build()); + Map assignments = node.getAssignments().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> ExpressionTreeRewriter.rewriteWith(rewriter, entry.getValue()))); + ProjectNode target = new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.copyOf(assignments)); + return Result.ofPlanNode(target); + } + + private class NestedFieldTranslator + { + private final Map symbolToColumnName; + private final Map columnNameToSymbol; + + NestedFieldTranslator(Map columnHandleMap, TableHandle tableHandle, Session session) + { + BiMap symbolToColumnName = HashBiMap.create(columnHandleMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> metadata.getColumnMetadata(session, tableHandle, entry.getValue()).getName()))); + this.symbolToColumnName = symbolToColumnName; + this.columnNameToSymbol = symbolToColumnName.inverse(); + } + + boolean columnHandleExists(NestedField nestedColumn) + { + return columnNameToSymbol.containsKey(nestedColumn.getName()); + } + + NestedField toNestedField(Expression expression) + { + ImmutableList.Builder builder = ImmutableList.builder(); + new DefaultExpressionTraversalVisitor() + { + @Override + protected Void visitSubscriptExpression(SubscriptExpression node, Void context) + { + return null; + } + + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, Void context) + { + process(node.getBase(), context); + builder.add(node.getField().getValue()); + return null; + } + + @Override + protected Void visitSymbolReference(SymbolReference node, Void context) + { + Symbol baseName = Symbol.from(node); + checkArgument(symbolToColumnName.containsKey(baseName), "base [%s] doesn't exist in assignments [%s]", baseName, symbolToColumnName); + builder.add(symbolToColumnName.get(baseName)); + return null; + } + }.process(expression, null); + + List names = builder.build(); + return new NestedField(names); + } + + Expression toExpression(NestedField nestedColumn) + { + Expression result = null; + for (String part : nestedColumn.getFields()) { + if (result == null) { + checkArgument(columnNameToSymbol.containsKey(part), "element %s doesn't exist in map %s", part, columnNameToSymbol); + result = columnNameToSymbol.get(part).toSymbolReference(); + } + else { + result = new DereferenceExpression(result, new Identifier(part)); + } + } + return result; + } + } + } + + private static class Rewriter + extends ExpressionRewriter + { + private final Map map; + + Rewriter(Map map) + { + this.map = map; + } + + @Override + public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (map.containsKey(node)) { + return map.get(node).toSymbolReference(); + } + return treeRewriter.defaultRewrite(node, context); + } + + @Override + public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (map.containsKey(node)) { + return map.get(node).toSymbolReference(); + } + return super.rewriteSymbolReference(node, context, treeRewriter); + } + } + + private static Type getExpressionType(Expression expression, Context context, Metadata metadata, SqlParser sqlParser) + { + Type type = getExpressionTypes(context.getSession(), metadata, sqlParser, context.getSymbolAllocator().getTypes(), expression, emptyList(), WarningCollector.NOOP) + .get(NodeRef.of(expression)); + verify(type != null); + return type; + } + + private static class DereferenceReplacer + extends ExpressionRewriter + { + private final Map expressions; + + DereferenceReplacer(Map expressions) + { + this.expressions = expressions; + } + + @Override + public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (expressions.containsKey(node)) { + return expressions.get(node).toSymbolReference(); + } + return treeRewriter.defaultRewrite(node, context); + } + } + + private static List extractDereference(Expression expression) + { + ImmutableList.Builder builder = ImmutableList.builder(); + new DefaultExpressionTraversalVisitor>() + { + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, ImmutableList.Builder context) + { + context.add(node); + return null; + } + }.process(expression, builder); + return builder.build(); + } + + private static Map getDereferenceSymbolMap(ProjectNode node, Context context, Metadata metadata, SqlParser sqlParser) + { + List expressions = extractExpressionsNonRecursive(node).stream() + .flatMap(expression -> extractDereference(expression).stream()) + .map(PruneNestedFields::processSubscriptDereference) + .filter(Objects::nonNull) + .collect(toList()); + + return expressions.stream() + .filter(expression -> !prefixExist(expression, expressions)) + .filter(expression -> expression instanceof DereferenceExpression) + .distinct() + .collect(toMap(Function.identity(), expression -> getSymbol(expression, context, metadata, sqlParser))); + } + + private static Symbol getSymbol(Expression expression, Context context, Metadata metadata, SqlParser sqlParser) + { + return context.getSymbolAllocator().newSymbol(expression, getExpressionType(expression, context, metadata, sqlParser)); + } + + private static boolean prefixExist(Expression expression, final List dereferences) + { + int[] count = {0}; + new DefaultExpressionTraversalVisitor() + { + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, int[] count) + { + if (dereferences.contains(node)) { + count[0] = count[0] + 1; + } + process(node.getBase(), count); + return null; + } + + @Override + protected Void visitSymbolReference(SymbolReference node, int[] count) + { + if (dereferences.contains(node)) { + count[0] = count[0] + 1; + } + return null; + } + }.process(expression, count); + + return count[0] > 1; + } + + private static Expression processSubscriptDereference(Expression expression) + { + checkArgument(expression instanceof DereferenceExpression, "Expression: " + expression.toString() + " is not DereferenceExpression"); + SubscriptExpression[] subscriptExpression = new SubscriptExpression[1]; + boolean[] isDereferenceOrSubscript = {true}; + + new DefaultExpressionTraversalVisitor() + { + @Override + protected Void visitSubscriptExpression(SubscriptExpression node, Void context) + { + subscriptExpression[0] = node; + process(node.getBase(), context); + return null; + } + + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, Void context) + { + if (!(node.getBase() instanceof SymbolReference || node.getBase() instanceof DereferenceExpression || node.getBase() instanceof SubscriptExpression)) { + isDereferenceOrSubscript[0] = false; + } + process(node.getBase(), context); + return null; + } + }.process(expression, null); + + if (isDereferenceOrSubscript[0]) { + return subscriptExpression[0] == null ? expression : subscriptExpression[0].getBase(); + } + return null; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java b/presto-main/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java index 2252378f204b9..70f3f33224213 100644 --- a/presto-main/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java +++ b/presto-main/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.NestedField; import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.block.BlockEncodingSerde; import com.facebook.presto.spi.connector.ConnectorCapabilities; @@ -174,6 +175,12 @@ public Map getColumnHandles(Session session, TableHandle t throw new UnsupportedOperationException(); } + @Override + public Map getNestedColumnHandles(Session session, TableHandle tableHandle, Collection nestedField) + { + throw new UnsupportedOperationException(); + } + @Override public ColumnMetadata getColumnMetadata(Session session, TableHandle tableHandle, ColumnHandle columnHandle) { diff --git a/presto-parquet/src/main/java/com/facebook/presto/parquet/ParquetTypeUtils.java b/presto-parquet/src/main/java/com/facebook/presto/parquet/ParquetTypeUtils.java index 062d8e28d1cf0..c3777e20472c8 100644 --- a/presto-parquet/src/main/java/com/facebook/presto/parquet/ParquetTypeUtils.java +++ b/presto-parquet/src/main/java/com/facebook/presto/parquet/ParquetTypeUtils.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.parquet; +import com.facebook.presto.spi.NestedField; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.spi.predicate.TupleDomain; @@ -25,6 +26,7 @@ import com.facebook.presto.spi.type.TimestampType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarcharType; +import com.google.common.collect.ImmutableList; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Encoding; import org.apache.parquet.io.ColumnIO; @@ -35,6 +37,7 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.PrimitiveColumnIO; import org.apache.parquet.schema.DecimalMetadata; +import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.OriginalType; @@ -48,6 +51,7 @@ import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.Iterators.getOnlyElement; import static org.apache.parquet.schema.OriginalType.DECIMAL; import static org.apache.parquet.schema.Type.Repetition.REPEATED; @@ -199,15 +203,15 @@ private static Type createVarcharType(TupleDomain effectivePre return VarcharType.VARCHAR; } - public static int getFieldIndex(MessageType fileSchema, String name) + public static int getFieldIndex(GroupType groupType, String name) { try { - return fileSchema.getFieldIndex(name.toLowerCase(Locale.ENGLISH)); + return groupType.getFieldIndex(name.toLowerCase(Locale.ENGLISH)); } catch (InvalidRecordException e) { - for (org.apache.parquet.schema.Type type : fileSchema.getFields()) { + for (org.apache.parquet.schema.Type type : groupType.getFields()) { if (type.getName().equalsIgnoreCase(name)) { - return fileSchema.getFieldIndex(type.getName()); + return groupType.getFieldIndex(type.getName()); } } return -1; @@ -238,14 +242,14 @@ public static ParquetEncoding getParquetEncoding(Encoding encoding) } } - public static org.apache.parquet.schema.Type getParquetTypeByName(String columnName, MessageType messageType) + public static org.apache.parquet.schema.Type getParquetTypeByName(String columnName, GroupType groupType) { - if (messageType.containsField(columnName)) { - return messageType.getType(columnName); + if (groupType.containsField(columnName)) { + return groupType.getType(columnName); } // parquet is case-sensitive, but hive is not. all hive columns get converted to lowercase // check for direct match above but if no match found, try case-insensitive match - for (org.apache.parquet.schema.Type type : messageType.getFields()) { + for (org.apache.parquet.schema.Type type : groupType.getFields()) { if (type.getName().equalsIgnoreCase(columnName)) { return type; } @@ -316,6 +320,37 @@ public static long getShortDecimalValue(byte[] bytes) return value; } + public static org.apache.parquet.schema.Type getFieldType(GroupType baseType, NestedField nestedField) + { + ImmutableList.Builder typeBuilder = ImmutableList.builder(); + org.apache.parquet.schema.Type parentType = baseType; + + for (String field : nestedField.getFields()) { + org.apache.parquet.schema.Type childType = getParquetTypeByName(field, parentType.asGroupType()); + if (childType == null) { + return null; + } + typeBuilder.add(childType); + parentType = childType; + } + List nestedType = typeBuilder.build(); + + if (nestedType.isEmpty()) { + return null; + } + else if (nestedType.size() == 1) { + return getOnlyElement(nestedType.iterator()); + } + else { + org.apache.parquet.schema.Type messageType = nestedType.get(nestedType.size() - 1); + for (int i = nestedType.size() - 2; i >= 0; --i) { + GroupType groupType = nestedType.get(i).asGroupType(); + messageType = new MessageType(groupType.getName(), ImmutableList.of(messageType)); + } + return messageType; + } + } + private static Type getInt32Type(RichColumnDescriptor descriptor) { OriginalType originalType = descriptor.getPrimitiveType().getOriginalType(); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/NestedField.java b/presto-spi/src/main/java/com/facebook/presto/spi/NestedField.java new file mode 100644 index 0000000000000..8549dbea78135 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/NestedField.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi; + +import com.facebook.presto.spi.type.Type; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +public class NestedField +{ + private final List fields; + + @JsonCreator + public NestedField(@JsonProperty("fields") List fields) + { + this.fields = requireNonNull(fields); + } + + @JsonProperty + public List getFields() + { + return fields; + } + + public String getBase() + { + return fields.get(0); + } + + public List getRemaining() + { + if (fields.size() <= 1) { + throw new IllegalArgumentException("NestedField has more than 1 field"); + } + return fields.subList(1, fields.size()); + } + + public String getName() + { + return fields.stream().collect(joining(".")); + } + + @JsonProperty + public Type getType() + { + return null; + } + + @Override + public int hashCode() + { + return Objects.hash(fields); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + NestedField other = (NestedField) obj; + return Objects.equals(this.fields, other.fields); + } + + @Override + public String toString() + { + StringBuilder sb = new StringBuilder("NestedFields<"); + sb.append("name='").append(getName()).append('>'); + return sb.toString(); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java index 1675336032f4b..5fd868d5488cc 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadata.java @@ -27,6 +27,7 @@ import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.ConnectorViewDefinition; import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.NestedField; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SchemaTablePrefix; @@ -184,6 +185,16 @@ default List listTables(ConnectorSession session, Optional getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle); + /** + * Gets all nested columns on the specified table, or an empty map if the columns can not be enumerated. + * + * @throws RuntimeException if table handle is no longer valid + */ + default Map getNestedColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle, Collection nestedFields) + { + return emptyMap(); + } + /** * Gets the metadata for the specified table column. * diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java index 0f24c86aa2519..ac60d30eeeed7 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/connector/classloader/ClassLoaderSafeConnectorMetadata.java @@ -27,6 +27,7 @@ import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.ConnectorViewDefinition; import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.NestedField; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SchemaTablePrefix; import com.facebook.presto.spi.SystemTable; @@ -238,6 +239,14 @@ public Map getColumnHandles(ConnectorSession session, Conn } } + @Override + public Map getNestedColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle, Collection nestedField) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.getNestedColumnHandles(session, tableHandle, nestedField); + } + } + @Override public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle) {