diff --git a/hudi-common/src/main/java/org/apache/hudi/avro/HoodieAvroUtils.java b/hudi-common/src/main/java/org/apache/hudi/avro/HoodieAvroUtils.java index 5ac0b4cfe2c28..4edda3db903a0 100644 --- a/hudi-common/src/main/java/org/apache/hudi/avro/HoodieAvroUtils.java +++ b/hudi-common/src/main/java/org/apache/hudi/avro/HoodieAvroUtils.java @@ -18,6 +18,7 @@ package org.apache.hudi.avro; +import java.util.Arrays; import org.apache.hudi.common.config.SerializableSchema; import org.apache.hudi.common.model.HoodieOperation; import org.apache.hudi.common.model.HoodieRecord; @@ -27,6 +28,7 @@ import org.apache.hudi.common.util.collection.Pair; import org.apache.hudi.exception.HoodieException; import org.apache.hudi.exception.HoodieIOException; +import org.apache.hudi.exception.HoodieValidationException; import org.apache.hudi.exception.SchemaCompatibilityException; import org.apache.avro.AvroRuntimeException; @@ -79,7 +81,6 @@ import java.util.TimeZone; import java.util.stream.Collectors; -import static org.apache.avro.Schema.Type.UNION; import static org.apache.hudi.avro.AvroSchemaUtils.createNullableSchema; import static org.apache.hudi.avro.AvroSchemaUtils.resolveNullableSchema; import static org.apache.hudi.avro.AvroSchemaUtils.resolveUnionSchema; @@ -108,6 +109,8 @@ public class HoodieAvroUtils { public static final Schema RECORD_KEY_SCHEMA = initRecordKeySchema(); + private static final String FIELD_LOCATION_DELIMITER = "\\."; + /** * Convert a given avro record to bytes. */ @@ -1033,7 +1036,7 @@ public static int fromJavaDate(Date date) { private static Schema getActualSchemaFromUnion(Schema schema, Object data) { Schema actualSchema; - if (!schema.getType().equals(UNION)) { + if (!schema.getType().equals(Schema.Type.UNION)) { return schema; } if (schema.getTypes().size() == 2 @@ -1064,4 +1067,62 @@ public static boolean gteqAvro1_9() { public static boolean gteqAvro1_10() { return VersionUtil.compareVersions(AVRO_VERSION, "1.10") >= 0; } + + /** + * Given an Avro schema, this method will return the field specified by the path parameter. + * The fieldLocation parameter is an ordered string specifying the location of the nested field to retrieve. + * For example, field1.nestedField1 takes field "field1", and retrieves "nestedField1" from it. + * @param schema is the record to retrieve the schema from + * @param fieldLocation is the location of the field + * @return the field + */ + public static Option getField(Schema schema, String fieldLocation) { + if (schema == null || fieldLocation == null || fieldLocation.isEmpty()) { + return Option.empty(); + } + + List pathList = Arrays.asList(fieldLocation.split(FIELD_LOCATION_DELIMITER)); + pathList.stream() + .filter(pl -> pl.trim().isEmpty()) + .findAny() + .ifPresent(f -> { + throw new HoodieValidationException("Invalid fieldLocation: " + fieldLocation); + }); + if (pathList.size() == 0) { + return Option.empty(); + } + + return getFieldHelper(schema, pathList, 0); + } + + /** + * Helper method that does the actual work for {@link #getField(Schema, String)} by recursively finding the required field. + * + * @param schema top level schema to be evaluated on + * @param pathList field to find, must be built in traversal order, from parent to child. + * @param field keeps track of the index used to access the list pathList + * @return the field + */ + private static Option getFieldHelper(Schema schema, List pathList, int field) { + Field curField = schema.getField(pathList.get(field)); + Schema fieldSchema = curField.schema(); + + if (pathList.size() == field + 1 && curField.name().equals(pathList.get(field))) { + return Option.of(curField); + } + + switch (fieldSchema.getType()) { + case UNION: + // assume UNION is strictly nullable + return getFieldHelper(resolveNullableSchema(fieldSchema), pathList, ++field); + case MAP: + return getFieldHelper(fieldSchema.getValueType(), pathList, ++field); + case RECORD: + return getFieldHelper(fieldSchema, pathList, ++field); + case ARRAY: + return getFieldHelper(fieldSchema.getElementType(), pathList, ++field); + default: + return Option.empty(); + } + } } diff --git a/hudi-common/src/test/java/org/apache/hudi/avro/TestHoodieAvroUtils.java b/hudi-common/src/test/java/org/apache/hudi/avro/TestHoodieAvroUtils.java index 896843b58f28c..790562c9fcc70 100644 --- a/hudi-common/src/test/java/org/apache/hudi/avro/TestHoodieAvroUtils.java +++ b/hudi-common/src/test/java/org/apache/hudi/avro/TestHoodieAvroUtils.java @@ -18,9 +18,12 @@ package org.apache.hudi.avro; +import org.apache.avro.Schema.Field; import org.apache.hudi.common.model.HoodieRecord; import org.apache.hudi.common.testutils.SchemaTestUtil; +import org.apache.hudi.common.util.Option; import org.apache.hudi.exception.HoodieException; +import org.apache.hudi.exception.HoodieValidationException; import org.apache.hudi.exception.SchemaCompatibilityException; import org.apache.avro.AvroRuntimeException; @@ -43,6 +46,7 @@ import static org.apache.hudi.avro.HoodieAvroUtils.getNestedFieldSchemaFromWriteSchema; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -426,4 +430,30 @@ public void testConvertDaysToDate() { int days = HoodieAvroUtils.fromJavaDate(now); assertEquals(now.toLocalDate(), HoodieAvroUtils.toJavaDate(days).toLocalDate()); } + + @Test + public void testGetField() { + Schema nestedSchema = new Schema.Parser().parse(SCHEMA_WITH_NESTED_FIELD); + + // empty schema should return empty option + Option nullSchemaTest = HoodieAvroUtils.getField(null, "nestedField"); + assertFalse(nullSchemaTest.isPresent()); + + // null fieldLocation should return empty option + Option nullFieldLocationTest = HoodieAvroUtils.getField(nestedSchema, null); + assertFalse(nullFieldLocationTest.isPresent()); + + // empty fieldLocation should return empty option + Option emptyFieldLocationTest = HoodieAvroUtils.getField(nestedSchema, ""); + assertFalse(emptyFieldLocationTest.isPresent()); + + // invalid fieldLocation should throw error + assertThrows(HoodieValidationException.class, () -> HoodieAvroUtils.getField(nestedSchema, ".firstname")); + + Option topLevelFieldTest = HoodieAvroUtils.getField(nestedSchema, "firstname"); + assertTrue(topLevelFieldTest.isPresent()); + + Option nestedFieldTest = HoodieAvroUtils.getField(nestedSchema, "student.lastname"); + assertTrue(nestedFieldTest.isPresent()); + } } diff --git a/hudi-flink-datasource/hudi-flink/src/main/java/org/apache/hudi/table/HoodieTableFactory.java b/hudi-flink-datasource/hudi-flink/src/main/java/org/apache/hudi/table/HoodieTableFactory.java index 1cf66ea3437ef..753304baec1b4 100644 --- a/hudi-flink-datasource/hudi-flink/src/main/java/org/apache/hudi/table/HoodieTableFactory.java +++ b/hudi-flink-datasource/hudi-flink/src/main/java/org/apache/hudi/table/HoodieTableFactory.java @@ -18,6 +18,8 @@ package org.apache.hudi.table; +import org.apache.avro.Schema; +import org.apache.hudi.avro.HoodieAvroUtils; import org.apache.hudi.common.model.DefaultHoodieRecordPayload; import org.apache.hudi.common.util.StringUtils; import org.apache.hudi.configuration.FlinkOptions; @@ -120,6 +122,7 @@ public Set> optionalOptions() { */ private void sanityCheck(Configuration conf, ResolvedSchema schema) { List fields = schema.getColumnNames(); + Schema inferredSchema = AvroSchemaConverter.convertToSchema(schema.toPhysicalRowDataType().notNull().getLogicalType()); // validate record key in pk absence. if (!schema.getPrimaryKey().isPresent()) { @@ -132,7 +135,7 @@ private void sanityCheck(Configuration conf, ResolvedSchema schema) { } Arrays.stream(recordKeys) - .filter(field -> !fields.contains(field)) + .filter(field -> !HoodieAvroUtils.getField(inferredSchema, field).isPresent()) .findAny() .ifPresent(f -> { throw new HoodieValidationException("Field '" + f + "' specified in option " @@ -142,7 +145,7 @@ private void sanityCheck(Configuration conf, ResolvedSchema schema) { // validate pre_combine key String preCombineField = conf.get(FlinkOptions.PRECOMBINE_FIELD); - if (!fields.contains(preCombineField)) { + if (!HoodieAvroUtils.getField(inferredSchema, preCombineField).isPresent()) { if (OptionsResolver.isDefaultHoodieRecordPayloadClazz(conf)) { throw new HoodieValidationException("Option '" + FlinkOptions.PRECOMBINE_FIELD.key() + "' is required for payload class: " + DefaultHoodieRecordPayload.class.getName()); diff --git a/hudi-flink-datasource/hudi-flink/src/test/java/org/apache/hudi/table/TestHoodieTableFactory.java b/hudi-flink-datasource/hudi-flink/src/test/java/org/apache/hudi/table/TestHoodieTableFactory.java index f7a35e57f2b09..728ca63b1c02c 100644 --- a/hudi-flink-datasource/hudi-flink/src/test/java/org/apache/hudi/table/TestHoodieTableFactory.java +++ b/hudi-flink-datasource/hudi-flink/src/test/java/org/apache/hudi/table/TestHoodieTableFactory.java @@ -165,6 +165,35 @@ void testRequiredOptionsForSource() { assertDoesNotThrow(() -> new HoodieTableFactory().createDynamicTableSource(sourceContext6)); assertDoesNotThrow(() -> new HoodieTableFactory().createDynamicTableSink(sourceContext6)); + + // nested pk field is allowed + ResolvedSchema schema6 = SchemaBuilder.instance() + .field("f0", + DataTypes.ROW(DataTypes.FIELD("id", DataTypes.INT()), DataTypes.FIELD("date", DataTypes.VARCHAR(20)))) + .field("f1", DataTypes.VARCHAR(20)) + .field("f2", DataTypes.TIMESTAMP(3)) + .field("ts", DataTypes.TIMESTAMP(3)) + .build(); + this.conf.setString(FlinkOptions.RECORD_KEY_FIELD, "f0.id"); + this.conf.setString(FlinkOptions.PRECOMBINE_FIELD, "f2"); + final MockContext sourceContext7 = MockContext.getInstance(this.conf, schema6, "f2"); + + assertDoesNotThrow(() -> new HoodieTableFactory().createDynamicTableSource(sourceContext7)); + assertDoesNotThrow(() -> new HoodieTableFactory().createDynamicTableSink(sourceContext7)); + + // nested precombine field is allowed + ResolvedSchema schema7 = SchemaBuilder.instance() + .field("f0", DataTypes.INT().notNull()) + .field("f1", DataTypes.VARCHAR(20)) + .field("f2", DataTypes.TIMESTAMP(3)) + .field("ts", DataTypes.ROW(DataTypes.FIELD("year", DataTypes.INT()), DataTypes.FIELD("MONTH", DataTypes.INT()))) + .build(); + this.conf.setString(FlinkOptions.RECORD_KEY_FIELD, "f0"); + this.conf.setString(FlinkOptions.PRECOMBINE_FIELD, "f2.year"); + final MockContext sourceContext8 = MockContext.getInstance(this.conf, schema7, "f2"); + + assertDoesNotThrow(() -> new HoodieTableFactory().createDynamicTableSource(sourceContext8)); + assertDoesNotThrow(() -> new HoodieTableFactory().createDynamicTableSink(sourceContext8)); } @Test