-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-42406][PROTOBUF] Fix recursive depth setting for Protobuf functions #40011
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,7 @@ import com.google.protobuf.{ByteString, DynamicMessage} | |
|
|
||
| import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row} | ||
| import org.apache.spark.sql.functions.{lit, struct} | ||
| import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{messageA, messageB, messageC, EM, EM2, Employee, EventPerson, EventRecursiveA, EventRecursiveB, EventWithRecursion, IC, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated} | ||
| import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{EM, EM2, Employee, EventPerson, EventPersonWrapper, EventRecursiveA, EventRecursiveB, IC, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated} | ||
| import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum | ||
| import org.apache.spark.sql.protobuf.utils.ProtobufUtils | ||
| import org.apache.spark.sql.test.SharedSparkSession | ||
|
|
@@ -39,6 +39,8 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | |
| val testFileDesc = testFile("functions_suite.desc", "protobuf/functions_suite.desc") | ||
| private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$" | ||
|
|
||
| private def emptyBinaryDF = Seq(Array[Byte]()).toDF("binary") | ||
|
|
||
| /** | ||
| * Runs the given closure twice. Once with descriptor file and second time with Java class name. | ||
| */ | ||
|
|
@@ -385,34 +387,12 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | |
| } | ||
| } | ||
|
|
||
| test("Handle recursive fields in Protobuf schema, A->B->A") { | ||
| val schemaA = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveA") | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This extra code to create actual records is not required since we are testing an AnalysisException. It does not need real data. |
||
| val schemaB = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveB") | ||
|
|
||
| val messageBForA = DynamicMessage | ||
| .newBuilder(schemaB) | ||
| .setField(schemaB.findFieldByName("keyB"), "key") | ||
| .build() | ||
|
|
||
| val messageA = DynamicMessage | ||
| .newBuilder(schemaA) | ||
| .setField(schemaA.findFieldByName("keyA"), "key") | ||
| .setField(schemaA.findFieldByName("messageB"), messageBForA) | ||
| .build() | ||
|
|
||
| val messageB = DynamicMessage | ||
| .newBuilder(schemaB) | ||
| .setField(schemaB.findFieldByName("keyB"), "key") | ||
| .setField(schemaB.findFieldByName("messageA"), messageA) | ||
| .build() | ||
|
|
||
| val df = Seq(messageB.toByteArray).toDF("messageB") | ||
|
|
||
| test("Recursive fields in Protobuf should result in an error (B -> A -> B)") { | ||
| checkWithFileAndClassName("recursiveB") { | ||
| case (name, descFilePathOpt) => | ||
| val e = intercept[AnalysisException] { | ||
| df.select( | ||
| from_protobuf_wrapper($"messageB", name, descFilePathOpt).as("messageFromProto")) | ||
| emptyBinaryDF.select( | ||
| from_protobuf_wrapper($"binary", name, descFilePathOpt).as("messageFromProto")) | ||
| .show() | ||
| } | ||
| assert(e.getMessage.contains( | ||
|
|
@@ -421,34 +401,12 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | |
| } | ||
| } | ||
|
|
||
| test("Handle recursive fields in Protobuf schema, C->D->Array(C)") { | ||
| val schemaC = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveC") | ||
| val schemaD = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveD") | ||
|
|
||
| val messageDForC = DynamicMessage | ||
| .newBuilder(schemaD) | ||
| .setField(schemaD.findFieldByName("keyD"), "key") | ||
| .build() | ||
|
|
||
| val messageC = DynamicMessage | ||
| .newBuilder(schemaC) | ||
| .setField(schemaC.findFieldByName("keyC"), "key") | ||
| .setField(schemaC.findFieldByName("messageD"), messageDForC) | ||
| .build() | ||
|
|
||
| val messageD = DynamicMessage | ||
| .newBuilder(schemaD) | ||
| .setField(schemaD.findFieldByName("keyD"), "key") | ||
| .addRepeatedField(schemaD.findFieldByName("messageC"), messageC) | ||
| .build() | ||
|
|
||
| val df = Seq(messageD.toByteArray).toDF("messageD") | ||
|
|
||
| test("Recursive fields in Protobuf should result in an error, C->D->Array(C)") { | ||
| checkWithFileAndClassName("recursiveD") { | ||
| case (name, descFilePathOpt) => | ||
| val e = intercept[AnalysisException] { | ||
| df.select( | ||
| from_protobuf_wrapper($"messageD", name, descFilePathOpt).as("messageFromProto")) | ||
| emptyBinaryDF.select( | ||
| from_protobuf_wrapper($"binary", name, descFilePathOpt).as("messageFromProto")) | ||
| .show() | ||
| } | ||
| assert(e.getMessage.contains( | ||
|
|
@@ -457,6 +415,22 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | |
| } | ||
| } | ||
|
|
||
| test("Setting depth to 0 or -1 should trigger error on recursive fields (B -> A -> B)") { | ||
| for (depth <- Seq("0", "-1")) { | ||
| val e = intercept[AnalysisException] { | ||
| emptyBinaryDF.select( | ||
| functions.from_protobuf( | ||
| $"binary", "recursiveB", testFileDesc, | ||
| Map("recursive.fields.max.depth" -> depth).asJava | ||
| ).as("messageFromProto") | ||
| ).show() | ||
| } | ||
| assert(e.getMessage.contains( | ||
| "Found recursive reference in Protobuf schema, which can not be processed by Spark" | ||
| )) | ||
| } | ||
| } | ||
|
|
||
| test("Handle extra fields : oldProducer -> newConsumer") { | ||
| val testFileDesc = testFile("catalyst_types.desc", "protobuf/catalyst_types.desc") | ||
| val oldProducer = ProtobufUtils.buildDescriptor(testFileDesc, "oldProducer") | ||
|
|
@@ -818,21 +792,11 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | |
| } | ||
|
|
||
| test("Fail for recursion field with complex schema without recursive.fields.max.depth") { | ||
| val aEventWithRecursion = EventWithRecursion.newBuilder().setKey(2).build() | ||
| val aaEventWithRecursion = EventWithRecursion.newBuilder().setKey(3).build() | ||
| val aaaEventWithRecursion = EventWithRecursion.newBuilder().setKey(4).build() | ||
| val c = messageC.newBuilder().setAaa(aaaEventWithRecursion).setKey(12092) | ||
| val b = messageB.newBuilder().setAa(aaEventWithRecursion).setC(c) | ||
| val a = messageA.newBuilder().setA(aEventWithRecursion).setB(b).build() | ||
| val eventWithRecursion = EventWithRecursion.newBuilder().setKey(1).setA(a).build() | ||
|
|
||
| val df = Seq(eventWithRecursion.toByteArray).toDF("protoEvent") | ||
|
|
||
| checkWithFileAndClassName("EventWithRecursion") { | ||
| case (name, descFilePathOpt) => | ||
| val e = intercept[AnalysisException] { | ||
| df.select( | ||
| from_protobuf_wrapper($"protoEvent", name, descFilePathOpt).as("messageFromProto")) | ||
| emptyBinaryDF.select( | ||
| from_protobuf_wrapper($"binary", name, descFilePathOpt).as("messageFromProto")) | ||
| .show() | ||
| } | ||
| assert(e.getMessage.contains( | ||
|
|
@@ -853,7 +817,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | |
|
|
||
| val df = Seq(employee.toByteArray).toDF("protoEvent") | ||
| val options = new java.util.HashMap[String, String]() | ||
| options.put("recursive.fields.max.depth", "1") | ||
| options.put("recursive.fields.max.depth", "2") | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to increase the setting because this PR fixes off by one with enforcement. |
||
|
|
||
| val fromProtoDf = df.select( | ||
| functions.from_protobuf($"protoEvent", "Employee", testFileDesc, options) as 'sample) | ||
|
|
@@ -908,7 +872,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | |
| val df = Seq(oneOfEventWithRecursion.toByteArray).toDF("value") | ||
|
|
||
| val options = new java.util.HashMap[String, String]() | ||
| options.put("recursive.fields.max.depth", "1") | ||
| options.put("recursive.fields.max.depth", "2") | ||
|
|
||
| val fromProtoDf = df.select( | ||
| functions.from_protobuf($"value", | ||
|
|
@@ -1128,15 +1092,15 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | |
| }) | ||
| } | ||
|
|
||
| test("Verify recursive.fields.max.depth Levels 0,1, and 2 with Simple Schema") { | ||
| test("Verify recursive.fields.max.depth Levels 1,2, and 3 with Simple Schema") { | ||
| val eventPerson3 = EventPerson.newBuilder().setName("person3").build() | ||
| val eventPerson2 = EventPerson.newBuilder().setName("person2").setBff(eventPerson3).build() | ||
| val eventPerson1 = EventPerson.newBuilder().setName("person1").setBff(eventPerson2).build() | ||
| val eventPerson0 = EventPerson.newBuilder().setName("person0").setBff(eventPerson1).build() | ||
| val df = Seq(eventPerson0.toByteArray).toDF("value") | ||
|
|
||
| val optionsZero = new java.util.HashMap[String, String]() | ||
| optionsZero.put("recursive.fields.max.depth", "0") | ||
| optionsZero.put("recursive.fields.max.depth", "1") | ||
| val schemaZero = DataType.fromJson( | ||
| s"""{ | ||
| | "type" : "struct", | ||
|
|
@@ -1160,10 +1124,10 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | |
| val expectedDfZero = spark.createDataFrame( | ||
| spark.sparkContext.parallelize(Seq(Row(Row("person0", null)))), schemaZero) | ||
|
|
||
| testFromProtobufWithOptions(df, expectedDfZero, optionsZero) | ||
| testFromProtobufWithOptions(df, expectedDfZero, optionsZero, "EventPerson") | ||
|
|
||
| val optionsOne = new java.util.HashMap[String, String]() | ||
| optionsOne.put("recursive.fields.max.depth", "1") | ||
| optionsOne.put("recursive.fields.max.depth", "2") | ||
| val schemaOne = DataType.fromJson( | ||
| s"""{ | ||
| | "type" : "struct", | ||
|
|
@@ -1197,10 +1161,10 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | |
| |}""".stripMargin).asInstanceOf[StructType] | ||
| val expectedDfOne = spark.createDataFrame( | ||
| spark.sparkContext.parallelize(Seq(Row(Row("person0", Row("person1", null))))), schemaOne) | ||
| testFromProtobufWithOptions(df, expectedDfOne, optionsOne) | ||
| testFromProtobufWithOptions(df, expectedDfOne, optionsOne, "EventPerson") | ||
|
|
||
| val optionsTwo = new java.util.HashMap[String, String]() | ||
| optionsTwo.put("recursive.fields.max.depth", "2") | ||
| optionsTwo.put("recursive.fields.max.depth", "3") | ||
| val schemaTwo = DataType.fromJson( | ||
| s"""{ | ||
| | "type" : "struct", | ||
|
|
@@ -1245,7 +1209,60 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | |
| |}""".stripMargin).asInstanceOf[StructType] | ||
| val expectedDfTwo = spark.createDataFrame(spark.sparkContext.parallelize( | ||
| Seq(Row(Row("person0", Row("person1", Row("person2", null)))))), schemaTwo) | ||
| testFromProtobufWithOptions(df, expectedDfTwo, optionsTwo) | ||
| testFromProtobufWithOptions(df, expectedDfTwo, optionsTwo, "EventPerson") | ||
|
|
||
| // Test recursive level 1 with EventPersonWrapper. In this case the top level struct | ||
| // 'EventPersonWrapper' itself does not recurse unlike 'EventPerson'. | ||
| // "bff" appears twice: Once allowed recursion and second time as terminated "null" type. | ||
| val wrapperSchemaOne = DataType.fromJson( | ||
| """ | ||
| |{ | ||
| | "type" : "struct", | ||
| | "fields" : [ { | ||
| | "name" : "sample", | ||
| | "type" : { | ||
| | "type" : "struct", | ||
| | "fields" : [ { | ||
| | "name" : "person", | ||
| | "type" : { | ||
| | "type" : "struct", | ||
| | "fields" : [ { | ||
| | "name" : "name", | ||
| | "type" : "string", | ||
| | "nullable" : true | ||
| | }, { | ||
| | "name" : "bff", | ||
| | "type" : { | ||
| | "type" : "struct", | ||
| | "fields" : [ { | ||
| | "name" : "name", | ||
| | "type" : "string", | ||
| | "nullable" : true | ||
| | }, { | ||
| | "name" : "bff", | ||
| | "type" : "void", | ||
| | "nullable" : true | ||
| | } ] | ||
| | }, | ||
| | "nullable" : true | ||
| | } ] | ||
| | }, | ||
| | "nullable" : true | ||
| | } ] | ||
| | }, | ||
| | "nullable" : true | ||
| | } ] | ||
| |} | ||
| |""".stripMargin).asInstanceOf[StructType] | ||
| val expectedWrapperDfOne = spark.createDataFrame( | ||
| spark.sparkContext.parallelize(Seq(Row(Row(Row("person0", Row("person1", null)))))), | ||
| wrapperSchemaOne) | ||
| testFromProtobufWithOptions( | ||
| Seq(EventPersonWrapper.newBuilder().setPerson(eventPerson0).build().toByteArray).toDF(), | ||
| expectedWrapperDfOne, | ||
| optionsOne, | ||
| "EventPersonWrapper" | ||
| ) | ||
| } | ||
|
|
||
| test("Verify exceptions are correctly propagated with errors") { | ||
|
|
@@ -1273,9 +1290,10 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | |
| def testFromProtobufWithOptions( | ||
| df: DataFrame, | ||
| expectedDf: DataFrame, | ||
| options: java.util.HashMap[String, String]): Unit = { | ||
| options: java.util.HashMap[String, String], | ||
| messageName: String): Unit = { | ||
| val fromProtoDf = df.select( | ||
| functions.from_protobuf($"value", "EventPerson", testFileDesc, options) as 'sample) | ||
| functions.from_protobuf($"value", messageName, testFileDesc, options) as 'sample) | ||
| assert(expectedDf.schema === fromProtoDf.schema) | ||
| checkAnswer(fromProtoDf, expectedDf) | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main bug fix:
We allowed recursive depth go over by one. This makes the setting of '0' does not work as documented (earlier documentation stated even the first occurrence would be removed it it was set to 0).
Removing the first occurrence does not make sense since we don't know if this field is recursive until the second occurrence.