diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml index 0515f128b8d63..b934c7f831a2b 100644 --- a/connector/protobuf/pom.xml +++ b/connector/protobuf/pom.xml @@ -83,7 +83,6 @@ ${protobuf.version} compile - target/scala-${scala.binary.version}/classes @@ -110,6 +109,28 @@ + + com.github.os72 + protoc-jar-maven-plugin + 3.11.4 + + + + generate-test-sources + + run + + + com.google.protobuf:protoc:${protobuf.version} + ${protobuf.version} + + src/test/resources/protobuf + + test + + + + diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala index 145100268c232..b9f7907ea8ca6 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala @@ -25,17 +25,17 @@ import org.apache.spark.sql.types.{BinaryType, DataType} private[protobuf] case class CatalystDataToProtobuf( child: Expression, - descFilePath: String, - messageName: String) + messageName: String, + descFilePath: Option[String] = None) extends UnaryExpression { override def dataType: DataType = BinaryType - @transient private lazy val protoType = - ProtobufUtils.buildDescriptor(descFilePath, messageName) + @transient private lazy val protoDescriptor = + ProtobufUtils.buildDescriptor(messageName, descFilePathOpt = descFilePath) @transient private lazy val serializer = - new ProtobufSerializer(child.dataType, protoType, child.nullable) + new ProtobufSerializer(child.dataType, protoDescriptor, child.nullable) override def nullSafeEval(input: Any): Any = { val dynamicMessage = serializer.serialize(input).asInstanceOf[DynamicMessage] diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala index f08f876799723..cad2442f10c17 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, Struc private[protobuf] case class ProtobufDataToCatalyst( child: Expression, - descFilePath: String, messageName: String, - options: Map[String, String]) + descFilePath: Option[String] = None, + options: Map[String, String] = Map.empty) extends UnaryExpression with ExpectsInputTypes { @@ -55,10 +55,14 @@ private[protobuf] case class ProtobufDataToCatalyst( private lazy val protobufOptions = ProtobufOptions(options) @transient private lazy val messageDescriptor = - ProtobufUtils.buildDescriptor(descFilePath, messageName) + ProtobufUtils.buildDescriptor(messageName, descFilePath) + // TODO: Avoid carrying the file name. Read the contents of descriptor file only once + // at the start. Rest of the runs should reuse the buffer. Otherwise, it could + // cause inconsistencies if the file contents are changed the user after a few days. + // Same for the write side in [[CatalystDataToProtobuf]]. @transient private lazy val fieldsNumbers = - messageDescriptor.getFields.asScala.map(f => f.getNumber) + messageDescriptor.getFields.asScala.map(f => f.getNumber).toSet @transient private lazy val deserializer = new ProtobufDeserializer(messageDescriptor, dataType) @@ -108,18 +112,18 @@ private[protobuf] case class ProtobufDataToCatalyst( val binary = input.asInstanceOf[Array[Byte]] try { result = DynamicMessage.parseFrom(messageDescriptor, binary) - val unknownFields = result.getUnknownFields - if (!unknownFields.asMap().isEmpty) { - unknownFields.asMap().keySet().asScala.map { number => - { - if (fieldsNumbers.contains(number)) { - return handleException( - new Throwable(s"Type mismatch encountered for field:" + - s" ${messageDescriptor.getFields.get(number)}")) - } - } - } + // If the Java class is available, it is likely more efficient to parse with it than using + // DynamicMessage. Can consider it in the future if parsing overhead is noticeable. + + result.getUnknownFields.asMap().keySet().asScala.find(fieldsNumbers.contains(_)) match { + case Some(number) => + // Unknown fields contain a field with same number as a known field. Must be due to + // mismatch of schema between writer and reader here. + throw new IllegalArgumentException(s"Type mismatch encountered for field:" + + s" ${messageDescriptor.getFields.get(number)}") + case None => } + val deserialized = deserializer.deserialize(result) assert( deserialized.isDefined, diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala index 283d1ca8c412c..af30de40dad04 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -33,20 +33,21 @@ object functions { * * @param data * the binary column. - * @param descFilePath - * the protobuf descriptor in Message GeneratedMessageV3 format. * @param messageName * the protobuf message name to look for in descriptorFile. + * @param descFilePath + * the protobuf descriptor in Message GeneratedMessageV3 format. * @since 3.4.0 */ @Experimental def from_protobuf( data: Column, - descFilePath: String, messageName: String, + descFilePath: String, options: java.util.Map[String, String]): Column = { new Column( - ProtobufDataToCatalyst(data.expr, descFilePath, messageName, options.asScala.toMap)) + ProtobufDataToCatalyst(data.expr, messageName, Some(descFilePath), options.asScala.toMap) + ) } /** @@ -57,15 +58,34 @@ object functions { * * @param data * the binary column. - * @param descFilePath - * the protobuf descriptor in Message GeneratedMessageV3 format. * @param messageName * the protobuf MessageName to look for in descriptorFile. + * @param descFilePath + * the protobuf descriptor in Message GeneratedMessageV3 format. * @since 3.4.0 */ @Experimental - def from_protobuf(data: Column, descFilePath: String, messageName: String): Column = { - new Column(ProtobufDataToCatalyst(data.expr, descFilePath, messageName, Map.empty)) + def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = { + new Column(ProtobufDataToCatalyst(data.expr, messageName, descFilePath = Some(descFilePath))) + // TODO: Add an option for user to provide descriptor file content as a buffer. This + // gives flexibility in how the content is fetched. + } + + /** + * Converts a binary column of Protobuf format into its corresponding catalyst value. The + * specified schema must match actual schema of the read data, otherwise the behavior is + * undefined: it may fail or return arbitrary result. To deserialize the data with a compatible + * and evolved schema, the expected Protobuf schema can be set via the option protoSchema. + * + * @param data + * the binary column. + * @param messageClassName + * The Protobuf class name. E.g. org.spark.examples.protobuf.ExampleEvent. + * @since 3.4.0 + */ + @Experimental + def from_protobuf(data: Column, messageClassName: String): Column = { + new Column(ProtobufDataToCatalyst(data.expr, messageClassName)) } /** @@ -73,14 +93,28 @@ object functions { * * @param data * the data column. - * @param descFilePath - * the protobuf descriptor in Message GeneratedMessageV3 format. * @param messageName * the protobuf MessageName to look for in descriptorFile. + * @param descFilePath + * the protobuf descriptor in Message GeneratedMessageV3 format. + * @since 3.4.0 + */ + @Experimental + def to_protobuf(data: Column, messageName: String, descFilePath: String): Column = { + new Column(CatalystDataToProtobuf(data.expr, messageName, Some(descFilePath))) + } + + /** + * Converts a column into binary of protobuf format. + * + * @param data + * the data column. + * @param messageClassName + * The Protobuf class name. E.g. org.spark.examples.protobuf.ExampleEvent. * @since 3.4.0 */ @Experimental - def to_protobuf(data: Column, descFilePath: String, messageName: String): Column = { - new Column(CatalystDataToProtobuf(data.expr, descFilePath, messageName)) + def to_protobuf(data: Column, messageClassName: String): Column = { + new Column(CatalystDataToProtobuf(data.expr, messageClassName)) } } diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala index 5ad043142a2d2..fa2ec9b7cd462 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala @@ -22,13 +22,14 @@ import java.util.Locale import scala.collection.JavaConverters._ -import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException} +import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException, Message} import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils private[sql] object ProtobufUtils extends Logging { @@ -132,23 +133,63 @@ private[sql] object ProtobufUtils extends Logging { } } - def buildDescriptor(descFilePath: String, messageName: String): Descriptor = { - val fileDescriptor: Descriptors.FileDescriptor = parseFileDescriptor(descFilePath) - var result: Descriptors.Descriptor = null; + /** + * Builds Protobuf message descriptor either from the Java class or from serialized descriptor + * read from the file. + * @param messageName + * Protobuf message name or Java class name. + * @param descFilePathOpt + * When the file name set, the descriptor and it's dependencies are read from the file. Other + * the `messageName` is treated as Java class name. + * @return + */ + def buildDescriptor(messageName: String, descFilePathOpt: Option[String]): Descriptor = { + descFilePathOpt match { + case Some(filePath) => buildDescriptor(descFilePath = filePath, messageName) + case None => buildDescriptorFromJavaClass(messageName) + } + } - for (descriptor <- fileDescriptor.getMessageTypes.asScala) { - if (descriptor.getName().equals(messageName)) { - result = descriptor - } + /** + * Loads the given protobuf class and returns Protobuf descriptor for it. + */ + def buildDescriptorFromJavaClass(protobufClassName: String): Descriptor = { + val protobufClass = try { + Utils.classForName(protobufClassName) + } catch { + case _: ClassNotFoundException => + val hasDots = protobufClassName.contains(".") + throw new IllegalArgumentException( + s"Could not load Protobuf class with name '$protobufClassName'" + + (if (hasDots) "" else ". Ensure the class name includes package prefix.") + ) + } + + if (!classOf[Message].isAssignableFrom(protobufClass)) { + throw new IllegalArgumentException(s"$protobufClassName is not a Protobuf message type") + // TODO: Need to support V2. This might work with V2 classes too. + } + + // Extract the descriptor from Protobuf message. + protobufClass + .getDeclaredMethod("getDescriptor") + .invoke(null) + .asInstanceOf[Descriptor] + } + + def buildDescriptor(descFilePath: String, messageName: String): Descriptor = { + val descriptor = parseFileDescriptor(descFilePath).getMessageTypes.asScala.find { desc => + desc.getName == messageName || desc.getFullName == messageName } - if (null == result) { - throw new RuntimeException("Unable to locate Message '" + messageName + "' in Descriptor"); + descriptor match { + case Some(d) => d + case None => + throw new RuntimeException(s"Unable to locate Message '$messageName' in Descriptor") } - result } - def parseFileDescriptor(descFilePath: String): Descriptors.FileDescriptor = { + private def parseFileDescriptor(descFilePath: String): Descriptors.FileDescriptor = { var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null try { val dscFile = new BufferedInputStream(new FileInputStream(descFilePath)) diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala index e385b816abe70..4fca06fb5d8ba 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala @@ -66,6 +66,10 @@ object SchemaConverters { Some(DayTimeIntervalType.defaultConcreteType) case MESSAGE if fd.getMessageType.getName == "Timestamp" => Some(TimestampType) + // FIXME: Is the above accurate? Users can have protos named "Timestamp" but are not + // expected to be TimestampType in Spark. How about verifying fields? + // Same for "Duration". Only the Timestamp & Duration protos defined in + // google.protobuf package should default to corresponding Catalylist types. case MESSAGE if fd.isRepeated && fd.getMessageType.getOptions.hasMapEntry => var keyType: DataType = NullType var valueType: DataType = NullType diff --git a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto index 54e6bc18df153..1deb193438c20 100644 --- a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto +++ b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto @@ -19,9 +19,11 @@ syntax = "proto3"; -package org.apache.spark.sql.protobuf; +package org.apache.spark.sql.protobuf.protos; option java_outer_classname = "CatalystTypes"; +// TODO: import one or more protobuf files. + message BooleanMsg { bool bool_type = 1; } diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto index f38c041b799ec..60f8c26214153 100644 --- a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto +++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto @@ -20,7 +20,7 @@ syntax = "proto3"; -package org.apache.spark.sql.protobuf; +package org.apache.spark.sql.protobuf.protos; option java_outer_classname = "SimpleMessageProtos"; @@ -119,7 +119,7 @@ message SimpleMessageEnum { string key = 1; string value = 2; enum NestedEnum { - ESTED_NOTHING = 0; + ESTED_NOTHING = 0; // TODO: Fix the name. NESTED_FIRST = 1; NESTED_SECOND = 2; } diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto index 1e3065259aa02..a7459213a87b2 100644 --- a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto +++ b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto @@ -20,11 +20,11 @@ syntax = "proto3"; -package org.apache.spark.sql.protobuf; -option java_outer_classname = "SimpleMessageProtos"; +package org.apache.spark.sql.protobuf.protos; +option java_outer_classname = "SerdeSuiteProtos"; /* Clean Message*/ -message BasicMessage { +message SerdeBasicMessage { Foo foo = 1; } diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala index b730ebb4fea80..19774a2ad07e4 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, NoopFilters, OrderedFilters, StructFilters} import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} +import org.apache.spark.sql.protobuf.protos.CatalystTypes.BytesMsg import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, SchemaConverters} import org.apache.spark.sql.sources.{EqualTo, Not} import org.apache.spark.sql.test.SharedSparkSession @@ -35,18 +36,32 @@ class ProtobufCatalystDataConversionSuite with SharedSparkSession with ExpressionEvalHelper { - private def checkResult( + private val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.CatalystTypes$" + + private def checkResultWithEval( data: Literal, descFilePath: String, messageName: String, expected: Any): Unit = { - checkEvaluation( - ProtobufDataToCatalyst( - CatalystDataToProtobuf(data, descFilePath, messageName), - descFilePath, - messageName, - Map.empty), - prepareExpectedResult(expected)) + + withClue("(Eval check with Java class name)") { + val className = s"$javaClassNamePrefix$messageName" + checkEvaluation( + ProtobufDataToCatalyst( + CatalystDataToProtobuf(data, className), + className, + descFilePath = None), + prepareExpectedResult(expected)) + } + withClue("(Eval check with descriptor file)") { + checkEvaluation( + ProtobufDataToCatalyst( + CatalystDataToProtobuf(data, messageName, Some(descFilePath)), + messageName, + descFilePath = Some(descFilePath)), + prepareExpectedResult(expected)) + } } protected def checkUnsupportedRead( @@ -55,10 +70,11 @@ class ProtobufCatalystDataConversionSuite actualSchema: String, badSchema: String): Unit = { - val binary = CatalystDataToProtobuf(data, descFilePath, actualSchema) + val binary = CatalystDataToProtobuf(data, actualSchema, Some(descFilePath)) intercept[Exception] { - ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> "FAILFAST")).eval() + ProtobufDataToCatalyst(binary, badSchema, Some(descFilePath), Map("mode" -> "FAILFAST")) + .eval() } val expected = { @@ -73,7 +89,7 @@ class ProtobufCatalystDataConversionSuite } checkEvaluation( - ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> "PERMISSIVE")), + ProtobufDataToCatalyst(binary, badSchema, Some(descFilePath), Map("mode" -> "PERMISSIVE")), expected) } @@ -99,26 +115,32 @@ class ProtobufCatalystDataConversionSuite StructType(StructField("bytes_type", BinaryType, nullable = true) :: Nil), StructType(StructField("string_type", StringType, nullable = true) :: Nil)) - private val catalystTypesToProtoMessages: Map[DataType, String] = Map( - IntegerType -> "IntegerMsg", - DoubleType -> "DoubleMsg", - FloatType -> "FloatMsg", - BinaryType -> "BytesMsg", - StringType -> "StringMsg") + private val catalystTypesToProtoMessages: Map[DataType, (String, Any)] = Map( + IntegerType -> ("IntegerMsg", 0), + DoubleType -> ("DoubleMsg", 0.0d), + FloatType -> ("FloatMsg", 0.0f), + BinaryType -> ("BytesMsg", ByteString.empty().toByteArray), + StringType -> ("StringMsg", "")) testingTypes.foreach { dt => val seed = 1 + scala.util.Random.nextInt((1024 - 1) + 1) - val filePath = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") test(s"single $dt with seed $seed") { + + val (messageName, defaultValue) = catalystTypesToProtoMessages(dt.fields(0).dataType) + val rand = new scala.util.Random(seed) - val data = RandomDataGenerator.forType(dt, rand = rand).get.apply() + val generator = RandomDataGenerator.forType(dt, rand = rand).get + var data = generator() + while (data.asInstanceOf[Row].get(0) == defaultValue) // Do not use default values, since + data = generator() // from_protobuf() returns null in v3. + val converter = CatalystTypeConverters.createToCatalystConverter(dt) val input = Literal.create(converter(data), dt) - checkResult( + checkResultWithEval( input, - filePath, - catalystTypesToProtoMessages(dt.fields(0).dataType), + testFileDesc, + messageName, input.eval()) } } @@ -137,6 +159,15 @@ class ProtobufCatalystDataConversionSuite val dynMsg = DynamicMessage.parseFrom(descriptor, data.toByteArray) val deserialized = deserializer.deserialize(dynMsg) + + // Verify Java class deserializer matches with descriptor based serializer. + val javaDescriptor = ProtobufUtils + .buildDescriptorFromJavaClass(s"$javaClassNamePrefix$messageName") + assert(dataType == SchemaConverters.toSqlType(javaDescriptor).dataType) + val javaDeserialized = new ProtobufDeserializer(javaDescriptor, dataType, filters) + .deserialize(DynamicMessage.parseFrom(javaDescriptor, data.toByteArray)) + assert(deserialized == javaDeserialized) + expected match { case None => assert(deserialized.isEmpty) case Some(d) => @@ -145,7 +176,6 @@ class ProtobufCatalystDataConversionSuite } test("Handle unsupported input of message type") { - val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") val actualSchema = StructType( Seq( StructField("col_0", StringType, nullable = false), @@ -165,7 +195,6 @@ class ProtobufCatalystDataConversionSuite test("filter push-down to Protobuf deserializer") { - val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") val sqlSchema = new StructType() .add("name", "string") .add("age", "int") @@ -196,17 +225,23 @@ class ProtobufCatalystDataConversionSuite test("ProtobufDeserializer with binary type") { - val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") val bb = java.nio.ByteBuffer.wrap(Array[Byte](97, 48, 53)) - val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg") - - val dynamicMessage = DynamicMessage - .newBuilder(descriptor) - .setField(descriptor.findFieldByName("bytes_type"), ByteString.copyFrom(bb)) + val bytesProto = BytesMsg + .newBuilder() + .setBytesType(ByteString.copyFrom(bb)) .build() val expected = InternalRow(Array[Byte](97, 48, 53)) - checkDeserialization(testFileDesc, "BytesMsg", dynamicMessage, Some(expected)) + checkDeserialization(testFileDesc, "BytesMsg", bytesProto, Some(expected)) + } + + test("Full names for message using descriptor file") { + val withShortName = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg") + assert(withShortName.findFieldByName("bytes_type") != null) + + val withFullName = ProtobufUtils.buildDescriptor( + testFileDesc, "org.apache.spark.sql.protobuf.BytesMsg") + assert(withFullName.findFieldByName("bytes_type") != null) } } diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 4e9bc1c1c287a..72280fb0d9e2d 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -23,8 +23,10 @@ import scala.collection.JavaConverters._ import com.google.protobuf.{ByteString, DynamicMessage} -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{Column, QueryTest, Row} import org.apache.spark.sql.functions.{lit, struct} +import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated +import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum import org.apache.spark.sql.protobuf.utils.ProtobufUtils import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException import org.apache.spark.sql.test.SharedSparkSession @@ -35,6 +37,39 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri import testImplicits._ val testFileDesc = testFile("protobuf/functions_suite.desc").replace("file:/", "/") + private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$" + + /** + * Runs the given closure twice. Once with descriptor file and second time with Java class name. + */ + private def checkWithFileAndClassName(messageName: String)( + fn: (String, Option[String]) => Unit): Unit = { + withClue("(With descriptor file)") { + fn(messageName, Some(testFileDesc)) + } + withClue("(With Java class name)") { + fn(s"$javaClassNamePrefix$messageName", None) + } + } + + // A wrapper to invoke the right variable of from_protobuf() depending on arguments. + private def from_protobuf_wrapper( + col: Column, messageName: String, descFilePathOpt: Option[String]): Column = { + descFilePathOpt match { + case Some(descFilePath) => functions.from_protobuf(col, messageName, descFilePath) + case None => functions.from_protobuf(col, messageName) + } + } + + // A wrapper to invoke the right variable of to_protobuf() depending on arguments. + private def to_protobuf_wrapper( + col: Column, messageName: String, descFilePathOpt: Option[String]): Column = { + descFilePathOpt match { + case Some(descFilePath) => functions.to_protobuf(col, messageName, descFilePath) + case None => functions.to_protobuf(col, messageName) + } + } + test("roundtrip in to_protobuf and from_protobuf - struct") { val df = spark @@ -56,44 +91,45 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri lit(1202.00).cast(org.apache.spark.sql.types.FloatType).as("float_value"), lit(true).as("bool_value"), lit("0".getBytes).as("bytes_value")).as("SimpleMessage")) - val protoStructDF = df.select( - functions.to_protobuf($"SimpleMessage", testFileDesc, "SimpleMessage").as("proto")) - val actualDf = protoStructDF.select( - functions.from_protobuf($"proto", testFileDesc, "SimpleMessage").as("proto.*")) - checkAnswer(actualDf, df) + + checkWithFileAndClassName("SimpleMessage") { + case (name, descFilePathOpt) => + val protoStructDF = df.select( + to_protobuf_wrapper($"SimpleMessage", name, descFilePathOpt).as("proto")) + val actualDf = protoStructDF.select( + from_protobuf_wrapper($"proto", name, descFilePathOpt).as("proto.*")) + checkAnswer(actualDf, df) + } } test("roundtrip in from_protobuf and to_protobuf - Repeated") { - val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageRepeated") - val dynamicMessage = DynamicMessage - .newBuilder(descriptor) - .setField(descriptor.findFieldByName("key"), "key") - .setField(descriptor.findFieldByName("value"), "value") - .addRepeatedField(descriptor.findFieldByName("rbool_value"), false) - .addRepeatedField(descriptor.findFieldByName("rbool_value"), true) - .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092092.654d) - .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092093.654d) - .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10903.0f) - .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10902.0f) - .addRepeatedField( - descriptor.findFieldByName("rnested_enum"), - descriptor.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING")) - .addRepeatedField( - descriptor.findFieldByName("rnested_enum"), - descriptor.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST")) + val protoMessage = SimpleMessageRepeated + .newBuilder() + .setKey("key") + .setValue("value") + .addRboolValue(false) + .addRboolValue(true) + .addRdoubleValue(1092092.654d) + .addRdoubleValue(1092093.654d) + .addRfloatValue(10903.0f) + .addRfloatValue(10902.0f) + .addRnestedEnum(NestedEnum.ESTED_NOTHING) + .addRnestedEnum(NestedEnum.NESTED_FIRST) .build() - val df = Seq(dynamicMessage.toByteArray).toDF("value") - val fromProtoDF = df.select( - functions.from_protobuf($"value", testFileDesc, "SimpleMessageRepeated").as("value_from")) - val toProtoDF = fromProtoDF.select( - functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageRepeated").as("value_to")) - val toFromProtoDF = toProtoDF.select( - functions - .from_protobuf($"value_to", testFileDesc, "SimpleMessageRepeated") - .as("value_to_from")) - checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + val df = Seq(protoMessage.toByteArray).toDF("value") + + checkWithFileAndClassName("SimpleMessageRepeated") { + case (name, descFilePathOpt) => + val fromProtoDF = df.select( + from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from")) + val toProtoDF = fromProtoDF.select( + to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) + val toFromProtoDF = toProtoDF.select( + from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } } test("roundtrip in from_protobuf and to_protobuf - Repeated Message Once") { @@ -120,13 +156,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri .build() val df = Seq(dynamicMessage.toByteArray).toDF("value") - val fromProtoDF = df.select( - functions.from_protobuf($"value", testFileDesc, "RepeatedMessage").as("value_from")) - val toProtoDF = fromProtoDF.select( - functions.to_protobuf($"value_from", testFileDesc, "RepeatedMessage").as("value_to")) - val toFromProtoDF = toProtoDF.select( - functions.from_protobuf($"value_to", testFileDesc, "RepeatedMessage").as("value_to_from")) - checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + + checkWithFileAndClassName("RepeatedMessage") { + case (name, descFilePathOpt) => + val fromProtoDF = df.select( + from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from")) + val toProtoDF = fromProtoDF.select( + to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) + val toFromProtoDF = toProtoDF.select( + from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } } test("roundtrip in from_protobuf and to_protobuf - Repeated Message Twice") { @@ -167,13 +207,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri .build() val df = Seq(dynamicMessage.toByteArray).toDF("value") - val fromProtoDF = df.select( - functions.from_protobuf($"value", testFileDesc, "RepeatedMessage").as("value_from")) - val toProtoDF = fromProtoDF.select( - functions.to_protobuf($"value_from", testFileDesc, "RepeatedMessage").as("value_to")) - val toFromProtoDF = toProtoDF.select( - functions.from_protobuf($"value_to", testFileDesc, "RepeatedMessage").as("value_to_from")) - checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + + checkWithFileAndClassName("RepeatedMessage") { + case (name, descFilePathOpt) => + val fromProtoDF = df.select( + from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from")) + val toProtoDF = fromProtoDF.select( + to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) + val toFromProtoDF = toProtoDF.select( + from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } } test("roundtrip in from_protobuf and to_protobuf - Map") { @@ -257,13 +301,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri .build() val df = Seq(dynamicMessage.toByteArray).toDF("value") - val fromProtoDF = df.select( - functions.from_protobuf($"value", testFileDesc, "SimpleMessageMap").as("value_from")) - val toProtoDF = fromProtoDF.select( - functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageMap").as("value_to")) - val toFromProtoDF = toProtoDF.select( - functions.from_protobuf($"value_to", testFileDesc, "SimpleMessageMap").as("value_to_from")) - checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + + checkWithFileAndClassName("SimpleMessageMap") { + case (name, descFilePathOpt) => + val fromProtoDF = df.select( + from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from")) + val toProtoDF = fromProtoDF.select( + to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) + val toFromProtoDF = toProtoDF.select( + from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } } test("roundtrip in from_protobuf and to_protobuf - Enum") { @@ -289,13 +337,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri .build() val df = Seq(dynamicMessage.toByteArray).toDF("value") - val fromProtoDF = df.select( - functions.from_protobuf($"value", testFileDesc, "SimpleMessageEnum").as("value_from")) - val toProtoDF = fromProtoDF.select( - functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageEnum").as("value_to")) - val toFromProtoDF = toProtoDF.select( - functions.from_protobuf($"value_to", testFileDesc, "SimpleMessageEnum").as("value_to_from")) - checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + + checkWithFileAndClassName("SimpleMessageEnum") { + case (name, descFilePathOpt) => + val fromProtoDF = df.select( + from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from")) + val toProtoDF = fromProtoDF.select( + to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) + val toFromProtoDF = toProtoDF.select( + from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } } test("roundtrip in from_protobuf and to_protobuf - Multiple Message") { @@ -320,13 +372,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri .build() val df = Seq(dynamicMessage.toByteArray).toDF("value") - val fromProtoDF = df.select( - functions.from_protobuf($"value", testFileDesc, "MultipleExample").as("value_from")) - val toProtoDF = fromProtoDF.select( - functions.to_protobuf($"value_from", testFileDesc, "MultipleExample").as("value_to")) - val toFromProtoDF = toProtoDF.select( - functions.from_protobuf($"value_to", testFileDesc, "MultipleExample").as("value_to_from")) - checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + + checkWithFileAndClassName("MultipleExample") { + case (name, descFilePathOpt) => + val fromProtoDF = df.select( + from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from")) + val toProtoDF = fromProtoDF.select( + to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) + val toFromProtoDF = toProtoDF.select( + from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } } test("Handle recursive fields in Protobuf schema, A->B->A") { @@ -352,15 +408,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri val df = Seq(messageB.toByteArray).toDF("messageB") - val e = intercept[IncompatibleSchemaException] { - df.select( - functions.from_protobuf($"messageB", testFileDesc, "recursiveB").as("messageFromProto")) - .show() + checkWithFileAndClassName("recursiveB") { + case (name, descFilePathOpt) => + val e = intercept[IncompatibleSchemaException] { + df.select( + from_protobuf_wrapper($"messageB", name, descFilePathOpt).as("messageFromProto")) + .show() + } + assert(e.getMessage.contains( + "Found recursive reference in Protobuf schema, which can not be processed by Spark:" + )) } - val expectedMessage = s""" - |Found recursive reference in Protobuf schema, which can not be processed by Spark: - |org.apache.spark.sql.protobuf.recursiveB.messageA""".stripMargin - assert(e.getMessage == expectedMessage) } test("Handle recursive fields in Protobuf schema, C->D->Array(C)") { @@ -386,16 +444,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri val df = Seq(messageD.toByteArray).toDF("messageD") - val e = intercept[IncompatibleSchemaException] { - df.select( - functions.from_protobuf($"messageD", testFileDesc, "recursiveD").as("messageFromProto")) - .show() + checkWithFileAndClassName("recursiveD") { + case (name, descFilePathOpt) => + val e = intercept[IncompatibleSchemaException] { + df.select( + from_protobuf_wrapper($"messageD", name, descFilePathOpt).as("messageFromProto")) + .show() + } + assert(e.getMessage.contains( + "Found recursive reference in Protobuf schema, which can not be processed by Spark:" + )) } - val expectedMessage = - s""" - |Found recursive reference in Protobuf schema, which can not be processed by Spark: - |org.apache.spark.sql.protobuf.recursiveD.messageC""".stripMargin - assert(e.getMessage == expectedMessage) } test("Handle extra fields : oldProducer -> newConsumer") { @@ -411,17 +470,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri val df = Seq(oldProducerMessage.toByteArray).toDF("oldProducerData") val fromProtoDf = df.select( functions - .from_protobuf($"oldProducerData", testFileDesc, "newConsumer") + .from_protobuf($"oldProducerData", "newConsumer", testFileDesc) .as("fromProto")) val toProtoDf = fromProtoDf.select( functions - .to_protobuf($"fromProto", testFileDesc, "newConsumer") + .to_protobuf($"fromProto", "newConsumer", testFileDesc) .as("toProto")) val toProtoDfToFromProtoDf = toProtoDf.select( functions - .from_protobuf($"toProto", testFileDesc, "newConsumer") + .from_protobuf($"toProto", "newConsumer", testFileDesc) .as("toProtoToFromProto")) val actualFieldNames = @@ -452,7 +511,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri val df = Seq(newProducerMessage.toByteArray).toDF("newProducerData") val fromProtoDf = df.select( functions - .from_protobuf($"newProducerData", testFileDesc, "oldConsumer") + .from_protobuf($"newProducerData", "oldConsumer", testFileDesc) .as("oldConsumerProto")) val expectedFieldNames = oldConsumer.getFields.asScala.map(f => f.getName) @@ -481,8 +540,9 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri )), schema ) + val toProtobuf = inputDf.select( - functions.to_protobuf($"requiredMsg", testFileDesc, "requiredMsg") + functions.to_protobuf($"requiredMsg", "requiredMsg", testFileDesc) .as("to_proto")) val binary = toProtobuf.take(1).toSeq(0).get(0).asInstanceOf[Array[Byte]] @@ -498,7 +558,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri assert(actualMessage.getField(messageDescriptor.findFieldByName("col_3")) == 0) val fromProtoDf = toProtobuf.select( - functions.from_protobuf($"to_proto", testFileDesc, "requiredMsg") as 'from_proto) + functions.from_protobuf($"to_proto", "requiredMsg", testFileDesc) as 'from_proto) assert(fromProtoDf.select("from_proto.key").take(1).toSeq(0).get(0) == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0)) @@ -526,16 +586,20 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri .build() val df = Seq(basicMessage.toByteArray).toDF("value") - val resultFrom = df - .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") as 'sample) - .where("sample.string_value == \"slam\"") - val resultToFrom = resultFrom - .select(functions.to_protobuf($"sample", testFileDesc, "BasicMessage") as 'value) - .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") as 'sample) - .where("sample.string_value == \"slam\"") + checkWithFileAndClassName("BasicMessage") { + case (name, descFilePathOpt) => + val resultFrom = df + .select(from_protobuf_wrapper($"value", name, descFilePathOpt) as 'sample) + .where("sample.string_value == \"slam\"") + + val resultToFrom = resultFrom + .select(to_protobuf_wrapper($"sample", name, descFilePathOpt) as 'value) + .select(from_protobuf_wrapper($"value", name, descFilePathOpt) as 'sample) + .where("sample.string_value == \"slam\"") - assert(resultFrom.except(resultToFrom).isEmpty) + assert(resultFrom.except(resultToFrom).isEmpty) + } } test("Handle TimestampType between to_protobuf and from_protobuf") { @@ -556,22 +620,24 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri schema ) - val toProtoDf = inputDf - .select(functions.to_protobuf($"timeStampMsg", testFileDesc, "timeStampMsg") as 'to_proto) + checkWithFileAndClassName("timeStampMsg") { + case (name, descFilePathOpt) => + val toProtoDf = inputDf + .select(to_protobuf_wrapper($"timeStampMsg", name, descFilePathOpt) as 'to_proto) - val fromProtoDf = toProtoDf - .select(functions.from_protobuf($"to_proto", testFileDesc, "timeStampMsg") as 'timeStampMsg) - fromProtoDf.show(truncate = false) + val fromProtoDf = toProtoDf + .select(from_protobuf_wrapper($"to_proto", name, descFilePathOpt) as 'timeStampMsg) - val actualFields = fromProtoDf.schema.fields.toList - val expectedFields = inputDf.schema.fields.toList + val actualFields = fromProtoDf.schema.fields.toList + val expectedFields = inputDf.schema.fields.toList - assert(actualFields.size === expectedFields.size) - assert(actualFields === expectedFields) - assert(fromProtoDf.select("timeStampMsg.key").take(1).toSeq(0).get(0) - === inputDf.select("timeStampMsg.key").take(1).toSeq(0).get(0)) - assert(fromProtoDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0) - === inputDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0)) + assert(actualFields.size === expectedFields.size) + assert(actualFields === expectedFields) + assert(fromProtoDf.select("timeStampMsg.key").take(1).toSeq(0).get(0) + === inputDf.select("timeStampMsg.key").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0) + === inputDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0)) + } } test("Handle DayTimeIntervalType between to_protobuf and from_protobuf") { @@ -595,21 +661,23 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri schema ) - val toProtoDf = inputDf - .select(functions.to_protobuf($"durationMsg", testFileDesc, "durationMsg") as 'to_proto) + checkWithFileAndClassName("durationMsg") { + case (name, descFilePathOpt) => + val toProtoDf = inputDf + .select(to_protobuf_wrapper($"durationMsg", name, descFilePathOpt) as 'to_proto) - val fromProtoDf = toProtoDf - .select(functions.from_protobuf($"to_proto", testFileDesc, "durationMsg") as 'durationMsg) + val fromProtoDf = toProtoDf + .select(from_protobuf_wrapper($"to_proto", name, descFilePathOpt) as 'durationMsg) - val actualFields = fromProtoDf.schema.fields.toList - val expectedFields = inputDf.schema.fields.toList - - assert(actualFields.size === expectedFields.size) - assert(actualFields === expectedFields) - assert(fromProtoDf.select("durationMsg.key").take(1).toSeq(0).get(0) - === inputDf.select("durationMsg.key").take(1).toSeq(0).get(0)) - assert(fromProtoDf.select("durationMsg.duration").take(1).toSeq(0).get(0) - === inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0)) + val actualFields = fromProtoDf.schema.fields.toList + val expectedFields = inputDf.schema.fields.toList + assert(actualFields.size === expectedFields.size) + assert(actualFields === expectedFields) + assert(fromProtoDf.select("durationMsg.key").take(1).toSeq(0).get(0) + === inputDf.select("durationMsg.key").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("durationMsg.duration").take(1).toSeq(0).get(0) + === inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0)) + } } } diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala index 37c59743e7714..efc02524e68db 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala @@ -36,6 +36,7 @@ class ProtobufSerdeSuite extends SharedSparkSession { import ProtoSerdeSuite.MatchType._ val testFileDesc = testFile("protobuf/serde_suite.desc").replace("file:/", "/") + private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SerdeSuiteProtos$" test("Test basic conversion") { withFieldMatchType { fieldMatch => @@ -96,7 +97,9 @@ class ProtobufSerdeSuite extends SharedSparkSession { } test("Fail to convert with deeply nested field type mismatch") { - val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "MissMatchTypeInDeepNested") + val protoFile = ProtobufUtils.buildDescriptorFromJavaClass( + s"${javaClassNamePrefix}MissMatchTypeInDeepNested" + ) val catalyst = new StructType().add("top", CATALYST_STRUCT) withFieldMatchType { fieldMatch => @@ -105,8 +108,8 @@ class ProtobufSerdeSuite extends SharedSparkSession { Deserializer, fieldMatch, s"Cannot convert Protobuf field 'top.foo.bar' to SQL field 'top.foo.bar' because schema " + - s"is incompatible (protoType = org.apache.spark.sql.protobuf.TypeMiss.bar " + - s"LABEL_OPTIONAL LONG INT64, sqlType = INT)".stripMargin, + s"is incompatible (protoType = org.apache.spark.sql.protobuf.protos.TypeMiss.bar " + + s"LABEL_OPTIONAL LONG INT64, sqlType = INT)", catalyst) assertFailedConversionMessage( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e5a48080e833a..cc103e4ab00ac 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -716,8 +716,10 @@ object SparkProtobuf { dependencyOverrides += "com.google.protobuf" % "protobuf-java" % protoVersion, - (Compile / PB.targets) := Seq( - PB.gens.java -> (Compile / sourceManaged).value, + (Test / PB.protoSources) += (Test / sourceDirectory).value / "resources", + + (Test / PB.targets) := Seq( + PB.gens.java -> target.value / "generated-test-sources" ), (assembly / test) := { }, diff --git a/python/pyspark/sql/protobuf/functions.py b/python/pyspark/sql/protobuf/functions.py index 9f8b90095dfd9..2059d868c7c12 100644 --- a/python/pyspark/sql/protobuf/functions.py +++ b/python/pyspark/sql/protobuf/functions.py @@ -31,8 +31,8 @@ def from_protobuf( data: "ColumnOrName", - descFilePath: str, messageName: str, + descFilePath: str, options: Optional[Dict[str, str]] = None, ) -> Column: """ @@ -48,10 +48,10 @@ def from_protobuf( ---------- data : :class:`~pyspark.sql.Column` or str the binary column. - descFilePath : str - the protobuf descriptor in Message GeneratedMessageV3 format. messageName: str the protobuf message name to look for in descriptor file. + descFilePath : str + the protobuf descriptor in Message GeneratedMessageV3 format. options : dict, optional options to control how the protobuf record is parsed. @@ -80,10 +80,10 @@ def from_protobuf( ... f.flush() ... message_name = 'SimpleMessage' ... proto_df = df.select( - ... to_protobuf(df.value, desc_file_path, message_name).alias("value")) + ... to_protobuf(df.value, message_name, desc_file_path).alias("value")) ... proto_df.show(truncate=False) ... proto_df = proto_df.select( - ... from_protobuf(proto_df.value, desc_file_path, message_name).alias("value")) + ... from_protobuf(proto_df.value, message_name, desc_file_path).alias("value")) ... proto_df.show(truncate=False) +----------------------------------------+ |value | @@ -101,7 +101,7 @@ def from_protobuf( assert sc is not None and sc._jvm is not None try: jc = sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf( - _to_java_column(data), descFilePath, messageName, options or {} + _to_java_column(data), messageName, descFilePath, options or {} ) except TypeError as e: if str(e) == "'JavaPackage' object is not callable": @@ -110,7 +110,7 @@ def from_protobuf( return Column(jc) -def to_protobuf(data: "ColumnOrName", descFilePath: str, messageName: str) -> Column: +def to_protobuf(data: "ColumnOrName", messageName: str, descFilePath: str) -> Column: """ Converts a column into binary of protobuf format. @@ -120,10 +120,10 @@ def to_protobuf(data: "ColumnOrName", descFilePath: str, messageName: str) -> Co ---------- data : :class:`~pyspark.sql.Column` or str the data column. - descFilePath : str - the protobuf descriptor in Message GeneratedMessageV3 format. messageName: str the protobuf message name to look for in descriptor file. + descFilePath : str + the protobuf descriptor in Message GeneratedMessageV3 format. Notes ----- @@ -150,7 +150,7 @@ def to_protobuf(data: "ColumnOrName", descFilePath: str, messageName: str) -> Co ... f.flush() ... message_name = 'SimpleMessage' ... proto_df = df.select( - ... to_protobuf(df.value, desc_file_path, message_name).alias("suite")) + ... to_protobuf(df.value, message_name, desc_file_path).alias("suite")) ... proto_df.show(truncate=False) +-------------------------------------------+ |suite | @@ -162,7 +162,7 @@ def to_protobuf(data: "ColumnOrName", descFilePath: str, messageName: str) -> Co assert sc is not None and sc._jvm is not None try: jc = sc._jvm.org.apache.spark.sql.protobuf.functions.to_protobuf( - _to_java_column(data), descFilePath, messageName + _to_java_column(data), messageName, descFilePath ) except TypeError as e: if str(e) == "'JavaPackage' object is not callable":