diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml
index 0515f128b8d63..b934c7f831a2b 100644
--- a/connector/protobuf/pom.xml
+++ b/connector/protobuf/pom.xml
@@ -83,7 +83,6 @@
${protobuf.version}
compile
-
target/scala-${scala.binary.version}/classes
@@ -110,6 +109,28 @@
+
+ com.github.os72
+ protoc-jar-maven-plugin
+ 3.11.4
+
+
+
+ generate-test-sources
+
+ run
+
+
+ com.google.protobuf:protoc:${protobuf.version}
+ ${protobuf.version}
+
+ src/test/resources/protobuf
+
+ test
+
+
+
+
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
index 145100268c232..b9f7907ea8ca6 100644
--- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala
@@ -25,17 +25,17 @@ import org.apache.spark.sql.types.{BinaryType, DataType}
private[protobuf] case class CatalystDataToProtobuf(
child: Expression,
- descFilePath: String,
- messageName: String)
+ messageName: String,
+ descFilePath: Option[String] = None)
extends UnaryExpression {
override def dataType: DataType = BinaryType
- @transient private lazy val protoType =
- ProtobufUtils.buildDescriptor(descFilePath, messageName)
+ @transient private lazy val protoDescriptor =
+ ProtobufUtils.buildDescriptor(messageName, descFilePathOpt = descFilePath)
@transient private lazy val serializer =
- new ProtobufSerializer(child.dataType, protoType, child.nullable)
+ new ProtobufSerializer(child.dataType, protoDescriptor, child.nullable)
override def nullSafeEval(input: Any): Any = {
val dynamicMessage = serializer.serialize(input).asInstanceOf[DynamicMessage]
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
index f08f876799723..cad2442f10c17 100644
--- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala
@@ -31,9 +31,9 @@ import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, Struc
private[protobuf] case class ProtobufDataToCatalyst(
child: Expression,
- descFilePath: String,
messageName: String,
- options: Map[String, String])
+ descFilePath: Option[String] = None,
+ options: Map[String, String] = Map.empty)
extends UnaryExpression
with ExpectsInputTypes {
@@ -55,10 +55,14 @@ private[protobuf] case class ProtobufDataToCatalyst(
private lazy val protobufOptions = ProtobufOptions(options)
@transient private lazy val messageDescriptor =
- ProtobufUtils.buildDescriptor(descFilePath, messageName)
+ ProtobufUtils.buildDescriptor(messageName, descFilePath)
+ // TODO: Avoid carrying the file name. Read the contents of descriptor file only once
+ // at the start. Rest of the runs should reuse the buffer. Otherwise, it could
+ // cause inconsistencies if the file contents are changed the user after a few days.
+ // Same for the write side in [[CatalystDataToProtobuf]].
@transient private lazy val fieldsNumbers =
- messageDescriptor.getFields.asScala.map(f => f.getNumber)
+ messageDescriptor.getFields.asScala.map(f => f.getNumber).toSet
@transient private lazy val deserializer = new ProtobufDeserializer(messageDescriptor, dataType)
@@ -108,18 +112,18 @@ private[protobuf] case class ProtobufDataToCatalyst(
val binary = input.asInstanceOf[Array[Byte]]
try {
result = DynamicMessage.parseFrom(messageDescriptor, binary)
- val unknownFields = result.getUnknownFields
- if (!unknownFields.asMap().isEmpty) {
- unknownFields.asMap().keySet().asScala.map { number =>
- {
- if (fieldsNumbers.contains(number)) {
- return handleException(
- new Throwable(s"Type mismatch encountered for field:" +
- s" ${messageDescriptor.getFields.get(number)}"))
- }
- }
- }
+ // If the Java class is available, it is likely more efficient to parse with it than using
+ // DynamicMessage. Can consider it in the future if parsing overhead is noticeable.
+
+ result.getUnknownFields.asMap().keySet().asScala.find(fieldsNumbers.contains(_)) match {
+ case Some(number) =>
+ // Unknown fields contain a field with same number as a known field. Must be due to
+ // mismatch of schema between writer and reader here.
+ throw new IllegalArgumentException(s"Type mismatch encountered for field:" +
+ s" ${messageDescriptor.getFields.get(number)}")
+ case None =>
}
+
val deserialized = deserializer.deserialize(result)
assert(
deserialized.isDefined,
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
index 283d1ca8c412c..af30de40dad04 100644
--- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
@@ -33,20 +33,21 @@ object functions {
*
* @param data
* the binary column.
- * @param descFilePath
- * the protobuf descriptor in Message GeneratedMessageV3 format.
* @param messageName
* the protobuf message name to look for in descriptorFile.
+ * @param descFilePath
+ * the protobuf descriptor in Message GeneratedMessageV3 format.
* @since 3.4.0
*/
@Experimental
def from_protobuf(
data: Column,
- descFilePath: String,
messageName: String,
+ descFilePath: String,
options: java.util.Map[String, String]): Column = {
new Column(
- ProtobufDataToCatalyst(data.expr, descFilePath, messageName, options.asScala.toMap))
+ ProtobufDataToCatalyst(data.expr, messageName, Some(descFilePath), options.asScala.toMap)
+ )
}
/**
@@ -57,15 +58,34 @@ object functions {
*
* @param data
* the binary column.
- * @param descFilePath
- * the protobuf descriptor in Message GeneratedMessageV3 format.
* @param messageName
* the protobuf MessageName to look for in descriptorFile.
+ * @param descFilePath
+ * the protobuf descriptor in Message GeneratedMessageV3 format.
* @since 3.4.0
*/
@Experimental
- def from_protobuf(data: Column, descFilePath: String, messageName: String): Column = {
- new Column(ProtobufDataToCatalyst(data.expr, descFilePath, messageName, Map.empty))
+ def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = {
+ new Column(ProtobufDataToCatalyst(data.expr, messageName, descFilePath = Some(descFilePath)))
+ // TODO: Add an option for user to provide descriptor file content as a buffer. This
+ // gives flexibility in how the content is fetched.
+ }
+
+ /**
+ * Converts a binary column of Protobuf format into its corresponding catalyst value. The
+ * specified schema must match actual schema of the read data, otherwise the behavior is
+ * undefined: it may fail or return arbitrary result. To deserialize the data with a compatible
+ * and evolved schema, the expected Protobuf schema can be set via the option protoSchema.
+ *
+ * @param data
+ * the binary column.
+ * @param messageClassName
+ * The Protobuf class name. E.g. org.spark.examples.protobuf.ExampleEvent.
+ * @since 3.4.0
+ */
+ @Experimental
+ def from_protobuf(data: Column, messageClassName: String): Column = {
+ new Column(ProtobufDataToCatalyst(data.expr, messageClassName))
}
/**
@@ -73,14 +93,28 @@ object functions {
*
* @param data
* the data column.
- * @param descFilePath
- * the protobuf descriptor in Message GeneratedMessageV3 format.
* @param messageName
* the protobuf MessageName to look for in descriptorFile.
+ * @param descFilePath
+ * the protobuf descriptor in Message GeneratedMessageV3 format.
+ * @since 3.4.0
+ */
+ @Experimental
+ def to_protobuf(data: Column, messageName: String, descFilePath: String): Column = {
+ new Column(CatalystDataToProtobuf(data.expr, messageName, Some(descFilePath)))
+ }
+
+ /**
+ * Converts a column into binary of protobuf format.
+ *
+ * @param data
+ * the data column.
+ * @param messageClassName
+ * The Protobuf class name. E.g. org.spark.examples.protobuf.ExampleEvent.
* @since 3.4.0
*/
@Experimental
- def to_protobuf(data: Column, descFilePath: String, messageName: String): Column = {
- new Column(CatalystDataToProtobuf(data.expr, descFilePath, messageName))
+ def to_protobuf(data: Column, messageClassName: String): Column = {
+ new Column(CatalystDataToProtobuf(data.expr, messageClassName))
}
}
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
index 5ad043142a2d2..fa2ec9b7cd462 100644
--- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala
@@ -22,13 +22,14 @@ import java.util.Locale
import scala.collection.JavaConverters._
-import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException}
+import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException, Message}
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
private[sql] object ProtobufUtils extends Logging {
@@ -132,23 +133,63 @@ private[sql] object ProtobufUtils extends Logging {
}
}
- def buildDescriptor(descFilePath: String, messageName: String): Descriptor = {
- val fileDescriptor: Descriptors.FileDescriptor = parseFileDescriptor(descFilePath)
- var result: Descriptors.Descriptor = null;
+ /**
+ * Builds Protobuf message descriptor either from the Java class or from serialized descriptor
+ * read from the file.
+ * @param messageName
+ * Protobuf message name or Java class name.
+ * @param descFilePathOpt
+ * When the file name set, the descriptor and it's dependencies are read from the file. Other
+ * the `messageName` is treated as Java class name.
+ * @return
+ */
+ def buildDescriptor(messageName: String, descFilePathOpt: Option[String]): Descriptor = {
+ descFilePathOpt match {
+ case Some(filePath) => buildDescriptor(descFilePath = filePath, messageName)
+ case None => buildDescriptorFromJavaClass(messageName)
+ }
+ }
- for (descriptor <- fileDescriptor.getMessageTypes.asScala) {
- if (descriptor.getName().equals(messageName)) {
- result = descriptor
- }
+ /**
+ * Loads the given protobuf class and returns Protobuf descriptor for it.
+ */
+ def buildDescriptorFromJavaClass(protobufClassName: String): Descriptor = {
+ val protobufClass = try {
+ Utils.classForName(protobufClassName)
+ } catch {
+ case _: ClassNotFoundException =>
+ val hasDots = protobufClassName.contains(".")
+ throw new IllegalArgumentException(
+ s"Could not load Protobuf class with name '$protobufClassName'" +
+ (if (hasDots) "" else ". Ensure the class name includes package prefix.")
+ )
+ }
+
+ if (!classOf[Message].isAssignableFrom(protobufClass)) {
+ throw new IllegalArgumentException(s"$protobufClassName is not a Protobuf message type")
+ // TODO: Need to support V2. This might work with V2 classes too.
+ }
+
+ // Extract the descriptor from Protobuf message.
+ protobufClass
+ .getDeclaredMethod("getDescriptor")
+ .invoke(null)
+ .asInstanceOf[Descriptor]
+ }
+
+ def buildDescriptor(descFilePath: String, messageName: String): Descriptor = {
+ val descriptor = parseFileDescriptor(descFilePath).getMessageTypes.asScala.find { desc =>
+ desc.getName == messageName || desc.getFullName == messageName
}
- if (null == result) {
- throw new RuntimeException("Unable to locate Message '" + messageName + "' in Descriptor");
+ descriptor match {
+ case Some(d) => d
+ case None =>
+ throw new RuntimeException(s"Unable to locate Message '$messageName' in Descriptor")
}
- result
}
- def parseFileDescriptor(descFilePath: String): Descriptors.FileDescriptor = {
+ private def parseFileDescriptor(descFilePath: String): Descriptors.FileDescriptor = {
var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null
try {
val dscFile = new BufferedInputStream(new FileInputStream(descFilePath))
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
index e385b816abe70..4fca06fb5d8ba 100644
--- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala
@@ -66,6 +66,10 @@ object SchemaConverters {
Some(DayTimeIntervalType.defaultConcreteType)
case MESSAGE if fd.getMessageType.getName == "Timestamp" =>
Some(TimestampType)
+ // FIXME: Is the above accurate? Users can have protos named "Timestamp" but are not
+ // expected to be TimestampType in Spark. How about verifying fields?
+ // Same for "Duration". Only the Timestamp & Duration protos defined in
+ // google.protobuf package should default to corresponding Catalylist types.
case MESSAGE if fd.isRepeated && fd.getMessageType.getOptions.hasMapEntry =>
var keyType: DataType = NullType
var valueType: DataType = NullType
diff --git a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
index 54e6bc18df153..1deb193438c20 100644
--- a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
+++ b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
@@ -19,9 +19,11 @@
syntax = "proto3";
-package org.apache.spark.sql.protobuf;
+package org.apache.spark.sql.protobuf.protos;
option java_outer_classname = "CatalystTypes";
+// TODO: import one or more protobuf files.
+
message BooleanMsg {
bool bool_type = 1;
}
diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
index f38c041b799ec..60f8c26214153 100644
--- a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
+++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
@@ -20,7 +20,7 @@
syntax = "proto3";
-package org.apache.spark.sql.protobuf;
+package org.apache.spark.sql.protobuf.protos;
option java_outer_classname = "SimpleMessageProtos";
@@ -119,7 +119,7 @@ message SimpleMessageEnum {
string key = 1;
string value = 2;
enum NestedEnum {
- ESTED_NOTHING = 0;
+ ESTED_NOTHING = 0; // TODO: Fix the name.
NESTED_FIRST = 1;
NESTED_SECOND = 2;
}
diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto
index 1e3065259aa02..a7459213a87b2 100644
--- a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto
+++ b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto
@@ -20,11 +20,11 @@
syntax = "proto3";
-package org.apache.spark.sql.protobuf;
-option java_outer_classname = "SimpleMessageProtos";
+package org.apache.spark.sql.protobuf.protos;
+option java_outer_classname = "SerdeSuiteProtos";
/* Clean Message*/
-message BasicMessage {
+message SerdeBasicMessage {
Foo foo = 1;
}
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
index b730ebb4fea80..19774a2ad07e4 100644
--- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, NoopFilters, OrderedFilters, StructFilters}
import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
+import org.apache.spark.sql.protobuf.protos.CatalystTypes.BytesMsg
import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, SchemaConverters}
import org.apache.spark.sql.sources.{EqualTo, Not}
import org.apache.spark.sql.test.SharedSparkSession
@@ -35,18 +36,32 @@ class ProtobufCatalystDataConversionSuite
with SharedSparkSession
with ExpressionEvalHelper {
- private def checkResult(
+ private val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
+ private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.CatalystTypes$"
+
+ private def checkResultWithEval(
data: Literal,
descFilePath: String,
messageName: String,
expected: Any): Unit = {
- checkEvaluation(
- ProtobufDataToCatalyst(
- CatalystDataToProtobuf(data, descFilePath, messageName),
- descFilePath,
- messageName,
- Map.empty),
- prepareExpectedResult(expected))
+
+ withClue("(Eval check with Java class name)") {
+ val className = s"$javaClassNamePrefix$messageName"
+ checkEvaluation(
+ ProtobufDataToCatalyst(
+ CatalystDataToProtobuf(data, className),
+ className,
+ descFilePath = None),
+ prepareExpectedResult(expected))
+ }
+ withClue("(Eval check with descriptor file)") {
+ checkEvaluation(
+ ProtobufDataToCatalyst(
+ CatalystDataToProtobuf(data, messageName, Some(descFilePath)),
+ messageName,
+ descFilePath = Some(descFilePath)),
+ prepareExpectedResult(expected))
+ }
}
protected def checkUnsupportedRead(
@@ -55,10 +70,11 @@ class ProtobufCatalystDataConversionSuite
actualSchema: String,
badSchema: String): Unit = {
- val binary = CatalystDataToProtobuf(data, descFilePath, actualSchema)
+ val binary = CatalystDataToProtobuf(data, actualSchema, Some(descFilePath))
intercept[Exception] {
- ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> "FAILFAST")).eval()
+ ProtobufDataToCatalyst(binary, badSchema, Some(descFilePath), Map("mode" -> "FAILFAST"))
+ .eval()
}
val expected = {
@@ -73,7 +89,7 @@ class ProtobufCatalystDataConversionSuite
}
checkEvaluation(
- ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> "PERMISSIVE")),
+ ProtobufDataToCatalyst(binary, badSchema, Some(descFilePath), Map("mode" -> "PERMISSIVE")),
expected)
}
@@ -99,26 +115,32 @@ class ProtobufCatalystDataConversionSuite
StructType(StructField("bytes_type", BinaryType, nullable = true) :: Nil),
StructType(StructField("string_type", StringType, nullable = true) :: Nil))
- private val catalystTypesToProtoMessages: Map[DataType, String] = Map(
- IntegerType -> "IntegerMsg",
- DoubleType -> "DoubleMsg",
- FloatType -> "FloatMsg",
- BinaryType -> "BytesMsg",
- StringType -> "StringMsg")
+ private val catalystTypesToProtoMessages: Map[DataType, (String, Any)] = Map(
+ IntegerType -> ("IntegerMsg", 0),
+ DoubleType -> ("DoubleMsg", 0.0d),
+ FloatType -> ("FloatMsg", 0.0f),
+ BinaryType -> ("BytesMsg", ByteString.empty().toByteArray),
+ StringType -> ("StringMsg", ""))
testingTypes.foreach { dt =>
val seed = 1 + scala.util.Random.nextInt((1024 - 1) + 1)
- val filePath = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
test(s"single $dt with seed $seed") {
+
+ val (messageName, defaultValue) = catalystTypesToProtoMessages(dt.fields(0).dataType)
+
val rand = new scala.util.Random(seed)
- val data = RandomDataGenerator.forType(dt, rand = rand).get.apply()
+ val generator = RandomDataGenerator.forType(dt, rand = rand).get
+ var data = generator()
+ while (data.asInstanceOf[Row].get(0) == defaultValue) // Do not use default values, since
+ data = generator() // from_protobuf() returns null in v3.
+
val converter = CatalystTypeConverters.createToCatalystConverter(dt)
val input = Literal.create(converter(data), dt)
- checkResult(
+ checkResultWithEval(
input,
- filePath,
- catalystTypesToProtoMessages(dt.fields(0).dataType),
+ testFileDesc,
+ messageName,
input.eval())
}
}
@@ -137,6 +159,15 @@ class ProtobufCatalystDataConversionSuite
val dynMsg = DynamicMessage.parseFrom(descriptor, data.toByteArray)
val deserialized = deserializer.deserialize(dynMsg)
+
+ // Verify Java class deserializer matches with descriptor based serializer.
+ val javaDescriptor = ProtobufUtils
+ .buildDescriptorFromJavaClass(s"$javaClassNamePrefix$messageName")
+ assert(dataType == SchemaConverters.toSqlType(javaDescriptor).dataType)
+ val javaDeserialized = new ProtobufDeserializer(javaDescriptor, dataType, filters)
+ .deserialize(DynamicMessage.parseFrom(javaDescriptor, data.toByteArray))
+ assert(deserialized == javaDeserialized)
+
expected match {
case None => assert(deserialized.isEmpty)
case Some(d) =>
@@ -145,7 +176,6 @@ class ProtobufCatalystDataConversionSuite
}
test("Handle unsupported input of message type") {
- val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
val actualSchema = StructType(
Seq(
StructField("col_0", StringType, nullable = false),
@@ -165,7 +195,6 @@ class ProtobufCatalystDataConversionSuite
test("filter push-down to Protobuf deserializer") {
- val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
val sqlSchema = new StructType()
.add("name", "string")
.add("age", "int")
@@ -196,17 +225,23 @@ class ProtobufCatalystDataConversionSuite
test("ProtobufDeserializer with binary type") {
- val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
val bb = java.nio.ByteBuffer.wrap(Array[Byte](97, 48, 53))
- val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg")
-
- val dynamicMessage = DynamicMessage
- .newBuilder(descriptor)
- .setField(descriptor.findFieldByName("bytes_type"), ByteString.copyFrom(bb))
+ val bytesProto = BytesMsg
+ .newBuilder()
+ .setBytesType(ByteString.copyFrom(bb))
.build()
val expected = InternalRow(Array[Byte](97, 48, 53))
- checkDeserialization(testFileDesc, "BytesMsg", dynamicMessage, Some(expected))
+ checkDeserialization(testFileDesc, "BytesMsg", bytesProto, Some(expected))
+ }
+
+ test("Full names for message using descriptor file") {
+ val withShortName = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg")
+ assert(withShortName.findFieldByName("bytes_type") != null)
+
+ val withFullName = ProtobufUtils.buildDescriptor(
+ testFileDesc, "org.apache.spark.sql.protobuf.BytesMsg")
+ assert(withFullName.findFieldByName("bytes_type") != null)
}
}
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
index 4e9bc1c1c287a..72280fb0d9e2d 100644
--- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -23,8 +23,10 @@ import scala.collection.JavaConverters._
import com.google.protobuf.{ByteString, DynamicMessage}
-import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.{Column, QueryTest, Row}
import org.apache.spark.sql.functions.{lit, struct}
+import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated
+import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum
import org.apache.spark.sql.protobuf.utils.ProtobufUtils
import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
import org.apache.spark.sql.test.SharedSparkSession
@@ -35,6 +37,39 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
import testImplicits._
val testFileDesc = testFile("protobuf/functions_suite.desc").replace("file:/", "/")
+ private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$"
+
+ /**
+ * Runs the given closure twice. Once with descriptor file and second time with Java class name.
+ */
+ private def checkWithFileAndClassName(messageName: String)(
+ fn: (String, Option[String]) => Unit): Unit = {
+ withClue("(With descriptor file)") {
+ fn(messageName, Some(testFileDesc))
+ }
+ withClue("(With Java class name)") {
+ fn(s"$javaClassNamePrefix$messageName", None)
+ }
+ }
+
+ // A wrapper to invoke the right variable of from_protobuf() depending on arguments.
+ private def from_protobuf_wrapper(
+ col: Column, messageName: String, descFilePathOpt: Option[String]): Column = {
+ descFilePathOpt match {
+ case Some(descFilePath) => functions.from_protobuf(col, messageName, descFilePath)
+ case None => functions.from_protobuf(col, messageName)
+ }
+ }
+
+ // A wrapper to invoke the right variable of to_protobuf() depending on arguments.
+ private def to_protobuf_wrapper(
+ col: Column, messageName: String, descFilePathOpt: Option[String]): Column = {
+ descFilePathOpt match {
+ case Some(descFilePath) => functions.to_protobuf(col, messageName, descFilePath)
+ case None => functions.to_protobuf(col, messageName)
+ }
+ }
+
test("roundtrip in to_protobuf and from_protobuf - struct") {
val df = spark
@@ -56,44 +91,45 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
lit(1202.00).cast(org.apache.spark.sql.types.FloatType).as("float_value"),
lit(true).as("bool_value"),
lit("0".getBytes).as("bytes_value")).as("SimpleMessage"))
- val protoStructDF = df.select(
- functions.to_protobuf($"SimpleMessage", testFileDesc, "SimpleMessage").as("proto"))
- val actualDf = protoStructDF.select(
- functions.from_protobuf($"proto", testFileDesc, "SimpleMessage").as("proto.*"))
- checkAnswer(actualDf, df)
+
+ checkWithFileAndClassName("SimpleMessage") {
+ case (name, descFilePathOpt) =>
+ val protoStructDF = df.select(
+ to_protobuf_wrapper($"SimpleMessage", name, descFilePathOpt).as("proto"))
+ val actualDf = protoStructDF.select(
+ from_protobuf_wrapper($"proto", name, descFilePathOpt).as("proto.*"))
+ checkAnswer(actualDf, df)
+ }
}
test("roundtrip in from_protobuf and to_protobuf - Repeated") {
- val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageRepeated")
- val dynamicMessage = DynamicMessage
- .newBuilder(descriptor)
- .setField(descriptor.findFieldByName("key"), "key")
- .setField(descriptor.findFieldByName("value"), "value")
- .addRepeatedField(descriptor.findFieldByName("rbool_value"), false)
- .addRepeatedField(descriptor.findFieldByName("rbool_value"), true)
- .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092092.654d)
- .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092093.654d)
- .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10903.0f)
- .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10902.0f)
- .addRepeatedField(
- descriptor.findFieldByName("rnested_enum"),
- descriptor.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING"))
- .addRepeatedField(
- descriptor.findFieldByName("rnested_enum"),
- descriptor.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST"))
+ val protoMessage = SimpleMessageRepeated
+ .newBuilder()
+ .setKey("key")
+ .setValue("value")
+ .addRboolValue(false)
+ .addRboolValue(true)
+ .addRdoubleValue(1092092.654d)
+ .addRdoubleValue(1092093.654d)
+ .addRfloatValue(10903.0f)
+ .addRfloatValue(10902.0f)
+ .addRnestedEnum(NestedEnum.ESTED_NOTHING)
+ .addRnestedEnum(NestedEnum.NESTED_FIRST)
.build()
- val df = Seq(dynamicMessage.toByteArray).toDF("value")
- val fromProtoDF = df.select(
- functions.from_protobuf($"value", testFileDesc, "SimpleMessageRepeated").as("value_from"))
- val toProtoDF = fromProtoDF.select(
- functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageRepeated").as("value_to"))
- val toFromProtoDF = toProtoDF.select(
- functions
- .from_protobuf($"value_to", testFileDesc, "SimpleMessageRepeated")
- .as("value_to_from"))
- checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ val df = Seq(protoMessage.toByteArray).toDF("value")
+
+ checkWithFileAndClassName("SimpleMessageRepeated") {
+ case (name, descFilePathOpt) =>
+ val fromProtoDF = df.select(
+ from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from"))
+ val toProtoDF = fromProtoDF.select(
+ to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
}
test("roundtrip in from_protobuf and to_protobuf - Repeated Message Once") {
@@ -120,13 +156,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
.build()
val df = Seq(dynamicMessage.toByteArray).toDF("value")
- val fromProtoDF = df.select(
- functions.from_protobuf($"value", testFileDesc, "RepeatedMessage").as("value_from"))
- val toProtoDF = fromProtoDF.select(
- functions.to_protobuf($"value_from", testFileDesc, "RepeatedMessage").as("value_to"))
- val toFromProtoDF = toProtoDF.select(
- functions.from_protobuf($"value_to", testFileDesc, "RepeatedMessage").as("value_to_from"))
- checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+
+ checkWithFileAndClassName("RepeatedMessage") {
+ case (name, descFilePathOpt) =>
+ val fromProtoDF = df.select(
+ from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from"))
+ val toProtoDF = fromProtoDF.select(
+ to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
}
test("roundtrip in from_protobuf and to_protobuf - Repeated Message Twice") {
@@ -167,13 +207,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
.build()
val df = Seq(dynamicMessage.toByteArray).toDF("value")
- val fromProtoDF = df.select(
- functions.from_protobuf($"value", testFileDesc, "RepeatedMessage").as("value_from"))
- val toProtoDF = fromProtoDF.select(
- functions.to_protobuf($"value_from", testFileDesc, "RepeatedMessage").as("value_to"))
- val toFromProtoDF = toProtoDF.select(
- functions.from_protobuf($"value_to", testFileDesc, "RepeatedMessage").as("value_to_from"))
- checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+
+ checkWithFileAndClassName("RepeatedMessage") {
+ case (name, descFilePathOpt) =>
+ val fromProtoDF = df.select(
+ from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from"))
+ val toProtoDF = fromProtoDF.select(
+ to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
}
test("roundtrip in from_protobuf and to_protobuf - Map") {
@@ -257,13 +301,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
.build()
val df = Seq(dynamicMessage.toByteArray).toDF("value")
- val fromProtoDF = df.select(
- functions.from_protobuf($"value", testFileDesc, "SimpleMessageMap").as("value_from"))
- val toProtoDF = fromProtoDF.select(
- functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageMap").as("value_to"))
- val toFromProtoDF = toProtoDF.select(
- functions.from_protobuf($"value_to", testFileDesc, "SimpleMessageMap").as("value_to_from"))
- checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+
+ checkWithFileAndClassName("SimpleMessageMap") {
+ case (name, descFilePathOpt) =>
+ val fromProtoDF = df.select(
+ from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from"))
+ val toProtoDF = fromProtoDF.select(
+ to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
}
test("roundtrip in from_protobuf and to_protobuf - Enum") {
@@ -289,13 +337,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
.build()
val df = Seq(dynamicMessage.toByteArray).toDF("value")
- val fromProtoDF = df.select(
- functions.from_protobuf($"value", testFileDesc, "SimpleMessageEnum").as("value_from"))
- val toProtoDF = fromProtoDF.select(
- functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageEnum").as("value_to"))
- val toFromProtoDF = toProtoDF.select(
- functions.from_protobuf($"value_to", testFileDesc, "SimpleMessageEnum").as("value_to_from"))
- checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+
+ checkWithFileAndClassName("SimpleMessageEnum") {
+ case (name, descFilePathOpt) =>
+ val fromProtoDF = df.select(
+ from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from"))
+ val toProtoDF = fromProtoDF.select(
+ to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
}
test("roundtrip in from_protobuf and to_protobuf - Multiple Message") {
@@ -320,13 +372,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
.build()
val df = Seq(dynamicMessage.toByteArray).toDF("value")
- val fromProtoDF = df.select(
- functions.from_protobuf($"value", testFileDesc, "MultipleExample").as("value_from"))
- val toProtoDF = fromProtoDF.select(
- functions.to_protobuf($"value_from", testFileDesc, "MultipleExample").as("value_to"))
- val toFromProtoDF = toProtoDF.select(
- functions.from_protobuf($"value_to", testFileDesc, "MultipleExample").as("value_to_from"))
- checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+
+ checkWithFileAndClassName("MultipleExample") {
+ case (name, descFilePathOpt) =>
+ val fromProtoDF = df.select(
+ from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from"))
+ val toProtoDF = fromProtoDF.select(
+ to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ }
}
test("Handle recursive fields in Protobuf schema, A->B->A") {
@@ -352,15 +408,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
val df = Seq(messageB.toByteArray).toDF("messageB")
- val e = intercept[IncompatibleSchemaException] {
- df.select(
- functions.from_protobuf($"messageB", testFileDesc, "recursiveB").as("messageFromProto"))
- .show()
+ checkWithFileAndClassName("recursiveB") {
+ case (name, descFilePathOpt) =>
+ val e = intercept[IncompatibleSchemaException] {
+ df.select(
+ from_protobuf_wrapper($"messageB", name, descFilePathOpt).as("messageFromProto"))
+ .show()
+ }
+ assert(e.getMessage.contains(
+ "Found recursive reference in Protobuf schema, which can not be processed by Spark:"
+ ))
}
- val expectedMessage = s"""
- |Found recursive reference in Protobuf schema, which can not be processed by Spark:
- |org.apache.spark.sql.protobuf.recursiveB.messageA""".stripMargin
- assert(e.getMessage == expectedMessage)
}
test("Handle recursive fields in Protobuf schema, C->D->Array(C)") {
@@ -386,16 +444,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
val df = Seq(messageD.toByteArray).toDF("messageD")
- val e = intercept[IncompatibleSchemaException] {
- df.select(
- functions.from_protobuf($"messageD", testFileDesc, "recursiveD").as("messageFromProto"))
- .show()
+ checkWithFileAndClassName("recursiveD") {
+ case (name, descFilePathOpt) =>
+ val e = intercept[IncompatibleSchemaException] {
+ df.select(
+ from_protobuf_wrapper($"messageD", name, descFilePathOpt).as("messageFromProto"))
+ .show()
+ }
+ assert(e.getMessage.contains(
+ "Found recursive reference in Protobuf schema, which can not be processed by Spark:"
+ ))
}
- val expectedMessage =
- s"""
- |Found recursive reference in Protobuf schema, which can not be processed by Spark:
- |org.apache.spark.sql.protobuf.recursiveD.messageC""".stripMargin
- assert(e.getMessage == expectedMessage)
}
test("Handle extra fields : oldProducer -> newConsumer") {
@@ -411,17 +470,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
val df = Seq(oldProducerMessage.toByteArray).toDF("oldProducerData")
val fromProtoDf = df.select(
functions
- .from_protobuf($"oldProducerData", testFileDesc, "newConsumer")
+ .from_protobuf($"oldProducerData", "newConsumer", testFileDesc)
.as("fromProto"))
val toProtoDf = fromProtoDf.select(
functions
- .to_protobuf($"fromProto", testFileDesc, "newConsumer")
+ .to_protobuf($"fromProto", "newConsumer", testFileDesc)
.as("toProto"))
val toProtoDfToFromProtoDf = toProtoDf.select(
functions
- .from_protobuf($"toProto", testFileDesc, "newConsumer")
+ .from_protobuf($"toProto", "newConsumer", testFileDesc)
.as("toProtoToFromProto"))
val actualFieldNames =
@@ -452,7 +511,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
val df = Seq(newProducerMessage.toByteArray).toDF("newProducerData")
val fromProtoDf = df.select(
functions
- .from_protobuf($"newProducerData", testFileDesc, "oldConsumer")
+ .from_protobuf($"newProducerData", "oldConsumer", testFileDesc)
.as("oldConsumerProto"))
val expectedFieldNames = oldConsumer.getFields.asScala.map(f => f.getName)
@@ -481,8 +540,9 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
)),
schema
)
+
val toProtobuf = inputDf.select(
- functions.to_protobuf($"requiredMsg", testFileDesc, "requiredMsg")
+ functions.to_protobuf($"requiredMsg", "requiredMsg", testFileDesc)
.as("to_proto"))
val binary = toProtobuf.take(1).toSeq(0).get(0).asInstanceOf[Array[Byte]]
@@ -498,7 +558,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
assert(actualMessage.getField(messageDescriptor.findFieldByName("col_3")) == 0)
val fromProtoDf = toProtobuf.select(
- functions.from_protobuf($"to_proto", testFileDesc, "requiredMsg") as 'from_proto)
+ functions.from_protobuf($"to_proto", "requiredMsg", testFileDesc) as 'from_proto)
assert(fromProtoDf.select("from_proto.key").take(1).toSeq(0).get(0)
== inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0))
@@ -526,16 +586,20 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
.build()
val df = Seq(basicMessage.toByteArray).toDF("value")
- val resultFrom = df
- .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") as 'sample)
- .where("sample.string_value == \"slam\"")
- val resultToFrom = resultFrom
- .select(functions.to_protobuf($"sample", testFileDesc, "BasicMessage") as 'value)
- .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") as 'sample)
- .where("sample.string_value == \"slam\"")
+ checkWithFileAndClassName("BasicMessage") {
+ case (name, descFilePathOpt) =>
+ val resultFrom = df
+ .select(from_protobuf_wrapper($"value", name, descFilePathOpt) as 'sample)
+ .where("sample.string_value == \"slam\"")
+
+ val resultToFrom = resultFrom
+ .select(to_protobuf_wrapper($"sample", name, descFilePathOpt) as 'value)
+ .select(from_protobuf_wrapper($"value", name, descFilePathOpt) as 'sample)
+ .where("sample.string_value == \"slam\"")
- assert(resultFrom.except(resultToFrom).isEmpty)
+ assert(resultFrom.except(resultToFrom).isEmpty)
+ }
}
test("Handle TimestampType between to_protobuf and from_protobuf") {
@@ -556,22 +620,24 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
schema
)
- val toProtoDf = inputDf
- .select(functions.to_protobuf($"timeStampMsg", testFileDesc, "timeStampMsg") as 'to_proto)
+ checkWithFileAndClassName("timeStampMsg") {
+ case (name, descFilePathOpt) =>
+ val toProtoDf = inputDf
+ .select(to_protobuf_wrapper($"timeStampMsg", name, descFilePathOpt) as 'to_proto)
- val fromProtoDf = toProtoDf
- .select(functions.from_protobuf($"to_proto", testFileDesc, "timeStampMsg") as 'timeStampMsg)
- fromProtoDf.show(truncate = false)
+ val fromProtoDf = toProtoDf
+ .select(from_protobuf_wrapper($"to_proto", name, descFilePathOpt) as 'timeStampMsg)
- val actualFields = fromProtoDf.schema.fields.toList
- val expectedFields = inputDf.schema.fields.toList
+ val actualFields = fromProtoDf.schema.fields.toList
+ val expectedFields = inputDf.schema.fields.toList
- assert(actualFields.size === expectedFields.size)
- assert(actualFields === expectedFields)
- assert(fromProtoDf.select("timeStampMsg.key").take(1).toSeq(0).get(0)
- === inputDf.select("timeStampMsg.key").take(1).toSeq(0).get(0))
- assert(fromProtoDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0)
- === inputDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0))
+ assert(actualFields.size === expectedFields.size)
+ assert(actualFields === expectedFields)
+ assert(fromProtoDf.select("timeStampMsg.key").take(1).toSeq(0).get(0)
+ === inputDf.select("timeStampMsg.key").take(1).toSeq(0).get(0))
+ assert(fromProtoDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0)
+ === inputDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0))
+ }
}
test("Handle DayTimeIntervalType between to_protobuf and from_protobuf") {
@@ -595,21 +661,23 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
schema
)
- val toProtoDf = inputDf
- .select(functions.to_protobuf($"durationMsg", testFileDesc, "durationMsg") as 'to_proto)
+ checkWithFileAndClassName("durationMsg") {
+ case (name, descFilePathOpt) =>
+ val toProtoDf = inputDf
+ .select(to_protobuf_wrapper($"durationMsg", name, descFilePathOpt) as 'to_proto)
- val fromProtoDf = toProtoDf
- .select(functions.from_protobuf($"to_proto", testFileDesc, "durationMsg") as 'durationMsg)
+ val fromProtoDf = toProtoDf
+ .select(from_protobuf_wrapper($"to_proto", name, descFilePathOpt) as 'durationMsg)
- val actualFields = fromProtoDf.schema.fields.toList
- val expectedFields = inputDf.schema.fields.toList
-
- assert(actualFields.size === expectedFields.size)
- assert(actualFields === expectedFields)
- assert(fromProtoDf.select("durationMsg.key").take(1).toSeq(0).get(0)
- === inputDf.select("durationMsg.key").take(1).toSeq(0).get(0))
- assert(fromProtoDf.select("durationMsg.duration").take(1).toSeq(0).get(0)
- === inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0))
+ val actualFields = fromProtoDf.schema.fields.toList
+ val expectedFields = inputDf.schema.fields.toList
+ assert(actualFields.size === expectedFields.size)
+ assert(actualFields === expectedFields)
+ assert(fromProtoDf.select("durationMsg.key").take(1).toSeq(0).get(0)
+ === inputDf.select("durationMsg.key").take(1).toSeq(0).get(0))
+ assert(fromProtoDf.select("durationMsg.duration").take(1).toSeq(0).get(0)
+ === inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0))
+ }
}
}
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
index 37c59743e7714..efc02524e68db 100644
--- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
@@ -36,6 +36,7 @@ class ProtobufSerdeSuite extends SharedSparkSession {
import ProtoSerdeSuite.MatchType._
val testFileDesc = testFile("protobuf/serde_suite.desc").replace("file:/", "/")
+ private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SerdeSuiteProtos$"
test("Test basic conversion") {
withFieldMatchType { fieldMatch =>
@@ -96,7 +97,9 @@ class ProtobufSerdeSuite extends SharedSparkSession {
}
test("Fail to convert with deeply nested field type mismatch") {
- val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "MissMatchTypeInDeepNested")
+ val protoFile = ProtobufUtils.buildDescriptorFromJavaClass(
+ s"${javaClassNamePrefix}MissMatchTypeInDeepNested"
+ )
val catalyst = new StructType().add("top", CATALYST_STRUCT)
withFieldMatchType { fieldMatch =>
@@ -105,8 +108,8 @@ class ProtobufSerdeSuite extends SharedSparkSession {
Deserializer,
fieldMatch,
s"Cannot convert Protobuf field 'top.foo.bar' to SQL field 'top.foo.bar' because schema " +
- s"is incompatible (protoType = org.apache.spark.sql.protobuf.TypeMiss.bar " +
- s"LABEL_OPTIONAL LONG INT64, sqlType = INT)".stripMargin,
+ s"is incompatible (protoType = org.apache.spark.sql.protobuf.protos.TypeMiss.bar " +
+ s"LABEL_OPTIONAL LONG INT64, sqlType = INT)",
catalyst)
assertFailedConversionMessage(
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index e5a48080e833a..cc103e4ab00ac 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -716,8 +716,10 @@ object SparkProtobuf {
dependencyOverrides += "com.google.protobuf" % "protobuf-java" % protoVersion,
- (Compile / PB.targets) := Seq(
- PB.gens.java -> (Compile / sourceManaged).value,
+ (Test / PB.protoSources) += (Test / sourceDirectory).value / "resources",
+
+ (Test / PB.targets) := Seq(
+ PB.gens.java -> target.value / "generated-test-sources"
),
(assembly / test) := { },
diff --git a/python/pyspark/sql/protobuf/functions.py b/python/pyspark/sql/protobuf/functions.py
index 9f8b90095dfd9..2059d868c7c12 100644
--- a/python/pyspark/sql/protobuf/functions.py
+++ b/python/pyspark/sql/protobuf/functions.py
@@ -31,8 +31,8 @@
def from_protobuf(
data: "ColumnOrName",
- descFilePath: str,
messageName: str,
+ descFilePath: str,
options: Optional[Dict[str, str]] = None,
) -> Column:
"""
@@ -48,10 +48,10 @@ def from_protobuf(
----------
data : :class:`~pyspark.sql.Column` or str
the binary column.
- descFilePath : str
- the protobuf descriptor in Message GeneratedMessageV3 format.
messageName: str
the protobuf message name to look for in descriptor file.
+ descFilePath : str
+ the protobuf descriptor in Message GeneratedMessageV3 format.
options : dict, optional
options to control how the protobuf record is parsed.
@@ -80,10 +80,10 @@ def from_protobuf(
... f.flush()
... message_name = 'SimpleMessage'
... proto_df = df.select(
- ... to_protobuf(df.value, desc_file_path, message_name).alias("value"))
+ ... to_protobuf(df.value, message_name, desc_file_path).alias("value"))
... proto_df.show(truncate=False)
... proto_df = proto_df.select(
- ... from_protobuf(proto_df.value, desc_file_path, message_name).alias("value"))
+ ... from_protobuf(proto_df.value, message_name, desc_file_path).alias("value"))
... proto_df.show(truncate=False)
+----------------------------------------+
|value |
@@ -101,7 +101,7 @@ def from_protobuf(
assert sc is not None and sc._jvm is not None
try:
jc = sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf(
- _to_java_column(data), descFilePath, messageName, options or {}
+ _to_java_column(data), messageName, descFilePath, options or {}
)
except TypeError as e:
if str(e) == "'JavaPackage' object is not callable":
@@ -110,7 +110,7 @@ def from_protobuf(
return Column(jc)
-def to_protobuf(data: "ColumnOrName", descFilePath: str, messageName: str) -> Column:
+def to_protobuf(data: "ColumnOrName", messageName: str, descFilePath: str) -> Column:
"""
Converts a column into binary of protobuf format.
@@ -120,10 +120,10 @@ def to_protobuf(data: "ColumnOrName", descFilePath: str, messageName: str) -> Co
----------
data : :class:`~pyspark.sql.Column` or str
the data column.
- descFilePath : str
- the protobuf descriptor in Message GeneratedMessageV3 format.
messageName: str
the protobuf message name to look for in descriptor file.
+ descFilePath : str
+ the protobuf descriptor in Message GeneratedMessageV3 format.
Notes
-----
@@ -150,7 +150,7 @@ def to_protobuf(data: "ColumnOrName", descFilePath: str, messageName: str) -> Co
... f.flush()
... message_name = 'SimpleMessage'
... proto_df = df.select(
- ... to_protobuf(df.value, desc_file_path, message_name).alias("suite"))
+ ... to_protobuf(df.value, message_name, desc_file_path).alias("suite"))
... proto_df.show(truncate=False)
+-------------------------------------------+
|suite |
@@ -162,7 +162,7 @@ def to_protobuf(data: "ColumnOrName", descFilePath: str, messageName: str) -> Co
assert sc is not None and sc._jvm is not None
try:
jc = sc._jvm.org.apache.spark.sql.protobuf.functions.to_protobuf(
- _to_java_column(data), descFilePath, messageName
+ _to_java_column(data), messageName, descFilePath
)
except TypeError as e:
if str(e) == "'JavaPackage' object is not callable":