diff --git a/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetReaders.java b/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetReaders.java index ccea5d6529c5..01fc2dd11a8f 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetReaders.java +++ b/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetReaders.java @@ -19,64 +19,716 @@ package org.apache.iceberg.flink.data; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; import java.util.List; -import org.apache.flink.types.Row; +import java.util.Map; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.DecimalData; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.MapData; +import org.apache.flink.table.data.RawValueData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.TimestampData; import org.apache.iceberg.Schema; -import org.apache.iceberg.data.parquet.BaseParquetReaders; +import org.apache.iceberg.parquet.ParquetSchemaUtil; import org.apache.iceberg.parquet.ParquetValueReader; import org.apache.iceberg.parquet.ParquetValueReaders; +import org.apache.iceberg.parquet.TypeWithSchemaVisitor; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.types.Types; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; import org.apache.parquet.schema.Type; -public class FlinkParquetReaders extends BaseParquetReaders { +public class FlinkParquetReaders { + private FlinkParquetReaders() { + } - private static final FlinkParquetReaders INSTANCE = new FlinkParquetReaders(); + public static ParquetValueReader buildReader(Schema expectedSchema, MessageType fileSchema) { + return buildReader(expectedSchema, fileSchema, ImmutableMap.of()); + } - private FlinkParquetReaders() { + @SuppressWarnings("unchecked") + public static ParquetValueReader buildReader(Schema expectedSchema, + MessageType fileSchema, + Map idToConstant) { + if (ParquetSchemaUtil.hasIds(fileSchema)) { + return (ParquetValueReader) + TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, + new ReadBuilder(fileSchema, idToConstant)); + } else { + return (ParquetValueReader) + TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, + new FallbackReadBuilder(fileSchema, idToConstant)); + } + } + + private static class FallbackReadBuilder extends ReadBuilder { + FallbackReadBuilder(MessageType type, Map idToConstant) { + super(type, idToConstant); + } + + @Override + public ParquetValueReader message(Types.StructType expected, MessageType message, + List> fieldReaders) { + // the top level matches by ID, but the remaining IDs are missing + return super.struct(expected, message, fieldReaders); + } + + @Override + public ParquetValueReader struct(Types.StructType ignored, GroupType struct, + List> fieldReaders) { + // the expected struct is ignored because nested fields are never found when the + List> newFields = Lists.newArrayListWithExpectedSize( + fieldReaders.size()); + List types = Lists.newArrayListWithExpectedSize(fieldReaders.size()); + List fields = struct.getFields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i); + int fieldD = type().getMaxDefinitionLevel(path(fieldType.getName())) - 1; + newFields.add(ParquetValueReaders.option(fieldType, fieldD, fieldReaders.get(i))); + types.add(fieldType); + } + + return new RowDataReader(types, newFields); + } + } + + private static class ReadBuilder extends TypeWithSchemaVisitor> { + private final MessageType type; + private final Map idToConstant; + + ReadBuilder(MessageType type, Map idToConstant) { + this.type = type; + this.idToConstant = idToConstant; + } + + @Override + public ParquetValueReader message(Types.StructType expected, MessageType message, + List> fieldReaders) { + return struct(expected, message.asGroupType(), fieldReaders); + } + + @Override + public ParquetValueReader struct(Types.StructType expected, GroupType struct, + List> fieldReaders) { + // match the expected struct's order + Map> readersById = Maps.newHashMap(); + Map typesById = Maps.newHashMap(); + List fields = struct.getFields(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i); + int fieldD = type.getMaxDefinitionLevel(path(fieldType.getName())) - 1; + if (fieldType.getId() != null) { + int id = fieldType.getId().intValue(); + readersById.put(id, ParquetValueReaders.option(fieldType, fieldD, fieldReaders.get(i))); + typesById.put(id, fieldType); + } + } + + List expectedFields = expected != null ? + expected.fields() : ImmutableList.of(); + List> reorderedFields = Lists.newArrayListWithExpectedSize( + expectedFields.size()); + List types = Lists.newArrayListWithExpectedSize(expectedFields.size()); + for (Types.NestedField field : expectedFields) { + int id = field.fieldId(); + if (idToConstant.containsKey(id)) { + // containsKey is used because the constant may be null + reorderedFields.add(ParquetValueReaders.constant(idToConstant.get(id))); + types.add(null); + } else { + ParquetValueReader reader = readersById.get(id); + if (reader != null) { + reorderedFields.add(reader); + types.add(typesById.get(id)); + } else { + reorderedFields.add(ParquetValueReaders.nulls()); + types.add(null); + } + } + } + + return new RowDataReader(types, reorderedFields); + } + + @Override + public ParquetValueReader list(Types.ListType expectedList, GroupType array, + ParquetValueReader elementReader) { + GroupType repeated = array.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath) - 1; + int repeatedR = type.getMaxRepetitionLevel(repeatedPath) - 1; + + Type elementType = repeated.getType(0); + int elementD = type.getMaxDefinitionLevel(path(elementType.getName())) - 1; + + return new ArrayReader<>(repeatedD, repeatedR, ParquetValueReaders.option(elementType, elementD, elementReader)); + } + + @Override + public ParquetValueReader map(Types.MapType expectedMap, GroupType map, + ParquetValueReader keyReader, + ParquetValueReader valueReader) { + GroupType repeatedKeyValue = map.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath) - 1; + int repeatedR = type.getMaxRepetitionLevel(repeatedPath) - 1; + + Type keyType = repeatedKeyValue.getType(0); + int keyD = type.getMaxDefinitionLevel(path(keyType.getName())) - 1; + Type valueType = repeatedKeyValue.getType(1); + int valueD = type.getMaxDefinitionLevel(path(valueType.getName())) - 1; + + return new MapReader<>(repeatedD, repeatedR, + ParquetValueReaders.option(keyType, keyD, keyReader), + ParquetValueReaders.option(valueType, valueD, valueReader)); + } + + @Override + public ParquetValueReader primitive(org.apache.iceberg.types.Type.PrimitiveType expected, + PrimitiveType primitive) { + ColumnDescriptor desc = type.getColumnDescription(currentPath()); + + if (primitive.getOriginalType() != null) { + switch (primitive.getOriginalType()) { + case ENUM: + case JSON: + case UTF8: + return new StringReader(desc); + case INT_8: + case INT_16: + case INT_32: + case TIME_MICROS: + case DATE: + if (expected != null && expected.typeId() == Types.LongType.get().typeId()) { + return new ParquetValueReaders.IntAsLongReader(desc); + } else { + return new ParquetValueReaders.UnboxedReader<>(desc); + } + case INT_64: + return new ParquetValueReaders.UnboxedReader<>(desc); + case TIMESTAMP_MICROS: + return new TimestampMicroReader(desc); + case DECIMAL: + DecimalLogicalTypeAnnotation decimal = (DecimalLogicalTypeAnnotation) primitive.getLogicalTypeAnnotation(); + switch (primitive.getPrimitiveTypeName()) { + case BINARY: + case FIXED_LEN_BYTE_ARRAY: + return new BinaryDecimalReader(desc, decimal.getScale()); + case INT64: + return new LongDecimalReader(desc, decimal.getPrecision(), decimal.getScale()); + case INT32: + return new IntegerDecimalReader(desc, decimal.getPrecision(), decimal.getScale()); + default: + throw new UnsupportedOperationException( + "Unsupported base type for decimal: " + primitive.getPrimitiveTypeName()); + } + case BSON: + return new BytesReader(desc); + default: + throw new UnsupportedOperationException( + "Unsupported logical type: " + primitive.getOriginalType()); + } + } + + switch (primitive.getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + return new BytesReader(desc); + case INT32: + if (expected != null && expected.typeId() == org.apache.iceberg.types.Type.TypeID.LONG) { + return new ParquetValueReaders.IntAsLongReader(desc); + } else { + return new ParquetValueReaders.UnboxedReader<>(desc); + } + case FLOAT: + if (expected != null && expected.typeId() == org.apache.iceberg.types.Type.TypeID.DOUBLE) { + return new ParquetValueReaders.FloatAsDoubleReader(desc); + } else { + return new ParquetValueReaders.UnboxedReader<>(desc); + } + case BOOLEAN: + case INT64: + case DOUBLE: + return new ParquetValueReaders.UnboxedReader<>(desc); + default: + throw new UnsupportedOperationException("Unsupported type: " + primitive); + } + } + + protected MessageType type() { + return type; + } + } + + private static class BinaryDecimalReader extends ParquetValueReaders.PrimitiveReader { + private final int scale; + + BinaryDecimalReader(ColumnDescriptor desc, int scale) { + super(desc); + this.scale = scale; + } + + @Override + public DecimalData read(DecimalData ignored) { + Binary binary = column.nextBinary(); + BigDecimal bigDecimal = new BigDecimal(new BigInteger(binary.getBytes()), scale); + return DecimalData.fromBigDecimal(bigDecimal, bigDecimal.precision(), scale); + } + } + + private static class IntegerDecimalReader extends ParquetValueReaders.PrimitiveReader { + private final int precision; + private final int scale; + + IntegerDecimalReader(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public DecimalData read(DecimalData ignored) { + return DecimalData.fromUnscaledLong(column.nextInteger(), precision, scale); + } + } + + private static class LongDecimalReader extends ParquetValueReaders.PrimitiveReader { + private final int precision; + private final int scale; + + LongDecimalReader(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public DecimalData read(DecimalData ignored) { + return DecimalData.fromUnscaledLong(column.nextLong(), precision, scale); + } + } + + private static class TimestampMicroReader extends ParquetValueReaders.UnboxedReader { + TimestampMicroReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public TimestampData read(TimestampData ignored) { + long value = readLong(); + return TimestampData.fromEpochMillis(value / 1000, (int) ((value % 1000) * 1000)); + } + + @Override + public long readLong() { + return column.nextLong(); + } + } + + private static class StringReader extends ParquetValueReaders.PrimitiveReader { + StringReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public StringData read(StringData ignored) { + Binary binary = column.nextBinary(); + ByteBuffer buffer = binary.toByteBuffer(); + if (buffer.hasArray()) { + return StringData.fromBytes( + buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining()); + } else { + return StringData.fromBytes(binary.getBytes()); + } + } } - public static ParquetValueReader buildReader(Schema expectedSchema, MessageType fileSchema) { - return INSTANCE.createReader(expectedSchema, fileSchema); + private static class BytesReader extends ParquetValueReaders.PrimitiveReader { + BytesReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public byte[] read(byte[] ignored) { + return column.nextBinary().getBytes(); + } } - @Override - protected ParquetValueReader createStructReader(List types, - List> fieldReaders, - Types.StructType structType) { - return new RowReader(types, fieldReaders, structType); + private static class ArrayReader extends ParquetValueReaders.RepeatedReader { + private int readPos = 0; + private int writePos = 0; + + ArrayReader(int definitionLevel, int repetitionLevel, ParquetValueReader reader) { + super(definitionLevel, repetitionLevel, reader); + } + + @Override + @SuppressWarnings("unchecked") + protected ReusableArrayData newListData(ArrayData reuse) { + this.readPos = 0; + this.writePos = 0; + + if (reuse instanceof ReusableArrayData) { + return (ReusableArrayData) reuse; + } else { + return new ReusableArrayData(); + } + } + + @Override + @SuppressWarnings("unchecked") + protected E getElement(ReusableArrayData list) { + E value = null; + if (readPos < list.capacity()) { + value = (E) list.values[readPos]; + } + + readPos += 1; + + return value; + } + + @Override + protected void addElement(ReusableArrayData reused, E element) { + if (writePos >= reused.capacity()) { + reused.grow(); + } + + reused.values[writePos] = element; + + writePos += 1; + } + + @Override + protected ArrayData buildList(ReusableArrayData list) { + list.setNumElements(writePos); + return list; + } + } + + private static class MapReader extends + ParquetValueReaders.RepeatedKeyValueReader { + private int readPos = 0; + private int writePos = 0; + + private final ParquetValueReaders.ReusableEntry entry = new ParquetValueReaders.ReusableEntry<>(); + private final ParquetValueReaders.ReusableEntry nullEntry = new ParquetValueReaders.ReusableEntry<>(); + + MapReader(int definitionLevel, int repetitionLevel, + ParquetValueReader keyReader, ParquetValueReader valueReader) { + super(definitionLevel, repetitionLevel, keyReader, valueReader); + } + + @Override + @SuppressWarnings("unchecked") + protected ReusableMapData newMapData(MapData reuse) { + this.readPos = 0; + this.writePos = 0; + + if (reuse instanceof ReusableMapData) { + return (ReusableMapData) reuse; + } else { + return new ReusableMapData(); + } + } + + @Override + @SuppressWarnings("unchecked") + protected Map.Entry getPair(ReusableMapData map) { + Map.Entry kv = nullEntry; + if (readPos < map.capacity()) { + entry.set((K) map.keys.values[readPos], (V) map.values.values[readPos]); + kv = entry; + } + + readPos += 1; + + return kv; + } + + @Override + protected void addPair(ReusableMapData map, K key, V value) { + if (writePos >= map.capacity()) { + map.grow(); + } + + map.keys.values[writePos] = key; + map.values.values[writePos] = value; + + writePos += 1; + } + + @Override + protected MapData buildMap(ReusableMapData map) { + map.setNumElements(writePos); + return map; + } } - private static class RowReader extends ParquetValueReaders.StructReader { - private final Types.StructType structType; + private static class RowDataReader extends ParquetValueReaders.StructReader { + private final int numFields; - RowReader(List types, List> readers, Types.StructType struct) { + RowDataReader(List types, List> readers) { super(types, readers); - this.structType = struct; + this.numFields = readers.size(); } @Override - protected Row newStructData(Row reuse) { - if (reuse != null) { - return reuse; + protected GenericRowData newStructData(RowData reuse) { + if (reuse instanceof GenericRowData) { + return (GenericRowData) reuse; } else { - return new Row(structType.fields().size()); + return new GenericRowData(numFields); } } @Override - protected Object getField(Row row, int pos) { - return row.getField(pos); + protected Object getField(GenericRowData intermediate, int pos) { + return intermediate.getField(pos); + } + + @Override + protected RowData buildStruct(GenericRowData struct) { + return struct; + } + + @Override + protected void set(GenericRowData row, int pos, Object value) { + row.setField(pos, value); + } + + @Override + protected void setNull(GenericRowData row, int pos) { + row.setField(pos, null); } @Override - protected Row buildStruct(Row row) { - return row; + protected void setBoolean(GenericRowData row, int pos, boolean value) { + row.setField(pos, value); } @Override - protected void set(Row row, int pos, Object value) { + protected void setInteger(GenericRowData row, int pos, int value) { row.setField(pos, value); } + + @Override + protected void setLong(GenericRowData row, int pos, long value) { + row.setField(pos, value); + } + + @Override + protected void setFloat(GenericRowData row, int pos, float value) { + row.setField(pos, value); + } + + @Override + protected void setDouble(GenericRowData row, int pos, double value) { + row.setField(pos, value); + } + } + + private static class ReusableMapData implements MapData { + private final ReusableArrayData keys; + private final ReusableArrayData values; + + private int numElements; + + private ReusableMapData() { + this.keys = new ReusableArrayData(); + this.values = new ReusableArrayData(); + } + + private void grow() { + keys.grow(); + values.grow(); + } + + private int capacity() { + return keys.capacity(); + } + + public void setNumElements(int numElements) { + this.numElements = numElements; + keys.setNumElements(numElements); + values.setNumElements(numElements); + } + + @Override + public int size() { + return numElements; + } + + @Override + public ReusableArrayData keyArray() { + return keys; + } + + @Override + public ReusableArrayData valueArray() { + return values; + } + } + + private static class ReusableArrayData implements ArrayData { + private static final Object[] EMPTY = new Object[0]; + + private Object[] values = EMPTY; + private int numElements = 0; + + private void grow() { + if (values.length == 0) { + this.values = new Object[20]; + } else { + Object[] old = values; + this.values = new Object[old.length << 2]; + // copy the old array in case it has values that can be reused + System.arraycopy(old, 0, values, 0, old.length); + } + } + + private int capacity() { + return values.length; + } + + public void setNumElements(int numElements) { + this.numElements = numElements; + } + + @Override + public int size() { + return numElements; + } + + @Override + public boolean isNullAt(int ordinal) { + return null == values[ordinal]; + } + + @Override + public boolean getBoolean(int ordinal) { + return (boolean) values[ordinal]; + } + + @Override + public byte getByte(int ordinal) { + return (byte) values[ordinal]; + } + + @Override + public short getShort(int ordinal) { + return (short) values[ordinal]; + } + + @Override + public int getInt(int ordinal) { + return (int) values[ordinal]; + } + + @Override + public long getLong(int ordinal) { + return (long) values[ordinal]; + } + + @Override + public float getFloat(int ordinal) { + return (float) values[ordinal]; + } + + @Override + public double getDouble(int ordinal) { + return (double) values[ordinal]; + } + + @Override + public StringData getString(int pos) { + return (StringData) values[pos]; + } + + @Override + public DecimalData getDecimal(int pos, int precision, int scale) { + return (DecimalData) values[pos]; + } + + @Override + public TimestampData getTimestamp(int pos, int precision) { + return (TimestampData) values[pos]; + } + + @SuppressWarnings("unchecked") + @Override + public RawValueData getRawValue(int pos) { + return (RawValueData) values[pos]; + } + + @Override + public byte[] getBinary(int ordinal) { + return (byte[]) values[ordinal]; + } + + @Override + public ArrayData getArray(int ordinal) { + return (ArrayData) values[ordinal]; + } + + @Override + public MapData getMap(int ordinal) { + return (MapData) values[ordinal]; + } + + @Override + public RowData getRow(int pos, int numFields) { + return (RowData) values[pos]; + } + + @Override + public boolean[] toBooleanArray() { + return ArrayUtils.toPrimitive((Boolean[]) values); + } + + @Override + public byte[] toByteArray() { + return ArrayUtils.toPrimitive((Byte[]) values); + } + + @Override + public short[] toShortArray() { + return ArrayUtils.toPrimitive((Short[]) values); + } + + @Override + public int[] toIntArray() { + return ArrayUtils.toPrimitive((Integer[]) values); + } + + @Override + public long[] toLongArray() { + return ArrayUtils.toPrimitive((Long[]) values); + } + + @Override + public float[] toFloatArray() { + return ArrayUtils.toPrimitive((Float[]) values); + } + + @Override + public double[] toDoubleArray() { + return ArrayUtils.toPrimitive((Double[]) values); + } } } diff --git a/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetWriters.java b/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetWriters.java index 54b4fea083a1..5185961fa84e 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetWriters.java +++ b/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetWriters.java @@ -19,38 +19,458 @@ package org.apache.iceberg.flink.data; +import java.math.BigDecimal; +import java.util.Iterator; import java.util.List; -import org.apache.flink.types.Row; -import org.apache.iceberg.data.parquet.BaseParquetWriter; +import java.util.Map; +import java.util.NoSuchElementException; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.DecimalData; +import org.apache.flink.table.data.MapData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.TimestampData; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.MapType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.RowType.RowField; +import org.apache.flink.table.types.logical.SmallIntType; +import org.apache.flink.table.types.logical.TinyIntType; +import org.apache.iceberg.parquet.ParquetValueReaders; import org.apache.iceberg.parquet.ParquetValueWriter; import org.apache.iceberg.parquet.ParquetValueWriters; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.TypeUtil; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; -public class FlinkParquetWriters extends BaseParquetWriter { +public class FlinkParquetWriters { + private FlinkParquetWriters() { + } - private static final FlinkParquetWriters INSTANCE = new FlinkParquetWriters(); + @SuppressWarnings("unchecked") + public static ParquetValueWriter buildWriter(LogicalType schema, MessageType type) { + return (ParquetValueWriter) ParquetWithFlinkSchemaVisitor.visit(schema, type, new WriteBuilder(type)); + } - private FlinkParquetWriters() { + private static class WriteBuilder extends ParquetWithFlinkSchemaVisitor> { + private final MessageType type; + + WriteBuilder(MessageType type) { + this.type = type; + } + + @Override + public ParquetValueWriter message(RowType sStruct, MessageType message, List> fields) { + return struct(sStruct, message.asGroupType(), fields); + } + + @Override + public ParquetValueWriter struct(RowType sStruct, GroupType struct, + List> fieldWriters) { + List fields = struct.getFields(); + List flinkFields = sStruct.getFields(); + List> writers = Lists.newArrayListWithExpectedSize(fieldWriters.size()); + List flinkTypes = Lists.newArrayList(); + for (int i = 0; i < fields.size(); i += 1) { + writers.add(newOption(struct.getType(i), fieldWriters.get(i))); + flinkTypes.add(flinkFields.get(i).getType()); + } + + return new RowDataWriter(writers, flinkTypes); + } + + @Override + public ParquetValueWriter list(ArrayType sArray, GroupType array, ParquetValueWriter elementWriter) { + GroupType repeated = array.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath); + int repeatedR = type.getMaxRepetitionLevel(repeatedPath); + + return new ArrayDataWriter<>(repeatedD, repeatedR, + newOption(repeated.getType(0), elementWriter), + sArray.getElementType()); + } + + @Override + public ParquetValueWriter map(MapType sMap, GroupType map, + ParquetValueWriter keyWriter, ParquetValueWriter valueWriter) { + GroupType repeatedKeyValue = map.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath); + int repeatedR = type.getMaxRepetitionLevel(repeatedPath); + + return new MapDataWriter<>(repeatedD, repeatedR, + newOption(repeatedKeyValue.getType(0), keyWriter), + newOption(repeatedKeyValue.getType(1), valueWriter), + sMap.getKeyType(), sMap.getValueType()); + } + + + private ParquetValueWriter newOption(org.apache.parquet.schema.Type fieldType, ParquetValueWriter writer) { + int maxD = type.getMaxDefinitionLevel(path(fieldType.getName())); + return ParquetValueWriters.option(fieldType, maxD, writer); + } + + @Override + public ParquetValueWriter primitive(LogicalType sType, PrimitiveType primitive) { + ColumnDescriptor desc = type.getColumnDescription(currentPath()); + + if (primitive.getOriginalType() != null) { + switch (primitive.getOriginalType()) { + case ENUM: + case JSON: + case UTF8: + return strings(desc); + case DATE: + case INT_8: + case INT_16: + case INT_32: + return ints(sType, desc); + case TIME_MICROS: + return timeMicros(desc); + case INT_64: + return ParquetValueWriters.longs(desc); + case TIMESTAMP_MICROS: + return timestamps(desc); + case DECIMAL: + DecimalLogicalTypeAnnotation decimal = (DecimalLogicalTypeAnnotation) primitive.getLogicalTypeAnnotation(); + switch (primitive.getPrimitiveTypeName()) { + case INT32: + return decimalAsInteger(desc, decimal.getPrecision(), decimal.getScale()); + case INT64: + return decimalAsLong(desc, decimal.getPrecision(), decimal.getScale()); + case BINARY: + case FIXED_LEN_BYTE_ARRAY: + return decimalAsFixed(desc, decimal.getPrecision(), decimal.getScale()); + default: + throw new UnsupportedOperationException( + "Unsupported base type for decimal: " + primitive.getPrimitiveTypeName()); + } + case BSON: + return byteArrays(desc); + default: + throw new UnsupportedOperationException( + "Unsupported logical type: " + primitive.getOriginalType()); + } + } + + switch (primitive.getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + return byteArrays(desc); + case BOOLEAN: + return ParquetValueWriters.booleans(desc); + case INT32: + return ints(sType, desc); + case INT64: + return ParquetValueWriters.longs(desc); + case FLOAT: + return ParquetValueWriters.floats(desc); + case DOUBLE: + return ParquetValueWriters.doubles(desc); + default: + throw new UnsupportedOperationException("Unsupported type: " + primitive); + } + } + } + + private static ParquetValueWriters.PrimitiveWriter ints(LogicalType type, ColumnDescriptor desc) { + if (type instanceof TinyIntType) { + return ParquetValueWriters.tinyints(desc); + } else if (type instanceof SmallIntType) { + return ParquetValueWriters.shorts(desc); + } + return ParquetValueWriters.ints(desc); + } + + private static ParquetValueWriters.PrimitiveWriter strings(ColumnDescriptor desc) { + return new StringDataWriter(desc); + } + + private static ParquetValueWriters.PrimitiveWriter timeMicros(ColumnDescriptor desc) { + return new TimeMicrosWriter(desc); + } + + private static class TimeMicrosWriter extends ParquetValueWriters.PrimitiveWriter { + + protected TimeMicrosWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, Integer value) { + column.writeLong(repetitionLevel, value * 1000); + } + } + + + private static ParquetValueWriters.PrimitiveWriter decimalAsInteger(ColumnDescriptor desc, + int precision, int scale) { + return new IntegerDecimalWriter(desc, precision, scale); + } + private static ParquetValueWriters.PrimitiveWriter decimalAsLong(ColumnDescriptor desc, + int precision, int scale) { + return new LongDecimalWriter(desc, precision, scale); + } + + private static ParquetValueWriters.PrimitiveWriter decimalAsFixed(ColumnDescriptor desc, + int precision, int scale) { + return new FixedDecimalWriter(desc, precision, scale); + } + + private static ParquetValueWriters.PrimitiveWriter timestamps(ColumnDescriptor desc) { + return new TimestampDataWriter(desc); + } + + private static ParquetValueWriters.PrimitiveWriter byteArrays(ColumnDescriptor desc) { + return new ByteArrayWriter(desc); + } + + private static class StringDataWriter extends ParquetValueWriters.PrimitiveWriter { + private StringDataWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, StringData value) { + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(value.toBytes())); + } + } + + private static class IntegerDecimalWriter extends ParquetValueWriters.PrimitiveWriter { + private final int precision; + private final int scale; + + private IntegerDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public void write(int repetitionLevel, DecimalData decimal) { + Preconditions.checkArgument(decimal.scale() == scale, + "Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, decimal); + Preconditions.checkArgument(decimal.precision() <= precision, + "Cannot write value as decimal(%s,%s), too large: %s", precision, scale, decimal); + + column.writeInteger(repetitionLevel, (int) decimal.toUnscaledLong()); + } + } + + private static class LongDecimalWriter extends ParquetValueWriters.PrimitiveWriter { + private final int precision; + private final int scale; + + private LongDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public void write(int repetitionLevel, DecimalData decimal) { + Preconditions.checkArgument(decimal.scale() == scale, + "Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, decimal); + Preconditions.checkArgument(decimal.precision() <= precision, + "Cannot write value as decimal(%s,%s), too large: %s", precision, scale, decimal); + + column.writeLong(repetitionLevel, decimal.toUnscaledLong()); + } + } + + private static class FixedDecimalWriter extends ParquetValueWriters.PrimitiveWriter { + private final int precision; + private final int scale; + private final int length; + private final ThreadLocal bytes; + + private FixedDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + this.length = TypeUtil.decimalRequiredBytes(precision); + this.bytes = ThreadLocal.withInitial(() -> new byte[length]); + } + + @Override + public void write(int repetitionLevel, DecimalData decimal) { + Preconditions.checkArgument(decimal.scale() == scale, + "Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, decimal); + Preconditions.checkArgument(decimal.precision() <= precision, + "Cannot write value as decimal(%s,%s), too large: %s", precision, scale, decimal); + + BigDecimal bigDecimal = decimal.toBigDecimal(); + + byte fillByte = (byte) (bigDecimal.signum() < 0 ? 0xFF : 0x00); + byte[] unscaled = bigDecimal.unscaledValue().toByteArray(); + byte[] buf = bytes.get(); + int offset = length - unscaled.length; + + for (int i = 0; i < length; i += 1) { + if (i < offset) { + buf[i] = fillByte; + } else { + buf[i] = unscaled[i - offset]; + } + } + + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(buf)); + } } - public static ParquetValueWriter buildWriter(MessageType type) { - return INSTANCE.createWriter(type); + private static class TimestampDataWriter extends ParquetValueWriters.PrimitiveWriter { + private TimestampDataWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, TimestampData value) { + column.writeLong(repetitionLevel, value.getMillisecond() * 1000 + value.getNanoOfMillisecond() / 1000); + } } - @Override - protected ParquetValueWriters.StructWriter createStructWriter(List> writers) { - return new RowWriter(writers); + private static class ByteArrayWriter extends ParquetValueWriters.PrimitiveWriter { + private ByteArrayWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, byte[] bytes) { + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(bytes)); + } + } + + private static class ArrayDataWriter extends ParquetValueWriters.RepeatedWriter { + private final LogicalType elementType; + + private ArrayDataWriter(int definitionLevel, int repetitionLevel, + ParquetValueWriter writer, LogicalType elementType) { + super(definitionLevel, repetitionLevel, writer); + this.elementType = elementType; + } + + @Override + protected Iterator elements(ArrayData list) { + return new ElementIterator<>(list); + } + + private class ElementIterator implements Iterator { + private final int size; + private final ArrayData list; + private int index; + + private ElementIterator(ArrayData list) { + this.list = list; + size = list.size(); + index = 0; + } + + @Override + public boolean hasNext() { + return index != size; + } + + @Override + @SuppressWarnings("unchecked") + public E next() { + if (index >= size) { + throw new NoSuchElementException(); + } + + E element; + if (list.isNullAt(index)) { + element = null; + } else { + element = (E) ArrayData.createElementGetter(elementType).getElementOrNull(list, index); + } + + index += 1; + + return element; + } + } + } + + private static class MapDataWriter extends ParquetValueWriters.RepeatedKeyValueWriter { + private final LogicalType keyType; + private final LogicalType valueType; + + private MapDataWriter(int definitionLevel, int repetitionLevel, + ParquetValueWriter keyWriter, ParquetValueWriter valueWriter, + LogicalType keyType, LogicalType valueType) { + super(definitionLevel, repetitionLevel, keyWriter, valueWriter); + this.keyType = keyType; + this.valueType = valueType; + } + + @Override + protected Iterator> pairs(MapData map) { + return new EntryIterator<>(map); + } + + private class EntryIterator implements Iterator> { + private final int size; + private final ArrayData keys; + private final ArrayData values; + private final ParquetValueReaders.ReusableEntry entry; + private int index; + + private EntryIterator(MapData map) { + size = map.size(); + keys = map.keyArray(); + values = map.valueArray(); + entry = new ParquetValueReaders.ReusableEntry<>(); + index = 0; + } + + @Override + public boolean hasNext() { + return index != size; + } + + @Override + @SuppressWarnings("unchecked") + public Map.Entry next() { + if (index >= size) { + throw new NoSuchElementException(); + } + + if (values.isNullAt(index)) { + entry.set((K) ArrayData.createElementGetter(keyType).getElementOrNull(keys, index), null); + } else { + entry.set((K) ArrayData.createElementGetter(keyType).getElementOrNull(keys, index), + (V) ArrayData.createElementGetter(valueType).getElementOrNull(values, index)); + } + + index += 1; + + return entry; + } + } } - private static class RowWriter extends ParquetValueWriters.StructWriter { + private static class RowDataWriter extends ParquetValueWriters.StructWriter { + private final List types; - private RowWriter(List> writers) { + RowDataWriter(List> writers, List types) { super(writers); + this.types = types; } @Override - protected Object get(Row row, int index) { - return row.getField(index); + protected Object get(RowData struct, int index) { + return RowData.createFieldGetter(types.get(index), index).getFieldOrNull(struct); } } } diff --git a/flink/src/main/java/org/apache/iceberg/flink/data/ParquetWithFlinkSchemaVisitor.java b/flink/src/main/java/org/apache/iceberg/flink/data/ParquetWithFlinkSchemaVisitor.java new file mode 100644 index 000000000000..541986f93889 --- /dev/null +++ b/flink/src/main/java/org/apache/iceberg/flink/data/ParquetWithFlinkSchemaVisitor.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.flink.data; + +import java.util.Deque; +import java.util.List; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.MapType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.RowType.RowField; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +public class ParquetWithFlinkSchemaVisitor { + private final Deque fieldNames = Lists.newLinkedList(); + + public static T visit(LogicalType sType, Type type, ParquetWithFlinkSchemaVisitor visitor) { + Preconditions.checkArgument(sType != null, "Invalid DataType: null"); + if (type instanceof MessageType) { + Preconditions.checkArgument(sType instanceof RowType, "Invalid struct: %s is not a struct", sType); + RowType struct = (RowType) sType; + return visitor.message(struct, (MessageType) type, visitFields(struct, type.asGroupType(), visitor)); + } else if (type.isPrimitive()) { + return visitor.primitive(sType, type.asPrimitiveType()); + } else { + // if not a primitive, the typeId must be a group + GroupType group = type.asGroupType(); + OriginalType annotation = group.getOriginalType(); + if (annotation != null) { + switch (annotation) { + case LIST: + Preconditions.checkArgument(!group.isRepetition(Type.Repetition.REPEATED), + "Invalid list: top-level group is repeated: %s", group); + Preconditions.checkArgument(group.getFieldCount() == 1, + "Invalid list: does not contain single repeated field: %s", group); + + GroupType repeatedElement = group.getFields().get(0).asGroupType(); + Preconditions.checkArgument(repeatedElement.isRepetition(Type.Repetition.REPEATED), + "Invalid list: inner group is not repeated"); + Preconditions.checkArgument(repeatedElement.getFieldCount() <= 1, + "Invalid list: repeated group is not a single field: %s", group); + + Preconditions.checkArgument(sType instanceof ArrayType, "Invalid list: %s is not an array", sType); + ArrayType array = (ArrayType) sType; + RowType.RowField element = new RowField( + "element", array.getElementType(), "element of " + array.asSummaryString()); + + visitor.fieldNames.push(repeatedElement.getName()); + try { + T elementResult = null; + if (repeatedElement.getFieldCount() > 0) { + elementResult = visitField(element, repeatedElement.getType(0), visitor); + } + + return visitor.list(array, group, elementResult); + + } finally { + visitor.fieldNames.pop(); + } + + case MAP: + Preconditions.checkArgument(!group.isRepetition(Type.Repetition.REPEATED), + "Invalid map: top-level group is repeated: %s", group); + Preconditions.checkArgument(group.getFieldCount() == 1, + "Invalid map: does not contain single repeated field: %s", group); + + GroupType repeatedKeyValue = group.getType(0).asGroupType(); + Preconditions.checkArgument(repeatedKeyValue.isRepetition(Type.Repetition.REPEATED), + "Invalid map: inner group is not repeated"); + Preconditions.checkArgument(repeatedKeyValue.getFieldCount() <= 2, + "Invalid map: repeated group does not have 2 fields"); + + Preconditions.checkArgument(sType instanceof MapType, "Invalid map: %s is not a map", sType); + MapType map = (MapType) sType; + RowField keyField = new RowField("key", map.getKeyType(), "key of " + map.asSummaryString()); + RowField valueField = new RowField( + "value", map.getValueType(), "value of " + map.asSummaryString()); + + visitor.fieldNames.push(repeatedKeyValue.getName()); + try { + T keyResult = null; + T valueResult = null; + switch (repeatedKeyValue.getFieldCount()) { + case 2: + // if there are 2 fields, both key and value are projected + keyResult = visitField(keyField, repeatedKeyValue.getType(0), visitor); + valueResult = visitField(valueField, repeatedKeyValue.getType(1), visitor); + break; + case 1: + // if there is just one, use the name to determine what it is + Type keyOrValue = repeatedKeyValue.getType(0); + if (keyOrValue.getName().equalsIgnoreCase("key")) { + keyResult = visitField(keyField, keyOrValue, visitor); + // value result remains null + } else { + valueResult = visitField(valueField, keyOrValue, visitor); + // key result remains null + } + break; + default: + // both results will remain null + } + + return visitor.map(map, group, keyResult, valueResult); + + } finally { + visitor.fieldNames.pop(); + } + + default: + } + } + Preconditions.checkArgument(sType instanceof RowType, "Invalid struct: %s is not a struct", sType); + RowType struct = (RowType) sType; + return visitor.struct(struct, group, visitFields(struct, group, visitor)); + } + } + + private static T visitField(RowType.RowField sField, Type field, ParquetWithFlinkSchemaVisitor visitor) { + visitor.fieldNames.push(field.getName()); + try { + return visit(sField.getType(), field, visitor); + } finally { + visitor.fieldNames.pop(); + } + } + + private static List visitFields(RowType struct, GroupType group, + ParquetWithFlinkSchemaVisitor visitor) { + List sFields = struct.getFields(); + Preconditions.checkArgument(sFields.size() == group.getFieldCount(), + "Structs do not match: %s and %s", struct, group); + List results = Lists.newArrayListWithExpectedSize(group.getFieldCount()); + for (int i = 0; i < sFields.size(); i += 1) { + Type field = group.getFields().get(i); + RowType.RowField sField = sFields.get(i); + Preconditions.checkArgument(field.getName().equals(AvroSchemaUtil.makeCompatibleName(sField.getName())), + "Structs do not match: field %s != %s", field.getName(), sField.getName()); + results.add(visitField(sField, field, visitor)); + } + + return results; + } + + public T message(RowType sStruct, MessageType message, List fields) { + return null; + } + + public T struct(RowType sStruct, GroupType struct, List fields) { + return null; + } + + public T list(ArrayType sArray, GroupType array, T element) { + return null; + } + + public T map(MapType sMap, GroupType map, T key, T value) { + return null; + } + + public T primitive(LogicalType sPrimitive, PrimitiveType primitive) { + return null; + } + + protected String[] currentPath() { + return Lists.newArrayList(fieldNames.descendingIterator()).toArray(new String[0]); + } + + protected String[] path(String name) { + List list = Lists.newArrayList(fieldNames.descendingIterator()); + list.add(name); + return list.toArray(new String[0]); + } + +} diff --git a/flink/src/test/java/org/apache/iceberg/flink/data/RandomData.java b/flink/src/test/java/org/apache/iceberg/flink/data/RandomData.java index 843006197fdf..a409a217e0fd 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/data/RandomData.java +++ b/flink/src/test/java/org/apache/iceberg/flink/data/RandomData.java @@ -19,15 +19,27 @@ package org.apache.iceberg.flink.data; +import java.math.BigDecimal; +import java.util.HashMap; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.NoSuchElementException; import java.util.Random; +import java.util.Set; import java.util.function.Supplier; +import org.apache.flink.table.data.DecimalData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.data.GenericMapData; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.TimestampData; import org.apache.flink.types.Row; import org.apache.iceberg.Schema; import org.apache.iceberg.data.RandomGenericData; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; @@ -64,7 +76,10 @@ private RandomData() { required(23, "koala", Types.IntegerType.get()), required(24, "couch rope", Types.IntegerType.get()) ))), - optional(2, "slide", Types.StringType.get()) + optional(2, "slide", Types.StringType.get()), + optional(25, "binary", Types.BinaryType.get()), + optional(26, "decimal", Types.DecimalType.of(10, 2)), + optional(27, "time micro", Types.TimeType.get()) ); private static Iterable generateData(Schema schema, int numRecords, Supplier supplier) { @@ -88,20 +103,153 @@ public Row next() { }; } + private static Iterable generateRowData(Schema schema, int numRecords, + Supplier supplier) { + return () -> new Iterator() { + private final RandomRowDataGenerator generator = supplier.get(); + private int count = 0; + + @Override + public boolean hasNext() { + return count < numRecords; + } + + @Override + public RowData next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + ++count; + return (RowData) TypeUtil.visit(schema, generator); + } + }; + } + + public static Iterable generateRowData(Schema schema, int numRecords, long seed) { + return generateRowData(schema, numRecords, () -> new RandomRowDataGenerator(seed)); + } + public static Iterable generate(Schema schema, int numRecords, long seed) { return generateData(schema, numRecords, () -> new RandomRowGenerator(seed)); } - public static Iterable generateFallbackData(Schema schema, int numRecords, long seed, long numDictRows) { - return generateData(schema, numRecords, () -> new FallbackGenerator(seed, numDictRows)); + public static Iterable generateFallbackData(Schema schema, int numRecords, long seed, long numDictRows) { + return generateRowData(schema, numRecords, () -> new FallbackGenerator(seed, numDictRows)); } - public static Iterable generateDictionaryEncodableData(Schema schema, int numRecords, long seed) { - return generateData(schema, numRecords, () -> new DictionaryEncodedGenerator(seed)); + public static Iterable generateDictionaryEncodableData(Schema schema, int numRecords, long seed) { + return generateRowData(schema, numRecords, () -> new DictionaryEncodedGenerator(seed)); } - private static class RandomRowGenerator extends RandomGenericData.RandomDataGenerator { + private static class RandomRowDataGenerator extends TypeUtil.CustomOrderSchemaVisitor { + protected final Random random; + private static final int MAX_ENTRIES = 20; + + RandomRowDataGenerator(long seed) { + this.random = new Random(seed); + } + + protected int getMaxEntries() { + return MAX_ENTRIES; + } + + @Override + public RowData schema(Schema schema, Supplier structResult) { + return (RowData) structResult.get(); + } + + @Override + public RowData struct(Types.StructType struct, Iterable fieldResults) { + GenericRowData row = new GenericRowData(struct.fields().size()); + + List values = Lists.newArrayList(fieldResults); + for (int i = 0; i < values.size(); i += 1) { + row.setField(i, values.get(i)); + } + return row; + } + + @Override + public Object field(Types.NestedField field, Supplier fieldResult) { + // return null 5% of the time when the value is optional + if (field.isOptional() && random.nextInt(20) == 1) { + return null; + } + return fieldResult.get(); + } + + @Override + public Object list(Types.ListType list, Supplier elementResult) { + int numElements = random.nextInt(20); + Object[] arr = new Object[numElements]; + GenericArrayData result = new GenericArrayData(arr); + + for (int i = 0; i < numElements; i += 1) { + // return null 5% of the time when the value is optional + if (list.isElementOptional() && random.nextInt(20) == 1) { + arr[i] = null; + } else { + arr[i] = elementResult.get(); + } + } + + return result; + } + + @Override + public Object map(Types.MapType map, Supplier keyResult, Supplier valueResult) { + int numEntries = random.nextInt(getMaxEntries()); + + Object[] keysArr = new Object[numEntries]; + Map javaMap = new HashMap<>(); + + Set keySet = Sets.newHashSet(); + for (int i = 0; i < numEntries; i += 1) { + Object key = keyResult.get(); + // ensure no collisions + while (keySet.contains(key)) { + key = keyResult.get(); + } + + keySet.add(key); + keysArr[i] = key; + + if (map.isValueOptional() && random.nextInt(20) == 1) { + javaMap.put(keysArr[i], null); + } else { + javaMap.put(keysArr[i], valueResult.get()); + } + } + + return new GenericMapData(javaMap); + } + + @Override + public Object primitive(Type.PrimitiveType primitive) { + Object obj = randomValue(primitive, random); + switch (primitive.typeId()) { + case STRING: + return StringData.fromString((String) obj); + case DECIMAL: + return DecimalData.fromBigDecimal((BigDecimal) obj, + ((BigDecimal) obj).precision(), + ((BigDecimal) obj).scale()); + case TIMESTAMP: + return TimestampData.fromEpochMillis((Long) obj); + case TIME: + return ((Long) obj).intValue(); + default: + return obj; + } + } + + protected Object randomValue(Type.PrimitiveType primitive, Random rand) { + return RandomUtil.generatePrimitive(primitive, random); + } + } + + private static class RandomRowGenerator extends RandomGenericData.RandomDataGenerator { RandomRowGenerator(long seed) { super(seed); } @@ -124,7 +272,7 @@ public Row struct(Types.StructType struct, Iterable fieldResults) { } } - private static class DictionaryEncodedGenerator extends RandomRowGenerator { + private static class DictionaryEncodedGenerator extends RandomRowDataGenerator { DictionaryEncodedGenerator(long seed) { super(seed); } @@ -144,7 +292,7 @@ protected Object randomValue(Type.PrimitiveType primitive, Random random) { } } - private static class FallbackGenerator extends RandomRowGenerator { + private static class FallbackGenerator extends RandomRowDataGenerator { private final long dictionaryEncodedRows; private long rowCount = 0; diff --git a/flink/src/test/java/org/apache/iceberg/flink/data/TestFlinkParquetReaderWriter.java b/flink/src/test/java/org/apache/iceberg/flink/data/TestFlinkParquetReaderWriter.java index 41ea960b72c2..11ad5da76f31 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/data/TestFlinkParquetReaderWriter.java +++ b/flink/src/test/java/org/apache/iceberg/flink/data/TestFlinkParquetReaderWriter.java @@ -21,13 +21,19 @@ import java.io.File; import java.io.IOException; +import java.util.ArrayList; import java.util.Iterator; -import org.apache.flink.types.Row; +import java.util.List; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.RowData; import org.apache.iceberg.Files; import org.apache.iceberg.Schema; +import org.apache.iceberg.flink.FlinkSchemaUtil; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -41,34 +47,193 @@ public class TestFlinkParquetReaderWriter { @Rule public TemporaryFolder temp = new TemporaryFolder(); - private void testCorrectness(Schema schema, int numRecords, Iterable iterable) throws IOException { + private void testCorrectness(Schema schema, int numRecords, Iterable iterable) throws IOException { File testFile = temp.newFile(); Assert.assertTrue("Delete should succeed", testFile.delete()); - try (FileAppender writer = Parquet.write(Files.localOutput(testFile)) + try (FileAppender writer = Parquet.write(Files.localOutput(testFile)) .schema(schema) - .createWriterFunc(FlinkParquetWriters::buildWriter) + .createWriterFunc(msgType -> FlinkParquetWriters.buildWriter(FlinkSchemaUtil.convert(schema), msgType)) .build()) { writer.addAll(iterable); } - try (CloseableIterable reader = Parquet.read(Files.localInput(testFile)) + try (CloseableIterable reader = Parquet.read(Files.localInput(testFile)) .project(schema) .createReaderFunc(type -> FlinkParquetReaders.buildReader(schema, type)) .build()) { - Iterator expected = iterable.iterator(); - Iterator rows = reader.iterator(); + Iterator expected = iterable.iterator(); + Iterator rows = reader.iterator(); for (int i = 0; i < numRecords; i += 1) { Assert.assertTrue("Should have expected number of rows", rows.hasNext()); - Assert.assertEquals(expected.next(), rows.next()); + assertRowData(schema.asStruct(), expected.next(), rows.next()); } Assert.assertFalse("Should not have extra rows", rows.hasNext()); } } + private void assertRowData(Type type, RowData expected, RowData actual) { + List types = new ArrayList<>(); + for (Types.NestedField field : type.asStructType().fields()) { + types.add(field.type()); + } + + for (int i = 0; i < types.size(); i += 1) { + if (expected.isNullAt(i)) { + Assert.assertEquals(expected.isNullAt(i), actual.isNullAt(i)); + continue; + } + switch (types.get(i).typeId()) { + case BOOLEAN: + Assert.assertEquals("boolean value should be equal", expected.getBoolean(i), actual.getBoolean(i)); + break; + case INTEGER: + Assert.assertEquals("int value should be equal", expected.getInt(i), actual.getInt(i)); + break; + case LONG: + Assert.assertEquals("long value should be equal", expected.getLong(i), actual.getLong(i)); + break; + case FLOAT: + Assert.assertEquals("float value should be equal", Float.valueOf(expected.getFloat(i)), + Float.valueOf(actual.getFloat(i))); + break; + case DOUBLE: + Assert.assertEquals("double should be equal", Double.valueOf(expected.getDouble(i)), + Double.valueOf(actual.getDouble(i))); + break; + case DATE: + Assert.assertEquals("date should be equal", expected.getInt(i), expected.getInt(i)); + break; + case TIME: + Assert.assertEquals("time should be equal", expected.getInt(i), expected.getInt(i)); + break; + case TIMESTAMP: + Assert.assertEquals("timestamp should be equal", expected.getTimestamp(i, 6), + actual.getTimestamp(i, 6)); + break; + case UUID: + case FIXED: + case BINARY: + Assert.assertArrayEquals("binary should be equal", expected.getBinary(i), actual.getBinary(i)); + break; + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) types.get(i); + int precision = decimal.precision(); + int scale = decimal.scale(); + Assert.assertEquals("uuid should be equal", + expected.getDecimal(i, precision, scale), + actual.getDecimal(i, precision, scale)); + break; + case LIST: + ArrayData arrayData1 = expected.getArray(i); + ArrayData arrayData2 = actual.getArray(i); + Assert.assertEquals("array length should be equal", arrayData1.size(), arrayData2.size()); + for (int j = 0; j < arrayData1.size(); j += 1) { + assertArrayValues(types.get(i).asListType().elementType(), arrayData1, arrayData2); + } + break; + case MAP: + ArrayData keyArrayData1 = expected.getMap(i).keyArray(); + ArrayData valueArrayData1 = expected.getMap(i).valueArray(); + ArrayData keyArrayData2 = actual.getMap(i).keyArray(); + ArrayData valueArrayData2 = actual.getMap(i).valueArray(); + Type keyType = types.get(i).asMapType().keyType(); + Type valueType = types.get(i).asMapType().valueType(); + + Assert.assertEquals("map size should be equal", expected.getMap(i).size(), actual.getMap(i).size()); + + for (int j = 0; j < keyArrayData1.size(); j += 1) { + assertArrayValues(keyType, keyArrayData1, keyArrayData2); + assertArrayValues(valueType, valueArrayData1, valueArrayData2); + } + break; + case STRUCT: + int numFields = types.get(i).asStructType().fields().size(); + assertRowData(types.get(i).asStructType(), expected.getRow(i, numFields), actual.getRow(i, numFields)); + break; + } + } + } + + private void assertArrayValues(Type type, ArrayData expected, ArrayData actual) { + for (int i = 0; i < expected.size(); i += 1) { + if (expected.isNullAt(i)) { + Assert.assertEquals(expected.isNullAt(i), actual.isNullAt(i)); + continue; + } + switch (type.typeId()) { + case BOOLEAN: + Assert.assertEquals("boolean value should be equal", expected.getBoolean(i), actual.getBoolean(i)); + break; + case INTEGER: + Assert.assertEquals("int value should be equal", expected.getInt(i), actual.getInt(i)); + break; + case LONG: + Assert.assertEquals("long value should be equal", expected.getLong(i), actual.getLong(i)); + break; + case FLOAT: + Assert.assertEquals("float value should be equal", Float.valueOf(expected.getFloat(i)), + Float.valueOf(actual.getFloat(i))); + break; + case DOUBLE: + Assert.assertEquals("double should be equal", Double.valueOf(expected.getDouble(i)), + Double.valueOf(actual.getDouble(i))); + break; + case DATE: + Assert.assertEquals("date should be equal", expected.getInt(i), expected.getInt(i)); + break; + case TIME: + Assert.assertEquals("time should be equal", expected.getInt(i), expected.getInt(i)); + break; + case TIMESTAMP: + Assert.assertEquals("timestamp should be equal", expected.getTimestamp(i, 6), + actual.getTimestamp(i, 6)); + break; + case UUID: + case FIXED: + case BINARY: + Assert.assertArrayEquals("binary should be equal", expected.getBinary(i), actual.getBinary(i)); + break; + case DECIMAL: + Types.DecimalType decimal = (Types.DecimalType) type; + int precision = decimal.precision(); + int scale = decimal.scale(); + Assert.assertEquals("uuid should be equal", + expected.getDecimal(i, precision, scale), + actual.getDecimal(i, precision, scale)); + break; + case LIST: + ArrayData arrayData1 = expected.getArray(i); + ArrayData arrayData2 = actual.getArray(i); + Assert.assertEquals("array length should be equal", arrayData1.size(), arrayData2.size()); + for (int j = 0; j < arrayData1.size(); j += 1) { + assertArrayValues(type.asListType().elementType(), arrayData1, arrayData2); + } + break; + case MAP: + ArrayData keyArrayData1 = expected.getMap(i).keyArray(); + ArrayData valueArrayData1 = expected.getMap(i).valueArray(); + ArrayData keyArrayData2 = actual.getMap(i).keyArray(); + ArrayData valueArrayData2 = actual.getMap(i).valueArray(); + + Assert.assertEquals("map size should be equal", expected.getMap(i).size(), actual.getMap(i).size()); + + for (int j = 0; j < keyArrayData1.size(); j += 1) { + assertArrayValues(type.asMapType().keyType(), keyArrayData1, keyArrayData2); + assertArrayValues(type.asMapType().valueType(), valueArrayData1, valueArrayData2); + } + break; + case STRUCT: + int numFields = type.asStructType().fields().size(); + assertRowData(type.asStructType(), expected.getRow(i, numFields), actual.getRow(i, numFields)); + break; + } + } + } + @Test public void testNormalRowData() throws IOException { - testCorrectness(COMPLEX_SCHEMA, NUM_RECORDS, RandomData.generate(COMPLEX_SCHEMA, NUM_RECORDS, 19981)); + testCorrectness(COMPLEX_SCHEMA, NUM_RECORDS, RandomData.generateRowData(COMPLEX_SCHEMA, NUM_RECORDS, 19981)); } @Test