diff --git a/api/src/main/java/org/apache/iceberg/types/GetProjectedIds.java b/api/src/main/java/org/apache/iceberg/types/GetProjectedIds.java index d50c45e22e6b..985663bf224e 100644 --- a/api/src/main/java/org/apache/iceberg/types/GetProjectedIds.java +++ b/api/src/main/java/org/apache/iceberg/types/GetProjectedIds.java @@ -25,8 +25,17 @@ import org.apache.iceberg.relocated.com.google.common.collect.Sets; class GetProjectedIds extends TypeUtil.SchemaVisitor> { + private final boolean includeStructIds; private final Set fieldIds = Sets.newHashSet(); + GetProjectedIds() { + this(false); + } + + GetProjectedIds(boolean includeStructIds) { + this.includeStructIds = includeStructIds; + } + @Override public Set schema(Schema schema, Set structResult) { return fieldIds; @@ -39,7 +48,7 @@ public Set struct(Types.StructType struct, List> fieldResu @Override public Set field(Types.NestedField field, Set fieldResult) { - if (fieldResult == null) { + if ((includeStructIds && field.type().isStructType()) || field.type().isPrimitiveType()) { fieldIds.add(field.fieldId()); } return fieldIds; diff --git a/api/src/main/java/org/apache/iceberg/types/PruneColumns.java b/api/src/main/java/org/apache/iceberg/types/PruneColumns.java index f58670365ea4..2944ec7bb5c0 100644 --- a/api/src/main/java/org/apache/iceberg/types/PruneColumns.java +++ b/api/src/main/java/org/apache/iceberg/types/PruneColumns.java @@ -24,13 +24,26 @@ import org.apache.iceberg.Schema; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Types.ListType; +import org.apache.iceberg.types.Types.MapType; +import org.apache.iceberg.types.Types.StructType; class PruneColumns extends TypeUtil.SchemaVisitor { private final Set selected; + private final boolean selectFullTypes; - PruneColumns(Set selected) { + /** + * Visits a schema and returns only the fields selected by the id set. + *

+ * When selectFullTypes is false selecting list or map types is undefined and forbidden. + * + * @param selected ids of elements to return + * @param selectFullTypes whether to select all subfields of a selected nested type + */ + PruneColumns(Set selected, boolean selectFullTypes) { Preconditions.checkNotNull(selected, "Selected field ids cannot be null"); this.selected = selected; + this.selectFullTypes = selectFullTypes; } @Override @@ -77,10 +90,19 @@ public Type struct(Types.StructType struct, List fieldResults) { @Override public Type field(Types.NestedField field, Type fieldResult) { if (selected.contains(field.fieldId())) { - return field.type(); + if (selectFullTypes) { + return field.type(); + } else if (field.type().isStructType()) { + return projectSelectedStruct(fieldResult); + } else { + Preconditions.checkArgument(!field.type().isNestedType(), + "Cannot explicitly project List or Map types, %s:%s of type %s was selected", + field.fieldId(), field.name(), field.type()); + // Selected non-struct field + return field.type(); + } } else if (fieldResult != null) { - // this isn't necessarily the same as field.type() because a struct may not have all - // fields selected. + // This field wasn't selected but a subfield was so include that return fieldResult; } return null; @@ -89,15 +111,19 @@ public Type field(Types.NestedField field, Type fieldResult) { @Override public Type list(Types.ListType list, Type elementResult) { if (selected.contains(list.elementId())) { - return list; - } else if (elementResult != null) { - if (list.elementType() == elementResult) { + if (selectFullTypes) { return list; - } else if (list.isElementOptional()) { - return Types.ListType.ofOptional(list.elementId(), elementResult); + } else if (list.elementType().isStructType()) { + StructType projectedStruct = projectSelectedStruct(elementResult); + return projectList(list, projectedStruct); } else { - return Types.ListType.ofRequired(list.elementId(), elementResult); + Preconditions.checkArgument(list.elementType().isPrimitiveType(), + "Cannot explicitly project List or Map types, List element %s of type %s was selected", + list.elementId(), list.elementType()); + return list; } + } else if (elementResult != null) { + return projectList(list, elementResult); } return null; } @@ -105,15 +131,19 @@ public Type list(Types.ListType list, Type elementResult) { @Override public Type map(Types.MapType map, Type ignored, Type valueResult) { if (selected.contains(map.valueId())) { - return map; - } else if (valueResult != null) { - if (map.valueType() == valueResult) { + if (selectFullTypes) { return map; - } else if (map.isValueOptional()) { - return Types.MapType.ofOptional(map.keyId(), map.valueId(), map.keyType(), valueResult); + } else if (map.valueType().isStructType()) { + Type projectedStruct = projectSelectedStruct(valueResult); + return projectMap(map, projectedStruct); } else { - return Types.MapType.ofRequired(map.keyId(), map.valueId(), map.keyType(), valueResult); + Preconditions.checkArgument(map.valueType().isPrimitiveType(), + "Cannot explicitly project List or Map types, Map value %s of type %s was selected", + map.valueId(), map.valueType()); + return map; } + } else if (valueResult != null) { + return projectMap(map, valueResult); } else if (selected.contains(map.keyId())) { // right now, maps can't be selected without values return map; @@ -125,4 +155,44 @@ public Type map(Types.MapType map, Type ignored, Type valueResult) { public Type primitive(Type.PrimitiveType primitive) { return null; } + + private ListType projectList(ListType list, Type elementResult) { + Preconditions.checkArgument(elementResult != null, "Cannot project a list when the element result is null"); + if (list.elementType() == elementResult) { + return list; + } else if (list.isElementOptional()) { + return Types.ListType.ofOptional(list.elementId(), elementResult); + } else { + return Types.ListType.ofRequired(list.elementId(), elementResult); + } + } + + private MapType projectMap(MapType map, Type valueResult) { + Preconditions.checkArgument(valueResult != null, "Attempted to project a map without a defined map value type"); + if (map.valueType() == valueResult) { + return map; + } else if (map.isValueOptional()) { + return Types.MapType.ofOptional(map.keyId(), map.valueId(), map.keyType(), valueResult); + } else { + return Types.MapType.ofRequired(map.keyId(), map.valueId(), map.keyType(), valueResult); + } + } + + /** + * If select full types is disabled we need to recreate the struct with only the selected + * subfields. If no subfields are selected we return an empty struct. + * @param projectedField subfields already selected in this projection + * @return projected struct + */ + private StructType projectSelectedStruct(Type projectedField) { + Preconditions.checkArgument(projectedField == null || projectedField.isStructType()); + // the struct was selected, ensure at least an empty struct is returned + if (projectedField == null) { + // no sub-fields were selected but the struct was, return an empty struct + return Types.StructType.of(); + } else { + // sub-fields were selected so return the projected struct + return projectedField.asStructType(); + } + } } diff --git a/api/src/main/java/org/apache/iceberg/types/TypeUtil.java b/api/src/main/java/org/apache/iceberg/types/TypeUtil.java index 5185038f66fb..80515fdececa 100644 --- a/api/src/main/java/org/apache/iceberg/types/TypeUtil.java +++ b/api/src/main/java/org/apache/iceberg/types/TypeUtil.java @@ -19,6 +19,7 @@ package org.apache.iceberg.types; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; @@ -42,6 +43,46 @@ public class TypeUtil { private TypeUtil() { } + /** + * Project extracts particular fields from a schema by ID. + *

+ * Unlike {@link TypeUtil#select(Schema, Set)}, project will pick out only the fields enumerated. Structs that are + * explicitly projected are empty unless sub-fields are explicitly projected. Maps and lists cannot be explicitly + * selected in fieldIds. + * @param schema to project fields from + * @param fieldIds list of explicit fields to extract + * @return the schema with all fields fields not selected removed + */ + public static Schema project(Schema schema, Set fieldIds) { + Preconditions.checkNotNull(schema, "Schema cannot be null"); + + Types.StructType result = project(schema.asStruct(), fieldIds); + if (schema.asStruct().equals(result)) { + return schema; + } else if (result != null) { + if (schema.getAliases() != null) { + return new Schema(result.fields(), schema.getAliases()); + } else { + return new Schema(result.fields()); + } + } + return new Schema(Collections.emptyList(), schema.getAliases()); + } + + public static Types.StructType project(Types.StructType struct, Set fieldIds) { + Preconditions.checkNotNull(struct, "Struct cannot be null"); + Preconditions.checkNotNull(fieldIds, "Field ids cannot be null"); + + Type result = visit(struct, new PruneColumns(fieldIds, false)); + if (struct.equals(result)) { + return struct; + } else if (result != null) { + return result.asStructType(); + } + + return Types.StructType.of(); + } + public static Schema select(Schema schema, Set fieldIds) { Preconditions.checkNotNull(schema, "Schema cannot be null"); @@ -63,8 +104,8 @@ public static Types.StructType select(Types.StructType struct, Set fiel Preconditions.checkNotNull(struct, "Struct cannot be null"); Preconditions.checkNotNull(fieldIds, "Field ids cannot be null"); - Type result = visit(struct, new PruneColumns(fieldIds)); - if (struct == result) { + Type result = visit(struct, new PruneColumns(fieldIds, true)); + if (struct.equals(result)) { return struct; } else if (result != null) { return result.asStructType(); @@ -74,30 +115,30 @@ public static Types.StructType select(Types.StructType struct, Set fiel } public static Set getProjectedIds(Schema schema) { - return ImmutableSet.copyOf(getIdsInternal(schema.asStruct())); + return ImmutableSet.copyOf(getIdsInternal(schema.asStruct(), true)); } public static Set getProjectedIds(Type type) { if (type.isPrimitiveType()) { return ImmutableSet.of(); } - return ImmutableSet.copyOf(getIdsInternal(type)); + return ImmutableSet.copyOf(getIdsInternal(type, true)); } - private static Set getIdsInternal(Type type) { - return visit(type, new GetProjectedIds()); + private static Set getIdsInternal(Type type, boolean includeStructIds) { + return visit(type, new GetProjectedIds(includeStructIds)); } public static Types.StructType selectNot(Types.StructType struct, Set fieldIds) { - Set projectedIds = getIdsInternal(struct); + Set projectedIds = getIdsInternal(struct, false); projectedIds.removeAll(fieldIds); - return select(struct, projectedIds); + return project(struct, projectedIds); } public static Schema selectNot(Schema schema, Set fieldIds) { - Set projectedIds = getIdsInternal(schema.asStruct()); + Set projectedIds = getIdsInternal(schema.asStruct(), false); projectedIds.removeAll(fieldIds); - return select(schema, projectedIds); + return project(schema, projectedIds); } public static Schema join(Schema left, Schema right) { diff --git a/api/src/main/java/org/apache/iceberg/util/StructProjection.java b/api/src/main/java/org/apache/iceberg/util/StructProjection.java index be05b0fe2db5..704effe6c712 100644 --- a/api/src/main/java/org/apache/iceberg/util/StructProjection.java +++ b/api/src/main/java/org/apache/iceberg/util/StructProjection.java @@ -42,7 +42,7 @@ public class StructProjection implements StructLike { */ public static StructProjection create(Schema schema, Set ids) { StructType structType = schema.asStruct(); - return new StructProjection(structType, TypeUtil.select(structType, ids)); + return new StructProjection(structType, TypeUtil.project(structType, ids)); } /** @@ -58,12 +58,30 @@ public static StructProjection create(Schema dataSchema, Schema projectedSchema) return new StructProjection(dataSchema.asStruct(), projectedSchema.asStruct()); } + /** + * Creates a projecting wrapper for {@link StructLike} rows. + *

+ * This projection allows missing fields and does not work with repeated types like lists and maps. + * + * @param structType type of rows wrapped by this projection + * @param projectedStructType result type of the projected rows + * @return a wrapper to project rows + */ + public static StructProjection createAllowMissing(StructType structType, StructType projectedStructType) { + return new StructProjection(structType, projectedStructType, true); + } + private final StructType type; private final int[] positionMap; private final StructProjection[] nestedProjections; private StructLike struct; private StructProjection(StructType structType, StructType projection) { + this(structType, projection, false); + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") + private StructProjection(StructType structType, StructType projection, boolean allowMissing) { this.type = projection; this.positionMap = new int[projection.fields().size()]; this.nestedProjections = new StructProjection[projection.fields().size()]; @@ -116,7 +134,10 @@ private StructProjection(StructType structType, StructType projection) { } } - if (!found) { + if (!found && projectedField.isOptional() && allowMissing) { + positionMap[pos] = -1; + nestedProjections[pos] = null; + } else if (!found) { throw new IllegalArgumentException(String.format("Cannot find field %s in %s", projectedField, structType)); } } @@ -134,11 +155,23 @@ public int size() { @Override public T get(int pos, Class javaClass) { + if (struct == null) { + // Return a null struct when projecting a nested required field from an optional struct. + // See more details in issue #2738. + return null; + } + + int structPos = positionMap[pos]; + if (nestedProjections[pos] != null) { - return javaClass.cast(nestedProjections[pos].wrap(struct.get(positionMap[pos], StructLike.class))); + return javaClass.cast(nestedProjections[pos].wrap(struct.get(structPos, StructLike.class))); } - return struct.get(positionMap[pos], javaClass); + if (structPos != -1) { + return struct.get(structPos, javaClass); + } else { + return null; + } } @Override diff --git a/api/src/test/java/org/apache/iceberg/types/TestTypeUtil.java b/api/src/test/java/org/apache/iceberg/types/TestTypeUtil.java index c11c859edacf..210efd352f5b 100644 --- a/api/src/test/java/org/apache/iceberg/types/TestTypeUtil.java +++ b/api/src/test/java/org/apache/iceberg/types/TestTypeUtil.java @@ -20,12 +20,16 @@ package org.apache.iceberg.types; +import java.util.Set; +import org.apache.iceberg.AssertHelpers; import org.apache.iceberg.Schema; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.types.Types.IntegerType; import org.junit.Assert; import org.junit.Test; +import static org.apache.iceberg.types.Types.NestedField.optional; import static org.apache.iceberg.types.Types.NestedField.required; @@ -103,6 +107,326 @@ public void testAssignIncreasingFreshIdNewIdentifier() { Sets.newHashSet(sourceSchema.findField("a").fieldId()), actualSchema.identifierFieldIds()); } + @Test + public void testProject() { + Schema schema = new Schema( + Lists.newArrayList( + required(10, "a", Types.IntegerType.get()), + required(11, "A", Types.IntegerType.get()), + required(12, "someStruct", Types.StructType.of( + required(13, "b", Types.IntegerType.get()), + required(14, "B", Types.IntegerType.get()), + required(15, "anotherStruct", Types.StructType.of( + required(16, "c", Types.IntegerType.get()), + required(17, "C", Types.IntegerType.get())) + ))))); + + Schema expectedTop = new Schema( + Lists.newArrayList( + required(11, "A", Types.IntegerType.get()))); + + Schema actualTop = TypeUtil.project(schema, Sets.newHashSet(11)); + Assert.assertEquals(expectedTop.asStruct(), actualTop.asStruct()); + + Schema expectedDepthOne = new Schema( + Lists.newArrayList( + required(10, "a", Types.IntegerType.get()), + required(12, "someStruct", Types.StructType.of( + required(13, "b", Types.IntegerType.get()))))); + + Schema actualDepthOne = TypeUtil.project(schema, Sets.newHashSet(10, 12, 13)); + Assert.assertEquals(expectedDepthOne.asStruct(), actualDepthOne.asStruct()); + + Schema expectedDepthTwo = new Schema( + Lists.newArrayList( + required(11, "A", Types.IntegerType.get()), + required(12, "someStruct", Types.StructType.of( + required(15, "anotherStruct", Types.StructType.of( + required(17, "C", Types.IntegerType.get())) + ))))); + + Schema actualDepthTwo = TypeUtil.project(schema, Sets.newHashSet(11, 12, 15, 17)); + Schema actualDepthTwoChildren = TypeUtil.project(schema, Sets.newHashSet(11, 17)); + Assert.assertEquals(expectedDepthTwo.asStruct(), actualDepthTwo.asStruct()); + Assert.assertEquals(expectedDepthTwo.asStruct(), actualDepthTwoChildren.asStruct()); + } + + @Test + public void testProjectNaturallyEmpty() { + Schema schema = new Schema( + Lists.newArrayList( + required(12, "someStruct", Types.StructType.of( + required(15, "anotherStruct", Types.StructType.of( + required(20, "empty", Types.StructType.of()) + )))))); + + Schema expectedDepthOne = new Schema( + Lists.newArrayList( + required(12, "someStruct", Types.StructType.of()))); + + Schema actualDepthOne = TypeUtil.project(schema, Sets.newHashSet(12)); + Assert.assertEquals(expectedDepthOne.asStruct(), actualDepthOne.asStruct()); + + Schema expectedDepthTwo = new Schema( + Lists.newArrayList( + required(12, "someStruct", Types.StructType.of( + required(15, "anotherStruct", Types.StructType.of()))))); + + Schema actualDepthTwo = TypeUtil.project(schema, Sets.newHashSet(12, 15)); + Assert.assertEquals(expectedDepthTwo.asStruct(), actualDepthTwo.asStruct()); + + Schema expectedDepthThree = new Schema( + Lists.newArrayList( + required(12, "someStruct", Types.StructType.of( + required(15, "anotherStruct", Types.StructType.of( + required(20, "empty", Types.StructType.of()) + )))))); + + Schema actualDepthThree = TypeUtil.project(schema, Sets.newHashSet(12, 15, 20)); + Schema actualDepthThreeChildren = TypeUtil.project(schema, Sets.newHashSet(20)); + Assert.assertEquals(expectedDepthThree.asStruct(), actualDepthThree.asStruct()); + Assert.assertEquals(expectedDepthThree.asStruct(), actualDepthThreeChildren.asStruct()); + } + + @Test + public void testProjectEmpty() { + Schema schema = new Schema( + Lists.newArrayList( + required(10, "a", Types.IntegerType.get()), + required(11, "A", Types.IntegerType.get()), + required(12, "someStruct", Types.StructType.of( + required(13, "b", Types.IntegerType.get()), + required(14, "B", Types.IntegerType.get()), + required(15, "anotherStruct", Types.StructType.of( + required(16, "c", Types.IntegerType.get()), + required(17, "C", Types.IntegerType.get())) + ))))); + + Schema expectedDepthOne = new Schema( + Lists.newArrayList( + required(12, "someStruct", Types.StructType.of()))); + + Schema actualDepthOne = TypeUtil.project(schema, Sets.newHashSet(12)); + Assert.assertEquals(expectedDepthOne.asStruct(), actualDepthOne.asStruct()); + + Schema expectedDepthTwo = new Schema( + Lists.newArrayList( + required(12, "someStruct", Types.StructType.of( + required(15, "anotherStruct", Types.StructType.of()))))); + + Schema actualDepthTwo = TypeUtil.project(schema, Sets.newHashSet(12, 15)); + Assert.assertEquals(expectedDepthTwo.asStruct(), actualDepthTwo.asStruct()); + } + + @Test + public void testSelect() { + Schema schema = new Schema( + Lists.newArrayList( + required(10, "a", Types.IntegerType.get()), + required(11, "A", Types.IntegerType.get()), + required(12, "someStruct", Types.StructType.of( + required(13, "b", Types.IntegerType.get()), + required(14, "B", Types.IntegerType.get()), + required(15, "anotherStruct", Types.StructType.of( + required(16, "c", Types.IntegerType.get()), + required(17, "C", Types.IntegerType.get())) + ))))); + + Schema expectedTop = new Schema( + Lists.newArrayList( + required(11, "A", Types.IntegerType.get()))); + + Schema actualTop = TypeUtil.select(schema, Sets.newHashSet(11)); + Assert.assertEquals(expectedTop.asStruct(), actualTop.asStruct()); + + Schema expectedDepthOne = new Schema( + Lists.newArrayList( + required(10, "a", Types.IntegerType.get()), + required(12, "someStruct", Types.StructType.of( + required(13, "b", Types.IntegerType.get()), + required(14, "B", Types.IntegerType.get()), + required(15, "anotherStruct", Types.StructType.of( + required(16, "c", Types.IntegerType.get()), + required(17, "C", Types.IntegerType.get()))))))); + + Schema actualDepthOne = TypeUtil.select(schema, Sets.newHashSet(10, 12)); + Assert.assertEquals(expectedDepthOne.asStruct(), actualDepthOne.asStruct()); + + Schema expectedDepthTwo = new Schema( + Lists.newArrayList( + required(11, "A", Types.IntegerType.get()), + required(12, "someStruct", Types.StructType.of( + required(15, "anotherStruct", Types.StructType.of( + required(17, "C", Types.IntegerType.get())) + ))))); + + Schema actualDepthTwo = TypeUtil.select(schema, Sets.newHashSet(11, 17)); + Assert.assertEquals(expectedDepthTwo.asStruct(), actualDepthTwo.asStruct()); + } + + @Test + public void testProjectMap() { + // We can't partially project keys because it changes key equality + Schema schema = new Schema( + Lists.newArrayList( + required(10, "a", Types.IntegerType.get()), + required(11, "A", Types.IntegerType.get()), + required(12, "map", Types.MapType.ofRequired(13, 14, + Types.StructType.of( + optional(100, "x", Types.IntegerType.get()), + optional(101, "y", Types.IntegerType.get())), + Types.StructType.of( + required(200, "z", Types.IntegerType.get()), + optional(201, "innerMap", Types.MapType.ofOptional(202, 203, + Types.IntegerType.get(), + Types.StructType.of( + required(300, "foo", Types.IntegerType.get()), + required(301, "bar", Types.IntegerType.get()))))))))); + + Assert.assertThrows("Cannot project maps explicitly", IllegalArgumentException.class, + () -> TypeUtil.project(schema, Sets.newHashSet(12))); + + Assert.assertThrows("Cannot project maps explicitly", IllegalArgumentException.class, + () -> TypeUtil.project(schema, Sets.newHashSet(201))); + + Schema expectedTopLevel = new Schema( + Lists.newArrayList(required(10, "a", Types.IntegerType.get()))); + Schema actualTopLevel = TypeUtil.project(schema, Sets.newHashSet(10)); + Assert.assertEquals(expectedTopLevel.asStruct(), actualTopLevel.asStruct()); + + Schema expectedDepthOne = new Schema( + Lists.newArrayList( + required(10, "a", Types.IntegerType.get()), + required(12, "map", Types.MapType.ofRequired(13, 14, + Types.StructType.of( + optional(100, "x", Types.IntegerType.get()), + optional(101, "y", Types.IntegerType.get())), + Types.StructType.of())))); + Schema actualDepthOne = TypeUtil.project(schema, Sets.newHashSet(10, 13, 14, 100, 101)); + Schema actualDepthOneNoKeys = TypeUtil.project(schema, Sets.newHashSet(10, 13, 14)); + Assert.assertEquals(expectedDepthOne.asStruct(), actualDepthOne.asStruct()); + Assert.assertEquals(expectedDepthOne.asStruct(), actualDepthOneNoKeys.asStruct()); + + Schema expectedDepthTwo = new Schema( + Lists.newArrayList( + required(10, "a", Types.IntegerType.get()), + required(12, "map", Types.MapType.ofRequired(13, 14, + Types.StructType.of( + optional(100, "x", Types.IntegerType.get()), + optional(101, "y", Types.IntegerType.get())), + Types.StructType.of( + required(200, "z", Types.IntegerType.get()), + optional(201, "innerMap", Types.MapType.ofOptional(202, 203, + Types.IntegerType.get(), + Types.StructType.of()))))))); + Schema actualDepthTwo = TypeUtil.project(schema, Sets.newHashSet(10, 13, 14, 100, 101, 200, 202, 203)); + Assert.assertEquals(expectedDepthTwo.asStruct(), actualDepthTwo.asStruct()); + } + + @Test + public void testGetProjectedIds() { + Schema schema = new Schema( + Lists.newArrayList( + required(10, "a", Types.IntegerType.get()), + required(11, "A", Types.IntegerType.get()), + required(35, "emptyStruct", Types.StructType.of()), + required(12, "someStruct", Types.StructType.of( + required(13, "b", Types.IntegerType.get()), + required(14, "B", Types.IntegerType.get()), + required(15, "anotherStruct", Types.StructType.of( + required(16, "c", Types.IntegerType.get()), + required(17, "C", Types.IntegerType.get())) + ))))); + + Set expectedIds = Sets.newHashSet(10, 11, 35, 12, 13, 14, 15, 16, 17); + Set actualIds = TypeUtil.getProjectedIds(schema); + + Assert.assertEquals(expectedIds, actualIds); + } + + @Test + public void testProjectListNested() { + Schema schema = new Schema( + Lists.newArrayList( + required(12, "list", Types.ListType.ofRequired(13, + Types.ListType.ofRequired(14, + Types.MapType.ofRequired(15, 16, + IntegerType.get(), + Types.StructType.of( + required(17, "x", Types.IntegerType.get()), + required(18, "y", Types.IntegerType.get()) + ))))))); + + AssertHelpers.assertThrows("Cannot explicitly project List", + IllegalArgumentException.class, + () -> TypeUtil.project(schema, Sets.newHashSet(12)) + ); + + AssertHelpers.assertThrows("Cannot explicitly project List", + IllegalArgumentException.class, + () -> TypeUtil.project(schema, Sets.newHashSet(13)) + ); + + AssertHelpers.assertThrows("Cannot explicitly project Map", + IllegalArgumentException.class, + () -> TypeUtil.project(schema, Sets.newHashSet(14)) + ); + + Schema expected = new Schema( + Lists.newArrayList( + required(12, "list", Types.ListType.ofRequired(13, + Types.ListType.ofRequired(14, + Types.MapType.ofRequired(15, 16, + IntegerType.get(), + Types.StructType.of())))))); + + Schema actual = TypeUtil.project(schema, Sets.newHashSet(16)); + Assert.assertEquals(expected.asStruct(), actual.asStruct()); + } + + @Test + public void testProjectMapNested() { + Schema schema = new Schema( + Lists.newArrayList( + required(12, "map", Types.MapType.ofRequired(13, 14, + Types.IntegerType.get(), + Types.MapType.ofRequired(15, 16, + Types.IntegerType.get(), + Types.ListType.ofRequired(17, + Types.StructType.of( + required(18, "x", Types.IntegerType.get()), + required(19, "y", Types.IntegerType.get()) + ))))))); + + + AssertHelpers.assertThrows("Cannot explicitly project Map", + IllegalArgumentException.class, + () -> TypeUtil.project(schema, Sets.newHashSet(12)) + ); + + AssertHelpers.assertThrows("Cannot explicitly project Map", + IllegalArgumentException.class, + () -> TypeUtil.project(schema, Sets.newHashSet(14)) + ); + + AssertHelpers.assertThrows("Cannot explicitly project List", + IllegalArgumentException.class, + () -> TypeUtil.project(schema, Sets.newHashSet(16)) + ); + + Schema expected = new Schema( + Lists.newArrayList( + required(12, "map", Types.MapType.ofRequired(13, 14, + Types.IntegerType.get(), + Types.MapType.ofRequired(15, 16, + Types.IntegerType.get(), + Types.ListType.ofRequired(17, + Types.StructType.of())))))); + + Schema actual = TypeUtil.project(schema, Sets.newHashSet(17)); + Assert.assertEquals(expected.asStruct(), actual.asStruct()); + } + @Test(expected = IllegalArgumentException.class) public void testReassignIdsIllegalArgumentException() { Schema schema = new Schema( @@ -128,4 +452,34 @@ public void testValidateSchemaViaIndexByName() { TypeUtil.indexByName(Types.StructType.of(nestedType)); } + + @Test + public void testSelectNot() { + Schema schema = new Schema( + Lists.newArrayList( + required(1, "id", Types.LongType.get()), + required(2, "location", Types.StructType.of( + required(3, "lat", Types.DoubleType.get()), + required(4, "long", Types.DoubleType.get()) + )))); + + Schema expectedNoPrimitive = new Schema( + Lists.newArrayList( + required(2, "location", Types.StructType.of( + required(3, "lat", Types.DoubleType.get()), + required(4, "long", Types.DoubleType.get()) + )))); + + Schema actualNoPrimitve = TypeUtil.selectNot(schema, Sets.newHashSet(1)); + Assert.assertEquals(expectedNoPrimitive.asStruct(), actualNoPrimitve.asStruct()); + + // Expected legacy behavior is to completely remove structs if their elements are removed + Schema expectedNoStructElements = new Schema(required(1, "id", Types.LongType.get())); + Schema actualNoStructElements = TypeUtil.selectNot(schema, Sets.newHashSet(3, 4)); + Assert.assertEquals(expectedNoStructElements.asStruct(), actualNoStructElements.asStruct()); + + // Expected legacy behavior is to ignore selectNot on struct elements. + Schema actualNoStruct = TypeUtil.selectNot(schema, Sets.newHashSet(2)); + Assert.assertEquals(schema.asStruct(), actualNoStruct.asStruct()); + } } diff --git a/core/src/main/java/org/apache/iceberg/AllDataFilesTable.java b/core/src/main/java/org/apache/iceberg/AllDataFilesTable.java index fc2dc699a1bd..d1b084f0206c 100644 --- a/core/src/main/java/org/apache/iceberg/AllDataFilesTable.java +++ b/core/src/main/java/org/apache/iceberg/AllDataFilesTable.java @@ -29,6 +29,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Iterables; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types.StructType; import org.apache.iceberg.util.ParallelIterable; import org.apache.iceberg.util.ThreadPools; @@ -56,8 +57,9 @@ public TableScan newScan() { @Override public Schema schema() { - Schema schema = new Schema(DataFile.getType(table().spec().partitionType()).fields()); - if (table().spec().fields().size() < 1) { + StructType partitionType = Partitioning.partitionType(table()); + Schema schema = new Schema(DataFile.getType(partitionType).fields()); + if (partitionType.fields().size() < 1) { // avoid returning an empty struct, which is not always supported. instead, drop the partition field (id 102) return TypeUtil.selectNot(schema, Sets.newHashSet(102)); } else { diff --git a/core/src/main/java/org/apache/iceberg/AllEntriesTable.java b/core/src/main/java/org/apache/iceberg/AllEntriesTable.java index c1b714534def..84c1609fd4e7 100644 --- a/core/src/main/java/org/apache/iceberg/AllEntriesTable.java +++ b/core/src/main/java/org/apache/iceberg/AllEntriesTable.java @@ -29,6 +29,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Iterables; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types.StructType; import org.apache.iceberg.util.ParallelIterable; import org.apache.iceberg.util.ThreadPools; @@ -55,8 +56,9 @@ public TableScan newScan() { @Override public Schema schema() { - Schema schema = ManifestEntry.getSchema(table().spec().partitionType()); - if (table().spec().fields().size() < 1) { + StructType partitionType = Partitioning.partitionType(table()); + Schema schema = ManifestEntry.getSchema(partitionType); + if (partitionType.fields().size() < 1) { // avoid returning an empty struct, which is not always supported. instead, drop the partition field (id 102) return TypeUtil.selectNot(schema, Sets.newHashSet(102)); } else { diff --git a/core/src/main/java/org/apache/iceberg/BaseTableScan.java b/core/src/main/java/org/apache/iceberg/BaseTableScan.java index 356d909f6bba..524276b57427 100644 --- a/core/src/main/java/org/apache/iceberg/BaseTableScan.java +++ b/core/src/main/java/org/apache/iceberg/BaseTableScan.java @@ -296,7 +296,7 @@ private Schema lazyColumnProjection() { } requiredFieldIds.addAll(selectedIds); - return TypeUtil.select(schema, requiredFieldIds); + return TypeUtil.project(schema, requiredFieldIds); } else if (context.projectedSchema() != null) { return context.projectedSchema(); diff --git a/core/src/main/java/org/apache/iceberg/DataFilesTable.java b/core/src/main/java/org/apache/iceberg/DataFilesTable.java index d6b80ee66587..f931c0650081 100644 --- a/core/src/main/java/org/apache/iceberg/DataFilesTable.java +++ b/core/src/main/java/org/apache/iceberg/DataFilesTable.java @@ -30,6 +30,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types.StructType; /** * A {@link Table} implementation that exposes a table's data files as rows. @@ -51,8 +52,9 @@ public TableScan newScan() { @Override public Schema schema() { - Schema schema = new Schema(DataFile.getType(table().spec().partitionType()).fields()); - if (table().spec().fields().size() < 1) { + StructType partitionType = Partitioning.partitionType(table()); + Schema schema = new Schema(DataFile.getType(partitionType).fields()); + if (partitionType.fields().size() < 1) { // avoid returning an empty struct, which is not always supported. instead, drop the partition field return TypeUtil.selectNot(schema, Sets.newHashSet(DataFile.PARTITION_ID)); } else { diff --git a/core/src/main/java/org/apache/iceberg/ManifestEntriesTable.java b/core/src/main/java/org/apache/iceberg/ManifestEntriesTable.java index 7bae3491a787..a44fc6421428 100644 --- a/core/src/main/java/org/apache/iceberg/ManifestEntriesTable.java +++ b/core/src/main/java/org/apache/iceberg/ManifestEntriesTable.java @@ -29,6 +29,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types.StructType; import org.apache.iceberg.util.StructProjection; /** @@ -54,8 +55,9 @@ public TableScan newScan() { @Override public Schema schema() { - Schema schema = ManifestEntry.getSchema(table().spec().partitionType()); - if (table().spec().fields().size() < 1) { + StructType partitionType = Partitioning.partitionType(table()); + Schema schema = ManifestEntry.getSchema(partitionType); + if (partitionType.fields().size() < 1) { // avoid returning an empty struct, which is not always supported. instead, drop the partition field (id 102) return TypeUtil.selectNot(schema, Sets.newHashSet(102)); } else { diff --git a/core/src/main/java/org/apache/iceberg/MetadataColumns.java b/core/src/main/java/org/apache/iceberg/MetadataColumns.java index e1cf096cd003..af7b655b2bfe 100644 --- a/core/src/main/java/org/apache/iceberg/MetadataColumns.java +++ b/core/src/main/java/org/apache/iceberg/MetadataColumns.java @@ -38,6 +38,12 @@ private MetadataColumns() { Integer.MAX_VALUE - 2, "_pos", Types.LongType.get(), "Ordinal position of a row in the source data file"); public static final NestedField IS_DELETED = NestedField.required( Integer.MAX_VALUE - 3, "_deleted", Types.BooleanType.get(), "Whether the row has been deleted"); + public static final NestedField SPEC_ID = NestedField.required( + Integer.MAX_VALUE - 4, "_spec_id", Types.IntegerType.get(), "Spec ID used to track the file containing a row"); + // the partition column type is not static and depends on all specs in the table + public static final int PARTITION_COLUMN_ID = Integer.MAX_VALUE - 5; + public static final String PARTITION_COLUMN_NAME = "_partition"; + public static final String PARTITION_COLUMN_DOC = "Partition to which a row belongs to"; // IDs Integer.MAX_VALUE - (101-200) are used for reserved columns public static final NestedField DELETE_FILE_PATH = NestedField.required( @@ -51,24 +57,39 @@ private MetadataColumns() { private static final Map META_COLUMNS = ImmutableMap.of( FILE_PATH.name(), FILE_PATH, ROW_POSITION.name(), ROW_POSITION, - IS_DELETED.name(), IS_DELETED); + IS_DELETED.name(), IS_DELETED, + SPEC_ID.name(), SPEC_ID + ); - private static final Set META_IDS = META_COLUMNS.values().stream().map(NestedField::fieldId) - .collect(ImmutableSet.toImmutableSet()); + private static final Set META_IDS = ImmutableSet.of( + FILE_PATH.fieldId(), + ROW_POSITION.fieldId(), + IS_DELETED.fieldId(), + SPEC_ID.fieldId(), + PARTITION_COLUMN_ID + ); public static Set metadataFieldIds() { return META_IDS; } - public static NestedField get(String name) { - return META_COLUMNS.get(name); + public static NestedField metadataColumn(Table table, String name) { + if (name.equals(PARTITION_COLUMN_NAME)) { + return Types.NestedField.optional( + PARTITION_COLUMN_ID, + PARTITION_COLUMN_NAME, + Partitioning.partitionType(table), + PARTITION_COLUMN_DOC); + } else { + return META_COLUMNS.get(name); + } } public static boolean isMetadataColumn(String name) { - return META_COLUMNS.containsKey(name); + return name.equals(PARTITION_COLUMN_NAME) || META_COLUMNS.containsKey(name); } public static boolean nonMetadataColumn(String name) { - return !META_COLUMNS.containsKey(name); + return !isMetadataColumn(name); } } diff --git a/core/src/main/java/org/apache/iceberg/Partitioning.java b/core/src/main/java/org/apache/iceberg/Partitioning.java index d393fe180507..6233fbfc4356 100644 --- a/core/src/main/java/org/apache/iceberg/Partitioning.java +++ b/core/src/main/java/org/apache/iceberg/Partitioning.java @@ -19,9 +19,21 @@ package org.apache.iceberg; +import java.util.Collections; +import java.util.Comparator; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.transforms.PartitionSpecVisitor; +import org.apache.iceberg.transforms.Transform; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.transforms.UnknownTransform; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StructType; public class Partitioning { private Partitioning() { @@ -177,4 +189,80 @@ public Void alwaysNull(int fieldId, String sourceName, int sourceId) { return null; } } + + /** + * Builds a common partition type for all specs in a table. + *

+ * Whenever a table has multiple specs, the partition type is a struct containing + * all columns that have ever been a part of any spec in the table. + * + * @param table a table with one or many specs + * @return the constructed common partition type + */ + public static StructType partitionType(Table table) { + // we currently don't know the output type of unknown transforms + List> unknownTransforms = collectUnknownTransforms(table); + ValidationException.check(unknownTransforms.isEmpty(), + "Cannot build table partition type, unknown transforms: %s", unknownTransforms); + + if (table.specs().size() == 1) { + return table.spec().partitionType(); + } + + Map fieldMap = Maps.newHashMap(); + List structFields = Lists.newArrayList(); + + // sort the spec IDs in descending order to pick up the most recent field names + List specIds = table.specs().keySet().stream() + .sorted(Collections.reverseOrder()) + .collect(Collectors.toList()); + + for (Integer specId : specIds) { + PartitionSpec spec = table.specs().get(specId); + + for (PartitionField field : spec.fields()) { + int fieldId = field.fieldId(); + PartitionField existingField = fieldMap.get(fieldId); + + if (existingField == null) { + fieldMap.put(fieldId, field); + NestedField structField = spec.partitionType().field(fieldId); + structFields.add(structField); + } else { + // verify the fields are compatible as they may conflict in v1 tables + ValidationException.check(equivalentIgnoringNames(field, existingField), + "Conflicting partition fields: ['%s', '%s']", + field, existingField); + } + } + } + + List sortedStructFields = structFields.stream() + .sorted(Comparator.comparingInt(NestedField::fieldId)) + .collect(Collectors.toList()); + return StructType.of(sortedStructFields); + } + + private static List> collectUnknownTransforms(Table table) { + List> unknownTransforms = Lists.newArrayList(); + + table.specs().values().forEach(spec -> { + spec.fields().stream() + .map(PartitionField::transform) + .filter(transform -> transform instanceof UnknownTransform) + .forEach(unknownTransforms::add); + }); + + return unknownTransforms; + } + + private static boolean equivalentIgnoringNames(PartitionField field, PartitionField anotherField) { + return field.fieldId() == anotherField.fieldId() && + field.sourceId() == anotherField.sourceId() && + compatibleTransforms(field.transform(), anotherField.transform()); + } + + private static boolean compatibleTransforms(Transform t1, Transform t2) { + return t1.equals(t2) || t1.equals(Transforms.alwaysNull()) || t2.equals(Transforms.alwaysNull()); + } } diff --git a/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java b/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java index 57e2c2709137..91089c3b6714 100644 --- a/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java +++ b/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java @@ -19,12 +19,14 @@ package org.apache.iceberg.avro; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; import org.apache.avro.JsonProperties; import org.apache.avro.Schema; +import org.apache.avro.Schema.Type; import org.apache.avro.SchemaNormalization; import org.apache.iceberg.mapping.NameMapping; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; @@ -81,15 +83,26 @@ public Schema record(Schema record, List names, List fields) { Schema fieldSchema = fields.get(field.pos()); // All primitives are selected by selecting the field, but map and list - // types can be selected by projecting the keys, values, or elements. + // types can be selected by projecting the keys, values, or elements. Empty + // Structs can be selected by selecting the record itself instead of its children. // This creates two conditions where the field should be selected: if the // id is selected or if the result of the field is non-null. The only // case where the converted field is non-null is when a map or list is // selected by lower IDs. if (selectedIds.contains(fieldId)) { - filteredFields.add(copyField(field, field.schema(), fieldId)); + if (fieldSchema != null) { + hasChange = true; // Sub-fields may be different + filteredFields.add(copyField(field, fieldSchema, fieldId)); + } else { + if (isRecord(field.schema())) { + hasChange = true; // Sub-fields are now empty + filteredFields.add(copyField(field, makeEmptyCopy(field.schema()), fieldId)); + } else { + filteredFields.add(copyField(field, field.schema(), fieldId)); + } + } } else if (fieldSchema != null) { - hasChange = true; + hasChange = true; // Sub-fields may be different filteredFields.add(copyField(field, fieldSchema, fieldId)); } } @@ -259,6 +272,26 @@ private static Schema copyRecord(Schema record, List newFields) { return copy; } + private boolean isRecord(Schema field) { + if (AvroSchemaUtil.isOptionSchema(field)) { + return AvroSchemaUtil.fromOption(field).getType().equals(Type.RECORD); + } else { + return field.getType().equals(Type.RECORD); + } + } + + private static Schema makeEmptyCopy(Schema field) { + if (AvroSchemaUtil.isOptionSchema(field)) { + Schema innerSchema = AvroSchemaUtil.fromOption(field); + Schema emptyRecord = Schema.createRecord(innerSchema.getName(), innerSchema.getDoc(), innerSchema.getNamespace(), + innerSchema.isError(), Collections.emptyList()); + return AvroSchemaUtil.toOption(emptyRecord); + } else { + return Schema.createRecord(field.getName(), field.getDoc(), field.getNamespace(), field.isError(), + Collections.emptyList()); + } + } + private static Schema.Field copyField(Schema.Field field, Schema newSchema, Integer fieldId) { Schema newSchemaReordered; // if the newSchema is an optional schema, make sure the NULL option is always the first diff --git a/core/src/main/java/org/apache/iceberg/encryption/InputFilesDecryptor.java b/core/src/main/java/org/apache/iceberg/encryption/InputFilesDecryptor.java new file mode 100644 index 000000000000..6c1e0eb8b250 --- /dev/null +++ b/core/src/main/java/org/apache/iceberg/encryption/InputFilesDecryptor.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.encryption; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.Map; +import java.util.stream.Stream; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; + +public class InputFilesDecryptor { + + private final Map decryptedInputFiles; + + public InputFilesDecryptor(CombinedScanTask combinedTask, FileIO io, EncryptionManager encryption) { + Map keyMetadata = Maps.newHashMap(); + combinedTask.files().stream() + .flatMap(fileScanTask -> Stream.concat(Stream.of(fileScanTask.file()), fileScanTask.deletes().stream())) + .forEach(file -> keyMetadata.put(file.path().toString(), file.keyMetadata())); + Stream encrypted = keyMetadata.entrySet().stream() + .map(entry -> EncryptedFiles.encryptedInput(io.newInputFile(entry.getKey()), entry.getValue())); + + // decrypt with the batch call to avoid multiple RPCs to a key server, if possible + Iterable decryptedFiles = encryption.decrypt(encrypted::iterator); + + Map files = Maps.newHashMapWithExpectedSize(keyMetadata.size()); + decryptedFiles.forEach(decrypted -> files.putIfAbsent(decrypted.location(), decrypted)); + this.decryptedInputFiles = Collections.unmodifiableMap(files); + } + + public InputFile getInputFile(FileScanTask task) { + Preconditions.checkArgument(!task.isDataTask(), "Invalid task type"); + return decryptedInputFiles.get(task.file().path().toString()); + } + + public InputFile getInputFile(String location) { + return decryptedInputFiles.get(location); + } +} diff --git a/core/src/main/java/org/apache/iceberg/util/PartitionUtil.java b/core/src/main/java/org/apache/iceberg/util/PartitionUtil.java index 929f77af4e78..02c8b302dad5 100644 --- a/core/src/main/java/org/apache/iceberg/util/PartitionUtil.java +++ b/core/src/main/java/org/apache/iceberg/util/PartitionUtil.java @@ -36,10 +36,15 @@ private PartitionUtil() { } public static Map constantsMap(FileScanTask task) { - return constantsMap(task, (type, constant) -> constant); + return constantsMap(task, null, (type, constant) -> constant); } public static Map constantsMap(FileScanTask task, BiFunction convertConstant) { + return constantsMap(task, null, convertConstant); + } + + public static Map constantsMap(FileScanTask task, Types.StructType partitionType, + BiFunction convertConstant) { PartitionSpec spec = task.spec(); StructLike partitionData = task.file().partition(); @@ -51,6 +56,22 @@ private PartitionUtil() { MetadataColumns.FILE_PATH.fieldId(), convertConstant.apply(Types.StringType.get(), task.file().path())); + // add _spec_id + idToConstant.put( + MetadataColumns.SPEC_ID.fieldId(), + convertConstant.apply(Types.IntegerType.get(), task.file().specId())); + + // add _partition + if (partitionType != null) { + if (partitionType.fields().size() > 0) { + StructLike coercedPartition = coercePartition(partitionType, spec, partitionData); + idToConstant.put(MetadataColumns.PARTITION_COLUMN_ID, convertConstant.apply(partitionType, coercedPartition)); + } else { + // use null as some query engines may not be able to handle empty structs + idToConstant.put(MetadataColumns.PARTITION_COLUMN_ID, null); + } + } + List partitionFields = spec.partitionType().fields(); List fields = spec.fields(); for (int pos = 0; pos < fields.size(); pos += 1) { @@ -63,4 +84,11 @@ private PartitionUtil() { return idToConstant; } + + // adapts the provided partition data to match the table partition type + private static StructLike coercePartition(Types.StructType partitionType, PartitionSpec spec, StructLike partition) { + StructProjection projection = StructProjection.createAllowMissing(spec.partitionType(), partitionType); + projection.wrap(partition); + return projection; + } } diff --git a/core/src/test/java/org/apache/iceberg/TestPartitioning.java b/core/src/test/java/org/apache/iceberg/TestPartitioning.java new file mode 100644 index 000000000000..2610ad5c01cd --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/TestPartitioning.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg; + +import java.io.File; +import java.io.IOException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StructType; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestPartitioning { + + private static final int V1_FORMAT_VERSION = 1; + private static final int V2_FORMAT_VERSION = 2; + private static final Schema SCHEMA = new Schema( + required(1, "id", Types.IntegerType.get()), + required(2, "data", Types.StringType.get()), + required(3, "category", Types.StringType.get()) + ); + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + private File tableDir = null; + + @Before + public void setupTableDir() throws IOException { + this.tableDir = temp.newFolder(); + } + + @After + public void cleanupTables() { + TestTables.clearTables(); + } + + @Test + public void testPartitionTypeWithSpecEvolutionInV1Tables() { + PartitionSpec initialSpec = PartitionSpec.builderFor(SCHEMA) + .identity("data") + .build(); + TestTables.TestTable table = TestTables.create(tableDir, "test", SCHEMA, initialSpec, V1_FORMAT_VERSION); + + table.updateSpec() + .addField(Expressions.bucket("category", 8)) + .commit(); + + Assert.assertEquals("Should have 2 specs", 2, table.specs().size()); + + StructType expectedType = StructType.of( + NestedField.optional(1000, "data", Types.StringType.get()), + NestedField.optional(1001, "category_bucket_8", Types.IntegerType.get()) + ); + StructType actualType = Partitioning.partitionType(table); + Assert.assertEquals("Types must match", expectedType, actualType); + } + + @Test + public void testPartitionTypeWithSpecEvolutionInV2Tables() { + PartitionSpec initialSpec = PartitionSpec.builderFor(SCHEMA) + .identity("data") + .build(); + TestTables.TestTable table = TestTables.create(tableDir, "test", SCHEMA, initialSpec, V2_FORMAT_VERSION); + + table.updateSpec() + .removeField("data") + .addField("category") + .commit(); + + Assert.assertEquals("Should have 2 specs", 2, table.specs().size()); + + StructType expectedType = StructType.of( + NestedField.optional(1000, "data", Types.StringType.get()), + NestedField.optional(1001, "category", Types.StringType.get()) + ); + StructType actualType = Partitioning.partitionType(table); + Assert.assertEquals("Types must match", expectedType, actualType); + } + + @Test + public void testPartitionTypeWithRenamesInV1Table() { + PartitionSpec initialSpec = PartitionSpec.builderFor(SCHEMA) + .identity("data", "p1") + .build(); + TestTables.TestTable table = TestTables.create(tableDir, "test", SCHEMA, initialSpec, V1_FORMAT_VERSION); + + table.updateSpec() + .addField("category") + .commit(); + + table.updateSpec() + .renameField("p1", "p2") + .commit(); + + StructType expectedType = StructType.of( + NestedField.optional(1000, "p2", Types.StringType.get()), + NestedField.optional(1001, "category", Types.StringType.get()) + ); + StructType actualType = Partitioning.partitionType(table); + Assert.assertEquals("Types must match", expectedType, actualType); + } + + @Test + public void testPartitionTypeWithAddingBackSamePartitionFieldInV1Table() { + PartitionSpec initialSpec = PartitionSpec.builderFor(SCHEMA) + .identity("data") + .build(); + TestTables.TestTable table = TestTables.create(tableDir, "test", SCHEMA, initialSpec, V1_FORMAT_VERSION); + + table.updateSpec() + .removeField("data") + .commit(); + + table.updateSpec() + .addField("data") + .commit(); + + // in v1, we use void transforms instead of dropping partition fields + StructType expectedType = StructType.of( + NestedField.optional(1000, "data_1000", Types.StringType.get()), + NestedField.optional(1001, "data", Types.StringType.get()) + ); + StructType actualType = Partitioning.partitionType(table); + Assert.assertEquals("Types must match", expectedType, actualType); + } + + @Test + public void testPartitionTypeWithAddingBackSamePartitionFieldInV2Table() { + PartitionSpec initialSpec = PartitionSpec.builderFor(SCHEMA) + .identity("data") + .build(); + TestTables.TestTable table = TestTables.create(tableDir, "test", SCHEMA, initialSpec, V2_FORMAT_VERSION); + + table.updateSpec() + .removeField("data") + .commit(); + + table.updateSpec() + .addField("data") + .commit(); + + // in v2, we should be able to reuse the original partition spec + StructType expectedType = StructType.of( + NestedField.optional(1000, "data", Types.StringType.get()) + ); + StructType actualType = Partitioning.partitionType(table); + Assert.assertEquals("Types must match", expectedType, actualType); + } + + @Test + public void testPartitionTypeWithIncompatibleSpecEvolution() { + PartitionSpec initialSpec = PartitionSpec.builderFor(SCHEMA) + .identity("data") + .build(); + TestTables.TestTable table = TestTables.create(tableDir, "test", SCHEMA, initialSpec, V1_FORMAT_VERSION); + + PartitionSpec newSpec = PartitionSpec.builderFor(table.schema()) + .identity("category") + .build(); + + TableOperations ops = ((HasTableOperations) table).operations(); + TableMetadata current = ops.current(); + ops.commit(current, current.updatePartitionSpec(newSpec)); + + Assert.assertEquals("Should have 2 specs", 2, table.specs().size()); + + AssertHelpers.assertThrows("Should complain about incompatible specs", + ValidationException.class, "Conflicting partition fields", + () -> Partitioning.partitionType(table)); + } +} diff --git a/core/src/test/java/org/apache/iceberg/TestSchemaUpdate.java b/core/src/test/java/org/apache/iceberg/TestSchemaUpdate.java index 4aa5bfd335ec..4b0bed1cde48 100644 --- a/core/src/test/java/org/apache/iceberg/TestSchemaUpdate.java +++ b/core/src/test/java/org/apache/iceberg/TestSchemaUpdate.java @@ -93,7 +93,7 @@ public void testDeleteFields() { Schema del = new SchemaUpdate(SCHEMA, 19).deleteColumn(name).apply(); Assert.assertEquals("Should match projection with '" + name + "' removed", - TypeUtil.select(SCHEMA, selected).asStruct(), del.asStruct()); + TypeUtil.project(SCHEMA, selected).asStruct(), del.asStruct()); } } diff --git a/core/src/test/java/org/apache/iceberg/avro/TestReadProjection.java b/core/src/test/java/org/apache/iceberg/avro/TestReadProjection.java index e71034483bdc..6c86155a0acd 100644 --- a/core/src/test/java/org/apache/iceberg/avro/TestReadProjection.java +++ b/core/src/test/java/org/apache/iceberg/avro/TestReadProjection.java @@ -526,4 +526,183 @@ public void testListOfStructsProjection() throws IOException { AssertHelpers.assertEmptyAvroField(projectedP2, "y"); Assert.assertNull("Should project null z", projectedP2.get("z")); } + + @Test + public void testEmptyStructProjection() throws Exception { + Schema writeSchema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(3, "location", Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get()) + )) + ); + + Record record = new Record(AvroSchemaUtil.convert(writeSchema, "table")); + record.put("id", 34L); + Record location = new Record( + AvroSchemaUtil.fromOption(record.getSchema().getField("location").schema())); + location.put("lat", 52.995143f); + location.put("long", -1.539054f); + record.put("location", location); + + Schema emptyStruct = new Schema( + Types.NestedField.required(3, "location", Types.StructType.of()) + ); + + Record projected = writeAndRead("empty_proj", writeSchema, emptyStruct, record); + AssertHelpers.assertEmptyAvroField(projected, "id"); + Record result = (Record) projected.get("location"); + + Assert.assertEquals("location should be in the 0th position", result, projected.get(0)); + Assert.assertNotNull("Should contain an empty record", result); + AssertHelpers.assertEmptyAvroField(result, "lat"); + AssertHelpers.assertEmptyAvroField(result, "long"); + } + + @Test + public void testEmptyStructRequiredProjection() throws Exception { + Schema writeSchema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.required(3, "location", Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get()) + )) + ); + + Record record = new Record(AvroSchemaUtil.convert(writeSchema, "table")); + record.put("id", 34L); + Record location = new Record(record.getSchema().getField("location").schema()); + location.put("lat", 52.995143f); + location.put("long", -1.539054f); + record.put("location", location); + + Schema emptyStruct = new Schema( + Types.NestedField.required(3, "location", Types.StructType.of()) + ); + + Record projected = writeAndRead("empty_req_proj", writeSchema, emptyStruct, record); + AssertHelpers.assertEmptyAvroField(projected, "id"); + Record result = (Record) projected.get("location"); + Assert.assertEquals("location should be in the 0th position", result, projected.get(0)); + Assert.assertNotNull("Should contain an empty record", result); + AssertHelpers.assertEmptyAvroField(result, "lat"); + AssertHelpers.assertEmptyAvroField(result, "long"); + } + + @Test + public void testRequiredEmptyStructInRequiredStruct() throws Exception { + Schema writeSchema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.required(3, "location", Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get()), + Types.NestedField.required(4, "empty", Types.StructType.of()) + )) + ); + + Record record = new Record(AvroSchemaUtil.convert(writeSchema, "table")); + record.put("id", 34L); + Record location = new Record(record.getSchema().getField("location").schema()); + location.put("lat", 52.995143f); + location.put("long", -1.539054f); + record.put("location", location); + + Schema emptyStruct = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.required(3, "location", Types.StructType.of( + Types.NestedField.required(4, "empty", Types.StructType.of()) + )) + ); + + Record projected = writeAndRead("req_empty_req_proj", writeSchema, emptyStruct, record); + Assert.assertEquals("Should project id", 34L, projected.get("id")); + Record result = (Record) projected.get("location"); + Assert.assertEquals("location should be in the 1st position", result, projected.get(1)); + Assert.assertNotNull("Should contain an empty record", result); + AssertHelpers.assertEmptyAvroField(result, "lat"); + AssertHelpers.assertEmptyAvroField(result, "long"); + Assert.assertNotNull("Should project empty", result.getSchema().getField("empty")); + Assert.assertNotNull("Empty should not be null", result.get("empty")); + Assert.assertEquals("Empty should be empty", 0, + ((Record) result.get("empty")).getSchema().getFields().size()); + } + + @Test + public void testEmptyNestedStructProjection() throws Exception { + Schema writeSchema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(3, "outer", Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.optional(2, "inner", Types.StructType.of( + Types.NestedField.required(5, "lon", Types.FloatType.get()) + ) + ) + )) + ); + + Record record = new Record(AvroSchemaUtil.convert(writeSchema, "table")); + record.put("id", 34L); + Record outer = new Record( + AvroSchemaUtil.fromOption(record.getSchema().getField("outer").schema())); + Record inner = new Record(AvroSchemaUtil.fromOption(outer.getSchema().getField("inner").schema())); + inner.put("lon", 32.14f); + outer.put("lat", 52.995143f); + outer.put("inner", inner); + record.put("outer", outer); + + Schema emptyStruct = new Schema( + Types.NestedField.required(3, "outer", Types.StructType.of( + Types.NestedField.required(2, "inner", Types.StructType.of()) + ))); + + Record projected = writeAndRead("nested_empty_proj", writeSchema, emptyStruct, record); + AssertHelpers.assertEmptyAvroField(projected, "id"); + Record outerResult = (Record) projected.get("outer"); + Assert.assertEquals("Outer should be in the 0th position", outerResult, projected.get(0)); + Assert.assertNotNull("Should contain the outer record", outerResult); + AssertHelpers.assertEmptyAvroField(outerResult, "lat"); + Record innerResult = (Record) outerResult.get("inner"); + Assert.assertEquals("Inner should be in the 0th position", innerResult, outerResult.get(0)); + Assert.assertNotNull("Should contain the inner record", innerResult); + AssertHelpers.assertEmptyAvroField(innerResult, "lon"); + } + + @Test + public void testEmptyNestedStructRequiredProjection() throws Exception { + Schema writeSchema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.required(3, "outer", Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "inner", Types.StructType.of( + Types.NestedField.required(5, "lon", Types.FloatType.get()) + ) + ) + )) + ); + + Record record = new Record(AvroSchemaUtil.convert(writeSchema, "table")); + record.put("id", 34L); + Record outer = new Record(record.getSchema().getField("outer").schema()); + Record inner = new Record(outer.getSchema().getField("inner").schema()); + inner.put("lon", 32.14f); + outer.put("lat", 52.995143f); + outer.put("inner", inner); + record.put("outer", outer); + + Schema emptyStruct = new Schema( + Types.NestedField.required(3, "outer", Types.StructType.of( + Types.NestedField.required(2, "inner", Types.StructType.of()) + ))); + + Record projected = writeAndRead("nested_empty_req_proj", writeSchema, emptyStruct, record); + AssertHelpers.assertEmptyAvroField(projected, "id"); + Record outerResult = (Record) projected.get("outer"); + Assert.assertEquals("Outer should be in the 0th position", outerResult, projected.get(0)); + Assert.assertNotNull("Should contain the outer record", outerResult); + AssertHelpers.assertEmptyAvroField(outerResult, "lat"); + Record innerResult = (Record) outerResult.get("inner"); + Assert.assertEquals("Inner should be in the 0th position", innerResult, outerResult.get(0)); + Assert.assertNotNull("Should contain the inner record", innerResult); + AssertHelpers.assertEmptyAvroField(innerResult, "lon"); + } } diff --git a/flink/src/main/java/org/apache/iceberg/flink/data/RowDataProjection.java b/flink/src/main/java/org/apache/iceberg/flink/data/RowDataProjection.java new file mode 100644 index 000000000000..6334a00fd0d7 --- /dev/null +++ b/flink/src/main/java/org/apache/iceberg/flink/data/RowDataProjection.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.flink.data; + +import java.util.Map; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.DecimalData; +import org.apache.flink.table.data.MapData; +import org.apache.flink.table.data.RawValueData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.data.TimestampData; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.types.RowKind; +import org.apache.iceberg.Schema; +import org.apache.iceberg.flink.FlinkSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; + +public class RowDataProjection implements RowData { + /** + * Creates a projecting wrapper for {@link RowData} rows. + *

+ * This projection will not project the nested children types of repeated types like lists and maps. + * + * @param schema schema of rows wrapped by this projection + * @param projectedSchema result schema of the projected rows + * @return a wrapper to project rows + */ + public static RowDataProjection create(Schema schema, Schema projectedSchema) { + return RowDataProjection.create(FlinkSchemaUtil.convert(schema), schema.asStruct(), projectedSchema.asStruct()); + } + + /** + * Creates a projecting wrapper for {@link RowData} rows. + *

+ * This projection will not project the nested children types of repeated types like lists and maps. + * + * @param rowType flink row type of rows wrapped by this projection + * @param schema schema of rows wrapped by this projection + * @param projectedSchema result schema of the projected rows + * @return a wrapper to project rows + */ + public static RowDataProjection create(RowType rowType, Types.StructType schema, Types.StructType projectedSchema) { + return new RowDataProjection(rowType, schema, projectedSchema); + } + + private final RowData.FieldGetter[] getters; + private RowData rowData; + + private RowDataProjection(RowType rowType, Types.StructType rowStruct, Types.StructType projectType) { + Map fieldIdToPosition = Maps.newHashMap(); + for (int i = 0; i < rowStruct.fields().size(); i++) { + fieldIdToPosition.put(rowStruct.fields().get(i).fieldId(), i); + } + + this.getters = new RowData.FieldGetter[projectType.fields().size()]; + for (int i = 0; i < getters.length; i++) { + Types.NestedField projectField = projectType.fields().get(i); + Types.NestedField rowField = rowStruct.field(projectField.fieldId()); + + Preconditions.checkNotNull(rowField, + "Cannot locate the project field <%s> in the iceberg struct <%s>", projectField, rowStruct); + + getters[i] = createFieldGetter(rowType, fieldIdToPosition.get(projectField.fieldId()), rowField, projectField); + } + } + + private static RowData.FieldGetter createFieldGetter(RowType rowType, + int position, + Types.NestedField rowField, + Types.NestedField projectField) { + Preconditions.checkArgument(rowField.type().typeId() == projectField.type().typeId(), + "Different iceberg type between row field <%s> and project field <%s>", rowField, projectField); + + switch (projectField.type().typeId()) { + case STRUCT: + RowType nestedRowType = (RowType) rowType.getTypeAt(position); + return row -> { + RowData nestedRow = row.isNullAt(position) ? null : row.getRow(position, nestedRowType.getFieldCount()); + return RowDataProjection + .create(nestedRowType, rowField.type().asStructType(), projectField.type().asStructType()) + .wrap(nestedRow); + }; + + case MAP: + Types.MapType projectedMap = projectField.type().asMapType(); + Types.MapType originalMap = rowField.type().asMapType(); + + boolean keyProjectable = !projectedMap.keyType().isNestedType() || + projectedMap.keyType().equals(originalMap.keyType()); + boolean valueProjectable = !projectedMap.valueType().isNestedType() || + projectedMap.valueType().equals(originalMap.valueType()); + Preconditions.checkArgument(keyProjectable && valueProjectable, + "Cannot project a partial map key or value with non-primitive type. Trying to project <%s> out of <%s>", + projectField, rowField); + + return RowData.createFieldGetter(rowType.getTypeAt(position), position); + + case LIST: + Types.ListType projectedList = projectField.type().asListType(); + Types.ListType originalList = rowField.type().asListType(); + + boolean elementProjectable = !projectedList.elementType().isNestedType() || + projectedList.elementType().equals(originalList.elementType()); + Preconditions.checkArgument(elementProjectable, + "Cannot project a partial list element with non-primitive type. Trying to project <%s> out of <%s>", + projectField, rowField); + + return RowData.createFieldGetter(rowType.getTypeAt(position), position); + + default: + return RowData.createFieldGetter(rowType.getTypeAt(position), position); + } + } + + public RowData wrap(RowData row) { + this.rowData = row; + return this; + } + + private Object getValue(int pos) { + return getters[pos].getFieldOrNull(rowData); + } + + @Override + public int getArity() { + return getters.length; + } + + @Override + public RowKind getRowKind() { + return rowData.getRowKind(); + } + + @Override + public void setRowKind(RowKind kind) { + throw new UnsupportedOperationException("Cannot set row kind in the RowDataProjection"); + } + + @Override + public boolean isNullAt(int pos) { + return rowData == null || getValue(pos) == null; + } + + @Override + public boolean getBoolean(int pos) { + return (boolean) getValue(pos); + } + + @Override + public byte getByte(int pos) { + return (byte) getValue(pos); + } + + @Override + public short getShort(int pos) { + return (short) getValue(pos); + } + + @Override + public int getInt(int pos) { + return (int) getValue(pos); + } + + @Override + public long getLong(int pos) { + return (long) getValue(pos); + } + + @Override + public float getFloat(int pos) { + return (float) getValue(pos); + } + + @Override + public double getDouble(int pos) { + return (double) getValue(pos); + } + + @Override + public StringData getString(int pos) { + return (StringData) getValue(pos); + } + + @Override + public DecimalData getDecimal(int pos, int precision, int scale) { + return (DecimalData) getValue(pos); + } + + @Override + public TimestampData getTimestamp(int pos, int precision) { + return (TimestampData) getValue(pos); + } + + @Override + @SuppressWarnings("unchecked") + public RawValueData getRawValue(int pos) { + return (RawValueData) getValue(pos); + } + + @Override + public byte[] getBinary(int pos) { + return (byte[]) getValue(pos); + } + + @Override + public ArrayData getArray(int pos) { + return (ArrayData) getValue(pos); + } + + @Override + public MapData getMap(int pos) { + return (MapData) getValue(pos); + } + + @Override + public RowData getRow(int pos, int numFields) { + return (RowData) getValue(pos); + } +} diff --git a/flink/src/main/java/org/apache/iceberg/flink/source/DataIterator.java b/flink/src/main/java/org/apache/iceberg/flink/source/DataIterator.java index f74a8968fab8..d470b0752304 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/source/DataIterator.java +++ b/flink/src/main/java/org/apache/iceberg/flink/source/DataIterator.java @@ -21,64 +21,38 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.nio.ByteBuffer; import java.util.Iterator; -import java.util.Map; -import java.util.stream.Stream; +import org.apache.flink.annotation.Internal; import org.apache.iceberg.CombinedScanTask; import org.apache.iceberg.FileScanTask; -import org.apache.iceberg.encryption.EncryptedFiles; -import org.apache.iceberg.encryption.EncryptedInputFile; import org.apache.iceberg.encryption.EncryptionManager; +import org.apache.iceberg.encryption.InputFilesDecryptor; import org.apache.iceberg.io.CloseableIterator; import org.apache.iceberg.io.FileIO; -import org.apache.iceberg.io.InputFile; -import org.apache.iceberg.relocated.com.google.common.base.Preconditions; -import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; -import org.apache.iceberg.relocated.com.google.common.collect.Maps; /** - * Base class of Flink iterators. + * Flink data iterator that reads {@link CombinedScanTask} into a {@link CloseableIterator} * - * @param is the Java class returned by this iterator whose objects contain one or more rows. + * @param is the output data type returned by this iterator. */ -abstract class DataIterator implements CloseableIterator { +@Internal +public class DataIterator implements CloseableIterator { - private Iterator tasks; - private final Map inputFiles; + private final FileScanTaskReader fileScanTaskReader; + private final InputFilesDecryptor inputFilesDecryptor; + private Iterator tasks; private CloseableIterator currentIterator; - DataIterator(CombinedScanTask task, FileIO io, EncryptionManager encryption) { - this.tasks = task.files().iterator(); - - Map keyMetadata = Maps.newHashMap(); - task.files().stream() - .flatMap(fileScanTask -> Stream.concat(Stream.of(fileScanTask.file()), fileScanTask.deletes().stream())) - .forEach(file -> keyMetadata.put(file.path().toString(), file.keyMetadata())); - Stream encrypted = keyMetadata.entrySet().stream() - .map(entry -> EncryptedFiles.encryptedInput(io.newInputFile(entry.getKey()), entry.getValue())); - - // decrypt with the batch call to avoid multiple RPCs to a key server, if possible - Iterable decryptedFiles = encryption.decrypt(encrypted::iterator); - - Map files = Maps.newHashMapWithExpectedSize(task.files().size()); - decryptedFiles.forEach(decrypted -> files.putIfAbsent(decrypted.location(), decrypted)); - this.inputFiles = ImmutableMap.copyOf(files); + public DataIterator(FileScanTaskReader fileScanTaskReader, CombinedScanTask task, + FileIO io, EncryptionManager encryption) { + this.fileScanTaskReader = fileScanTaskReader; + this.inputFilesDecryptor = new InputFilesDecryptor(task, io, encryption); + this.tasks = task.files().iterator(); this.currentIterator = CloseableIterator.empty(); } - InputFile getInputFile(FileScanTask task) { - Preconditions.checkArgument(!task.isDataTask(), "Invalid task type"); - - return inputFiles.get(task.file().path().toString()); - } - - InputFile getInputFile(String location) { - return inputFiles.get(location); - } - @Override public boolean hasNext() { updateCurrentIterator(); @@ -106,7 +80,9 @@ private void updateCurrentIterator() { } } - abstract CloseableIterator openTaskIterator(FileScanTask scanTask) throws IOException; + private CloseableIterator openTaskIterator(FileScanTask scanTask) { + return fileScanTaskReader.open(scanTask, inputFilesDecryptor); + } @Override public void close() throws IOException { diff --git a/flink/src/main/java/org/apache/iceberg/flink/source/FileScanTaskReader.java b/flink/src/main/java/org/apache/iceberg/flink/source/FileScanTaskReader.java new file mode 100644 index 000000000000..04273016ee2d --- /dev/null +++ b/flink/src/main/java/org/apache/iceberg/flink/source/FileScanTaskReader.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.flink.source; + +import java.io.Serializable; +import org.apache.flink.annotation.Internal; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.encryption.InputFilesDecryptor; +import org.apache.iceberg.io.CloseableIterator; + +/** + * Read a {@link FileScanTask} into a {@link CloseableIterator} + * + * @param is the output data type returned by this iterator. + */ +@Internal +public interface FileScanTaskReader extends Serializable { + CloseableIterator open(FileScanTask fileScanTask, InputFilesDecryptor inputFilesDecryptor); +} diff --git a/flink/src/main/java/org/apache/iceberg/flink/source/FlinkInputFormat.java b/flink/src/main/java/org/apache/iceberg/flink/source/FlinkInputFormat.java index 1bad1c25952e..8b757ac31606 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/source/FlinkInputFormat.java +++ b/flink/src/main/java/org/apache/iceberg/flink/source/FlinkInputFormat.java @@ -42,21 +42,22 @@ public class FlinkInputFormat extends RichInputFormat private static final long serialVersionUID = 1L; private final TableLoader tableLoader; - private final Schema tableSchema; private final FileIO io; private final EncryptionManager encryption; private final ScanContext context; + private final RowDataFileScanTaskReader rowDataReader; - private transient RowDataIterator iterator; + private transient DataIterator iterator; private transient long currentReadCount = 0L; FlinkInputFormat(TableLoader tableLoader, Schema tableSchema, FileIO io, EncryptionManager encryption, ScanContext context) { this.tableLoader = tableLoader; - this.tableSchema = tableSchema; this.io = io; this.encryption = encryption; this.context = context; + this.rowDataReader = new RowDataFileScanTaskReader(tableSchema, + context.project(), context.nameMapping(), context.caseSensitive()); } @VisibleForTesting @@ -91,9 +92,7 @@ public void configure(Configuration parameters) { @Override public void open(FlinkInputSplit split) { - this.iterator = new RowDataIterator( - split.getTask(), io, encryption, tableSchema, context.project(), context.nameMapping(), - context.caseSensitive()); + this.iterator = new DataIterator<>(rowDataReader, split.getTask(), io, encryption); } @Override diff --git a/flink/src/main/java/org/apache/iceberg/flink/source/RowDataIterator.java b/flink/src/main/java/org/apache/iceberg/flink/source/RowDataFileScanTaskReader.java similarity index 62% rename from flink/src/main/java/org/apache/iceberg/flink/source/RowDataIterator.java rename to flink/src/main/java/org/apache/iceberg/flink/source/RowDataFileScanTaskReader.java index 5a568144d1f7..08f2f51e5d9c 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/source/RowDataIterator.java +++ b/flink/src/main/java/org/apache/iceberg/flink/source/RowDataFileScanTaskReader.java @@ -20,24 +20,25 @@ package org.apache.iceberg.flink.source; import java.util.Map; +import org.apache.flink.annotation.Internal; import org.apache.flink.table.data.RowData; -import org.apache.iceberg.CombinedScanTask; +import org.apache.flink.table.types.logical.RowType; import org.apache.iceberg.FileScanTask; import org.apache.iceberg.MetadataColumns; import org.apache.iceberg.Schema; import org.apache.iceberg.StructLike; import org.apache.iceberg.avro.Avro; import org.apache.iceberg.data.DeleteFilter; -import org.apache.iceberg.encryption.EncryptionManager; +import org.apache.iceberg.encryption.InputFilesDecryptor; import org.apache.iceberg.flink.FlinkSchemaUtil; import org.apache.iceberg.flink.RowDataWrapper; import org.apache.iceberg.flink.data.FlinkAvroReader; import org.apache.iceberg.flink.data.FlinkOrcReader; import org.apache.iceberg.flink.data.FlinkParquetReaders; +import org.apache.iceberg.flink.data.RowDataProjection; import org.apache.iceberg.flink.data.RowDataUtil; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.CloseableIterator; -import org.apache.iceberg.io.FileIO; import org.apache.iceberg.io.InputFile; import org.apache.iceberg.mapping.NameMappingParser; import org.apache.iceberg.orc.ORC; @@ -47,16 +48,16 @@ import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.util.PartitionUtil; -class RowDataIterator extends DataIterator { +@Internal +public class RowDataFileScanTaskReader implements FileScanTaskReader { private final Schema tableSchema; private final Schema projectedSchema; private final String nameMapping; private final boolean caseSensitive; - RowDataIterator(CombinedScanTask task, FileIO io, EncryptionManager encryption, Schema tableSchema, - Schema projectedSchema, String nameMapping, boolean caseSensitive) { - super(task, io, encryption); + public RowDataFileScanTaskReader(Schema tableSchema, Schema projectedSchema, + String nameMapping, boolean caseSensitive) { this.tableSchema = tableSchema; this.projectedSchema = projectedSchema; this.nameMapping = nameMapping; @@ -64,34 +65,44 @@ class RowDataIterator extends DataIterator { } @Override - protected CloseableIterator openTaskIterator(FileScanTask task) { + public CloseableIterator open(FileScanTask task, InputFilesDecryptor inputFilesDecryptor) { Schema partitionSchema = TypeUtil.select(projectedSchema, task.spec().identitySourceIds()); Map idToConstant = partitionSchema.columns().isEmpty() ? ImmutableMap.of() : PartitionUtil.constantsMap(task, RowDataUtil::convertConstant); - FlinkDeleteFilter deletes = new FlinkDeleteFilter(task, tableSchema, projectedSchema); - CloseableIterable iterable = deletes.filter(newIterable(task, deletes.requiredSchema(), idToConstant)); + FlinkDeleteFilter deletes = new FlinkDeleteFilter(task, tableSchema, projectedSchema, inputFilesDecryptor); + CloseableIterable iterable = deletes.filter( + newIterable(task, deletes.requiredSchema(), idToConstant, inputFilesDecryptor) + ); + + // Project the RowData to remove the extra meta columns. + if (!projectedSchema.sameSchema(deletes.requiredSchema())) { + RowDataProjection rowDataProjection = RowDataProjection.create( + deletes.requiredRowType(), deletes.requiredSchema().asStruct(), projectedSchema.asStruct()); + iterable = CloseableIterable.transform(iterable, rowDataProjection::wrap); + } return iterable.iterator(); } - private CloseableIterable newIterable(FileScanTask task, Schema schema, Map idToConstant) { + private CloseableIterable newIterable( + FileScanTask task, Schema schema, Map idToConstant, InputFilesDecryptor inputFilesDecryptor) { CloseableIterable iter; if (task.isDataTask()) { throw new UnsupportedOperationException("Cannot read data task."); } else { switch (task.file().format()) { case PARQUET: - iter = newParquetIterable(task, schema, idToConstant); + iter = newParquetIterable(task, schema, idToConstant, inputFilesDecryptor); break; case AVRO: - iter = newAvroIterable(task, schema, idToConstant); + iter = newAvroIterable(task, schema, idToConstant, inputFilesDecryptor); break; case ORC: - iter = newOrcIterable(task, schema, idToConstant); + iter = newOrcIterable(task, schema, idToConstant, inputFilesDecryptor); break; default: @@ -103,8 +114,9 @@ private CloseableIterable newIterable(FileScanTask task, Schema schema, return iter; } - private CloseableIterable newAvroIterable(FileScanTask task, Schema schema, Map idToConstant) { - Avro.ReadBuilder builder = Avro.read(getInputFile(task)) + private CloseableIterable newAvroIterable( + FileScanTask task, Schema schema, Map idToConstant, InputFilesDecryptor inputFilesDecryptor) { + Avro.ReadBuilder builder = Avro.read(inputFilesDecryptor.getInputFile(task)) .reuseContainers() .project(schema) .split(task.start(), task.length()) @@ -117,9 +129,9 @@ private CloseableIterable newAvroIterable(FileScanTask task, Schema sch return builder.build(); } - private CloseableIterable newParquetIterable(FileScanTask task, Schema schema, - Map idToConstant) { - Parquet.ReadBuilder builder = Parquet.read(getInputFile(task)) + private CloseableIterable newParquetIterable( + FileScanTask task, Schema schema, Map idToConstant, InputFilesDecryptor inputFilesDecryptor) { + Parquet.ReadBuilder builder = Parquet.read(inputFilesDecryptor.getInputFile(task)) .reuseContainers() .split(task.start(), task.length()) .project(schema) @@ -135,11 +147,12 @@ private CloseableIterable newParquetIterable(FileScanTask task, Schema return builder.build(); } - private CloseableIterable newOrcIterable(FileScanTask task, Schema schema, Map idToConstant) { + private CloseableIterable newOrcIterable( + FileScanTask task, Schema schema, Map idToConstant, InputFilesDecryptor inputFilesDecryptor) { Schema readSchemaWithoutConstantAndMetadataFields = TypeUtil.selectNot(schema, Sets.union(idToConstant.keySet(), MetadataColumns.metadataFieldIds())); - ORC.ReadBuilder builder = ORC.read(getInputFile(task)) + ORC.ReadBuilder builder = ORC.read(inputFilesDecryptor.getInputFile(task)) .project(readSchemaWithoutConstantAndMetadataFields) .split(task.start(), task.length()) .createReaderFunc(readOrcSchema -> new FlinkOrcReader(schema, readOrcSchema, idToConstant)) @@ -153,12 +166,21 @@ private CloseableIterable newOrcIterable(FileScanTask task, Schema sche return builder.build(); } - private class FlinkDeleteFilter extends DeleteFilter { + private static class FlinkDeleteFilter extends DeleteFilter { + private final RowType requiredRowType; private final RowDataWrapper asStructLike; + private final InputFilesDecryptor inputFilesDecryptor; - FlinkDeleteFilter(FileScanTask task, Schema tableSchema, Schema requestedSchema) { + FlinkDeleteFilter(FileScanTask task, Schema tableSchema, Schema requestedSchema, + InputFilesDecryptor inputFilesDecryptor) { super(task, tableSchema, requestedSchema); - this.asStructLike = new RowDataWrapper(FlinkSchemaUtil.convert(requiredSchema()), requiredSchema().asStruct()); + this.requiredRowType = FlinkSchemaUtil.convert(requiredSchema()); + this.asStructLike = new RowDataWrapper(requiredRowType, requiredSchema().asStruct()); + this.inputFilesDecryptor = inputFilesDecryptor; + } + + public RowType requiredRowType() { + return requiredRowType; } @Override @@ -168,7 +190,7 @@ protected StructLike asStructLike(RowData row) { @Override protected InputFile getInputFile(String location) { - return RowDataIterator.this.getInputFile(location); + return inputFilesDecryptor.getInputFile(location); } } } diff --git a/flink/src/main/java/org/apache/iceberg/flink/source/RowDataRewriter.java b/flink/src/main/java/org/apache/iceberg/flink/source/RowDataRewriter.java index a6cd374c3044..752035e4ea3b 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/source/RowDataRewriter.java +++ b/flink/src/main/java/org/apache/iceberg/flink/source/RowDataRewriter.java @@ -99,6 +99,7 @@ public static class RewriteMap extends RichMapFunction taskWriterFactory; + private final RowDataFileScanTaskReader rowDataReader; public RewriteMap(Schema schema, String nameMapping, FileIO io, boolean caseSensitive, EncryptionManager encryptionManager, TaskWriterFactory taskWriterFactory) { @@ -108,6 +109,7 @@ public RewriteMap(Schema schema, String nameMapping, FileIO io, boolean caseSens this.caseSensitive = caseSensitive; this.encryptionManager = encryptionManager; this.taskWriterFactory = taskWriterFactory; + this.rowDataReader = new RowDataFileScanTaskReader(schema, schema, nameMapping, caseSensitive); } @Override @@ -122,8 +124,8 @@ public void open(Configuration parameters) { public List map(CombinedScanTask task) throws Exception { // Initialize the task writer. this.writer = taskWriterFactory.create(); - try (RowDataIterator iterator = - new RowDataIterator(task, io, encryptionManager, schema, schema, nameMapping, caseSensitive)) { + try (DataIterator iterator = + new DataIterator<>(rowDataReader, task, io, encryptionManager)) { while (iterator.hasNext()) { RowData rowData = iterator.next(); writer.write(rowData); diff --git a/flink/src/test/java/org/apache/iceberg/flink/TestChangeLogTable.java b/flink/src/test/java/org/apache/iceberg/flink/TestChangeLogTable.java index d44f45ab52fd..68b706e2d281 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/TestChangeLogTable.java +++ b/flink/src/test/java/org/apache/iceberg/flink/TestChangeLogTable.java @@ -38,6 +38,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.util.StructLikeSet; import org.junit.After; import org.junit.Assert; @@ -125,10 +126,10 @@ public void testSqlChangeLogOnIdKey() throws Exception { ) ); - List> expectedRecordsPerCheckpoint = ImmutableList.of( - ImmutableList.of(record(1, "bbb"), record(2, "bbb")), - ImmutableList.of(record(1, "bbb"), record(2, "ddd")), - ImmutableList.of(record(1, "ddd"), record(2, "ddd")) + List> expectedRecordsPerCheckpoint = ImmutableList.of( + ImmutableList.of(insertRow(1, "bbb"), insertRow(2, "bbb")), + ImmutableList.of(insertRow(1, "bbb"), insertRow(2, "ddd")), + ImmutableList.of(insertRow(1, "ddd"), insertRow(2, "ddd")) ); testSqlChangeLog(TABLE_NAME, ImmutableList.of("id"), inputRowsPerCheckpoint, @@ -157,10 +158,10 @@ public void testChangeLogOnDataKey() throws Exception { ) ); - List> expectedRecords = ImmutableList.of( - ImmutableList.of(record(1, "bbb"), record(2, "aaa")), - ImmutableList.of(record(1, "aaa"), record(1, "bbb"), record(1, "ccc")), - ImmutableList.of(record(1, "aaa"), record(1, "ccc"), record(2, "aaa"), record(2, "ccc")) + List> expectedRecords = ImmutableList.of( + ImmutableList.of(insertRow(1, "bbb"), insertRow(2, "aaa")), + ImmutableList.of(insertRow(1, "aaa"), insertRow(1, "bbb"), insertRow(1, "ccc")), + ImmutableList.of(insertRow(1, "aaa"), insertRow(1, "ccc"), insertRow(2, "aaa"), insertRow(2, "ccc")) ); testSqlChangeLog(TABLE_NAME, ImmutableList.of("data"), elementsPerCheckpoint, expectedRecords); @@ -187,10 +188,10 @@ public void testChangeLogOnIdDataKey() throws Exception { ) ); - List> expectedRecords = ImmutableList.of( - ImmutableList.of(record(1, "bbb"), record(2, "aaa"), record(2, "bbb")), - ImmutableList.of(record(1, "aaa"), record(1, "bbb"), record(1, "ccc"), record(2, "bbb")), - ImmutableList.of(record(1, "aaa"), record(1, "ccc"), record(2, "aaa"), record(2, "bbb")) + List> expectedRecords = ImmutableList.of( + ImmutableList.of(insertRow(1, "bbb"), insertRow(2, "aaa"), insertRow(2, "bbb")), + ImmutableList.of(insertRow(1, "aaa"), insertRow(1, "bbb"), insertRow(1, "ccc"), insertRow(2, "bbb")), + ImmutableList.of(insertRow(1, "aaa"), insertRow(1, "ccc"), insertRow(2, "aaa"), insertRow(2, "bbb")) ); testSqlChangeLog(TABLE_NAME, ImmutableList.of("data", "id"), elementsPerCheckpoint, expectedRecords); @@ -213,31 +214,31 @@ public void testPureInsertOnIdKey() throws Exception { ) ); - List> expectedRecords = ImmutableList.of( + List> expectedRecords = ImmutableList.of( ImmutableList.of( - record(1, "aaa"), - record(2, "bbb") + insertRow(1, "aaa"), + insertRow(2, "bbb") ), ImmutableList.of( - record(1, "aaa"), - record(2, "bbb"), - record(3, "ccc"), - record(4, "ddd") + insertRow(1, "aaa"), + insertRow(2, "bbb"), + insertRow(3, "ccc"), + insertRow(4, "ddd") ), ImmutableList.of( - record(1, "aaa"), - record(2, "bbb"), - record(3, "ccc"), - record(4, "ddd"), - record(5, "eee"), - record(6, "fff") + insertRow(1, "aaa"), + insertRow(2, "bbb"), + insertRow(3, "ccc"), + insertRow(4, "ddd"), + insertRow(5, "eee"), + insertRow(6, "fff") ) ); testSqlChangeLog(TABLE_NAME, ImmutableList.of("data"), elementsPerCheckpoint, expectedRecords); } - private Record record(int id, String data) { + private static Record record(int id, String data) { return SimpleDataUtil.createRecord(id, data); } @@ -261,7 +262,7 @@ private Table createTable(String tableName, List key, boolean isPartitio private void testSqlChangeLog(String tableName, List key, List> inputRowsPerCheckpoint, - List> expectedRecordsPerCheckpoint) throws Exception { + List> expectedRecordsPerCheckpoint) throws Exception { String dataId = BoundedTableFactory.registerDataSet(inputRowsPerCheckpoint); sql("CREATE TABLE %s(id INT NOT NULL, data STRING NOT NULL)" + " WITH ('connector'='BoundedSource', 'data-id'='%s')", SOURCE_TABLE, dataId); @@ -280,9 +281,15 @@ private void testSqlChangeLog(String tableName, for (int i = 0; i < expectedSnapshotNum; i++) { long snapshotId = snapshots.get(i).snapshotId(); - List expectedRecords = expectedRecordsPerCheckpoint.get(i); + List expectedRows = expectedRecordsPerCheckpoint.get(i); Assert.assertEquals("Should have the expected records for the checkpoint#" + i, - expectedRowSet(table, expectedRecords), actualRowSet(table, snapshotId)); + expectedRowSet(table, expectedRows), actualRowSet(table, snapshotId)); + } + + if (expectedSnapshotNum > 0) { + Assert.assertEquals("Should have the expected rows in the final table", + Sets.newHashSet(expectedRecordsPerCheckpoint.get(expectedSnapshotNum - 1)), + Sets.newHashSet(sql("SELECT * FROM %s", tableName))); } } @@ -296,8 +303,12 @@ private List findValidSnapshots(Table table) { return validSnapshots; } - private static StructLikeSet expectedRowSet(Table table, List records) { - return SimpleDataUtil.expectedRowSet(table, records.toArray(new Record[0])); + private static StructLikeSet expectedRowSet(Table table, List rows) { + Record[] records = new Record[rows.size()]; + for (int i = 0; i < records.length; i++) { + records[i] = record((int) rows.get(i).getField(0), (String) rows.get(i).getField(1)); + } + return SimpleDataUtil.expectedRowSet(table, records); } private static StructLikeSet actualRowSet(Table table, long snapshotId) throws IOException { diff --git a/flink/src/test/java/org/apache/iceberg/flink/TestHelpers.java b/flink/src/test/java/org/apache/iceberg/flink/TestHelpers.java index 7099c864cb34..c1d17f5c036a 100644 --- a/flink/src/test/java/org/apache/iceberg/flink/TestHelpers.java +++ b/flink/src/test/java/org/apache/iceberg/flink/TestHelpers.java @@ -50,6 +50,7 @@ import org.apache.iceberg.ContentFile; import org.apache.iceberg.ManifestFile; import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; import org.apache.iceberg.data.Record; import org.apache.iceberg.flink.data.RowDataUtil; import org.apache.iceberg.flink.source.FlinkInputFormat; @@ -116,7 +117,11 @@ public static void assertRows(List results, List expected) { Assert.assertEquals(expected, results); } - public static void assertRowData(Types.StructType structType, LogicalType rowType, Record expectedRecord, + public static void assertRowData(Schema schema, StructLike expected, RowData actual) { + assertRowData(schema.asStruct(), FlinkSchemaUtil.convert(schema), expected, actual); + } + + public static void assertRowData(Types.StructType structType, LogicalType rowType, StructLike expectedRecord, RowData actualRowData) { if (expectedRecord == null && actualRowData == null) { return; @@ -131,10 +136,15 @@ public static void assertRowData(Types.StructType structType, LogicalType rowTyp } for (int i = 0; i < types.size(); i += 1) { - Object expected = expectedRecord.get(i); LogicalType logicalType = ((RowType) rowType).getTypeAt(i); - assertEquals(types.get(i), logicalType, expected, - RowData.createFieldGetter(logicalType, i).getFieldOrNull(actualRowData)); + Object expected = expectedRecord.get(i, Object.class); + // The RowData.createFieldGetter won't return null for the required field. But in the projection case, if we are + // projecting a nested required field from an optional struct, then we should give a null for the projected field + // if the outer struct value is null. So we need to check the nullable for actualRowData here. For more details + // please see issue #2738. + Object actual = actualRowData.isNullAt(i) ? null : + RowData.createFieldGetter(logicalType, i).getFieldOrNull(actualRowData); + assertEquals(types.get(i), logicalType, expected, actual); } } @@ -213,8 +223,8 @@ private static void assertEquals(Type type, LogicalType logicalType, Object expe assertMapValues(type.asMapType(), logicalType, (Map) expected, (MapData) actual); break; case STRUCT: - Assertions.assertThat(expected).as("Should expect a Record").isInstanceOf(Record.class); - assertRowData(type.asStructType(), logicalType, (Record) expected, (RowData) actual); + Assertions.assertThat(expected).as("Should expect a Record").isInstanceOf(StructLike.class); + assertRowData(type.asStructType(), logicalType, (StructLike) expected, (RowData) actual); break; case UUID: Assertions.assertThat(expected).as("Should expect a UUID").isInstanceOf(UUID.class); diff --git a/flink/src/test/java/org/apache/iceberg/flink/data/TestRowDataProjection.java b/flink/src/test/java/org/apache/iceberg/flink/data/TestRowDataProjection.java new file mode 100644 index 000000000000..37016adfbdf2 --- /dev/null +++ b/flink/src/test/java/org/apache/iceberg/flink/data/TestRowDataProjection.java @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.flink.data; + +import java.util.Iterator; +import org.apache.flink.table.data.RowData; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.data.RandomGenericData; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.flink.TestHelpers; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.StructProjection; +import org.junit.Assert; +import org.junit.Test; + +public class TestRowDataProjection { + + @Test + public void testFullProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + + generateAndValidate(schema, schema); + } + + @Test + public void testReorderedFullProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + + Schema reordered = new Schema( + Types.NestedField.optional(1, "data", Types.StringType.get()), + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + + generateAndValidate(schema, reordered); + } + + @Test + public void testBasicProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + Schema id = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + Schema data = new Schema( + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + generateAndValidate(schema, id); + generateAndValidate(schema, data); + } + + @Test + public void testEmptyProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + generateAndValidate(schema, schema.select()); + } + + @Test + public void testRename() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()) + ); + + Schema renamed = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "renamed", Types.StringType.get()) + ); + generateAndValidate(schema, renamed); + } + + @Test + public void testNestedProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(3, "location", Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.required(2, "long", Types.FloatType.get()) + )) + ); + + // Project id only. + Schema idOnly = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()) + ); + generateAndValidate(schema, idOnly); + + // Project lat only. + Schema latOnly = new Schema( + Types.NestedField.optional(3, "location", Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()) + )) + ); + generateAndValidate(schema, latOnly); + + // Project long only. + Schema longOnly = new Schema( + Types.NestedField.optional(3, "location", Types.StructType.of( + Types.NestedField.required(2, "long", Types.FloatType.get()) + )) + ); + generateAndValidate(schema, longOnly); + + // Project location. + Schema locationOnly = schema.select("location"); + generateAndValidate(schema, locationOnly); + } + + @Test + public void testPrimitiveTypeProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(1, "data", Types.StringType.get()), + Types.NestedField.required(2, "b", Types.BooleanType.get()), + Types.NestedField.optional(3, "i", Types.IntegerType.get()), + Types.NestedField.required(4, "l", Types.LongType.get()), + Types.NestedField.optional(5, "f", Types.FloatType.get()), + Types.NestedField.required(6, "d", Types.DoubleType.get()), + Types.NestedField.optional(7, "date", Types.DateType.get()), + Types.NestedField.optional(8, "time", Types.TimeType.get()), + Types.NestedField.required(9, "ts", Types.TimestampType.withoutZone()), + Types.NestedField.required(10, "ts_tz", Types.TimestampType.withZone()), + Types.NestedField.required(11, "s", Types.StringType.get()), + Types.NestedField.required(12, "fixed", Types.FixedType.ofLength(7)), + Types.NestedField.optional(13, "bytes", Types.BinaryType.get()), + Types.NestedField.required(14, "dec_9_0", Types.DecimalType.of(9, 0)), + Types.NestedField.required(15, "dec_11_2", Types.DecimalType.of(11, 2)), + Types.NestedField.required(16, "dec_38_10", Types.DecimalType.of(38, 10))// maximum precision + ); + + generateAndValidate(schema, schema); + } + + @Test + public void testPrimitiveMapTypeProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(3, "map", Types.MapType.ofOptional( + 1, 2, Types.IntegerType.get(), Types.StringType.get() + )) + ); + + // Project id only. + Schema idOnly = schema.select("id"); + generateAndValidate(schema, idOnly); + + // Project map only. + Schema mapOnly = schema.select("map"); + generateAndValidate(schema, mapOnly); + + // Project all. + generateAndValidate(schema, schema); + } + + @Test + public void testNestedMapTypeProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(7, "map", Types.MapType.ofOptional( + 5, 6, + Types.StructType.of( + Types.NestedField.required(1, "key", Types.LongType.get()), + Types.NestedField.required(2, "keyData", Types.LongType.get()) + ), + Types.StructType.of( + Types.NestedField.required(3, "value", Types.LongType.get()), + Types.NestedField.required(4, "valueData", Types.LongType.get()) + ) + )) + ); + + // Project id only. + Schema idOnly = schema.select("id"); + generateAndValidate(schema, idOnly); + + // Project map only. + Schema mapOnly = schema.select("map"); + generateAndValidate(schema, mapOnly); + + // Project all. + generateAndValidate(schema, schema); + + // Project partial map key. + Schema partialMapKey = new Schema( + Types.NestedField.optional(7, "map", Types.MapType.ofOptional( + 5, 6, + Types.StructType.of( + Types.NestedField.required(1, "key", Types.LongType.get()) + ), + Types.StructType.of( + Types.NestedField.required(3, "value", Types.LongType.get()), + Types.NestedField.required(4, "valueData", Types.LongType.get()) + ) + )) + ); + AssertHelpers.assertThrows("Should not allow to project a partial map key with non-primitive type.", + IllegalArgumentException.class, "Cannot project a partial map key or value", + () -> generateAndValidate(schema, partialMapKey) + ); + + // Project partial map key. + Schema partialMapValue = new Schema( + Types.NestedField.optional(7, "map", Types.MapType.ofOptional( + 5, 6, + Types.StructType.of( + Types.NestedField.required(1, "key", Types.LongType.get()), + Types.NestedField.required(2, "keyData", Types.LongType.get()) + ), + Types.StructType.of( + Types.NestedField.required(3, "value", Types.LongType.get()) + ) + )) + ); + AssertHelpers.assertThrows("Should not allow to project a partial map value with non-primitive type.", + IllegalArgumentException.class, "Cannot project a partial map key or value", + () -> generateAndValidate(schema, partialMapValue) + ); + } + + @Test + public void testPrimitiveListTypeProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(2, "list", Types.ListType.ofOptional( + 1, Types.StringType.get() + )) + ); + + // Project id only. + Schema idOnly = schema.select("id"); + generateAndValidate(schema, idOnly); + + // Project list only. + Schema mapOnly = schema.select("list"); + generateAndValidate(schema, mapOnly); + + // Project all. + generateAndValidate(schema, schema); + } + + @Test + public void testNestedListTypeProjection() { + Schema schema = new Schema( + Types.NestedField.required(0, "id", Types.LongType.get()), + Types.NestedField.optional(5, "list", Types.ListType.ofOptional( + 4, Types.StructType.of( + Types.NestedField.required(1, "nestedListField1", Types.LongType.get()), + Types.NestedField.required(2, "nestedListField2", Types.LongType.get()), + Types.NestedField.required(3, "nestedListField3", Types.LongType.get()) + ) + )) + ); + + // Project id only. + Schema idOnly = schema.select("id"); + generateAndValidate(schema, idOnly); + + // Project list only. + Schema mapOnly = schema.select("list"); + generateAndValidate(schema, mapOnly); + + // Project all. + generateAndValidate(schema, schema); + + // Project partial list value. + Schema partialList = new Schema( + Types.NestedField.optional(5, "list", Types.ListType.ofOptional( + 4, Types.StructType.of( + Types.NestedField.required(2, "nestedListField2", Types.LongType.get()) + ) + )) + ); + AssertHelpers.assertThrows("Should not allow to project a partial list element with non-primitive type.", + IllegalArgumentException.class, "Cannot project a partial list element", + () -> generateAndValidate(schema, partialList) + ); + } + + private void generateAndValidate(Schema schema, Schema projectSchema) { + int numRecords = 100; + Iterable recordList = RandomGenericData.generate(schema, numRecords, 102L); + Iterable rowDataList = RandomRowData.generate(schema, numRecords, 102L); + + StructProjection structProjection = StructProjection.create(schema, projectSchema); + RowDataProjection rowDataProjection = RowDataProjection.create(schema, projectSchema); + + Iterator recordIter = recordList.iterator(); + Iterator rowDataIter = rowDataList.iterator(); + + for (int i = 0; i < numRecords; i++) { + Assert.assertTrue("Should have more records", recordIter.hasNext()); + Assert.assertTrue("Should have more RowData", rowDataIter.hasNext()); + + StructLike expected = structProjection.wrap(recordIter.next()); + RowData actual = rowDataProjection.wrap(rowDataIter.next()); + + TestHelpers.assertRowData(projectSchema, expected, actual); + } + + Assert.assertFalse("Shouldn't have more record", recordIter.hasNext()); + Assert.assertFalse("Shouldn't have more RowData", rowDataIter.hasNext()); + } +} diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/PruneColumns.java b/parquet/src/main/java/org/apache/iceberg/parquet/PruneColumns.java index f181875ad3ac..aafa2902bc64 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/PruneColumns.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/PruneColumns.java @@ -19,12 +19,14 @@ package org.apache.iceberg.parquet; +import java.util.Collections; import java.util.List; import java.util.Set; import org.apache.iceberg.relocated.com.google.common.base.Objects; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; import org.apache.parquet.schema.Type; @@ -49,12 +51,22 @@ public Type message(MessageType message, List fields) { Type field = fields.get(i); Integer fieldId = getId(originalField); if (fieldId != null && selectedIds.contains(fieldId)) { - builder.addField(originalField); + if (field != null) { + hasChange = true; + builder.addField(field); + } else { + if (isStruct(originalField)) { + hasChange = true; + builder.addField(originalField.asGroupType().withNewFields(Collections.emptyList())); + } else { + builder.addField(originalField); + } + } fieldCount += 1; } else if (field != null) { + hasChange = true; builder.addField(field); fieldCount += 1; - hasChange = true; } } @@ -141,4 +153,15 @@ public Type primitive(PrimitiveType primitive) { private Integer getId(Type type) { return type.getId() == null ? null : type.getId().intValue(); } + + private boolean isStruct(Type field) { + if (field.isPrimitive()) { + return false; + } else { + GroupType groupType = field.asGroupType(); + LogicalTypeAnnotation logicalTypeAnnotation = groupType.getLogicalTypeAnnotation(); + return !logicalTypeAnnotation.equals(LogicalTypeAnnotation.mapType()) && + !logicalTypeAnnotation.equals(LogicalTypeAnnotation.listType()); + } + } } diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java b/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java index c8b33dd2f706..b58745c7a00d 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java @@ -24,24 +24,32 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.stream.Stream; import org.apache.avro.generic.GenericData; import org.apache.avro.util.Utf8; import org.apache.iceberg.CombinedScanTask; import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Partitioning; +import org.apache.iceberg.Schema; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; import org.apache.iceberg.encryption.EncryptedFiles; import org.apache.iceberg.encryption.EncryptedInputFile; -import org.apache.iceberg.encryption.EncryptionManager; import org.apache.iceberg.io.CloseableIterator; -import org.apache.iceberg.io.FileIO; import org.apache.iceberg.io.InputFile; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StructType; import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.PartitionUtil; import org.apache.spark.rdd.InputFileBlockHolder; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.UTF8String; import org.slf4j.Logger; @@ -55,6 +63,7 @@ abstract class BaseDataReader implements Closeable { private static final Logger LOG = LoggerFactory.getLogger(BaseDataReader.class); + private final Table table; private final Iterator tasks; private final Map inputFiles; @@ -62,17 +71,18 @@ abstract class BaseDataReader implements Closeable { private T current = null; private FileScanTask currentTask = null; - BaseDataReader(CombinedScanTask task, FileIO io, EncryptionManager encryptionManager) { + BaseDataReader(Table table, CombinedScanTask task) { + this.table = table; this.tasks = task.files().iterator(); Map keyMetadata = Maps.newHashMap(); task.files().stream() .flatMap(fileScanTask -> Stream.concat(Stream.of(fileScanTask.file()), fileScanTask.deletes().stream())) .forEach(file -> keyMetadata.put(file.path().toString(), file.keyMetadata())); Stream encrypted = keyMetadata.entrySet().stream() - .map(entry -> EncryptedFiles.encryptedInput(io.newInputFile(entry.getKey()), entry.getValue())); + .map(entry -> EncryptedFiles.encryptedInput(table.io().newInputFile(entry.getKey()), entry.getValue())); // decrypt with the batch call to avoid multiple RPCs to a key server, if possible - Iterable decryptedFiles = encryptionManager.decrypt(encrypted::iterator); + Iterable decryptedFiles = table.encryption().decrypt(encrypted::iterator); Map files = Maps.newHashMapWithExpectedSize(task.files().size()); decryptedFiles.forEach(decrypted -> files.putIfAbsent(decrypted.location(), decrypted)); @@ -132,6 +142,15 @@ protected InputFile getInputFile(String location) { return inputFiles.get(location); } + protected Map constantsMap(FileScanTask task, Schema readSchema) { + if (readSchema.findField(MetadataColumns.PARTITION_COLUMN_ID) != null) { + StructType partitionType = Partitioning.partitionType(table); + return PartitionUtil.constantsMap(task, partitionType, BaseDataReader::convertConstant); + } else { + return PartitionUtil.constantsMap(task, BaseDataReader::convertConstant); + } + } + protected static Object convertConstant(Type type, Object value) { if (value == null) { return null; @@ -155,6 +174,24 @@ protected static Object convertConstant(Type type, Object value) { return ByteBuffers.toByteArray((ByteBuffer) value); case BINARY: return ByteBuffers.toByteArray((ByteBuffer) value); + case STRUCT: + StructType structType = (StructType) type; + + if (structType.fields().isEmpty()) { + return new GenericInternalRow(); + } + + List fields = structType.fields(); + Object[] values = new Object[fields.size()]; + StructLike struct = (StructLike) value; + + for (int index = 0; index < fields.size(); index++) { + NestedField field = fields.get(index); + Type fieldType = field.type(); + values[index] = convertConstant(fieldType, struct.get(index, fieldType.typeId().javaClass())); + } + + return new GenericInternalRow(values); default: } return value; diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java b/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java index 8cfe46b598fc..e4bd3ceba6ce 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java @@ -41,7 +41,6 @@ import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders; import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; import org.apache.iceberg.types.TypeUtil; -import org.apache.iceberg.util.PartitionUtil; import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -52,7 +51,7 @@ class BatchDataReader extends BaseDataReader { private final int batchSize; BatchDataReader(CombinedScanTask task, Table table, Schema expectedSchema, boolean caseSensitive, int size) { - super(task, table.io(), table.encryption()); + super(table, task); this.expectedSchema = expectedSchema; this.nameMapping = table.properties().get(TableProperties.DEFAULT_NAME_MAPPING); this.caseSensitive = caseSensitive; @@ -66,7 +65,7 @@ CloseableIterator open(FileScanTask task) { // update the current file for Spark's filename() function InputFileBlockHolder.set(file.path().toString(), task.start(), task.length()); - Map idToConstant = PartitionUtil.constantsMap(task, BatchDataReader::convertConstant); + Map idToConstant = constantsMap(task, expectedSchema); CloseableIterable iter; InputFile location = getInputFile(task); diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java b/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java index d4328addc759..ce2226f4f75e 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java @@ -26,7 +26,6 @@ import org.apache.iceberg.Schema; import org.apache.iceberg.Table; import org.apache.iceberg.io.CloseableIterator; -import org.apache.iceberg.util.PartitionUtil; import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.catalyst.InternalRow; @@ -44,7 +43,7 @@ CloseableIterator open(FileScanTask task) { // schema or rows returned by readers Schema requiredSchema = matches.requiredSchema(); - Map idToConstant = PartitionUtil.constantsMap(task, RowDataReader::convertConstant); + Map idToConstant = constantsMap(task, expectedSchema); DataFile file = task.file(); // update the current file for Spark's filename() function diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java b/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java index 391d4a053490..8770e17aa015 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java @@ -44,7 +44,6 @@ import org.apache.iceberg.spark.data.SparkOrcReader; import org.apache.iceberg.spark.data.SparkParquetReaders; import org.apache.iceberg.types.TypeUtil; -import org.apache.iceberg.util.PartitionUtil; import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.catalyst.InternalRow; @@ -56,7 +55,7 @@ class RowDataReader extends BaseDataReader { private final boolean caseSensitive; RowDataReader(CombinedScanTask task, Table table, Schema expectedSchema, boolean caseSensitive) { - super(task, table.io(), table.encryption()); + super(table, task); this.tableSchema = table.schema(); this.expectedSchema = expectedSchema; this.nameMapping = table.properties().get(TableProperties.DEFAULT_NAME_MAPPING); @@ -69,7 +68,7 @@ CloseableIterator open(FileScanTask task) { // schema or rows returned by readers Schema requiredSchema = deletes.requiredSchema(); - Map idToConstant = PartitionUtil.constantsMap(task, RowDataReader::convertConstant); + Map idToConstant = constantsMap(task, expectedSchema); DataFile file = task.file(); // update the current file for Spark's filename() function diff --git a/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java b/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java index e4ca09f1fec8..10b9d6f3030c 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java +++ b/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java @@ -68,6 +68,8 @@ public abstract class TestIcebergSourceTablesBase extends SparkTestBase { optional(2, "data", Types.StringType.get()) ); + private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA).identity("id").build(); + @Rule public TemporaryFolder temp = new TemporaryFolder(); @@ -147,6 +149,31 @@ public void testEntriesTable() throws Exception { TestHelpers.assertEqualsSafe(entriesTable.schema().asStruct(), expected.get(0), actual.get(0)); } + @Test + public void testEntriesTablePartitionedPrune() throws Exception { + TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); + + List records = Lists.newArrayList(new SimpleRecord(1, "1")); + + Dataset inputDf = spark.createDataFrame(records, SimpleRecord.class); + inputDf.select("id", "data").write() + .format("iceberg") + .mode("append") + .save(loadLocation(tableIdentifier)); + + table.refresh(); + + List actual = spark.read() + .format("iceberg") + .load(loadLocation(tableIdentifier, "entries")) + .select("status") + .collectAsList(); + + Assert.assertEquals("Results should contain only one status", 1, actual.size()); + Assert.assertEquals("That status should be Added (1)", 1, actual.get(0).getInt(0)); + } + @Test public void testEntriesTableDataFilePrune() throws Exception { TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_test"); @@ -312,7 +339,7 @@ public void testCountEntriesTable() { @Test public void testFilesTable() throws Exception { TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_test"); - Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("id").build()); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); Table entriesTable = loadTable(tableIdentifier, "entries"); Table filesTable = loadTable(tableIdentifier, "files"); @@ -362,7 +389,7 @@ public void testFilesTableWithSnapshotIdInheritance() throws Exception { spark.sql("DROP TABLE IF EXISTS parquet_table"); TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_inheritance_test"); - Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("id").build()); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); table.updateProperties() .set(TableProperties.SNAPSHOT_ID_INHERITANCE_ENABLED, "true") .commit(); @@ -422,7 +449,7 @@ public void testEntriesTableWithSnapshotIdInheritance() throws Exception { spark.sql("DROP TABLE IF EXISTS parquet_table"); TableIdentifier tableIdentifier = TableIdentifier.of("db", "entries_inheritance_test"); - PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("id").build(); + PartitionSpec spec = SPEC; Table table = createTable(tableIdentifier, SCHEMA, spec); table.updateProperties() @@ -523,7 +550,7 @@ public void testFilesUnpartitionedTable() throws Exception { @Test public void testAllMetadataTablesWithStagedCommits() throws Exception { TableIdentifier tableIdentifier = TableIdentifier.of("db", "stage_aggregate_table_test"); - Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("id").build()); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); table.updateProperties().set(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED, "true").commit(); spark.conf().set("spark.wap.id", "1234567"); @@ -567,7 +594,7 @@ public void testAllMetadataTablesWithStagedCommits() throws Exception { @Test public void testAllDataFilesTable() throws Exception { TableIdentifier tableIdentifier = TableIdentifier.of("db", "files_test"); - Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("id").build()); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); Table entriesTable = loadTable(tableIdentifier, "entries"); Table filesTable = loadTable(tableIdentifier, "all_data_files"); @@ -831,7 +858,7 @@ public void testPrunedSnapshotsTable() { @Test public void testManifestsTable() { TableIdentifier tableIdentifier = TableIdentifier.of("db", "manifests_test"); - Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("id").build()); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); Table manifestTable = loadTable(tableIdentifier, "manifests"); Dataset df1 = spark.createDataFrame( Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(null, "b")), SimpleRecord.class); @@ -878,7 +905,7 @@ public void testManifestsTable() { @Test public void testPruneManifestsTable() { TableIdentifier tableIdentifier = TableIdentifier.of("db", "manifests_test"); - Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("id").build()); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); Table manifestTable = loadTable(tableIdentifier, "manifests"); Dataset df1 = spark.createDataFrame( Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(null, "b")), SimpleRecord.class); @@ -938,7 +965,7 @@ public void testPruneManifestsTable() { @Test public void testAllManifestsTable() { TableIdentifier tableIdentifier = TableIdentifier.of("db", "manifests_test"); - Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("id").build()); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); Table manifestTable = loadTable(tableIdentifier, "all_manifests"); Dataset df1 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); @@ -1034,7 +1061,7 @@ public void testUnpartitionedPartitionsTable() { @Test public void testPartitionsTable() { TableIdentifier tableIdentifier = TableIdentifier.of("db", "partitions_test"); - Table table = createTable(tableIdentifier, SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("id").build()); + Table table = createTable(tableIdentifier, SCHEMA, SPEC); Table partitionsTable = loadTable(tableIdentifier, "partitions"); Dataset df1 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); Dataset df2 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(2, "b")), SimpleRecord.class); diff --git a/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkBaseDataReader.java b/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkBaseDataReader.java index 51b47cbd972d..8bae666c0475 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkBaseDataReader.java +++ b/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkBaseDataReader.java @@ -30,7 +30,6 @@ import java.util.stream.IntStream; import java.util.stream.StreamSupport; import org.apache.avro.generic.GenericData; -import org.apache.hadoop.conf.Configuration; import org.apache.iceberg.AppendFiles; import org.apache.iceberg.BaseCombinedScanTask; import org.apache.iceberg.DataFile; @@ -39,8 +38,6 @@ import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; import org.apache.iceberg.Table; -import org.apache.iceberg.encryption.PlaintextEncryptionManager; -import org.apache.iceberg.hadoop.HadoopFileIO; import org.apache.iceberg.io.CloseableIterator; import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.parquet.Parquet; @@ -59,7 +56,7 @@ public abstract class TestSparkBaseDataReader { @Rule public TemporaryFolder temp = new TemporaryFolder(); - private static final Configuration CONFD = new Configuration(); + private Table table; // Simulates the closeable iterator of data to be read private static class CloseableIntegerRange implements CloseableIterator { @@ -92,10 +89,8 @@ public Integer next() { private static class ClosureTrackingReader extends BaseDataReader { private Map tracker = new HashMap<>(); - ClosureTrackingReader(List tasks) { - super(new BaseCombinedScanTask(tasks), - new HadoopFileIO(CONFD), - new PlaintextEncryptionManager()); + ClosureTrackingReader(Table table, List tasks) { + super(table, new BaseCombinedScanTask(tasks)); } @Override @@ -124,7 +119,7 @@ public void testClosureOnDataExhaustion() throws IOException { Integer recordPerTask = 10; List tasks = createFileScanTasks(totalTasks, recordPerTask); - ClosureTrackingReader reader = new ClosureTrackingReader(tasks); + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); int countRecords = 0; while (reader.next()) { @@ -151,7 +146,7 @@ public void testClosureDuringIteration() throws IOException { FileScanTask firstTask = tasks.get(0); FileScanTask secondTask = tasks.get(1); - ClosureTrackingReader reader = new ClosureTrackingReader(tasks); + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); // Total of 2 elements Assert.assertTrue(reader.next()); @@ -175,7 +170,7 @@ public void testClosureWithoutAnyRead() throws IOException { Integer recordPerTask = 10; List tasks = createFileScanTasks(totalTasks, recordPerTask); - ClosureTrackingReader reader = new ClosureTrackingReader(tasks); + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); reader.close(); @@ -191,7 +186,7 @@ public void testExplicitClosure() throws IOException { Integer recordPerTask = 10; List tasks = createFileScanTasks(totalTasks, recordPerTask); - ClosureTrackingReader reader = new ClosureTrackingReader(tasks); + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); Integer halfDataSize = (totalTasks * recordPerTask) / 2; for (int i = 0; i < halfDataSize; i++) { @@ -217,7 +212,7 @@ public void testIdempotentExplicitClosure() throws IOException { Integer recordPerTask = 10; List tasks = createFileScanTasks(totalTasks, recordPerTask); - ClosureTrackingReader reader = new ClosureTrackingReader(tasks); + ClosureTrackingReader reader = new ClosureTrackingReader(table, tasks); // Total 100 elements, only 5 iterators have been created for (int i = 0; i < 45; i++) { @@ -250,7 +245,7 @@ private List createFileScanTasks(Integer totalTasks, Integer recor ); try { - Table table = TestTables.create(location, desc, schema, PartitionSpec.unpartitioned()); + this.table = TestTables.create(location, desc, schema, PartitionSpec.unpartitioned()); // Important: use the table's schema for the rest of the test // When tables are created, the column ids are reassigned. Schema tableSchema = table.schema(); diff --git a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java index 633a17143f52..483dbcfc5b2f 100644 --- a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java +++ b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java @@ -149,7 +149,7 @@ private Schema schemaWithMetadataColumns() { // metadata columns List fields = metaColumns.stream() .distinct() - .map(MetadataColumns::get) + .map(name -> MetadataColumns.metadataColumn(table, name)) .collect(Collectors.toList()); Schema meta = new Schema(fields); diff --git a/spark3/src/test/java/org/apache/iceberg/spark/source/SparkTestTable.java b/spark3/src/test/java/org/apache/iceberg/spark/source/SparkTestTable.java new file mode 100644 index 000000000000..afb1136f4fa5 --- /dev/null +++ b/spark3/src/test/java/org/apache/iceberg/spark/source/SparkTestTable.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.iceberg.MetadataColumns; +import org.apache.iceberg.Table; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +// TODO: remove this class once we compile against Spark 3.2 +public class SparkTestTable extends SparkTable { + + private final String[] metadataColumnNames; + + public SparkTestTable(Table icebergTable, String[] metadataColumnNames, boolean refreshEagerly) { + super(icebergTable, refreshEagerly); + this.metadataColumnNames = metadataColumnNames; + } + + @Override + public StructType schema() { + StructType schema = super.schema(); + if (metadataColumnNames != null) { + for (String columnName : metadataColumnNames) { + Types.NestedField metadataColumn = MetadataColumns.metadataColumn(table(), columnName); + schema = schema.add(columnName, SparkSchemaUtil.convert(metadataColumn.type())); + } + } + return schema; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + SparkScanBuilder scanBuilder = (SparkScanBuilder) super.newScanBuilder(options); + if (metadataColumnNames != null) { + scanBuilder.withMetadataColumns(metadataColumnNames); + } + return scanBuilder; + } +} diff --git a/spark3/src/test/java/org/apache/iceberg/spark/source/TestMetadataTablesWithPartitionEvolution.java b/spark3/src/test/java/org/apache/iceberg/spark/source/TestMetadataTablesWithPartitionEvolution.java new file mode 100644 index 000000000000..ea9818cae9d9 --- /dev/null +++ b/spark3/src/test/java/org/apache/iceberg/spark/source/TestMetadataTablesWithPartitionEvolution.java @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PartitionSpecParser; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.SparkSessionCatalog; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.parser.ParseException; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructType; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +import static org.apache.iceberg.FileFormat.AVRO; +import static org.apache.iceberg.FileFormat.ORC; +import static org.apache.iceberg.FileFormat.PARQUET; +import static org.apache.iceberg.MetadataTableType.ALL_DATA_FILES; +import static org.apache.iceberg.MetadataTableType.ALL_ENTRIES; +import static org.apache.iceberg.MetadataTableType.ENTRIES; +import static org.apache.iceberg.MetadataTableType.FILES; +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; + +@RunWith(Parameterized.class) +public class TestMetadataTablesWithPartitionEvolution extends SparkCatalogTestBase { + + @Parameters(name = "catalog = {0}, impl = {1}, conf = {2}, fileFormat = {3}, formatVersion = {4}") + public static Object[][] parameters() { + return new Object[][] { + { "testhive", SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default" + ), + ORC, + formatVersion() + }, + { "testhadoop", SparkCatalog.class.getName(), + ImmutableMap.of( + "type", "hadoop" + ), + PARQUET, + formatVersion() + }, + { "spark_catalog", SparkSessionCatalog.class.getName(), + ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "clients", "1", + "parquet-enabled", "false", + "cache-enabled", "false" // Spark will delete tables using v1, leaving the cache out of sync + ), + AVRO, + formatVersion() + } + }; + } + + private static int formatVersion() { + return RANDOM.nextInt(2) + 1; + } + + private static final Random RANDOM = ThreadLocalRandom.current(); + + private final FileFormat fileFormat; + private final int formatVersion; + + public TestMetadataTablesWithPartitionEvolution(String catalogName, String implementation, Map config, + FileFormat fileFormat, int formatVersion) { + super(catalogName, implementation, config); + this.fileFormat = fileFormat; + this.formatVersion = formatVersion; + } + + @After + public void removeTable() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testFilesMetadataTable() throws ParseException { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", tableName); + initTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + Dataset df = loadMetadataTable(tableType); + Assert.assertTrue("Partition must be skipped", df.schema().getFieldIndex("partition").isEmpty()); + } + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec() + .addField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the first partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(new Object[]{null}), row("b1")), + "STRUCT", + tableType); + } + + table.updateSpec() + .addField(Expressions.bucket("category", 8)) + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the second partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, null), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec() + .removeField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after dropping the first partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec() + .renameField("category_bucket_8", "category_bucket_8_another_name") + .commit(); + sql("REFRESH TABLE %s", tableName); + + // verify the metadata tables after renaming the second partition column + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + } + + @Test + public void testEntriesMetadataTable() throws ParseException { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", tableName); + initTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables while the current spec is still unpartitioned + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + Dataset df = loadMetadataTable(tableType); + StructType dataFileType = (StructType) df.schema().apply("data_file").dataType(); + Assert.assertTrue("Partition must be skipped", dataFileType.getFieldIndex("").isEmpty()); + } + + Table table = validationCatalog.loadTable(tableIdent); + + table.updateSpec() + .addField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the first partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(new Object[]{null}), row("b1")), + "STRUCT", + tableType); + } + + table.updateSpec() + .addField(Expressions.bucket("category", 8)) + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after adding the second partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(null, null), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec() + .removeField("data") + .commit(); + sql("REFRESH TABLE %s", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + // verify the metadata tables after dropping the first partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + + table.updateSpec() + .renameField("category_bucket_8", "category_bucket_8_another_name") + .commit(); + sql("REFRESH TABLE %s", tableName); + + // verify the metadata tables after renaming the second partition column + for (MetadataTableType tableType : Arrays.asList(ENTRIES, ALL_ENTRIES)) { + assertPartitions( + ImmutableList.of(row(null, null), row(null, 2), row("b1", null), row("b1", 2)), + "STRUCT", + tableType); + } + } + + @Test + public void testMetadataTablesWithUnknownTransforms() { + sql("CREATE TABLE %s (id bigint NOT NULL, category string, data string) USING iceberg", tableName); + initTable(); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", tableName); + + Table table = validationCatalog.loadTable(tableIdent); + + PartitionSpec unknownSpec = PartitionSpecParser.fromJson(table.schema(), + "{ \"spec-id\": 1, \"fields\": [ { \"name\": \"id_zero\", \"transform\": \"zero\", \"source-id\": 1 } ] }"); + + // replace the table spec to include an unknown transform + TableOperations ops = ((HasTableOperations) table).operations(); + TableMetadata base = ops.current(); + ops.commit(base, base.updatePartitionSpec(unknownSpec)); + + sql("REFRESH TABLE %s", tableName); + + for (MetadataTableType tableType : Arrays.asList(FILES, ALL_DATA_FILES, ENTRIES, ALL_ENTRIES)) { + AssertHelpers.assertThrows("Should complain about the partition type", + ValidationException.class, "Cannot build table partition type, unknown transforms", + () -> loadMetadataTable(tableType)); + } + } + + private void assertPartitions(List expectedPartitions, String expectedTypeAsString, + MetadataTableType tableType) throws ParseException { + Dataset df = loadMetadataTable(tableType); + + DataType expectedType = spark.sessionState().sqlParser().parseDataType(expectedTypeAsString); + switch (tableType) { + case FILES: + case ALL_DATA_FILES: + DataType actualFilesType = df.schema().apply("partition").dataType(); + Assert.assertEquals("Partition type must match", expectedType, actualFilesType); + break; + + case ENTRIES: + case ALL_ENTRIES: + StructType dataFileType = (StructType) df.schema().apply("data_file").dataType(); + DataType actualEntriesType = dataFileType.apply("partition").dataType(); + Assert.assertEquals("Partition type must match", expectedType, actualEntriesType); + break; + + default: + throw new UnsupportedOperationException("Unsupported metadata table type: " + tableType); + } + + switch (tableType) { + case FILES: + case ALL_DATA_FILES: + List actualFilesPartitions = df.orderBy("partition") + .select("partition.*") + .collectAsList(); + assertEquals("Partitions must match", expectedPartitions, rowsToJava(actualFilesPartitions)); + break; + + case ENTRIES: + case ALL_ENTRIES: + List actualEntriesPartitions = df.orderBy("data_file.partition") + .select("data_file.partition.*") + .collectAsList(); + assertEquals("Partitions must match", expectedPartitions, rowsToJava(actualEntriesPartitions)); + break; + + default: + throw new UnsupportedOperationException("Unsupported metadata table type: " + tableType); + } + } + + private Dataset loadMetadataTable(MetadataTableType tableType) { + return spark.read().format("iceberg").load(tableName + "." + tableType.name()); + } + + private void initTable() { + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%s')", tableName, DEFAULT_FILE_FORMAT, fileFormat.name()); + sql("ALTER TABLE %s SET TBLPROPERTIES('%s' '%d')", tableName, FORMAT_VERSION, formatVersion); + } +} diff --git a/spark3/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalog.java b/spark3/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalog.java index 92013d396c1a..027c88cd4df6 100644 --- a/spark3/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalog.java +++ b/spark3/src/test/java/org/apache/iceberg/spark/source/TestSparkCatalog.java @@ -19,7 +19,6 @@ package org.apache.iceberg.spark.source; -import org.apache.iceberg.spark.Spark3Util; import org.apache.iceberg.spark.SparkSessionCatalog; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.connector.catalog.Identifier; @@ -31,7 +30,14 @@ public class TestSparkCatalog exten @Override public Table loadTable(Identifier ident) throws NoSuchTableException { - TestTables.TestTable table = TestTables.load(Spark3Util.identifierToTableIdentifier(ident).toString()); - return new SparkTable(table, false); + String[] parts = ident.name().split("\\$", 2); + if (parts.length == 2) { + TestTables.TestTable table = TestTables.load(parts[0]); + String[] metadataColumns = parts[1].split(","); + return new SparkTestTable(table, metadataColumns, false); + } else { + TestTables.TestTable table = TestTables.load(ident.name()); + return new SparkTestTable(table, null, false); + } } } diff --git a/spark3/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java b/spark3/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java new file mode 100644 index 000000000000..b29d281863cb --- /dev/null +++ b/spark3/src/test/java/org/apache/iceberg/spark/source/TestSparkMetadataColumns.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import java.io.IOException; +import java.util.List; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.HasTableOperations; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.PartitionSpecParser; +import org.apache.iceberg.Schema; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableOperations; +import org.apache.iceberg.UpdateProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.expressions.Expressions; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.spark.SparkTestBase; +import org.apache.iceberg.types.Types; +import org.junit.After; +import org.junit.Assume; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.apache.iceberg.TableProperties.DEFAULT_FILE_FORMAT; +import static org.apache.iceberg.TableProperties.FORMAT_VERSION; +import static org.apache.iceberg.TableProperties.ORC_VECTORIZATION_ENABLED; +import static org.apache.iceberg.TableProperties.PARQUET_VECTORIZATION_ENABLED; + +@RunWith(Parameterized.class) +public class TestSparkMetadataColumns extends SparkTestBase { + + private static final String TABLE_NAME = "test_table"; + private static final Schema SCHEMA = new Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "category", Types.StringType.get()), + Types.NestedField.optional(3, "data", Types.StringType.get()) + ); + private static final PartitionSpec UNKNOWN_SPEC = PartitionSpecParser.fromJson(SCHEMA, + "{ \"spec-id\": 1, \"fields\": [ { \"name\": \"id_zero\", \"transform\": \"zero\", \"source-id\": 1 } ] }"); + + @Parameterized.Parameters(name = "fileFormat = {0}, vectorized = {1}, formatVersion = {2}") + public static Object[][] parameters() { + return new Object[][] { + { FileFormat.PARQUET, false, 1}, + { FileFormat.PARQUET, true, 1}, + { FileFormat.PARQUET, false, 2}, + { FileFormat.PARQUET, true, 2}, + { FileFormat.AVRO, false, 1}, + { FileFormat.AVRO, false, 2}, + { FileFormat.ORC, false, 1}, + { FileFormat.ORC, true, 1}, + { FileFormat.ORC, false, 2}, + { FileFormat.ORC, true, 2}, + }; + } + + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + private final FileFormat fileFormat; + private final boolean vectorized; + private final int formatVersion; + + private Table table = null; + + public TestSparkMetadataColumns(FileFormat fileFormat, boolean vectorized, int formatVersion) { + this.fileFormat = fileFormat; + this.vectorized = vectorized; + this.formatVersion = formatVersion; + } + + @BeforeClass + public static void setupSpark() { + ImmutableMap config = ImmutableMap.of( + "type", "hive", + "default-namespace", "default", + "cache-enabled", "true" + ); + spark.conf().set("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.source.TestSparkCatalog"); + config.forEach((key, value) -> spark.conf().set("spark.sql.catalog.spark_catalog." + key, value)); + } + + @Before + public void setupTable() throws IOException { + createAndInitTable(); + } + + @After + public void dropTable() { + TestTables.clearTables(); + } + + // TODO: remove testing workarounds once we compile against Spark 3.2 + + @Test + public void testSpecAndPartitionMetadataColumns() { + // TODO: support metadata structs in vectorized ORC reads + Assume.assumeFalse(fileFormat == FileFormat.ORC && vectorized); + + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec() + .addField("data") + .commit(); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec() + .addField(Expressions.bucket("category", 8)) + .commit(); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec() + .removeField("data") + .commit(); + sql("INSERT INTO TABLE %s VALUES (1, 'a1', 'b1')", TABLE_NAME); + + table.refresh(); + table.updateSpec() + .renameField("category_bucket_8", "category_bucket_8_another_name") + .commit(); + + List expected = ImmutableList.of( + row(0, row(null, null)), + row(1, row("b1", null)), + row(2, row("b1", 2)), + row(3, row(null, 2)) + ); + assertEquals("Rows must match", expected, + sql("SELECT _spec_id, _partition FROM `%s$_spec_id,_partition` ORDER BY _spec_id", TABLE_NAME)); + } + + @Test + public void testPartitionMetadataColumnWithUnknownTransforms() { + // replace the table spec to include an unknown transform + TableOperations ops = ((HasTableOperations) table).operations(); + TableMetadata base = ops.current(); + ops.commit(base, base.updatePartitionSpec(UNKNOWN_SPEC)); + + AssertHelpers.assertThrows("Should fail to query the partition metadata column", + ValidationException.class, "Cannot build table partition type, unknown transforms", + () -> sql("SELECT _partition FROM `%s$_partition`", TABLE_NAME)); + } + + private void createAndInitTable() throws IOException { + this.table = TestTables.create(temp.newFolder(), TABLE_NAME, SCHEMA, PartitionSpec.unpartitioned()); + + UpdateProperties updateProperties = table.updateProperties(); + updateProperties.set(FORMAT_VERSION, String.valueOf(formatVersion)); + updateProperties.set(DEFAULT_FILE_FORMAT, fileFormat.name()); + + switch (fileFormat) { + case PARQUET: + updateProperties.set(PARQUET_VECTORIZATION_ENABLED, String.valueOf(vectorized)); + break; + case ORC: + updateProperties.set(ORC_VECTORIZATION_ENABLED, String.valueOf(vectorized)); + break; + default: + Preconditions.checkState(!vectorized, "File format %s does not support vectorized reads", fileFormat); + } + + updateProperties.commit(); + } +}