diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala index 52f9f74bd432..53036668ebf5 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala @@ -39,12 +39,12 @@ private[sql] class ProtobufOptions( val parseMode: ParseMode = parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode) - // Setting the `recursive.fields.max.depth` to 0 drops all recursive fields, - // 1 allows it to be recurse once, and 2 allows it to be recursed twice and so on. - // A value of `recursive.fields.max.depth` greater than 10 is not permitted. If it is not - // specified, the default value is -1; recursive fields are not permitted. If a protobuf + // Setting the `recursive.fields.max.depth` to 1 allows it to be recurse once, + // and 2 allows it to be recursed twice and so on. A value of `recursive.fields.max.depth` + // greater than 10 is not permitted. If it is not specified, the default value is -1; + // A value of 0 or below disallows any recursive fields. If a protobuf // record has more depth than the allowed value for recursive fields, it will be truncated - // and some fields may be discarded. + // and corresponding fields are ignored (dropped). val recursiveFieldMaxDepth: Int = parameters.getOrElse("recursive.fields.max.depth", "-1").toInt } diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala index bb4aa492f5c4..9666e34bab49 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala @@ -118,17 +118,17 @@ object SchemaConverters { // discarded. // SQL Schema for the protobuf message `message Person { string name = 1; Person bff = 2}` // will vary based on the value of "recursive.fields.max.depth". - // 0: struct - // 1: struct> - // 2: struct>> ... + // 1: struct + // 2: struct> + // 3: struct>> ... val recordName = fd.getMessageType.getFullName val recursiveDepth = existingRecordNames.getOrElse(recordName, 0) val recursiveFieldMaxDepth = protobufOptions.recursiveFieldMaxDepth - if (existingRecordNames.contains(recordName) && (recursiveFieldMaxDepth < 0 || + if (existingRecordNames.contains(recordName) && (recursiveFieldMaxDepth <= 0 || recursiveFieldMaxDepth > 10)) { throw QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString()) } else if (existingRecordNames.contains(recordName) && - recursiveDepth > recursiveFieldMaxDepth) { + recursiveDepth >= recursiveFieldMaxDepth) { Some(NullType) } else { val newRecordNames = existingRecordNames + (recordName -> (recursiveDepth + 1)) diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc index 135d489f5206..d16f89350805 100644 Binary files a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc and b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc differ diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto index 449f1b68bb8f..a0698ee39799 100644 --- a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto +++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto @@ -224,11 +224,15 @@ message EM2 { Employee em2Manager = 2; } -message EventPerson { +message EventPerson { // Used for simple recursive field testing. string name = 1; EventPerson bff = 2; } +message EventPersonWrapper { + EventPerson person = 1; +} + message OneOfEventWithRecursion { string key = 1; oneof payload { diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 30b38eafd781..60e13644fc6c 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -25,7 +25,7 @@ import com.google.protobuf.{ByteString, DynamicMessage} import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row} import org.apache.spark.sql.functions.{lit, struct} -import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{messageA, messageB, messageC, EM, EM2, Employee, EventPerson, EventRecursiveA, EventRecursiveB, EventWithRecursion, IC, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated} +import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{EM, EM2, Employee, EventPerson, EventPersonWrapper, EventRecursiveA, EventRecursiveB, IC, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated} import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum import org.apache.spark.sql.protobuf.utils.ProtobufUtils import org.apache.spark.sql.test.SharedSparkSession @@ -39,6 +39,8 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot val testFileDesc = testFile("functions_suite.desc", "protobuf/functions_suite.desc") private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$" + private def emptyBinaryDF = Seq(Array[Byte]()).toDF("binary") + /** * Runs the given closure twice. Once with descriptor file and second time with Java class name. */ @@ -385,34 +387,12 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } } - test("Handle recursive fields in Protobuf schema, A->B->A") { - val schemaA = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveA") - val schemaB = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveB") - - val messageBForA = DynamicMessage - .newBuilder(schemaB) - .setField(schemaB.findFieldByName("keyB"), "key") - .build() - - val messageA = DynamicMessage - .newBuilder(schemaA) - .setField(schemaA.findFieldByName("keyA"), "key") - .setField(schemaA.findFieldByName("messageB"), messageBForA) - .build() - - val messageB = DynamicMessage - .newBuilder(schemaB) - .setField(schemaB.findFieldByName("keyB"), "key") - .setField(schemaB.findFieldByName("messageA"), messageA) - .build() - - val df = Seq(messageB.toByteArray).toDF("messageB") - + test("Recursive fields in Protobuf should result in an error (B -> A -> B)") { checkWithFileAndClassName("recursiveB") { case (name, descFilePathOpt) => val e = intercept[AnalysisException] { - df.select( - from_protobuf_wrapper($"messageB", name, descFilePathOpt).as("messageFromProto")) + emptyBinaryDF.select( + from_protobuf_wrapper($"binary", name, descFilePathOpt).as("messageFromProto")) .show() } assert(e.getMessage.contains( @@ -421,34 +401,12 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } } - test("Handle recursive fields in Protobuf schema, C->D->Array(C)") { - val schemaC = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveC") - val schemaD = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveD") - - val messageDForC = DynamicMessage - .newBuilder(schemaD) - .setField(schemaD.findFieldByName("keyD"), "key") - .build() - - val messageC = DynamicMessage - .newBuilder(schemaC) - .setField(schemaC.findFieldByName("keyC"), "key") - .setField(schemaC.findFieldByName("messageD"), messageDForC) - .build() - - val messageD = DynamicMessage - .newBuilder(schemaD) - .setField(schemaD.findFieldByName("keyD"), "key") - .addRepeatedField(schemaD.findFieldByName("messageC"), messageC) - .build() - - val df = Seq(messageD.toByteArray).toDF("messageD") - + test("Recursive fields in Protobuf should result in an error, C->D->Array(C)") { checkWithFileAndClassName("recursiveD") { case (name, descFilePathOpt) => val e = intercept[AnalysisException] { - df.select( - from_protobuf_wrapper($"messageD", name, descFilePathOpt).as("messageFromProto")) + emptyBinaryDF.select( + from_protobuf_wrapper($"binary", name, descFilePathOpt).as("messageFromProto")) .show() } assert(e.getMessage.contains( @@ -457,6 +415,22 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } } + test("Setting depth to 0 or -1 should trigger error on recursive fields (B -> A -> B)") { + for (depth <- Seq("0", "-1")) { + val e = intercept[AnalysisException] { + emptyBinaryDF.select( + functions.from_protobuf( + $"binary", "recursiveB", testFileDesc, + Map("recursive.fields.max.depth" -> depth).asJava + ).as("messageFromProto") + ).show() + } + assert(e.getMessage.contains( + "Found recursive reference in Protobuf schema, which can not be processed by Spark" + )) + } + } + test("Handle extra fields : oldProducer -> newConsumer") { val testFileDesc = testFile("catalyst_types.desc", "protobuf/catalyst_types.desc") val oldProducer = ProtobufUtils.buildDescriptor(testFileDesc, "oldProducer") @@ -818,21 +792,11 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } test("Fail for recursion field with complex schema without recursive.fields.max.depth") { - val aEventWithRecursion = EventWithRecursion.newBuilder().setKey(2).build() - val aaEventWithRecursion = EventWithRecursion.newBuilder().setKey(3).build() - val aaaEventWithRecursion = EventWithRecursion.newBuilder().setKey(4).build() - val c = messageC.newBuilder().setAaa(aaaEventWithRecursion).setKey(12092) - val b = messageB.newBuilder().setAa(aaEventWithRecursion).setC(c) - val a = messageA.newBuilder().setA(aEventWithRecursion).setB(b).build() - val eventWithRecursion = EventWithRecursion.newBuilder().setKey(1).setA(a).build() - - val df = Seq(eventWithRecursion.toByteArray).toDF("protoEvent") - checkWithFileAndClassName("EventWithRecursion") { case (name, descFilePathOpt) => val e = intercept[AnalysisException] { - df.select( - from_protobuf_wrapper($"protoEvent", name, descFilePathOpt).as("messageFromProto")) + emptyBinaryDF.select( + from_protobuf_wrapper($"binary", name, descFilePathOpt).as("messageFromProto")) .show() } assert(e.getMessage.contains( @@ -853,7 +817,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot val df = Seq(employee.toByteArray).toDF("protoEvent") val options = new java.util.HashMap[String, String]() - options.put("recursive.fields.max.depth", "1") + options.put("recursive.fields.max.depth", "2") val fromProtoDf = df.select( functions.from_protobuf($"protoEvent", "Employee", testFileDesc, options) as 'sample) @@ -908,7 +872,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot val df = Seq(oneOfEventWithRecursion.toByteArray).toDF("value") val options = new java.util.HashMap[String, String]() - options.put("recursive.fields.max.depth", "1") + options.put("recursive.fields.max.depth", "2") val fromProtoDf = df.select( functions.from_protobuf($"value", @@ -1128,7 +1092,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot }) } - test("Verify recursive.fields.max.depth Levels 0,1, and 2 with Simple Schema") { + test("Verify recursive.fields.max.depth Levels 1,2, and 3 with Simple Schema") { val eventPerson3 = EventPerson.newBuilder().setName("person3").build() val eventPerson2 = EventPerson.newBuilder().setName("person2").setBff(eventPerson3).build() val eventPerson1 = EventPerson.newBuilder().setName("person1").setBff(eventPerson2).build() @@ -1136,7 +1100,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot val df = Seq(eventPerson0.toByteArray).toDF("value") val optionsZero = new java.util.HashMap[String, String]() - optionsZero.put("recursive.fields.max.depth", "0") + optionsZero.put("recursive.fields.max.depth", "1") val schemaZero = DataType.fromJson( s"""{ | "type" : "struct", @@ -1160,10 +1124,10 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot val expectedDfZero = spark.createDataFrame( spark.sparkContext.parallelize(Seq(Row(Row("person0", null)))), schemaZero) - testFromProtobufWithOptions(df, expectedDfZero, optionsZero) + testFromProtobufWithOptions(df, expectedDfZero, optionsZero, "EventPerson") val optionsOne = new java.util.HashMap[String, String]() - optionsOne.put("recursive.fields.max.depth", "1") + optionsOne.put("recursive.fields.max.depth", "2") val schemaOne = DataType.fromJson( s"""{ | "type" : "struct", @@ -1197,10 +1161,10 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot |}""".stripMargin).asInstanceOf[StructType] val expectedDfOne = spark.createDataFrame( spark.sparkContext.parallelize(Seq(Row(Row("person0", Row("person1", null))))), schemaOne) - testFromProtobufWithOptions(df, expectedDfOne, optionsOne) + testFromProtobufWithOptions(df, expectedDfOne, optionsOne, "EventPerson") val optionsTwo = new java.util.HashMap[String, String]() - optionsTwo.put("recursive.fields.max.depth", "2") + optionsTwo.put("recursive.fields.max.depth", "3") val schemaTwo = DataType.fromJson( s"""{ | "type" : "struct", @@ -1245,7 +1209,60 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot |}""".stripMargin).asInstanceOf[StructType] val expectedDfTwo = spark.createDataFrame(spark.sparkContext.parallelize( Seq(Row(Row("person0", Row("person1", Row("person2", null)))))), schemaTwo) - testFromProtobufWithOptions(df, expectedDfTwo, optionsTwo) + testFromProtobufWithOptions(df, expectedDfTwo, optionsTwo, "EventPerson") + + // Test recursive level 1 with EventPersonWrapper. In this case the top level struct + // 'EventPersonWrapper' itself does not recurse unlike 'EventPerson'. + // "bff" appears twice: Once allowed recursion and second time as terminated "null" type. + val wrapperSchemaOne = DataType.fromJson( + """ + |{ + | "type" : "struct", + | "fields" : [ { + | "name" : "sample", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "person", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "name", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "bff", + | "type" : { + | "type" : "struct", + | "fields" : [ { + | "name" : "name", + | "type" : "string", + | "nullable" : true + | }, { + | "name" : "bff", + | "type" : "void", + | "nullable" : true + | } ] + | }, + | "nullable" : true + | } ] + | }, + | "nullable" : true + | } ] + | }, + | "nullable" : true + | } ] + |} + |""".stripMargin).asInstanceOf[StructType] + val expectedWrapperDfOne = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(Row(Row("person0", Row("person1", null)))))), + wrapperSchemaOne) + testFromProtobufWithOptions( + Seq(EventPersonWrapper.newBuilder().setPerson(eventPerson0).build().toByteArray).toDF(), + expectedWrapperDfOne, + optionsOne, + "EventPersonWrapper" + ) } test("Verify exceptions are correctly propagated with errors") { @@ -1273,9 +1290,10 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot def testFromProtobufWithOptions( df: DataFrame, expectedDf: DataFrame, - options: java.util.HashMap[String, String]): Unit = { + options: java.util.HashMap[String, String], + messageName: String): Unit = { val fromProtoDf = df.select( - functions.from_protobuf($"value", "EventPerson", testFileDesc, options) as 'sample) + functions.from_protobuf($"value", messageName, testFileDesc, options) as 'sample) assert(expectedDf.schema === fromProtoDf.schema) checkAnswer(fromProtoDf, expectedDf) }