diff --git a/api/src/main/java/org/apache/iceberg/util/StructProjection.java b/api/src/main/java/org/apache/iceberg/util/StructProjection.java index 27b03e2f0b93..704effe6c712 100644 --- a/api/src/main/java/org/apache/iceberg/util/StructProjection.java +++ b/api/src/main/java/org/apache/iceberg/util/StructProjection.java @@ -155,6 +155,12 @@ public int size() { @Override public T get(int pos, Class javaClass) { + if (struct == null) { + // Return a null struct when projecting a nested required field from an optional struct. + // See more details in issue #2738. + return null; + } + int structPos = positionMap[pos]; if (nestedProjections[pos] != null) { diff --git a/flink/src/main/java/org/apache/iceberg/flink/data/RowDataProjection.java b/flink/src/main/java/org/apache/iceberg/flink/data/RowDataProjection.java new file mode 100644 index 000000000000..6334a00fd0d7 --- /dev/null +++ b/flink/src/main/java/org/apache/iceberg/flink/data/RowDataProjection.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.flink.data; + +import java.util.Map; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.DecimalData; +import org.apache.flink.table.data.MapData; +import org.apache.flink.table.data.RawValueData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.TimestampData; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.types.RowKind; +import org.apache.iceberg.Schema; +import org.apache.iceberg.flink.FlinkSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; + +public class RowDataProjection implements RowData { + /** + * Creates a projecting wrapper for {@link RowData} rows. + *

+ * This projection will not project the nested children types of repeated types like lists and maps. + * + * @param schema schema of rows wrapped by this projection + * @param projectedSchema result schema of the projected rows + * @return a wrapper to project rows + */ + public static RowDataProjection create(Schema schema, Schema projectedSchema) { + return RowDataProjection.create(FlinkSchemaUtil.convert(schema), schema.asStruct(), projectedSchema.asStruct()); + } + + /** + * Creates a projecting wrapper for {@link RowData} rows. + *

+ * This projection will not project the nested children types of repeated types like lists and maps. + * + * @param rowType flink row type of rows wrapped by this projection + * @param schema schema of rows wrapped by this projection + * @param projectedSchema result schema of the projected rows + * @return a wrapper to project rows + */ + public static RowDataProjection create(RowType rowType, Types.StructType schema, Types.StructType projectedSchema) { + return new RowDataProjection(rowType, schema, projectedSchema); + } + + private final RowData.FieldGetter[] getters; + private RowData rowData; + + private RowDataProjection(RowType rowType, Types.StructType rowStruct, Types.StructType projectType) { + Map fieldIdToPosition = Maps.newHashMap(); + for (int i = 0; i < rowStruct.fields().size(); i++) { + fieldIdToPosition.put(rowStruct.fields().get(i).fieldId(), i); + } + + this.getters = new RowData.FieldGetter[projectType.fields().size()]; + for (int i = 0; i < getters.length; i++) { + Types.NestedField projectField = projectType.fields().get(i); + Types.NestedField rowField = rowStruct.field(projectField.fieldId()); + + Preconditions.checkNotNull(rowField, + "Cannot locate the project field <%s> in the iceberg struct <%s>", projectField, rowStruct); + + getters[i] = createFieldGetter(rowType, fieldIdToPosition.get(projectField.fieldId()), rowField, projectField); + } + } + + private static RowData.FieldGetter createFieldGetter(RowType rowType, + int position, + Types.NestedField rowField, + Types.NestedField projectField) { + Preconditions.checkArgument(rowField.type().typeId() == projectField.type().typeId(), + "Different iceberg type between row field <%s> and project field <%s>", rowField, projectField); + + switch (projectField.type().typeId()) { + case STRUCT: + RowType nestedRowType = (RowType) rowType.getTypeAt(position); + return row -> { + RowData nestedRow = row.isNullAt(position) ? null : row.getRow(position, nestedRowType.getFieldCount()); + return RowDataProjection + .create(nestedRowType, rowField.type().asStructType(), projectField.type().asStructType()) + .wrap(nestedRow); + }; + + case MAP: + Types.MapType projectedMap = projectField.type().asMapType(); + Types.MapType originalMap = rowField.type().asMapType(); + + boolean keyProjectable = !projectedMap.keyType().isNestedType() || + projectedMap.keyType().equals(originalMap.keyType()); + boolean valueProjectable = !projectedMap.valueType().isNestedType() || + projectedMap.valueType().equals(originalMap.valueType()); + Preconditions.checkArgument(keyProjectable && valueProjectable, + "Cannot project a partial map key or value with non-primitive type. Trying to project <%s> out of <%s>", + projectField, rowField); + + return RowData.createFieldGetter(rowType.getTypeAt(position), position); + + case LIST: + Types.ListType projectedList = projectField.type().asListType(); + Types.ListType originalList = rowField.type().asListType(); + + boolean elementProjectable = !projectedList.elementType().isNestedType() || + projectedList.elementType().equals(originalList.elementType()); + Preconditions.checkArgument(elementProjectable, + "Cannot project a partial list element with non-primitive type. Trying to project <%s> out of <%s>", + projectField, rowField); + + return RowData.createFieldGetter(rowType.getTypeAt(position), position); + + default: + return RowData.createFieldGetter(rowType.getTypeAt(position), position); + } + } + + public RowData wrap(RowData row) { + this.rowData = row; + return this; + } + + private Object getValue(int pos) { + return getters[pos].getFieldOrNull(rowData); + } + + @Override + public int getArity() { + return getters.length; + } + + @Override + public RowKind getRowKind() { + return rowData.getRowKind(); + } + + @Override + public void setRowKind(RowKind kind) { + throw new UnsupportedOperationException("Cannot set row kind in the RowDataProjection"); + } + + @Override + public boolean isNullAt(int pos) { + return rowData == null || getValue(pos) == null; + } + + @Override + public boolean getBoolean(int pos) { + return (boolean) getValue(pos); + } + + @Override + public byte getByte(int pos) { + return (byte) getValue(pos); + } + + @Override + public short getShort(int pos) { + return (short) getValue(pos); + } + + @Override + public int getInt(int pos) { + return (int) getValue(pos); + } + + @Override + public long getLong(int pos) { + return (long) getValue(pos); + } + + @Override + public float getFloat(int pos) { + return (float) getValue(pos); + } + + @Override + public double getDouble(int pos) { + return (double) getValue(pos); + } + + @Override + public StringData getString(int pos) { + return (StringData) getValue(pos); + } + + @Override + public DecimalData getDecimal(int pos, int precision, int scale) { + return (DecimalData) getValue(pos); + } + + @Override + public TimestampData getTimestamp(int pos, int precision) { + return (TimestampData) getValue(pos); + } + + @Override + @SuppressWarnings("unchecked") + public RawValueData getRawValue(int pos) { + return (RawValueData) getValue(pos); + } + + @Override + public byte[] getBinary(int pos) { + return (byte[]) getValue(pos); + } + + @Override + public ArrayData getArray(int pos) { + return (ArrayData) getValue(pos); + } + + @Override + public MapData getMap(int pos) { + return (MapData) getValue(pos); + } + + @Override + public RowData getRow(int pos, int numFields) { + return (RowData) getValue(pos); + } +} diff --git a/flink/src/main/java/org/apache/iceberg/flink/source/RowDataFileScanTaskReader.java b/flink/src/main/java/org/apache/iceberg/flink/source/RowDataFileScanTaskReader.java index fbdb7bf3cc02..08f2f51e5d9c 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/source/RowDataFileScanTaskReader.java +++ b/flink/src/main/java/org/apache/iceberg/flink/source/RowDataFileScanTaskReader.java @@ -22,6 +22,7 @@ import java.util.Map; import org.apache.flink.annotation.Internal; import org.apache.flink.table.data.RowData; +import org.apache.flink.table.types.logical.RowType; import org.apache.iceberg.FileScanTask; import org.apache.iceberg.MetadataColumns; import org.apache.iceberg.Schema; @@ -34,6 +35,7 @@ import org.apache.iceberg.flink.data.FlinkAvroReader; import org.apache.iceberg.flink.data.FlinkOrcReader; import org.apache.iceberg.flink.data.FlinkParquetReaders; +import org.apache.iceberg.flink.data.RowDataProjection; import org.apache.iceberg.flink.data.RowDataUtil; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.CloseableIterator; @@ -70,9 +72,18 @@ public CloseableIterator open(FileScanTask task, InputFilesDecryptor in PartitionUtil.constantsMap(task, RowDataUtil::convertConstant); FlinkDeleteFilter deletes = new FlinkDeleteFilter(task, tableSchema, projectedSchema, inputFilesDecryptor); - return deletes - .filter(newIterable(task, deletes.requiredSchema(), idToConstant, inputFilesDecryptor)) - .iterator(); + CloseableIterable iterable = deletes.filter( + newIterable(task, deletes.requiredSchema(), idToConstant, inputFilesDecryptor) + ); + + // Project the RowData to remove the extra meta columns. + if (!projectedSchema.sameSchema(deletes.requiredSchema())) { + RowDataProjection rowDataProjection = RowDataProjection.create( + deletes.requiredRowType(), deletes.requiredSchema().asStruct(), projectedSchema.asStruct()); + iterable = CloseableIterable.transform(iterable, rowDataProjection::wrap); + } + + return iterable.iterator(); } private CloseableIterable newIterable( @@ -156,16 +167,22 @@ private CloseableIterable newOrcIterable( } private static class FlinkDeleteFilter extends DeleteFilter { + private final RowType requiredRowType; private final RowDataWrapper asStructLike; private final InputFilesDecryptor inputFilesDecryptor; FlinkDeleteFilter(FileScanTask task, Schema tableSchema, Schema requestedSchema, InputFilesDecryptor inputFilesDecryptor) { super(task, tableSchema, requestedSchema); - this.asStructLike = new RowDataWrapper(FlinkSchemaUtil.convert(requiredSchema()), requiredSchema().asStruct()); + this.requiredRowType = FlinkSchemaUtil.convert(requiredSchema()); + this.asStructLike = new RowDataWrapper(requiredRowType, requiredSchema().asStruct()); this.inputFilesDecryptor = inputFilesDecryptor; } + public RowType requiredRowType() { + return requiredRowType; + } + @Override protected StructLike asStructLike(RowData row) { return asStructLike.wrap(row); diff --git a/flink/src/test/java/org/apache/iceberg/flink/TestChangeLogTable.java b/flink/src/test/java/org/apache/iceberg/flink/TestChangeLogTable.java index d44f45ab52fd..68b706e2d281 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/TestChangeLogTable.java +++ b/flink/src/test/java/org/apache/iceberg/flink/TestChangeLogTable.java @@ -38,6 +38,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.util.StructLikeSet; import org.junit.After; import org.junit.Assert; @@ -125,10 +126,10 @@ public void testSqlChangeLogOnIdKey() throws Exception { ) ); - List> expectedRecordsPerCheckpoint = ImmutableList.of( - ImmutableList.of(record(1, "bbb"), record(2, "bbb")), - ImmutableList.of(record(1, "bbb"), record(2, "ddd")), - ImmutableList.of(record(1, "ddd"), record(2, "ddd")) + List> expectedRecordsPerCheckpoint = ImmutableList.of( + ImmutableList.of(insertRow(1, "bbb"), insertRow(2, "bbb")), + ImmutableList.of(insertRow(1, "bbb"), insertRow(2, "ddd")), + ImmutableList.of(insertRow(1, "ddd"), insertRow(2, "ddd")) ); testSqlChangeLog(TABLE_NAME, ImmutableList.of("id"), inputRowsPerCheckpoint, @@ -157,10 +158,10 @@ public void testChangeLogOnDataKey() throws Exception { ) ); - List> expectedRecords = ImmutableList.of( - ImmutableList.of(record(1, "bbb"), record(2, "aaa")), - ImmutableList.of(record(1, "aaa"), record(1, "bbb"), record(1, "ccc")), - ImmutableList.of(record(1, "aaa"), record(1, "ccc"), record(2, "aaa"), record(2, "ccc")) + List> expectedRecords = ImmutableList.of( + ImmutableList.of(insertRow(1, "bbb"), insertRow(2, "aaa")), + ImmutableList.of(insertRow(1, "aaa"), insertRow(1, "bbb"), insertRow(1, "ccc")), + ImmutableList.of(insertRow(1, "aaa"), insertRow(1, "ccc"), insertRow(2, "aaa"), insertRow(2, "ccc")) ); testSqlChangeLog(TABLE_NAME, ImmutableList.of("data"), elementsPerCheckpoint, expectedRecords); @@ -187,10 +188,10 @@ public void testChangeLogOnIdDataKey() throws Exception { ) ); - List> expectedRecords = ImmutableList.of( - ImmutableList.of(record(1, "bbb"), record(2, "aaa"), record(2, "bbb")), - ImmutableList.of(record(1, "aaa"), record(1, "bbb"), record(1, "ccc"), record(2, "bbb")), - ImmutableList.of(record(1, "aaa"), record(1, "ccc"), record(2, "aaa"), record(2, "bbb")) + List> expectedRecords = ImmutableList.of( + ImmutableList.of(insertRow(1, "bbb"), insertRow(2, "aaa"), insertRow(2, "bbb")), + ImmutableList.of(insertRow(1, "aaa"), insertRow(1, "bbb"), insertRow(1, "ccc"), insertRow(2, "bbb")), + ImmutableList.of(insertRow(1, "aaa"), insertRow(1, "ccc"), insertRow(2, "aaa"), insertRow(2, "bbb")) ); testSqlChangeLog(TABLE_NAME, ImmutableList.of("data", "id"), elementsPerCheckpoint, expectedRecords); @@ -213,31 +214,31 @@ public void testPureInsertOnIdKey() throws Exception { ) ); - List> expectedRecords = ImmutableList.of( + List> expectedRecords = ImmutableList.of( ImmutableList.of( - record(1, "aaa"), - record(2, "bbb") + insertRow(1, "aaa"), + insertRow(2, "bbb") ), ImmutableList.of( - record(1, "aaa"), - record(2, "bbb"), - record(3, "ccc"), - record(4, "ddd") + insertRow(1, "aaa"), + insertRow(2, "bbb"), + insertRow(3, "ccc"), + insertRow(4, "ddd") ), ImmutableList.of( - record(1, "aaa"), - record(2, "bbb"), - record(3, "ccc"), - record(4, "ddd"), - record(5, "eee"), - record(6, "fff") + insertRow(1, "aaa"), + insertRow(2, "bbb"), + insertRow(3, "ccc"), + insertRow(4, "ddd"), + insertRow(5, "eee"), + insertRow(6, "fff") ) ); testSqlChangeLog(TABLE_NAME, ImmutableList.of("data"), elementsPerCheckpoint, expectedRecords); } - private Record record(int id, String data) { + private static Record record(int id, String data) { return SimpleDataUtil.createRecord(id, data); } @@ -261,7 +262,7 @@ private Table createTable(String tableName, List key, boolean isPartitio private void testSqlChangeLog(String tableName, List key, List> inputRowsPerCheckpoint, - List> expectedRecordsPerCheckpoint) throws Exception { + List> expectedRecordsPerCheckpoint) throws Exception { String dataId = BoundedTableFactory.registerDataSet(inputRowsPerCheckpoint); sql("CREATE TABLE %s(id INT NOT NULL, data STRING NOT NULL)" + " WITH ('connector'='BoundedSource', 'data-id'='%s')", SOURCE_TABLE, dataId); @@ -280,9 +281,15 @@ private void testSqlChangeLog(String tableName, for (int i = 0; i < expectedSnapshotNum; i++) { long snapshotId = snapshots.get(i).snapshotId(); - List expectedRecords = expectedRecordsPerCheckpoint.get(i); + List expectedRows = expectedRecordsPerCheckpoint.get(i); Assert.assertEquals("Should have the expected records for the checkpoint#" + i, - expectedRowSet(table, expectedRecords), actualRowSet(table, snapshotId)); + expectedRowSet(table, expectedRows), actualRowSet(table, snapshotId)); + } + + if (expectedSnapshotNum > 0) { + Assert.assertEquals("Should have the expected rows in the final table", + Sets.newHashSet(expectedRecordsPerCheckpoint.get(expectedSnapshotNum - 1)), + Sets.newHashSet(sql("SELECT * FROM %s", tableName))); } } @@ -296,8 +303,12 @@ private List findValidSnapshots(Table table) { return validSnapshots; } - private static StructLikeSet expectedRowSet(Table table, List records) { - return SimpleDataUtil.expectedRowSet(table, records.toArray(new Record[0])); + private static StructLikeSet expectedRowSet(Table table, List rows) { + Record[] records = new Record[rows.size()]; + for (int i = 0; i < records.length; i++) { + records[i] = record((int) rows.get(i).getField(0), (String) rows.get(i).getField(1)); + } + return SimpleDataUtil.expectedRowSet(table, records); } private static StructLikeSet actualRowSet(Table table, long snapshotId) throws IOException { diff --git a/flink/src/test/java/org/apache/iceberg/flink/TestHelpers.java b/flink/src/test/java/org/apache/iceberg/flink/TestHelpers.java index 7099c864cb34..c1d17f5c036a 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/TestHelpers.java +++ b/flink/src/test/java/org/apache/iceberg/flink/TestHelpers.java @@ -50,6 +50,7 @@ import org.apache.iceberg.ContentFile; import org.apache.iceberg.ManifestFile; import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; import org.apache.iceberg.data.Record; import org.apache.iceberg.flink.data.RowDataUtil; import org.apache.iceberg.flink.source.FlinkInputFormat; @@ -116,7 +117,11 @@ public static void assertRows(List results, List expected) { Assert.assertEquals(expected, results); } - public static void assertRowData(Types.StructType structType, LogicalType rowType, Record expectedRecord, + public static void assertRowData(Schema schema, StructLike expected, RowData actual) { + assertRowData(schema.asStruct(), FlinkSchemaUtil.convert(schema), expected, actual); + } + + public static void assertRowData(Types.StructType structType, LogicalType rowType, StructLike expectedRecord, RowData actualRowData) { if (expectedRecord == null && actualRowData == null) { return; @@ -131,10 +136,15 @@ public static void assertRowData(Types.StructType structType, LogicalType rowTyp } for (int i = 0; i < types.size(); i += 1) { - Object expected = expectedRecord.get(i); LogicalType logicalType = ((RowType) rowType).getTypeAt(i); - assertEquals(types.get(i), logicalType, expected, - RowData.createFieldGetter(logicalType, i).getFieldOrNull(actualRowData)); + Object expected = expectedRecord.get(i, Object.class); + // The RowData.createFieldGetter won't return null for the required field. But in the projection case, if we are + // projecting a nested required field from an optional struct, then we should give a null for the projected field + // if the outer struct value is null. So we need to check the nullable for actualRowData here. For more details + // please see issue #2738. + Object actual = actualRowData.isNullAt(i) ? null : + RowData.createFieldGetter(logicalType, i).getFieldOrNull(actualRowData); + assertEquals(types.get(i), logicalType, expected, actual); } } @@ -213,8 +223,8 @@ private static void assertEquals(Type type, LogicalType logicalType, Object expe assertMapValues(type.asMapType(), logicalType, (Map) expected, (MapData) actual); break; case STRUCT: - Assertions.assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); - assertRowData(type.asStructType(), logicalType, (Record) expected, (RowData) actual); + Assertions.assertThat(expected).as("Should expect a Record").isInstanceOf(StructLike.class); + assertRowData(type.asStructType(), logicalType, (StructLike) expected, (RowData) actual); break; case UUID: Assertions.assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); diff --git a/flink/src/test/java/org/apache/iceberg/flink/data/TestRowDataProjection.java b/flink/src/test/java/org/apache/iceberg/flink/data/TestRowDataProjection.java new file mode 100644 index 000000000000..37016adfbdf2 --- /dev/null +++ b/flink/src/test/java/org/apache/iceberg/flink/data/TestRowDataProjection.java @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.flink.data; + +import java.util.Iterator; +import org.apache.flink.table.data.RowData; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.data.RandomGenericData; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.flink.TestHelpers; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.StructProjection; +import org.junit.Assert; +import org.junit.Test; + +public class TestRowDataProjection { + + @Test + public void testFullProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + + generateAndValidate(schema, schema); + } + + @Test + public void testReorderedFullProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + + Schema reordered = new Schema( + Types.NestedField.optional(1, "data", Types.StringType.get()), + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + + generateAndValidate(schema, reordered); + } + + @Test + public void testBasicProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + Schema id = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + Schema data = new Schema( + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + generateAndValidate(schema, id); + generateAndValidate(schema, data); + } + + @Test + public void testEmptyProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + generateAndValidate(schema, schema.select()); + } + + @Test + public void testRename() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + + Schema renamed = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "renamed", Types.StringType.get()) + ); + generateAndValidate(schema, renamed); + } + + @Test + public void testNestedProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(3, "location", Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get()) + )) + ); + + // Project id only. + Schema idOnly = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + generateAndValidate(schema, idOnly); + + // Project lat only. + Schema latOnly = new Schema( + Types.NestedField.optional(3, "location", Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()) + )) + ); + generateAndValidate(schema, latOnly); + + // Project long only. + Schema longOnly = new Schema( + Types.NestedField.optional(3, "location", Types.StructType.of( + Types.NestedField.required(2, "long", Types.FloatType.get()) + )) + ); + generateAndValidate(schema, longOnly); + + // Project location. + Schema locationOnly = schema.select("location"); + generateAndValidate(schema, locationOnly); + } + + @Test + public void testPrimitiveTypeProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()), + Types.NestedField.required(2, "b", Types.BooleanType.get()), + Types.NestedField.optional(3, "i", Types.IntegerType.get()), + Types.NestedField.required(4, "l", Types.LongType.get()), + Types.NestedField.optional(5, "f", Types.FloatType.get()), + Types.NestedField.required(6, "d", Types.DoubleType.get()), + Types.NestedField.optional(7, "date", Types.DateType.get()), + Types.NestedField.optional(8, "time", Types.TimeType.get()), + Types.NestedField.required(9, "ts", Types.TimestampType.withoutZone()), + Types.NestedField.required(10, "ts_tz", Types.TimestampType.withZone()), + Types.NestedField.required(11, "s", Types.StringType.get()), + Types.NestedField.required(12, "fixed", Types.FixedType.ofLength(7)), + Types.NestedField.optional(13, "bytes", Types.BinaryType.get()), + Types.NestedField.required(14, "dec_9_0", Types.DecimalType.of(9, 0)), + Types.NestedField.required(15, "dec_11_2", Types.DecimalType.of(11, 2)), + Types.NestedField.required(16, "dec_38_10", Types.DecimalType.of(38, 10))// maximum precision + ); + + generateAndValidate(schema, schema); + } + + @Test + public void testPrimitiveMapTypeProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(3, "map", Types.MapType.ofOptional( + 1, 2, Types.IntegerType.get(), Types.StringType.get() + )) + ); + + // Project id only. + Schema idOnly = schema.select("id"); + generateAndValidate(schema, idOnly); + + // Project map only. + Schema mapOnly = schema.select("map"); + generateAndValidate(schema, mapOnly); + + // Project all. + generateAndValidate(schema, schema); + } + + @Test + public void testNestedMapTypeProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(7, "map", Types.MapType.ofOptional( + 5, 6, + Types.StructType.of( + Types.NestedField.required(1, "key", Types.LongType.get()), + Types.NestedField.required(2, "keyData", Types.LongType.get()) + ), + Types.StructType.of( + Types.NestedField.required(3, "value", Types.LongType.get()), + Types.NestedField.required(4, "valueData", Types.LongType.get()) + ) + )) + ); + + // Project id only. + Schema idOnly = schema.select("id"); + generateAndValidate(schema, idOnly); + + // Project map only. + Schema mapOnly = schema.select("map"); + generateAndValidate(schema, mapOnly); + + // Project all. + generateAndValidate(schema, schema); + + // Project partial map key. + Schema partialMapKey = new Schema( + Types.NestedField.optional(7, "map", Types.MapType.ofOptional( + 5, 6, + Types.StructType.of( + Types.NestedField.required(1, "key", Types.LongType.get()) + ), + Types.StructType.of( + Types.NestedField.required(3, "value", Types.LongType.get()), + Types.NestedField.required(4, "valueData", Types.LongType.get()) + ) + )) + ); + AssertHelpers.assertThrows("Should not allow to project a partial map key with non-primitive type.", + IllegalArgumentException.class, "Cannot project a partial map key or value", + () -> generateAndValidate(schema, partialMapKey) + ); + + // Project partial map key. + Schema partialMapValue = new Schema( + Types.NestedField.optional(7, "map", Types.MapType.ofOptional( + 5, 6, + Types.StructType.of( + Types.NestedField.required(1, "key", Types.LongType.get()), + Types.NestedField.required(2, "keyData", Types.LongType.get()) + ), + Types.StructType.of( + Types.NestedField.required(3, "value", Types.LongType.get()) + ) + )) + ); + AssertHelpers.assertThrows("Should not allow to project a partial map value with non-primitive type.", + IllegalArgumentException.class, "Cannot project a partial map key or value", + () -> generateAndValidate(schema, partialMapValue) + ); + } + + @Test + public void testPrimitiveListTypeProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(2, "list", Types.ListType.ofOptional( + 1, Types.StringType.get() + )) + ); + + // Project id only. + Schema idOnly = schema.select("id"); + generateAndValidate(schema, idOnly); + + // Project list only. + Schema mapOnly = schema.select("list"); + generateAndValidate(schema, mapOnly); + + // Project all. + generateAndValidate(schema, schema); + } + + @Test + public void testNestedListTypeProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(5, "list", Types.ListType.ofOptional( + 4, Types.StructType.of( + Types.NestedField.required(1, "nestedListField1", Types.LongType.get()), + Types.NestedField.required(2, "nestedListField2", Types.LongType.get()), + Types.NestedField.required(3, "nestedListField3", Types.LongType.get()) + ) + )) + ); + + // Project id only. + Schema idOnly = schema.select("id"); + generateAndValidate(schema, idOnly); + + // Project list only. + Schema mapOnly = schema.select("list"); + generateAndValidate(schema, mapOnly); + + // Project all. + generateAndValidate(schema, schema); + + // Project partial list value. + Schema partialList = new Schema( + Types.NestedField.optional(5, "list", Types.ListType.ofOptional( + 4, Types.StructType.of( + Types.NestedField.required(2, "nestedListField2", Types.LongType.get()) + ) + )) + ); + AssertHelpers.assertThrows("Should not allow to project a partial list element with non-primitive type.", + IllegalArgumentException.class, "Cannot project a partial list element", + () -> generateAndValidate(schema, partialList) + ); + } + + private void generateAndValidate(Schema schema, Schema projectSchema) { + int numRecords = 100; + Iterable recordList = RandomGenericData.generate(schema, numRecords, 102L); + Iterable rowDataList = RandomRowData.generate(schema, numRecords, 102L); + + StructProjection structProjection = StructProjection.create(schema, projectSchema); + RowDataProjection rowDataProjection = RowDataProjection.create(schema, projectSchema); + + Iterator recordIter = recordList.iterator(); + Iterator rowDataIter = rowDataList.iterator(); + + for (int i = 0; i < numRecords; i++) { + Assert.assertTrue("Should have more records", recordIter.hasNext()); + Assert.assertTrue("Should have more RowData", rowDataIter.hasNext()); + + StructLike expected = structProjection.wrap(recordIter.next()); + RowData actual = rowDataProjection.wrap(rowDataIter.next()); + + TestHelpers.assertRowData(projectSchema, expected, actual); + } + + Assert.assertFalse("Shouldn't have more record", recordIter.hasNext()); + Assert.assertFalse("Shouldn't have more RowData", rowDataIter.hasNext()); + } +}