Skip to content
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[protobuf] case class ProtobufDataToCatalyst(
override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)

override lazy val dataType: DataType = {
val dt = SchemaConverters.toSqlType(messageDescriptor).dataType
val dt = SchemaConverters.toSqlType(messageDescriptor, protobufOptions).dataType
parseMode match {
// With PermissiveMode, the output Catalyst row might contain columns of null values for
// corrupt records, even if some of the columns are not nullable in the user-provided schema.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ private[sql] class ProtobufDeserializer(

case (null, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal)

case (MESSAGE, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for? For handling limited recursion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, correct.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment about we might be dropping data here? It will not be easy to see for a future reader.
We could have an option to error our if the actual data has more recursion than the configure.


// TODO: we can avoid boxing if future version of Protobuf provide primitive accessors.
case (BOOLEAN, BooleanType) =>
(updater, ordinal, value) => updater.setBoolean(ordinal, value.asInstanceOf[Boolean])
Expand Down Expand Up @@ -235,7 +237,7 @@ private[sql] class ProtobufDeserializer(
writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage])
updater.set(ordinal, row)

case (MESSAGE, ArrayType(st: StructType, containsNull)) =>
case (MESSAGE, ArrayType(st: DataType, containsNull)) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear as to why we need to make this change. It does not look like OneOf related. Could you clarify why this would be needed ? Does it cover any specific case ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mposdev21 Actually, we don't need this here at all; we can add it at the top along with other arraytype types.

newArrayWriter(protoType, protoPath, catalystPath, st, containsNull)

case (ENUM, StringType) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ private[sql] class ProtobufOptions(

val parseMode: ParseMode =
parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode)

// Setting the circularReferenceDepth to 0 allows the field to be recursed once, setting
// it to 1 allows it to be recursed twice, and setting it to 2 allows it to be recursed
// thrice. circularReferenceDepth value greater than 2 is not allowed. If the not
// specified, it will default to -1, which disables recursive fields.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'-1' implies recursive fields are not allowed.
("disables" does not clearly imply that it will be an error")

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also warn that if the the protobuf record has more depth for recursive fields than allowed here, it will be truncated to the allowed depth. The implies some fields are discarded from the record.

Could you add a simple example in the comment showing resulting spark schema when this is set to '0' and '2'.

val circularReferenceDepth: Int = parameters.getOrElse("circularReferenceDepth", "-1").toInt
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion for renaming this option: "recursive.fields.max.depth"

circularReferenceDepth sounds very code variable type.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go with that, we could rename the variable as well to 'recursiveFieldMaxDepth' (but this is your choice).

}

private[sql] object ProtobufOptions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,25 @@ object SchemaConverters {
*
* @since 3.4.0
*/
def toSqlType(descriptor: Descriptor): SchemaType = {
toSqlTypeHelper(descriptor)
def toSqlType(
descriptor: Descriptor,
protobufOptions: ProtobufOptions = ProtobufOptions(Map.empty)): SchemaType = {
toSqlTypeHelper(descriptor, protobufOptions)
}

def toSqlTypeHelper(descriptor: Descriptor): SchemaType = ScalaReflectionLock.synchronized {
def toSqlTypeHelper(
descriptor: Descriptor,
protobufOptions: ProtobufOptions): SchemaType = ScalaReflectionLock.synchronized {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR, but why would we lock ScalaReflectionLock here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I just noticed. Not sure if if we need.
@SandishKumarHN could we remove this in a follow up?

SchemaType(
StructType(descriptor.getFields.asScala.flatMap(structFieldFor(_, Set.empty)).toArray),
StructType(descriptor.getFields.asScala.flatMap(
structFieldFor(_, Map.empty, protobufOptions: ProtobufOptions)).toArray),
nullable = true)
}

def structFieldFor(
fd: FieldDescriptor,
existingRecordNames: Set[String]): Option[StructField] = {
existingRecordNames: Map[String, Int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add comments to explain what map key and value means here?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan added a comment.

protobufOptions: ProtobufOptions): Option[StructField] = {
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
val dataType = fd.getJavaType match {
case INT => Some(IntegerType)
Expand Down Expand Up @@ -81,9 +87,17 @@ object SchemaConverters {
fd.getMessageType.getFields.forEach { field =>
field.getName match {
case "key" =>
keyType = structFieldFor(field, existingRecordNames).get.dataType
keyType =
structFieldFor(
field,
existingRecordNames,
protobufOptions).get.dataType
case "value" =>
valueType = structFieldFor(field, existingRecordNames).get.dataType
valueType =
structFieldFor(
field,
existingRecordNames,
protobufOptions).get.dataType
}
}
return Option(
Expand All @@ -92,14 +106,26 @@ object SchemaConverters {
MapType(keyType, valueType, valueContainsNull = false).defaultConcreteType,
nullable = false))
case MESSAGE =>
if (existingRecordNames.contains(fd.getFullName)) {
// Setting the circularReferenceDepth to 0 allows the field to be recursed once, setting
// it to 1 allows it to be recursed twice, and setting it to 2 allows it to be recursed
// thrice. circularReferenceDepth value greater than 2 is not allowed. If the not
// specified, it will default to -1, which disables recursive fields.
val recordName = fd.getMessageType.getFullName
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
val recordName = fd.getMessageType.getFullName
val recordName = fd.getFullName

are they same? The previous code uses fd.getFullName

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I think the previous code was incorrect. We need to verify if a same Protobuf type was seen before in this DFS traversal.
@SandishKumarHN what was the unit test that verified recursion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan fd.getFullName gives a fully qualified name along with a field name, we needed the fully qualified type name. we made this decision above.

here is the difference.

println(s"${fd.getFullName} : ${fd.getMessageType.getFullName}")

org.apache.spark.sql.protobuf.protos.Employee.ic : org.apache.spark.sql.protobuf.protos.IC
org.apache.spark.sql.protobuf.protos.IC.icManager : org.apache.spark.sql.protobuf.protos.Employee
org.apache.spark.sql.protobuf.protos.Employee.ic : org.apache.spark.sql.protobuf.protos.IC
org.apache.spark.sql.protobuf.protos.IC.icManager : org.apache.spark.sql.protobuf.protos.Employee
org.apache.spark.sql.protobuf.protos.Employee.em : org.apache.spark.sql.protobuf.protos.EM

@rangadi previous code fd.getFullName fully qualified name along with a field name works to find out recursion. so before we just use to throw errors on any recursion field.

if (existingRecordNames.contains(recordName) &&
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to remove 'return' statement.
How about"

val recursiveDepth = existingRecordNames.getOrElse(recordName, 0)
if (recursiveDepth == 0 ||  // No recursion
  (protobufOptions.circularReferenceDepth >= 0 
    && recursiveDepth <=  (protobufOptions.circularReferenceDepth + 1)) // recursion is within allowed limit.
     val newRecordNames = existingRecordNames + (recordName -> recursiveDepth + 1)) 
      ... 
} else if (protobufOptions.circularReferenceDepth >= 0) { 
     // Recursion is allowed we reached limit. Truncate.
  return Some(StructField(fd.getName, NullType, ...)      
} else { // Recursion is not allowed
   throw ...
}

Copy link

@rangadi rangadi Dec 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scratch the above suggestion.
Instead you could add 'else' to what you have and remove 'return'. That is simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangadi thanks for the review, I have made all changes you suggested.

protobufOptions.circularReferenceDepth < 0 ) {
throw QueryCompilationErrors.foundRecursionInProtobufSchema(fd.toString())
} else if (existingRecordNames.contains(recordName) &&
existingRecordNames.getOrElse(recordName, 0)
> protobufOptions.circularReferenceDepth) {
return Some(StructField(fd.getName, NullType, nullable = false))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is nullable false?

}
val newRecordNames = existingRecordNames + fd.getFullName

val newRecordNames = existingRecordNames +
(recordName -> (existingRecordNames.getOrElse(recordName, 0) + 1))

Option(
fd.getMessageType.getFields.asScala
.flatMap(structFieldFor(_, newRecordNames))
.flatMap(structFieldFor(_, newRecordNames, protobufOptions))
.toSeq)
.filter(_.nonEmpty)
.map(StructType.apply)
Expand Down
Binary file not shown.
116 changes: 115 additions & 1 deletion connector/protobuf/src/test/resources/protobuf/functions_suite.proto
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,118 @@ message timeStampMsg {
message durationMsg {
string key = 1;
Duration duration = 2;
}
}

message OneOfEvent {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you testing more OneOf and recusion in the same message? Could you split them into separate messages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangadi I see a lot of use cases for the "payload" Oneof the field and recursive fields in it. So I thought combining Oneof with recursion would be a good test. will separate

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combined one is fine, we could keep it. Better to have a simpler separate tests as well.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

string key = 1;
oneof payload {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do one-of fields look like in spark schema? Could you give an example? I could not see the schema in the unit tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangadi the "Oneof" field is of message type, Oneof will be converted to a struct type.

int32 col_1 = 2;
string col_2 = 3;
int64 col_3 = 4;
}
repeated string col_4 = 5;
}

message EventWithRecursion {
int32 key = 1;
messageA a = 2;
}
message messageA {
EventWithRecursion a = 1;
messageB b = 2;
}
message messageB {
EventWithRecursion aa = 1;
messageC c = 2;
}
message messageC {
EventWithRecursion aaa = 1;
int32 key= 2;
}

message Employee {
string firstName = 1;
string lastName = 2;
oneof role {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need so many fields for 'OneOf'? How about just 2 or 3? It will simplify testing.

IC ic = 3;
EM em = 4;
EM2 em2 = 5;
Director dir = 6;
SeniorDirector sDir = 7;
VP vp = 8;
SVP svp = 9;
CTO cto = 10;
CEO ceo = 11;
}
}

message IC {
repeated string skills = 1;
Employee icManager = 2; // EM or EM2 or Director..
}

message EM {
int64 teamsize = 1;
Employee emManager = 2; // EM2 or Director..
}

message EM2 {
int64 teamsize = 1;
Employee em2Manager = 2; // Director or Senior Director..
}

message Director {
int64 teamsize = 1;
Employee dirManager = 2; // Senior Director or VP..
}

message SeniorDirector {
int64 teamsize = 1;
Employee sdManager = 2; // VP or SVP...
}

message VP {
int64 teamsize = 1;
Employee vpManager = 2; // SVP or CTO...
}

message SVP {
int64 teamsize = 1;
Employee svpManager = 2; // CTO or CET
}

message CTO {
int64 teamsize = 1;
Employee ctoManager = 2; // CEO
}

message CEO {
int64 teamsize = 1;
Employee ceoManager = 2; // null
}

message OneOfEventWithRecursion {
string key = 1;
oneof payload {
EventRecursiveA recursiveA = 3;
EventRecursiveB recursiveB = 6;
}
string value = 7;
}

message EventRecursiveA {
OneOfEventWithRecursion recursiveA = 1;
string key = 2;
}

message EventRecursiveB {
string key = 1;
string value = 2;
OneOfEventWithRecursion recursiveA = 3;
}

message Status {
int32 id = 1;
Timestamp trade_time = 2;
Status status = 3;
}
Loading