Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion connector/protobuf/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@
<version>${protobuf.version}</version>
<scope>compile</scope>
</dependency>

</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
Expand All @@ -110,6 +109,28 @@
</relocations>
</configuration>
</plugin>
<plugin>
<groupId>com.github.os72</groupId>
<artifactId>protoc-jar-maven-plugin</artifactId>
<version>3.11.4</version>
<!-- Generates Java classes for tests. TODO(Raghu): Generate descriptor files too. -->
<executions>
<execution>
<phase>generate-test-sources</phase>
<goals>
<goal>run</goal>
</goals>
<configuration>
<protocArtifact>com.google.protobuf:protoc:${protobuf.version}</protocArtifact>
<protocVersion>${protobuf.version}</protocVersion>
<inputDirectories>
<include>src/test/resources/protobuf</include>
</inputDirectories>
<addSources>test</addSources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ import org.apache.spark.sql.types.{BinaryType, DataType}

private[protobuf] case class CatalystDataToProtobuf(
child: Expression,
descFilePath: String,
messageName: String)
messageName: String,
descFilePath: Option[String] = None)
extends UnaryExpression {

override def dataType: DataType = BinaryType

@transient private lazy val protoType =
ProtobufUtils.buildDescriptor(descFilePath, messageName)
@transient private lazy val protoDescriptor =
ProtobufUtils.buildDescriptor(messageName, descFilePathOpt = descFilePath)

@transient private lazy val serializer =
new ProtobufSerializer(child.dataType, protoType, child.nullable)
new ProtobufSerializer(child.dataType, protoDescriptor, child.nullable)

override def nullSafeEval(input: Any): Any = {
val dynamicMessage = serializer.serialize(input).asInstanceOf[DynamicMessage]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, Struc

private[protobuf] case class ProtobufDataToCatalyst(
child: Expression,
descFilePath: String,
messageName: String,
options: Map[String, String])
descFilePath: Option[String] = None,
options: Map[String, String] = Map.empty)
extends UnaryExpression
with ExpectsInputTypes {

Expand All @@ -55,10 +55,14 @@ private[protobuf] case class ProtobufDataToCatalyst(
private lazy val protobufOptions = ProtobufOptions(options)

@transient private lazy val messageDescriptor =
ProtobufUtils.buildDescriptor(descFilePath, messageName)
ProtobufUtils.buildDescriptor(messageName, descFilePath)
// TODO: Avoid carrying the file name. Read the contents of descriptor file only once
// at the start. Rest of the runs should reuse the buffer. Otherwise, it could
// cause inconsistencies if the file contents are changed the user after a few days.
// Same for the write side in [[CatalystDataToProtobuf]].

@transient private lazy val fieldsNumbers =
messageDescriptor.getFields.asScala.map(f => f.getNumber)
messageDescriptor.getFields.asScala.map(f => f.getNumber).toSet

@transient private lazy val deserializer = new ProtobufDeserializer(messageDescriptor, dataType)

Expand Down Expand Up @@ -108,18 +112,18 @@ private[protobuf] case class ProtobufDataToCatalyst(
val binary = input.asInstanceOf[Array[Byte]]
try {
result = DynamicMessage.parseFrom(messageDescriptor, binary)
val unknownFields = result.getUnknownFields
if (!unknownFields.asMap().isEmpty) {
unknownFields.asMap().keySet().asScala.map { number =>
{
if (fieldsNumbers.contains(number)) {
return handleException(
new Throwable(s"Type mismatch encountered for field:" +
s" ${messageDescriptor.getFields.get(number)}"))
}
}
}
// If the Java class is available, it is likely more efficient to parse with it than using
// DynamicMessage. Can consider it in the future if parsing overhead is noticeable.

result.getUnknownFields.asMap().keySet().asScala.find(fieldsNumbers.contains(_)) match {
case Some(number) =>
// Unknown fields contain a field with same number as a known field. Must be due to
// mismatch of schema between writer and reader here.
throw new IllegalArgumentException(s"Type mismatch encountered for field:" +
s" ${messageDescriptor.getFields.get(number)}")
case None =>
}

val deserialized = deserializer.deserialize(result)
assert(
deserialized.isDefined,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,21 @@ object functions {
*
* @param data
* the binary column.
* @param descFilePath
* the protobuf descriptor in Message GeneratedMessageV3 format.
* @param messageName
* the protobuf message name to look for in descriptorFile.
* @param descFilePath
* the protobuf descriptor in Message GeneratedMessageV3 format.
* @since 3.4.0
*/
@Experimental
def from_protobuf(
data: Column,
descFilePath: String,
messageName: String,
descFilePath: String,
options: java.util.Map[String, String]): Column = {
new Column(
ProtobufDataToCatalyst(data.expr, descFilePath, messageName, options.asScala.toMap))
ProtobufDataToCatalyst(data.expr, messageName, Some(descFilePath), options.asScala.toMap)
)
}

/**
Expand All @@ -57,30 +58,63 @@ object functions {
*
* @param data
* the binary column.
* @param descFilePath
* the protobuf descriptor in Message GeneratedMessageV3 format.
* @param messageName
* the protobuf MessageName to look for in descriptorFile.
* @param descFilePath
* the protobuf descriptor in Message GeneratedMessageV3 format.
* @since 3.4.0
*/
@Experimental
def from_protobuf(data: Column, descFilePath: String, messageName: String): Column = {
new Column(ProtobufDataToCatalyst(data.expr, descFilePath, messageName, Map.empty))
def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = {
new Column(ProtobufDataToCatalyst(data.expr, messageName, descFilePath = Some(descFilePath)))
// TODO: Add an option for user to provide descriptor file content as a buffer. This
// gives flexibility in how the content is fetched.
}

/**
* Converts a binary column of Protobuf format into its corresponding catalyst value. The
* specified schema must match actual schema of the read data, otherwise the behavior is
* undefined: it may fail or return arbitrary result. To deserialize the data with a compatible
* and evolved schema, the expected Protobuf schema can be set via the option protoSchema.
*
* @param data
* the binary column.
* @param messageClassName
* The Protobuf class name. E.g. <code>org.spark.examples.protobuf.ExampleEvent</code>.
* @since 3.4.0
*/
@Experimental
def from_protobuf(data: Column, messageClassName: String): Column = {
new Column(ProtobufDataToCatalyst(data.expr, messageClassName))
}

/**
* Converts a column into binary of protobuf format.
*
* @param data
* the data column.
* @param descFilePath
* the protobuf descriptor in Message GeneratedMessageV3 format.
* @param messageName
* the protobuf MessageName to look for in descriptorFile.
* @param descFilePath
* the protobuf descriptor in Message GeneratedMessageV3 format.
* @since 3.4.0
*/
@Experimental
def to_protobuf(data: Column, messageName: String, descFilePath: String): Column = {
new Column(CatalystDataToProtobuf(data.expr, messageName, Some(descFilePath)))
}

/**
* Converts a column into binary of protobuf format.
*
* @param data
* the data column.
* @param messageClassName
* The Protobuf class name. E.g. <code>org.spark.examples.protobuf.ExampleEvent</code>.
* @since 3.4.0
*/
@Experimental
def to_protobuf(data: Column, descFilePath: String, messageName: String): Column = {
new Column(CatalystDataToProtobuf(data.expr, descFilePath, messageName))
def to_protobuf(data: Column, messageClassName: String): Column = {
new Column(CatalystDataToProtobuf(data.expr, messageClassName))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ import java.util.Locale

import scala.collection.JavaConverters._

import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException}
import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException, Message}
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

private[sql] object ProtobufUtils extends Logging {

Expand Down Expand Up @@ -132,23 +133,63 @@ private[sql] object ProtobufUtils extends Logging {
}
}

def buildDescriptor(descFilePath: String, messageName: String): Descriptor = {
val fileDescriptor: Descriptors.FileDescriptor = parseFileDescriptor(descFilePath)
var result: Descriptors.Descriptor = null;
/**
* Builds Protobuf message descriptor either from the Java class or from serialized descriptor
* read from the file.
* @param messageName
* Protobuf message name or Java class name.
* @param descFilePathOpt
* When the file name set, the descriptor and it's dependencies are read from the file. Other
* the `messageName` is treated as Java class name.
* @return
*/
def buildDescriptor(messageName: String, descFilePathOpt: Option[String]): Descriptor = {
descFilePathOpt match {
case Some(filePath) => buildDescriptor(descFilePath = filePath, messageName)
case None => buildDescriptorFromJavaClass(messageName)
}
}

for (descriptor <- fileDescriptor.getMessageTypes.asScala) {
if (descriptor.getName().equals(messageName)) {
result = descriptor
}
/**
* Loads the given protobuf class and returns Protobuf descriptor for it.
*/
def buildDescriptorFromJavaClass(protobufClassName: String): Descriptor = {
val protobufClass = try {
Utils.classForName(protobufClassName)
} catch {
case _: ClassNotFoundException =>
val hasDots = protobufClassName.contains(".")
throw new IllegalArgumentException(
s"Could not load Protobuf class with name '$protobufClassName'" +
(if (hasDots) "" else ". Ensure the class name includes package prefix.")
)
}

if (!classOf[Message].isAssignableFrom(protobufClass)) {
throw new IllegalArgumentException(s"$protobufClassName is not a Protobuf message type")
// TODO: Need to support V2. This might work with V2 classes too.
}

// Extract the descriptor from Protobuf message.
protobufClass
.getDeclaredMethod("getDescriptor")
.invoke(null)
.asInstanceOf[Descriptor]
}

def buildDescriptor(descFilePath: String, messageName: String): Descriptor = {
val descriptor = parseFileDescriptor(descFilePath).getMessageTypes.asScala.find { desc =>
desc.getName == messageName || desc.getFullName == messageName
}

if (null == result) {
throw new RuntimeException("Unable to locate Message '" + messageName + "' in Descriptor");
descriptor match {
case Some(d) => d
case None =>
throw new RuntimeException(s"Unable to locate Message '$messageName' in Descriptor")
}
result
}

def parseFileDescriptor(descFilePath: String): Descriptors.FileDescriptor = {
private def parseFileDescriptor(descFilePath: String): Descriptors.FileDescriptor = {
var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null
try {
val dscFile = new BufferedInputStream(new FileInputStream(descFilePath))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ object SchemaConverters {
Some(DayTimeIntervalType.defaultConcreteType)
case MESSAGE if fd.getMessageType.getName == "Timestamp" =>
Some(TimestampType)
// FIXME: Is the above accurate? Users can have protos named "Timestamp" but are not
// expected to be TimestampType in Spark. How about verifying fields?
// Same for "Duration". Only the Timestamp & Duration protos defined in
// google.protobuf package should default to corresponding Catalylist types.
case MESSAGE if fd.isRepeated && fd.getMessageType.getOptions.hasMapEntry =>
var keyType: DataType = NullType
var valueType: DataType = NullType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

syntax = "proto3";

package org.apache.spark.sql.protobuf;
package org.apache.spark.sql.protobuf.protos;
option java_outer_classname = "CatalystTypes";

// TODO: import one or more protobuf files.

message BooleanMsg {
bool bool_type = 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

syntax = "proto3";

package org.apache.spark.sql.protobuf;
package org.apache.spark.sql.protobuf.protos;

option java_outer_classname = "SimpleMessageProtos";

Expand Down Expand Up @@ -119,7 +119,7 @@ message SimpleMessageEnum {
string key = 1;
string value = 2;
enum NestedEnum {
ESTED_NOTHING = 0;
ESTED_NOTHING = 0; // TODO: Fix the name.
NESTED_FIRST = 1;
NESTED_SECOND = 2;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

syntax = "proto3";

package org.apache.spark.sql.protobuf;
option java_outer_classname = "SimpleMessageProtos";
package org.apache.spark.sql.protobuf.protos;
option java_outer_classname = "SerdeSuiteProtos";

/* Clean Message*/
message BasicMessage {
message SerdeBasicMessage {
Foo foo = 1;
}

Expand Down
Loading