diff --git a/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetWriters.java b/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetWriters.java index 54b4fea083a1..75104cacd9a5 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetWriters.java +++ b/flink/src/main/java/org/apache/iceberg/flink/data/FlinkParquetWriters.java @@ -19,38 +19,436 @@ package org.apache.iceberg.flink.data; +import java.util.Iterator; import java.util.List; -import org.apache.flink.types.Row; -import org.apache.iceberg.data.parquet.BaseParquetWriter; +import java.util.Map; +import java.util.NoSuchElementException; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.DecimalData; +import org.apache.flink.table.data.MapData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.TimestampData; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.MapType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.RowType.RowField; +import org.apache.flink.table.types.logical.SmallIntType; +import org.apache.flink.table.types.logical.TinyIntType; +import org.apache.iceberg.parquet.ParquetValueReaders; import org.apache.iceberg.parquet.ParquetValueWriter; import org.apache.iceberg.parquet.ParquetValueWriters; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.util.DecimalUtil; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; -public class FlinkParquetWriters extends BaseParquetWriter { +public class FlinkParquetWriters { + private FlinkParquetWriters() { + } - private static final FlinkParquetWriters INSTANCE = new FlinkParquetWriters(); + @SuppressWarnings("unchecked") + public static ParquetValueWriter buildWriter(LogicalType schema, MessageType type) { + return (ParquetValueWriter) ParquetWithFlinkSchemaVisitor.visit(schema, type, new WriteBuilder(type)); + } - private FlinkParquetWriters() { + private static class WriteBuilder extends ParquetWithFlinkSchemaVisitor> { + private final MessageType type; + + WriteBuilder(MessageType type) { + this.type = type; + } + + @Override + public ParquetValueWriter message(RowType sStruct, MessageType message, List> fields) { + return struct(sStruct, message.asGroupType(), fields); + } + + @Override + public ParquetValueWriter struct(RowType sStruct, GroupType struct, + List> fieldWriters) { + List fields = struct.getFields(); + List flinkFields = sStruct.getFields(); + List> writers = Lists.newArrayListWithExpectedSize(fieldWriters.size()); + List flinkTypes = Lists.newArrayList(); + for (int i = 0; i < fields.size(); i += 1) { + writers.add(newOption(struct.getType(i), fieldWriters.get(i))); + flinkTypes.add(flinkFields.get(i).getType()); + } + + return new RowDataWriter(writers, flinkTypes); + } + + @Override + public ParquetValueWriter list(ArrayType sArray, GroupType array, ParquetValueWriter elementWriter) { + GroupType repeated = array.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath); + int repeatedR = type.getMaxRepetitionLevel(repeatedPath); + + return new ArrayDataWriter<>(repeatedD, repeatedR, + newOption(repeated.getType(0), elementWriter), + sArray.getElementType()); + } + + @Override + public ParquetValueWriter map(MapType sMap, GroupType map, + ParquetValueWriter keyWriter, ParquetValueWriter valueWriter) { + GroupType repeatedKeyValue = map.getFields().get(0).asGroupType(); + String[] repeatedPath = currentPath(); + + int repeatedD = type.getMaxDefinitionLevel(repeatedPath); + int repeatedR = type.getMaxRepetitionLevel(repeatedPath); + + return new MapDataWriter<>(repeatedD, repeatedR, + newOption(repeatedKeyValue.getType(0), keyWriter), + newOption(repeatedKeyValue.getType(1), valueWriter), + sMap.getKeyType(), sMap.getValueType()); + } + + + private ParquetValueWriter newOption(org.apache.parquet.schema.Type fieldType, ParquetValueWriter writer) { + int maxD = type.getMaxDefinitionLevel(path(fieldType.getName())); + return ParquetValueWriters.option(fieldType, maxD, writer); + } + + @Override + public ParquetValueWriter primitive(LogicalType sType, PrimitiveType primitive) { + ColumnDescriptor desc = type.getColumnDescription(currentPath()); + + if (primitive.getOriginalType() != null) { + switch (primitive.getOriginalType()) { + case ENUM: + case JSON: + case UTF8: + return strings(desc); + case DATE: + case INT_8: + case INT_16: + case INT_32: + return ints(sType, desc); + case INT_64: + return ParquetValueWriters.longs(desc); + case TIME_MICROS: + return timeMicros(desc); + case TIMESTAMP_MICROS: + return timestamps(desc); + case DECIMAL: + DecimalLogicalTypeAnnotation decimal = (DecimalLogicalTypeAnnotation) primitive.getLogicalTypeAnnotation(); + switch (primitive.getPrimitiveTypeName()) { + case INT32: + return decimalAsInteger(desc, decimal.getPrecision(), decimal.getScale()); + case INT64: + return decimalAsLong(desc, decimal.getPrecision(), decimal.getScale()); + case BINARY: + case FIXED_LEN_BYTE_ARRAY: + return decimalAsFixed(desc, decimal.getPrecision(), decimal.getScale()); + default: + throw new UnsupportedOperationException( + "Unsupported base type for decimal: " + primitive.getPrimitiveTypeName()); + } + case BSON: + return byteArrays(desc); + default: + throw new UnsupportedOperationException( + "Unsupported logical type: " + primitive.getOriginalType()); + } + } + + switch (primitive.getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + return byteArrays(desc); + case BOOLEAN: + return ParquetValueWriters.booleans(desc); + case INT32: + return ints(sType, desc); + case INT64: + return ParquetValueWriters.longs(desc); + case FLOAT: + return ParquetValueWriters.floats(desc); + case DOUBLE: + return ParquetValueWriters.doubles(desc); + default: + throw new UnsupportedOperationException("Unsupported type: " + primitive); + } + } + } + + private static ParquetValueWriters.PrimitiveWriter ints(LogicalType type, ColumnDescriptor desc) { + if (type instanceof TinyIntType) { + return ParquetValueWriters.tinyints(desc); + } else if (type instanceof SmallIntType) { + return ParquetValueWriters.shorts(desc); + } + return ParquetValueWriters.ints(desc); + } + + private static ParquetValueWriters.PrimitiveWriter strings(ColumnDescriptor desc) { + return new StringDataWriter(desc); + } + + private static ParquetValueWriters.PrimitiveWriter timeMicros(ColumnDescriptor desc) { + return new TimeMicrosWriter(desc); + } + + private static ParquetValueWriters.PrimitiveWriter decimalAsInteger(ColumnDescriptor desc, + int precision, int scale) { + return new IntegerDecimalWriter(desc, precision, scale); + } + private static ParquetValueWriters.PrimitiveWriter decimalAsLong(ColumnDescriptor desc, + int precision, int scale) { + return new LongDecimalWriter(desc, precision, scale); + } + + private static ParquetValueWriters.PrimitiveWriter decimalAsFixed(ColumnDescriptor desc, + int precision, int scale) { + return new FixedDecimalWriter(desc, precision, scale); + } + + private static ParquetValueWriters.PrimitiveWriter timestamps(ColumnDescriptor desc) { + return new TimestampDataWriter(desc); + } + + private static ParquetValueWriters.PrimitiveWriter byteArrays(ColumnDescriptor desc) { + return new ByteArrayWriter(desc); + } + + private static class StringDataWriter extends ParquetValueWriters.PrimitiveWriter { + private StringDataWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, StringData value) { + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(value.toBytes())); + } + } + + private static class TimeMicrosWriter extends ParquetValueWriters.PrimitiveWriter { + private TimeMicrosWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, Integer value) { + long micros = Long.valueOf(value) * 1000; + column.writeLong(repetitionLevel, micros); + } + } + + private static class IntegerDecimalWriter extends ParquetValueWriters.PrimitiveWriter { + private final int precision; + private final int scale; + + private IntegerDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public void write(int repetitionLevel, DecimalData decimal) { + Preconditions.checkArgument(decimal.scale() == scale, + "Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, decimal); + Preconditions.checkArgument(decimal.precision() <= precision, + "Cannot write value as decimal(%s,%s), too large: %s", precision, scale, decimal); + + column.writeInteger(repetitionLevel, (int) decimal.toUnscaledLong()); + } + } + + private static class LongDecimalWriter extends ParquetValueWriters.PrimitiveWriter { + private final int precision; + private final int scale; + + private LongDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + } + + @Override + public void write(int repetitionLevel, DecimalData decimal) { + Preconditions.checkArgument(decimal.scale() == scale, + "Cannot write value as decimal(%s,%s), wrong scale: %s", precision, scale, decimal); + Preconditions.checkArgument(decimal.precision() <= precision, + "Cannot write value as decimal(%s,%s), too large: %s", precision, scale, decimal); + + column.writeLong(repetitionLevel, decimal.toUnscaledLong()); + } + } + + private static class FixedDecimalWriter extends ParquetValueWriters.PrimitiveWriter { + private final int precision; + private final int scale; + private final ThreadLocal bytes; + + private FixedDecimalWriter(ColumnDescriptor desc, int precision, int scale) { + super(desc); + this.precision = precision; + this.scale = scale; + this.bytes = ThreadLocal.withInitial(() -> new byte[TypeUtil.decimalRequiredBytes(precision)]); + } + + @Override + public void write(int repetitionLevel, DecimalData decimal) { + byte[] binary = DecimalUtil.toReusedFixLengthBytes(precision, scale, decimal.toBigDecimal(), bytes.get()); + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(binary)); + } + } + + private static class TimestampDataWriter extends ParquetValueWriters.PrimitiveWriter { + private TimestampDataWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, TimestampData value) { + column.writeLong(repetitionLevel, value.getMillisecond() * 1000 + value.getNanoOfMillisecond() / 1000); + } + } + + private static class ByteArrayWriter extends ParquetValueWriters.PrimitiveWriter { + private ByteArrayWriter(ColumnDescriptor desc) { + super(desc); + } + + @Override + public void write(int repetitionLevel, byte[] bytes) { + column.writeBinary(repetitionLevel, Binary.fromReusedByteArray(bytes)); + } } - public static ParquetValueWriter buildWriter(MessageType type) { - return INSTANCE.createWriter(type); + private static class ArrayDataWriter extends ParquetValueWriters.RepeatedWriter { + private final LogicalType elementType; + + private ArrayDataWriter(int definitionLevel, int repetitionLevel, + ParquetValueWriter writer, LogicalType elementType) { + super(definitionLevel, repetitionLevel, writer); + this.elementType = elementType; + } + + @Override + protected Iterator elements(ArrayData list) { + return new ElementIterator<>(list); + } + + private class ElementIterator implements Iterator { + private final int size; + private final ArrayData list; + private int index; + + private ElementIterator(ArrayData list) { + this.list = list; + size = list.size(); + index = 0; + } + + @Override + public boolean hasNext() { + return index != size; + } + + @Override + @SuppressWarnings("unchecked") + public E next() { + if (index >= size) { + throw new NoSuchElementException(); + } + + E element; + if (list.isNullAt(index)) { + element = null; + } else { + element = (E) ArrayData.createElementGetter(elementType).getElementOrNull(list, index); + } + + index += 1; + + return element; + } + } } - @Override - protected ParquetValueWriters.StructWriter createStructWriter(List> writers) { - return new RowWriter(writers); + private static class MapDataWriter extends ParquetValueWriters.RepeatedKeyValueWriter { + private final LogicalType keyType; + private final LogicalType valueType; + + private MapDataWriter(int definitionLevel, int repetitionLevel, + ParquetValueWriter keyWriter, ParquetValueWriter valueWriter, + LogicalType keyType, LogicalType valueType) { + super(definitionLevel, repetitionLevel, keyWriter, valueWriter); + this.keyType = keyType; + this.valueType = valueType; + } + + @Override + protected Iterator> pairs(MapData map) { + return new EntryIterator<>(map); + } + + private class EntryIterator implements Iterator> { + private final int size; + private final ArrayData keys; + private final ArrayData values; + private final ParquetValueReaders.ReusableEntry entry; + private int index; + + private EntryIterator(MapData map) { + size = map.size(); + keys = map.keyArray(); + values = map.valueArray(); + entry = new ParquetValueReaders.ReusableEntry<>(); + index = 0; + } + + @Override + public boolean hasNext() { + return index != size; + } + + @Override + @SuppressWarnings("unchecked") + public Map.Entry next() { + if (index >= size) { + throw new NoSuchElementException(); + } + + if (values.isNullAt(index)) { + entry.set((K) ArrayData.createElementGetter(keyType).getElementOrNull(keys, index), null); + } else { + entry.set((K) ArrayData.createElementGetter(keyType).getElementOrNull(keys, index), + (V) ArrayData.createElementGetter(valueType).getElementOrNull(values, index)); + } + + index += 1; + + return entry; + } + } } - private static class RowWriter extends ParquetValueWriters.StructWriter { + private static class RowDataWriter extends ParquetValueWriters.StructWriter { + private final List types; - private RowWriter(List> writers) { + RowDataWriter(List> writers, List types) { super(writers); + this.types = types; } @Override - protected Object get(Row row, int index) { - return row.getField(index); + protected Object get(RowData struct, int index) { + return RowData.createFieldGetter(types.get(index), index).getFieldOrNull(struct); } } } diff --git a/flink/src/main/java/org/apache/iceberg/flink/data/ParquetWithFlinkSchemaVisitor.java b/flink/src/main/java/org/apache/iceberg/flink/data/ParquetWithFlinkSchemaVisitor.java new file mode 100644 index 000000000000..541986f93889 --- /dev/null +++ b/flink/src/main/java/org/apache/iceberg/flink/data/ParquetWithFlinkSchemaVisitor.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.flink.data; + +import java.util.Deque; +import java.util.List; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.MapType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.RowType.RowField; +import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.OriginalType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +public class ParquetWithFlinkSchemaVisitor { + private final Deque fieldNames = Lists.newLinkedList(); + + public static T visit(LogicalType sType, Type type, ParquetWithFlinkSchemaVisitor visitor) { + Preconditions.checkArgument(sType != null, "Invalid DataType: null"); + if (type instanceof MessageType) { + Preconditions.checkArgument(sType instanceof RowType, "Invalid struct: %s is not a struct", sType); + RowType struct = (RowType) sType; + return visitor.message(struct, (MessageType) type, visitFields(struct, type.asGroupType(), visitor)); + } else if (type.isPrimitive()) { + return visitor.primitive(sType, type.asPrimitiveType()); + } else { + // if not a primitive, the typeId must be a group + GroupType group = type.asGroupType(); + OriginalType annotation = group.getOriginalType(); + if (annotation != null) { + switch (annotation) { + case LIST: + Preconditions.checkArgument(!group.isRepetition(Type.Repetition.REPEATED), + "Invalid list: top-level group is repeated: %s", group); + Preconditions.checkArgument(group.getFieldCount() == 1, + "Invalid list: does not contain single repeated field: %s", group); + + GroupType repeatedElement = group.getFields().get(0).asGroupType(); + Preconditions.checkArgument(repeatedElement.isRepetition(Type.Repetition.REPEATED), + "Invalid list: inner group is not repeated"); + Preconditions.checkArgument(repeatedElement.getFieldCount() <= 1, + "Invalid list: repeated group is not a single field: %s", group); + + Preconditions.checkArgument(sType instanceof ArrayType, "Invalid list: %s is not an array", sType); + ArrayType array = (ArrayType) sType; + RowType.RowField element = new RowField( + "element", array.getElementType(), "element of " + array.asSummaryString()); + + visitor.fieldNames.push(repeatedElement.getName()); + try { + T elementResult = null; + if (repeatedElement.getFieldCount() > 0) { + elementResult = visitField(element, repeatedElement.getType(0), visitor); + } + + return visitor.list(array, group, elementResult); + + } finally { + visitor.fieldNames.pop(); + } + + case MAP: + Preconditions.checkArgument(!group.isRepetition(Type.Repetition.REPEATED), + "Invalid map: top-level group is repeated: %s", group); + Preconditions.checkArgument(group.getFieldCount() == 1, + "Invalid map: does not contain single repeated field: %s", group); + + GroupType repeatedKeyValue = group.getType(0).asGroupType(); + Preconditions.checkArgument(repeatedKeyValue.isRepetition(Type.Repetition.REPEATED), + "Invalid map: inner group is not repeated"); + Preconditions.checkArgument(repeatedKeyValue.getFieldCount() <= 2, + "Invalid map: repeated group does not have 2 fields"); + + Preconditions.checkArgument(sType instanceof MapType, "Invalid map: %s is not a map", sType); + MapType map = (MapType) sType; + RowField keyField = new RowField("key", map.getKeyType(), "key of " + map.asSummaryString()); + RowField valueField = new RowField( + "value", map.getValueType(), "value of " + map.asSummaryString()); + + visitor.fieldNames.push(repeatedKeyValue.getName()); + try { + T keyResult = null; + T valueResult = null; + switch (repeatedKeyValue.getFieldCount()) { + case 2: + // if there are 2 fields, both key and value are projected + keyResult = visitField(keyField, repeatedKeyValue.getType(0), visitor); + valueResult = visitField(valueField, repeatedKeyValue.getType(1), visitor); + break; + case 1: + // if there is just one, use the name to determine what it is + Type keyOrValue = repeatedKeyValue.getType(0); + if (keyOrValue.getName().equalsIgnoreCase("key")) { + keyResult = visitField(keyField, keyOrValue, visitor); + // value result remains null + } else { + valueResult = visitField(valueField, keyOrValue, visitor); + // key result remains null + } + break; + default: + // both results will remain null + } + + return visitor.map(map, group, keyResult, valueResult); + + } finally { + visitor.fieldNames.pop(); + } + + default: + } + } + Preconditions.checkArgument(sType instanceof RowType, "Invalid struct: %s is not a struct", sType); + RowType struct = (RowType) sType; + return visitor.struct(struct, group, visitFields(struct, group, visitor)); + } + } + + private static T visitField(RowType.RowField sField, Type field, ParquetWithFlinkSchemaVisitor visitor) { + visitor.fieldNames.push(field.getName()); + try { + return visit(sField.getType(), field, visitor); + } finally { + visitor.fieldNames.pop(); + } + } + + private static List visitFields(RowType struct, GroupType group, + ParquetWithFlinkSchemaVisitor visitor) { + List sFields = struct.getFields(); + Preconditions.checkArgument(sFields.size() == group.getFieldCount(), + "Structs do not match: %s and %s", struct, group); + List results = Lists.newArrayListWithExpectedSize(group.getFieldCount()); + for (int i = 0; i < sFields.size(); i += 1) { + Type field = group.getFields().get(i); + RowType.RowField sField = sFields.get(i); + Preconditions.checkArgument(field.getName().equals(AvroSchemaUtil.makeCompatibleName(sField.getName())), + "Structs do not match: field %s != %s", field.getName(), sField.getName()); + results.add(visitField(sField, field, visitor)); + } + + return results; + } + + public T message(RowType sStruct, MessageType message, List fields) { + return null; + } + + public T struct(RowType sStruct, GroupType struct, List fields) { + return null; + } + + public T list(ArrayType sArray, GroupType array, T element) { + return null; + } + + public T map(MapType sMap, GroupType map, T key, T value) { + return null; + } + + public T primitive(LogicalType sPrimitive, PrimitiveType primitive) { + return null; + } + + protected String[] currentPath() { + return Lists.newArrayList(fieldNames.descendingIterator()).toArray(new String[0]); + } + + protected String[] path(String name) { + List list = Lists.newArrayList(fieldNames.descendingIterator()); + list.add(name); + return list.toArray(new String[0]); + } + +} diff --git a/flink/src/test/java/org/apache/iceberg/flink/data/RandomData.java b/flink/src/test/java/org/apache/iceberg/flink/data/RandomData.java deleted file mode 100644 index b1e14c6c0fc5..000000000000 --- a/flink/src/test/java/org/apache/iceberg/flink/data/RandomData.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iceberg.flink.data; - -import java.util.Iterator; -import java.util.List; -import java.util.NoSuchElementException; -import java.util.function.Supplier; -import org.apache.flink.types.Row; -import org.apache.iceberg.Schema; -import org.apache.iceberg.data.RandomGenericData; -import org.apache.iceberg.relocated.com.google.common.collect.Lists; -import org.apache.iceberg.types.TypeUtil; -import org.apache.iceberg.types.Types; - -import static org.apache.iceberg.types.Types.NestedField.optional; -import static org.apache.iceberg.types.Types.NestedField.required; - -public class RandomData { - private RandomData() { - } - - static final Schema COMPLEX_SCHEMA = new Schema( - required(1, "roots", Types.LongType.get()), - optional(3, "lime", Types.ListType.ofRequired(4, Types.DoubleType.get())), - required(5, "strict", Types.StructType.of( - required(9, "tangerine", Types.StringType.get()), - optional(6, "hopeful", Types.StructType.of( - required(7, "steel", Types.FloatType.get()), - required(8, "lantern", Types.DateType.get()) - )), - optional(10, "vehement", Types.LongType.get()) - )), - optional(11, "metamorphosis", Types.MapType.ofRequired(12, 13, - Types.StringType.get(), Types.TimestampType.withZone())), - required(14, "winter", Types.ListType.ofOptional(15, Types.StructType.of( - optional(16, "beet", Types.DoubleType.get()), - required(17, "stamp", Types.FloatType.get()), - optional(18, "wheeze", Types.StringType.get()) - ))), - optional(19, "renovate", Types.MapType.ofRequired(20, 21, - Types.StringType.get(), Types.StructType.of( - optional(22, "jumpy", Types.DoubleType.get()), - required(23, "koala", Types.IntegerType.get()), - required(24, "couch rope", Types.IntegerType.get()) - ))), - optional(2, "slide", Types.StringType.get()) - ); - - private static Iterable generateData(Schema schema, int numRecords, Supplier supplier) { - return () -> new Iterator() { - private final RandomRowGenerator generator = supplier.get(); - private int count = 0; - - @Override - public boolean hasNext() { - return count < numRecords; - } - - @Override - public Row next() { - if (!hasNext()) { - throw new NoSuchElementException(); - } - ++count; - return (Row) TypeUtil.visit(schema, generator); - } - }; - } - - public static Iterable generate(Schema schema, int numRecords, long seed) { - return generateData(schema, numRecords, () -> new RandomRowGenerator(seed)); - } - - private static class RandomRowGenerator extends RandomGenericData.RandomDataGenerator { - RandomRowGenerator(long seed) { - super(seed); - } - - @Override - public Row schema(Schema schema, Supplier structResult) { - return (Row) structResult.get(); - } - - @Override - public Row struct(Types.StructType struct, Iterable fieldResults) { - Row row = new Row(struct.fields().size()); - - List values = Lists.newArrayList(fieldResults); - for (int i = 0; i < values.size(); i += 1) { - row.setField(i, values.get(i)); - } - - return row; - } - } -} diff --git a/flink/src/test/java/org/apache/iceberg/flink/data/TestFlinkParquetWriter.java b/flink/src/test/java/org/apache/iceberg/flink/data/TestFlinkParquetWriter.java new file mode 100644 index 000000000000..d0e18b840624 --- /dev/null +++ b/flink/src/test/java/org/apache/iceberg/flink/data/TestFlinkParquetWriter.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.flink.data; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.DataTest; +import org.apache.iceberg.data.RandomGenericData; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.parquet.GenericParquetReaders; +import org.apache.iceberg.flink.FlinkSchemaUtil; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.rules.TemporaryFolder; + +public class TestFlinkParquetWriter extends DataTest { + private static final int NUM_RECORDS = 100; + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private void writeAndValidate(Iterable iterable, Schema schema) throws IOException { + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + LogicalType logicalType = FlinkSchemaUtil.convert(schema); + + try (FileAppender writer = Parquet.write(Files.localOutput(testFile)) + .schema(schema) + .createWriterFunc(msgType -> FlinkParquetWriters.buildWriter(logicalType, msgType)) + .build()) { + writer.addAll(iterable); + } + + try (CloseableIterable reader = Parquet.read(Files.localInput(testFile)) + .project(schema) + .createReaderFunc(msgType -> GenericParquetReaders.buildReader(schema, msgType)) + .build()) { + Iterator expected = iterable.iterator(); + Iterator actual = reader.iterator(); + LogicalType rowType = FlinkSchemaUtil.convert(schema); + for (int i = 0; i < NUM_RECORDS; i += 1) { + Assert.assertTrue("Should have expected number of rows", actual.hasNext()); + TestHelpers.assertRowData(schema.asStruct(), rowType, actual.next(), expected.next()); + } + Assert.assertFalse("Should not have extra rows", actual.hasNext()); + } + } + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + writeAndValidate( + RandomRowData.generate(schema, NUM_RECORDS, 19981), schema); + + writeAndValidate(RandomRowData.convert(schema, + RandomGenericData.generateDictionaryEncodableRecords(schema, NUM_RECORDS, 21124)), + schema); + + writeAndValidate(RandomRowData.convert(schema, + RandomGenericData.generateFallbackRecords(schema, NUM_RECORDS, 21124, NUM_RECORDS / 20)), + schema); + } +} diff --git a/flink/src/test/java/org/apache/iceberg/flink/data/TestHelpers.java b/flink/src/test/java/org/apache/iceberg/flink/data/TestHelpers.java index be427ce868b8..7f471eb634ec 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/data/TestHelpers.java +++ b/flink/src/test/java/org/apache/iceberg/flink/data/TestHelpers.java @@ -29,7 +29,6 @@ import java.util.List; import java.util.Map; import java.util.UUID; -import java.util.function.Supplier; import org.apache.flink.table.data.ArrayData; import org.apache.flink.table.data.DecimalData; import org.apache.flink.table.data.MapData; @@ -66,15 +65,12 @@ public static void assertRowData(Types.StructType structType, LogicalType rowTyp for (int i = 0; i < types.size(); i += 1) { Object expected = expectedRecord.get(i); LogicalType logicalType = ((RowType) rowType).getTypeAt(i); - - final int fieldPos = i; assertEquals(types.get(i), logicalType, expected, - () -> RowData.createFieldGetter(logicalType, fieldPos).getFieldOrNull(actualRowData)); + RowData.createFieldGetter(logicalType, i).getFieldOrNull(actualRowData)); } } - private static void assertEquals(Type type, LogicalType logicalType, Object expected, Supplier supplier) { - Object actual = supplier.get(); + private static void assertEquals(Type type, LogicalType logicalType, Object expected, Object actual) { if (expected == null && actual == null) { return; @@ -177,9 +173,8 @@ private static void assertArrayValues(Type type, LogicalType logicalType, Collec Object expected = expectedElements.get(i); - final int pos = i; assertEquals(type, logicalType, expected, - () -> ArrayData.createElementGetter(logicalType).getElementOrNull(actualArray, pos)); + ArrayData.createElementGetter(logicalType).getElementOrNull(actualArray, i)); } } @@ -202,7 +197,7 @@ private static void assertMapValues(Types.MapType mapType, LogicalType type, Map for (int i = 0; i < actual.size(); i += 1) { try { Object key = keyGetter.getElementOrNull(actualKeyArrayData, i); - assertEquals(keyType, actualKeyType, entry.getKey(), () -> key); + assertEquals(keyType, actualKeyType, entry.getKey(), key); matchedActualKey = key; matchedKeyIndex = i; break; @@ -213,7 +208,7 @@ private static void assertMapValues(Types.MapType mapType, LogicalType type, Map Assert.assertNotNull("Should have a matching key", matchedActualKey); final int valueIndex = matchedKeyIndex; assertEquals(valueType, actualValueType, entry.getValue(), - () -> valueGetter.getElementOrNull(actualValueArrayData, valueIndex)); + valueGetter.getElementOrNull(actualValueArrayData, valueIndex)); } } }