diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveCoercionPolicy.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveCoercionPolicy.java index ed9a5a505ad9f..9f1823a4ebd1a 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveCoercionPolicy.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveCoercionPolicy.java @@ -16,15 +16,23 @@ import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.VarcharType; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import javax.inject.Inject; +import java.util.List; + import static com.facebook.presto.hive.HiveType.HIVE_BYTE; import static com.facebook.presto.hive.HiveType.HIVE_DOUBLE; import static com.facebook.presto.hive.HiveType.HIVE_FLOAT; import static com.facebook.presto.hive.HiveType.HIVE_INT; import static com.facebook.presto.hive.HiveType.HIVE_LONG; import static com.facebook.presto.hive.HiveType.HIVE_SHORT; +import static com.facebook.presto.hive.HiveUtil.extractStructFieldTypes; +import static java.lang.Math.min; import static java.util.Objects.requireNonNull; public class HiveCoercionPolicy @@ -62,6 +70,49 @@ public boolean canCoerce(HiveType fromHiveType, HiveType toHiveType) return toHiveType.equals(HIVE_DOUBLE); } - return false; + return canCoerceForList(fromHiveType, toHiveType) || canCoerceForMap(fromHiveType, toHiveType) || canCoerceForStruct(fromHiveType, toHiveType); + } + + private boolean canCoerceForMap(HiveType fromHiveType, HiveType toHiveType) + { + if (!fromHiveType.getCategory().equals(Category.MAP) || !toHiveType.getCategory().equals(Category.MAP)) { + return false; + } + HiveType fromKeyType = HiveType.valueOf(((MapTypeInfo) fromHiveType.getTypeInfo()).getMapKeyTypeInfo().getTypeName()); + HiveType fromValueType = HiveType.valueOf(((MapTypeInfo) fromHiveType.getTypeInfo()).getMapValueTypeInfo().getTypeName()); + HiveType toKeyType = HiveType.valueOf(((MapTypeInfo) toHiveType.getTypeInfo()).getMapKeyTypeInfo().getTypeName()); + HiveType toValueType = HiveType.valueOf(((MapTypeInfo) toHiveType.getTypeInfo()).getMapValueTypeInfo().getTypeName()); + return (fromKeyType.equals(toKeyType) || canCoerce(fromKeyType, toKeyType)) && + (fromValueType.equals(toValueType) || canCoerce(fromValueType, toValueType)); + } + + private boolean canCoerceForList(HiveType fromHiveType, HiveType toHiveType) + { + if (!fromHiveType.getCategory().equals(Category.LIST) || !toHiveType.getCategory().equals(Category.LIST)) { + return false; + } + HiveType fromElementType = HiveType.valueOf(((ListTypeInfo) fromHiveType.getTypeInfo()).getListElementTypeInfo().getTypeName()); + HiveType toElementType = HiveType.valueOf(((ListTypeInfo) toHiveType.getTypeInfo()).getListElementTypeInfo().getTypeName()); + return fromElementType.equals(toElementType) || canCoerce(fromElementType, toElementType); + } + + private boolean canCoerceForStruct(HiveType fromHiveType, HiveType toHiveType) + { + if (!fromHiveType.getCategory().equals(Category.STRUCT) || !toHiveType.getCategory().equals(Category.STRUCT)) { + return false; + } + List fromFieldNames = ((StructTypeInfo) fromHiveType.getTypeInfo()).getAllStructFieldNames(); + List toFieldNames = ((StructTypeInfo) toHiveType.getTypeInfo()).getAllStructFieldNames(); + List fromFieldTypes = extractStructFieldTypes(fromHiveType); + List toFieldTypes = extractStructFieldTypes(toHiveType); + for (int i = 0; i < min(fromFieldTypes.size(), toFieldTypes.size()); i++) { + if (!fromFieldNames.get(i).equals(toFieldNames.get(i))) { + return false; + } + if (!fromFieldTypes.get(i).equals(toFieldTypes.get(i)) && !canCoerce(fromFieldTypes.get(i), toFieldTypes.get(i))) { + return false; + } + } + return true; } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveCoercionRecordCursor.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveCoercionRecordCursor.java index 442be5042d91e..68b8ea9b1d580 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveCoercionRecordCursor.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveCoercionRecordCursor.java @@ -14,14 +14,19 @@ package com.facebook.presto.hive; import com.facebook.presto.hive.HivePageSourceProvider.ColumnMapping; +import com.facebook.presto.spi.PageBuilder; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.RecordCursor; +import com.facebook.presto.spi.block.Block; +import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.VarcharType; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import java.util.List; @@ -31,9 +36,14 @@ import static com.facebook.presto.hive.HiveType.HIVE_INT; import static com.facebook.presto.hive.HiveType.HIVE_LONG; import static com.facebook.presto.hive.HiveType.HIVE_SHORT; +import static com.facebook.presto.hive.HiveUtil.extractStructFieldTypes; +import static com.facebook.presto.hive.HiveUtil.isArrayType; +import static com.facebook.presto.hive.HiveUtil.isMapType; +import static com.facebook.presto.hive.HiveUtil.isRowType; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.min; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -43,6 +53,7 @@ public class HiveCoercionRecordCursor private final RecordCursor delegate; private final List columnMappings; private final Coercer[] coercers; + private final BridgingRecordCursor bridgingRecordCursor; public HiveCoercionRecordCursor( List columnMappings, @@ -51,6 +62,7 @@ public HiveCoercionRecordCursor( { requireNonNull(columnMappings, "columns is null"); requireNonNull(typeManager, "typeManager is null"); + this.bridgingRecordCursor = new BridgingRecordCursor(); this.delegate = requireNonNull(delegate, "delegate is null"); this.columnMappings = ImmutableList.copyOf(columnMappings); @@ -63,7 +75,7 @@ public HiveCoercionRecordCursor( ColumnMapping columnMapping = columnMappings.get(columnIndex); if (columnMapping.getCoercionFrom().isPresent()) { - coercers[columnIndex] = createCoercer(typeManager, columnMapping.getCoercionFrom().get(), columnMapping.getHiveColumnHandle().getHiveType()); + coercers[columnIndex] = createCoercer(typeManager, columnMapping.getCoercionFrom().get(), columnMapping.getHiveColumnHandle().getHiveType(), bridgingRecordCursor); } } } @@ -266,7 +278,7 @@ protected void setIsNull(boolean isNull) } } - private static Coercer createCoercer(TypeManager typeManager, HiveType fromHiveType, HiveType toHiveType) + private static Coercer createCoercer(TypeManager typeManager, HiveType fromHiveType, HiveType toHiveType, BridgingRecordCursor bridgingRecordCursor) { Type fromType = typeManager.getType(fromHiveType.getTypeSignature()); Type toType = typeManager.getType(toHiveType.getTypeSignature()); @@ -288,6 +300,16 @@ else if (fromHiveType.equals(HIVE_INT) && toHiveType.equals(HIVE_LONG)) { else if (fromHiveType.equals(HIVE_FLOAT) && toHiveType.equals(HIVE_DOUBLE)) { return new FloatToDoubleCoercer(); } + else if (isArrayType(fromType) && isArrayType(toType)) { + return new ListCoercer(typeManager, fromHiveType, toHiveType, bridgingRecordCursor); + } + else if (isMapType(fromType) && isMapType(toType)) { + return new MapCoercer(typeManager, fromHiveType, toHiveType, bridgingRecordCursor); + } + else if (isRowType(fromType) && isRowType(toType)) { + return new StructCoercer(typeManager, fromHiveType, toHiveType, bridgingRecordCursor); + } + throw new PrestoException(NOT_SUPPORTED, format("Unsupported coercion from %s to %s", fromHiveType, toHiveType)); } @@ -367,4 +389,315 @@ public void coerce(RecordCursor delegate, int field) } } } + + private static class ListCoercer + extends Coercer + { + private final Type fromElementType; + private final Type toType; + private final Type toElementType; + private final Coercer elementCoercer; + private final BridgingRecordCursor bridgingRecordCursor; + private final PageBuilder pageBuilder; + + public ListCoercer(TypeManager typeManager, HiveType fromHiveType, HiveType toHiveType, BridgingRecordCursor bridgingRecordCursor) + { + requireNonNull(typeManager, "typeManage is null"); + requireNonNull(fromHiveType, "fromHiveType is null"); + requireNonNull(toHiveType, "toHiveType is null"); + this.bridgingRecordCursor = requireNonNull(bridgingRecordCursor, "bridgingRecordCursor is null"); + HiveType fromElementHiveType = HiveType.valueOf(((ListTypeInfo) fromHiveType.getTypeInfo()).getListElementTypeInfo().getTypeName()); + HiveType toElementHiveType = HiveType.valueOf(((ListTypeInfo) toHiveType.getTypeInfo()).getListElementTypeInfo().getTypeName()); + this.fromElementType = fromElementHiveType.getType(typeManager); + this.toType = toHiveType.getType(typeManager); + this.toElementType = toElementHiveType.getType(typeManager); + this.elementCoercer = fromElementHiveType.equals(toElementHiveType) ? null : createCoercer(typeManager, fromElementHiveType, toElementHiveType, bridgingRecordCursor); + this.pageBuilder = elementCoercer == null ? null : new PageBuilder(ImmutableList.of(toType)); + } + + @Override + public void coerce(RecordCursor delegate, int field) + { + if (delegate.isNull(field)) { + setIsNull(true); + return; + } + Block block = (Block) delegate.getObject(field); + if (pageBuilder.isFull()) { + pageBuilder.reset(); + } + BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); + BlockBuilder listBuilder = blockBuilder.beginBlockEntry(); + for (int i = 0; i < block.getPositionCount(); i++) { + if (elementCoercer == null) { + toElementType.appendTo(block, i, listBuilder); + } + else { + if (block.isNull(i)) { + listBuilder.appendNull(); + } + else { + rewriteBlock(fromElementType, toElementType, block, i, listBuilder, elementCoercer, bridgingRecordCursor); + } + } + } + blockBuilder.closeEntry(); + pageBuilder.declarePosition(); + setObject(toType.getObject(blockBuilder, blockBuilder.getPositionCount() - 1)); + } + } + + private static class MapCoercer + extends Coercer + { + private final List fromKeyValueTypes; + private final Type toType; + private final List toKeyValueTypes; + private final Coercer[] coercers; + private final BridgingRecordCursor bridgingRecordCursor; + private final PageBuilder pageBuilder; + + public MapCoercer(TypeManager typeManager, HiveType fromHiveType, HiveType toHiveType, BridgingRecordCursor bridgingRecordCursor) + { + requireNonNull(typeManager, "typeManage is null"); + requireNonNull(fromHiveType, "fromHiveType is null"); + requireNonNull(toHiveType, "toHiveType is null"); + this.bridgingRecordCursor = requireNonNull(bridgingRecordCursor, "bridgingRecordCursor is null"); + HiveType fromKeyHiveType = HiveType.valueOf(((MapTypeInfo) fromHiveType.getTypeInfo()).getMapKeyTypeInfo().getTypeName()); + HiveType fromValueHiveType = HiveType.valueOf(((MapTypeInfo) fromHiveType.getTypeInfo()).getMapValueTypeInfo().getTypeName()); + HiveType toKeyHiveType = HiveType.valueOf(((MapTypeInfo) toHiveType.getTypeInfo()).getMapKeyTypeInfo().getTypeName()); + HiveType toValueHiveType = HiveType.valueOf(((MapTypeInfo) toHiveType.getTypeInfo()).getMapValueTypeInfo().getTypeName()); + this.fromKeyValueTypes = fromHiveType.getType(typeManager).getTypeParameters(); + this.toType = toHiveType.getType(typeManager); + this.toKeyValueTypes = toType.getTypeParameters(); + this.coercers = new Coercer[2]; + coercers[0] = fromKeyHiveType.equals(toKeyHiveType) ? null : createCoercer(typeManager, fromKeyHiveType, toKeyHiveType, bridgingRecordCursor); + coercers[1] = fromValueHiveType.equals(toValueHiveType) ? null : createCoercer(typeManager, fromValueHiveType, toValueHiveType, bridgingRecordCursor); + this.pageBuilder = coercers[0] == null && coercers[1] == null ? null : new PageBuilder(ImmutableList.of(toType)); + } + + @Override + public void coerce(RecordCursor delegate, int field) + { + if (delegate.isNull(field)) { + setIsNull(true); + return; + } + Block block = (Block) delegate.getObject(field); + if (pageBuilder.isFull()) { + pageBuilder.reset(); + } + BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); + BlockBuilder mapBuilder = blockBuilder.beginBlockEntry(); + for (int i = 0; i < block.getPositionCount(); i++) { + int k = i % 2; + if (coercers[k] == null) { + toKeyValueTypes.get(k).appendTo(block, i, mapBuilder); + } + else { + if (block.isNull(i)) { + mapBuilder.appendNull(); + } + else { + rewriteBlock(fromKeyValueTypes.get(k), toKeyValueTypes.get(k), block, i, mapBuilder, coercers[k], bridgingRecordCursor); + } + } + } + blockBuilder.closeEntry(); + pageBuilder.declarePosition(); + setObject(toType.getObject(blockBuilder, blockBuilder.getPositionCount() - 1)); + } + } + + private static class StructCoercer + extends Coercer + { + private final Type toType; + private final List fromFieldTypes; + private final List toFieldTypes; + private final Coercer[] coercers; + private final BridgingRecordCursor bridgingRecordCursor; + private final PageBuilder pageBuilder; + + public StructCoercer(TypeManager typeManager, HiveType fromHiveType, HiveType toHiveType, BridgingRecordCursor bridgingRecordCursor) + { + requireNonNull(typeManager, "typeManage is null"); + requireNonNull(fromHiveType, "fromHiveType is null"); + requireNonNull(toHiveType, "toHiveType is null"); + this.bridgingRecordCursor = requireNonNull(bridgingRecordCursor, "bridgingRecordCursor is null"); + List fromFieldHiveTypes = extractStructFieldTypes(fromHiveType); + List toFieldHiveTypes = extractStructFieldTypes(toHiveType); + this.fromFieldTypes = fromHiveType.getType(typeManager).getTypeParameters(); + this.toType = toHiveType.getType(typeManager); + this.toFieldTypes = toType.getTypeParameters(); + this.coercers = new Coercer[toFieldHiveTypes.size()]; + for (int i = 0; i < min(fromFieldHiveTypes.size(), toFieldHiveTypes.size()); i++) { + if (!fromFieldTypes.get(i).equals(toFieldTypes.get(i))) { + coercers[i] = createCoercer(typeManager, fromFieldHiveTypes.get(i), toFieldHiveTypes.get(i), bridgingRecordCursor); + } + } + this.pageBuilder = new PageBuilder(ImmutableList.of(toType)); + } + + @Override + public void coerce(RecordCursor delegate, int field) + { + if (delegate.isNull(field)) { + setIsNull(true); + return; + } + Block block = (Block) delegate.getObject(field); + if (pageBuilder.isFull()) { + pageBuilder.reset(); + } + BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(0); + BlockBuilder rowBuilder = blockBuilder.beginBlockEntry(); + for (int i = 0; i < toFieldTypes.size(); i++) { + if (i >= fromFieldTypes.size() || block.isNull(i)) { + rowBuilder.appendNull(); + } + else if (coercers[i] == null) { + toFieldTypes.get(i).appendTo(block, i, rowBuilder); + } + else { + rewriteBlock(fromFieldTypes.get(i), toFieldTypes.get(i), block, i, rowBuilder, coercers[i], bridgingRecordCursor); + } + } + blockBuilder.closeEntry(); + pageBuilder.declarePosition(); + setObject(toType.getObject(blockBuilder, blockBuilder.getPositionCount() - 1)); + } + } + + private static void rewriteBlock( + Type fromType, + Type toType, + Block block, + int position, + BlockBuilder blockBuilder, + Coercer coercer, + BridgingRecordCursor bridgingRecordCursor) + { + Class fromJavaType = fromType.getJavaType(); + if (fromJavaType == long.class) { + bridgingRecordCursor.setValue(fromType.getLong(block, position)); + } + else if (fromJavaType == double.class) { + bridgingRecordCursor.setValue(fromType.getDouble(block, position)); + } + else if (fromJavaType == boolean.class) { + bridgingRecordCursor.setValue(fromType.getBoolean(block, position)); + } + else if (fromJavaType == Slice.class) { + bridgingRecordCursor.setValue(fromType.getSlice(block, position)); + } + else if (fromJavaType == Block.class) { + bridgingRecordCursor.setValue(fromType.getObject(block, position)); + } + else { + bridgingRecordCursor.setValue(null); + } + coercer.reset(); + Class toJaveType = toType.getJavaType(); + if (coercer.isNull(bridgingRecordCursor, 0)) { + blockBuilder.appendNull(); + } + else if (toJaveType == long.class) { + toType.writeLong(blockBuilder, coercer.getLong(bridgingRecordCursor, 0)); + } + else if (toJaveType == double.class) { + toType.writeDouble(blockBuilder, coercer.getDouble(bridgingRecordCursor, 0)); + } + else if (toJaveType == boolean.class) { + toType.writeBoolean(blockBuilder, coercer.getBoolean(bridgingRecordCursor, 0)); + } + else if (toJaveType == Slice.class) { + toType.writeSlice(blockBuilder, coercer.getSlice(bridgingRecordCursor, 0)); + } + else if (toJaveType == Block.class) { + toType.writeObject(blockBuilder, coercer.getObject(bridgingRecordCursor, 0)); + } + else { + throw new PrestoException(NOT_SUPPORTED, format("Unsupported coercion from %s to %s", fromType.getDisplayName(), toType.getDisplayName())); + } + coercer.reset(); + bridgingRecordCursor.close(); + } + + private static class BridgingRecordCursor + implements RecordCursor + { + private Object value; + + public void setValue(Object value) + { + this.value = value; + } + + @Override + public long getCompletedBytes() + { + return 0; + } + + @Override + public long getReadTimeNanos() + { + return 0; + } + + @Override + public Type getType(int field) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean advanceNextPosition() + { + return true; + } + + @Override + public boolean getBoolean(int field) + { + return (Boolean) value; + } + + @Override + public long getLong(int field) + { + return (Long) value; + } + + @Override + public double getDouble(int field) + { + return (Double) value; + } + + @Override + public Slice getSlice(int field) + { + return (Slice) value; + } + + @Override + public Object getObject(int field) + { + return value; + } + + @Override + public boolean isNull(int field) + { + return value == null; + } + + @Override + public void close() + { + this.value = null; + } + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSource.java b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSource.java index 5b034df2d6165..88ace837f9f74 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSource.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSource.java @@ -17,16 +17,25 @@ import com.facebook.presto.spi.ConnectorPageSource; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.block.ArrayBlock; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.block.BlockBuilder; import com.facebook.presto.spi.block.BlockBuilderStatus; +import com.facebook.presto.spi.block.ColumnarArray; +import com.facebook.presto.spi.block.ColumnarMap; +import com.facebook.presto.spi.block.ColumnarRow; +import com.facebook.presto.spi.block.DictionaryBlock; import com.facebook.presto.spi.block.LazyBlock; import com.facebook.presto.spi.block.LazyBlockLoader; +import com.facebook.presto.spi.block.RowBlock; import com.facebook.presto.spi.block.RunLengthEncodedBlock; import com.facebook.presto.spi.type.DecimalType; +import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeManager; import com.facebook.presto.spi.type.VarcharType; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.joda.time.DateTimeZone; import java.io.IOException; @@ -46,8 +55,13 @@ import static com.facebook.presto.hive.HiveUtil.charPartitionKey; import static com.facebook.presto.hive.HiveUtil.datePartitionKey; import static com.facebook.presto.hive.HiveUtil.doublePartitionKey; +import static com.facebook.presto.hive.HiveUtil.extractStructFieldTypes; import static com.facebook.presto.hive.HiveUtil.floatPartitionKey; import static com.facebook.presto.hive.HiveUtil.integerPartitionKey; +import static com.facebook.presto.hive.HiveUtil.isArrayType; +import static com.facebook.presto.hive.HiveUtil.isHiveNull; +import static com.facebook.presto.hive.HiveUtil.isMapType; +import static com.facebook.presto.hive.HiveUtil.isRowType; import static com.facebook.presto.hive.HiveUtil.longDecimalPartitionKey; import static com.facebook.presto.hive.HiveUtil.shortDecimalPartitionKey; import static com.facebook.presto.hive.HiveUtil.smallintPartitionKey; @@ -55,6 +69,9 @@ import static com.facebook.presto.hive.HiveUtil.tinyintPartitionKey; import static com.facebook.presto.hive.HiveUtil.varcharPartitionKey; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.spi.block.ColumnarArray.toColumnarArray; +import static com.facebook.presto.spi.block.ColumnarMap.toColumnarMap; +import static com.facebook.presto.spi.block.ColumnarRow.toColumnarRow; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.Chars.isCharType; @@ -120,7 +137,7 @@ public HivePageSource( byte[] bytes = columnValue.getBytes(UTF_8); Object prefilledValue; - if (HiveUtil.isHiveNull(bytes)) { + if (isHiveNull(bytes)) { prefilledValue = null; } else if (type.equals(BOOLEAN)) { @@ -287,6 +304,15 @@ else if (fromHiveType.equals(HIVE_INT) && toHiveType.equals(HIVE_LONG)) { else if (fromHiveType.equals(HIVE_FLOAT) && toHiveType.equals(HIVE_DOUBLE)) { return new FloatToDoubleCoercer(); } + else if (isArrayType(fromType) && isArrayType(toType)) { + return new ListCoercer(typeManager, fromHiveType, toHiveType); + } + else if (isMapType(fromType) && isMapType(toType)) { + return new MapCoercer(typeManager, fromHiveType, toHiveType); + } + else if (isRowType(fromType) && isRowType(toType)) { + return new StructCoercer(typeManager, fromHiveType, toHiveType); + } throw new PrestoException(NOT_SUPPORTED, format("Unsupported coercion from %s to %s", fromHiveType, toHiveType)); } @@ -424,6 +450,128 @@ public Block apply(Block block) } } + private static class ListCoercer + implements Function + { + private final Function elementCoercer; + + public ListCoercer(TypeManager typeManager, HiveType fromHiveType, HiveType toHiveType) + { + requireNonNull(typeManager, "typeManage is null"); + requireNonNull(fromHiveType, "fromHiveType is null"); + requireNonNull(toHiveType, "toHiveType is null"); + HiveType fromElementHiveType = HiveType.valueOf(((ListTypeInfo) fromHiveType.getTypeInfo()).getListElementTypeInfo().getTypeName()); + HiveType toElementHiveType = HiveType.valueOf(((ListTypeInfo) toHiveType.getTypeInfo()).getListElementTypeInfo().getTypeName()); + this.elementCoercer = fromElementHiveType.equals(toElementHiveType) ? null : createCoercer(typeManager, fromElementHiveType, toElementHiveType); + } + + @Override + public Block apply(Block block) + { + if (elementCoercer == null) { + return block; + } + ColumnarArray arrayBlock = toColumnarArray(block); + Block elementsBlock = elementCoercer.apply(arrayBlock.getElementsBlock()); + boolean[] valueIsNull = new boolean[arrayBlock.getPositionCount()]; + int[] offsets = new int[arrayBlock.getPositionCount() + 1]; + for (int i = 0; i < arrayBlock.getPositionCount(); i++) { + valueIsNull[i] = arrayBlock.isNull(i); + offsets[i + 1] = offsets[i] + arrayBlock.getLength(i); + } + return new ArrayBlock(arrayBlock.getPositionCount(), valueIsNull, offsets, elementsBlock); + } + } + + private static class MapCoercer + implements Function + { + private final Type toType; + private final Function keyCoercer; + private final Function valueCoercer; + + public MapCoercer(TypeManager typeManager, HiveType fromHiveType, HiveType toHiveType) + { + requireNonNull(typeManager, "typeManage is null"); + requireNonNull(fromHiveType, "fromHiveType is null"); + this.toType = requireNonNull(toHiveType, "toHiveType is null").getType(typeManager); + HiveType fromKeyHiveType = HiveType.valueOf(((MapTypeInfo) fromHiveType.getTypeInfo()).getMapKeyTypeInfo().getTypeName()); + HiveType fromValueHiveType = HiveType.valueOf(((MapTypeInfo) fromHiveType.getTypeInfo()).getMapValueTypeInfo().getTypeName()); + HiveType toKeyHiveType = HiveType.valueOf(((MapTypeInfo) toHiveType.getTypeInfo()).getMapKeyTypeInfo().getTypeName()); + HiveType toValueHiveType = HiveType.valueOf(((MapTypeInfo) toHiveType.getTypeInfo()).getMapValueTypeInfo().getTypeName()); + this.keyCoercer = fromKeyHiveType.equals(toKeyHiveType) ? null : createCoercer(typeManager, fromKeyHiveType, toKeyHiveType); + this.valueCoercer = fromValueHiveType.equals(toValueHiveType) ? null : createCoercer(typeManager, fromValueHiveType, toValueHiveType); + } + + @Override + public Block apply(Block block) + { + ColumnarMap mapBlock = toColumnarMap(block); + Block keysBlock = keyCoercer == null ? mapBlock.getKeysBlock() : keyCoercer.apply(mapBlock.getKeysBlock()); + Block valuesBlock = valueCoercer == null ? mapBlock.getValuesBlock() : valueCoercer.apply(mapBlock.getValuesBlock()); + boolean[] valueIsNull = new boolean[mapBlock.getPositionCount()]; + int[] offsets = new int[mapBlock.getPositionCount() + 1]; + for (int i = 0; i < mapBlock.getPositionCount(); i++) { + valueIsNull[i] = mapBlock.isNull(i); + offsets[i + 1] = offsets[i] + mapBlock.getEntryCount(i); + } + return ((MapType) toType).createBlockFromKeyValue(valueIsNull, offsets, keysBlock, valuesBlock); + } + } + + private static class StructCoercer + implements Function + { + private final Function[] coercers; + private final Block[] nullBlocks; + + public StructCoercer(TypeManager typeManager, HiveType fromHiveType, HiveType toHiveType) + { + requireNonNull(typeManager, "typeManage is null"); + requireNonNull(fromHiveType, "fromHiveType is null"); + requireNonNull(toHiveType, "toHiveType is null"); + List fromFieldTypes = extractStructFieldTypes(fromHiveType); + List toFieldTypes = extractStructFieldTypes(toHiveType); + this.coercers = new Function[toFieldTypes.size()]; + this.nullBlocks = new Block[toFieldTypes.size()]; + BlockBuilderStatus blockBuilderStatus = new BlockBuilderStatus(); + for (int i = 0; i < coercers.length; i++) { + if (i >= fromFieldTypes.size()) { + nullBlocks[i] = toFieldTypes.get(i).getType(typeManager).createBlockBuilder(blockBuilderStatus, 1).appendNull().build(); + } + else if (!fromFieldTypes.get(i).equals(toFieldTypes.get(i))) { + coercers[i] = createCoercer(typeManager, fromFieldTypes.get(i), toFieldTypes.get(i)); + } + } + } + + @Override + public Block apply(Block block) + { + ColumnarRow rowBlock = toColumnarRow(block); + Block[] fields = new Block[coercers.length]; + int[] ids = new int[rowBlock.getField(0).getPositionCount()]; + for (int i = 0; i < coercers.length; i++) { + if (coercers[i] != null) { + fields[i] = coercers[i].apply(rowBlock.getField(i)); + } + else if (i < rowBlock.getFieldCount()) { + fields[i] = rowBlock.getField(i); + } + else { + fields[i] = new DictionaryBlock(nullBlocks[i], ids); + } + } + boolean[] valueIsNull = new boolean[rowBlock.getPositionCount()]; + int[] offsets = new int[rowBlock.getPositionCount() + 1]; + for (int i = 0; i < rowBlock.getPositionCount(); i++) { + valueIsNull[i] = rowBlock.isNull(i); + offsets[i + 1] = offsets[i] + (valueIsNull[i] ? 0 : 1); + } + return new RowBlock(0, rowBlock.getPositionCount(), valueIsNull, offsets, fields); + } + } + private final class CoercionLazyBlockLoader implements LazyBlockLoader { @@ -443,6 +591,9 @@ public void load(LazyBlock lazyBlock) return; } + if (block instanceof LazyBlock) { + block = ((LazyBlock) block).getBlock(); + } Block coercedBlock = coercer.apply(block); lazyBlock.setBlock(coercedBlock); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java index 9121ef9d6e3f0..1047acb411f5b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java @@ -47,6 +47,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import org.apache.hadoop.mapred.FileSplit; import org.apache.hadoop.mapred.InputFormat; import org.apache.hadoop.mapred.JobConf; @@ -107,6 +108,7 @@ import static com.facebook.presto.spi.type.TinyintType.TINYINT; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.filter; import static com.google.common.collect.Lists.newArrayList; import static com.google.common.collect.Lists.transform; @@ -830,4 +832,11 @@ public static void closeWithSuppression(RecordCursor recordCursor, Throwable thr } } } + + public static List extractStructFieldTypes(HiveType hiveType) + { + return ((StructTypeInfo) hiveType.getTypeInfo()).getAllStructFieldTypeInfos().stream() + .map(typeInfo -> HiveType.valueOf(typeInfo.getTypeName())) + .collect(toImmutableList()); + } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java index ddddebc2d4aa2..8564d901542f6 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java @@ -75,12 +75,12 @@ import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.MapType; import com.facebook.presto.spi.type.NamedTypeSignature; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.SqlDate; import com.facebook.presto.spi.type.SqlTimestamp; import com.facebook.presto.spi.type.SqlVarbinary; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; -import com.facebook.presto.spi.type.TypeSignatureParameter; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; @@ -148,10 +148,13 @@ import static com.facebook.presto.hive.HiveTableProperties.STORAGE_FORMAT_PROPERTY; import static com.facebook.presto.hive.HiveTestUtils.SESSION; import static com.facebook.presto.hive.HiveTestUtils.TYPE_MANAGER; +import static com.facebook.presto.hive.HiveTestUtils.arrayType; import static com.facebook.presto.hive.HiveTestUtils.getDefaultHiveDataStreamFactories; import static com.facebook.presto.hive.HiveTestUtils.getDefaultHiveFileWriterFactories; import static com.facebook.presto.hive.HiveTestUtils.getDefaultHiveRecordCursorProvider; import static com.facebook.presto.hive.HiveTestUtils.getTypes; +import static com.facebook.presto.hive.HiveTestUtils.mapType; +import static com.facebook.presto.hive.HiveTestUtils.rowType; import static com.facebook.presto.hive.HiveType.HIVE_INT; import static com.facebook.presto.hive.HiveType.HIVE_LONG; import static com.facebook.presto.hive.HiveType.HIVE_STRING; @@ -172,9 +175,6 @@ import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.spi.type.RealType.REAL; import static com.facebook.presto.spi.type.SmallintType.SMALLINT; -import static com.facebook.presto.spi.type.StandardTypes.ARRAY; -import static com.facebook.presto.spi.type.StandardTypes.MAP; -import static com.facebook.presto.spi.type.StandardTypes.ROW; import static com.facebook.presto.spi.type.TimeZoneKey.UTC_KEY; import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP; import static com.facebook.presto.spi.type.TinyintType.TINYINT; @@ -229,14 +229,12 @@ public abstract class AbstractTestHiveClient protected static final String TEST_SERVER_VERSION = "test_version"; - private static final Type ARRAY_TYPE = TYPE_MANAGER.getParameterizedType(ARRAY, ImmutableList.of(TypeSignatureParameter.of(createUnboundedVarcharType().getTypeSignature()))); - private static final Type MAP_TYPE = TYPE_MANAGER.getParameterizedType(MAP, ImmutableList.of(TypeSignatureParameter.of(createUnboundedVarcharType().getTypeSignature()), TypeSignatureParameter.of(BIGINT.getTypeSignature()))); - private static final Type ROW_TYPE = TYPE_MANAGER.getParameterizedType( - ROW, - ImmutableList.of( - TypeSignatureParameter.of(new NamedTypeSignature("f_string", createUnboundedVarcharType().getTypeSignature())), - TypeSignatureParameter.of(new NamedTypeSignature("f_bigint", BIGINT.getTypeSignature())), - TypeSignatureParameter.of(new NamedTypeSignature("f_boolean", BOOLEAN.getTypeSignature())))); + private static final Type ARRAY_TYPE = arrayType(createUnboundedVarcharType()); + private static final Type MAP_TYPE = mapType(createUnboundedVarcharType(), BIGINT); + private static final Type ROW_TYPE = rowType(ImmutableList.of( + new NamedTypeSignature("f_string", createUnboundedVarcharType().getTypeSignature()), + new NamedTypeSignature("f_bigint", BIGINT.getTypeSignature()), + new NamedTypeSignature("f_boolean", BOOLEAN.getTypeSignature()))); private static final List CREATE_TABLE_COLUMNS = ImmutableList.builder() .add(new ColumnMetadata("id", BIGINT)) @@ -253,7 +251,35 @@ public abstract class AbstractTestHiveClient .add(new ColumnMetadata("t_row", ROW_TYPE)) .build(); - private static final List MISMATCH_SCHEMA_TABLE_BEFORE = ImmutableList.builder() + private static final MaterializedResult CREATE_TABLE_DATA = + MaterializedResult.resultBuilder(SESSION, BIGINT, createUnboundedVarcharType(), TINYINT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, BOOLEAN, ARRAY_TYPE, MAP_TYPE, ROW_TYPE) + .row(1L, "hello", (byte) 45, (short) 345, 234, 123L, -754.1985f, 43.5, true, ImmutableList.of("apple", "banana"), ImmutableMap.of("one", 1L, "two", 2L), ImmutableList.of("true", 1L, true)) + .row(2L, null, null, null, null, null, null, null, null, null, null, null) + .row(3L, "bye", (byte) 46, (short) 346, 345, 456L, 754.2008f, 98.1, false, ImmutableList.of("ape", "bear"), ImmutableMap.of("three", 3L, "four", 4L), ImmutableList.of("false", 0L, false)) + .build(); + + private static final List CREATE_TABLE_COLUMNS_PARTITIONED = ImmutableList.builder() + .addAll(CREATE_TABLE_COLUMNS) + .add(new ColumnMetadata("ds", createUnboundedVarcharType())) + .build(); + + private static final MaterializedResult CREATE_TABLE_PARTITIONED_DATA = new MaterializedResult( + CREATE_TABLE_DATA.getMaterializedRows().stream() + .map(row -> new MaterializedRow(row.getPrecision(), newArrayList(concat(row.getFields(), ImmutableList.of("2015-07-0" + row.getField(0)))))) + .collect(toList()), + ImmutableList.builder() + .addAll(CREATE_TABLE_DATA.getTypes()) + .add(createUnboundedVarcharType()) + .build()); + + private static final MaterializedResult CREATE_TABLE_PARTITIONED_DATA_2ND = + MaterializedResult.resultBuilder(SESSION, BIGINT, createUnboundedVarcharType(), TINYINT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, BOOLEAN, ARRAY_TYPE, MAP_TYPE, ROW_TYPE, createUnboundedVarcharType()) + .row(4L, "hello", (byte) 45, (short) 345, 234, 123L, 754.1985f, 43.5, true, ImmutableList.of("apple", "banana"), ImmutableMap.of("one", 1L, "two", 2L), ImmutableList.of("true", 1L, true), "2015-07-04") + .row(5L, null, null, null, null, null, null, null, null, null, null, null, "2015-07-04") + .row(6L, "bye", (byte) 46, (short) 346, 345, 456L, -754.2008f, 98.1, false, ImmutableList.of("ape", "bear"), ImmutableMap.of("three", 3L, "four", 4L), ImmutableList.of("false", 0L, false), "2015-07-04") + .build(); + + private static final List MISMATCH_SCHEMA_PRIMITIVE_COLUMN_BEFORE = ImmutableList.builder() .add(new ColumnMetadata("tinyint_to_smallint", TINYINT)) .add(new ColumnMetadata("tinyint_to_integer", TINYINT)) .add(new ColumnMetadata("tinyint_to_bigint", TINYINT)) @@ -263,10 +289,48 @@ public abstract class AbstractTestHiveClient .add(new ColumnMetadata("integer_to_varchar", INTEGER)) .add(new ColumnMetadata("varchar_to_integer", createUnboundedVarcharType())) .add(new ColumnMetadata("float_to_double", REAL)) + .add(new ColumnMetadata("varchar_to_drop_in_row", createUnboundedVarcharType())) + .build(); + + private static final List MISMATCH_SCHEMA_TABLE_BEFORE = ImmutableList.builder() + .addAll(MISMATCH_SCHEMA_PRIMITIVE_COLUMN_BEFORE) + .add(new ColumnMetadata("struct_to_struct", toRowType(MISMATCH_SCHEMA_PRIMITIVE_COLUMN_BEFORE))) + .add(new ColumnMetadata("list_to_list", arrayType(toRowType(MISMATCH_SCHEMA_PRIMITIVE_COLUMN_BEFORE)))) + .add(new ColumnMetadata("map_to_map", mapType(MISMATCH_SCHEMA_PRIMITIVE_COLUMN_BEFORE.get(1).getType(), toRowType(MISMATCH_SCHEMA_PRIMITIVE_COLUMN_BEFORE)))) .add(new ColumnMetadata("ds", createUnboundedVarcharType())) .build(); - private static final List MISMATCH_SCHEMA_TABLE_AFTER = ImmutableList.builder() + private static RowType toRowType(List columns) + { + return rowType(columns.stream() + .map(col -> new NamedTypeSignature(format("f_%s", col.getName()), col.getType().getTypeSignature())) + .collect(toList())); + } + + private static final MaterializedResult MISMATCH_SCHEMA_PRIMITIVE_FIELDS_DATA_BEFORE = + MaterializedResult.resultBuilder(SESSION, TINYINT, TINYINT, TINYINT, SMALLINT, SMALLINT, INTEGER, INTEGER, createUnboundedVarcharType(), REAL, createUnboundedVarcharType()) + .row((byte) -11, (byte) 12, (byte) -13, (short) 14, (short) 15, -16, 17, "2147483647", 18.0f, "2016-08-01") + .row((byte) 21, (byte) -22, (byte) 23, (short) -24, (short) 25, 26, -27, "asdf", -28.0f, "2016-08-02") + .row((byte) -31, (byte) -32, (byte) 33, (short) 34, (short) -35, 36, 37, "-923", 39.5f, "2016-08-03") + .row(null, (byte) 42, (byte) 43, (short) 44, (short) -45, 46, 47, "2147483648", 49.5f, "2016-08-03") + .build(); + + private static final MaterializedResult MISMATCH_SCHEMA_TABLE_DATA_BEFORE = + MaterializedResult.resultBuilder(SESSION, MISMATCH_SCHEMA_TABLE_BEFORE.stream().map(ColumnMetadata::getType).collect(toList())) + .rows(MISMATCH_SCHEMA_PRIMITIVE_FIELDS_DATA_BEFORE.getMaterializedRows() + .stream() + .map(materializedRow -> { + List result = materializedRow.getFields(); + List rowResult = materializedRow.getFields(); + result.add(rowResult); + result.add(Arrays.asList(rowResult, null, rowResult)); + result.add(ImmutableMap.of(rowResult.get(1), rowResult)); + result.add(rowResult.get(9)); + return new MaterializedRow(materializedRow.getPrecision(), result); + }).collect(toList())) + .build(); + + private static final List MISMATCH_SCHEMA_PRIMITIVE_COLUMN_AFTER = ImmutableList.builder() .add(new ColumnMetadata("tinyint_to_smallint", SMALLINT)) .add(new ColumnMetadata("tinyint_to_integer", INTEGER)) .add(new ColumnMetadata("tinyint_to_bigint", BIGINT)) @@ -276,25 +340,24 @@ public abstract class AbstractTestHiveClient .add(new ColumnMetadata("integer_to_varchar", createUnboundedVarcharType())) .add(new ColumnMetadata("varchar_to_integer", INTEGER)) .add(new ColumnMetadata("float_to_double", DOUBLE)) - .add(new ColumnMetadata("ds", createUnboundedVarcharType())) + .add(new ColumnMetadata("varchar_to_drop_in_row", createUnboundedVarcharType())) .build(); - private static final MaterializedResult CREATE_TABLE_DATA = - MaterializedResult.resultBuilder(SESSION, BIGINT, createUnboundedVarcharType(), TINYINT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, BOOLEAN, ARRAY_TYPE, MAP_TYPE, ROW_TYPE) - .row(1L, "hello", (byte) 45, (short) 345, 234, 123L, -754.1985f, 43.5, true, ImmutableList.of("apple", "banana"), ImmutableMap.of("one", 1L, "two", 2L), ImmutableList.of("true", 1L, true)) - .row(2L, null, null, null, null, null, null, null, null, null, null, null) - .row(3L, "bye", (byte) 46, (short) 346, 345, 456L, 754.2008f, 98.1, false, ImmutableList.of("ape", "bear"), ImmutableMap.of("three", 3L, "four", 4L), ImmutableList.of("false", 0L, false)) - .build(); + private static final Type MISMATCH_SCHEMA_ROW_TYPE_APPEND = toRowType(ImmutableList.builder() + .addAll(MISMATCH_SCHEMA_PRIMITIVE_COLUMN_AFTER) + .add(new ColumnMetadata(format("%s_append", MISMATCH_SCHEMA_PRIMITIVE_COLUMN_AFTER.get(0).getName()), MISMATCH_SCHEMA_PRIMITIVE_COLUMN_AFTER.get(0).getType())) + .build()); + private static final Type MISMATCH_SCHEMA_ROW_TYPE_DROP = toRowType(MISMATCH_SCHEMA_PRIMITIVE_COLUMN_AFTER.subList(0, MISMATCH_SCHEMA_PRIMITIVE_COLUMN_AFTER.size() - 1)); - private static final MaterializedResult MISMATCH_SCHEMA_TABLE_DATA_BEFORE = - MaterializedResult.resultBuilder(SESSION, TINYINT, TINYINT, TINYINT, SMALLINT, SMALLINT, INTEGER, INTEGER, createUnboundedVarcharType(), REAL, createUnboundedVarcharType()) - .row((byte) -11, (byte) 12, (byte) -13, (short) 14, (short) 15, -16, 17, "2147483647", 18.0f, "2016-08-01") - .row((byte) 21, (byte) -22, (byte) 23, (short) -24, (short) 25, 26, -27, "asdf", -28.0f, "2016-08-02") - .row((byte) -31, (byte) -32, (byte) 33, (short) 34, (short) -35, 36, 37, "-923", 39.5f, "2016-08-03") - .row(null, (byte) 42, (byte) 43, (short) 44, (short) -45, 46, 47, "2147483648", 49.5f, "2016-08-03") - .build(); + private static final List MISMATCH_SCHEMA_TABLE_AFTER = ImmutableList.builder() + .addAll(MISMATCH_SCHEMA_PRIMITIVE_COLUMN_AFTER) + .add(new ColumnMetadata("struct_to_struct", MISMATCH_SCHEMA_ROW_TYPE_APPEND)) + .add(new ColumnMetadata("list_to_list", arrayType(MISMATCH_SCHEMA_ROW_TYPE_APPEND))) + .add(new ColumnMetadata("map_to_map", mapType(MISMATCH_SCHEMA_PRIMITIVE_COLUMN_AFTER.get(1).getType(), MISMATCH_SCHEMA_ROW_TYPE_DROP))) + .add(new ColumnMetadata("ds", createUnboundedVarcharType())) + .build(); - private static final MaterializedResult MISMATCH_SCHEMA_TABLE_DATA_AFTER = + private static final MaterializedResult MISMATCH_SCHEMA_PRIMITIVE_FIELDS_DATA_AFTER = MaterializedResult.resultBuilder(SESSION, SMALLINT, INTEGER, BIGINT, INTEGER, BIGINT, BIGINT, createUnboundedVarcharType(), INTEGER, DOUBLE, createUnboundedVarcharType()) .row((short) -11, 12, -13L, 14, 15L, -16L, "17", 2147483647, 18.0, "2016-08-01") .row((short) 21, -22, 23L, -24, 25L, 26L, "-27", null, -28.0, "2016-08-02") @@ -302,25 +365,21 @@ public abstract class AbstractTestHiveClient .row(null, 42, 43L, 44, -45L, 46L, "47", null, 49.5, "2016-08-03") .build(); - private static final List CREATE_TABLE_COLUMNS_PARTITIONED = ImmutableList.builder() - .addAll(CREATE_TABLE_COLUMNS) - .add(new ColumnMetadata("ds", createUnboundedVarcharType())) - .build(); - - private static final MaterializedResult CREATE_TABLE_PARTITIONED_DATA = new MaterializedResult( - CREATE_TABLE_DATA.getMaterializedRows().stream() - .map(row -> new MaterializedRow(row.getPrecision(), newArrayList(concat(row.getFields(), ImmutableList.of("2015-07-0" + row.getField(0)))))) - .collect(toList()), - ImmutableList.builder() - .addAll(CREATE_TABLE_DATA.getTypes()) - .add(createUnboundedVarcharType()) - .build()); - - private static final MaterializedResult CREATE_TABLE_PARTITIONED_DATA_2ND = - MaterializedResult.resultBuilder(SESSION, BIGINT, createUnboundedVarcharType(), TINYINT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, BOOLEAN, ARRAY_TYPE, MAP_TYPE, ROW_TYPE, createUnboundedVarcharType()) - .row(4L, "hello", (byte) 45, (short) 345, 234, 123L, 754.1985f, 43.5, true, ImmutableList.of("apple", "banana"), ImmutableMap.of("one", 1L, "two", 2L), ImmutableList.of("true", 1L, true), "2015-07-04") - .row(5L, null, null, null, null, null, null, null, null, null, null, null, "2015-07-04") - .row(6L, "bye", (byte) 46, (short) 346, 345, 456L, -754.2008f, 98.1, false, ImmutableList.of("ape", "bear"), ImmutableMap.of("three", 3L, "four", 4L), ImmutableList.of("false", 0L, false), "2015-07-04") + private static final MaterializedResult MISMATCH_SCHEMA_TABLE_DATA_AFTER = + MaterializedResult.resultBuilder(SESSION, MISMATCH_SCHEMA_TABLE_AFTER.stream().map(ColumnMetadata::getType).collect(toList())) + .rows(MISMATCH_SCHEMA_PRIMITIVE_FIELDS_DATA_AFTER.getMaterializedRows() + .stream() + .map(materializedRow -> { + List result = materializedRow.getFields(); + List appendFieldRowResult = materializedRow.getFields(); + appendFieldRowResult.add(null); + List dropFieldRowResult = materializedRow.getFields().subList(0, materializedRow.getFields().size() - 1); + result.add(appendFieldRowResult); + result.add(Arrays.asList(appendFieldRowResult, null, appendFieldRowResult)); + result.add(ImmutableMap.of(result.get(1), dropFieldRowResult)); + result.add(result.get(9)); + return new MaterializedRow(materializedRow.getPrecision(), result); + }).collect(toList())) .build(); protected Set createTableFormats = difference(ImmutableSet.copyOf(HiveStorageFormat.values()), ImmutableSet.of(AVRO)); @@ -2926,6 +2985,21 @@ else if (rowNumber % 39 == 1) { } } + // STRUCT + index = columnIndex.get("t_struct"); + if (index != null) { + if ((rowNumber % 31) == 0) { + assertNull(row.getField(index)); + } + else { + assertTrue(row.getField(index) instanceof List); + List values = (List) row.getField(index); + assertEquals(values.size(), 2); + assertEquals(values.get(0), "test abc"); + assertEquals(values.get(1), 0.1); + } + } + // MAP>> index = columnIndex.get("t_complex"); if (index != null) { @@ -3154,7 +3228,7 @@ else if (TIMESTAMP.equals(column.getType())) { else if (DATE.equals(column.getType())) { assertInstanceOf(value, SqlDate.class); } - else if (column.getType() instanceof ArrayType) { + else if (column.getType() instanceof ArrayType || column.getType() instanceof RowType) { assertInstanceOf(value, List.class); } else if (column.getType() instanceof MapType) { diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java b/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java index 8489bbb89ff54..ff4e9f203c2cd 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java @@ -26,7 +26,10 @@ import com.facebook.presto.metadata.FunctionRegistry; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.type.ArrayType; import com.facebook.presto.spi.type.MapType; +import com.facebook.presto.spi.type.NamedTypeSignature; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.spi.type.StandardTypes; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.TypeSignatureParameter; @@ -39,6 +42,8 @@ import java.util.List; import java.util.Set; +import static java.util.stream.Collectors.toList; + public final class HiveTestUtils { private HiveTestUtils() @@ -119,4 +124,20 @@ public static MapType mapType(Type keyType, Type valueType) TypeSignatureParameter.of(keyType.getTypeSignature()), TypeSignatureParameter.of(valueType.getTypeSignature()))); } + + public static ArrayType arrayType(Type elementType) + { + return (ArrayType) TYPE_MANAGER.getParameterizedType( + StandardTypes.ARRAY, + ImmutableList.of(TypeSignatureParameter.of(elementType.getTypeSignature()))); + } + + public static RowType rowType(List elementTypeSignatures) + { + return (RowType) TYPE_MANAGER.getParameterizedType( + StandardTypes.ROW, + ImmutableList.copyOf(elementTypeSignatures.stream() + .map(TypeSignatureParameter::of) + .collect(toList()))); + } } diff --git a/presto-hive/src/test/sql/create-test-hive13.sql b/presto-hive/src/test/sql/create-test-hive13.sql index 0d63bf67ce1f9..54747332d993b 100644 --- a/presto-hive/src/test/sql/create-test-hive13.sql +++ b/presto-hive/src/test/sql/create-test-hive13.sql @@ -15,6 +15,7 @@ CREATE TABLE presto_test_types_textfile ( , t_map MAP , t_array_string ARRAY , t_array_struct ARRAY> +, t_struct STRUCT , t_complex MAP>> ) STORED AS TEXTFILE @@ -40,6 +41,8 @@ SELECT , CASE WHEN n % 31 = 0 THEN NULL ELSE array(named_struct('s_string', 'test abc', 's_double', 0.1), named_struct('s_string' , 'test xyz', 's_double', 0.2)) END +, CASE WHEN n % 31 = 0 THEN NULL ELSE + named_struct('s_string', 'test abc', 's_double', 0.1) END , CASE WHEN n % 33 = 0 THEN NULL ELSE map(1, array(named_struct('s_string', 'test abc', 's_double', 0.1), named_struct('s_string' , 'test xyz', 's_double', 0.2))) END @@ -82,7 +85,6 @@ SELECT * FROM presto_test_types_textfile ; --- Parquet fails when trying to use complex nested types. -- Parquet is missing TIMESTAMP and BINARY. CREATE TABLE presto_test_types_parquet ( t_string STRING @@ -99,6 +101,8 @@ CREATE TABLE presto_test_types_parquet ( , t_map MAP , t_array_string ARRAY , t_array_struct ARRAY> +, t_struct STRUCT +, t_complex MAP>> ) STORED AS PARQUET ; @@ -119,6 +123,8 @@ SELECT , t_map , t_array_string , t_array_struct +, t_struct +, t_complex FROM presto_test_types_textfile ; diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveCoercion.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveCoercion.java index f4fea9659d954..d99b97e110d16 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveCoercion.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/hive/TestHiveCoercion.java @@ -13,10 +13,14 @@ */ package com.facebook.presto.tests.hive; +import com.facebook.presto.jdbc.PrestoArray; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.prestodb.tempto.ProductTest; import io.prestodb.tempto.Requirement; import io.prestodb.tempto.RequirementsProvider; import io.prestodb.tempto.Requires; +import io.prestodb.tempto.assertions.QueryAssert.Row; import io.prestodb.tempto.configuration.Configuration; import io.prestodb.tempto.fulfillment.table.MutableTableRequirement; import io.prestodb.tempto.fulfillment.table.MutableTablesState; @@ -29,7 +33,11 @@ import org.testng.annotations.Test; import java.sql.Connection; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Optional; import static com.facebook.presto.tests.TestGroups.HIVE_COERCION; @@ -37,6 +45,7 @@ import static com.facebook.presto.tests.TestGroups.JDBC; import static com.facebook.presto.tests.utils.JdbcDriverUtils.usingPrestoJdbcDriver; import static com.facebook.presto.tests.utils.JdbcDriverUtils.usingTeradataJdbcDriver; +import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder; import static io.prestodb.tempto.assertions.QueryAssert.Row.row; import static io.prestodb.tempto.assertions.QueryAssert.assertThat; import static io.prestodb.tempto.context.ThreadLocalTestContextHolder.testContext; @@ -45,12 +54,16 @@ import static io.prestodb.tempto.query.QueryExecutor.defaultQueryExecutor; import static io.prestodb.tempto.query.QueryExecutor.query; import static java.lang.String.format; +import static java.sql.JDBCType.ARRAY; import static java.sql.JDBCType.BIGINT; import static java.sql.JDBCType.DOUBLE; import static java.sql.JDBCType.INTEGER; +import static java.sql.JDBCType.JAVA_OBJECT; import static java.sql.JDBCType.LONGNVARCHAR; import static java.sql.JDBCType.SMALLINT; import static java.sql.JDBCType.VARCHAR; +import static java.util.stream.Collectors.toList; +import static org.testng.Assert.assertEquals; public class TestHiveCoercion extends ProductTest @@ -61,7 +74,7 @@ public class TestHiveCoercion .setNoData() .build(); - public static final HiveTableDefinition HIVE_COERCION_PARQUET = parquetTableDefinitionBuilder() + public static final HiveTableDefinition HIVE_COERCION_PARQUET = tableDefinitionBuilder("PARQUET", Optional.empty(), Optional.empty()) .setNoData() .build(); @@ -84,6 +97,7 @@ public class TestHiveCoercion private static HiveTableDefinition.HiveTableDefinitionBuilder tableDefinitionBuilder(String fileFormat, Optional recommendTableName, Optional rowFormat) { String tableName = format(tableNameFormat, recommendTableName.orElse(fileFormat).toLowerCase(Locale.ENGLISH)); + String floatToDoubleType = fileFormat.toLowerCase(Locale.ENGLISH).contains("parquet") ? "DOUBLE" : "FLOAT"; return HiveTableDefinition.builder(tableName) .setCreateTableDDLTemplate("" + "CREATE TABLE %NAME%(" + @@ -94,31 +108,17 @@ private static HiveTableDefinition.HiveTableDefinitionBuilder tableDefinitionBui " smallint_to_bigint SMALLINT," + " int_to_bigint INT," + " bigint_to_varchar BIGINT," + - " float_to_double FLOAT" + + " float_to_double " + floatToDoubleType + "," + + // all nested primitive/varchar coercions and adding/removing tailing nested fields are covered across row_to_row, list_to_list, and map_to_map + " row_to_row STRUCT," + + " list_to_list ARRAY>," + + " map_to_map MAP>" + ") " + "PARTITIONED BY (id BIGINT) " + (rowFormat.isPresent() ? "ROW FORMAT " + rowFormat.get() + " " : " ") + "STORED AS " + fileFormat); } - private static HiveTableDefinition.HiveTableDefinitionBuilder parquetTableDefinitionBuilder() - { - return HiveTableDefinition.builder("parquet_hive_coercion") - .setCreateTableDDLTemplate("" + - "CREATE TABLE %NAME%(" + - " tinyint_to_smallint TINYINT," + - " tinyint_to_int TINYINT," + - " tinyint_to_bigint TINYINT," + - " smallint_to_int SMALLINT," + - " smallint_to_bigint SMALLINT," + - " int_to_bigint INT," + - " bigint_to_varchar BIGINT," + - " float_to_double DOUBLE" + - ") " + - "PARTITIONED BY (id BIGINT) " + - "STORED AS PARQUET"); - } - private static HiveTableDefinition.HiveTableDefinitionBuilder avroTableDefinitionBuilder() { return HiveTableDefinition.builder("avro_hive_coercion") @@ -259,20 +259,36 @@ public void testHiveCoercionAvro() private void doTestHiveCoercion(HiveTableDefinition tableDefinition) { String tableName = mutableTableInstanceOf(tableDefinition).getNameInDatabase(); + String floatToDoubleType = tableName.toLowerCase(Locale.ENGLISH).contains("parquet") ? "DOUBLE" : "FLOAT"; + executeHiveQuery(format("DROP TABLE IF EXISTS %s_dummy", tableName)); + executeHiveQuery(format("CREATE TABLE %s_dummy (id BIGINT)", tableName)); + executeHiveQuery(format("INSERT INTO TABLE %s_dummy (id) VALUES (1)", tableName)); executeHiveQuery(format("INSERT INTO TABLE %s " + "PARTITION (id=1) " + - "VALUES" + - "(-1, 2, -3, 100, -101, 2323, 12345, 0.5)," + - "(1, -2, null, -100, 101, -2323, -12345, -1.5)", - tableName)); + "SELECT " + + " -1, 2, -3, 100, -101, 2323, 12345, 0.5," + + " named_struct('keep', 'as is', 'ti2si', CAST(-1 as TINYINT), 'si2int', CAST(100 as SMALLINT), 'int2bi', 2323, 'bi2vc', CAST(12345 as BIGINT))," + + " array(named_struct('ti2int', CAST(2 as TINYINT), 'si2bi', CAST(-101 as SMALLINT), 'bi2vc', CAST(12345 as BIGINT), 'remove', 'gone'))," + + " map(CAST(2 as TINYINT), named_struct('ti2bi', CAST(-3 as TINYINT), 'int2bi', 2323, 'float2double', CAST(0.5 as %s)))" + + "FROM %s_dummy limit 1", + tableName, floatToDoubleType, tableName)); + executeHiveQuery(format("INSERT INTO TABLE %s " + + "PARTITION (id=1) " + + "SELECT " + + " 1, -2, null, -100, 101, -2323, -12345, -1.5," + + " named_struct('keep', CAST(null as STRING), 'ti2si', CAST(1 as TINYINT), 'si2int', CAST(-100 as SMALLINT), 'int2bi', -2323, 'bi2vc', CAST(-12345 as BIGINT))," + + " array(named_struct('ti2int', CAST(-2 as TINYINT), 'si2bi', CAST(101 as SMALLINT), 'bi2vc', CAST(-12345 as BIGINT), 'remove', CAST(null as STRING)))," + + " map(CAST(2 as TINYINT), named_struct('ti2bi', CAST(null as TINYINT), 'int2bi', -2323, 'float2double', CAST(-1.5 as %s)))" + + "FROM %s_dummy limit 1", + tableName, floatToDoubleType, tableName)); alterTableColumnTypes(tableName); assertProperAlteredTableSchema(tableName); QueryResult queryResult = query(format("SELECT * FROM %s", tableName)); assertColumnTypes(queryResult); - assertThat(queryResult).containsOnly( + List expectedRows = ImmutableList.of( row( -1, 2, @@ -282,6 +298,9 @@ private void doTestHiveCoercion(HiveTableDefinition tableDefinition) 2323L, "12345", 0.5, + namedStruct("keep", "as is", "ti2si", (short) -1, "si2int", 100, "int2bi", 2323L, "bi2vc", "12345"), + ImmutableList.of(namedStruct("ti2int", 2, "si2bi", -101L, "bi2vc", "12345")), + ImmutableMap.of(2, namedStruct("ti2bi", -3L, "int2bi", 2323L, "float2double", 0.5, "add", null)), 1), row( 1, @@ -292,7 +311,16 @@ private void doTestHiveCoercion(HiveTableDefinition tableDefinition) -2323L, "-12345", -1.5, + namedStruct("keep", null, "ti2si", (short) 1, "si2int", -100, "int2bi", -2323L, "bi2vc", "-12345"), + ImmutableList.of(namedStruct("ti2int", -2, "si2bi", 101L, "bi2vc", "-12345")), + ImmutableMap.of(2, namedStruct("ti2bi", null, "int2bi", -2323L, "float2double", -1.5, "add", null)), 1)); + // test primitive values + assertThat(queryResult.project(1, 2, 3, 4, 5, 6, 7, 8, 12)).containsOnly(project(expectedRows, 1, 2, 3, 4, 5, 6, 7, 8, 12)); + // test structural values + assertEqualsIgnoreOrder(queryResult.column(9), column(expectedRows, 9), "row_to_row field is not equal"); + assertEqualsIgnoreOrder(extract(queryResult.column(10)), column(expectedRows, 10), "list_to_list field is not equal"); + assertEqualsIgnoreOrder(queryResult.column(11), column(expectedRows, 11), "map_to_map field is not equal"); } private void assertProperAlteredTableSchema(String tableName) @@ -306,6 +334,9 @@ private void assertProperAlteredTableSchema(String tableName) row("int_to_bigint", "bigint"), row("bigint_to_varchar", "varchar"), row("float_to_double", "double"), + row("row_to_row", "row(keep varchar, ti2si smallint, si2int integer, int2bi bigint, bi2vc varchar)"), + row("list_to_list", "array(row(ti2int integer, si2bi bigint, bi2vc varchar))"), + row("map_to_map", "map(integer, row(ti2bi bigint, int2bi bigint, float2double double, add tinyint))"), row("id", "bigint")); } @@ -322,6 +353,9 @@ private void assertColumnTypes(QueryResult queryResult) BIGINT, LONGNVARCHAR, DOUBLE, + JAVA_OBJECT, + ARRAY, + JAVA_OBJECT, BIGINT); } else if (usingTeradataJdbcDriver(connection)) { @@ -334,6 +368,9 @@ else if (usingTeradataJdbcDriver(connection)) { BIGINT, VARCHAR, DOUBLE, + JAVA_OBJECT, + ARRAY, + JAVA_OBJECT, BIGINT); } else { @@ -351,6 +388,9 @@ private static void alterTableColumnTypes(String tableName) executeHiveQuery(format("ALTER TABLE %s CHANGE COLUMN int_to_bigint int_to_bigint bigint", tableName)); executeHiveQuery(format("ALTER TABLE %s CHANGE COLUMN bigint_to_varchar bigint_to_varchar string", tableName)); executeHiveQuery(format("ALTER TABLE %s CHANGE COLUMN float_to_double float_to_double double", tableName)); + executeHiveQuery(format("ALTER TABLE %s CHANGE COLUMN row_to_row row_to_row struct", tableName)); + executeHiveQuery(format("ALTER TABLE %s CHANGE COLUMN list_to_list list_to_list array>", tableName)); + executeHiveQuery(format("ALTER TABLE %s CHANGE COLUMN map_to_map map_to_map map>", tableName)); } private static TableInstance mutableTableInstanceOf(TableDefinition tableDefinition) @@ -386,4 +426,44 @@ private static QueryResult executeHiveQuery(String query) { return testContext().getDependency(QueryExecutor.class, "hive").executeQuery(query); } + + private static Map namedStruct(Object... objects) + { + assertEquals(objects.length % 2, 0, "number of objects must be even"); + Map struct = new HashMap<>(); + for (int i = 0; i < objects.length; i += 2) { + struct.put(objects[i], objects[i + 1]); + } + return struct; + } + + private static Object[] project(Row row, int... columns) + { + Object[] values = new Object[columns.length]; + for (int i = 0; i < columns.length; i++) { + values[i] = (row.getValues().get(columns[i] - 1)); + } + return values; + } + + private static List project(List rows, int... columns) + { + return rows.stream() + .map(row -> Row.row(project(row, columns))) + .collect(ImmutableList.toImmutableList()); + } + + private static List> extract(List arrays) + { + return arrays.stream() + .map(obj -> Arrays.asList(Object[].class.cast(PrestoArray.class.cast(obj).getArray()))) + .collect(toList()); + } + + private static List column(List rows, int sqlColumnIndex) + { + return rows.stream() + .map(row -> project(row, sqlColumnIndex)[0]) + .collect(toList()); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ColumnarArray.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ColumnarArray.java index 20308110d3008..51b0150085bb2 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ColumnarArray.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ColumnarArray.java @@ -34,7 +34,7 @@ public static ColumnarArray toColumnarArray(Block block) } if (!(block instanceof AbstractArrayBlock)) { - throw new IllegalArgumentException("Invalid array block"); + throw new IllegalArgumentException("Invalid array block: " + block.getClass().getName()); } AbstractArrayBlock arrayBlock = (AbstractArrayBlock) block; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ColumnarMap.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ColumnarMap.java index 01440cb4ea564..e1e34e00ab9f1 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ColumnarMap.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ColumnarMap.java @@ -35,7 +35,7 @@ public static ColumnarMap toColumnarMap(Block block) } if (!(block instanceof AbstractMapBlock)) { - throw new IllegalArgumentException("Invalid map block"); + throw new IllegalArgumentException("Invalid map block: " + block.getClass().getName()); } AbstractMapBlock mapBlock = (AbstractMapBlock) block; diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/block/ColumnarRow.java b/presto-spi/src/main/java/com/facebook/presto/spi/block/ColumnarRow.java index 9ead281fb5c27..7ad5d1d55e507 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/block/ColumnarRow.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/block/ColumnarRow.java @@ -32,7 +32,7 @@ public static ColumnarRow toColumnarRow(Block block) } if (!(block instanceof AbstractRowBlock)) { - throw new IllegalArgumentException("Invalid row block"); + throw new IllegalArgumentException("Invalid row block: " + block.getClass().getName()); } AbstractRowBlock rowBlock = (AbstractRowBlock) block;