diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml
index b934c7f831a2..36c9efb3d8c3 100644
--- a/connector/protobuf/pom.xml
+++ b/connector/protobuf/pom.xml
@@ -113,20 +113,54 @@
com.github.os72
protoc-jar-maven-plugin
3.11.4
-
+
generate-test-sources
+ v2-protos
run
com.google.protobuf:protoc:${protobuf.version}
${protobuf.version}
-
- src/test/resources/protobuf
-
- test
+ src/test/protobuf/v2
+
+
+ java
+ test
+
+
+ descriptor
+ test
+ target/scala-${scala.binary.version}/test-classes
+ descriptor-set-v2
+
+
+
+
+
+ generate-test-sources
+ v3-protos
+
+ run
+
+
+ com.google.protobuf:protoc:${protobuf.version}
+ ${protobuf.version}
+ src/test/protobuf/v3
+
+
+ java
+ test
+
+
+ descriptor
+ test
+ target/scala-${scala.binary.version}/test-classes
+ descriptor-set-v3
+
+
diff --git a/connector/protobuf/src/test/protobuf/v2/catalyst_types.proto b/connector/protobuf/src/test/protobuf/v2/catalyst_types.proto
new file mode 100644
index 000000000000..f2e11486a448
--- /dev/null
+++ b/connector/protobuf/src/test/protobuf/v2/catalyst_types.proto
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto2";
+
+package org.apache.spark.sql.protobuf.protos.v2;
+option java_outer_classname = "CatalystTypes";
+
+// TODO: import one or more protobuf files.
+
+message BooleanMsg {
+ optional bool bool_type = 1;
+}
+message IntegerMsg {
+ optional int32 int32_type = 1;
+}
+message DoubleMsg {
+ optional double double_type = 1;
+}
+message FloatMsg {
+ optional float float_type = 1;
+}
+message BytesMsg {
+ optional bytes bytes_type = 1;
+}
+message StringMsg {
+ optional string string_type = 1;
+}
+
+message Person {
+ optional string name = 1;
+ optional int32 age = 2;
+}
+
+message Bad {
+ optional bytes col_0 = 1;
+ optional double col_1 = 2;
+ optional string col_2 = 3;
+ optional float col_3 = 4;
+ optional int64 col_4 = 5;
+}
+
+message Actual {
+ optional string col_0 = 1;
+ optional int32 col_1 = 2;
+ optional float col_2 = 3;
+ optional bool col_3 = 4;
+ optional double col_4 = 5;
+}
+
+message oldConsumer {
+ optional string key = 1;
+}
+
+message newProducer {
+ optional string key = 1;
+ optional int32 value = 2;
+}
+
+message newConsumer {
+ optional string key = 1;
+ optional int32 value = 2;
+ optional Actual actual = 3;
+}
+
+message oldProducer {
+ optional string key = 1;
+}
diff --git a/connector/protobuf/src/test/protobuf/v2/functions_suite.proto b/connector/protobuf/src/test/protobuf/v2/functions_suite.proto
new file mode 100644
index 000000000000..0002ebe55f35
--- /dev/null
+++ b/connector/protobuf/src/test/protobuf/v2/functions_suite.proto
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto2";
+
+package org.apache.spark.sql.protobuf.protos.v2;
+
+option java_outer_classname = "SimpleMessageProtos";
+
+message SimpleMessageJavaTypes {
+ optional int64 id = 1;
+ optional string string_value = 2;
+ optional int32 int32_value = 3;
+ optional int64 int64_value = 4;
+ optional double double_value = 5;
+ optional float float_value = 6;
+ optional bool bool_value = 7;
+ optional bytes bytes_value = 8;
+}
+
+message SimpleMessage {
+ optional int64 id = 1;
+ optional string string_value = 2;
+ optional int32 int32_value = 3;
+ optional uint32 uint32_value = 4;
+ optional sint32 sint32_value = 5;
+ optional fixed32 fixed32_value = 6;
+ optional sfixed32 sfixed32_value = 7;
+ optional int64 int64_value = 8;
+ optional uint64 uint64_value = 9;
+ optional sint64 sint64_value = 10;
+ optional fixed64 fixed64_value = 11;
+ optional sfixed64 sfixed64_value = 12;
+ optional double double_value = 13;
+ optional float float_value = 14;
+ optional bool bool_value = 15;
+ optional bytes bytes_value = 16;
+}
+
+message SimpleMessageRepeated {
+ optional string key = 1;
+ optional string value = 2;
+ enum NestedEnum {
+ ESTED_NOTHING = 0;
+ NESTED_FIRST = 1;
+ NESTED_SECOND = 2;
+ }
+ repeated string rstring_value = 3;
+ repeated int32 rint32_value = 4;
+ repeated bool rbool_value = 5;
+ repeated int64 rint64_value = 6;
+ repeated float rfloat_value = 7;
+ repeated double rdouble_value = 8;
+ repeated bytes rbytes_value = 9;
+ repeated NestedEnum rnested_enum = 10;
+}
+
+message BasicMessage {
+ optional int64 id = 1;
+ optional string string_value = 2;
+ optional int32 int32_value = 3;
+ optional int64 int64_value = 4;
+ optional double double_value = 5;
+ optional float float_value = 6;
+ optional bool bool_value = 7;
+ optional bytes bytes_value = 8;
+}
+
+message RepeatedMessage {
+ repeated BasicMessage basic_message = 1;
+}
+
+message SimpleMessageMap {
+ optional string key = 1;
+ optional string value = 2;
+ map string_mapdata = 3;
+ map int32_mapdata = 4;
+ map uint32_mapdata = 5;
+ map sint32_mapdata = 6;
+ map float32_mapdata = 7;
+ map sfixed32_mapdata = 8;
+ map int64_mapdata = 9;
+ map uint64_mapdata = 10;
+ map sint64_mapdata = 11;
+ map fixed64_mapdata = 12;
+ map sfixed64_mapdata = 13;
+ map double_mapdata = 14;
+ map float_mapdata = 15;
+ map bool_mapdata = 16;
+ map bytes_mapdata = 17;
+}
+
+message BasicEnumMessage {
+ enum BasicEnum {
+ NOTHING = 0;
+ FIRST = 1;
+ SECOND = 2;
+ }
+}
+
+message SimpleMessageEnum {
+ optional string key = 1;
+ optional string value = 2;
+ enum NestedEnum {
+ NESTED_NOTHING = 0;
+ NESTED_FIRST = 1;
+ NESTED_SECOND = 2;
+ }
+ optional BasicEnumMessage.BasicEnum basic_enum = 3;
+ optional NestedEnum nested_enum = 4;
+}
+
+
+message OtherExample {
+ optional string other = 1;
+}
+
+message IncludedExample {
+ optional string included = 1;
+ optional OtherExample other = 2;
+}
+
+message MultipleExample {
+ optional IncludedExample included_example = 1;
+}
+
+message recursiveA {
+ optional string keyA = 1;
+ optional recursiveB messageB = 2;
+}
+
+message recursiveB {
+ optional string keyB = 1;
+ optional recursiveA messageA = 2;
+}
+
+message recursiveC {
+ optional string keyC = 1;
+ optional recursiveD messageD = 2;
+}
+
+message recursiveD {
+ optional string keyD = 1;
+ repeated recursiveC messageC = 2;
+}
+
+message requiredMsg {
+ optional string key = 1;
+ optional int32 col_1 = 2;
+ optional string col_2 = 3;
+ optional int32 col_3 = 4;
+}
+
+// https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/timestamp.proto
+message Timestamp {
+ optional int64 seconds = 1;
+ optional int32 nanos = 2;
+}
+
+message timeStampMsg {
+ optional string key = 1;
+ optional Timestamp stmp = 2;
+}
+// https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/duration.proto
+message Duration {
+ optional int64 seconds = 1;
+ optional int32 nanos = 2;
+}
+
+message durationMsg {
+ optional string key = 1;
+ optional Duration duration = 2;
+}
+
+message ProtoWithDefaults {
+ optional string user_name = 1;
+ required int32 id = 2;
+ optional int32 api_quota = 3 [default = 100]; // Default 100 qps.
+ optional string location = 4 [default = "Unknown"];
+}
\ No newline at end of file
diff --git a/connector/protobuf/src/test/protobuf/v2/pyspark_test.proto b/connector/protobuf/src/test/protobuf/v2/pyspark_test.proto
new file mode 100644
index 000000000000..336c86cf7e79
--- /dev/null
+++ b/connector/protobuf/src/test/protobuf/v2/pyspark_test.proto
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto2";
+
+package org.apache.spark.sql.protobuf;
+option java_outer_classname = "SimpleMessageProtos";
+
+message SimpleMessage {
+ optional int32 age = 1;
+ optional string name = 2;
+ optional int64 score = 3;
+}
\ No newline at end of file
diff --git a/connector/protobuf/src/test/protobuf/v2/serde_suite.proto b/connector/protobuf/src/test/protobuf/v2/serde_suite.proto
new file mode 100644
index 000000000000..c401766c9f0d
--- /dev/null
+++ b/connector/protobuf/src/test/protobuf/v2/serde_suite.proto
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto2";
+
+package org.apache.spark.sql.protobuf.protos.v2;
+option java_outer_classname = "SerdeSuiteProtos";
+
+/* Clean Message*/
+message SerdeBasicMessage {
+ optional Foo foo = 1;
+}
+
+message Foo {
+ optional int32 bar = 1;
+}
+
+/* Field Type missMatch in root Message*/
+message MissMatchTypeInRoot {
+ optional int64 foo = 1;
+}
+
+/* Field bar missing from protobuf and Available in SQL*/
+message FieldMissingInProto {
+ optional MissingField foo = 1;
+}
+
+message MissingField {
+ optional int64 barFoo = 1;
+}
+
+/* Deep-nested field bar type missMatch Message*/
+message MissMatchTypeInDeepNested {
+ optional TypeMissNested top = 1;
+}
+
+message TypeMissNested {
+ optional TypeMiss foo = 1;
+}
+
+message TypeMiss {
+ optional int64 bar = 1;
+}
+
+/* Field boo missing from SQL root, but available in Protobuf root*/
+message FieldMissingInSQLRoot {
+ optional Foo foo = 1;
+ optional int32 boo = 2;
+}
+
+/* Field baz missing from SQL nested and available in Protobuf nested*/
+message FieldMissingInSQLNested {
+ optional Baz foo = 1;
+}
+
+message Baz {
+ optional int32 bar = 1;
+ optional int32 baz = 2;
+}
\ No newline at end of file
diff --git a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto b/connector/protobuf/src/test/protobuf/v3/catalyst_types.proto
similarity index 79%
rename from connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
rename to connector/protobuf/src/test/protobuf/v3/catalyst_types.proto
index 1deb193438c2..f8db4f382a5a 100644
--- a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
+++ b/connector/protobuf/src/test/protobuf/v3/catalyst_types.proto
@@ -14,12 +14,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
-// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/catalyst_types.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/catalyst_types.proto
syntax = "proto3";
-package org.apache.spark.sql.protobuf.protos;
+package org.apache.spark.sql.protobuf.protos.v3;
option java_outer_classname = "CatalystTypes";
// TODO: import one or more protobuf files.
diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto b/connector/protobuf/src/test/protobuf/v3/functions_suite.proto
similarity index 90%
rename from connector/protobuf/src/test/resources/protobuf/functions_suite.proto
rename to connector/protobuf/src/test/protobuf/v3/functions_suite.proto
index 60f8c2621415..f394500e1465 100644
--- a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto
+++ b/connector/protobuf/src/test/protobuf/v3/functions_suite.proto
@@ -14,13 +14,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-// To compile and create test class:
-// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/functions_suite.proto
-// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/functions_suite.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/functions_suite.proto
syntax = "proto3";
-package org.apache.spark.sql.protobuf.protos;
+package org.apache.spark.sql.protobuf.protos.v3;
option java_outer_classname = "SimpleMessageProtos";
diff --git a/connector/protobuf/src/test/resources/protobuf/pyspark_test.proto b/connector/protobuf/src/test/protobuf/v3/pyspark_test.proto
similarity index 66%
rename from connector/protobuf/src/test/resources/protobuf/pyspark_test.proto
rename to connector/protobuf/src/test/protobuf/v3/pyspark_test.proto
index 8750371349a0..f48fde9010b7 100644
--- a/connector/protobuf/src/test/resources/protobuf/pyspark_test.proto
+++ b/connector/protobuf/src/test/protobuf/v3/pyspark_test.proto
@@ -14,17 +14,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-// TODO(SPARK-40777): Instead of saving .desc files in resources, generate during build.
-// To compile and create test class:
-// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/pyspark_test.proto
-// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/pyspark_test.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/pyspark_test.proto
syntax = "proto3";
package org.apache.spark.sql.protobuf;
option java_outer_classname = "SimpleMessageProtos";
-
message SimpleMessage {
int32 age = 1;
string name = 2;
diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto b/connector/protobuf/src/test/protobuf/v3/serde_suite.proto
similarity index 78%
rename from connector/protobuf/src/test/resources/protobuf/serde_suite.proto
rename to connector/protobuf/src/test/protobuf/v3/serde_suite.proto
index a7459213a87b..be9eb6c815f7 100644
--- a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto
+++ b/connector/protobuf/src/test/protobuf/v3/serde_suite.proto
@@ -14,13 +14,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-// To compile and create test class:
-// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/serde_suite.proto
-// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/serde_suite.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/serde_suite.proto
syntax = "proto3";
-package org.apache.spark.sql.protobuf.protos;
+package org.apache.spark.sql.protobuf.protos.v3;
option java_outer_classname = "SerdeSuiteProtos";
/* Clean Message*/
diff --git a/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc b/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc
deleted file mode 100644
index 59255b488a03..000000000000
--- a/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc
+++ /dev/null
@@ -1,48 +0,0 @@
-
-‰
-Cconnector/protobuf/src/test/resources/protobuf/catalyst_types.protoorg.apache.spark.sql.protobuf")
-
-BooleanMsg
- bool_type (RboolType"+
-
-IntegerMsg
-
-int32_type (R int32Type",
- DoubleMsg
-double_type (R
-doubleType")
-FloatMsg
-
-float_type (R floatType")
-BytesMsg
-
-bytes_type (R bytesType",
- StringMsg
-string_type ( R
-stringType".
-Person
-name ( Rname
-age (Rage"n
-Bad
-col_0 (Rcol0
-col_1 (Rcol1
-col_2 ( Rcol2
-col_3 (Rcol3
-col_4 (Rcol4"q
-Actual
-col_0 ( Rcol0
-col_1 (Rcol1
-col_2 (Rcol2
-col_3 (Rcol3
-col_4 (Rcol4"
-oldConsumer
-key ( Rkey"5
-newProducer
-key ( Rkey
-value (Rvalue"t
-newConsumer
-key ( Rkey
-value (Rvalue=
-actual (2%.org.apache.spark.sql.protobuf.ActualRactual"
-oldProducer
-key ( RkeyBB
CatalystTypesbproto3
\ No newline at end of file
diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc
deleted file mode 100644
index 6e3a39672772..000000000000
Binary files a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc and /dev/null differ
diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.desc b/connector/protobuf/src/test/resources/protobuf/serde_suite.desc
deleted file mode 100644
index 3d1847eecc5c..000000000000
--- a/connector/protobuf/src/test/resources/protobuf/serde_suite.desc
+++ /dev/null
@@ -1,27 +0,0 @@
-
-²
-Fconnector/protobuf/src/test/resources/protobuf/proto_serde_suite.protoorg.apache.spark.sql.protobuf"D
-BasicMessage4
-foo (2".org.apache.spark.sql.protobuf.FooRfoo"
-Foo
-bar (Rbar"'
-MissMatchTypeInRoot
-foo (Rfoo"T
-FieldMissingInProto=
-foo (2+.org.apache.spark.sql.protobuf.MissingFieldRfoo"&
-MissingField
-barFoo (RbarFoo"\
-MissMatchTypeInDeepNested?
-top (2-.org.apache.spark.sql.protobuf.TypeMissNestedRtop"K
-TypeMissNested9
-foo (2'.org.apache.spark.sql.protobuf.TypeMissRfoo"
-TypeMiss
-bar (Rbar"_
-FieldMissingInSQLRoot4
-foo (2".org.apache.spark.sql.protobuf.FooRfoo
-boo (Rboo"O
-FieldMissingInSQLNested4
-foo (2".org.apache.spark.sql.protobuf.BazRfoo")
-Baz
-bar (Rbar
-baz (RbazBBSimpleMessageProtosbproto3
\ No newline at end of file
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
index 19774a2ad07e..9484bb8ecce6 100644
--- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, NoopFilters, OrderedFilters, StructFilters}
import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
-import org.apache.spark.sql.protobuf.protos.CatalystTypes.BytesMsg
+import org.apache.spark.sql.protobuf.protos.v3.CatalystTypes.BytesMsg
import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, SchemaConverters}
import org.apache.spark.sql.sources.{EqualTo, Not}
import org.apache.spark.sql.test.SharedSparkSession
@@ -33,34 +33,41 @@ import org.apache.spark.unsafe.types.UTF8String
class ProtobufCatalystDataConversionSuite
extends SparkFunSuite
+ with ProtobufTestUtils
with SharedSparkSession
with ExpressionEvalHelper {
- private val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
- private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.CatalystTypes$"
+ private val descPathV2 = descFilePathV2("catalyst_types.desc")
+ private val descPathV3 = descFilePathV3("catalyst_types.desc")
+ private val outerClass = "CatalystTypes"
private def checkResultWithEval(
data: Literal,
- descFilePath: String,
messageName: String,
expected: Any): Unit = {
- withClue("(Eval check with Java class name)") {
- val className = s"$javaClassNamePrefix$messageName"
- checkEvaluation(
- ProtobufDataToCatalyst(
- CatalystDataToProtobuf(data, className),
- className,
- descFilePath = None),
- prepareExpectedResult(expected))
- }
- withClue("(Eval check with descriptor file)") {
- checkEvaluation(
- ProtobufDataToCatalyst(
- CatalystDataToProtobuf(data, messageName, Some(descFilePath)),
- messageName,
- descFilePath = Some(descFilePath)),
- prepareExpectedResult(expected))
+ // Check all the 16 variations.
+ // 4 variations for how the record is written and 4 variations for how it is read.
+
+ val variations = Seq(
+ (messageName, Some(descPathV2), "V2 descriptor file"),
+ (javaClassFullNameV2(outerClass, messageName), None, "V2 Java class"),
+ (messageName, Some(descPathV3), "V3 descriptor file"),
+ (javaClassFullNameV3(outerClass, messageName), None, "V3 Java class")
+ )
+
+ for ( // 16 permutations of 4 variations above.
+ (writerMessageName, writerDesc, writerClue) <- variations;
+ (readerMessageName, readerDesc, readerClue) <- variations
+ ) {
+ withClue(s"(Eval check with $writerClue for writing $readerClue for reading") {
+ checkEvaluation(
+ ProtobufDataToCatalyst(
+ CatalystDataToProtobuf(data, writerMessageName, writerDesc),
+ readerMessageName,
+ descFilePath = readerDesc),
+ prepareExpectedResult(expected))
+ }
}
}
@@ -139,7 +146,6 @@ class ProtobufCatalystDataConversionSuite
checkResultWithEval(
input,
- testFileDesc,
messageName,
input.eval())
}
@@ -162,7 +168,7 @@ class ProtobufCatalystDataConversionSuite
// Verify Java class deserializer matches with descriptor based serializer.
val javaDescriptor = ProtobufUtils
- .buildDescriptorFromJavaClass(s"$javaClassNamePrefix$messageName")
+ .buildDescriptorFromJavaClass(javaClassFullNameV3(outerClass, messageName))
assert(dataType == SchemaConverters.toSqlType(javaDescriptor).dataType)
val javaDeserialized = new ProtobufDeserializer(javaDescriptor, dataType, filters)
.deserialize(DynamicMessage.parseFrom(javaDescriptor, data.toByteArray))
@@ -189,7 +195,7 @@ class ProtobufCatalystDataConversionSuite
val data = RandomDataGenerator.randomRow(new scala.util.Random(seed), actualSchema)
val converter = CatalystTypeConverters.createToCatalystConverter(actualSchema)
val input = Literal.create(converter(data), actualSchema)
- checkUnsupportedRead(input, testFileDesc, "Actual", "Bad")
+ checkUnsupportedRead(input, descPathV3, "Actual", "Bad")
}
}
@@ -199,7 +205,7 @@ class ProtobufCatalystDataConversionSuite
.add("name", "string")
.add("age", "int")
- val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "Person")
+ val descriptor = ProtobufUtils.buildDescriptor(descPathV3, "Person")
val dynamicMessage = DynamicMessage
.newBuilder(descriptor)
.setField(descriptor.findFieldByName("name"), "Maxim")
@@ -207,16 +213,16 @@ class ProtobufCatalystDataConversionSuite
.build()
val expectedRow = Some(InternalRow(UTF8String.fromString("Maxim"), 39))
- checkDeserialization(testFileDesc, "Person", dynamicMessage, expectedRow)
+ checkDeserialization(descPathV3, "Person", dynamicMessage, expectedRow)
checkDeserialization(
- testFileDesc,
+ descPathV3,
"Person",
dynamicMessage,
expectedRow,
new OrderedFilters(Seq(EqualTo("age", 39)), sqlSchema))
checkDeserialization(
- testFileDesc,
+ descPathV3,
"Person",
dynamicMessage,
None,
@@ -233,15 +239,15 @@ class ProtobufCatalystDataConversionSuite
.build()
val expected = InternalRow(Array[Byte](97, 48, 53))
- checkDeserialization(testFileDesc, "BytesMsg", bytesProto, Some(expected))
+ checkDeserialization(descPathV3, "BytesMsg", bytesProto, Some(expected))
}
test("Full names for message using descriptor file") {
- val withShortName = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg")
+ val withShortName = ProtobufUtils.buildDescriptor(descPathV3, "BytesMsg")
assert(withShortName.findFieldByName("bytes_type") != null)
val withFullName = ProtobufUtils.buildDescriptor(
- testFileDesc, "org.apache.spark.sql.protobuf.BytesMsg")
+ descPathV3, "org.apache.spark.sql.protobuf.protos.v3.BytesMsg")
assert(withFullName.findFieldByName("bytes_type") != null)
}
}
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
index 72280fb0d9e2..ef7cad250534 100644
--- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -25,31 +25,43 @@ import com.google.protobuf.{ByteString, DynamicMessage}
import org.apache.spark.sql.{Column, QueryTest, Row}
import org.apache.spark.sql.functions.{lit, struct}
-import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated
-import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum
+import org.apache.spark.sql.protobuf.protos.v3.SimpleMessageProtos.SimpleMessageRepeated
+import org.apache.spark.sql.protobuf.protos.v3.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum
import org.apache.spark.sql.protobuf.utils.ProtobufUtils
import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{DayTimeIntervalType, IntegerType, StringType, StructField, StructType, TimestampType}
-class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Serializable {
+class ProtobufFunctionsSuite
+ extends QueryTest
+ with ProtobufTestUtils
+ with SharedSparkSession
+ with Serializable {
import testImplicits._
- val testFileDesc = testFile("protobuf/functions_suite.desc").replace("file:/", "/")
- private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$"
+ private val descPathV2 = descFilePathV2("functions_suite.desc")
+ private val descPathV3 = descFilePathV3("functions_suite.desc")
+ private val outerClass = "SimpleMessageProtos"
/**
- * Runs the given closure twice. Once with descriptor file and second time with Java class name.
+ * Runs the given closure 4 times with combination of v2, v3 protos and descriptor files.
*/
private def checkWithFileAndClassName(messageName: String)(
fn: (String, Option[String]) => Unit): Unit = {
- withClue("(With descriptor file)") {
- fn(messageName, Some(testFileDesc))
- }
- withClue("(With Java class name)") {
- fn(s"$javaClassNamePrefix$messageName", None)
+
+ val variations = Seq(
+ (messageName, Some(descPathV2), "V2 descriptor file"),
+ (javaClassFullNameV2(outerClass, messageName), None, "V2 Java class"),
+ (messageName, Some(descPathV3), "V3 descriptor file"),
+ (javaClassFullNameV3(outerClass, messageName), None, "V3 Java class")
+ )
+
+ for ((messageName, descOpt, clue) <- variations) {
+ withClue(s"(With $clue)") {
+ fn(messageName, descOpt)
}
+ }
}
// A wrapper to invoke the right variable of from_protobuf() depending on arguments.
@@ -124,17 +136,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
case (name, descFilePathOpt) =>
val fromProtoDF = df.select(
from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from"))
- val toProtoDF = fromProtoDF.select(
- to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to"))
- val toFromProtoDF = toProtoDF.select(
- from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from"))
- checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
+ val toProtoDF = fromProtoDF.select(
+ to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to"))
+ val toFromProtoDF = toProtoDF.select(
+ from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from"))
+ checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*"))
}
}
test("roundtrip in from_protobuf and to_protobuf - Repeated Message Once") {
- val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "RepeatedMessage")
- val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage")
+ val repeatedMessageDesc = ProtobufUtils.buildDescriptor(descPathV3, "RepeatedMessage")
+ val basicMessageDesc = ProtobufUtils.buildDescriptor(descPathV3, "BasicMessage")
val basicMessage = DynamicMessage
.newBuilder(basicMessageDesc)
@@ -170,8 +182,8 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
}
test("roundtrip in from_protobuf and to_protobuf - Repeated Message Twice") {
- val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "RepeatedMessage")
- val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage")
+ val repeatedMessageDesc = ProtobufUtils.buildDescriptor(descPathV2, "RepeatedMessage")
+ val basicMessageDesc = ProtobufUtils.buildDescriptor(descPathV2, "BasicMessage")
val basicMessage1 = DynamicMessage
.newBuilder(basicMessageDesc)
@@ -221,7 +233,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
}
test("roundtrip in from_protobuf and to_protobuf - Map") {
- val messageMapDesc = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageMap")
+ val messageMapDesc = ProtobufUtils.buildDescriptor(descPathV3, "SimpleMessageMap")
val mapStr1 = DynamicMessage
.newBuilder(messageMapDesc.findNestedTypeByName("StringMapdataEntry"))
@@ -315,8 +327,8 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
}
test("roundtrip in from_protobuf and to_protobuf - Enum") {
- val messageEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageEnum")
- val basicEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicEnumMessage")
+ val messageEnumDesc = ProtobufUtils.buildDescriptor(descPathV2, "SimpleMessageEnum")
+ val basicEnumDesc = ProtobufUtils.buildDescriptor(descPathV2, "BasicEnumMessage")
val dynamicMessage = DynamicMessage
.newBuilder(messageEnumDesc)
@@ -324,7 +336,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
.setField(messageEnumDesc.findFieldByName("value"), "value")
.setField(
messageEnumDesc.findFieldByName("nested_enum"),
- messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING"))
+ messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("NESTED_NOTHING"))
.setField(
messageEnumDesc.findFieldByName("nested_enum"),
messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST"))
@@ -351,9 +363,9 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
}
test("roundtrip in from_protobuf and to_protobuf - Multiple Message") {
- val messageMultiDesc = ProtobufUtils.buildDescriptor(testFileDesc, "MultipleExample")
- val messageIncludeDesc = ProtobufUtils.buildDescriptor(testFileDesc, "IncludedExample")
- val messageOtherDesc = ProtobufUtils.buildDescriptor(testFileDesc, "OtherExample")
+ val messageMultiDesc = ProtobufUtils.buildDescriptor(descPathV2, "MultipleExample")
+ val messageIncludeDesc = ProtobufUtils.buildDescriptor(descPathV2, "IncludedExample")
+ val messageOtherDesc = ProtobufUtils.buildDescriptor(descPathV3, "OtherExample")
val otherMessage = DynamicMessage
.newBuilder(messageOtherDesc)
@@ -386,8 +398,8 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
}
test("Handle recursive fields in Protobuf schema, A->B->A") {
- val schemaA = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveA")
- val schemaB = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveB")
+ val schemaA = ProtobufUtils.buildDescriptor(descPathV3, "recursiveA")
+ val schemaB = ProtobufUtils.buildDescriptor(descPathV3, "recursiveB")
val messageBForA = DynamicMessage
.newBuilder(schemaB)
@@ -422,8 +434,8 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
}
test("Handle recursive fields in Protobuf schema, C->D->Array(C)") {
- val schemaC = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveC")
- val schemaD = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveD")
+ val schemaC = ProtobufUtils.buildDescriptor(descPathV2, "recursiveC")
+ val schemaD = ProtobufUtils.buildDescriptor(descPathV2, "recursiveD")
val messageDForC = DynamicMessage
.newBuilder(schemaD)
@@ -458,7 +470,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
}
test("Handle extra fields : oldProducer -> newConsumer") {
- val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
+ val testFileDesc = descFilePathV2("catalyst_types.desc")
val oldProducer = ProtobufUtils.buildDescriptor(testFileDesc, "oldProducer")
val newConsumer = ProtobufUtils.buildDescriptor(testFileDesc, "newConsumer")
@@ -498,7 +510,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
}
test("Handle extra fields : newProducer -> oldConsumer") {
- val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/")
+ val testFileDesc = descFilePathV3("catalyst_types.desc")
val newProducer = ProtobufUtils.buildDescriptor(testFileDesc, "newProducer")
val oldConsumer = ProtobufUtils.buildDescriptor(testFileDesc, "oldConsumer")
@@ -541,35 +553,38 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
schema
)
- val toProtobuf = inputDf.select(
- functions.to_protobuf($"requiredMsg", "requiredMsg", testFileDesc)
- .as("to_proto"))
-
- val binary = toProtobuf.take(1).toSeq(0).get(0).asInstanceOf[Array[Byte]]
-
- val messageDescriptor = ProtobufUtils.buildDescriptor(testFileDesc, "requiredMsg")
- val actualMessage = DynamicMessage.parseFrom(messageDescriptor, binary)
-
- assert(actualMessage.getField(messageDescriptor.findFieldByName("key"))
- == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0))
- assert(actualMessage.getField(messageDescriptor.findFieldByName("col_2"))
- == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0))
- assert(actualMessage.getField(messageDescriptor.findFieldByName("col_1")) == 0)
- assert(actualMessage.getField(messageDescriptor.findFieldByName("col_3")) == 0)
-
- val fromProtoDf = toProtobuf.select(
- functions.from_protobuf($"to_proto", "requiredMsg", testFileDesc) as 'from_proto)
-
- assert(fromProtoDf.select("from_proto.key").take(1).toSeq(0).get(0)
- == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0))
- assert(fromProtoDf.select("from_proto.col_2").take(1).toSeq(0).get(0)
- == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0))
- assert(fromProtoDf.select("from_proto.col_1").take(1).toSeq(0).get(0) == null)
- assert(fromProtoDf.select("from_proto.col_3").take(1).toSeq(0).get(0) == null)
+ checkWithFileAndClassName("requiredMsg") {
+ case (name, descFilePathOpt) =>
+ val toProtobuf = inputDf.select(
+ to_protobuf_wrapper($"requiredMsg", name, descFilePathOpt)
+ .as("to_proto"))
+
+ val binary = toProtobuf.take(1).toSeq(0).get(0).asInstanceOf[Array[Byte]]
+
+ val messageDescriptor = ProtobufUtils.buildDescriptor(name, descFilePathOpt)
+ val actualMessage = DynamicMessage.parseFrom(messageDescriptor, binary)
+
+ assert(actualMessage.getField(messageDescriptor.findFieldByName("key"))
+ == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0))
+ assert(actualMessage.getField(messageDescriptor.findFieldByName("col_2"))
+ == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0))
+ assert(actualMessage.getField(messageDescriptor.findFieldByName("col_1")) == 0)
+ assert(actualMessage.getField(messageDescriptor.findFieldByName("col_3")) == 0)
+
+ val fromProtoDf = toProtobuf.select(
+ from_protobuf_wrapper($"to_proto", name, descFilePathOpt) as 'from_proto)
+
+ assert(fromProtoDf.select("from_proto.key").take(1).toSeq(0).get(0)
+ == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0))
+ assert(fromProtoDf.select("from_proto.col_2").take(1).toSeq(0).get(0)
+ == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0))
+ assert(fromProtoDf.select("from_proto.col_1").take(1).toSeq(0).get(0) == null)
+ assert(fromProtoDf.select("from_proto.col_3").take(1).toSeq(0).get(0) == null)
+ }
}
test("from_protobuf filter to_protobuf") {
- val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage")
+ val basicMessageDesc = ProtobufUtils.buildDescriptor(descPathV3, "BasicMessage")
val basicMessage = DynamicMessage
.newBuilder(basicMessageDesc)
@@ -680,4 +695,64 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri
=== inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0))
}
}
+
+ test("Default and required values in version 2 protobufs") {
+ // Verifies that V2 protobufs with default values have correct values returned.
+
+ // Proto:
+ // message ProtoWithDefaults {
+ // optional string user_name = 1;
+ // required int32 id = 2;
+ // optional int32 api_quota = 3 [default = 100]; // Default 100 qps.
+ // optional string location = 4 [default = "Unknown"];
+ // }
+
+ Seq( // Only V2 variations. This proto does not exist in V3
+ ("ProtoWithDefaults", Some(descPathV2), "V2 descriptor file"),
+ (javaClassFullNameV2(outerClass, "ProtoWithDefaults"), None, "V2 Java class"),
+ ).foreach {
+ case (name, descFilePathOpt, clue) =>
+ withClue(clue) {
+ val schema = Seq[Array[Byte]]()
+ .toDF("bytes")
+ .select(from_protobuf_wrapper($"bytes", name, descFilePathOpt).as("proto"))
+ .schema
+
+ val inputRows = Seq(
+ Row(Row("user-1", 1, 10, "CA")), // All the fields are set.
+ Row(Row("user-2", 2, null, null)), // Defaults.
+ Row(Row(null, 3, 100, "NY")) // Explicit default for "api_quota".
+ )
+
+ val inputDf = spark.createDataFrame(
+ spark.sparkContext.parallelize(inputRows),
+ schema
+ )
+
+ assert(
+ !schema.head.dataType.asInstanceOf[StructType].fields(1).nullable,
+ "id should not be nullable"
+ )
+
+ val rows = inputDf // Do a round trip with to_protobuf & from_protobuf
+ .select(
+ to_protobuf_wrapper($"proto", name, descFilePathOpt).as("proto_bytes")
+ )
+ .select(
+ from_protobuf_wrapper($"proto_bytes", name, descFilePathOpt).as("proto")
+ )
+
+ val results = rows.collect().toSeq.sortBy(_.getStruct(0).getInt(1)) // sort by id.
+
+ // First row should be same as input.
+ assert(results(0) == inputRows(0))
+
+ // Second row should have defaults for 'api_quota' and 'location'.
+ assert(results(1) == Row(Row("user-2", 2, 100, "Unknown")))
+
+ // Third row should return null for optional field user_name. So same as input.
+ assert(results(2) == inputRows(2))
+ }
+ }
+ }
}
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
index efc02524e68d..9fc1337476cc 100644
--- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
@@ -30,126 +30,173 @@ import org.apache.spark.sql.types.{IntegerType, StructType}
* Tests for [[ProtobufSerializer]] and [[ProtobufDeserializer]] with a more specific focus on
* those classes.
*/
-class ProtobufSerdeSuite extends SharedSparkSession {
+class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestUtils {
import ProtoSerdeSuite._
import ProtoSerdeSuite.MatchType._
- val testFileDesc = testFile("protobuf/serde_suite.desc").replace("file:/", "/")
- private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SerdeSuiteProtos$"
+ private val descPathV2 = descFilePathV2("serde_suite.desc")
+ private val descPathV3 = descFilePathV3("serde_suite.desc")
+ private val outerClass = "SerdeSuiteProtos"
+
+ /**
+ * Runs the given closure 4 times with combination of v2, v3 protos and descriptor files.
+ */
+ private def checkWithFileAndClassName(messageName: String)(
+ fn: (String, Option[String]) => Unit): Unit = {
+
+ val variations = Seq(
+ (messageName, Some(descPathV2), "V2 descriptor file"),
+ (javaClassFullNameV2(outerClass, messageName), None, "V2 Java class"),
+ (messageName, Some(descPathV3), "V3 descriptor file"),
+ (javaClassFullNameV3(outerClass, messageName), None, "V3 Java class")
+ )
+
+ for ((messageName, descOpt, clue) <- variations) {
+ withClue(s"(With $clue)") {
+ fn(messageName, descOpt)
+ }
+ }
+ }
+
test("Test basic conversion") {
withFieldMatchType { fieldMatch =>
- val (top, nest) = fieldMatch match {
- case BY_NAME => ("foo", "bar")
- }
- val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage")
+ checkWithFileAndClassName("SerdeBasicMessage") {
+ case (name, descPathOpt) =>
+ val protoFile = ProtobufUtils.buildDescriptor(name, descPathOpt)
- val dynamicMessageFoo = DynamicMessage
- .newBuilder(protoFile.getFile.findMessageTypeByName("Foo"))
- .setField(protoFile.getFile.findMessageTypeByName("Foo").findFieldByName("bar"), 10902)
- .build()
+ val dynamicMessageFoo = DynamicMessage
+ .newBuilder(protoFile.getFile.findMessageTypeByName("Foo"))
+ .setField(protoFile.getFile.findMessageTypeByName("Foo").findFieldByName("bar"), 10902)
+ .build()
- val dynamicMessage = DynamicMessage
- .newBuilder(protoFile)
- .setField(protoFile.findFieldByName("foo"), dynamicMessageFoo)
- .build()
+ val dynamicMessage = DynamicMessage
+ .newBuilder(protoFile)
+ .setField(protoFile.findFieldByName("foo"), dynamicMessageFoo)
+ .build()
- val serializer = Serializer.create(CATALYST_STRUCT, protoFile, fieldMatch)
- val deserializer = Deserializer.create(CATALYST_STRUCT, protoFile, fieldMatch)
+ val serializer = Serializer.create(CATALYST_STRUCT, protoFile, fieldMatch)
+ val deserializer = Deserializer.create(CATALYST_STRUCT, protoFile, fieldMatch)
- assert(
- serializer.serialize(deserializer.deserialize(dynamicMessage).get) === dynamicMessage)
+ assert(
+ serializer.serialize(deserializer.deserialize(dynamicMessage).get) === dynamicMessage)
+ }
}
}
test("Fail to convert with field type mismatch") {
- val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "MissMatchTypeInRoot")
-
withFieldMatchType { fieldMatch =>
- assertFailedConversionMessage(
- protoFile,
- Deserializer,
- fieldMatch,
- "Cannot convert Protobuf field 'foo' to SQL field 'foo' because schema is incompatible " +
- s"(protoType = org.apache.spark.sql.protobuf.MissMatchTypeInRoot.foo " +
- s"LABEL_OPTIONAL LONG INT64, sqlType = ${CATALYST_STRUCT.head.dataType.sql})".stripMargin)
-
- assertFailedConversionMessage(
- protoFile,
- Serializer,
- fieldMatch,
- s"Cannot convert SQL field 'foo' to Protobuf field 'foo' because schema is incompatible " +
- s"""(sqlType = ${CATALYST_STRUCT.head.dataType.sql}, protoType = LONG)""")
+ checkWithFileAndClassName("MissMatchTypeInRoot") {
+ case (name, descPathOpt) =>
+ val protoFile = ProtobufUtils.buildDescriptor(name, descPathOpt)
+
+ val version = descPathOpt match { // A bit hacky, but ok for now.
+ case Some(path) => if (path == descPathV2) 2 else 3
+ case None => if (name.contains(".v2.")) 2 else 3
+ }
+
+ assertFailedConversionMessage(
+ protoFile,
+ Deserializer,
+ fieldMatch,
+ "Cannot convert Protobuf field 'foo' to SQL field 'foo' because schema is " +
+ s"incompatible (protoType = " +
+ s"org.apache.spark.sql.protobuf.protos.v$version.MissMatchTypeInRoot.foo " +
+ s"LABEL_OPTIONAL LONG INT64, sqlType = ${CATALYST_STRUCT.head.dataType.sql})")
+
+ assertFailedConversionMessage(
+ protoFile,
+ Serializer,
+ fieldMatch,
+ s"Cannot convert SQL field 'foo' to Protobuf field 'foo' because schema is " +
+ s"incompatible (sqlType = ${CATALYST_STRUCT.head.dataType.sql}, protoType = LONG)")
+ }
}
}
test("Fail to convert with missing nested Protobuf fields for serializer") {
- val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInProto")
+ checkWithFileAndClassName("FieldMissingInProto") {
+ case (name, descPathOpt) =>
+ val protoFile = ProtobufUtils.buildDescriptor(name, descPathOpt)
- val nonnullCatalyst = new StructType()
- .add("foo", new StructType().add("bar", IntegerType, nullable = false))
+ val nonnullCatalyst = new StructType()
+ .add("foo", new StructType().add("bar", IntegerType, nullable = false))
- // serialize fails whether or not 'bar' is nullable
- val byNameMsg = "Cannot find field 'foo.bar' in Protobuf schema"
- assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg)
- assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg, nonnullCatalyst)
+ // serialize fails whether or not 'bar' is nullable
+ val byNameMsg = "Cannot find field 'foo.bar' in Protobuf schema"
+ assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg)
+ assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg, nonnullCatalyst)
+ }
}
test("Fail to convert with deeply nested field type mismatch") {
- val protoFile = ProtobufUtils.buildDescriptorFromJavaClass(
- s"${javaClassNamePrefix}MissMatchTypeInDeepNested"
- )
+
val catalyst = new StructType().add("top", CATALYST_STRUCT)
withFieldMatchType { fieldMatch =>
- assertFailedConversionMessage(
- protoFile,
- Deserializer,
- fieldMatch,
- s"Cannot convert Protobuf field 'top.foo.bar' to SQL field 'top.foo.bar' because schema " +
- s"is incompatible (protoType = org.apache.spark.sql.protobuf.protos.TypeMiss.bar " +
- s"LABEL_OPTIONAL LONG INT64, sqlType = INT)",
- catalyst)
-
- assertFailedConversionMessage(
- protoFile,
- Serializer,
- fieldMatch,
- "Cannot convert SQL field 'top.foo.bar' to Protobuf field 'top.foo.bar' because schema " +
- """is incompatible (sqlType = INT, protoType = LONG)""",
- catalyst)
+ checkWithFileAndClassName("MissMatchTypeInDeepNested") {
+ case (name, descPathOpt) =>
+ val protoFile = ProtobufUtils.buildDescriptor(name, descPathOpt)
+
+ val version = descPathOpt match {
+ case Some(path) => if (path == descPathV2) 2 else 3
+ case None => if (name.contains(".v2.")) 2 else 3
+ }
+
+ assertFailedConversionMessage(
+ protoFile,
+ Deserializer,
+ fieldMatch,
+ s"Cannot convert Protobuf field 'top.foo.bar' to SQL field 'top.foo.bar' " +
+ s"because schema is incompatible (protoType = " +
+ s"org.apache.spark.sql.protobuf.protos.v$version.TypeMiss.bar " +
+ s"LABEL_OPTIONAL LONG INT64, sqlType = INT)",
+ catalyst)
+
+ assertFailedConversionMessage(
+ protoFile,
+ Serializer,
+ fieldMatch,
+ "Cannot convert SQL field 'top.foo.bar' to Protobuf field 'top.foo.bar' because " +
+ "schema is incompatible (sqlType = INT, protoType = LONG)",
+ catalyst)
+ }
}
}
test("Fail to convert with missing Catalyst fields") {
- val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLRoot")
-
- // serializing with extra fails if extra field is missing in SQL Schema
- assertFailedConversionMessage(
- protoFile,
- Serializer,
- BY_NAME,
- "Found field 'boo' in Protobuf schema but there is no match in the SQL schema")
-
- /* deserializing should work regardless of whether the extra field is missing
- in SQL Schema or not */
- withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _))
- withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _))
-
- val protoNestedFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLNested")
-
- // serializing with extra fails if extra field is missing in SQL Schema
- assertFailedConversionMessage(
- protoNestedFile,
- Serializer,
- BY_NAME,
- "Found field 'foo.baz' in Protobuf schema but there is no match in the SQL schema")
-
- /* deserializing should work regardless of whether the extra field is missing
+
+ checkWithFileAndClassName("FieldMissingInSQLRoot") {
+ case (name, descPathOpt) =>
+ val protoFile = ProtobufUtils.buildDescriptor(name, descPathOpt)
+ // serializing with extra fails if extra field is missing in SQL Schema
+ assertFailedConversionMessage(
+ protoFile,
+ Serializer,
+ BY_NAME,
+ "Found field 'boo' in Protobuf schema but there is no match in the SQL schema")
+
+ // deserializing should work regardless of whether the extra field is missing
+ // in SQL Schema or not */
+ withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _))
+ withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _))
+
+ val nestedMessageName = name.replace("FieldMissingInSQLRoot", "FieldMissingInSQLNested")
+ val protoNestedFile = ProtobufUtils.buildDescriptor(nestedMessageName, descPathOpt)
+
+ // serializing with extra fails if extra field is missing in SQL Schema
+ assertFailedConversionMessage(
+ protoNestedFile,
+ Serializer,
+ BY_NAME,
+ "Found field 'foo.baz' in Protobuf schema but there is no match in the SQL schema")
+
+ /* deserializing should work regardless of whether the extra field is missing
in SQL Schema or not */
- withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _))
- withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _))
+ withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _))
+ withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _))
+ }
}
/**
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufTestUtils.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufTestUtils.scala
new file mode 100644
index 000000000000..afc967a63e1e
--- /dev/null
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufTestUtils.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.protobuf
+
+import org.apache.spark.sql.test.SQLTestUtils
+
+/**
+ * A trait used by all the Protobuf function suites.
+ */
+private[sql] trait ProtobufTestUtils extends SQLTestUtils {
+
+ val javaPackageV2 = "org.apache.spark.sql.protobuf.protos.v2"
+ val javaPackageV3 = "org.apache.spark.sql.protobuf.protos.v3"
+
+ def javaClassFullNameV2(outerClass: String, innerClass: String): String =
+ s"$javaPackageV2.$outerClass$$$innerClass"
+
+ def javaClassFullNameV3(outerClass: String, innerClass: String): String =
+ s"$javaPackageV3.$outerClass$$$innerClass"
+
+ def descFilePathV2(descFileName: String): String =
+ testFile(s"descriptor-set-v2/$descFileName").replace("file:/", "/")
+
+ def descFilePathV3(descFileName: String): String =
+ testFile(s"descriptor-set-v3/$descFileName").replace("file:/", "/")
+}
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 33883a2efaa5..6577e9ef5e42 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -723,10 +723,11 @@ object SparkProtobuf {
dependencyOverrides += "com.google.protobuf" % "protobuf-java" % protoVersion,
- (Test / PB.protoSources) += (Test / sourceDirectory).value / "resources",
-
+ // Protobuf V3 code is generated here. V2 is generated in another settings bellow.
(Test / PB.targets) := Seq(
- PB.gens.java -> target.value / "generated-test-sources"
+ PB.gens.java -> target.value / "generated-test-sources",
+ PB.gens.descriptorSet ->
+ target.value / "generated-test-sources/protobuf/descriptor-set.desc",
),
(assembly / test) := { },
diff --git a/python/pyspark/sql/protobuf/functions.py b/python/pyspark/sql/protobuf/functions.py
index 2059d868c7c1..49f741569894 100644
--- a/python/pyspark/sql/protobuf/functions.py
+++ b/python/pyspark/sql/protobuf/functions.py
@@ -101,7 +101,7 @@ def from_protobuf(
assert sc is not None and sc._jvm is not None
try:
jc = sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf(
- _to_java_column(data), messageName, descFilePath, options or {}
+ _to_java_column(data), descFilePath, messageName, options or {}
)
except TypeError as e:
if str(e) == "'JavaPackage' object is not callable":
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index ae425419c540..7c55805a5a94 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -463,7 +463,10 @@ private[sql] trait SQLTestUtilsBase
* Returns full path to the given file in the resource folder
*/
protected def testFile(fileName: String): String = {
- Thread.currentThread().getContextClassLoader.getResource(fileName).toString
+ Option(Thread.currentThread().getContextClassLoader.getResource(fileName)) match {
+ case Some(path) => path.toString
+ case None => throw new IllegalArgumentException(s"Resource not found: $fileName")
+ }
}
/**