Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[protobuf] case class ProtobufDataToCatalyst(
override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)

override lazy val dataType: DataType = {
val dt = SchemaConverters.toSqlType(messageDescriptor).dataType
val dt = SchemaConverters.toSqlType(messageDescriptor, protobufOptions).dataType
parseMode match {
// With PermissiveMode, the output Catalyst row might contain columns of null values for
// corrupt records, even if some of the columns are not nullable in the user-provided schema.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ private[sql] class ProtobufDeserializer(

case (null, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal)

case (MESSAGE, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for? For handling limited recursion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, correct.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment about we might be dropping data here? It will not be easy to see for a future reader.
We could have an option to error our if the actual data has more recursion than the configure.


// TODO: we can avoid boxing if future version of Protobuf provide primitive accessors.
case (BOOLEAN, BooleanType) =>
(updater, ordinal, value) => updater.setBoolean(ordinal, value.asInstanceOf[Boolean])
Expand Down Expand Up @@ -235,7 +237,7 @@ private[sql] class ProtobufDeserializer(
writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage])
updater.set(ordinal, row)

case (MESSAGE, ArrayType(st: StructType, containsNull)) =>
case (MESSAGE, ArrayType(st: DataType, containsNull)) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear as to why we need to make this change. It does not look like OneOf related. Could you clarify why this would be needed ? Does it cover any specific case ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mposdev21 Actually, we don't need this here at all; we can add it at the top along with other arraytype types.

newArrayWriter(protoType, protoPath, catalystPath, st, containsNull)

case (ENUM, StringType) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ private[sql] class ProtobufOptions(

val parseMode: ParseMode =
parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode)

val circularReferenceType: String = parameters.getOrElse("circularReferenceType", "FIELD_NAME")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SandishKumarHN @baganokodo2022 moving the discussion here (for threading).

Besides, can we also support a "CircularReferenceType" option with a enum value of [FIELD_NAME, FIELD_TYPE]. The reason is because navigation can go very deep before the same fully-qualified FIELD_NAME is encountered again. While FIELD_TYPE stops recursive navigation much faster. ...

I didn't quite follow the motivation here. Could you give a concrete examples for the two difference cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangadi we already know about field_name circusive check. using fd.getFullName we detect the recursion and throw and error. another option is to detect recursion through field type. example below.

message A {
B b;
}

message B {
A c;
}

in the case of field_name recursive check it is A.B.C no recursion.
in the case of field_type recursive check. it is MESSAGE.MESSAGE.MESSAGE recursion will be found and throw an error or drop the certain recursive depth.
but it will also throw an error for the below case with the field_type check. since it will be MESSAGE.MESSAGE.MESSAGE.MESSAGE

message A {
B b = 1;
}

message B {
D d = 1;
}

message D {
E e = 1;
}

message E {
int32 key = 1;
}

@baganokodo2022 argument is field_type base check will give users an option to drop recursion more quickly because with complex nested schema recursive field_name can be found at very deep. before hitting this we might see OOM. field_type base check finds the circle reference more quickly.

@baganokodo2022 please correct me if I'm wrong.

Copy link

@rangadi rangadi Dec 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the case of field_name recursive check it is A.B.C no recursion.

The first example is clearly recursion. What is 'C' here?

but it will also throw an error for the below case with the field_type check. since it will be MESSAGE.MESSAGE.MESSAGE.MESSAGE

Why is this recursion?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are our unit tests showing these cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have @baganokodo2022 give more details on the field type case.

We have not yet added unit tests for the field-type case. would like to discuss this before adding unit tests.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thread would be A.B.A.aa.D.d.A.aaa.E

What is this thread?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this discussion, let's write down functionality and examples, before we implement so that we are all on the same page.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangadi fd.fullName is able to detect the recursive field with different field names. add a unit test. now I'm confused.
Fail for recursion field with different field names

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:) yeah, field names should not matter at all.
We can do video chat to clarify all this.

Copy link
Contributor Author

@SandishKumarHN SandishKumarHN Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangadi @baganokodo2022 thanks for the quick meet. meeting conclusion was to use descriptor type full name and added unit tests with some complex schema.

val recordName = fd.getMessageType.getFullName


// User can choose a circularReferenceDepth of 0, 1, or 1.
// Going beyond 3 levels of recursion is not allowed.
val circularReferenceDepth: Int = parameters.getOrElse("circularReferenceDepth", "-1").toInt
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion for renaming this option: "recursive.fields.max.depth"

circularReferenceDepth sounds very code variable type.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go with that, we could rename the variable as well to 'recursiveFieldMaxDepth' (but this is your choice).

}

private[sql] object ProtobufOptions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,26 @@ object SchemaConverters {
*
* @since 3.4.0
*/
def toSqlType(descriptor: Descriptor): SchemaType = {
toSqlTypeHelper(descriptor)
def toSqlType(
descriptor: Descriptor,
protobufOptions: ProtobufOptions = ProtobufOptions(Map.empty)): SchemaType = {
toSqlTypeHelper(descriptor, protobufOptions)
}

def toSqlTypeHelper(descriptor: Descriptor): SchemaType = ScalaReflectionLock.synchronized {
def toSqlTypeHelper(
descriptor: Descriptor,
protobufOptions: ProtobufOptions): SchemaType = ScalaReflectionLock.synchronized {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR, but why would we lock ScalaReflectionLock here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I just noticed. Not sure if if we need.
@SandishKumarHN could we remove this in a follow up?

SchemaType(
StructType(descriptor.getFields.asScala.flatMap(structFieldFor(_, Set.empty)).toArray),
StructType(descriptor.getFields.asScala.flatMap(
structFieldFor(_, Map.empty, Map.empty, protobufOptions: ProtobufOptions)).toArray),
nullable = true)
}

def structFieldFor(
fd: FieldDescriptor,
existingRecordNames: Set[String]): Option[StructField] = {
existingRecordNames: Map[String, Int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add comments to explain what map key and value means here?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan added a comment.

existingRecordTypes: Map[String, Int],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SandishKumarHN since it is going to be either FIELD_NAME or FIELD_TYPE, do we need keep both 2 Maps?

protobufOptions: ProtobufOptions): Option[StructField] = {
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
val dataType = fd.getJavaType match {
case INT => Some(IntegerType)
Expand Down Expand Up @@ -81,9 +88,19 @@ object SchemaConverters {
fd.getMessageType.getFields.forEach { field =>
field.getName match {
case "key" =>
keyType = structFieldFor(field, existingRecordNames).get.dataType
keyType =
structFieldFor(
field,
existingRecordNames,
existingRecordTypes,
protobufOptions).get.dataType
case "value" =>
valueType = structFieldFor(field, existingRecordNames).get.dataType
valueType =
structFieldFor(
field,
existingRecordNames,
existingRecordTypes,
protobufOptions).get.dataType
}
}
return Option(
Expand All @@ -92,14 +109,40 @@ object SchemaConverters {
MapType(keyType, valueType, valueContainsNull = false).defaultConcreteType,
nullable = false))
case MESSAGE =>
if (existingRecordNames.contains(fd.getFullName)) {
throw QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
// Setting the circularReferenceDepth to 0 allows the field to be recursed once, setting
// it to 1 allows it to be recursed twice, and setting it to 2 allows it to be recursed
// thrice. circularReferenceDepth value greater than 2 is not allowed. If the not
// specified, it will default to -1, which disables recursive fields.
if (protobufOptions.circularReferenceType.equals("FIELD_TYPE")) {
if (existingRecordTypes.contains(fd.getType.name()) &&
(protobufOptions.circularReferenceDepth < 0 ||
protobufOptions.circularReferenceDepth >= 3)) {
throw QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
} else if (existingRecordTypes.contains(fd.getType.name()) &&
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name or full name?
also what keeps track of the recursion depth?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangadi we have two maps with incremental counters, one for field_name base check and one for field_type.

Copy link

@baganokodo2022 baganokodo2022 Dec 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SandishKumarHN and @rangadi , should we error out on -1 the default value unless users specifically override?
0 (tolerance) -> drop all recursed fields once encountered
1 (tolerance) -> allowed the same field name (type) to be entered twice.
2 (tolerance) -> allowed the same field name (type) to be entered 3 timce.

thoughts?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my back-ported branch,

        val recordName = circularReferenceType match {
          case CircularReferenceTypes.FIELD_NAME =>
            fd.getFullName
          case CircularReferenceTypes.FIELD_TYPE =>
            fd.getFullName().substring(0, fd.getFullName().lastIndexOf(".")) 
        }
        
        if (circularReferenceTolerance < 0 && existingRecordNames(recordName) > 0) {
          // no tolerance on circular reference
          logError(s"circular reference in protobuf schema detected [no tolerance] - ${recordName}")
          throw new IllegalStateException(s"circular reference in protobuf schema detected [no tolerance] - ${recordName}")
        }

        if (existingRecordNames(recordName) > (circularReferenceTolerance max 0) ) {
          // stop navigation and drop the repetitive field
          logInfo(s"circular reference in protobuf schema detected [max tolerance breached] field dropped - ${recordName} = ${existingRecordNames(recordName)}")
          Some(NullType)
        } else {
          val newRecordNames: Map[String, Int] = existingRecordNames +  
            (recordName -> (1 + existingRecordNames(recordName)))
          Option(
            fd.getMessageType.getFields.asScala
              .flatMap(structFieldFor(_, newRecordNames, protobufOptions))
              .toSeq)
            .filter(_.nonEmpty)
            .map(StructType.apply)
        }```

(existingRecordTypes.getOrElse(fd.getType.name(), 0)
<= protobufOptions.circularReferenceDepth)) {
return Some(StructField(fd.getName, NullType, nullable = false))
}
} else {
if (existingRecordNames.contains(fd.getFullName) &&
(protobufOptions.circularReferenceDepth < 0 ||
protobufOptions.circularReferenceDepth >= 3)) {
throw QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
} else if (existingRecordNames.contains(fd.getFullName) &&
existingRecordNames.getOrElse(fd.getFullName, 0)
<= protobufOptions.circularReferenceDepth) {
return Some(StructField(fd.getName, NullType, nullable = false))
}
}
val newRecordNames = existingRecordNames + fd.getFullName

val newRecordNames = existingRecordNames +
(fd.getFullName -> (existingRecordNames.getOrElse(fd.getFullName, 0) + 1))
val newRecordTypes = existingRecordTypes +
(fd.getType.name() -> (existingRecordTypes.getOrElse(fd.getType.name(), 0) + 1))

Option(
fd.getMessageType.getFields.asScala
.flatMap(structFieldFor(_, newRecordNames))
.flatMap(structFieldFor(_, newRecordNames, newRecordTypes, protobufOptions))
.toSeq)
.filter(_.nonEmpty)
.map(StructType.apply)
Expand Down
Binary file modified connector/protobuf/src/test/resources/protobuf/functions_suite.desc
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,40 @@ message timeStampMsg {
message durationMsg {
string key = 1;
Duration duration = 2;
}
}

message OneOfEvent {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you testing more OneOf and recusion in the same message? Could you split them into separate messages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangadi I see a lot of use cases for the "payload" Oneof the field and recursive fields in it. So I thought combining Oneof with recursion would be a good test. will separate

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combined one is fine, we could keep it. Better to have a simpler separate tests as well.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

string key = 1;
oneof payload {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do one-of fields look like in spark schema? Could you give an example? I could not see the schema in the unit tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangadi the "Oneof" field is of message type, Oneof will be converted to a struct type.

int32 col_1 = 2;
string col_2 = 3;
int64 col_3 = 4;
}
repeated string col_4 = 5;
}

message OneOfEventWithRecursion {
string key = 1;
oneof payload {
EventRecursiveA recursiveA = 3;
EventRecursiveB recursiveB = 6;
}
string value = 7;
}

message EventRecursiveA {
OneOfEventWithRecursion recursiveA = 1;
string key = 2;
}

message EventRecursiveB {
string key = 1;
string value = 2;
OneOfEventWithRecursion recursiveA = 3;
}

message Status {
int32 id = 1;
Timestamp trade_time = 2;
Status status = 3;
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ import com.google.protobuf.{ByteString, DynamicMessage}
import org.apache.spark.sql.{Column, QueryTest, Row}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.functions.{lit, struct}
import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated
import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.{EventRecursiveA, EventRecursiveB, OneOfEvent, OneOfEventWithRecursion, SimpleMessageRepeated}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there tests for recursive fields?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangadi yes,
Handle recursive fields in Protobuf schema, C->D->Array(C) and
Handle recursive fields in Protobuf schema, A->B->A

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move that to different tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangadi I didn't understand. these are already two different tests.

import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum
import org.apache.spark.sql.protobuf.utils.ProtobufUtils
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{DayTimeIntervalType, IntegerType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, IntegerType, StringType, StructField, StructType, TimestampType}

class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with ProtobufTestBase
with Serializable {
Expand Down Expand Up @@ -417,7 +417,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
.show()
}
assert(e.getMessage.contains(
"Found recursive reference in Protobuf schema, which can not be processed by Spark:"
"Found recursive reference in Protobuf schema, which can not be processed by Spark"
))
}
}
Expand Down Expand Up @@ -453,7 +453,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
.show()
}
assert(e.getMessage.contains(
"Found recursive reference in Protobuf schema, which can not be processed by Spark:"
"Found recursive reference in Protobuf schema, which can not be processed by Spark"
))
}
}
Expand Down Expand Up @@ -693,4 +693,173 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
errorClass = "CANNOT_CONSTRUCT_PROTOBUF_DESCRIPTOR",
parameters = Map("descFilePath" -> testFileDescriptor))
}

test("Verify OneOf field between from_protobuf -> to_protobuf and struct -> from_protobuf") {
val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "OneOfEvent")
val oneOfEvent = OneOfEvent.newBuilder()
.setKey("key")
.setCol1(123)
.setCol3(109202L)
.setCol2("col2value")
.addCol4("col4value").build()

val df = Seq(oneOfEvent.toByteArray).toDF("value")

checkWithFileAndClassName("OneOfEvent") {
case (name, descFilePathOpt) =>
val fromProtoDf = df.select(
from_protobuf_wrapper($"value", name, descFilePathOpt) as 'sample)
val toDf = fromProtoDf.select(
to_protobuf_wrapper($"sample", name, descFilePathOpt) as 'toProto)
val toFromDf = toDf.select(
from_protobuf_wrapper($"toProto", name, descFilePathOpt) as 'fromToProto)
checkAnswer(fromProtoDf, toFromDf)
val actualFieldNames = fromProtoDf.select("sample.*").schema.fields.toSeq.map(f => f.name)
descriptor.getFields.asScala.map(f => {
assert(actualFieldNames.contains(f.getName))
})

val eventFromSpark = OneOfEvent.parseFrom(
toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
// OneOf field: the last set value(by order) will overwrite all previous ones.
assert(eventFromSpark.getCol2.equals("col2value"))
assert(eventFromSpark.getCol3 == 0)
val expectedFields = descriptor.getFields.asScala.map(f => f.getName)
eventFromSpark.getDescriptorForType.getFields.asScala.map(f => {
assert(expectedFields.contains(f.getName))
})

val jsonSchema =
"""{"type":"struct","fields":[{"name":"sample","type":{"type":"struct","fields":
|[{"name":"key","type":"string","nullable":true},{"name":"col_1","type":"integer",
|"nullable":true},{"name":"col_2","type":"string","nullable":true},{"name":"col_3",
|"type":"long","nullable":true},{"name":"col_4","type":{"type":"array",
|"elementType":"string","containsNull":false},"nullable":false}]},"nullable":true}]}
|{"type":"struct","fields":[{"name":"sample","type":{"type":"struct","fields":
|[{"name":"key","type":"string","nullable":true},{"name":"col_1","type":"integer",
|"nullable":true},{"name":"col_2","type":"string","nullable":true},{"name":"col_3",
|"type":"long","nullable":true},{"name":"col_4","type":{"type":"array",
|"elementType":"string","containsNull":false},"nullable":false}]},
|"nullable":true}]}""".stripMargin
val schema = DataType.fromJson(jsonSchema).asInstanceOf[StructType]
val data = Seq(Row(Row("key", 123, "col2value", 109202L, Seq("col4value"))))
val dataDf = spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
val dataDfToProto = dataDf.select(
to_protobuf_wrapper($"sample", name, descFilePathOpt) as 'toProto)

val eventFromSparkSchema = OneOfEvent.parseFrom(
dataDfToProto.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
assert(eventFromSparkSchema.getCol2.isEmpty)
assert(eventFromSparkSchema.getCol3 == 109202L)
eventFromSparkSchema.getDescriptorForType.getFields.asScala.map(f => {
assert(expectedFields.contains(f.getName))
})
}
}

test("Verify OneOf field with recursive fields between from_protobuf -> to_protobuf " +
"and struct -> from_protobuf") {
val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "OneOfEventWithRecursion")

val recursiveANested = EventRecursiveA.newBuilder()
.setKey("keyNested3").build()
val oneOfEventNested = OneOfEventWithRecursion.newBuilder()
.setKey("keyNested2")
.setValue("valueNested2")
.setRecursiveA(recursiveANested).build()
val recursiveA = EventRecursiveA.newBuilder().setKey("recursiveAKey")
.setRecursiveA(oneOfEventNested).build()
val recursiveB = EventRecursiveB.newBuilder()
.setKey("recursiveBKey")
.setValue("recursiveBvalue").build()
val oneOfEventWithRecursion = OneOfEventWithRecursion.newBuilder()
.setKey("key1")
.setValue("value1")
.setRecursiveB(recursiveB)
.setRecursiveA(recursiveA).build()

val df = Seq(oneOfEventWithRecursion.toByteArray).toDF("value")

val options = new java.util.HashMap[String, String]()
options.put("circularReferenceDepth", "1")

val fromProtoDf = df.select(
functions.from_protobuf($"value",
"OneOfEventWithRecursion",
testFileDesc, options) as 'sample)

val toDf = fromProtoDf.select(
functions.to_protobuf($"sample", "OneOfEventWithRecursion", testFileDesc) as 'toProto)
val toFromDf = toDf.select(
functions.from_protobuf($"toProto",
"OneOfEventWithRecursion",
testFileDesc,
options) as 'fromToProto)

checkAnswer(fromProtoDf, toFromDf)

val actualFieldNames = fromProtoDf.select("sample.*").schema.fields.toSeq.map(f => f.name)
descriptor.getFields.asScala.map(f => {
assert(actualFieldNames.contains(f.getName))
})

val eventFromSpark = OneOfEventWithRecursion.parseFrom(
toDf.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))

assert(eventFromSpark.getRecursiveA.getRecursiveA.getKey.equals("keyNested2"))
assert(eventFromSpark.getRecursiveA.getRecursiveA.getValue.equals("valueNested2"))
assert(eventFromSpark.getRecursiveA.getRecursiveA.getRecursiveA.getKey.isEmpty)

val expectedFields = descriptor.getFields.asScala.map(f => f.getName)
eventFromSpark.getDescriptorForType.getFields.asScala.map(f => {
assert(expectedFields.contains(f.getName))
})

val jsonSchema =
"""{"type":"struct","fields":[{"name":"sample","type":{"type":"struct","fields":
|[{"name":"key","type":"string","nullable":true},{"name":"recursiveA","type":
|{"type":"struct","fields":[{"name":"recursiveA","type":{"type":"struct","fields":
|[{"name":"key","type":"string","nullable":true},{"name":"recursiveA","type":"void",
|"nullable":true},{"name":"recursiveB","type":{"type":"struct","fields":[{"name":"key",
|"type":"string","nullable":true},{"name":"value","type":"string","nullable":true},
|{"name":"recursiveA","type":{"type":"struct","fields":[{"name":"key","type":"string",
|"nullable":true},{"name":"recursiveA","type":"void","nullable":true},{"name":"recursiveB",
|"type":"void","nullable":true},{"name":"value","type":"string","nullable":true}]},
|"nullable":true}]},"nullable":true},{"name":"value","type":"string","nullable":true}]},
|"nullable":true},{"name":"key","type":"string","nullable":true}]},"nullable":true},
|{"name":"recursiveB","type":{"type":"struct","fields":[{"name":"key","type":"string",
|"nullable":true},{"name":"value","type":"string","nullable":true},{"name":"recursiveA",
|"type":{"type":"struct","fields":[{"name":"key","type":"string","nullable":true},
|{"name":"recursiveA","type":{"type":"struct","fields":[{"name":"recursiveA","type":
|{"type":"struct","fields":[{"name":"key","type":"string","nullable":true},
|{"name":"recursiveA","type":"void","nullable":true},{"name":"recursiveB","type":"void",
|"nullable":true},{"name":"value","type":"string","nullable":true}]},"nullable":true},
|{"name":"key","type":"string","nullable":true}]},"nullable":true},{"name":"recursiveB",
|"type":"void","nullable":true},{"name":"value","type":"string","nullable":true}]},
|"nullable":true}]},"nullable":true},{"name":"value","type":"string","nullable":true}]},
|"nullable":true}]}""".stripMargin
val schema = DataType.fromJson(jsonSchema).asInstanceOf[StructType]
val data = Seq(
Row(
Row("key1",
Row(
Row("keyNested2", null, null, "valueNested2"),
"recursiveAKey"),
null,
"value1")
)
)
val dataDf = spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
val dataDfToProto = dataDf.select(
functions.to_protobuf($"sample", "OneOfEventWithRecursion", testFileDesc) as 'toProto)

val eventFromSparkSchema = OneOfEventWithRecursion.parseFrom(
dataDfToProto.select("toProto").take(1).toSeq(0).getAs[Array[Byte]](0))
assert(eventFromSpark.getRecursiveA.getRecursiveA.getKey.equals("keyNested2"))
assert(eventFromSpark.getRecursiveA.getRecursiveA.getValue.equals("valueNested2"))
assert(eventFromSpark.getRecursiveA.getRecursiveA.getRecursiveA.getKey.isEmpty)
eventFromSparkSchema.getDescriptorForType.getFields.asScala.map(f => {
assert(expectedFields.contains(f.getName))
})
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add test that clearly shows the expected schema similar to my comment here: https://github.com/apache/spark/pull/38922/files#r1051292604

It is not easy to seem from these test what schema does 0 or 2 results in.

}
2 changes: 1 addition & 1 deletion core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@
},
"RECURSIVE_PROTOBUF_SCHEMA" : {
"message" : [
"Found recursive reference in Protobuf schema, which can not be processed by Spark: <fieldDescriptor>"
"Found recursive reference in Protobuf schema, which can not be processed by Spark by default: <fieldDescriptor>. try setting the option `circularReferenceDepth` as 0 or 1 or 2. Going beyond 3 levels of recursion is not allowed."
]
},
"RENAME_SRC_PATH_NOT_FOUND" : {
Expand Down