-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[HUDI-4904] Add support for unraveling proto schemas in ProtoClassBasedSchemaProvider #6761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9c1fa14
510d525
aad9ec1
8899274
a922a5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,7 @@ | |
| import org.apache.avro.generic.GenericFixed; | ||
| import org.apache.avro.generic.GenericRecord; | ||
| import org.apache.avro.util.Utf8; | ||
| import org.apache.kafka.common.utils.CopyOnWriteMap; | ||
|
|
||
| import java.nio.ByteBuffer; | ||
| import java.util.ArrayList; | ||
|
|
@@ -45,6 +46,7 @@ | |
| import java.util.HashMap; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.Objects; | ||
| import java.util.concurrent.ConcurrentHashMap; | ||
| import java.util.function.Function; | ||
|
|
||
|
|
@@ -57,10 +59,11 @@ public class ProtoConversionUtil { | |
| * Creates an Avro {@link Schema} for the provided class. Assumes that the class is a protobuf {@link Message}. | ||
| * @param clazz The protobuf class | ||
| * @param flattenWrappedPrimitives set to true to treat wrapped primitives like nullable fields instead of nested messages. | ||
| * @param maxRecursionDepth the number of times to unravel a recursive proto schema before spilling the rest to bytes | ||
| * @return An Avro schema | ||
| */ | ||
| public static Schema getAvroSchemaForMessageClass(Class clazz, boolean flattenWrappedPrimitives) { | ||
| return AvroSupport.get().getSchema(clazz, flattenWrappedPrimitives); | ||
| public static Schema getAvroSchemaForMessageClass(Class clazz, boolean flattenWrappedPrimitives, int maxRecursionDepth) { | ||
| return AvroSupport.get().getSchema(clazz, flattenWrappedPrimitives, maxRecursionDepth); | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -80,17 +83,19 @@ public static GenericRecord convertToAvro(Schema schema, Message message) { | |
| * 2. Convert directly from a protobuf {@link Message} to a {@link GenericRecord} while properly handling enums and wrapped primitives mentioned above. | ||
| */ | ||
| private static class AvroSupport { | ||
| private static final Schema STRING_SCHEMA = Schema.create(Schema.Type.STRING); | ||
| private static final Schema NULL_SCHEMA = Schema.create(Schema.Type.NULL); | ||
| private static final String OVERFLOW_DESCRIPTOR_FIELD_NAME = "descriptor_full_name"; | ||
| private static final String OVERFLOW_BYTES_FIELD_NAME = "proto_bytes"; | ||
| private static final Schema RECURSION_OVERFLOW_SCHEMA = Schema.createRecord("recursion_overflow", null, "org.apache.hudi.proto", false, | ||
| Arrays.asList(new Schema.Field(OVERFLOW_DESCRIPTOR_FIELD_NAME, STRING_SCHEMA, null, ""), | ||
| new Schema.Field(OVERFLOW_BYTES_FIELD_NAME, Schema.create(Schema.Type.BYTES), null, "".getBytes()))); | ||
| private static final AvroSupport INSTANCE = new AvroSupport(); | ||
| // A cache of the proto class name paired with whether wrapped primitives should be flattened as the key and the generated avro schema as the value | ||
| private static final Map<Pair<Class, Boolean>, Schema> SCHEMA_CACHE = new ConcurrentHashMap<>(); | ||
| private static final Map<SchemaCacheKey, Schema> SCHEMA_CACHE = new ConcurrentHashMap<>(); | ||
| // A cache with a key as the pair target avro schema and the proto descriptor for the source and the value as an array of proto field descriptors where the order matches the avro ordering. | ||
| // When converting from proto to avro, we want to be able to iterate over the fields in the proto in the same order as they appear in the avro schema. | ||
| private static final Map<Pair<Schema, Descriptors.Descriptor>, Descriptors.FieldDescriptor[]> FIELD_CACHE = new ConcurrentHashMap<>(); | ||
|
|
||
|
|
||
| private static final Schema STRINGS = Schema.create(Schema.Type.STRING); | ||
|
|
||
| private static final Schema NULL = Schema.create(Schema.Type.NULL); | ||
| private static final Map<Descriptors.Descriptor, Schema.Type> WRAPPER_DESCRIPTORS_TO_TYPE = getWrapperDescriptorsToType(); | ||
|
|
||
| private static Map<Descriptors.Descriptor, Schema.Type> getWrapperDescriptorsToType() { | ||
|
|
@@ -118,14 +123,15 @@ public GenericRecord convert(Schema schema, Message message) { | |
| return (GenericRecord) convertObject(schema, message); | ||
| } | ||
|
|
||
| public Schema getSchema(Class c, boolean flattenWrappedPrimitives) { | ||
| return SCHEMA_CACHE.computeIfAbsent(Pair.of(c, flattenWrappedPrimitives), key -> { | ||
| public Schema getSchema(Class c, boolean flattenWrappedPrimitives, int maxRecursionDepth) { | ||
| return SCHEMA_CACHE.computeIfAbsent(new SchemaCacheKey(c, flattenWrappedPrimitives, maxRecursionDepth), key -> { | ||
| try { | ||
| Object descriptor = c.getMethod("getDescriptor").invoke(null); | ||
| if (c.isEnum()) { | ||
| return getEnumSchema((Descriptors.EnumDescriptor) descriptor); | ||
| } else { | ||
| return getMessageSchema((Descriptors.Descriptor) descriptor, new HashMap<>(), flattenWrappedPrimitives); | ||
| Descriptors.Descriptor castedDescriptor = (Descriptors.Descriptor) descriptor; | ||
| return getMessageSchema(castedDescriptor, new CopyOnWriteMap<>(), flattenWrappedPrimitives, getNamespace(castedDescriptor.getFullName()), maxRecursionDepth); | ||
| } | ||
| } catch (Exception e) { | ||
| throw new RuntimeException(e); | ||
|
|
@@ -141,24 +147,40 @@ private Schema getEnumSchema(Descriptors.EnumDescriptor enumDescriptor) { | |
| return Schema.createEnum(enumDescriptor.getName(), null, getNamespace(enumDescriptor.getFullName()), symbols); | ||
| } | ||
|
|
||
| private Schema getMessageSchema(Descriptors.Descriptor descriptor, Map<Descriptors.Descriptor, Schema> seen, boolean flattenWrappedPrimitives) { | ||
| if (seen.containsKey(descriptor)) { | ||
| return seen.get(descriptor); | ||
| /** | ||
| * Translates a Proto Message descriptor into an Avro Schema | ||
| * @param descriptor the descriptor for the proto message | ||
| * @param recursionDepths a map of the descriptor to the number of times it has been encountered in this depth first traversal of the schema. | ||
| * This is used to cap the number of times we recurse on a schema. | ||
| * @param flattenWrappedPrimitives if true, treat wrapped primitives as nullable primitives, if false, treat them as proto messages | ||
| * @param path a string prefixed with the namespace of the original message being translated to avro and containing the current dot separated path tracking progress through the schema. | ||
| * This value is used for a namespace when creating Avro records to avoid an error when reusing the same class name when unraveling a recursive schema. | ||
| * @param maxRecursionDepth the number of times to unravel a recursive proto schema before spilling the rest to bytes | ||
| * @return an avro schema | ||
| */ | ||
| private Schema getMessageSchema(Descriptors.Descriptor descriptor, CopyOnWriteMap<Descriptors.Descriptor, Integer> recursionDepths, boolean flattenWrappedPrimitives, String path, | ||
| int maxRecursionDepth) { | ||
| // Parquet does not handle recursive schemas so we "unravel" the proto N levels | ||
| Integer currentRecursionCount = recursionDepths.getOrDefault(descriptor, 0); | ||
| if (currentRecursionCount >= maxRecursionDepth) { | ||
| return RECURSION_OVERFLOW_SCHEMA; | ||
| } | ||
| Schema result = Schema.createRecord(descriptor.getName(), null, | ||
| getNamespace(descriptor.getFullName()), false); | ||
| // The current path is used as a namespace to avoid record name collisions within recursive schemas | ||
| Schema result = Schema.createRecord(descriptor.getName(), null, path, false); | ||
|
|
||
| seen.put(descriptor, result); | ||
| recursionDepths.put(descriptor, ++currentRecursionCount); | ||
|
|
||
| List<Schema.Field> fields = new ArrayList<>(descriptor.getFields().size()); | ||
| for (Descriptors.FieldDescriptor f : descriptor.getFields()) { | ||
| fields.add(new Schema.Field(f.getName(), getFieldSchema(f, seen, flattenWrappedPrimitives), null, getDefault(f))); | ||
| // each branch of the schema traversal requires its own recursion depth tracking so copy the recursionDepths map | ||
| fields.add(new Schema.Field(f.getName(), getFieldSchema(f, new CopyOnWriteMap<>(recursionDepths), flattenWrappedPrimitives, path, maxRecursionDepth), null, getDefault(f))); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not too strong on the suggestion. will leave it to you to decide.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this will make the code harder to read. You would need to create copies of all entries related to your current path and add in the current field name to the path for your new keys and then re-lookup that path. I would prefer to leave this portion as is.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sg |
||
| } | ||
| result.setFields(fields); | ||
| return result; | ||
| } | ||
|
|
||
| private Schema getFieldSchema(Descriptors.FieldDescriptor f, Map<Descriptors.Descriptor, Schema> seen, boolean flattenWrappedPrimitives) { | ||
| private Schema getFieldSchema(Descriptors.FieldDescriptor f, CopyOnWriteMap<Descriptors.Descriptor, Integer> recursionDepths, boolean flattenWrappedPrimitives, String path, | ||
| int maxRecursionDepth) { | ||
| Function<Schema, Schema> schemaFinalizer = f.isRepeated() ? Schema::createArray : Function.identity(); | ||
| switch (f.getType()) { | ||
| case BOOL: | ||
|
|
@@ -188,16 +210,18 @@ private Schema getFieldSchema(Descriptors.FieldDescriptor f, Map<Descriptors.Des | |
| case SFIXED64: | ||
| return schemaFinalizer.apply(Schema.create(Schema.Type.LONG)); | ||
| case MESSAGE: | ||
| String updatedPath = appendFieldNameToPath(path, f.getName()); | ||
| if (flattenWrappedPrimitives && WRAPPER_DESCRIPTORS_TO_TYPE.containsKey(f.getMessageType())) { | ||
| // all wrapper types have a single field, so we can get the first field in the message's schema | ||
| return schemaFinalizer.apply(Schema.createUnion(Arrays.asList(NULL, getFieldSchema(f.getMessageType().getFields().get(0), seen, flattenWrappedPrimitives)))); | ||
| return schemaFinalizer.apply(Schema.createUnion(Arrays.asList(NULL_SCHEMA, getFieldSchema(f.getMessageType().getFields().get(0), recursionDepths, flattenWrappedPrimitives, updatedPath, | ||
| maxRecursionDepth)))); | ||
| } | ||
| // if message field is repeated (like a list), elements are non-null | ||
| if (f.isRepeated()) { | ||
| return schemaFinalizer.apply(getMessageSchema(f.getMessageType(), seen, flattenWrappedPrimitives)); | ||
| return schemaFinalizer.apply(getMessageSchema(f.getMessageType(), recursionDepths, flattenWrappedPrimitives, updatedPath, maxRecursionDepth)); | ||
| } | ||
| // otherwise we create a nullable field schema | ||
| return schemaFinalizer.apply(Schema.createUnion(Arrays.asList(NULL, getMessageSchema(f.getMessageType(), seen, flattenWrappedPrimitives)))); | ||
| return schemaFinalizer.apply(Schema.createUnion(Arrays.asList(NULL_SCHEMA, getMessageSchema(f.getMessageType(), recursionDepths, flattenWrappedPrimitives, updatedPath, maxRecursionDepth)))); | ||
| case GROUP: // groups are deprecated | ||
| default: | ||
| throw new RuntimeException("Unexpected type: " + f.getType()); | ||
|
|
@@ -255,6 +279,14 @@ private Object convertObject(Schema schema, Object value) { | |
| if (value == null) { | ||
| return null; | ||
| } | ||
| // if we've reached max recursion depth in the provided schema, write out message to bytes | ||
| if (RECURSION_OVERFLOW_SCHEMA.getFullName().equals(schema.getFullName())) { | ||
| GenericData.Record overflowRecord = new GenericData.Record(schema); | ||
| Message messageValue = (Message) value; | ||
| overflowRecord.put(OVERFLOW_DESCRIPTOR_FIELD_NAME, messageValue.getDescriptorForType().getFullName()); | ||
| overflowRecord.put(OVERFLOW_BYTES_FIELD_NAME, ByteBuffer.wrap(messageValue.toByteArray())); | ||
| return overflowRecord; | ||
| } | ||
|
|
||
| switch (schema.getType()) { | ||
| case ARRAY: | ||
|
|
@@ -305,7 +337,7 @@ private Object convertObject(Schema schema, Object value) { | |
| Map<Object, Object> mapValue = (Map) value; | ||
| Map<Object, Object> mapCopy = new HashMap<>(mapValue.size()); | ||
| for (Map.Entry<Object, Object> entry : mapValue.entrySet()) { | ||
| mapCopy.put(convertObject(STRINGS, entry.getKey()), convertObject(schema.getValueType(), entry.getValue())); | ||
| mapCopy.put(convertObject(STRING_SCHEMA, entry.getKey()), convertObject(schema.getValueType(), entry.getValue())); | ||
| } | ||
| return mapCopy; | ||
| case NULL: | ||
|
|
@@ -355,5 +387,38 @@ private String getNamespace(String descriptorFullName) { | |
| int lastDotIndex = descriptorFullName.lastIndexOf('.'); | ||
| return descriptorFullName.substring(0, lastDotIndex); | ||
| } | ||
|
|
||
| private String appendFieldNameToPath(String existingPath, String fieldName) { | ||
| return existingPath + "." + fieldName; | ||
| } | ||
|
|
||
| private static class SchemaCacheKey { | ||
| private final String className; | ||
| private final boolean flattenWrappedPrimitives; | ||
| private final int maxRecursionDepth; | ||
|
|
||
| SchemaCacheKey(Class clazz, boolean flattenWrappedPrimitives, int maxRecursionDepth) { | ||
| this.className = clazz.getName(); | ||
| this.flattenWrappedPrimitives = flattenWrappedPrimitives; | ||
| this.maxRecursionDepth = maxRecursionDepth; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object o) { | ||
| if (this == o) { | ||
| return true; | ||
| } | ||
| if (o == null || getClass() != o.getClass()) { | ||
| return false; | ||
| } | ||
| SchemaCacheKey that = (SchemaCacheKey) o; | ||
| return flattenWrappedPrimitives == that.flattenWrappedPrimitives && maxRecursionDepth == that.maxRecursionDepth && className.equals(that.className); | ||
| } | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| return Objects.hash(className, flattenWrappedPrimitives, maxRecursionDepth); | ||
| } | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why rename the field?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're not renaming the field, it's just a name for the schema/type. You can see the output here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L59? https://github.com/apache/hudi/pull/6761/files#diff-c4931766d3118fdebd1fb97fce4b5c36923f6bd66f6226adbfba1dd7ec2ddbd4R59
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's where the name of the type is coming through. Remember that the name of a record and the name of a field are different but both use "name" in the avro schema definition. Line 56 is the field name. The record name is more like a class name in java where it can be defined once and reused for different variables. You can see this down at line 141 where the type is simply
"type" : [ "null", "org.apache.hudi.proto.recursion_overflow" ]There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, got it. thank you for the explanation.