Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ private[sql] class ProtobufOptions(
val parseMode: ParseMode =
parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode)

// Setting the `recursive.fields.max.depth` to 0 drops all recursive fields,
// 1 allows it to be recurse once, and 2 allows it to be recursed twice and so on.
// A value of `recursive.fields.max.depth` greater than 10 is not permitted. If it is not
// specified, the default value is -1; recursive fields are not permitted. If a protobuf
// Setting the `recursive.fields.max.depth` to 1 allows it to be recurse once,
// and 2 allows it to be recursed twice and so on. A value of `recursive.fields.max.depth`
// greater than 10 is not permitted. If it is not specified, the default value is -1;
// A value of 0 or below disallows any recursive fields. If a protobuf
// record has more depth than the allowed value for recursive fields, it will be truncated
// and some fields may be discarded.
// and corresponding fields are ignored (dropped).
val recursiveFieldMaxDepth: Int = parameters.getOrElse("recursive.fields.max.depth", "-1").toInt
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,17 @@ object SchemaConverters {
// discarded.
// SQL Schema for the protobuf message `message Person { string name = 1; Person bff = 2}`
// will vary based on the value of "recursive.fields.max.depth".
// 0: struct<name: string, bff: null>
// 1: struct<name string, bff: <name: string, bff: null>>
// 2: struct<name string, bff: <name: string, bff: struct<name: string, bff: null>>> ...
// 1: struct<name: string, bff: null>
// 2: struct<name string, bff: <name: string, bff: null>>
// 3: struct<name string, bff: <name: string, bff: struct<name: string, bff: null>>> ...
val recordName = fd.getMessageType.getFullName
val recursiveDepth = existingRecordNames.getOrElse(recordName, 0)
val recursiveFieldMaxDepth = protobufOptions.recursiveFieldMaxDepth
if (existingRecordNames.contains(recordName) && (recursiveFieldMaxDepth < 0 ||
if (existingRecordNames.contains(recordName) && (recursiveFieldMaxDepth <= 0 ||
recursiveFieldMaxDepth > 10)) {
throw QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
} else if (existingRecordNames.contains(recordName) &&
recursiveDepth > recursiveFieldMaxDepth) {
recursiveDepth >= recursiveFieldMaxDepth) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main bug fix:
We allowed recursive depth go over by one. This makes the setting of '0' does not work as documented (earlier documentation stated even the first occurrence would be removed it it was set to 0).
Removing the first occurrence does not make sense since we don't know if this field is recursive until the second occurrence.

Some(NullType)
} else {
val newRecordNames = existingRecordNames + (recordName -> (recursiveDepth + 1))
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,15 @@ message EM2 {
Employee em2Manager = 2;
}

message EventPerson {
message EventPerson { // Used for simple recursive field testing.
string name = 1;
EventPerson bff = 2;
}

message EventPersonWrapper {
EventPerson person = 1;
}

message OneOfEventWithRecursion {
string key = 1;
oneof payload {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.google.protobuf.{ByteString, DynamicMessage}

import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row}
import org.apache.spark.sql.functions.{lit, struct}
import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{messageA, messageB, messageC, EM, EM2, Employee, EventPerson, EventRecursiveA, EventRecursiveB, EventWithRecursion, IC, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated}
import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{EM, EM2, Employee, EventPerson, EventPersonWrapper, EventRecursiveA, EventRecursiveB, IC, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated}
import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum
import org.apache.spark.sql.protobuf.utils.ProtobufUtils
import org.apache.spark.sql.test.SharedSparkSession
Expand All @@ -39,6 +39,8 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
val testFileDesc = testFile("functions_suite.desc", "protobuf/functions_suite.desc")
private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$"

private def emptyBinaryDF = Seq(Array[Byte]()).toDF("binary")

/**
* Runs the given closure twice. Once with descriptor file and second time with Java class name.
*/
Expand Down Expand Up @@ -385,34 +387,12 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
}
}

test("Handle recursive fields in Protobuf schema, A->B->A") {
val schemaA = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveA")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This extra code to create actual records is not required since we are testing an AnalysisException. It does not need real data.

val schemaB = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveB")

val messageBForA = DynamicMessage
.newBuilder(schemaB)
.setField(schemaB.findFieldByName("keyB"), "key")
.build()

val messageA = DynamicMessage
.newBuilder(schemaA)
.setField(schemaA.findFieldByName("keyA"), "key")
.setField(schemaA.findFieldByName("messageB"), messageBForA)
.build()

val messageB = DynamicMessage
.newBuilder(schemaB)
.setField(schemaB.findFieldByName("keyB"), "key")
.setField(schemaB.findFieldByName("messageA"), messageA)
.build()

val df = Seq(messageB.toByteArray).toDF("messageB")

test("Recursive fields in Protobuf should result in an error (B -> A -> B)") {
checkWithFileAndClassName("recursiveB") {
case (name, descFilePathOpt) =>
val e = intercept[AnalysisException] {
df.select(
from_protobuf_wrapper($"messageB", name, descFilePathOpt).as("messageFromProto"))
emptyBinaryDF.select(
from_protobuf_wrapper($"binary", name, descFilePathOpt).as("messageFromProto"))
.show()
}
assert(e.getMessage.contains(
Expand All @@ -421,34 +401,12 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
}
}

test("Handle recursive fields in Protobuf schema, C->D->Array(C)") {
val schemaC = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveC")
val schemaD = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveD")

val messageDForC = DynamicMessage
.newBuilder(schemaD)
.setField(schemaD.findFieldByName("keyD"), "key")
.build()

val messageC = DynamicMessage
.newBuilder(schemaC)
.setField(schemaC.findFieldByName("keyC"), "key")
.setField(schemaC.findFieldByName("messageD"), messageDForC)
.build()

val messageD = DynamicMessage
.newBuilder(schemaD)
.setField(schemaD.findFieldByName("keyD"), "key")
.addRepeatedField(schemaD.findFieldByName("messageC"), messageC)
.build()

val df = Seq(messageD.toByteArray).toDF("messageD")

test("Recursive fields in Protobuf should result in an error, C->D->Array(C)") {
checkWithFileAndClassName("recursiveD") {
case (name, descFilePathOpt) =>
val e = intercept[AnalysisException] {
df.select(
from_protobuf_wrapper($"messageD", name, descFilePathOpt).as("messageFromProto"))
emptyBinaryDF.select(
from_protobuf_wrapper($"binary", name, descFilePathOpt).as("messageFromProto"))
.show()
}
assert(e.getMessage.contains(
Expand All @@ -457,6 +415,22 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
}
}

test("Setting depth to 0 or -1 should trigger error on recursive fields (B -> A -> B)") {
for (depth <- Seq("0", "-1")) {
val e = intercept[AnalysisException] {
emptyBinaryDF.select(
functions.from_protobuf(
$"binary", "recursiveB", testFileDesc,
Map("recursive.fields.max.depth" -> depth).asJava
).as("messageFromProto")
).show()
}
assert(e.getMessage.contains(
"Found recursive reference in Protobuf schema, which can not be processed by Spark"
))
}
}

test("Handle extra fields : oldProducer -> newConsumer") {
val testFileDesc = testFile("catalyst_types.desc", "protobuf/catalyst_types.desc")
val oldProducer = ProtobufUtils.buildDescriptor(testFileDesc, "oldProducer")
Expand Down Expand Up @@ -818,21 +792,11 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
}

test("Fail for recursion field with complex schema without recursive.fields.max.depth") {
val aEventWithRecursion = EventWithRecursion.newBuilder().setKey(2).build()
val aaEventWithRecursion = EventWithRecursion.newBuilder().setKey(3).build()
val aaaEventWithRecursion = EventWithRecursion.newBuilder().setKey(4).build()
val c = messageC.newBuilder().setAaa(aaaEventWithRecursion).setKey(12092)
val b = messageB.newBuilder().setAa(aaEventWithRecursion).setC(c)
val a = messageA.newBuilder().setA(aEventWithRecursion).setB(b).build()
val eventWithRecursion = EventWithRecursion.newBuilder().setKey(1).setA(a).build()

val df = Seq(eventWithRecursion.toByteArray).toDF("protoEvent")

checkWithFileAndClassName("EventWithRecursion") {
case (name, descFilePathOpt) =>
val e = intercept[AnalysisException] {
df.select(
from_protobuf_wrapper($"protoEvent", name, descFilePathOpt).as("messageFromProto"))
emptyBinaryDF.select(
from_protobuf_wrapper($"binary", name, descFilePathOpt).as("messageFromProto"))
.show()
}
assert(e.getMessage.contains(
Expand All @@ -853,7 +817,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot

val df = Seq(employee.toByteArray).toDF("protoEvent")
val options = new java.util.HashMap[String, String]()
options.put("recursive.fields.max.depth", "1")
options.put("recursive.fields.max.depth", "2")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to increase the setting because this PR fixes off by one with enforcement.


val fromProtoDf = df.select(
functions.from_protobuf($"protoEvent", "Employee", testFileDesc, options) as 'sample)
Expand Down Expand Up @@ -908,7 +872,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
val df = Seq(oneOfEventWithRecursion.toByteArray).toDF("value")

val options = new java.util.HashMap[String, String]()
options.put("recursive.fields.max.depth", "1")
options.put("recursive.fields.max.depth", "2")

val fromProtoDf = df.select(
functions.from_protobuf($"value",
Expand Down Expand Up @@ -1128,15 +1092,15 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
})
}

test("Verify recursive.fields.max.depth Levels 0,1, and 2 with Simple Schema") {
test("Verify recursive.fields.max.depth Levels 1,2, and 3 with Simple Schema") {
val eventPerson3 = EventPerson.newBuilder().setName("person3").build()
val eventPerson2 = EventPerson.newBuilder().setName("person2").setBff(eventPerson3).build()
val eventPerson1 = EventPerson.newBuilder().setName("person1").setBff(eventPerson2).build()
val eventPerson0 = EventPerson.newBuilder().setName("person0").setBff(eventPerson1).build()
val df = Seq(eventPerson0.toByteArray).toDF("value")

val optionsZero = new java.util.HashMap[String, String]()
optionsZero.put("recursive.fields.max.depth", "0")
optionsZero.put("recursive.fields.max.depth", "1")
val schemaZero = DataType.fromJson(
s"""{
| "type" : "struct",
Expand All @@ -1160,10 +1124,10 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
val expectedDfZero = spark.createDataFrame(
spark.sparkContext.parallelize(Seq(Row(Row("person0", null)))), schemaZero)

testFromProtobufWithOptions(df, expectedDfZero, optionsZero)
testFromProtobufWithOptions(df, expectedDfZero, optionsZero, "EventPerson")

val optionsOne = new java.util.HashMap[String, String]()
optionsOne.put("recursive.fields.max.depth", "1")
optionsOne.put("recursive.fields.max.depth", "2")
val schemaOne = DataType.fromJson(
s"""{
| "type" : "struct",
Expand Down Expand Up @@ -1197,10 +1161,10 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
|}""".stripMargin).asInstanceOf[StructType]
val expectedDfOne = spark.createDataFrame(
spark.sparkContext.parallelize(Seq(Row(Row("person0", Row("person1", null))))), schemaOne)
testFromProtobufWithOptions(df, expectedDfOne, optionsOne)
testFromProtobufWithOptions(df, expectedDfOne, optionsOne, "EventPerson")

val optionsTwo = new java.util.HashMap[String, String]()
optionsTwo.put("recursive.fields.max.depth", "2")
optionsTwo.put("recursive.fields.max.depth", "3")
val schemaTwo = DataType.fromJson(
s"""{
| "type" : "struct",
Expand Down Expand Up @@ -1245,7 +1209,60 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
|}""".stripMargin).asInstanceOf[StructType]
val expectedDfTwo = spark.createDataFrame(spark.sparkContext.parallelize(
Seq(Row(Row("person0", Row("person1", Row("person2", null)))))), schemaTwo)
testFromProtobufWithOptions(df, expectedDfTwo, optionsTwo)
testFromProtobufWithOptions(df, expectedDfTwo, optionsTwo, "EventPerson")

// Test recursive level 1 with EventPersonWrapper. In this case the top level struct
// 'EventPersonWrapper' itself does not recurse unlike 'EventPerson'.
// "bff" appears twice: Once allowed recursion and second time as terminated "null" type.
val wrapperSchemaOne = DataType.fromJson(
"""
|{
| "type" : "struct",
| "fields" : [ {
| "name" : "sample",
| "type" : {
| "type" : "struct",
| "fields" : [ {
| "name" : "person",
| "type" : {
| "type" : "struct",
| "fields" : [ {
| "name" : "name",
| "type" : "string",
| "nullable" : true
| }, {
| "name" : "bff",
| "type" : {
| "type" : "struct",
| "fields" : [ {
| "name" : "name",
| "type" : "string",
| "nullable" : true
| }, {
| "name" : "bff",
| "type" : "void",
| "nullable" : true
| } ]
| },
| "nullable" : true
| } ]
| },
| "nullable" : true
| } ]
| },
| "nullable" : true
| } ]
|}
|""".stripMargin).asInstanceOf[StructType]
val expectedWrapperDfOne = spark.createDataFrame(
spark.sparkContext.parallelize(Seq(Row(Row(Row("person0", Row("person1", null)))))),
wrapperSchemaOne)
testFromProtobufWithOptions(
Seq(EventPersonWrapper.newBuilder().setPerson(eventPerson0).build().toByteArray).toDF(),
expectedWrapperDfOne,
optionsOne,
"EventPersonWrapper"
)
}

test("Verify exceptions are correctly propagated with errors") {
Expand Down Expand Up @@ -1273,9 +1290,10 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
def testFromProtobufWithOptions(
df: DataFrame,
expectedDf: DataFrame,
options: java.util.HashMap[String, String]): Unit = {
options: java.util.HashMap[String, String],
messageName: String): Unit = {
val fromProtoDf = df.select(
functions.from_protobuf($"value", "EventPerson", testFileDesc, options) as 'sample)
functions.from_protobuf($"value", messageName, testFileDesc, options) as 'sample)
assert(expectedDf.schema === fromProtoDf.schema)
checkAnswer(fromProtoDf, expectedDf)
}
Expand Down