-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-41396][SQL][PROTOBUF] OneOf field support and recursion checks #38922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
1266857
d38cc71
e2dc559
f0d2e5f
c8c7bd7
2337892
5340bb4
f71d1ea
b0eba7f
660c354
ae82005
a48f7d6
dd47096
231f0a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -157,6 +157,8 @@ private[sql] class ProtobufDeserializer( | |
|
|
||
| case (null, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal) | ||
|
|
||
| case (MESSAGE, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal) | ||
|
|
||
| // TODO: we can avoid boxing if future version of Protobuf provide primitive accessors. | ||
| case (BOOLEAN, BooleanType) => | ||
| (updater, ordinal, value) => updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) | ||
|
|
@@ -235,7 +237,7 @@ private[sql] class ProtobufDeserializer( | |
| writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage]) | ||
| updater.set(ordinal, row) | ||
|
|
||
| case (MESSAGE, ArrayType(st: StructType, containsNull)) => | ||
| case (MESSAGE, ArrayType(st: DataType, containsNull)) => | ||
|
||
| newArrayWriter(protoType, protoPath, catalystPath, st, containsNull) | ||
|
|
||
| case (ENUM, StringType) => | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,12 @@ private[sql] class ProtobufOptions( | |
|
|
||
| val parseMode: ParseMode = | ||
| parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode) | ||
|
|
||
| // Setting the circularReferenceDepth to 0 allows the field to be recursed once, setting | ||
| // it to 1 allows it to be recursed twice, and setting it to 2 allows it to be recursed | ||
| // thrice. circularReferenceDepth value greater than 2 is not allowed. If the not | ||
| // specified, it will default to -1, which disables recursive fields. | ||
|
||
| val circularReferenceDepth: Int = parameters.getOrElse("circularReferenceDepth", "-1").toInt | ||
|
||
| } | ||
|
|
||
| private[sql] object ProtobufOptions { | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -40,19 +40,25 @@ object SchemaConverters { | |||||
| * | ||||||
| * @since 3.4.0 | ||||||
| */ | ||||||
| def toSqlType(descriptor: Descriptor): SchemaType = { | ||||||
| toSqlTypeHelper(descriptor) | ||||||
| def toSqlType( | ||||||
| descriptor: Descriptor, | ||||||
| protobufOptions: ProtobufOptions = ProtobufOptions(Map.empty)): SchemaType = { | ||||||
| toSqlTypeHelper(descriptor, protobufOptions) | ||||||
| } | ||||||
|
|
||||||
| def toSqlTypeHelper(descriptor: Descriptor): SchemaType = ScalaReflectionLock.synchronized { | ||||||
| def toSqlTypeHelper( | ||||||
| descriptor: Descriptor, | ||||||
| protobufOptions: ProtobufOptions): SchemaType = ScalaReflectionLock.synchronized { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not related to this PR, but why would we lock
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I just noticed. Not sure if if we need. |
||||||
| SchemaType( | ||||||
| StructType(descriptor.getFields.asScala.flatMap(structFieldFor(_, Set.empty)).toArray), | ||||||
| StructType(descriptor.getFields.asScala.flatMap( | ||||||
| structFieldFor(_, Map.empty, protobufOptions: ProtobufOptions)).toArray), | ||||||
| nullable = true) | ||||||
| } | ||||||
|
|
||||||
| def structFieldFor( | ||||||
| fd: FieldDescriptor, | ||||||
| existingRecordNames: Set[String]): Option[StructField] = { | ||||||
| existingRecordNames: Map[String, Int], | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we add comments to explain what map key and value means here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan added a comment. |
||||||
| protobufOptions: ProtobufOptions): Option[StructField] = { | ||||||
| import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ | ||||||
| val dataType = fd.getJavaType match { | ||||||
| case INT => Some(IntegerType) | ||||||
|
|
@@ -81,9 +87,17 @@ object SchemaConverters { | |||||
| fd.getMessageType.getFields.forEach { field => | ||||||
| field.getName match { | ||||||
| case "key" => | ||||||
| keyType = structFieldFor(field, existingRecordNames).get.dataType | ||||||
| keyType = | ||||||
| structFieldFor( | ||||||
| field, | ||||||
| existingRecordNames, | ||||||
| protobufOptions).get.dataType | ||||||
| case "value" => | ||||||
| valueType = structFieldFor(field, existingRecordNames).get.dataType | ||||||
| valueType = | ||||||
| structFieldFor( | ||||||
| field, | ||||||
| existingRecordNames, | ||||||
| protobufOptions).get.dataType | ||||||
| } | ||||||
| } | ||||||
| return Option( | ||||||
|
|
@@ -92,14 +106,26 @@ object SchemaConverters { | |||||
| MapType(keyType, valueType, valueContainsNull = false).defaultConcreteType, | ||||||
| nullable = false)) | ||||||
| case MESSAGE => | ||||||
| if (existingRecordNames.contains(fd.getFullName)) { | ||||||
| // Setting the circularReferenceDepth to 0 allows the field to be recursed once, setting | ||||||
| // it to 1 allows it to be recursed twice, and setting it to 2 allows it to be recursed | ||||||
| // thrice. circularReferenceDepth value greater than 2 is not allowed. If the not | ||||||
| // specified, it will default to -1, which disables recursive fields. | ||||||
| val recordName = fd.getMessageType.getFullName | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
are they same? The previous code uses There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. I think the previous code was incorrect. We need to verify if a same Protobuf type was seen before in this DFS traversal.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan fd.getFullName gives a fully qualified name along with a field name, we needed the fully qualified type name. we made this decision above. here is the difference. @rangadi previous code |
||||||
| if (existingRecordNames.contains(recordName) && | ||||||
|
||||||
| protobufOptions.circularReferenceDepth < 0 ) { | ||||||
| throw QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString()) | ||||||
| } else if (existingRecordNames.contains(recordName) && | ||||||
| existingRecordNames.getOrElse(recordName, 0) | ||||||
| > protobufOptions.circularReferenceDepth) { | ||||||
| return Some(StructField(fd.getName, NullType, nullable = false)) | ||||||
|
||||||
| } | ||||||
| val newRecordNames = existingRecordNames + fd.getFullName | ||||||
|
|
||||||
| val newRecordNames = existingRecordNames + | ||||||
| (recordName -> (existingRecordNames.getOrElse(recordName, 0) + 1)) | ||||||
|
|
||||||
| Option( | ||||||
| fd.getMessageType.getFields.asScala | ||||||
| .flatMap(structFieldFor(_, newRecordNames)) | ||||||
| .flatMap(structFieldFor(_, newRecordNames, protobufOptions)) | ||||||
| .toSeq) | ||||||
| .filter(_.nonEmpty) | ||||||
| .map(StructType.apply) | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -170,4 +170,118 @@ message timeStampMsg { | |
| message durationMsg { | ||
| string key = 1; | ||
| Duration duration = 2; | ||
| } | ||
| } | ||
|
|
||
| message OneOfEvent { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you testing more OneOf and recusion in the same message? Could you split them into separate messages?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rangadi I see a lot of use cases for the "payload" Oneof the field and recursive fields in it. So I thought combining Oneof with recursion would be a good test. will separate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Combined one is fine, we could keep it. Better to have a simpler separate tests as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
| string key = 1; | ||
| oneof payload { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do one-of fields look like in spark schema? Could you give an example? I could not see the schema in the unit tests.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rangadi the "Oneof" field is of message type, Oneof will be converted to a struct type. |
||
| int32 col_1 = 2; | ||
| string col_2 = 3; | ||
| int64 col_3 = 4; | ||
| } | ||
| repeated string col_4 = 5; | ||
| } | ||
|
|
||
| message EventWithRecursion { | ||
| int32 key = 1; | ||
| messageA a = 2; | ||
| } | ||
| message messageA { | ||
| EventWithRecursion a = 1; | ||
| messageB b = 2; | ||
| } | ||
| message messageB { | ||
| EventWithRecursion aa = 1; | ||
| messageC c = 2; | ||
| } | ||
| message messageC { | ||
| EventWithRecursion aaa = 1; | ||
| int32 key= 2; | ||
| } | ||
|
|
||
| message Employee { | ||
| string firstName = 1; | ||
| string lastName = 2; | ||
| oneof role { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need so many fields for 'OneOf'? How about just 2 or 3? It will simplify testing. |
||
| IC ic = 3; | ||
| EM em = 4; | ||
| EM2 em2 = 5; | ||
| Director dir = 6; | ||
| SeniorDirector sDir = 7; | ||
| VP vp = 8; | ||
| SVP svp = 9; | ||
| CTO cto = 10; | ||
| CEO ceo = 11; | ||
| } | ||
| } | ||
|
|
||
| message IC { | ||
| repeated string skills = 1; | ||
| Employee icManager = 2; // EM or EM2 or Director.. | ||
| } | ||
|
|
||
| message EM { | ||
| int64 teamsize = 1; | ||
| Employee emManager = 2; // EM2 or Director.. | ||
| } | ||
|
|
||
| message EM2 { | ||
| int64 teamsize = 1; | ||
| Employee em2Manager = 2; // Director or Senior Director.. | ||
| } | ||
|
|
||
| message Director { | ||
| int64 teamsize = 1; | ||
| Employee dirManager = 2; // Senior Director or VP.. | ||
| } | ||
|
|
||
| message SeniorDirector { | ||
| int64 teamsize = 1; | ||
| Employee sdManager = 2; // VP or SVP... | ||
| } | ||
|
|
||
| message VP { | ||
| int64 teamsize = 1; | ||
| Employee vpManager = 2; // SVP or CTO... | ||
| } | ||
|
|
||
| message SVP { | ||
| int64 teamsize = 1; | ||
| Employee svpManager = 2; // CTO or CET | ||
| } | ||
|
|
||
| message CTO { | ||
| int64 teamsize = 1; | ||
| Employee ctoManager = 2; // CEO | ||
| } | ||
|
|
||
| message CEO { | ||
| int64 teamsize = 1; | ||
| Employee ceoManager = 2; // null | ||
| } | ||
|
|
||
| message OneOfEventWithRecursion { | ||
| string key = 1; | ||
| oneof payload { | ||
| EventRecursiveA recursiveA = 3; | ||
| EventRecursiveB recursiveB = 6; | ||
| } | ||
| string value = 7; | ||
| } | ||
|
|
||
| message EventRecursiveA { | ||
| OneOfEventWithRecursion recursiveA = 1; | ||
| string key = 2; | ||
| } | ||
|
|
||
| message EventRecursiveB { | ||
| string key = 1; | ||
| string value = 2; | ||
| OneOfEventWithRecursion recursiveA = 3; | ||
| } | ||
|
|
||
| message Status { | ||
| int32 id = 1; | ||
| Timestamp trade_time = 2; | ||
| Status status = 3; | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this for? For handling limited recursion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment about we might be dropping data here? It will not be easy to see for a future reader.
We could have an option to error our if the actual data has more recursion than the configure.